import os import sys import argparse import torch sys.path.append(os.path.abspath(sys.path[0] + '/../')) parser = argparse.ArgumentParser() parser.add_argument('-m', '--model', type=str, help='The model file to load for testing') parser.add_argument('-r', '--output-rays', type=int, default=100, help='How many rays to output') parser.add_argument('-p', '--prompt', action='store_true', help='Interactive prompt mode') parser.add_argument('dataset', type=str, help='Dataset description file') args = parser.parse_args() import model as mdl from utils import misc from utils import color from utils import interact from utils import device from data.dataset_factory import * from data.loader import DataLoader from modules import Samples, Voxels from model.nsvf import NSVF model: NSVF samples: Samples DATA_LOADER_CHUNK_SIZE = 1e8 data_desc_path = args.dataset if args.dataset.endswith('.json') \ else os.path.join(args.dataset, 'train.json') data_desc_name = os.path.splitext(os.path.basename(data_desc_path))[0] data_dir = os.path.dirname(data_desc_path) + '/' def get_model_files(datadir): model_files = [] for root, _, files in os.walk(datadir): model_files += [ os.path.join(root, file).replace(datadir, '') for file in files if file.endswith('.tar') or file.endswith('.pth') ] return model_files if args.prompt: # Prompt test model, output resolution, output mode model_files = get_model_files(data_dir) args.model = interact.input_enum('Specify test model:', model_files, err_msg='No such model file') args.output_rays = interact.input_ex('Specify number of rays to output:', interact.input_to_int(), default=10) model_path = os.path.join(data_dir, args.model) model_name = os.path.splitext(os.path.basename(model_path))[0] model, iters = mdl.load(model_path, {"perturb_sample": False}) model.to(device.default()).eval() model_class = model.__class__.__name__ model_args = model.args print(f"model: {model_name} ({model_class})") print("args:", json.dumps(model.args0)) dataset = DatasetFactory.load(data_desc_path) print("Dataset loaded: " + data_desc_path) run_dir = os.path.dirname(model_path) + '/' output_dir = f"{run_dir}output_{int(model_name.split('_')[-1])}" if __name__ == "__main__": with torch.no_grad(): # 1. Initialize data loader data_loader = DataLoader(dataset, args.output_rays, chunk_max_items=DATA_LOADER_CHUNK_SIZE, shuffle=True, enable_preload=True, color=color.from_str(model.args['color'])) sys.stdout.write("Export samples...\r") for _, rays_o, rays_d, extra in data_loader: samples, rays_filter = model.sampler(rays_o, rays_d, model.space) invalid_rays_o = rays_o[torch.logical_not(rays_filter)] invalid_rays_d = rays_d[torch.logical_not(rays_filter)] rays_o = rays_o[rays_filter] rays_d = rays_d[rays_filter] break print("Export samples...Done") os.makedirs(output_dir, exist_ok=True) export_data = {} if model.space.bbox is not None: export_data['bbox'] = model.space.bbox.tolist() if isinstance(model.space, Voxels): export_data['voxel_size'] = model.space.voxel_size.tolist() export_data['voxels'] = model.space.voxels.tolist() if False: voxel_access_counts = torch.zeros(model.space.n_voxels, dtype=torch.long, device=device.default()) iters_in_epoch = 0 data_loader.batch_size = 2 ** 20 for _, rays_o1, rays_d1, _ in data_loader: model(rays_o1, rays_d1, raymarching_tolerance=0.5, raymarching_chunk_size=0, voxel_access_counts=voxel_access_counts) iters_in_epoch += 1 percent = iters_in_epoch / len(data_loader) * 100 sys.stdout.write(f'Export voxel access counts...{percent:.1f}% \r') export_data['voxel_access_counts'] = voxel_access_counts.tolist() print("Export voxel access counts...Done ") export_data.update({ 'rays_o': rays_o.tolist(), 'rays_d': rays_d.tolist(), 'invalid_rays_o': invalid_rays_o.tolist(), 'invalid_rays_d': invalid_rays_d.tolist(), 'samples': { 'depths': samples.depths.tolist(), 'dists': samples.dists.tolist(), 'voxel_indices': samples.voxel_indices.tolist() } }) with open(f'{output_dir}/debug_voxel_sampler_export3d.json', 'w') as fp: json.dump(export_data, fp) print("Write JSON file...Done") args.output_rays print(f"Rays: total {args.output_rays}, valid {rays_o.size(0)}") print(f"Samples: average {samples.voxel_indices.ne(-1).sum(-1).float().mean().item()} per ray")