Commit 4860d0af authored by Nianchen Deng's avatar Nianchen Deng
Browse files

a runable version

parent 338ae906
...@@ -106,7 +106,7 @@ class PanoDataset(object): ...@@ -106,7 +106,7 @@ class PanoDataset(object):
""" """
self.c = c self.c = c
self.device = device self.device = device
self._load_desc(desc, res, views_to_load, load_images) self._load_desc(desc, data_dir, res, views_to_load, load_images)
def get_data(self): def get_data(self):
return { return {
......
...@@ -91,7 +91,7 @@ class ViewDataset(object): ...@@ -91,7 +91,7 @@ class ViewDataset(object):
rays_o = self.centers[view_idx] rays_o = self.centers[view_idx]
rays_d = self.dataset.cam_rays[pix_idx] # (N, 3) rays_d = self.dataset.cam_rays[pix_idx] # (N, 3)
r = self.rots[view_idx].movedim(-1, -2) # (N, 3, 3) r = self.rots[view_idx].movedim(-1, -2) # (N, 3, 3)
rays_d = torch.matmul(rays_d, r) rays_d = (rays_d[:, None] @ r)[:, 0]
extra_data = {} extra_data = {}
if self.colors is not None: if self.colors is not None:
extra_data['colors'] = self.colors[idx] extra_data['colors'] = self.colors[idx]
...@@ -150,15 +150,15 @@ class ViewDataset(object): ...@@ -150,15 +150,15 @@ class ViewDataset(object):
load_depths: bool, load_depths: bool,
load_bins: bool): load_bins: bool):
if load_images and desc.get('view_file_pattern'): if load_images and desc.get('view_file_pattern'):
self.image_path = os.path.join(self.data_dir, desc['view_file_pattern']) self.image_path = os.path.join(os.getcwd(), desc['view_file_pattern'])
else: else:
self.image_path = None self.image_path = None
if load_depths and desc.get('depth_file_pattern'): if load_depths and desc.get('depth_file_pattern'):
self.depth_path = os.path.join(self.data_dir, desc['depth_file_pattern']) self.depth_path = os.path.join(os.getcwd(), desc['depth_file_pattern'])
else: else:
self.depth_path = None self.depth_path = None
if load_bins and desc.get('bins_file_pattern'): if load_bins and desc.get('bins_file_pattern'):
self.bins_path = os.path.join(self.data_dir, desc['bins_file_pattern']) self.bins_path = os.path.join(os.getcwd(), desc['bins_file_pattern'])
else: else:
self.bins_path = None self.bins_path = None
self.res = res if res else misc.values(desc['view_res'], 'y', 'x') self.res = res if res else misc.values(desc['view_res'], 'y', 'x')
......
...@@ -163,25 +163,4 @@ class PdfSampler(nn.Module): ...@@ -163,25 +163,4 @@ class PdfSampler(nn.Module):
return samples return samples
class VoxelSampler(nn.Module):
def __init__(self, *, depth_range: Tuple[float, float], n_samples: int, perturb_sample: bool,
lindisp: bool, space):
"""
Initialize a Sampler module
:param depth_range: depth range for sampler
:param n_samples: count to sample along ray
:param perturb_sample: perturb the sample depths
:param lindisp: If True, sample linearly in inverse depth rather than in depth
"""
super().__init__()
self.lindisp = lindisp
self.perturb_sample = perturb_sample
self.n_samples = n_samples
self.space = space
self.s_range = (1 / depth_range[0], 1 / depth_range[1]) if self.lindisp else depth_range
def forward(self, rays_o, rays_d, *, weights, s_vals=None, include_s_vals=False):
\ No newline at end of file
...@@ -341,6 +341,7 @@ def test(): ...@@ -341,6 +341,7 @@ def test():
load_images=args.output_flags['perf']) load_images=args.output_flags['perf'])
data_loader = DataLoader(dataset, TEST_BATCH_SIZE, chunk_max_items=TEST_MAX_CHUNK_ITEMS, data_loader = DataLoader(dataset, TEST_BATCH_SIZE, chunk_max_items=TEST_MAX_CHUNK_ITEMS,
shuffle=False) shuffle=False)
print(dataset.image_path)
# 2. Load trained model # 2. Load trained model
netio.load(test_model_path, model) netio.load(test_model_path, model)
...@@ -367,7 +368,10 @@ def test(): ...@@ -367,7 +368,10 @@ def test():
if args.output_flags['perf']: if args.output_flags['perf']:
perf = Perf(True, start=True) perf = Perf(True, start=True)
for _, rays_o, rays_d, _ in data_loader: gt = []
for _, rays_o, rays_d, extra_data in data_loader:
if args.output_flags["perf"] and "colors" in extra_data:
gt.append(extra_data["colors"])
n_rays = rays_o.size(0) n_rays = rays_o.size(0)
ret = model(rays_o, rays_d, ret = model(rays_o, rays_d,
ret_depth=args.output_flags['depth'], ret_depth=args.output_flags['depth'],
...@@ -388,8 +392,8 @@ def test(): ...@@ -388,8 +392,8 @@ def test():
* 0.5 + 0.5) * (vals > 0.1) * 0.5 + 0.5) * (vals > 0.1)
idx = slice(offset, offset + n_rays) idx = slice(offset, offset + n_rays)
for key in out: for key in out:
print("key ", key, ", idx ", idx, ", out is ", #print("key ", key, ", idx ", idx, ", out is ",
out[key].shape, ", ret is ", ret[key].shape, ", rays is ", n_rays) # out[key].shape, ", ret is ", ret[key].shape, ", rays is ", n_rays)
out[key][idx] = ret[key] out[key][idx] = ret[key]
if not args.log_redirect: if not args.log_redirect:
progress_bar(i, math.ceil(total_pixels / n_rays), 'Inferring...') progress_bar(i, math.ceil(total_pixels / n_rays), 'Inferring...')
...@@ -416,11 +420,11 @@ def test(): ...@@ -416,11 +420,11 @@ def test():
if args.output_flags['perf']: if args.output_flags['perf']:
perf_errors = torch.ones(n) * NaN perf_errors = torch.ones(n) * NaN
perf_ssims = torch.ones(n) * NaN perf_ssims = torch.ones(n) * NaN
if dataset.images != None: if len(gt) > 0:
gt = torch.cat(gt).reshape(n, *dataset.res, -1).movedim(-1, -3)
for i in range(n): for i in range(n):
perf_errors[i] = loss_mse(dataset.images[i], out['color'][i]).item() perf_errors[i] = loss_mse(gt[i], out['color'][i]).item()
perf_ssims[i] = ssim(dataset.images[i:i + 1], perf_ssims[i] = ssim(gt[i:i + 1], out['color'][i:i + 1]).item() * 100
out['color'][i:i + 1]).item() * 100
perf_mean_time = tot_time / n perf_mean_time = tot_time / n
perf_mean_error = torch.mean(perf_errors).item() perf_mean_error = torch.mean(perf_errors).item()
perf_name = 'perf_%s_%.1fms_%.2e.csv' % ( perf_name = 'perf_%s_%.1fms_%.2e.csv' % (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment