import argparse import json import torch from concurrent.futures import ThreadPoolExecutor from matplotlib import pyplot as plt from collections import defaultdict from tqdm import tqdm, trange from model import Model from utils import device, img, netio, math from utils.loss import ssim, mse_loss, LpipsLoss from utils.types import * from utils.view import Trans from utils.profile import Profiler, enable_profile from data import * from components.render import render parser = argparse.ArgumentParser() parser.add_argument('-r', '--output-res', type=Resolution.from_str, help='Output resolution') parser.add_argument('-o', '--output', nargs='+', type=str, default=['perf', 'color'], help='Specify what to output (perf, color, depth, all)') parser.add_argument('--media', type=str, default='image', help='Specify the media of output (image, video)') parser.add_argument('--views', type=lambda s: range(*[int(val) for val in s.split('-')]), help='Specify the range of views to test') parser.add_argument('--batch', type=int, help="Batch size (to avoid out-of-memory") parser.add_argument('--profile', action='store_true', help='Enable time measurement') parser.add_argument("--warm-up", type=int, default=10) parser.add_argument('--stereo', type=float, default=0, help='Specify the stereo disparity. If greater than 0, stereo images will be generated') parser.add_argument('ckpt_path', type=str, help='Path to the ckpt file') parser.add_argument('dataset_path', type=str, help='Path to the dataset') args = parser.parse_args() torch.set_grad_enabled(False) lpips_loss = LpipsLoss().to(device.default()) output_types = list({ "color" if item in ["color", "perf"] else item for item in args.output }) # Load model ckpt_path = netio.find_checkpoint(Path(args.ckpt_path)) ckpt = torch.load(ckpt_path) print(f"Load checkpoint: {ckpt_path}") print("Model arguments:", json.dumps(ckpt["args"]["model_args"])) model = Model.create(ckpt["args"]["model"], ckpt["args"]["model_args"] # raymarching_early_stop_tolerance=0.01, # raymarching_chunk_size_or_sections=None, # perturb_sample=False ) model.load_state_dict(ckpt["states"]["model"]) model.to(device.default()).eval() # Debug: print model structure print(model) # Load dataset dataset = Dataset(args.dataset_path, res=args.output_res, views_to_load=args.views, color_mode=model.color, coord_sys=model.args.coord, device=device.default()) print(f"Load dataset: {dataset.root}/{dataset.name} ({dataset.color_mode}, {dataset.coord_sys})") run_dir = ckpt_path.parent out_dir = run_dir / f"output_{ckpt_path.stem.split('_')[-1]}" out_id = f'{dataset.name}_{args.output_res.w}x{args.output_res.h}' if args.output_res\ else dataset.name batch_size = args.batch or dataset.pixels_per_view n = len(dataset) executor = ThreadPoolExecutor(8) if args.media == "video": video_frames = defaultdict(list) def save_image(out: torch.Tensor, out_type: str, view_idx: int): out = out.detach().cpu() if args.media == 'video': video_frames[out_type].append(out) else: output_subdir = out_dir / f"{out_id}_{out_type}{'_stereo' if args.stereo > 0 else ''}" output_subdir.mkdir(parents=True, exist_ok=True) executor.submit(img.save, out, f'{output_subdir}/{view_idx:04d}.png') def save_error_image(gt: torch.Tensor, out: torch.Tensor, view_idx: int): error_image = (mse_loss(out, gt, reduction='none').mean(-3, True) / 1e-2).clamp(0, 1) error_image = img.torch2np(error_image)[..., 0] output_subdir = out_dir / f"{out_id}_error" output_subdir.mkdir(exist_ok=True) def save_fn(error_image, view_idx): img.save(plt.get_cmap("jet")(error_image), f'{output_subdir}/{view_idx:04d}.png') executor.submit(save_fn, error_image, view_idx) if args.profile: def handle_profile_result(result: Profiler.ProfileResult): print(result.get_report()) enable_profile(0, len(dataset), handle_profile_result) perf = "perf" in args.output and args.stereo == 0 and defaultdict(list, dummy=[]) out_dir.mkdir(parents=True, exist_ok=True) if perf: # Warm-up first for accurate time measurement rays_d = Trans(dataset.centers[0], dataset.rots[0]).trans_vector( dataset.cam.local_rays[:batch_size]) rays_o = dataset.centers[:1, None, :].expand_as(rays_d) rays = Rays(rays_o=rays_o, rays_d=rays_d).flatten() print(rays_o.shape, rays_d.shape) for i in trange(args.warm_up, desc="Warm up"): model(rays, *output_types) for i in trange(n, desc="Test"): view_idx = dataset.indices[i].item() if perf: test_perf = Profiler.Node("Test") view = Trans(dataset.centers[i], dataset.rots[i]) if args.stereo > 0: left_view = Trans( view.trans_point(torch.tensor([-args.stereo / 2, 0, 0], device=view.device)), view.r) right_view = Trans( view.trans_point(torch.tensor([args.stereo / 2, 0, 0], device=view.device)), view.r) out_left = render(model, dataset.cam, left_view, *output_types, batch_size=batch_size) out_right = render(model, dataset.cam, right_view, *output_types, batch_size=batch_size) out = ReturnData({ key: torch.cat([out_left[key], out_right[key]], dim=2) for key in out_left if isinstance(out_left[key], torch.Tensor) }) else: out = render(model, dataset.cam, view, *output_types, batch_size=batch_size) if perf: test_perf.close() torch.cuda.synchronize() perf["view"].append(view_idx) perf["time"].append(test_perf.device_duration) gt_image = dataset.load_images("color", view_idx) out_image = out.color.movedim(-1, -3) if gt_image is not None: perf["mse"].append(mse_loss(out_image, gt_image).item()) perf["ssim"].append(ssim(out_image, gt_image).item() * 100) perf["lpips"].append(lpips_loss(out_image, gt_image).item()) save_error_image(gt_image, out_image, view_idx) else: perf["mse"].append(math.nan) perf["ssim"].append(math.nan) perf["lpips"].append(math.nan) for key, value in out.items(): save_image(value, key, view_idx) if perf: perf_mean_time = sum(perf['time']) / n perf_mean_error = sum(perf['mse']) / n perf_name = f'perf_{out_id}_{perf_mean_time:.1f}ms_{perf_mean_error:.2e}.csv' # Remove old performance reports for file in out_dir.glob(f'perf_{out_id}*'): file.unlink() # Save new performance reports with (out_dir / perf_name).open('w') as fp: fp.write('PSNR, SSIM, LPIPS\n') fp.writelines([ f'{img.mse2psnr(perf["mse"][i]):.2f}, {perf["ssim"][i]:.2f}, {perf["lpips"][i]:.2e}\n' for i in range(n) ]) if args.media == "video": for key, frames in video_frames.items(): img.save_video(torch.cat(frames, 0), out_dir / f"{out_id}_{key}{'_stereo' if args.stereo > 0 else ''}.mp4", 30)