import os import argparse import torch import torch.nn.functional as nn_f from math import nan, ceil, prod from pathlib import Path parser = argparse.ArgumentParser() parser.add_argument('-m', '--model', type=str, help='The model file to load for testing') parser.add_argument('-r', '--output-res', type=str, help='Output resolution') parser.add_argument('-o', '--output', nargs='+', type=str, default=['perf', 'color'], help='Specify what to output (perf, color, depth, all)') parser.add_argument('--output-type', type=str, default='image', help='Specify the output type (image, video, debug)') parser.add_argument('--views', type=str, help='Specify the range of views to test') parser.add_argument('-p', '--prompt', action='store_true', help='Interactive prompt mode') parser.add_argument('--time', action='store_true', help='Enable time measurement') parser.add_argument('dataset', type=str, help='Dataset description file') args = parser.parse_args() import model as mdl from loss.ssim import ssim from utils import color from utils import interact from utils import device from utils import img from utils.perf import Perf, enable_perf, get_perf_result from utils.progress_bar import progress_bar from data.dataset_factory import * from data.loader import DataLoader from utils.constants import HUGE_FLOAT RAYS_PER_BATCH = 2 ** 14 DATA_LOADER_CHUNK_SIZE = 1e8 data_desc_path = DatasetFactory.get_dataset_desc_path(args.dataset) os.chdir(data_desc_path.parent) nets_dir = Path("_nets") data_desc_path = data_desc_path.name def set_outputs(args, outputs_str: str): args.output = [s.strip() for s in outputs_str.split(',')] if args.prompt: # Prompt test model, output resolution, output mode model_files = [str(path.relative_to(nets_dir)) for path in nets_dir.rglob("*.tar")] \ + [str(path.relative_to(nets_dir)) for path in nets_dir.rglob("*.pth")] args.model = interact.input_enum('Specify test model:', model_files, err_msg='No such model file') args.output_res = interact.input_ex('Specify output resolution:', default='') set_outputs(args, interact.input_ex('Specify the outputs | [perf,color,depth,layers,diffuse,specular]/all:', default='perf,color')) args.output_type = interact.input_enum('Specify the output type | image/video:', ['image', 'video'], err_msg='Wrong output type', default='image') args.output_res = tuple(int(s) for s in reversed(args.output_res.split('x'))) if args.output_res \ else None args.output_flags = { item: item in args.output or 'all' in args.output for item in ['perf', 'color', 'depth', 'layers', 'diffuse', 'specular'] } args.views = range(*[int(val) for val in args.views.split('-')]) if args.views else None if args.time: enable_perf() dataset = DatasetFactory.load(data_desc_path, res=args.output_res, load_images=args.output_flags['perf'], views_to_load=args.views) print(f"Dataset loaded: {dataset.root}/{dataset.name}") model_path: Path = nets_dir / args.model model_name = model_path.parent.name model = mdl.load(model_path, { "raymarching_early_stop_tolerance": 0.01, # "raymarching_chunk_size_or_sections": [8], "perturb_sample": False })[0].to(device.default()).eval() model_class = model.__class__.__name__ model_args = model.args print(f"model: {model_name} ({model_class})") print("args:", json.dumps(model.args0)) run_dir = model_path.parent output_dir = run_dir / f"output_{int(model_path.stem.split('_')[-1])}" output_dataset_id = '%s%s' % ( dataset.name, f'_{args.output_res[1]}x{args.output_res[0]}' if args.output_res else '' ) if __name__ == "__main__": with torch.no_grad(): # 1. Initialize data loader data_loader = DataLoader(dataset, RAYS_PER_BATCH, chunk_max_items=DATA_LOADER_CHUNK_SIZE, shuffle=False, enable_preload=True, color=color.from_str(model.args['color'])) # 3. Test on dataset print("Begin test, batch size is %d" % RAYS_PER_BATCH) i = 0 offset = 0 chns = model.chns('color') n = dataset.n_views total_pixels = prod([n, *dataset.res]) out = {} if args.output_flags['perf'] or args.output_flags['color']: out['color'] = torch.zeros(total_pixels, chns, device=device.default()) if args.output_flags['diffuse']: out['diffuse'] = torch.zeros(total_pixels, chns, device=device.default()) if args.output_flags['specular']: out['specular'] = torch.zeros(total_pixels, chns, device=device.default()) if args.output_flags['depth']: out['depth'] = torch.full([total_pixels, 1], HUGE_FLOAT, device=device.default()) gt_images = torch.empty_like(out['color']) if dataset.image_path else None tot_time = 0 tot_iters = len(data_loader) progress_bar(i, tot_iters, 'Inferring...') for _, rays_o, rays_d, extra in data_loader: if args.output_flags['perf']: test_perf = Perf.Node("Test") n_rays = rays_o.size(0) idx = slice(offset, offset + n_rays) ret = model(rays_o, rays_d, extra_outputs=[key for key in out.keys() if key != 'color']) if ret is not None: for key in out: out[key][idx][ret['rays_mask']] = ret[key] if args.output_flags['perf']: test_perf.close() torch.cuda.synchronize() tot_time += test_perf.duration() if gt_images is not None: gt_images[idx] = extra['color'] i += 1 progress_bar(i, tot_iters, 'Inferring...') offset += n_rays # 4. Save results print('Saving results...') output_dir.mkdir(parents=True, exist_ok=True) for key in out: out[key] = out[key].reshape([n, *dataset.res, *out[key].shape[1:]]) if 'color' in out: out['color'] = out['color'].permute(0, 3, 1, 2) if 'diffuse' in out: out['diffuse'] = out['diffuse'].permute(0, 3, 1, 2) if 'specular' in out: out['specular'] = out['specular'].permute(0, 3, 1, 2) if args.output_flags['perf']: perf_errors = torch.full([n], nan) perf_ssims = torch.full([n], nan) if gt_images is not None: gt_images = gt_images.reshape(n, *dataset.res, chns).permute(0, 3, 1, 2) for i in range(n): perf_errors[i] = nn_f.mse_loss(gt_images[i], out['color'][i]).item() perf_ssims[i] = ssim(gt_images[i:i + 1], out['color'][i:i + 1]).item() * 100 perf_mean_time = tot_time / n perf_mean_error = torch.mean(perf_errors).item() perf_name = f'perf_{output_dataset_id}_{perf_mean_time:.1f}ms_{perf_mean_error:.2e}.csv' # Remove old performance reports for file in output_dir.glob(f'perf_{output_dataset_id}*'): file.unlink() # Save new performance reports with (output_dir / perf_name).open('w') as fp: fp.write('View, PSNR, SSIM\n') fp.writelines([ f'{dataset.indices[i]}, ' f'{img.mse2psnr(perf_errors[i].item()):.2f}, {perf_ssims[i].item():.2f}\n' for i in range(n) ]) for output_type in ['color', 'diffuse', 'specular']: if not args.output_flags[output_type]: continue if args.output_type == 'video': output_file = output_dir / f"{output_dataset_id}_{output_type}.mp4" img.save_video(out[output_type], output_file, 30) else: output_subdir = output_dir / f"{output_dataset_id}_{output_type}" output_subdir.mkdir(exist_ok=True) img.save(out[output_type], [f'{output_subdir}/{i:0>4d}.png' for i in dataset.indices]) if args.output_flags['depth']: colored_depths = img.colorize_depthmap(out['depth'][..., 0], model_args['sample_range']) if args.output_type == 'video': output_file = output_dir / f"{output_dataset_id}_depth.mp4" img.save_video(colored_depths, output_file, 30) else: output_subdir = output_dir / f"{output_dataset_id}_depth" output_subdir.mkdir(exist_ok=True) img.save(colored_depths, [f'{output_subdir}/{i:0>4d}.png' for i in dataset.indices]) #output_subdir = output_dir / f"{output_dataset_id}_bins" # output_dir.mkdir(exist_ok=True) #img.save(out['bins'], [f'{output_subdir}/{i:0>4d}.png' for i in dataset.indices]) if args.time: s = "Performance Report ==>\n" res = get_perf_result() if res is None: s += "No available data.\n" else: for key, val in res.items(): path_segs = key.split("/") s += " " * (len(path_segs) - 1) + f"{path_segs[-1]}: {val:.1f}ms\n" print(s)