import os import argparse import torch import torch.nn.functional as nn_f import cv2 import numpy as np from pathlib import Path parser = argparse.ArgumentParser() parser.add_argument('-m', '--model', type=str, help='The model file to load for testing') parser.add_argument('-r', '--output-res', type=str, help='Output resolution') parser.add_argument('-o', '--output', nargs='*', type=str, default=['perf', 'color'], help='Specify what to output (perf, color, depth, all)') parser.add_argument('--output-type', type=str, default='image', help='Specify the output type (image, video, debug)') parser.add_argument('--views', type=str, help='Specify the range of views to test') parser.add_argument('-s', '--samples', type=int) parser.add_argument('-p', '--prompt', action='store_true', help='Interactive prompt mode') parser.add_argument('--time', action='store_true', help='Enable time measurement') parser.add_argument('dataset', type=str, help='Dataset description file') args = parser.parse_args() import model as mdl from loss.ssim import ssim from utils import color from utils import interact from utils import device from utils import img from utils import netio from utils import math from utils.perf import Perf, enable_perf, get_perf_result from utils.progress_bar import progress_bar from data import * DATA_LOADER_CHUNK_SIZE = 1e8 torch.set_grad_enabled(False) data_desc_path = get_dataset_desc_path(args.dataset) os.chdir(data_desc_path.parent) nets_dir = Path("_nets") data_desc_path = data_desc_path.name def set_outputs(args, outputs_str: str): args.output = [s.strip() for s in outputs_str.split(',')] if args.prompt: # Prompt test model, output resolution, output mode model_files = [str(path.relative_to(nets_dir)) for path in nets_dir.rglob("*.tar")] \ + [str(path.relative_to(nets_dir)) for path in nets_dir.rglob("*.pth")] args.model = interact.input_enum('Specify test model:', model_files, err_msg='No such model file') args.output_res = interact.input_ex('Specify output resolution:', default='') set_outputs(args, interact.input_ex('Specify the outputs | [perf,color,depth,layers,diffuse,specular]/all:', default='perf,color')) args.output_type = interact.input_enum('Specify the output type | image/video:', ['image', 'video'], err_msg='Wrong output type', default='image') args.output_res = tuple(int(s) for s in reversed(args.output_res.split('x'))) if args.output_res \ else None args.output_flags = { item: item in args.output or 'all' in args.output for item in ['perf', 'color', 'depth', 'layers', 'diffuse', 'specular'] } args.views = range(*[int(val) for val in args.views.split('-')]) if args.views else None if args.time: enable_perf() dataset = DatasetFactory.load(data_desc_path, res=args.output_res, load_images=args.output_flags['perf'], views_to_load=args.views) print(f"Dataset loaded: {dataset.root}/{dataset.name}") RAYS_PER_BATCH = dataset.res[0] * dataset.res[1] // 4 model_path: Path = nets_dir / args.model model_name = model_path.parent.name states, _ = netio.load_checkpoint(model_path) if args.samples: states['args']['n_samples'] = args.samples model = mdl.deserialize(states, raymarching_early_stop_tolerance=0.01, raymarching_chunk_size_or_sections=None, perturb_sample=False).to(device.default()).eval() print(f"model: {model_name} ({model._get_name()})") print("args:", json.dumps(model.args)) run_dir = model_path.parent output_dir = run_dir / f"output_{int(model_path.stem.split('_')[-1])}" output_dataset_id = '%s%s' % ( dataset.name, f'_{args.output_res[1]}x{args.output_res[0]}' if args.output_res else '' ) # 1. Initialize data loader data_loader = DataLoader(dataset, RAYS_PER_BATCH, chunk_max_items=DATA_LOADER_CHUNK_SIZE, shuffle=False, enable_preload=not args.time, color=color.from_str(model.args['color'])) # 3. Test on dataset print("Begin test, batch size is %d" % RAYS_PER_BATCH) i = 0 offset = 0 chns = model.chns('color') n = dataset.n_views total_pixels = math.prod([n, *dataset.res]) out = {} if args.output_flags['perf'] or args.output_flags['color']: out['color'] = torch.zeros(total_pixels, chns, device=device.default()) if args.output_flags['diffuse']: out['diffuse'] = torch.zeros(total_pixels, chns, device=device.default()) if args.output_flags['specular']: out['specular'] = torch.zeros(total_pixels, chns, device=device.default()) if args.output_flags['depth']: out['depth'] = torch.full([total_pixels, 1], math.huge, device=device.default()) gt_images = torch.empty_like(out['color']) if dataset.image_path else None tot_time = 0 tot_iters = len(data_loader) progress_bar(i, tot_iters, 'Inferring...') for data in data_loader: if args.output_flags['perf']: test_perf = Perf.Node("Test") n_rays = data['rays_o'].size(0) idx = slice(offset, offset + n_rays) ret = model(data, *out.keys()) if ret is not None: for key in out: if key not in ret: out[key] = None else: if 'rays_filter' in ret: out[key][idx][ret['rays_filter']] = ret[key] else: out[key][idx] = ret[key] if args.output_flags['perf']: test_perf.close() torch.cuda.synchronize() tot_time += test_perf.duration() if gt_images is not None: gt_images[idx] = data['color'] i += 1 progress_bar(i, tot_iters, 'Inferring...') offset += n_rays # 4. Save results print('Saving results...') output_dir.mkdir(parents=True, exist_ok=True) out = {key: value for key, value in out.items() if value is not None} for key in out: out[key] = out[key].reshape([n, *dataset.res, *out[key].shape[1:]]) if key == 'color' or key == 'diffuse' or key == 'specular': out[key] = out[key].permute(0, 3, 1, 2) if args.output_flags['perf']: perf_errors = torch.full([n], math.nan) perf_ssims = torch.full([n], math.nan) if gt_images is not None: gt_images = gt_images.reshape(n, *dataset.res, chns).permute(0, 3, 1, 2) for i in range(n): perf_errors[i] = nn_f.mse_loss(gt_images[i], out['color'][i]).item() perf_ssims[i] = ssim(gt_images[i:i + 1], out['color'][i:i + 1]).item() * 100 perf_mean_time = tot_time / n perf_mean_error = torch.mean(perf_errors).item() perf_name = f'perf_{output_dataset_id}_{perf_mean_time:.1f}ms_{perf_mean_error:.2e}.csv' # Remove old performance reports for file in output_dir.glob(f'perf_{output_dataset_id}*'): file.unlink() # Save new performance reports with (output_dir / perf_name).open('w') as fp: fp.write('View, PSNR, SSIM\n') fp.writelines([ f'{dataset.indices[i]}, ' f'{img.mse2psnr(perf_errors[i].item()):.2f}, {perf_ssims[i].item():.2f}\n' for i in range(n) ]) error_images = ((gt_images - out['color'])**2).sum(1, True) / chns error_images = (error_images / 1e-2).clamp(0, 1) * 255 error_images = img.torch2np(error_images) error_images = np.asarray(error_images, dtype=np.uint8) output_subdir = output_dir / f"{output_dataset_id}_error" output_subdir.mkdir(exist_ok=True) for i in range(n): heat_img = cv2.applyColorMap(error_images[i], cv2.COLORMAP_JET) # 注意此处的三通道热力图是cv2专有的GBR排列 cv2.imwrite(f'{output_subdir}/{dataset.indices[i]:0>4d}.png', heat_img) for output_type in ['color', 'diffuse', 'specular']: if output_type not in out: continue if args.output_type == 'video': output_file = output_dir / f"{output_dataset_id}_{output_type}.mp4" img.save_video(out[output_type], output_file, 30) else: output_subdir = output_dir / f"{output_dataset_id}_{output_type}" output_subdir.mkdir(exist_ok=True) img.save(out[output_type], [f'{output_subdir}/{i:0>4d}.png' for i in dataset.indices]) if 'depth' in out: colored_depths = img.colorize_depthmap(out['depth'][..., 0], model.args['sample_range']) if args.output_type == 'video': output_file = output_dir / f"{output_dataset_id}_depth.mp4" img.save_video(colored_depths, output_file, 30) else: output_subdir = output_dir / f"{output_dataset_id}_depth" output_subdir.mkdir(exist_ok=True) img.save(colored_depths, [f'{output_subdir}/{i:0>4d}.png' for i in dataset.indices]) #output_subdir = output_dir / f"{output_dataset_id}_bins" # output_dir.mkdir(exist_ok=True) #img.save(out['bins'], [f'{output_subdir}/{i:0>4d}.png' for i in dataset.indices]) if args.time: s = "Performance Report ==>\n" res = get_perf_result() if res is None: s += "No available data.\n" else: for key, val in res.items(): path_segs = key.split("/") s += " " * (len(path_segs) - 1) + f"{path_segs[-1]}: {val:.1f}ms\n" print(s)