From 4860d0af83f1dcd374b10fc8d6c655650bcb714a Mon Sep 17 00:00:00 2001 From: Nianchen Deng Date: Sun, 26 Jun 2022 09:20:52 +0800 Subject: [PATCH] a runable version --- data/pano_dataset.py | 2 +- data/view_dataset.py | 8 ++++---- modules/sampler.py | 21 --------------------- run_spherical_view_syn.py | 18 +++++++++++------- 4 files changed, 16 insertions(+), 33 deletions(-) diff --git a/data/pano_dataset.py b/data/pano_dataset.py index 9953c8f..60e8c5f 100644 --- a/data/pano_dataset.py +++ b/data/pano_dataset.py @@ -106,7 +106,7 @@ class PanoDataset(object): """ self.c = c self.device = device - self._load_desc(desc, res, views_to_load, load_images) + self._load_desc(desc, data_dir, res, views_to_load, load_images) def get_data(self): return { diff --git a/data/view_dataset.py b/data/view_dataset.py index 477629b..c33a292 100644 --- a/data/view_dataset.py +++ b/data/view_dataset.py @@ -91,7 +91,7 @@ class ViewDataset(object): rays_o = self.centers[view_idx] rays_d = self.dataset.cam_rays[pix_idx] # (N, 3) r = self.rots[view_idx].movedim(-1, -2) # (N, 3, 3) - rays_d = torch.matmul(rays_d, r) + rays_d = (rays_d[:, None] @ r)[:, 0] extra_data = {} if self.colors is not None: extra_data['colors'] = self.colors[idx] @@ -150,15 +150,15 @@ class ViewDataset(object): load_depths: bool, load_bins: bool): if load_images and desc.get('view_file_pattern'): - self.image_path = os.path.join(self.data_dir, desc['view_file_pattern']) + self.image_path = os.path.join(os.getcwd(), desc['view_file_pattern']) else: self.image_path = None if load_depths and desc.get('depth_file_pattern'): - self.depth_path = os.path.join(self.data_dir, desc['depth_file_pattern']) + self.depth_path = os.path.join(os.getcwd(), desc['depth_file_pattern']) else: self.depth_path = None if load_bins and desc.get('bins_file_pattern'): - self.bins_path = os.path.join(self.data_dir, desc['bins_file_pattern']) + self.bins_path = os.path.join(os.getcwd(), desc['bins_file_pattern']) else: self.bins_path = None self.res = res if res else misc.values(desc['view_res'], 'y', 'x') diff --git a/modules/sampler.py b/modules/sampler.py index eacd072..5331b12 100644 --- a/modules/sampler.py +++ b/modules/sampler.py @@ -163,25 +163,4 @@ class PdfSampler(nn.Module): return samples - -class VoxelSampler(nn.Module): - - def __init__(self, *, depth_range: Tuple[float, float], n_samples: int, perturb_sample: bool, - lindisp: bool, space): - """ - Initialize a Sampler module - - :param depth_range: depth range for sampler - :param n_samples: count to sample along ray - :param perturb_sample: perturb the sample depths - :param lindisp: If True, sample linearly in inverse depth rather than in depth - """ - super().__init__() - self.lindisp = lindisp - self.perturb_sample = perturb_sample - self.n_samples = n_samples - self.space = space - self.s_range = (1 / depth_range[0], 1 / depth_range[1]) if self.lindisp else depth_range - - def forward(self, rays_o, rays_d, *, weights, s_vals=None, include_s_vals=False): \ No newline at end of file diff --git a/run_spherical_view_syn.py b/run_spherical_view_syn.py index b87daa9..27539a1 100644 --- a/run_spherical_view_syn.py +++ b/run_spherical_view_syn.py @@ -341,6 +341,7 @@ def test(): load_images=args.output_flags['perf']) data_loader = DataLoader(dataset, TEST_BATCH_SIZE, chunk_max_items=TEST_MAX_CHUNK_ITEMS, shuffle=False) + print(dataset.image_path) # 2. Load trained model netio.load(test_model_path, model) @@ -367,7 +368,10 @@ def test(): if args.output_flags['perf']: perf = Perf(True, start=True) - for _, rays_o, rays_d, _ in data_loader: + gt = [] + for _, rays_o, rays_d, extra_data in data_loader: + if args.output_flags["perf"] and "colors" in extra_data: + gt.append(extra_data["colors"]) n_rays = rays_o.size(0) ret = model(rays_o, rays_d, ret_depth=args.output_flags['depth'], @@ -388,8 +392,8 @@ def test(): * 0.5 + 0.5) * (vals > 0.1) idx = slice(offset, offset + n_rays) for key in out: - print("key ", key, ", idx ", idx, ", out is ", - out[key].shape, ", ret is ", ret[key].shape, ", rays is ", n_rays) + #print("key ", key, ", idx ", idx, ", out is ", + # out[key].shape, ", ret is ", ret[key].shape, ", rays is ", n_rays) out[key][idx] = ret[key] if not args.log_redirect: progress_bar(i, math.ceil(total_pixels / n_rays), 'Inferring...') @@ -416,11 +420,11 @@ def test(): if args.output_flags['perf']: perf_errors = torch.ones(n) * NaN perf_ssims = torch.ones(n) * NaN - if dataset.images != None: + if len(gt) > 0: + gt = torch.cat(gt).reshape(n, *dataset.res, -1).movedim(-1, -3) for i in range(n): - perf_errors[i] = loss_mse(dataset.images[i], out['color'][i]).item() - perf_ssims[i] = ssim(dataset.images[i:i + 1], - out['color'][i:i + 1]).item() * 100 + perf_errors[i] = loss_mse(gt[i], out['color'][i]).item() + perf_ssims[i] = ssim(gt[i:i + 1], out['color'][i:i + 1]).item() * 100 perf_mean_time = tot_time / n perf_mean_error = torch.mean(perf_errors).item() perf_name = 'perf_%s_%.1fms_%.2e.csv' % ( -- GitLab