From 5699ccbf8536ceb9e79194f2a1b784da39b511eb Mon Sep 17 00:00:00 2001 From: Nianchen Deng <dengnianchen@sjtu.edu.cn> Date: Fri, 3 Dec 2021 09:36:27 +0800 Subject: [PATCH] sync --- .vscode/launch.json | 57 ++ README.md | 6 + blender/gen_pano.py | 15 + blender/gen_utils.py | 184 ++++ clib/__init__.py | 479 +++++++++++ clib/include/cuda_utils.h | 46 + clib/include/cutil_math.h | 793 ++++++++++++++++++ clib/include/intersect.h | 17 + clib/include/octree.h | 10 + clib/include/sample.h | 16 + clib/include/utils.h | 30 + clib/src/binding.cpp | 21 + clib/src/intersect.cpp | 146 ++++ clib/src/intersect_gpu.cu | 375 +++++++++ clib/src/octree.cpp | 136 +++ clib/src/sample.cpp | 96 +++ clib/src/sample_gpu.cu | 231 +++++ configs/nerf_default.json | 22 + configs/nerf_voxels.json | 24 + configs/nsvf_coarse.json | 21 + configs/nsvf_default.json | 21 + configs/nsvf_voxels.json | 21 + configs/{ => old}/bgnet.py | 0 configs/{ => old}/cnerf.py | 0 configs/{ => old}/dnerfabins.py | 0 configs/{ => old}/fovea.py | 0 configs/{ => old}/fovea_small_rot1.py | 0 configs/{ => old}/fovea_small_trans.py | 0 configs/{ => old}/msl2fast.py | 0 configs/{ => old}/msl_fovea.py | 0 configs/{ => old}/mslfast.py | 0 configs/{ => old}/mslray.py | 0 configs/{ => old}/nerf.py | 0 configs/{ => old}/nerf_horns.py | 0 configs/{ => old}/nerf_horns_4.py | 0 configs/{ => old}/nerf_horns_8.py | 0 configs/{ => old}/nerf_periph.py | 0 configs/{ => old}/nerf_trex.py | 0 configs/{ => old}/nerf_trex_4.py | 0 configs/{ => old}/nerf_trex_8.py | 0 configs/{ => old}/nerfsimple.py | 0 configs/{ => old}/nmsl_fovea.py | 0 configs/{ => old}/nnerf.py | 0 configs/{ => old}/oracle.py | 0 configs/{ => old}/periph.py | 0 configs/{ => old}/periph_new.py | 0 configs/{ => old}/periph_small_trans.py | 0 configs/{ => old}/snerffast_periph.py | 0 configs/{ => old}/snerffastx.py | 0 configs/snerf_fine_voxels.json | 21 + configs/snerf_voxels+ls-d.json | 20 + configs/snerf_voxels+ls.json | 21 + configs/snerf_voxels.json | 19 + configs/snerf_voxels_128x8_x2.json | 22 + configs/snerf_voxels_128x8_x4.json | 22 + configs/snerf_voxels_feat.json | 21 + configs/snerf_voxels_fine.json | 21 + configs/snerfadv_finevoxels+ls.json | 34 + ...erfadv_finevoxels+ls_256x4_256x6_16x2.json | 34 + ...dv_finevoxels+ls_256x4_256x6_combined.json | 30 + configs/snerfadv_finevoxels_ls2.json | 34 + configs/snerfadv_voxels+ls+ns.json | 36 + configs/snerfadv_voxels+ls.json | 33 + configs/snerfadv_voxels+ls1.json | 34 + configs/snerfadv_voxels+ls2.json | 34 + configs/snerfadv_voxels+ls3.json | 34 + configs/snerfadv_voxels+ls4.json | 34 + configs/snerfadv_voxels+ls5.json | 34 + configs/snerfadv_voxels+ls6.json | 34 + configs/snerfadvx_voxels_x16.json | 34 + configs/snerfadvx_voxels_x4.json | 34 + configs/snerfadvx_voxels_x8.json | 34 + configs/snerfx_voxels_128x4_x4.json | 21 + configs/snerfx_voxels_128x4_x8.json | 21 + configs/snerfx_voxels_128x8_x4.json | 21 + configs/snerfx_voxels_256x4_x4.json | 21 + configs/snerfx_voxels_256x4_x4_balance.json | 22 + dash_test.py | 66 +- data/dataset_factory.py | 35 +- data/loader.py | 23 +- data/pano_dataset.py | 100 ++- data/view_dataset.py | 138 +-- debug/voxel_sampler_export3d.py | 134 +++ fntest.py | 12 + loss/__init__.py | 5 + loss/cauchy.py | 16 + loss/ssim.py | 1 - model/__init__.py | 45 + model/base.py | 34 + {nets => model}/bg_net.py | 0 model/nerf.py | 181 ++++ model/nerf_advance.py | 37 + {nets => model}/nerf_depth.py | 2 +- model/nsvf.py | 16 + {nets => model}/oracle.py | 2 +- model/snerf.py | 26 + model/snerf_advance.py | 33 + model/snerf_advance_x.py | 74 ++ {nets => model}/snerf_fast.py | 2 +- model/snerf_x.py | 79 ++ modules/__init__.py | 42 +- modules/core.py | 175 ++++ modules/generic.py | 20 +- modules/renderer.py | 384 +++++++-- modules/sampler.py | 264 +++++- modules/space.py | 351 ++++++++ nerf++ | 1 - nets/nerf.py | 78 -- nets/nsvf.py | 71 -- nets/snerf.py | 110 --- notebook/gen_crop.ipynb | 27 +- notebook/gen_demo_mono.ipynb | 4 +- notebook/gen_demo_stereo.ipynb | 4 +- notebook/gen_for_eval.ipynb | 4 +- notebook/gen_teaser.ipynb | 2 +- notebook/gen_test.ipynb | 2 +- notebook/gen_user_study_images.ipynb | 2 +- notebook/gen_video.ipynb | 2 +- notebook/net_insight.ipynb | 4 +- notebook/test_mono_gen.ipynb | 2 +- notebook/test_mono_view.ipynb | 2 +- run_lf_syn.py | 4 +- run_spherical_view_syn.py | 30 +- setup.py | 27 + term_test.py | 15 + test.py | 226 +++++ tools/clean_nets.py | 23 +- tools/depth_downsample.py | 2 +- tools/export_msl.py | 2 +- tools/export_nmsl.py | 2 +- tools/export_onnx.py | 2 +- tools/export_snerf_fast.py | 2 +- tools/gen_video.py | 6 +- tools/image_scale.py | 2 +- tools/merge_dataset.py | 2 +- tools/pano_process.py | 36 + tools/split_dataset.py | 85 +- train.py | 103 +++ train/__init__.py | 26 + train/base.py | 225 +++++ train/train_with_space.py | 127 +++ train_oracle.py | 8 +- upsampling/run_upsampling.py | 4 +- utils/constants.py | 4 +- utils/geometry.py | 284 +++++++ utils/img.py | 38 +- utils/mem_profiler.py | 5 +- utils/misc.py | 101 ++- utils/perf.py | 157 +++- utils/progress_bar.py | 96 +-- utils/sphere.py | 14 +- utils/voxels.py | 174 ++++ 152 files changed, 7186 insertions(+), 805 deletions(-) create mode 100644 .vscode/launch.json create mode 100644 blender/gen_pano.py create mode 100644 blender/gen_utils.py create mode 100644 clib/__init__.py create mode 100644 clib/include/cuda_utils.h create mode 100644 clib/include/cutil_math.h create mode 100644 clib/include/intersect.h create mode 100644 clib/include/octree.h create mode 100644 clib/include/sample.h create mode 100644 clib/include/utils.h create mode 100644 clib/src/binding.cpp create mode 100644 clib/src/intersect.cpp create mode 100644 clib/src/intersect_gpu.cu create mode 100644 clib/src/octree.cpp create mode 100644 clib/src/sample.cpp create mode 100644 clib/src/sample_gpu.cu create mode 100644 configs/nerf_default.json create mode 100644 configs/nerf_voxels.json create mode 100644 configs/nsvf_coarse.json create mode 100644 configs/nsvf_default.json create mode 100644 configs/nsvf_voxels.json rename configs/{ => old}/bgnet.py (100%) rename configs/{ => old}/cnerf.py (100%) rename configs/{ => old}/dnerfabins.py (100%) rename configs/{ => old}/fovea.py (100%) rename configs/{ => old}/fovea_small_rot1.py (100%) rename configs/{ => old}/fovea_small_trans.py (100%) rename configs/{ => old}/msl2fast.py (100%) rename configs/{ => old}/msl_fovea.py (100%) rename configs/{ => old}/mslfast.py (100%) rename configs/{ => old}/mslray.py (100%) rename configs/{ => old}/nerf.py (100%) rename configs/{ => old}/nerf_horns.py (100%) rename configs/{ => old}/nerf_horns_4.py (100%) rename configs/{ => old}/nerf_horns_8.py (100%) rename configs/{ => old}/nerf_periph.py (100%) rename configs/{ => old}/nerf_trex.py (100%) rename configs/{ => old}/nerf_trex_4.py (100%) rename configs/{ => old}/nerf_trex_8.py (100%) rename configs/{ => old}/nerfsimple.py (100%) rename configs/{ => old}/nmsl_fovea.py (100%) rename configs/{ => old}/nnerf.py (100%) rename configs/{ => old}/oracle.py (100%) rename configs/{ => old}/periph.py (100%) rename configs/{ => old}/periph_new.py (100%) rename configs/{ => old}/periph_small_trans.py (100%) rename configs/{ => old}/snerffast_periph.py (100%) rename configs/{ => old}/snerffastx.py (100%) create mode 100644 configs/snerf_fine_voxels.json create mode 100644 configs/snerf_voxels+ls-d.json create mode 100644 configs/snerf_voxels+ls.json create mode 100644 configs/snerf_voxels.json create mode 100644 configs/snerf_voxels_128x8_x2.json create mode 100644 configs/snerf_voxels_128x8_x4.json create mode 100644 configs/snerf_voxels_feat.json create mode 100644 configs/snerf_voxels_fine.json create mode 100644 configs/snerfadv_finevoxels+ls.json create mode 100644 configs/snerfadv_finevoxels+ls_256x4_256x6_16x2.json create mode 100644 configs/snerfadv_finevoxels+ls_256x4_256x6_combined.json create mode 100644 configs/snerfadv_finevoxels_ls2.json create mode 100644 configs/snerfadv_voxels+ls+ns.json create mode 100644 configs/snerfadv_voxels+ls.json create mode 100644 configs/snerfadv_voxels+ls1.json create mode 100644 configs/snerfadv_voxels+ls2.json create mode 100644 configs/snerfadv_voxels+ls3.json create mode 100644 configs/snerfadv_voxels+ls4.json create mode 100644 configs/snerfadv_voxels+ls5.json create mode 100644 configs/snerfadv_voxels+ls6.json create mode 100644 configs/snerfadvx_voxels_x16.json create mode 100644 configs/snerfadvx_voxels_x4.json create mode 100644 configs/snerfadvx_voxels_x8.json create mode 100644 configs/snerfx_voxels_128x4_x4.json create mode 100644 configs/snerfx_voxels_128x4_x8.json create mode 100644 configs/snerfx_voxels_128x8_x4.json create mode 100644 configs/snerfx_voxels_256x4_x4.json create mode 100644 configs/snerfx_voxels_256x4_x4_balance.json create mode 100644 debug/voxel_sampler_export3d.py create mode 100644 fntest.py create mode 100644 loss/__init__.py create mode 100644 loss/cauchy.py create mode 100644 model/__init__.py create mode 100644 model/base.py rename {nets => model}/bg_net.py (100%) create mode 100644 model/nerf.py create mode 100644 model/nerf_advance.py rename {nets => model}/nerf_depth.py (96%) create mode 100644 model/nsvf.py rename {nets => model}/oracle.py (96%) create mode 100644 model/snerf.py create mode 100644 model/snerf_advance.py create mode 100644 model/snerf_advance_x.py rename {nets => model}/snerf_fast.py (98%) create mode 100644 model/snerf_x.py create mode 100644 modules/core.py create mode 100644 modules/space.py delete mode 160000 nerf++ delete mode 100644 nets/nerf.py delete mode 100644 nets/nsvf.py delete mode 100644 nets/snerf.py create mode 100644 setup.py create mode 100644 term_test.py create mode 100644 test.py create mode 100644 tools/pano_process.py create mode 100644 train.py create mode 100644 train/__init__.py create mode 100644 train/base.py create mode 100644 train/train_with_space.py create mode 100644 utils/geometry.py create mode 100644 utils/voxels.py diff --git a/.vscode/launch.json b/.vscode/launch.json new file mode 100644 index 0000000..70285a7 --- /dev/null +++ b/.vscode/launch.json @@ -0,0 +1,57 @@ +{ + // 浣跨敤 IntelliSense 浜嗚В鐩稿叧灞炴€с€� + // 鎮仠浠ユ煡鐪嬬幇鏈夊睘鎬х殑鎻忚堪銆� + // 娆蹭簡瑙f洿澶氫俊鎭紝璇疯闂�: https://go.microsoft.com/fwlink/?linkid=830387 + "version": "0.2.0", + "configurations": [ + + + { + "name": "Debug/Voxel Sampler Export 3D", + "type": "python", + "request": "launch", + "program": "debug/voxel_sampler_export3d.py", + "args": [ + "-p", + "data/__new/barbershop_fovea_r360x80_t0.6/train_t0.3.json" + ], + "console": "integratedTerminal" + }, + { + "name": "Train", + "type": "python", + "request": "launch", + "program": "train.py", + "args": [ + //"-c", + //"snerf_voxels", + "/home/dengnc/dvs/data/__new/barbershop_fovea_r360x80_t0.6/_nets/train_t0.3/snerfadvx_voxels_x4/checkpoint_10.tar", + "--prune", + "100", + "--split", + "100" + //"data/__new/barbershop_fovea_r360x80_t0.6/train_t0.3.json" + ], + "console": "integratedTerminal" + }, + { + "name": "Test", + "type": "python", + "request": "launch", + "program": "test.py", + "args": [ + "-m", + "/home/dengnc/dvs/data/__new/barbershop_fovea_r360x80_t0.6/_nets/train_t0.3/snerfadv_voxels+ls2/checkpoint_50.tar", + "-o", + "perf", + "color", + "--output-type", + "image", + "/home/dengnc/dvs/data/__new/barbershop_fovea_r360x80_t0.6/test_t0.3.json", + "--views", + "1" + ], + "console": "integratedTerminal" + } + ] +} \ No newline at end of file diff --git a/README.md b/README.md index 9005cdb..fee4397 100644 --- a/README.md +++ b/README.md @@ -13,6 +13,12 @@ Or ref to https://pytorch.org/get-started/locally/ for install guide * tensorboard +* plyfile + +``` +$ conda install -c conda-forge plyfile +``` + * (Optional) dash ``` diff --git a/blender/gen_pano.py b/blender/gen_pano.py new file mode 100644 index 0000000..ee0e9ed --- /dev/null +++ b/blender/gen_pano.py @@ -0,0 +1,15 @@ +import sys +import os +import argparse + +sys.path.append(os.path.dirname(os.path.abspath(__file__))) + +from gen_utils import GenPano + +parser = argparse.ArgumentParser() +parser.add_argument('-r', '--radius', type=float, required=True) +parser.add_argument("-n", "--samples", type=int, required=True) +parser.add_argument("--cycles-device", type=str) +args = parser.parse_args(sys.argv[sys.argv.index("--") + 1:]) + +GenPano('output/pano', f'hr_r{args.radius:.1f}', samples=[args.samples], depth_range=[args.radius, 50])() diff --git a/blender/gen_utils.py b/blender/gen_utils.py new file mode 100644 index 0000000..8f100d3 --- /dev/null +++ b/blender/gen_utils.py @@ -0,0 +1,184 @@ +import bpy +import math +import json +import os +import math +import numpy as np +from typing import List, Tuple +from itertools import product + + +class Gen: + def __init__(self, root_dir: str, dataset_name: str, *, + res: Tuple[int, int], + fov: float, + samples: List[int]) -> None: + self.res = res + self.fov = fov + self.samples = samples + + self.scene = bpy.context.scene + self.cam_obj = self.scene.camera + self.cam = self.cam_obj.data + self.scene.render.resolution_x = self.res[0] + self.scene.render.resolution_y = self.res[1] + self.init_camera() + + self.root_dir = root_dir + self.data_dir = f"{root_dir}/{dataset_name}/" + self.data_name = dataset_name + self.data_desc_file = f'{root_dir}/{dataset_name}.json' + + def init_camera(self): + if self.fov < 0: + self.cam.type = 'PANO' + self.cam.cycles.panorama_type = 'EQUIRECTANGULAR' + else: + self.cam.type = 'PERSP' + self.cam.lens_unit = 'FOV' + self.cam.angle = math.radians(self.fov) + self.cam.dof.use_dof = False + self.cam.clip_start = 0.1 + self.cam.clip_end = 1000 + + def init_desc(self): + return None + + def save_desc(self): + with open(self.data_desc_file, 'w') as fp: + json.dump(self.desc, fp, indent=4) + + def add_sample(self, i, x: List[float], render_only=False): + self.cam_obj.location = x[:3] + if len(x) > 3: + self.cam_obj.rotation_euler = [math.radians(x[4]), math.radians(x[3]), 0] + self.scene.render.filepath = self.data_dir + self.desc['view_file_pattern'] % i + bpy.ops.render.render(write_still=True) + if not render_only: + self.desc['view_centers'].append(x[:3]) + if len(x) > 3: + self.desc['view_rots'].append(x[3:]) + self.save_desc() + + def gen_grid(self): + start_view = len(self.desc['view_centers']) + ranges = [ + np.linspace(self.desc['range']['min'][i], + self.desc['range']['max'][i], + self.desc['samples'][i]) + for i in range(len(self.desc['samples'])) + ] + for i, x in enumerate(product(*ranges)): + if i >= start_view: + self.add_sample(i, list(x)) + + def gen_rand(self): + pass + + def __call__(self): + os.makedirs(self.data_dir, exist_ok=True) + if os.path.exists(self.data_desc_file): + with open(self.data_desc_file, 'r') as fp: + self.desc = json.load(fp) + else: + self.desc = self.init_desc() + + # Render missing views in data desc + for i in range(len(self.desc['view_centers'])): + if not os.path.exists(self.data_dir + self.desc['view_file_pattern'] % i): + x: List[float] = self.desc['view_centers'][i] + if 'view_rots' in self.desc: + x += self.desc['view_rots'][i] + self.add_sample(i, x, render_only=True) + + if len(self.desc['samples']) == 1: + self.gen_rand() + else: + self.gen_grid() + + +class GenView(Gen): + + def __init__(self, root_dir: str, dataset_name: str, *, + res: Tuple[int, int], fov: float, samples: List[int], + tbox: Tuple[float, float, float], rbox: Tuple[float, float]) -> None: + super().__init__(root_dir, dataset_name, res=res, fov=fov, samples=samples) + self.tbox = tbox + self.rbox = rbox + + def init_desc(self): + return { + 'view_file_pattern': 'view_%04d.png', + "gl_coord": True, + 'view_res': { + 'x': self.res[0], + 'y': self.res[1] + }, + 'cam_params': { + 'fov': self.fov, + 'cx': 0.5, + 'cy': 0.5, + 'normalized': True + }, + 'range': { + 'min': [-self.tbox[0] / 2, -self.tbox[1] / 2, -self.tbox[2] / 2, + -self.rbox[0] / 2, -self.rbox[1] / 2], + 'max': [self.tbox[0] / 2, self.tbox[1] / 2, self.tbox[2] / 2, + self.rbox[0] / 2, self.rbox[1] / 2] + }, + 'samples': self.samples, + 'view_centers': [], + 'view_rots': [] + } + + def gen_rand(self): + start_view = len(self.desc['view_centers']) + n = self.desc['samples'][0] - start_view + range_min = np.array(self.desc['range']['min']) + range_max = np.array(self.desc['range']['max']) + samples = (range_max - range_min) * np.random.rand(n, 5) + range_min + for i in range(n): + self.add_sample(i + start_view, list(samples[i])) + + +class GenPano(Gen): + + def __init__(self, root_dir: str, dataset_name: str, *, + samples: List[int], depth_range: Tuple[float, float], + tbox: Tuple[float, float, float] = None) -> None: + self.depth_range = depth_range + self.tbox = tbox + super().__init__(root_dir, dataset_name, res=[4096, 2048], fov=-1, samples=samples) + + def init_desc(self): + range = { + 'range': { + 'min': [-self.tbox[0] / 2, -self.tbox[1] / 2, -self.tbox[2] / 2], + 'max': [self.tbox[0] / 2, self.tbox[1] / 2, self.tbox[2] / 2] + } + } if self.tbox else {} + return { + "type": "pano", + 'view_file_pattern': 'view_%04d.png', + "gl_coord": True, + 'view_res': { + 'x': self.res[0], + 'y': self.res[1] + }, + **range, + "depth_range": { + "min": self.depth_range[0], + "max": self.depth_range[1] + }, + 'samples': self.samples, + 'view_centers': [] + } + + def gen_rand(self): + start_view = len(self.desc['view_centers']) + n = self.desc['samples'][0] - start_view + r_max = self.desc['depth_range']['min'] + pts = (np.random.rand(n * 5, 3) - 0.5) * 2 * r_max + samples = pts[np.linalg.norm(pts, axis=1) < r_max][:n] + for i in range(n): + self.add_sample(i + start_view, list(samples[i])) diff --git a/clib/__init__.py b/clib/__init__.py new file mode 100644 index 0000000..c740aff --- /dev/null +++ b/clib/__init__.py @@ -0,0 +1,479 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +''' Modified based on: https://github.com/erikwijmans/Pointnet2_PyTorch ''' +from __future__ import ( + division, + absolute_import, + with_statement, + print_function, + unicode_literals, +) +import os +import sys +from typing import Tuple +import torch +import torch.nn.functional as F +from torch.autograd import Function +import torch.nn as nn +import sys +import numpy as np +from utils.geometry import discretize_points +from utils.constants import HUGE_FLOAT + +try: + import builtins +except: + import __builtin__ as builtins + +try: + import clib._ext as _ext +except ImportError: + raise ImportError( + "Could not import _ext module.\n" + "Please see the setup instructions in the README" + ) + + + +class BallRayIntersect(Function): + @staticmethod + def forward(ctx, radius, n_max, points, ray_start, ray_dir): + inds, min_depth, max_depth = _ext.ball_intersect( + ray_start.float(), ray_dir.float(), points.float(), radius, n_max) + min_depth = min_depth.type_as(ray_start) + max_depth = max_depth.type_as(ray_start) + + ctx.mark_non_differentiable(inds) + ctx.mark_non_differentiable(min_depth) + ctx.mark_non_differentiable(max_depth) + return inds, min_depth, max_depth + + @staticmethod + def backward(ctx, a, b, c): + return None, None, None, None, None + + +ball_ray_intersect = BallRayIntersect.apply + + +class AABBRayIntersect(Function): + @staticmethod + def forward(ctx, voxelsize, n_max, points, ray_start, ray_dir): + # HACK: speed-up ray-voxel intersection by batching... + G = min(2048, int(2 * 10 ** 9 / points.numel())) # HACK: avoid out-of-memory + S, N = ray_start.shape[:2] + K = int(np.ceil(N / G)) + G, K = 1, N # HACK + H = K * G + if H > N: + ray_start = torch.cat([ray_start, ray_start[:, :H - N]], 1) + ray_dir = torch.cat([ray_dir, ray_dir[:, :H - N]], 1) + ray_start = ray_start.reshape(S * G, K, 3) + ray_dir = ray_dir.reshape(S * G, K, 3) + points = points[None].expand(S * G, *points.size()).contiguous() + + inds, min_depth, max_depth = _ext.aabb_intersect( + ray_start.float(), ray_dir.float(), points.float(), voxelsize, n_max) + min_depth = min_depth.type_as(ray_start) + max_depth = max_depth.type_as(ray_start) + + inds = inds.reshape(S, H, -1) + min_depth = min_depth.reshape(S, H, -1) + max_depth = max_depth.reshape(S, H, -1) + if H > N: + inds = inds[:, :N] + min_depth = min_depth[:, :N] + max_depth = max_depth[:, :N] + + ctx.mark_non_differentiable(inds) + ctx.mark_non_differentiable(min_depth) + ctx.mark_non_differentiable(max_depth) + return inds, min_depth, max_depth + + @staticmethod + def backward(ctx, a, b, c): + return None, None, None, None, None + + +def aabb_ray_intersect(voxelsize: float, n_max: int, points: torch.Tensor, ray_start: torch.Tensor, + ray_dir: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + AABB-Ray intersect test + + :param voxelsize `float`: size of a voxel + :param n_max `int`: maximum number of hits + :param points `Tensor(M, 3)`: voxels' centers + :param ray_start `Tensor(S, N, 3)`: rays' start positions + :param ray_dir `Tensor(S, N, 3)`: rays' directions + :return `Tensor(S, N, n_max)`: indices of intersected voxels or -1 + :return `Tensor(S, N, n_max)`: min depths of every intersected voxels + :return `Tensor(S, N, n_max)`: max depths of every intersected voxels + """ + return AABBRayIntersect.apply(voxelsize, n_max, points, ray_start, ray_dir) + + +class SparseVoxelOctreeRayIntersect(Function): + @staticmethod + def forward(ctx, voxelsize, n_max, points, children, ray_start, ray_dir): + # HACK: avoid out-of-memory + G = min(2048, int(2 * 10 ** 9 / (points.numel() + children.numel()))) + S, N = ray_start.shape[:2] + K = int(np.ceil(N / G)) + G, K = 1, N # HACK + H = K * G + if H > N: + ray_start = torch.cat([ray_start, ray_start[:, :H - N]], 1) + ray_dir = torch.cat([ray_dir, ray_dir[:, :H - N]], 1) + ray_start = ray_start.reshape(S * G, K, 3) + ray_dir = ray_dir.reshape(S * G, K, 3) + points = points[None].expand(S * G, *points.size()).contiguous() + children = children[None].expand(S * G, *children.size()).contiguous() + inds, min_depth, max_depth = _ext.svo_intersect( + ray_start.float(), ray_dir.float(), points.float(), children.int(), voxelsize, n_max) + + min_depth = min_depth.type_as(ray_start) + max_depth = max_depth.type_as(ray_start) + + inds = inds.reshape(S, H, -1) + min_depth = min_depth.reshape(S, H, -1) + max_depth = max_depth.reshape(S, H, -1) + if H > N: + inds = inds[:, :N] + min_depth = min_depth[:, :N] + max_depth = max_depth[:, :N] + + ctx.mark_non_differentiable(inds) + ctx.mark_non_differentiable(min_depth) + ctx.mark_non_differentiable(max_depth) + return inds, min_depth, max_depth + + @staticmethod + def backward(ctx, a, b, c): + return None, None, None, None, None + + +def octree_ray_intersect(voxelsize: float, n_max: int, points: torch.Tensor, children: torch.Tensor, + ray_start: torch.Tensor, ray_dir: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Octree-Ray intersect test + + :param voxelsize `float`: size of a voxel + :param n_max `int`: maximum number of hits + :param points `Tensor(M, 3)`: voxels' centers + :param children `Tensor(M, 9)`: flattened octree structure + :param ray_start `Tensor(S, N, 3)`: rays' start positions + :param ray_dir `Tensor(S, N, 3)`: rays' directions + :return `Tensor(S, N, n_max)`: indices of intersected voxels or -1 + :return `Tensor(S, N, n_max)`: min depths of every intersected voxels + :return `Tensor(S, N, n_max)`: max depths of every intersected voxels + """ + return SparseVoxelOctreeRayIntersect.apply(voxelsize, n_max, points, children, ray_start, + ray_dir) + + +class TriangleRayIntersect(Function): + @staticmethod + def forward(ctx, cagesize, blur_ratio, n_max, points, faces, ray_start, ray_dir): + # HACK: speed-up ray-voxel intersection by batching... + G = min(2048, int(2 * 10 ** 9 / (3 * faces.numel()))) # HACK: avoid out-of-memory + S, N = ray_start.shape[:2] + K = int(np.ceil(N / G)) + H = K * G + if H > N: + ray_start = torch.cat([ray_start, ray_start[:, :H - N]], 1) + ray_dir = torch.cat([ray_dir, ray_dir[:, :H - N]], 1) + ray_start = ray_start.reshape(S * G, K, 3) + ray_dir = ray_dir.reshape(S * G, K, 3) + face_points = F.embedding(faces.reshape(-1, 3), points.reshape(-1, 3)) + face_points = face_points.unsqueeze(0).expand(S * G, *face_points.size()).contiguous() + inds, depth, uv = _ext.triangle_intersect( + ray_start.float(), ray_dir.float(), face_points.float(), cagesize, blur_ratio, n_max) + depth = depth.type_as(ray_start) + uv = uv.type_as(ray_start) + + inds = inds.reshape(S, H, -1) + depth = depth.reshape(S, H, -1, 3) + uv = uv.reshape(S, H, -1) + if H > N: + inds = inds[:, :N] + depth = depth[:, :N] + uv = uv[:, :N] + + ctx.mark_non_differentiable(inds) + ctx.mark_non_differentiable(depth) + ctx.mark_non_differentiable(uv) + return inds, depth, uv + + @staticmethod + def backward(ctx, a, b, c): + return None, None, None, None, None, None + + +triangle_ray_intersect = TriangleRayIntersect.apply + + +class UniformRaySampling(Function): + @staticmethod + def forward(ctx, pts_idx, min_depth, max_depth, step_size, max_ray_length, deterministic=False): + G, N, P = 256, pts_idx.size(0), pts_idx.size(1) + H = int(np.ceil(N / G)) * G + if H > N: + pts_idx = torch.cat([pts_idx, pts_idx[:H - N]], 0) + min_depth = torch.cat([min_depth, min_depth[:H - N]], 0) + max_depth = torch.cat([max_depth, max_depth[:H - N]], 0) + pts_idx = pts_idx.reshape(G, -1, P) + min_depth = min_depth.reshape(G, -1, P) + max_depth = max_depth.reshape(G, -1, P) + + # pre-generate noise + max_steps = int(max_ray_length / step_size) + max_steps = max_steps + min_depth.size(-1) * 2 + noise = min_depth.new_zeros(*min_depth.size()[:-1], max_steps) + if deterministic: + noise += 0.5 + else: + noise = noise.uniform_() + + # call cuda function + sampled_idx, sampled_depth, sampled_dists = _ext.uniform_ray_sampling( + pts_idx, min_depth.float(), max_depth.float(), noise.float(), step_size, max_steps) + sampled_depth = sampled_depth.type_as(min_depth) + sampled_dists = sampled_dists.type_as(min_depth) + + sampled_idx = sampled_idx.reshape(H, -1) + sampled_depth = sampled_depth.reshape(H, -1) + sampled_dists = sampled_dists.reshape(H, -1) + if H > N: + sampled_idx = sampled_idx[: N] + sampled_depth = sampled_depth[: N] + sampled_dists = sampled_dists[: N] + + max_len = sampled_idx.ne(-1).sum(-1).max() + sampled_idx = sampled_idx[:, :max_len] + sampled_depth = sampled_depth[:, :max_len] + sampled_dists = sampled_dists[:, :max_len] + + ctx.mark_non_differentiable(sampled_idx) + ctx.mark_non_differentiable(sampled_depth) + ctx.mark_non_differentiable(sampled_dists) + return sampled_idx, sampled_depth, sampled_dists + + @staticmethod + def backward(ctx, a, b, c): + return None, None, None, None, None, None + + +def uniform_ray_sampling(pts_idx: torch.Tensor, min_depth: torch.Tensor, max_depth: torch.Tensor, + step_size: float, max_ray_length: float, deterministic: bool = False) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Sample along rays uniformly + + :param pts_idx `Tensor(N, P)`: indices of voxels intersected with rays + :param min_depth `Tensor(N, P)`: min depth of intersections of rays and voxels + :param max_depth `Tensor(N, P)`: max depth of intersections of rays and voxels + :param step_size `float`: size of sampling step + :param max_ray_length `float`: maximum sampling depth along rays + :param deterministic `bool`: (optional) sample deterministically (or randomly), defaults to False + :return `Tensor(N, P')`: voxel indices of sampled points + :return `Tensor(N, P')`: depth of sampled points + :return `Tensor(N, P')`: length of sampled points + """ + return UniformRaySampling.apply(pts_idx, min_depth, max_depth, step_size, max_ray_length, + deterministic) + + +class InverseCDFRaySampling(Function): + @staticmethod + def forward(ctx, pts_idx, min_depth, max_depth, probs, steps, fixed_step_size=-1, deterministic=False): + G, N, P = 200, pts_idx.size(0), pts_idx.size(1) + H = int(np.ceil(N / G)) * G + + if H > N: + pts_idx = torch.cat([pts_idx, pts_idx[:1].expand(H - N, P)], 0) + min_depth = torch.cat([min_depth, min_depth[:1].expand(H - N, P)], 0) + max_depth = torch.cat([max_depth, max_depth[:1].expand(H - N, P)], 0) + probs = torch.cat([probs, probs[:1].expand(H - N, P)], 0) + steps = torch.cat([steps, steps[:1].expand(H - N)], 0) + # print(G, P, np.ceil(N / G), N, H, pts_idx.shape, min_depth.device) + pts_idx = pts_idx.reshape(G, -1, P) + min_depth = min_depth.reshape(G, -1, P) + max_depth = max_depth.reshape(G, -1, P) + probs = probs.reshape(G, -1, P) + steps = steps.reshape(G, -1) + + # pre-generate noise + max_steps = steps.ceil().long().max() + P + noise = min_depth.new_zeros(*min_depth.size()[:-1], max_steps) + if deterministic: + noise += 0.5 + else: + noise = noise.uniform_().clamp(min=0.001, max=0.999) # in case + + # call cuda function + chunk_size = 4 * G # to avoid oom? + results = [ + _ext.inverse_cdf_sampling( + pts_idx[:, i:i + chunk_size].contiguous(), + min_depth.float()[:, i:i + chunk_size].contiguous(), + max_depth.float()[:, i:i + chunk_size].contiguous(), + noise.float()[:, i:i + chunk_size].contiguous(), + probs.float()[:, i:i + chunk_size].contiguous(), + steps.float()[:, i:i + chunk_size].contiguous(), + fixed_step_size) + for i in range(0, min_depth.size(1), chunk_size) + ] + sampled_idx, sampled_depth, sampled_dists = [ + torch.cat([r[i] for r in results], 1) + for i in range(3) + ] + sampled_depth = sampled_depth.type_as(min_depth) + sampled_dists = sampled_dists.type_as(min_depth) + + sampled_idx = sampled_idx.reshape(H, -1) + sampled_depth = sampled_depth.reshape(H, -1) + sampled_dists = sampled_dists.reshape(H, -1) + if H > N: + sampled_idx = sampled_idx[: N] + sampled_depth = sampled_depth[: N] + sampled_dists = sampled_dists[: N] + + max_len = sampled_idx.ne(-1).sum(-1).max() + sampled_idx = sampled_idx[:, :max_len] + sampled_depth = sampled_depth[:, :max_len] + sampled_dists = sampled_dists[:, :max_len] + + ctx.mark_non_differentiable(sampled_idx) + ctx.mark_non_differentiable(sampled_depth) + ctx.mark_non_differentiable(sampled_dists) + return sampled_idx, sampled_depth, sampled_dists + + @staticmethod + def backward(ctx, a, b, c): + return None, None, None, None, None, None, None + + +def inverse_cdf_sampling(pts_idx: torch.Tensor, min_depth: torch.Tensor, max_depth: torch.Tensor, + probs: torch.Tensor, steps: torch.Tensor, fixed_step_size: float = -1, + deterministic: bool = False) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Sample along rays by inverse CDF + + :param pts_idx `Tensor(N, P)`: indices of voxels intersected with rays + :param min_depth `Tensor(N, P)`: min depth of intersections of rays and voxels + :param max_depth `Tensor(N, P)`: max depth of intersections of rays and voxels + :param probs `Tensor(N, P)`: + :param steps `Tensor(N)`: + :param fixed_step_size `float`: + :param deterministic `bool`: (optional) sample deterministically (or randomly), defaults to False + :return `Tensor(N, P')`: voxel indices of sampled points + :return `Tensor(N, P')`: depth of sampled points + :return `Tensor(N, P')`: length of sampled points + """ + return InverseCDFRaySampling.apply(pts_idx, min_depth, max_depth, probs, steps, fixed_step_size, + deterministic) + + +# back-up for ray point sampling +@torch.no_grad() +def _parallel_ray_sampling(MARCH_SIZE, pts_idx, min_depth, max_depth, deterministic=False): + # uniform sampling + _min_depth = min_depth.min(1)[0] + _max_depth = max_depth.masked_fill(max_depth.eq(HUGE_FLOAT), 0).max(1)[0] + max_ray_length = (_max_depth - _min_depth).max() + + delta = torch.arange(int(max_ray_length / MARCH_SIZE), + device=min_depth.device, dtype=min_depth.dtype) + delta = delta[None, :].expand(min_depth.size(0), delta.size(-1)) + if deterministic: + delta = delta + 0.5 + else: + delta = delta + delta.clone().uniform_().clamp(min=0.01, max=0.99) + delta = delta * MARCH_SIZE + sampled_depth = min_depth[:, :1] + delta + sampled_idx = (sampled_depth[:, :, None] >= min_depth[:, None, :]).sum(-1) - 1 + sampled_idx = pts_idx.gather(1, sampled_idx) + + # include all boundary points + sampled_depth = torch.cat([min_depth, max_depth, sampled_depth], -1) + sampled_idx = torch.cat([pts_idx, pts_idx, sampled_idx], -1) + + # reorder + sampled_depth, ordered_index = sampled_depth.sort(-1) + sampled_idx = sampled_idx.gather(1, ordered_index) + sampled_dists = sampled_depth[:, 1:] - sampled_depth[:, :-1] # distances + sampled_depth = .5 * (sampled_depth[:, 1:] + sampled_depth[:, :-1]) # mid-points + + # remove all invalid depths + min_ids = (sampled_depth[:, :, None] >= min_depth[:, None, :]).sum(-1) - 1 + max_ids = (sampled_depth[:, :, None] >= max_depth[:, None, :]).sum(-1) + + sampled_depth.masked_fill_( + (max_ids.ne(min_ids)) | + (sampled_depth > _max_depth[:, None]) | + (sampled_dists == 0.0), HUGE_FLOAT) + sampled_depth, ordered_index = sampled_depth.sort(-1) # sort again + sampled_masks = sampled_depth.eq(HUGE_FLOAT) + num_max_steps = (~sampled_masks).sum(-1).max() + + sampled_depth = sampled_depth[:, :num_max_steps] + sampled_dists = sampled_dists.gather(1, ordered_index).masked_fill_( + sampled_masks, 0.0)[:, :num_max_steps] + sampled_idx = sampled_idx.gather(1, ordered_index).masked_fill_( + sampled_masks, -1)[:, :num_max_steps] + + return sampled_idx, sampled_depth, sampled_dists + + +@torch.no_grad() +def parallel_ray_sampling(MARCH_SIZE, pts_idx, min_depth, max_depth, deterministic=False): + chunk_size = 4096 + full_size = min_depth.shape[0] + if full_size <= chunk_size: + return _parallel_ray_sampling(MARCH_SIZE, pts_idx, min_depth, max_depth, deterministic=deterministic) + + outputs = zip(*[ + _parallel_ray_sampling( + MARCH_SIZE, + pts_idx[i:i + chunk_size], min_depth[i:i + chunk_size], max_depth[i:i + chunk_size], + deterministic=deterministic) + for i in range(0, full_size, chunk_size)]) + sampled_idx, sampled_depth, sampled_dists = outputs + + def padding_points(xs, pad): + if len(xs) == 1: + return xs[0] + + maxlen = max([x.size(1) for x in xs]) + full_size = sum([x.size(0) for x in xs]) + xt = xs[0].new_ones(full_size, maxlen).fill_(pad) + st = 0 + for i in range(len(xs)): + xt[st: st + xs[i].size(0), :xs[i].size(1)] = xs[i] + st += xs[i].size(0) + return xt + + sampled_idx = padding_points(sampled_idx, -1) + sampled_depth = padding_points(sampled_depth, HUGE_FLOAT) + sampled_dists = padding_points(sampled_dists, 0.0) + return sampled_idx, sampled_depth, sampled_dists + + +def build_easy_octree(points: torch.Tensor, half_voxel: float) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Build an octree. + + :param points `Tensor(M, 3)`: centers of leaf voxels + :param half_voxel `float`: half size of voxel + :return `Tensor(M', 3)`: centers of all nodes in octree + :return `Tensor(M', 9)`: flattened octree structure + """ + coords, residual = discretize_points(points, half_voxel) + ranges = coords.max(0)[0] - coords.min(0)[0] + depths = torch.log2(ranges.max().float()).ceil_().long() - 1 + center = (coords.max(0)[0] + coords.min(0)[0]) / 2 + centers, children = _ext.build_octree(center, coords, int(depths)) + centers = centers.float() * half_voxel + residual # transform back to float + return centers, children \ No newline at end of file diff --git a/clib/include/cuda_utils.h b/clib/include/cuda_utils.h new file mode 100644 index 0000000..d4c4bb4 --- /dev/null +++ b/clib/include/cuda_utils.h @@ -0,0 +1,46 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +#ifndef _CUDA_UTILS_H +#define _CUDA_UTILS_H + +#include <ATen/ATen.h> +#include <ATen/cuda/CUDAContext.h> +#include <cmath> + +#include <cuda.h> +#include <cuda_runtime.h> + +#include <vector> + +#define TOTAL_THREADS 512 + +inline int opt_n_threads(int work_size) { + const int pow_2 = std::log(static_cast<double>(work_size)) / std::log(2.0); + + return max(min(1 << pow_2, TOTAL_THREADS), 1); +} + +inline dim3 opt_block_config(int x, int y) { + const int x_threads = opt_n_threads(x); + const int y_threads = + max(min(opt_n_threads(y), TOTAL_THREADS / x_threads), 1); + dim3 block_config(x_threads, y_threads, 1); + + return block_config; +} + +#define CUDA_CHECK_ERRORS() \ + do { \ + cudaError_t err = cudaGetLastError(); \ + if (cudaSuccess != err) { \ + fprintf(stderr, "CUDA kernel failed : %s\n%s at L:%d in %s\n", \ + cudaGetErrorString(err), __PRETTY_FUNCTION__, __LINE__, \ + __FILE__); \ + exit(-1); \ + } \ + } while (0) + +#endif diff --git a/clib/include/cutil_math.h b/clib/include/cutil_math.h new file mode 100644 index 0000000..d8748b9 --- /dev/null +++ b/clib/include/cutil_math.h @@ -0,0 +1,793 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +/* + * Copyright 1993-2009 NVIDIA Corporation. All rights reserved. + * + * NVIDIA Corporation and its licensors retain all intellectual property and + * proprietary rights in and to this software and related documentation and + * any modifications thereto. Any use, reproduction, disclosure, or distribution + * of this software and related documentation without an express license + * agreement from NVIDIA Corporation is strictly prohibited. + * + */ + +/* + This file implements common mathematical operations on vector types + (float3, float4 etc.) since these are not provided as standard by CUDA. + + The syntax is modelled on the Cg standard library. +*/ + +#ifndef CUTIL_MATH_H +#define CUTIL_MATH_H + +#include "cuda_runtime.h" + +//////////////////////////////////////////////////////////////////////////////// +typedef unsigned int uint; +typedef unsigned short ushort; + +#ifndef __CUDACC__ +#include <math.h> + +inline float fminf(float a, float b) +{ + return a < b ? a : b; +} + +inline float fmaxf(float a, float b) +{ + return a > b ? a : b; +} + +inline int max(int a, int b) +{ + return a > b ? a : b; +} + +inline int min(int a, int b) +{ + return a < b ? a : b; +} + +inline float rsqrtf(float x) +{ + return 1.0f / sqrtf(x); +} + +#endif + +// float functions +//////////////////////////////////////////////////////////////////////////////// + +// lerp +inline __device__ __host__ float lerp(float a, float b, float t) +{ + return a + t*(b-a); +} + +// clamp +inline __device__ __host__ float clamp(float f, float a, float b) +{ + return fmaxf(a, fminf(f, b)); +} + +inline __device__ __host__ void swap(float &a, float &b) +{ + float c = a; + a = b; + b = c; +} + +inline __device__ __host__ void swap(int &a, int &b) +{ + float c = a; + a = b; + b = c; +} + + +// int2 functions +//////////////////////////////////////////////////////////////////////////////// + +// negate +inline __host__ __device__ int2 operator-(int2 &a) +{ + return make_int2(-a.x, -a.y); +} + +// addition +inline __host__ __device__ int2 operator+(int2 a, int2 b) +{ + return make_int2(a.x + b.x, a.y + b.y); +} +inline __host__ __device__ void operator+=(int2 &a, int2 b) +{ + a.x += b.x; a.y += b.y; +} + +// subtract +inline __host__ __device__ int2 operator-(int2 a, int2 b) +{ + return make_int2(a.x - b.x, a.y - b.y); +} +inline __host__ __device__ void operator-=(int2 &a, int2 b) +{ + a.x -= b.x; a.y -= b.y; +} + +// multiply +inline __host__ __device__ int2 operator*(int2 a, int2 b) +{ + return make_int2(a.x * b.x, a.y * b.y); +} +inline __host__ __device__ int2 operator*(int2 a, int s) +{ + return make_int2(a.x * s, a.y * s); +} +inline __host__ __device__ int2 operator*(int s, int2 a) +{ + return make_int2(a.x * s, a.y * s); +} +inline __host__ __device__ void operator*=(int2 &a, int s) +{ + a.x *= s; a.y *= s; +} + +// float2 functions +//////////////////////////////////////////////////////////////////////////////// + +// additional constructors +inline __host__ __device__ float2 make_float2(float s) +{ + return make_float2(s, s); +} +inline __host__ __device__ float2 make_float2(int2 a) +{ + return make_float2(float(a.x), float(a.y)); +} + +// negate +inline __host__ __device__ float2 operator-(float2 &a) +{ + return make_float2(-a.x, -a.y); +} + +// addition +inline __host__ __device__ float2 operator+(float2 a, float2 b) +{ + return make_float2(a.x + b.x, a.y + b.y); +} +inline __host__ __device__ void operator+=(float2 &a, float2 b) +{ + a.x += b.x; a.y += b.y; +} + +// subtract +inline __host__ __device__ float2 operator-(float2 a, float2 b) +{ + return make_float2(a.x - b.x, a.y - b.y); +} +inline __host__ __device__ void operator-=(float2 &a, float2 b) +{ + a.x -= b.x; a.y -= b.y; +} + +// multiply +inline __host__ __device__ float2 operator*(float2 a, float2 b) +{ + return make_float2(a.x * b.x, a.y * b.y); +} +inline __host__ __device__ float2 operator*(float2 a, float s) +{ + return make_float2(a.x * s, a.y * s); +} +inline __host__ __device__ float2 operator*(float s, float2 a) +{ + return make_float2(a.x * s, a.y * s); +} +inline __host__ __device__ void operator*=(float2 &a, float s) +{ + a.x *= s; a.y *= s; +} + +// divide +inline __host__ __device__ float2 operator/(float2 a, float2 b) +{ + return make_float2(a.x / b.x, a.y / b.y); +} +inline __host__ __device__ float2 operator/(float2 a, float s) +{ + float inv = 1.0f / s; + return a * inv; +} +inline __host__ __device__ float2 operator/(float s, float2 a) +{ + float inv = 1.0f / s; + return a * inv; +} +inline __host__ __device__ void operator/=(float2 &a, float s) +{ + float inv = 1.0f / s; + a *= inv; +} + +// lerp +inline __device__ __host__ float2 lerp(float2 a, float2 b, float t) +{ + return a + t*(b-a); +} + +// clamp +inline __device__ __host__ float2 clamp(float2 v, float a, float b) +{ + return make_float2(clamp(v.x, a, b), clamp(v.y, a, b)); +} + +inline __device__ __host__ float2 clamp(float2 v, float2 a, float2 b) +{ + return make_float2(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y)); +} + +// dot product +inline __host__ __device__ float dot(float2 a, float2 b) +{ + return a.x * b.x + a.y * b.y; +} + +// length +inline __host__ __device__ float length(float2 v) +{ + return sqrtf(dot(v, v)); +} + +// normalize +inline __host__ __device__ float2 normalize(float2 v) +{ + float invLen = rsqrtf(dot(v, v)); + return v * invLen; +} + +// floor +inline __host__ __device__ float2 floor(const float2 v) +{ + return make_float2(floor(v.x), floor(v.y)); +} + +// reflect +inline __host__ __device__ float2 reflect(float2 i, float2 n) +{ + return i - 2.0f * n * dot(n,i); +} + +// absolute value +inline __host__ __device__ float2 fabs(float2 v) +{ + return make_float2(fabs(v.x), fabs(v.y)); +} + +// float3 functions +//////////////////////////////////////////////////////////////////////////////// + +// additional constructors +inline __host__ __device__ float3 make_float3(float s) +{ + return make_float3(s, s, s); +} +inline __host__ __device__ float3 make_float3(float2 a) +{ + return make_float3(a.x, a.y, 0.0f); +} +inline __host__ __device__ float3 make_float3(float2 a, float s) +{ + return make_float3(a.x, a.y, s); +} +inline __host__ __device__ float3 make_float3(float4 a) +{ + return make_float3(a.x, a.y, a.z); // discards w +} +inline __host__ __device__ float3 make_float3(int3 a) +{ + return make_float3(float(a.x), float(a.y), float(a.z)); +} + +// negate +inline __host__ __device__ float3 operator-(float3 &a) +{ + return make_float3(-a.x, -a.y, -a.z); +} + +// min +static __inline__ __host__ __device__ float3 fminf(float3 a, float3 b) +{ + return make_float3(fminf(a.x,b.x), fminf(a.y,b.y), fminf(a.z,b.z)); +} + +// max +static __inline__ __host__ __device__ float3 fmaxf(float3 a, float3 b) +{ + return make_float3(fmaxf(a.x,b.x), fmaxf(a.y,b.y), fmaxf(a.z,b.z)); +} + +// addition +inline __host__ __device__ float3 operator+(float3 a, float3 b) +{ + return make_float3(a.x + b.x, a.y + b.y, a.z + b.z); +} +inline __host__ __device__ float3 operator+(float3 a, float b) +{ + return make_float3(a.x + b, a.y + b, a.z + b); +} +inline __host__ __device__ void operator+=(float3 &a, float3 b) +{ + a.x += b.x; a.y += b.y; a.z += b.z; +} + +// subtract +inline __host__ __device__ float3 operator-(float3 a, float3 b) +{ + return make_float3(a.x - b.x, a.y - b.y, a.z - b.z); +} +inline __host__ __device__ float3 operator-(float3 a, float b) +{ + return make_float3(a.x - b, a.y - b, a.z - b); +} +inline __host__ __device__ void operator-=(float3 &a, float3 b) +{ + a.x -= b.x; a.y -= b.y; a.z -= b.z; +} + +// multiply +inline __host__ __device__ float3 operator*(float3 a, float3 b) +{ + return make_float3(a.x * b.x, a.y * b.y, a.z * b.z); +} +inline __host__ __device__ float3 operator*(float3 a, float s) +{ + return make_float3(a.x * s, a.y * s, a.z * s); +} +inline __host__ __device__ float3 operator*(float s, float3 a) +{ + return make_float3(a.x * s, a.y * s, a.z * s); +} +inline __host__ __device__ void operator*=(float3 &a, float s) +{ + a.x *= s; a.y *= s; a.z *= s; +} +inline __host__ __device__ void operator*=(float3 &a, float3 b) +{ + a.x *= b.x; a.y *= b.y; a.z *= b.z;; +} + +// divide +inline __host__ __device__ float3 operator/(float3 a, float3 b) +{ + return make_float3(a.x / b.x, a.y / b.y, a.z / b.z); +} +inline __host__ __device__ float3 operator/(float3 a, float s) +{ + float inv = 1.0f / s; + return a * inv; +} +inline __host__ __device__ float3 operator/(float s, float3 a) +{ + float inv = 1.0f / s; + return a * inv; +} +inline __host__ __device__ void operator/=(float3 &a, float s) +{ + float inv = 1.0f / s; + a *= inv; +} + +// lerp +inline __device__ __host__ float3 lerp(float3 a, float3 b, float t) +{ + return a + t*(b-a); +} + +// clamp +inline __device__ __host__ float3 clamp(float3 v, float a, float b) +{ + return make_float3(clamp(v.x, a, b), clamp(v.y, a, b), clamp(v.z, a, b)); +} + +inline __device__ __host__ float3 clamp(float3 v, float3 a, float3 b) +{ + return make_float3(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y), clamp(v.z, a.z, b.z)); +} + +// dot product +inline __host__ __device__ float dot(float3 a, float3 b) +{ + return a.x * b.x + a.y * b.y + a.z * b.z; +} + +// cross product +inline __host__ __device__ float3 cross(float3 a, float3 b) +{ + return make_float3(a.y*b.z - a.z*b.y, a.z*b.x - a.x*b.z, a.x*b.y - a.y*b.x); +} + +// length +inline __host__ __device__ float length(float3 v) +{ + return sqrtf(dot(v, v)); +} + +// normalize +inline __host__ __device__ float3 normalize(float3 v) +{ + float invLen = rsqrtf(dot(v, v)); + return v * invLen; +} + +// floor +inline __host__ __device__ float3 floor(const float3 v) +{ + return make_float3(floor(v.x), floor(v.y), floor(v.z)); +} + +// reflect +inline __host__ __device__ float3 reflect(float3 i, float3 n) +{ + return i - 2.0f * n * dot(n,i); +} + +// absolute value +inline __host__ __device__ float3 fabs(float3 v) +{ + return make_float3(fabs(v.x), fabs(v.y), fabs(v.z)); +} + +// float4 functions +//////////////////////////////////////////////////////////////////////////////// + +// additional constructors +inline __host__ __device__ float4 make_float4(float s) +{ + return make_float4(s, s, s, s); +} +inline __host__ __device__ float4 make_float4(float3 a) +{ + return make_float4(a.x, a.y, a.z, 0.0f); +} +inline __host__ __device__ float4 make_float4(float3 a, float w) +{ + return make_float4(a.x, a.y, a.z, w); +} +inline __host__ __device__ float4 make_float4(int4 a) +{ + return make_float4(float(a.x), float(a.y), float(a.z), float(a.w)); +} + +// negate +inline __host__ __device__ float4 operator-(float4 &a) +{ + return make_float4(-a.x, -a.y, -a.z, -a.w); +} + +// min +static __inline__ __host__ __device__ float4 fminf(float4 a, float4 b) +{ + return make_float4(fminf(a.x,b.x), fminf(a.y,b.y), fminf(a.z,b.z), fminf(a.w,b.w)); +} + +// max +static __inline__ __host__ __device__ float4 fmaxf(float4 a, float4 b) +{ + return make_float4(fmaxf(a.x,b.x), fmaxf(a.y,b.y), fmaxf(a.z,b.z), fmaxf(a.w,b.w)); +} + +// addition +inline __host__ __device__ float4 operator+(float4 a, float4 b) +{ + return make_float4(a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w); +} +inline __host__ __device__ void operator+=(float4 &a, float4 b) +{ + a.x += b.x; a.y += b.y; a.z += b.z; a.w += b.w; +} + +// subtract +inline __host__ __device__ float4 operator-(float4 a, float4 b) +{ + return make_float4(a.x - b.x, a.y - b.y, a.z - b.z, a.w - b.w); +} +inline __host__ __device__ void operator-=(float4 &a, float4 b) +{ + a.x -= b.x; a.y -= b.y; a.z -= b.z; a.w -= b.w; +} + +// multiply +inline __host__ __device__ float4 operator*(float4 a, float s) +{ + return make_float4(a.x * s, a.y * s, a.z * s, a.w * s); +} +inline __host__ __device__ float4 operator*(float s, float4 a) +{ + return make_float4(a.x * s, a.y * s, a.z * s, a.w * s); +} +inline __host__ __device__ void operator*=(float4 &a, float s) +{ + a.x *= s; a.y *= s; a.z *= s; a.w *= s; +} + +// divide +inline __host__ __device__ float4 operator/(float4 a, float4 b) +{ + return make_float4(a.x / b.x, a.y / b.y, a.z / b.z, a.w / b.w); +} +inline __host__ __device__ float4 operator/(float4 a, float s) +{ + float inv = 1.0f / s; + return a * inv; +} +inline __host__ __device__ float4 operator/(float s, float4 a) +{ + float inv = 1.0f / s; + return a * inv; +} +inline __host__ __device__ void operator/=(float4 &a, float s) +{ + float inv = 1.0f / s; + a *= inv; +} + +// lerp +inline __device__ __host__ float4 lerp(float4 a, float4 b, float t) +{ + return a + t*(b-a); +} + +// clamp +inline __device__ __host__ float4 clamp(float4 v, float a, float b) +{ + return make_float4(clamp(v.x, a, b), clamp(v.y, a, b), clamp(v.z, a, b), clamp(v.w, a, b)); +} + +inline __device__ __host__ float4 clamp(float4 v, float4 a, float4 b) +{ + return make_float4(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y), clamp(v.z, a.z, b.z), clamp(v.w, a.w, b.w)); +} + +// dot product +inline __host__ __device__ float dot(float4 a, float4 b) +{ + return a.x * b.x + a.y * b.y + a.z * b.z + a.w * b.w; +} + +// length +inline __host__ __device__ float length(float4 r) +{ + return sqrtf(dot(r, r)); +} + +// normalize +inline __host__ __device__ float4 normalize(float4 v) +{ + float invLen = rsqrtf(dot(v, v)); + return v * invLen; +} + +// floor +inline __host__ __device__ float4 floor(const float4 v) +{ + return make_float4(floor(v.x), floor(v.y), floor(v.z), floor(v.w)); +} + +// absolute value +inline __host__ __device__ float4 fabs(float4 v) +{ + return make_float4(fabs(v.x), fabs(v.y), fabs(v.z), fabs(v.w)); +} + +// int3 functions +//////////////////////////////////////////////////////////////////////////////// + +// additional constructors +inline __host__ __device__ int3 make_int3(int s) +{ + return make_int3(s, s, s); +} +inline __host__ __device__ int3 make_int3(float3 a) +{ + return make_int3(int(a.x), int(a.y), int(a.z)); +} + +// negate +inline __host__ __device__ int3 operator-(int3 &a) +{ + return make_int3(-a.x, -a.y, -a.z); +} + +// min +inline __host__ __device__ int3 min(int3 a, int3 b) +{ + return make_int3(min(a.x,b.x), min(a.y,b.y), min(a.z,b.z)); +} + +// max +inline __host__ __device__ int3 max(int3 a, int3 b) +{ + return make_int3(max(a.x,b.x), max(a.y,b.y), max(a.z,b.z)); +} + +// addition +inline __host__ __device__ int3 operator+(int3 a, int3 b) +{ + return make_int3(a.x + b.x, a.y + b.y, a.z + b.z); +} +inline __host__ __device__ void operator+=(int3 &a, int3 b) +{ + a.x += b.x; a.y += b.y; a.z += b.z; +} + +// subtract +inline __host__ __device__ int3 operator-(int3 a, int3 b) +{ + return make_int3(a.x - b.x, a.y - b.y, a.z - b.z); +} + +inline __host__ __device__ void operator-=(int3 &a, int3 b) +{ + a.x -= b.x; a.y -= b.y; a.z -= b.z; +} + +// multiply +inline __host__ __device__ int3 operator*(int3 a, int3 b) +{ + return make_int3(a.x * b.x, a.y * b.y, a.z * b.z); +} +inline __host__ __device__ int3 operator*(int3 a, int s) +{ + return make_int3(a.x * s, a.y * s, a.z * s); +} +inline __host__ __device__ int3 operator*(int s, int3 a) +{ + return make_int3(a.x * s, a.y * s, a.z * s); +} +inline __host__ __device__ void operator*=(int3 &a, int s) +{ + a.x *= s; a.y *= s; a.z *= s; +} + +// divide +inline __host__ __device__ int3 operator/(int3 a, int3 b) +{ + return make_int3(a.x / b.x, a.y / b.y, a.z / b.z); +} +inline __host__ __device__ int3 operator/(int3 a, int s) +{ + return make_int3(a.x / s, a.y / s, a.z / s); +} +inline __host__ __device__ int3 operator/(int s, int3 a) +{ + return make_int3(a.x / s, a.y / s, a.z / s); +} +inline __host__ __device__ void operator/=(int3 &a, int s) +{ + a.x /= s; a.y /= s; a.z /= s; +} + +// clamp +inline __device__ __host__ int clamp(int f, int a, int b) +{ + return max(a, min(f, b)); +} + +inline __device__ __host__ int3 clamp(int3 v, int a, int b) +{ + return make_int3(clamp(v.x, a, b), clamp(v.y, a, b), clamp(v.z, a, b)); +} + +inline __device__ __host__ int3 clamp(int3 v, int3 a, int3 b) +{ + return make_int3(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y), clamp(v.z, a.z, b.z)); +} + + +// uint3 functions +//////////////////////////////////////////////////////////////////////////////// + +// additional constructors +inline __host__ __device__ uint3 make_uint3(uint s) +{ + return make_uint3(s, s, s); +} +inline __host__ __device__ uint3 make_uint3(float3 a) +{ + return make_uint3(uint(a.x), uint(a.y), uint(a.z)); +} + +// min +inline __host__ __device__ uint3 min(uint3 a, uint3 b) +{ + return make_uint3(min(a.x,b.x), min(a.y,b.y), min(a.z,b.z)); +} + +// max +inline __host__ __device__ uint3 max(uint3 a, uint3 b) +{ + return make_uint3(max(a.x,b.x), max(a.y,b.y), max(a.z,b.z)); +} + +// addition +inline __host__ __device__ uint3 operator+(uint3 a, uint3 b) +{ + return make_uint3(a.x + b.x, a.y + b.y, a.z + b.z); +} +inline __host__ __device__ void operator+=(uint3 &a, uint3 b) +{ + a.x += b.x; a.y += b.y; a.z += b.z; +} + +// subtract +inline __host__ __device__ uint3 operator-(uint3 a, uint3 b) +{ + return make_uint3(a.x - b.x, a.y - b.y, a.z - b.z); +} + +inline __host__ __device__ void operator-=(uint3 &a, uint3 b) +{ + a.x -= b.x; a.y -= b.y; a.z -= b.z; +} + +// multiply +inline __host__ __device__ uint3 operator*(uint3 a, uint3 b) +{ + return make_uint3(a.x * b.x, a.y * b.y, a.z * b.z); +} +inline __host__ __device__ uint3 operator*(uint3 a, uint s) +{ + return make_uint3(a.x * s, a.y * s, a.z * s); +} +inline __host__ __device__ uint3 operator*(uint s, uint3 a) +{ + return make_uint3(a.x * s, a.y * s, a.z * s); +} +inline __host__ __device__ void operator*=(uint3 &a, uint s) +{ + a.x *= s; a.y *= s; a.z *= s; +} + +// divide +inline __host__ __device__ uint3 operator/(uint3 a, uint3 b) +{ + return make_uint3(a.x / b.x, a.y / b.y, a.z / b.z); +} +inline __host__ __device__ uint3 operator/(uint3 a, uint s) +{ + return make_uint3(a.x / s, a.y / s, a.z / s); +} +inline __host__ __device__ uint3 operator/(uint s, uint3 a) +{ + return make_uint3(a.x / s, a.y / s, a.z / s); +} +inline __host__ __device__ void operator/=(uint3 &a, uint s) +{ + a.x /= s; a.y /= s; a.z /= s; +} + +// clamp +inline __device__ __host__ uint clamp(uint f, uint a, uint b) +{ + return max(a, min(f, b)); +} + +inline __device__ __host__ uint3 clamp(uint3 v, uint a, uint b) +{ + return make_uint3(clamp(v.x, a, b), clamp(v.y, a, b), clamp(v.z, a, b)); +} + +inline __device__ __host__ uint3 clamp(uint3 v, uint3 a, uint3 b) +{ + return make_uint3(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y), clamp(v.z, a.z, b.z)); +} + + + +#endif \ No newline at end of file diff --git a/clib/include/intersect.h b/clib/include/intersect.h new file mode 100644 index 0000000..757b137 --- /dev/null +++ b/clib/include/intersect.h @@ -0,0 +1,17 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once +#include <torch/extension.h> +#include <utility> + +std::tuple<at::Tensor, at::Tensor, at::Tensor> ball_intersect(at::Tensor ray_start, at::Tensor ray_dir, at::Tensor points, + const float radius, const int n_max); +std::tuple<at::Tensor, at::Tensor, at::Tensor> aabb_intersect(at::Tensor ray_start, at::Tensor ray_dir, at::Tensor points, + const float voxelsize, const int n_max); +std::tuple<at::Tensor, at::Tensor, at::Tensor> svo_intersect(at::Tensor ray_start, at::Tensor ray_dir, at::Tensor points, at::Tensor children, + const float voxelsize, const int n_max); +std::tuple< at::Tensor, at::Tensor, at::Tensor > triangle_intersect(at::Tensor ray_start, at::Tensor ray_dir, at::Tensor face_points, + const float cagesize, const float blur, const int n_max); diff --git a/clib/include/octree.h b/clib/include/octree.h new file mode 100644 index 0000000..429053e --- /dev/null +++ b/clib/include/octree.h @@ -0,0 +1,10 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once +#include <torch/extension.h> +#include <utility> + +std::tuple<at::Tensor, at::Tensor> build_octree(at::Tensor center, at::Tensor points, int depth); \ No newline at end of file diff --git a/clib/include/sample.h b/clib/include/sample.h new file mode 100644 index 0000000..7547710 --- /dev/null +++ b/clib/include/sample.h @@ -0,0 +1,16 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once +#include <torch/extension.h> +#include <utility> + + +std::tuple<at::Tensor, at::Tensor, at::Tensor> uniform_ray_sampling( + at::Tensor pts_idx, at::Tensor min_depth, at::Tensor max_depth, at::Tensor uniform_noise, + const float step_size, const int max_steps); +std::tuple<at::Tensor, at::Tensor, at::Tensor> inverse_cdf_sampling( + at::Tensor pts_idx, at::Tensor min_depth, at::Tensor max_depth, at::Tensor uniform_noise, + at::Tensor probs, at::Tensor steps, float fixed_step_size); \ No newline at end of file diff --git a/clib/include/utils.h b/clib/include/utils.h new file mode 100644 index 0000000..925f769 --- /dev/null +++ b/clib/include/utils.h @@ -0,0 +1,30 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once +#include <ATen/cuda/CUDAContext.h> +#include <torch/extension.h> + +#define CHECK_CUDA(x) \ + do { \ + TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor"); \ + } while (0) + +#define CHECK_CONTIGUOUS(x) \ + do { \ + TORCH_CHECK(x.is_contiguous(), #x " must be a contiguous tensor"); \ + } while (0) + +#define CHECK_IS_INT(x) \ + do { \ + TORCH_CHECK(x.scalar_type() == at::ScalarType::Int, \ + #x " must be an int tensor"); \ + } while (0) + +#define CHECK_IS_FLOAT(x) \ + do { \ + TORCH_CHECK(x.scalar_type() == at::ScalarType::Float, \ + #x " must be a float tensor"); \ + } while (0) diff --git a/clib/src/binding.cpp b/clib/src/binding.cpp new file mode 100644 index 0000000..a7274d0 --- /dev/null +++ b/clib/src/binding.cpp @@ -0,0 +1,21 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +#include "intersect.h" +#include "octree.h" +#include "sample.h" + + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("ball_intersect", &ball_intersect); + m.def("aabb_intersect", &aabb_intersect); + m.def("svo_intersect", &svo_intersect); + m.def("triangle_intersect", &triangle_intersect); + + m.def("uniform_ray_sampling", &uniform_ray_sampling); + m.def("inverse_cdf_sampling", &inverse_cdf_sampling); + + m.def("build_octree", &build_octree); +} \ No newline at end of file diff --git a/clib/src/intersect.cpp b/clib/src/intersect.cpp new file mode 100644 index 0000000..5e5bab4 --- /dev/null +++ b/clib/src/intersect.cpp @@ -0,0 +1,146 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +#include "intersect.h" +#include "utils.h" +#include <utility> + +void ball_intersect_point_kernel_wrapper( + int b, int n, int m, float radius, int n_max, + const float *ray_start, const float *ray_dir, const float *points, + int *idx, float *min_depth, float *max_depth); + +std::tuple< at::Tensor, at::Tensor, at::Tensor > ball_intersect(at::Tensor ray_start, at::Tensor ray_dir, at::Tensor points, + const float radius, const int n_max){ + CHECK_CONTIGUOUS(ray_start); + CHECK_CONTIGUOUS(ray_dir); + CHECK_CONTIGUOUS(points); + CHECK_IS_FLOAT(ray_start); + CHECK_IS_FLOAT(ray_dir); + CHECK_IS_FLOAT(points); + CHECK_CUDA(ray_start); + CHECK_CUDA(ray_dir); + CHECK_CUDA(points); + + at::Tensor idx = + torch::zeros({ray_start.size(0), ray_start.size(1), n_max}, + at::device(ray_start.device()).dtype(at::ScalarType::Int)); + at::Tensor min_depth = + torch::zeros({ray_start.size(0), ray_start.size(1), n_max}, + at::device(ray_start.device()).dtype(at::ScalarType::Float)); + at::Tensor max_depth = + torch::zeros({ray_start.size(0), ray_start.size(1), n_max}, + at::device(ray_start.device()).dtype(at::ScalarType::Float)); + ball_intersect_point_kernel_wrapper(points.size(0), points.size(1), ray_start.size(1), + radius, n_max, + ray_start.data_ptr <float>(), ray_dir.data_ptr <float>(), points.data_ptr <float>(), + idx.data_ptr <int>(), min_depth.data_ptr <float>(), max_depth.data_ptr <float>()); + return std::make_tuple(idx, min_depth, max_depth); +} + + +void aabb_intersect_point_kernel_wrapper( + int b, int n, int m, float voxelsize, int n_max, + const float *ray_start, const float *ray_dir, const float *points, + int *idx, float *min_depth, float *max_depth); + +std::tuple< at::Tensor, at::Tensor, at::Tensor > aabb_intersect(at::Tensor ray_start, at::Tensor ray_dir, at::Tensor points, + const float voxelsize, const int n_max){ + CHECK_CONTIGUOUS(ray_start); + CHECK_CONTIGUOUS(ray_dir); + CHECK_CONTIGUOUS(points); + CHECK_IS_FLOAT(ray_start); + CHECK_IS_FLOAT(ray_dir); + CHECK_IS_FLOAT(points); + CHECK_CUDA(ray_start); + CHECK_CUDA(ray_dir); + CHECK_CUDA(points); + + at::Tensor idx = + torch::zeros({ray_start.size(0), ray_start.size(1), n_max}, + at::device(ray_start.device()).dtype(at::ScalarType::Int)); + at::Tensor min_depth = + torch::zeros({ray_start.size(0), ray_start.size(1), n_max}, + at::device(ray_start.device()).dtype(at::ScalarType::Float)); + at::Tensor max_depth = + torch::zeros({ray_start.size(0), ray_start.size(1), n_max}, + at::device(ray_start.device()).dtype(at::ScalarType::Float)); + aabb_intersect_point_kernel_wrapper(points.size(0), points.size(1), ray_start.size(1), + voxelsize, n_max, + ray_start.data_ptr <float>(), ray_dir.data_ptr <float>(), points.data_ptr <float>(), + idx.data_ptr <int>(), min_depth.data_ptr <float>(), max_depth.data_ptr <float>()); + return std::make_tuple(idx, min_depth, max_depth); +} + + +void svo_intersect_point_kernel_wrapper( + int b, int n, int m, float voxelsize, int n_max, + const float *ray_start, const float *ray_dir, const float *points, const int *children, + int *idx, float *min_depth, float *max_depth); + + +std::tuple< at::Tensor, at::Tensor, at::Tensor > svo_intersect(at::Tensor ray_start, at::Tensor ray_dir, at::Tensor points, + at::Tensor children, const float voxelsize, const int n_max){ + CHECK_CONTIGUOUS(ray_start); + CHECK_CONTIGUOUS(ray_dir); + CHECK_CONTIGUOUS(points); + CHECK_CONTIGUOUS(children); + CHECK_IS_FLOAT(ray_start); + CHECK_IS_FLOAT(ray_dir); + CHECK_IS_FLOAT(points); + CHECK_CUDA(ray_start); + CHECK_CUDA(ray_dir); + CHECK_CUDA(points); + CHECK_CUDA(children); + + at::Tensor idx = + torch::zeros({ray_start.size(0), ray_start.size(1), n_max}, + at::device(ray_start.device()).dtype(at::ScalarType::Int)); + at::Tensor min_depth = + torch::zeros({ray_start.size(0), ray_start.size(1), n_max}, + at::device(ray_start.device()).dtype(at::ScalarType::Float)); + at::Tensor max_depth = + torch::zeros({ray_start.size(0), ray_start.size(1), n_max}, + at::device(ray_start.device()).dtype(at::ScalarType::Float)); + svo_intersect_point_kernel_wrapper(points.size(0), points.size(1), ray_start.size(1), + voxelsize, n_max, + ray_start.data_ptr <float>(), ray_dir.data_ptr <float>(), points.data_ptr <float>(), + children.data_ptr <int>(), idx.data_ptr <int>(), min_depth.data_ptr <float>(), max_depth.data_ptr <float>()); + return std::make_tuple(idx, min_depth, max_depth); +} + + +void triangle_intersect_point_kernel_wrapper( + int b, int n, int m, float cagesize, float blur, int n_max, + const float *ray_start, const float *ray_dir, const float *face_points, + int *idx, float *depth, float *uv); + +std::tuple< at::Tensor, at::Tensor, at::Tensor > triangle_intersect(at::Tensor ray_start, at::Tensor ray_dir, at::Tensor face_points, + const float cagesize, const float blur, const int n_max){ + CHECK_CONTIGUOUS(ray_start); + CHECK_CONTIGUOUS(ray_dir); + CHECK_CONTIGUOUS(face_points); + CHECK_IS_FLOAT(ray_start); + CHECK_IS_FLOAT(ray_dir); + CHECK_IS_FLOAT(face_points); + CHECK_CUDA(ray_start); + CHECK_CUDA(ray_dir); + CHECK_CUDA(face_points); + + at::Tensor idx = + torch::zeros({ray_start.size(0), ray_start.size(1), n_max}, + at::device(ray_start.device()).dtype(at::ScalarType::Int)); + at::Tensor depth = + torch::zeros({ray_start.size(0), ray_start.size(1), n_max * 3}, + at::device(ray_start.device()).dtype(at::ScalarType::Float)); + at::Tensor uv = + torch::zeros({ray_start.size(0), ray_start.size(1), n_max * 2}, + at::device(ray_start.device()).dtype(at::ScalarType::Float)); + triangle_intersect_point_kernel_wrapper(face_points.size(0), face_points.size(1), ray_start.size(1), + cagesize, blur, n_max, + ray_start.data_ptr <float>(), ray_dir.data_ptr <float>(), face_points.data_ptr <float>(), + idx.data_ptr <int>(), depth.data_ptr <float>(), uv.data_ptr <float>()); + return std::make_tuple(idx, depth, uv); +} diff --git a/clib/src/intersect_gpu.cu b/clib/src/intersect_gpu.cu new file mode 100644 index 0000000..fa25cda --- /dev/null +++ b/clib/src/intersect_gpu.cu @@ -0,0 +1,375 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +#include <math.h> +#include <stdio.h> +#include <stdlib.h> + +#include "cuda_utils.h" +#include "cutil_math.h" // required for float3 vector math + +__global__ void ball_intersect_point_kernel(int b, int n, int m, float radius, int n_max, + const float *__restrict__ ray_start, + const float *__restrict__ ray_dir, + const float *__restrict__ points, int *__restrict__ idx, + float *__restrict__ min_depth, + float *__restrict__ max_depth) { + + int batch_index = blockIdx.x; + points += batch_index * n * 3; + ray_start += batch_index * m * 3; + ray_dir += batch_index * m * 3; + idx += batch_index * m * n_max; + min_depth += batch_index * m * n_max; + max_depth += batch_index * m * n_max; + + int index = threadIdx.x; + int stride = blockDim.x; + float radius2 = radius * radius; + + for (int j = index; j < m; j += stride) { + + float x0 = ray_start[j * 3 + 0]; + float y0 = ray_start[j * 3 + 1]; + float z0 = ray_start[j * 3 + 2]; + float xw = ray_dir[j * 3 + 0]; + float yw = ray_dir[j * 3 + 1]; + float zw = ray_dir[j * 3 + 2]; + + for (int l = 0; l < n_max; ++l) { + idx[j * n_max + l] = -1; + } + + for (int k = 0, cnt = 0; k < n && cnt < n_max; ++k) { + float x = points[k * 3 + 0] - x0; + float y = points[k * 3 + 1] - y0; + float z = points[k * 3 + 2] - z0; + float d2 = x * x + y * y + z * z; + float d2_proj = pow(x * xw + y * yw + z * zw, 2); + float r2 = d2 - d2_proj; + + if (r2 < radius2) { + idx[j * n_max + cnt] = k; + + float depth = sqrt(d2_proj); + float depth_blur = sqrt(radius2 - r2); + + min_depth[j * n_max + cnt] = depth - depth_blur; + max_depth[j * n_max + cnt] = depth + depth_blur; + ++cnt; + } + } + } +} + +__device__ float2 RayAABBIntersection(const float3 &ori, const float3 &dir, const float3 ¢er, + float half_voxel) { + + float f_low = 0; + float f_high = 100000.; + float f_dim_low, f_dim_high, temp, inv_ray_dir, start, aabb; + + for (int d = 0; d < 3; ++d) { + switch (d) { + case 0: + inv_ray_dir = __fdividef(1.0f, dir.x); + start = ori.x; + aabb = center.x; + break; + case 1: + inv_ray_dir = __fdividef(1.0f, dir.y); + start = ori.y; + aabb = center.y; + break; + case 2: + inv_ray_dir = __fdividef(1.0f, dir.z); + start = ori.z; + aabb = center.z; + break; + } + + f_dim_low = (aabb - half_voxel - start) * inv_ray_dir; + f_dim_high = (aabb + half_voxel - start) * inv_ray_dir; + + // Make sure low is less than high + if (f_dim_high < f_dim_low) { + temp = f_dim_low; + f_dim_low = f_dim_high; + f_dim_high = temp; + } + + // If this dimension's high is less than the low we got then we definitely missed. + // Likewise if the low is less than the high. + if (f_dim_high < f_low || f_dim_low > f_high) + return make_float2(-1.0f, -1.0f); + + // Add the clip from this dimension to the previous results + f_low = max(f_dim_low, f_low); + f_high = min(f_dim_high, f_high); + if (f_low >= f_high - 1e-5f) + return make_float2(-1.0f, -1.0f); + } + return make_float2(f_low, f_high); +} + +__global__ void aabb_intersect_point_kernel(int b, int n, int m, float voxelsize, int n_max, + const float *__restrict__ ray_start, + const float *__restrict__ ray_dir, + const float *__restrict__ points, int *__restrict__ idx, + float *__restrict__ min_depth, + float *__restrict__ max_depth) { + + int batch_index = blockIdx.x; + points += batch_index * n * 3; + ray_start += batch_index * m * 3; + ray_dir += batch_index * m * 3; + idx += batch_index * m * n_max; + min_depth += batch_index * m * n_max; + max_depth += batch_index * m * n_max; + + int index = threadIdx.x; + int stride = blockDim.x; + float half_voxel = voxelsize * 0.5; + + for (int j = index; j < m; j += stride) { + for (int l = 0; l < n_max; ++l) { + idx[j * n_max + l] = -1; + } + + for (int k = 0, cnt = 0; k < n && cnt < n_max; ++k) { + float2 depths = RayAABBIntersection( + make_float3(ray_start[j * 3 + 0], ray_start[j * 3 + 1], ray_start[j * 3 + 2]), + make_float3(ray_dir[j * 3 + 0], ray_dir[j * 3 + 1], ray_dir[j * 3 + 2]), + make_float3(points[k * 3 + 0], points[k * 3 + 1], points[k * 3 + 2]), half_voxel); + + if (depths.x > -1.0f) { + idx[j * n_max + cnt] = k; + min_depth[j * n_max + cnt] = depths.x; + max_depth[j * n_max + cnt] = depths.y; + ++cnt; + } + } + } +} + +__global__ void svo_intersect_point_kernel(int b, int n, int m, float voxelsize, int n_max, + const float *__restrict__ ray_start, + const float *__restrict__ ray_dir, + const float *__restrict__ points, + const int *__restrict__ children, int *__restrict__ idx, + float *__restrict__ min_depth, + float *__restrict__ max_depth) { + /* + TODO: this is an inefficient implementation of the + navie Ray -- Sparse Voxel Octree Intersection. + It can be further improved using: + + Revelles, Jorge, Carlos Urena, and Miguel Lastra. + "An efficient parametric algorithm for octree traversal." (2000). + */ + int batch_index = blockIdx.x; + points += batch_index * n * 3; + children += batch_index * n * 9; + ray_start += batch_index * m * 3; + ray_dir += batch_index * m * 3; + idx += batch_index * m * n_max; + min_depth += batch_index * m * n_max; + max_depth += batch_index * m * n_max; + + int index = threadIdx.x; + int stride = blockDim.x; + float half_voxel = voxelsize * 0.5; + + for (int j = index; j < m; j += stride) { + for (int l = 0; l < n_max; ++l) { + idx[j * n_max + l] = -1; + } + int stack[256] = {-1}; // DFS, initialize the stack + int ptr = 0, cnt = 0, k = -1; + stack[ptr] = n - 1; // ROOT node is always the last + while (ptr > -1 && cnt < n_max) { + assert((ptr < 256)); + + // evaluate the current node + k = stack[ptr]; + float2 depths = RayAABBIntersection( + make_float3(ray_start[j * 3 + 0], ray_start[j * 3 + 1], ray_start[j * 3 + 2]), + make_float3(ray_dir[j * 3 + 0], ray_dir[j * 3 + 1], ray_dir[j * 3 + 2]), + make_float3(points[k * 3 + 0], points[k * 3 + 1], points[k * 3 + 2]), + half_voxel * float(children[k * 9 + 8])); + stack[ptr] = -1; + ptr--; + + if (depths.x > -1.0f) { // ray did not miss the voxel + // TODO: here it should be able to know which children is ok, further optimize the + // code + if (children[k * 9 + 8] == 1) { // this is a terminal node + idx[j * n_max + cnt] = k; + min_depth[j * n_max + cnt] = depths.x; + max_depth[j * n_max + cnt] = depths.y; + ++cnt; + continue; + } + + for (int u = 0; u < 8; u++) { + if (children[k * 9 + u] > -1) { + ptr++; + stack[ptr] = children[k * 9 + u]; // push child to the stack + } + } + } + } + } +} + +__device__ float3 RayTriangleIntersection(const float3 &ori, const float3 &dir, const float3 &v0, + const float3 &v1, const float3 &v2, float blur) { + + float3 v0v1 = v1 - v0; + float3 v0v2 = v2 - v0; + float3 v0O = ori - v0; + float3 dir_crs_v0v2 = cross(dir, v0v2); + + float det = dot(v0v1, dir_crs_v0v2); + det = __fdividef(1.0f, det); // CUDA intrinsic function + + float u = dot(v0O, dir_crs_v0v2) * det; + if ((u < 0.0f - blur) || (u > 1.0f + blur)) + return make_float3(-1.0f, 0.0f, 0.0f); + + float3 v0O_crs_v0v1 = cross(v0O, v0v1); + float v = dot(dir, v0O_crs_v0v1) * det; + if ((v < 0.0f - blur) || (v > 1.0f + blur)) + return make_float3(-1.0f, 0.0f, 0.0f); + + if (((u + v) < 0.0f - blur) || ((u + v) > 1.0f + blur)) + return make_float3(-1.0f, 0.0f, 0.0f); + + float t = dot(v0v2, v0O_crs_v0v1) * det; + return make_float3(t, u, v); +} + +__global__ void triangle_intersect_point_kernel(int b, int n, int m, float cagesize, float blur, + int n_max, const float *__restrict__ ray_start, + const float *__restrict__ ray_dir, + const float *__restrict__ face_points, + int *__restrict__ idx, float *__restrict__ depth, + float *__restrict__ uv) { + + int batch_index = blockIdx.x; + face_points += batch_index * n * 9; + ray_start += batch_index * m * 3; + ray_dir += batch_index * m * 3; + idx += batch_index * m * n_max; + depth += batch_index * m * n_max * 3; + uv += batch_index * m * n_max * 2; + + int index = threadIdx.x; + int stride = blockDim.x; + for (int j = index; j < m; j += stride) { + // go over rays + for (int l = 0; l < n_max; ++l) { + idx[j * n_max + l] = -1; + } + + int cnt = 0; + for (int k = 0; k < n && cnt < n_max; ++k) { + // go over triangles + float3 tuv = RayTriangleIntersection( + make_float3(ray_start[j * 3 + 0], ray_start[j * 3 + 1], ray_start[j * 3 + 2]), + make_float3(ray_dir[j * 3 + 0], ray_dir[j * 3 + 1], ray_dir[j * 3 + 2]), + make_float3(face_points[k * 9 + 0], face_points[k * 9 + 1], face_points[k * 9 + 2]), + make_float3(face_points[k * 9 + 3], face_points[k * 9 + 4], face_points[k * 9 + 5]), + make_float3(face_points[k * 9 + 6], face_points[k * 9 + 7], face_points[k * 9 + 8]), + blur); + + if (tuv.x > 0) { + int ki = k; + float d = tuv.x, u = tuv.y, v = tuv.z; + + // sort + for (int l = 0; l < cnt; l++) { + if (d < depth[j * n_max * 3 + l * 3]) { + swap(ki, idx[j * n_max + l]); + swap(d, depth[j * n_max * 3 + l * 3]); + swap(u, uv[j * n_max * 2 + l * 2]); + swap(v, uv[j * n_max * 2 + l * 2 + 1]); + } + } + idx[j * n_max + cnt] = ki; + depth[j * n_max * 3 + cnt * 3] = d; + uv[j * n_max * 2 + cnt * 2] = u; + uv[j * n_max * 2 + cnt * 2 + 1] = v; + cnt++; + } + } + + for (int l = 0; l < cnt; l++) { + // compute min_depth + if (l == 0) + depth[j * n_max * 3 + l * 3 + 1] = -cagesize; + else + depth[j * n_max * 3 + l * 3 + 1] = + -fminf(cagesize, + .5 * (depth[j * n_max * 3 + l * 3] - depth[j * n_max * 3 + l * 3 - 3])); + + // compute max_depth + if (l == cnt - 1) + depth[j * n_max * 3 + l * 3 + 2] = cagesize; + else + depth[j * n_max * 3 + l * 3 + 2] = + fminf(cagesize, + .5 * (depth[j * n_max * 3 + l * 3 + 3] - depth[j * n_max * 3 + l * 3])); + } + } +} + +void ball_intersect_point_kernel_wrapper(int b, int n, int m, float radius, int n_max, + const float *ray_start, const float *ray_dir, + const float *points, int *idx, float *min_depth, + float *max_depth) { + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + ball_intersect_point_kernel<<<b, opt_n_threads(m), 0, stream>>>( + b, n, m, radius, n_max, ray_start, ray_dir, points, idx, min_depth, max_depth); + + CUDA_CHECK_ERRORS(); +} + +void aabb_intersect_point_kernel_wrapper(int b, int n, int m, float voxelsize, int n_max, + const float *ray_start, const float *ray_dir, + const float *points, int *idx, float *min_depth, + float *max_depth) { + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + aabb_intersect_point_kernel<<<b, opt_n_threads(m), 0, stream>>>( + b, n, m, voxelsize, n_max, ray_start, ray_dir, points, idx, min_depth, max_depth); + + CUDA_CHECK_ERRORS(); +} + +void svo_intersect_point_kernel_wrapper(int b, int n, int m, float voxelsize, int n_max, + const float *ray_start, const float *ray_dir, + const float *points, const int *children, int *idx, + float *min_depth, float *max_depth) { + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + svo_intersect_point_kernel<<<b, opt_n_threads(m), 0, stream>>>( + b, n, m, voxelsize, n_max, ray_start, ray_dir, points, children, idx, min_depth, max_depth); + + CUDA_CHECK_ERRORS(); +} + +void triangle_intersect_point_kernel_wrapper(int b, int n, int m, float cagesize, float blur, + int n_max, const float *ray_start, + const float *ray_dir, const float *face_points, + int *idx, float *depth, float *uv) { + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + triangle_intersect_point_kernel<<<b, opt_n_threads(m), 0, stream>>>( + b, n, m, cagesize, blur, n_max, ray_start, ray_dir, face_points, idx, depth, uv); + + CUDA_CHECK_ERRORS(); +} diff --git a/clib/src/octree.cpp b/clib/src/octree.cpp new file mode 100644 index 0000000..e1c8ab0 --- /dev/null +++ b/clib/src/octree.cpp @@ -0,0 +1,136 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +#include "octree.h" +#include "utils.h" +#include <utility> +#include <chrono> +using namespace std::chrono; + + +typedef struct OcTree +{ + int depth; + int index; + at::Tensor center; + struct OcTree *children[8]; + void init(at::Tensor center, int d, int i) { + this->center = center; + this->depth = d; + this->index = i; + for (int i=0; i<8; i++) this->children[i] = nullptr; + } +}OcTree; + +class EasyOctree { + public: + OcTree *root; + int total; + int terminal; + + at::Tensor all_centers; + at::Tensor all_children; + + EasyOctree(at::Tensor center, int depth) { + root = new OcTree; + root->init(center, depth, -1); + total = -1; + terminal = -1; + } + ~EasyOctree() { + OcTree *p = root; + destory(p); + } + void destory(OcTree * &p); + void insert(OcTree * &p, at::Tensor point, int index); + void finalize(); + std::pair<int, int> count(OcTree * &p); +}; + +void EasyOctree::destory(OcTree * &p){ + if (p != nullptr) { + for (int i=0; i<8; i++) { + if (p->children[i] != nullptr) destory(p->children[i]); + } + delete p; + p = nullptr; + } +} + +void EasyOctree::insert(OcTree * &p, at::Tensor point, int index) { + at::Tensor diff = (point > p->center).to(at::kInt); + int idx = diff[0].item<int>() + 2 * diff[1].item<int>() + 4 * diff[2].item<int>(); + if (p->depth == 0) { + p->children[idx] = new OcTree; + p->children[idx]->init(point, -1, index); + } else { + if (p->children[idx] == nullptr) { + int length = 1 << (p->depth - 1); + at::Tensor new_center = p->center + (2 * diff - 1) * length; + p->children[idx] = new OcTree; + p->children[idx]->init(new_center, p->depth-1, -1); + } + insert(p->children[idx], point, index); + } +} + +std::pair<int, int> EasyOctree::count(OcTree * &p) { + int total = 0, terminal = 0; + for (int i=0; i<8; i++) { + if (p->children[i] != nullptr) { + std::pair<int, int> sub = count(p->children[i]); + total += sub.first; + terminal += sub.second; + } + } + total += 1; + if (p->depth == -1) terminal += 1; + return std::make_pair(total, terminal); +} + +void EasyOctree::finalize() { + std::pair<int, int> outs = count(root); + total = outs.first; terminal = outs.second; + + all_centers = + torch::zeros({outs.first, 3}, at::device(root->center.device()).dtype(at::ScalarType::Int)); + all_children = + -torch::ones({outs.first, 9}, at::device(root->center.device()).dtype(at::ScalarType::Int)); + + int node_idx = outs.first - 1; + root->index = node_idx; + + std::queue<OcTree*> all_leaves; all_leaves.push(root); + while (!all_leaves.empty()) { + OcTree* node_ptr = all_leaves.front(); + all_leaves.pop(); + for (int i=0; i<8; i++) { + if (node_ptr->children[i] != nullptr) { + if (node_ptr->children[i]->depth > -1) { + node_idx--; + node_ptr->children[i]->index = node_idx; + } + all_leaves.push(node_ptr->children[i]); + all_children[node_ptr->index][i] = node_ptr->children[i]->index; + } + } + all_children[node_ptr->index][8] = 1 << (node_ptr->depth + 1); + all_centers[node_ptr->index] = node_ptr->center; + } + assert (node_idx == outs.second); +}; + +std::tuple<at::Tensor, at::Tensor> build_octree(at::Tensor center, at::Tensor points, int depth) { + auto start = high_resolution_clock::now(); + EasyOctree tree(center, depth); + for (int k=0; k<points.size(0); k++) + tree.insert(tree.root, points[k], k); + tree.finalize(); + auto stop = high_resolution_clock::now(); + auto duration = duration_cast<microseconds>(stop - start); + printf("Building EasyOctree done. total #nodes = %d, terminal #nodes = %d (time taken %f s)\n", + tree.total, tree.terminal, float(duration.count()) / 1000000.); + return std::make_tuple(tree.all_centers, tree.all_children); +} \ No newline at end of file diff --git a/clib/src/sample.cpp b/clib/src/sample.cpp new file mode 100644 index 0000000..a67c2f7 --- /dev/null +++ b/clib/src/sample.cpp @@ -0,0 +1,96 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +#include "sample.h" +#include "utils.h" +#include <utility> + + +void uniform_ray_sampling_kernel_wrapper( + int b, int num_rays, int max_hits, int max_steps, float step_size, + const int *pts_idx, const float *min_depth, const float *max_depth, const float *uniform_noise, + int *sampled_idx, float *sampled_depth, float *sampled_dists); + +void inverse_cdf_sampling_kernel_wrapper( + int b, int num_rays, int max_hits, int max_steps, float fixed_step_size, + const int *pts_idx, const float *min_depth, const float *max_depth, + const float *uniform_noise, const float *probs, const float *steps, + int *sampled_idx, float *sampled_depth, float *sampled_dists); + + +std::tuple< at::Tensor, at::Tensor, at::Tensor> uniform_ray_sampling( + at::Tensor pts_idx, at::Tensor min_depth, at::Tensor max_depth, at::Tensor uniform_noise, + const float step_size, const int max_steps){ + + CHECK_CONTIGUOUS(pts_idx); + CHECK_CONTIGUOUS(min_depth); + CHECK_CONTIGUOUS(max_depth); + CHECK_CONTIGUOUS(uniform_noise); + CHECK_IS_FLOAT(min_depth); + CHECK_IS_FLOAT(max_depth); + CHECK_IS_FLOAT(uniform_noise); + CHECK_IS_INT(pts_idx); + CHECK_CUDA(pts_idx); + CHECK_CUDA(min_depth); + CHECK_CUDA(max_depth); + CHECK_CUDA(uniform_noise); + + at::Tensor sampled_idx = + -torch::ones({pts_idx.size(0), pts_idx.size(1), max_steps}, + at::device(pts_idx.device()).dtype(at::ScalarType::Int)); + at::Tensor sampled_depth = + torch::zeros({min_depth.size(0), min_depth.size(1), max_steps}, + at::device(min_depth.device()).dtype(at::ScalarType::Float)); + at::Tensor sampled_dists = + torch::zeros({min_depth.size(0), min_depth.size(1), max_steps}, + at::device(min_depth.device()).dtype(at::ScalarType::Float)); + uniform_ray_sampling_kernel_wrapper(min_depth.size(0), min_depth.size(1), min_depth.size(2), sampled_depth.size(2), + step_size, + pts_idx.data_ptr <int>(), min_depth.data_ptr <float>(), max_depth.data_ptr <float>(), + uniform_noise.data_ptr <float>(), sampled_idx.data_ptr <int>(), + sampled_depth.data_ptr <float>(), sampled_dists.data_ptr <float>()); + return std::make_tuple(sampled_idx, sampled_depth, sampled_dists); +} + + +std::tuple<at::Tensor, at::Tensor, at::Tensor> inverse_cdf_sampling( + at::Tensor pts_idx, at::Tensor min_depth, at::Tensor max_depth, at::Tensor uniform_noise, + at::Tensor probs, at::Tensor steps, float fixed_step_size) { + + CHECK_CONTIGUOUS(pts_idx); + CHECK_CONTIGUOUS(min_depth); + CHECK_CONTIGUOUS(max_depth); + CHECK_CONTIGUOUS(probs); + CHECK_CONTIGUOUS(steps); + CHECK_CONTIGUOUS(uniform_noise); + CHECK_IS_FLOAT(min_depth); + CHECK_IS_FLOAT(max_depth); + CHECK_IS_FLOAT(uniform_noise); + CHECK_IS_FLOAT(probs); + CHECK_IS_FLOAT(steps); + CHECK_IS_INT(pts_idx); + CHECK_CUDA(pts_idx); + CHECK_CUDA(min_depth); + CHECK_CUDA(max_depth); + CHECK_CUDA(uniform_noise); + CHECK_CUDA(probs); + CHECK_CUDA(steps); + + int max_steps = uniform_noise.size(-1); + at::Tensor sampled_idx = + -torch::ones({pts_idx.size(0), pts_idx.size(1), max_steps}, + at::device(pts_idx.device()).dtype(at::ScalarType::Int)); + at::Tensor sampled_depth = + torch::zeros({min_depth.size(0), min_depth.size(1), max_steps}, + at::device(min_depth.device()).dtype(at::ScalarType::Float)); + at::Tensor sampled_dists = + torch::zeros({min_depth.size(0), min_depth.size(1), max_steps}, + at::device(min_depth.device()).dtype(at::ScalarType::Float)); + inverse_cdf_sampling_kernel_wrapper(min_depth.size(0), min_depth.size(1), min_depth.size(2), sampled_depth.size(2), fixed_step_size, + pts_idx.data_ptr <int>(), min_depth.data_ptr <float>(), max_depth.data_ptr <float>(), + uniform_noise.data_ptr <float>(), probs.data_ptr <float>(), steps.data_ptr <float>(), + sampled_idx.data_ptr <int>(), sampled_depth.data_ptr <float>(), sampled_dists.data_ptr <float>()); + return std::make_tuple(sampled_idx, sampled_depth, sampled_dists); +} \ No newline at end of file diff --git a/clib/src/sample_gpu.cu b/clib/src/sample_gpu.cu new file mode 100644 index 0000000..7e4e212 --- /dev/null +++ b/clib/src/sample_gpu.cu @@ -0,0 +1,231 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + + +#include <math.h> +#include <stdio.h> +#include <stdlib.h> + +#include "cuda_utils.h" +#include "cutil_math.h" // required for float3 vector math + + +__global__ void uniform_ray_sampling_kernel( + int b, int num_rays, + int max_hits, + int max_steps, + float step_size, + const int *__restrict__ pts_idx, + const float *__restrict__ min_depth, + const float *__restrict__ max_depth, + const float *__restrict__ uniform_noise, + int *__restrict__ sampled_idx, + float *__restrict__ sampled_depth, + float *__restrict__ sampled_dists) { + + int batch_index = blockIdx.x; + int index = threadIdx.x; + int stride = blockDim.x; + + pts_idx += batch_index * num_rays * max_hits; + min_depth += batch_index * num_rays * max_hits; + max_depth += batch_index * num_rays * max_hits; + + uniform_noise += batch_index * num_rays * max_steps; + sampled_idx += batch_index * num_rays * max_steps; + sampled_depth += batch_index * num_rays * max_steps; + sampled_dists += batch_index * num_rays * max_steps; + + // loop over all rays + for (int j = index; j < num_rays; j += stride) { + int H = j * max_hits, K = j * max_steps; + int s = 0, ucur = 0, umin = 0, umax = 0; + float last_min_depth, last_max_depth, curr_depth; + + // sort all depths + while (true) { + if ((umax == max_hits) || (ucur == max_steps) || (pts_idx[H + umax] == -1)) { + break; // reach the maximum + } + if (umin < max_hits) { + last_min_depth = min_depth[H + umin]; + } else { + last_min_depth = 10000.0; + } + if (umax < max_hits) { + last_max_depth = max_depth[H + umax]; + } else { + last_max_depth = 10000.0; + } + if (ucur < max_steps) { + curr_depth = min_depth[H] + (float(ucur) + uniform_noise[K + ucur]) * step_size; + } + + if ((last_max_depth <= curr_depth) && (last_max_depth <= last_min_depth)) { + sampled_depth[K + s] = last_max_depth; + sampled_idx[K + s] = pts_idx[H + umax]; + umax++; s++; continue; + } + if ((curr_depth <= last_min_depth) && (curr_depth <= last_max_depth)) { + sampled_depth[K + s] = curr_depth; + sampled_idx[K + s] = pts_idx[H + umin - 1]; + ucur++; s++; continue; + } + if ((last_min_depth <= curr_depth) && (last_min_depth <= last_max_depth)) { + sampled_depth[K + s] = last_min_depth; + sampled_idx[K + s] = pts_idx[H + umin]; + umin++; s++; continue; + } + } + + float l_depth, r_depth; + int step = 0; + for (ucur = 0, umin = 0, umax = 0; ucur < max_steps - 1; ucur++) { + if (sampled_idx[K + ucur + 1] == -1) break; + l_depth = sampled_depth[K + ucur]; + r_depth = sampled_depth[K + ucur + 1]; + sampled_depth[K + ucur] = (l_depth + r_depth) * .5; + sampled_dists[K + ucur] = (r_depth - l_depth); + if ((umin < max_hits) && (sampled_depth[K + ucur] >= min_depth[H + umin]) && (pts_idx[H + umin] > -1)) umin++; + if ((umax < max_hits) && (sampled_depth[K + ucur] >= max_depth[H + umax]) && (pts_idx[H + umax] > -1)) umax++; + if ((umax == max_hits) || (pts_idx[H + umax] == -1)) break; + if ((umin - 1 == umax) && (sampled_dists[K + ucur] > 0)) { + sampled_depth[K + step] = sampled_depth[K + ucur]; + sampled_dists[K + step] = sampled_dists[K + ucur]; + sampled_idx[K + step] = sampled_idx[K + ucur]; + step++; + } + } + + for (int s = step; s < max_steps; s++) { + sampled_idx[K + s] = -1; + } + } +} + +__global__ void inverse_cdf_sampling_kernel( + int b, int num_rays, + int max_hits, + int max_steps, + float fixed_step_size, + const int *__restrict__ pts_idx, + const float *__restrict__ min_depth, + const float *__restrict__ max_depth, + const float *__restrict__ uniform_noise, + const float *__restrict__ probs, + const float *__restrict__ steps, + int *__restrict__ sampled_idx, + float *__restrict__ sampled_depth, + float *__restrict__ sampled_dists) { + + int batch_index = blockIdx.x; + int index = threadIdx.x; + int stride = blockDim.x; + + pts_idx += batch_index * num_rays * max_hits; + min_depth += batch_index * num_rays * max_hits; + max_depth += batch_index * num_rays * max_hits; + probs += batch_index * num_rays * max_hits; + steps += batch_index * num_rays; + + uniform_noise += batch_index * num_rays * max_steps; + sampled_idx += batch_index * num_rays * max_steps; + sampled_depth += batch_index * num_rays * max_steps; + sampled_dists += batch_index * num_rays * max_steps; + + // loop over all rays + for (int j = index; j < num_rays; j += stride) { + int H = j * max_hits, K = j * max_steps; + int curr_bin = 0, s = 0; // current index (bin) + + float curr_min_depth = min_depth[H]; // lower depth + float curr_max_depth = max_depth[H]; // upper depth + float curr_min_cdf = 0; + float curr_max_cdf = probs[H]; + float step_size = 1.0 / steps[j]; + float z_low = curr_min_depth; + int total_steps = int(ceil(steps[j])); + bool done = false; + + // optional use a fixed step size + if (fixed_step_size > 0.0) step_size = fixed_step_size; + + // sample points + for (int curr_step = 0; curr_step < total_steps; curr_step++) { + float curr_cdf = (float(curr_step) + uniform_noise[K + curr_step]) * step_size; + while (curr_cdf > curr_max_cdf) { + // first include max cdf + sampled_idx[K + s] = pts_idx[H + curr_bin]; + sampled_dists[K + s] = (curr_max_depth - z_low); + sampled_depth[K + s] = (curr_max_depth + z_low) * .5; + + // move to next cdf + curr_bin++; + s++; + if ((curr_bin >= max_hits) || (pts_idx[H + curr_bin] == -1)) { + done = true; break; + } + curr_min_depth = min_depth[H + curr_bin]; + curr_max_depth = max_depth[H + curr_bin]; + curr_min_cdf = curr_max_cdf; + curr_max_cdf = curr_max_cdf + probs[H + curr_bin]; + z_low = curr_min_depth; + } + if (done) break; + + // if the sampled cdf is inside bin + float u = (curr_cdf - curr_min_cdf) / (curr_max_cdf - curr_min_cdf); + float z = curr_min_depth + u * (curr_max_depth - curr_min_depth); + sampled_idx[K + s] = pts_idx[H + curr_bin]; + sampled_dists[K + s] = (z - z_low); + sampled_depth[K + s] = (z + z_low) * .5; + z_low = z; s++; + } + + // if there are bins still remained + while ((z_low < curr_max_depth) && (~done)) { + sampled_idx[K + s] = pts_idx[H + curr_bin]; + sampled_dists[K + s] = (curr_max_depth - z_low); + sampled_depth[K + s] = (curr_max_depth + z_low) * .5; + curr_bin++; + s++; + if ((curr_bin >= max_hits) || (pts_idx[curr_bin] == -1)) + break; + + curr_min_depth = min_depth[H + curr_bin]; + curr_max_depth = max_depth[H + curr_bin]; + z_low = curr_min_depth; + } + } +} + +void uniform_ray_sampling_kernel_wrapper( + int b, int num_rays, int max_hits, int max_steps, float step_size, + const int *pts_idx, const float *min_depth, const float *max_depth, const float *uniform_noise, + int *sampled_idx, float *sampled_depth, float *sampled_dists) { + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + uniform_ray_sampling_kernel<<<b, opt_n_threads(num_rays), 0, stream>>>( + b, num_rays, max_hits, max_steps, step_size, pts_idx, + min_depth, max_depth, uniform_noise, sampled_idx, sampled_depth, sampled_dists); + + CUDA_CHECK_ERRORS(); +} + +void inverse_cdf_sampling_kernel_wrapper( + int b, int num_rays, int max_hits, int max_steps, float fixed_step_size, + const int *pts_idx, const float *min_depth, const float *max_depth, + const float *uniform_noise, const float *probs, const float *steps, + int *sampled_idx, float *sampled_depth, float *sampled_dists) { + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + inverse_cdf_sampling_kernel<<<b, opt_n_threads(num_rays), 0, stream>>>( + b, num_rays, max_hits, max_steps, fixed_step_size, + pts_idx, min_depth, max_depth, uniform_noise, probs, steps, + sampled_idx, sampled_depth, sampled_dists); + + CUDA_CHECK_ERRORS(); +} + \ No newline at end of file diff --git a/configs/nerf_default.json b/configs/nerf_default.json new file mode 100644 index 0000000..3f9165b --- /dev/null +++ b/configs/nerf_default.json @@ -0,0 +1,22 @@ +{ + "model": "NeRF", + "args": { + "color": "rgb", + "n_pot_encode": 10, + "n_dir_encode": 4, + "fc_params": { + "nf": 256, + "n_layers": 8, + "activation": "relu", + "skips": [ 4 ] + }, + "n_featdim": 0, + "sample_range": [0, 10], + "n_samples": 256, + "perturb_sample": true, + "spherical": false, + "lindisp": false, + "raymarching_tolerance": 0, + "raymarching_chunk_size": -1 + } +} \ No newline at end of file diff --git a/configs/nerf_voxels.json b/configs/nerf_voxels.json new file mode 100644 index 0000000..411ab9d --- /dev/null +++ b/configs/nerf_voxels.json @@ -0,0 +1,24 @@ +{ + "model": "NeRF", + "args": { + "color": "rgb", + "n_pot_encode": 10, + "n_dir_encode": 4, + "fc_params": { + "nf": 256, + "n_layers": 8, + "activation": "relu", + "skips": [ 4 ] + }, + "n_featdim": 0, + "space": "voxels", + "voxel_size": 0.5, + "sample_range": [0, 10], + "n_samples": 50, + "perturb_sample": true, + "spherical": false, + "lindisp": false, + "raymarching_tolerance": 0, + "raymarching_chunk_size": -1 + } +} \ No newline at end of file diff --git a/configs/nsvf_coarse.json b/configs/nsvf_coarse.json new file mode 100644 index 0000000..f9b341c --- /dev/null +++ b/configs/nsvf_coarse.json @@ -0,0 +1,21 @@ +{ + "model": "NSVF", + "args": { + "color": "rgb", + "n_pot_encode": 10, + "n_dir_encode": 4, + "fc_params": { + "nf": 128, + "n_layers": 4, + "activation": "relu", + "skips": [ 4 ] + }, + "n_featdim": 0, + "space": "octree", + "voxel_size": 0.5, + "sample_step_ratio": 0.2, + "perturb_sample": true, + "raymarching_tolerance": 0, + "raymarching_chunk_size": -1 + } +} \ No newline at end of file diff --git a/configs/nsvf_default.json b/configs/nsvf_default.json new file mode 100644 index 0000000..ad6faaf --- /dev/null +++ b/configs/nsvf_default.json @@ -0,0 +1,21 @@ +{ + "model": "NSVF", + "args": { + "color": "rgb", + "n_pot_encode": 10, + "n_dir_encode": 4, + "fc_params": { + "nf": 256, + "n_layers": 8, + "activation": "relu", + "skips": [ 4 ] + }, + "n_featdim": 0, + "space": "octree", + "voxel_size": 0.5, + "sample_step_ratio": 0.2, + "perturb_sample": true, + "raymarching_tolerance": 0, + "raymarching_chunk_size": -1 + } +} \ No newline at end of file diff --git a/configs/nsvf_voxels.json b/configs/nsvf_voxels.json new file mode 100644 index 0000000..b60ae89 --- /dev/null +++ b/configs/nsvf_voxels.json @@ -0,0 +1,21 @@ +{ + "model": "NSVF", + "args": { + "color": "rgb", + "n_pot_encode": 10, + "n_dir_encode": 4, + "fc_params": { + "nf": 256, + "n_layers": 8, + "activation": "relu", + "skips": [ 4 ] + }, + "n_featdim": 0, + "space": "voxels", + "voxel_size": 0.5, + "sample_step_ratio": 0.2, + "perturb_sample": true, + "raymarching_tolerance": 0, + "raymarching_chunk_size": -1 + } +} \ No newline at end of file diff --git a/configs/bgnet.py b/configs/old/bgnet.py similarity index 100% rename from configs/bgnet.py rename to configs/old/bgnet.py diff --git a/configs/cnerf.py b/configs/old/cnerf.py similarity index 100% rename from configs/cnerf.py rename to configs/old/cnerf.py diff --git a/configs/dnerfabins.py b/configs/old/dnerfabins.py similarity index 100% rename from configs/dnerfabins.py rename to configs/old/dnerfabins.py diff --git a/configs/fovea.py b/configs/old/fovea.py similarity index 100% rename from configs/fovea.py rename to configs/old/fovea.py diff --git a/configs/fovea_small_rot1.py b/configs/old/fovea_small_rot1.py similarity index 100% rename from configs/fovea_small_rot1.py rename to configs/old/fovea_small_rot1.py diff --git a/configs/fovea_small_trans.py b/configs/old/fovea_small_trans.py similarity index 100% rename from configs/fovea_small_trans.py rename to configs/old/fovea_small_trans.py diff --git a/configs/msl2fast.py b/configs/old/msl2fast.py similarity index 100% rename from configs/msl2fast.py rename to configs/old/msl2fast.py diff --git a/configs/msl_fovea.py b/configs/old/msl_fovea.py similarity index 100% rename from configs/msl_fovea.py rename to configs/old/msl_fovea.py diff --git a/configs/mslfast.py b/configs/old/mslfast.py similarity index 100% rename from configs/mslfast.py rename to configs/old/mslfast.py diff --git a/configs/mslray.py b/configs/old/mslray.py similarity index 100% rename from configs/mslray.py rename to configs/old/mslray.py diff --git a/configs/nerf.py b/configs/old/nerf.py similarity index 100% rename from configs/nerf.py rename to configs/old/nerf.py diff --git a/configs/nerf_horns.py b/configs/old/nerf_horns.py similarity index 100% rename from configs/nerf_horns.py rename to configs/old/nerf_horns.py diff --git a/configs/nerf_horns_4.py b/configs/old/nerf_horns_4.py similarity index 100% rename from configs/nerf_horns_4.py rename to configs/old/nerf_horns_4.py diff --git a/configs/nerf_horns_8.py b/configs/old/nerf_horns_8.py similarity index 100% rename from configs/nerf_horns_8.py rename to configs/old/nerf_horns_8.py diff --git a/configs/nerf_periph.py b/configs/old/nerf_periph.py similarity index 100% rename from configs/nerf_periph.py rename to configs/old/nerf_periph.py diff --git a/configs/nerf_trex.py b/configs/old/nerf_trex.py similarity index 100% rename from configs/nerf_trex.py rename to configs/old/nerf_trex.py diff --git a/configs/nerf_trex_4.py b/configs/old/nerf_trex_4.py similarity index 100% rename from configs/nerf_trex_4.py rename to configs/old/nerf_trex_4.py diff --git a/configs/nerf_trex_8.py b/configs/old/nerf_trex_8.py similarity index 100% rename from configs/nerf_trex_8.py rename to configs/old/nerf_trex_8.py diff --git a/configs/nerfsimple.py b/configs/old/nerfsimple.py similarity index 100% rename from configs/nerfsimple.py rename to configs/old/nerfsimple.py diff --git a/configs/nmsl_fovea.py b/configs/old/nmsl_fovea.py similarity index 100% rename from configs/nmsl_fovea.py rename to configs/old/nmsl_fovea.py diff --git a/configs/nnerf.py b/configs/old/nnerf.py similarity index 100% rename from configs/nnerf.py rename to configs/old/nnerf.py diff --git a/configs/oracle.py b/configs/old/oracle.py similarity index 100% rename from configs/oracle.py rename to configs/old/oracle.py diff --git a/configs/periph.py b/configs/old/periph.py similarity index 100% rename from configs/periph.py rename to configs/old/periph.py diff --git a/configs/periph_new.py b/configs/old/periph_new.py similarity index 100% rename from configs/periph_new.py rename to configs/old/periph_new.py diff --git a/configs/periph_small_trans.py b/configs/old/periph_small_trans.py similarity index 100% rename from configs/periph_small_trans.py rename to configs/old/periph_small_trans.py diff --git a/configs/snerffast_periph.py b/configs/old/snerffast_periph.py similarity index 100% rename from configs/snerffast_periph.py rename to configs/old/snerffast_periph.py diff --git a/configs/snerffastx.py b/configs/old/snerffastx.py similarity index 100% rename from configs/snerffastx.py rename to configs/old/snerffastx.py diff --git a/configs/snerf_fine_voxels.json b/configs/snerf_fine_voxels.json new file mode 100644 index 0000000..8079e6f --- /dev/null +++ b/configs/snerf_fine_voxels.json @@ -0,0 +1,21 @@ +{ + "model": "SNeRF", + "args": { + "color": "rgb", + "n_pot_encode": 10, + "n_dir_encode": 4, + "fc_params": { + "nf": 256, + "n_layers": 8, + "activation": "relu", + "skips": [ 4 ] + }, + "n_featdim": 0, + "space": "voxels", + "steps": [8, 32, 16], + "n_samples": 16, + "perturb_sample": true, + "raymarching_tolerance": 0, + "raymarching_chunk_size": -1 + } +} \ No newline at end of file diff --git a/configs/snerf_voxels+ls-d.json b/configs/snerf_voxels+ls-d.json new file mode 100644 index 0000000..2eebf9b --- /dev/null +++ b/configs/snerf_voxels+ls-d.json @@ -0,0 +1,20 @@ +{ + "model": "SNeRF", + "args": { + "color": "rgb", + "n_pot_encode": 10, + "fc_params": { + "nf": 256, + "n_layers": 8, + "activation": "relu", + "skips": [ 4 ] + }, + "n_featdim": 0, + "space": "voxels", + "steps": [4, 16, 8], + "n_samples": 16, + "perturb_sample": true, + "density_regularization_weight": 1e-4, + "density_regularization_scale": 1e4 + } +} \ No newline at end of file diff --git a/configs/snerf_voxels+ls.json b/configs/snerf_voxels+ls.json new file mode 100644 index 0000000..a7cde45 --- /dev/null +++ b/configs/snerf_voxels+ls.json @@ -0,0 +1,21 @@ +{ + "model": "SNeRF", + "args": { + "color": "rgb", + "n_pot_encode": 10, + "n_dir_encode": 4, + "fc_params": { + "nf": 256, + "n_layers": 8, + "activation": "relu", + "skips": [ 4 ] + }, + "n_featdim": 0, + "space": "voxels", + "steps": [4, 16, 8], + "n_samples": 16, + "perturb_sample": true, + "density_regularization_weight": 1e-4, + "density_regularization_scale": 1e4 + } +} \ No newline at end of file diff --git a/configs/snerf_voxels.json b/configs/snerf_voxels.json new file mode 100644 index 0000000..7e68cb4 --- /dev/null +++ b/configs/snerf_voxels.json @@ -0,0 +1,19 @@ +{ + "model": "SNeRF", + "args": { + "color": "rgb", + "n_pot_encode": 10, + "n_dir_encode": 4, + "fc_params": { + "nf": 256, + "n_layers": 8, + "activation": "relu", + "skips": [ 4 ] + }, + "n_featdim": 0, + "space": "voxels", + "steps": [4, 16, 8], + "n_samples": 16, + "perturb_sample": true + } +} \ No newline at end of file diff --git a/configs/snerf_voxels_128x8_x2.json b/configs/snerf_voxels_128x8_x2.json new file mode 100644 index 0000000..0052515 --- /dev/null +++ b/configs/snerf_voxels_128x8_x2.json @@ -0,0 +1,22 @@ +{ + "model": "SNeRF", + "args": { + "color": "rgb", + "n_pot_encode": 10, + "n_dir_encode": 4, + "fc_params": { + "nf": 128, + "n_layers": 8, + "activation": "relu", + "skips": [ 4 ] + }, + "n_featdim": 0, + "space": "voxels", + "steps": [4, 16, 8], + "n_samples": 16, + "perturb_sample": true, + "raymarching_tolerance": 0, + "raymarching_chunk_size": -1, + "multi_nets": 2 + } +} \ No newline at end of file diff --git a/configs/snerf_voxels_128x8_x4.json b/configs/snerf_voxels_128x8_x4.json new file mode 100644 index 0000000..268498c --- /dev/null +++ b/configs/snerf_voxels_128x8_x4.json @@ -0,0 +1,22 @@ +{ + "model": "SNeRF", + "args": { + "color": "rgb", + "n_pot_encode": 10, + "n_dir_encode": 4, + "fc_params": { + "nf": 128, + "n_layers": 8, + "activation": "relu", + "skips": [ 4 ] + }, + "n_featdim": 0, + "space": "voxels", + "steps": [4, 16, 8], + "n_samples": 16, + "perturb_sample": true, + "raymarching_tolerance": 0, + "raymarching_chunk_size": -1, + "multi_nets": 4 + } +} \ No newline at end of file diff --git a/configs/snerf_voxels_feat.json b/configs/snerf_voxels_feat.json new file mode 100644 index 0000000..fcd8dce --- /dev/null +++ b/configs/snerf_voxels_feat.json @@ -0,0 +1,21 @@ +{ + "model": "SNeRF", + "args": { + "color": "rgb", + "n_pot_encode": 10, + "n_dir_encode": 4, + "fc_params": { + "nf": 256, + "n_layers": 8, + "activation": "relu", + "skips": [ 4 ] + }, + "n_featdim": 32, + "space": "voxels", + "steps": [4, 16, 8], + "n_samples": 16, + "perturb_sample": true, + "raymarching_tolerance": 0, + "raymarching_chunk_size": -1 + } +} \ No newline at end of file diff --git a/configs/snerf_voxels_fine.json b/configs/snerf_voxels_fine.json new file mode 100644 index 0000000..2a356ed --- /dev/null +++ b/configs/snerf_voxels_fine.json @@ -0,0 +1,21 @@ +{ + "model": "SNeRF", + "args": { + "color": "rgb", + "n_pot_encode": 10, + "n_dir_encode": 4, + "fc_params": { + "nf": 512, + "n_layers": 8, + "activation": "relu", + "skips": [4] + }, + "n_featdim": 0, + "space": "voxels", + "steps": [32, 128, 64], + "n_samples": 128, + "perturb_sample": true, + "raymarching_tolerance": 0, + "raymarching_chunk_size": -1 + } +} \ No newline at end of file diff --git a/configs/snerfadv_finevoxels+ls.json b/configs/snerfadv_finevoxels+ls.json new file mode 100644 index 0000000..5511733 --- /dev/null +++ b/configs/snerfadv_finevoxels+ls.json @@ -0,0 +1,34 @@ +{ + "model": "SNeRFAdvance", + "args": { + "color": "rgb", + "n_pot_encode": 10, + "n_dir_encode": 4, + "density_net": { + "nf": 256, + "n_layers": 4, + "act": "relu", + "skips": [] + }, + "color_net": { + "nf": 256, + "n_layers": 3, + "act": "relu", + "skips": [] + }, + "specular_net": { + "nf": 128, + "n_layers": 1, + "act": "relu" + }, + "n_featdim": 0, + "space": "voxels", + "steps": [8, 32, 16], + "n_samples": 64, + "perturb_sample": true, + "appearance": "combined", + "density_color_connection": true, + "density_regularization_weight": 1e-4, + "density_regularization_scale": 1e4 + } +} \ No newline at end of file diff --git a/configs/snerfadv_finevoxels+ls_256x4_256x6_16x2.json b/configs/snerfadv_finevoxels+ls_256x4_256x6_16x2.json new file mode 100644 index 0000000..8ee790a --- /dev/null +++ b/configs/snerfadv_finevoxels+ls_256x4_256x6_16x2.json @@ -0,0 +1,34 @@ +{ + "model": "SNeRFAdvance", + "args": { + "color": "rgb", + "n_pot_encode": 10, + "n_dir_encode": 4, + "density_net": { + "nf": 256, + "n_layers": 4, + "act": "relu", + "skips": [] + }, + "color_net": { + "nf": 256, + "n_layers": 6, + "act": "relu", + "skips": [] + }, + "specular_net": { + "nf": 16, + "n_layers": 2, + "act": "relu" + }, + "n_featdim": 0, + "space": "voxels", + "steps": [16, 64, 32], + "n_samples": 64, + "perturb_sample": true, + "raymarching_tolerance": 0, + "raymarching_chunk_size": -1, + "density_regularization_weight": 1e-4, + "density_regularization_scale": 1e4 + } +} \ No newline at end of file diff --git a/configs/snerfadv_finevoxels+ls_256x4_256x6_combined.json b/configs/snerfadv_finevoxels+ls_256x4_256x6_combined.json new file mode 100644 index 0000000..ce9dd1b --- /dev/null +++ b/configs/snerfadv_finevoxels+ls_256x4_256x6_combined.json @@ -0,0 +1,30 @@ +{ + "model": "SNeRFAdvance", + "args": { + "color": "rgb", + "n_pot_encode": 10, + "n_dir_encode": 4, + "density_net": { + "nf": 256, + "n_layers": 4, + "act": "relu", + "skips": [] + }, + "color_net": { + "nf": 256, + "n_layers": 6, + "act": "relu", + "skips": [] + }, + "n_featdim": 0, + "space": "voxels", + "steps": [16, 64, 32], + "n_samples": 64, + "perturb_sample": true, + "raymarching_tolerance": 0, + "raymarching_chunk_size": -1, + "density_regularization_weight": 1e-4, + "density_regularization_scale": 1e4, + "appearance": "combined" + } +} \ No newline at end of file diff --git a/configs/snerfadv_finevoxels_ls2.json b/configs/snerfadv_finevoxels_ls2.json new file mode 100644 index 0000000..94d4fd6 --- /dev/null +++ b/configs/snerfadv_finevoxels_ls2.json @@ -0,0 +1,34 @@ +{ + "model": "SNeRFAdvance", + "args": { + "color": "rgb", + "n_pot_encode": 10, + "n_dir_encode": 4, + "density_net": { + "nf": 256, + "n_layers": 2, + "act": "relu", + "skips": [] + }, + "color_net": { + "nf": 256, + "n_layers": 3, + "act": "relu", + "skips": [] + }, + "specular_net": { + "nf": 128, + "n_layers": 1, + "act": "relu" + }, + "n_featdim": 0, + "space": "voxels", + "steps": [16, 64, 32], + "n_samples": 64, + "perturb_sample": true, + "raymarching_tolerance": 0, + "raymarching_chunk_size": -1, + "density_regularization_weight": 1e-4, + "density_regularization_scale": 1e4 + } +} \ No newline at end of file diff --git a/configs/snerfadv_voxels+ls+ns.json b/configs/snerfadv_voxels+ls+ns.json new file mode 100644 index 0000000..9e69a9d --- /dev/null +++ b/configs/snerfadv_voxels+ls+ns.json @@ -0,0 +1,36 @@ +{ + "model": "SNeRFAdvance", + "args": { + "color": "rgb", + "n_pot_encode": 10, + "n_dir_encode": 4, + "density_net": { + "nf": 256, + "n_layers": 4, + "act": "relu", + "skips": [] + }, + "color_net": { + "nf": 256, + "n_layers": 3, + "act": "relu", + "skips": [] + }, + "specular_net": { + "nf": 128, + "n_layers": 1, + "act": "relu" + }, + "n_featdim": 0, + "space": "voxels", + "steps": [4, 16, 8], + "n_samples": 16, + "perturb_sample": true, + "appearance": "newtype", + "density_color_connection": true, + "density_regularization_weight": 1e-4, + "density_regularization_scale": 1e4, + "specular_regularization_weight": 1e-1, + "specular_regularization_scale": 1e4 + } +} \ No newline at end of file diff --git a/configs/snerfadv_voxels+ls.json b/configs/snerfadv_voxels+ls.json new file mode 100644 index 0000000..e533e59 --- /dev/null +++ b/configs/snerfadv_voxels+ls.json @@ -0,0 +1,33 @@ +{ + "model": "SNeRFAdvance", + "args": { + "color": "rgb", + "n_pot_encode": 10, + "n_dir_encode": 4, + "density_net": { + "nf": 256, + "n_layers": 4, + "act": "relu", + "skips": [] + }, + "color_net": { + "nf": 256, + "n_layers": 3, + "act": "relu", + "skips": [] + }, + "specular_net": { + "nf": 128, + "n_layers": 1, + "act": "relu" + }, + "n_featdim": 0, + "space": "voxels", + "steps": [4, 16, 8], + "n_samples": 16, + "perturb_sample": true, + "appearance": "combined", + "density_regularization_weight": 1e-4, + "density_regularization_scale": 1e4 + } +} \ No newline at end of file diff --git a/configs/snerfadv_voxels+ls1.json b/configs/snerfadv_voxels+ls1.json new file mode 100644 index 0000000..d634054 --- /dev/null +++ b/configs/snerfadv_voxels+ls1.json @@ -0,0 +1,34 @@ +{ + "model": "SNeRFAdvance", + "args": { + "color": "rgb", + "n_pot_encode": 10, + "n_dir_encode": 4, + "density_net": { + "nf": 256, + "n_layers": 5, + "act": "relu", + "skips": [] + }, + "color_net": { + "nf": 256, + "n_layers": 2, + "act": "relu", + "skips": [] + }, + "specular_net": { + "nf": 128, + "n_layers": 1, + "act": "relu" + }, + "n_featdim": 0, + "space": "voxels", + "steps": [4, 16, 8], + "n_samples": 16, + "perturb_sample": true, + "appearance": "combined", + "density_color_connection": true, + "density_regularization_weight": 1e-4, + "density_regularization_scale": 1e4 + } +} \ No newline at end of file diff --git a/configs/snerfadv_voxels+ls2.json b/configs/snerfadv_voxels+ls2.json new file mode 100644 index 0000000..bbe5723 --- /dev/null +++ b/configs/snerfadv_voxels+ls2.json @@ -0,0 +1,34 @@ +{ + "model": "SNeRFAdvance", + "args": { + "color": "rgb", + "n_pot_encode": 10, + "n_dir_encode": 4, + "density_net": { + "nf": 256, + "n_layers": 4, + "act": "relu", + "skips": [] + }, + "color_net": { + "nf": 256, + "n_layers": 3, + "act": "relu", + "skips": [] + }, + "specular_net": { + "nf": 128, + "n_layers": 1, + "act": "relu" + }, + "n_featdim": 0, + "space": "voxels", + "steps": [4, 16, 8], + "n_samples": 16, + "perturb_sample": true, + "appearance": "combined", + "density_color_connection": true, + "density_regularization_weight": 1e-4, + "density_regularization_scale": 1e4 + } +} \ No newline at end of file diff --git a/configs/snerfadv_voxels+ls3.json b/configs/snerfadv_voxels+ls3.json new file mode 100644 index 0000000..51aa71f --- /dev/null +++ b/configs/snerfadv_voxels+ls3.json @@ -0,0 +1,34 @@ +{ + "model": "SNeRFAdvance", + "args": { + "color": "rgb", + "n_pot_encode": 10, + "n_dir_encode": 4, + "density_net": { + "nf": 256, + "n_layers": 3, + "act": "relu", + "skips": [] + }, + "color_net": { + "nf": 256, + "n_layers": 4, + "act": "relu", + "skips": [] + }, + "specular_net": { + "nf": 128, + "n_layers": 1, + "act": "relu" + }, + "n_featdim": 0, + "space": "voxels", + "steps": [4, 16, 8], + "n_samples": 16, + "perturb_sample": true, + "appearance": "combined", + "density_color_connection": true, + "density_regularization_weight": 1e-4, + "density_regularization_scale": 1e4 + } +} \ No newline at end of file diff --git a/configs/snerfadv_voxels+ls4.json b/configs/snerfadv_voxels+ls4.json new file mode 100644 index 0000000..99daf4d --- /dev/null +++ b/configs/snerfadv_voxels+ls4.json @@ -0,0 +1,34 @@ +{ + "model": "SNeRFAdvance", + "args": { + "color": "rgb", + "n_pot_encode": 10, + "n_dir_encode": 4, + "density_net": { + "nf": 256, + "n_layers": 2, + "act": "relu", + "skips": [] + }, + "color_net": { + "nf": 256, + "n_layers": 5, + "act": "relu", + "skips": [] + }, + "specular_net": { + "nf": 128, + "n_layers": 1, + "act": "relu" + }, + "n_featdim": 0, + "space": "voxels", + "steps": [4, 16, 8], + "n_samples": 16, + "perturb_sample": true, + "appearance": "combined", + "density_color_connection": true, + "density_regularization_weight": 1e-4, + "density_regularization_scale": 1e4 + } +} \ No newline at end of file diff --git a/configs/snerfadv_voxels+ls5.json b/configs/snerfadv_voxels+ls5.json new file mode 100644 index 0000000..897dcb9 --- /dev/null +++ b/configs/snerfadv_voxels+ls5.json @@ -0,0 +1,34 @@ +{ + "model": "SNeRFAdvance", + "args": { + "color": "rgb", + "n_pot_encode": 10, + "n_dir_encode": 4, + "density_net": { + "nf": 256, + "n_layers": 8, + "act": "relu", + "skips": [] + }, + "color_net": { + "nf": 256, + "n_layers": 6, + "act": "relu", + "skips": [] + }, + "specular_net": { + "nf": 128, + "n_layers": 2, + "act": "relu" + }, + "n_featdim": 0, + "space": "voxels", + "steps": [4, 16, 8], + "n_samples": 16, + "perturb_sample": true, + "appearance": "combined", + "density_color_connection": true, + "density_regularization_weight": 1e-4, + "density_regularization_scale": 1e4 + } +} \ No newline at end of file diff --git a/configs/snerfadv_voxels+ls6.json b/configs/snerfadv_voxels+ls6.json new file mode 100644 index 0000000..3971cae --- /dev/null +++ b/configs/snerfadv_voxels+ls6.json @@ -0,0 +1,34 @@ +{ + "model": "SNeRFAdvance", + "args": { + "color": "rgb", + "n_pot_encode": 10, + "n_dir_encode": 4, + "density_net": { + "nf": 512, + "n_layers": 4, + "act": "relu", + "skips": [] + }, + "color_net": { + "nf": 512, + "n_layers": 3, + "act": "relu", + "skips": [] + }, + "specular_net": { + "nf": 256, + "n_layers": 1, + "act": "relu" + }, + "n_featdim": 0, + "space": "voxels", + "steps": [4, 16, 8], + "n_samples": 16, + "perturb_sample": true, + "appearance": "combined", + "density_color_connection": true, + "density_regularization_weight": 1e-4, + "density_regularization_scale": 1e4 + } +} \ No newline at end of file diff --git a/configs/snerfadvx_voxels_x16.json b/configs/snerfadvx_voxels_x16.json new file mode 100644 index 0000000..7e55bea --- /dev/null +++ b/configs/snerfadvx_voxels_x16.json @@ -0,0 +1,34 @@ +{ + "model": "SNeRFAdvanceX", + "args": { + "color": "rgb", + "n_pot_encode": 10, + "n_dir_encode": 4, + "density_net": { + "nf": 128, + "n_layers": 4, + "act": "relu", + "skips": [] + }, + "color_net": { + "nf": 128, + "n_layers": 3, + "act": "relu", + "skips": [] + }, + "specular_net": { + "nf": 128, + "n_layers": 1, + "act": "relu" + }, + "n_featdim": 0, + "space": "_nets/hr_r0.8s/snerfadv_voxels+ls6/checkpoint_50.tar", + "n_samples": 256, + "perturb_sample": true, + "appearance": "combined", + "density_color_connection": true, + "density_regularization_weight": 1e-4, + "density_regularization_scale": 1e4, + "multi_nets": 16 + } +} \ No newline at end of file diff --git a/configs/snerfadvx_voxels_x4.json b/configs/snerfadvx_voxels_x4.json new file mode 100644 index 0000000..1e55fff --- /dev/null +++ b/configs/snerfadvx_voxels_x4.json @@ -0,0 +1,34 @@ +{ + "model": "SNeRFAdvanceX", + "args": { + "color": "rgb", + "n_pot_encode": 10, + "n_dir_encode": 4, + "density_net": { + "nf": 128, + "n_layers": 4, + "act": "relu", + "skips": [] + }, + "color_net": { + "nf": 128, + "n_layers": 3, + "act": "relu", + "skips": [] + }, + "specular_net": { + "nf": 128, + "n_layers": 1, + "act": "relu" + }, + "n_featdim": 0, + "space": "_nets/train_t0.3/snerfadv_voxels+ls2/checkpoint_50.tar", + "n_samples": 256, + "perturb_sample": true, + "appearance": "combined", + "density_color_connection": true, + "density_regularization_weight": 1e-4, + "density_regularization_scale": 1e4, + "multi_nets": 4 + } +} \ No newline at end of file diff --git a/configs/snerfadvx_voxels_x8.json b/configs/snerfadvx_voxels_x8.json new file mode 100644 index 0000000..1285ead --- /dev/null +++ b/configs/snerfadvx_voxels_x8.json @@ -0,0 +1,34 @@ +{ + "model": "SNeRFAdvanceX", + "args": { + "color": "rgb", + "n_pot_encode": 10, + "n_dir_encode": 4, + "density_net": { + "nf": 128, + "n_layers": 4, + "act": "relu", + "skips": [] + }, + "color_net": { + "nf": 128, + "n_layers": 3, + "act": "relu", + "skips": [] + }, + "specular_net": { + "nf": 128, + "n_layers": 1, + "act": "relu" + }, + "n_featdim": 0, + "space": "_nets/hr_t1.0s/snerfadv_voxels+ls2/checkpoint_50.tar", + "n_samples": 256, + "perturb_sample": true, + "appearance": "combined", + "density_color_connection": true, + "density_regularization_weight": 1e-4, + "density_regularization_scale": 1e4, + "multi_nets": 8 + } +} \ No newline at end of file diff --git a/configs/snerfx_voxels_128x4_x4.json b/configs/snerfx_voxels_128x4_x4.json new file mode 100644 index 0000000..3d3af48 --- /dev/null +++ b/configs/snerfx_voxels_128x4_x4.json @@ -0,0 +1,21 @@ +{ + "model": "SNeRFX", + "args": { + "color": "rgb", + "n_pot_encode": 10, + "n_dir_encode": 4, + "fc_params": { + "nf": 128, + "n_layers": 4, + "activation": "relu", + "skips": [] + }, + "n_featdim": 0, + "space": "nets/train1/snerf_voxels/checkpoint_50.tar", + "n_samples": 256, + "perturb_sample": true, + "raymarching_tolerance": 0, + "raymarching_chunk_size": -1, + "multi_nets": 4 + } +} \ No newline at end of file diff --git a/configs/snerfx_voxels_128x4_x8.json b/configs/snerfx_voxels_128x4_x8.json new file mode 100644 index 0000000..374d8fc --- /dev/null +++ b/configs/snerfx_voxels_128x4_x8.json @@ -0,0 +1,21 @@ +{ + "model": "SNeRFX", + "args": { + "color": "rgb", + "n_pot_encode": 10, + "n_dir_encode": 4, + "fc_params": { + "nf": 128, + "n_layers": 4, + "activation": "relu", + "skips": [] + }, + "n_featdim": 0, + "space": "nets/train_t0.3/snerf_voxels/checkpoint_50.tar", + "n_samples": 256, + "perturb_sample": true, + "raymarching_tolerance": 0, + "raymarching_chunk_size": -1, + "multi_nets": 8 + } +} \ No newline at end of file diff --git a/configs/snerfx_voxels_128x8_x4.json b/configs/snerfx_voxels_128x8_x4.json new file mode 100644 index 0000000..6a2fe8a --- /dev/null +++ b/configs/snerfx_voxels_128x8_x4.json @@ -0,0 +1,21 @@ +{ + "model": "SNeRFX", + "args": { + "color": "rgb", + "n_pot_encode": 10, + "n_dir_encode": 4, + "fc_params": { + "nf": 128, + "n_layers": 8, + "activation": "relu", + "skips": [4] + }, + "n_featdim": 0, + "space": "nets/train_t0.3/snerf_voxels/checkpoint_50.tar", + "n_samples": 256, + "perturb_sample": true, + "raymarching_tolerance": 0, + "raymarching_chunk_size": -1, + "multi_nets": 4 + } +} \ No newline at end of file diff --git a/configs/snerfx_voxels_256x4_x4.json b/configs/snerfx_voxels_256x4_x4.json new file mode 100644 index 0000000..706d22f --- /dev/null +++ b/configs/snerfx_voxels_256x4_x4.json @@ -0,0 +1,21 @@ +{ + "model": "SNeRFX", + "args": { + "color": "rgb", + "n_pot_encode": 10, + "n_dir_encode": 4, + "fc_params": { + "nf": 256, + "n_layers": 4, + "activation": "relu", + "skips": [] + }, + "n_featdim": 0, + "space": "nets/train1/snerf_voxels/checkpoint_50.tar", + "n_samples": 256, + "perturb_sample": true, + "raymarching_tolerance": 0, + "raymarching_chunk_size": -1, + "multi_nets": 4 + } +} \ No newline at end of file diff --git a/configs/snerfx_voxels_256x4_x4_balance.json b/configs/snerfx_voxels_256x4_x4_balance.json new file mode 100644 index 0000000..11bc5e8 --- /dev/null +++ b/configs/snerfx_voxels_256x4_x4_balance.json @@ -0,0 +1,22 @@ +{ + "model": "SNeRFX", + "args": { + "color": "rgb", + "n_pot_encode": 10, + "n_dir_encode": 4, + "fc_params": { + "nf": 256, + "n_layers": 4, + "activation": "relu", + "skips": [] + }, + "n_featdim": 0, + "space": "voxels", + "steps": [4, 16, 8], + "n_samples": 16, + "perturb_sample": true, + "raymarching_tolerance": 0, + "raymarching_chunk_size": -1, + "multi_nets": 4 + } +} \ No newline at end of file diff --git a/dash_test.py b/dash_test.py index 46f1434..9761996 100644 --- a/dash_test.py +++ b/dash_test.py @@ -1,67 +1,38 @@ import os -import argparse import torch import json import dash import dash_core_components as dcc import dash_html_components as html import plotly.express as px -import pandas as pd import numpy as np # from skimage import data +from pathlib import Path from dash.dependencies import Input, Output from dash.exceptions import PreventUpdate -if __name__ == '__main__': - parser = argparse.ArgumentParser() - parser.add_argument('--device', type=int, default=0, - help='Which CUDA device to use.') - opt = parser.parse_args() - - # Select device - torch.cuda.set_device(opt.device) - print("Set CUDA:%d as current device." % torch.cuda.current_device()) torch.autograd.set_grad_enabled(False) -from data.spherical_view_syn import * -from configs.spherical_view_syn import SphericalViewSynConfig -from utils import netio from utils import device from utils import view from utils import img from utils import misc -from nets.modules import AlphaComposition, Sampler +import model as mdl +from modules import AlphaComposition, Sampler -datadir = 'data/__new/lobby_fovea_r360x80_t1.0/' -data_desc_file = 'train1.json' +datadir = Path('data/__new/classroom_fovea_r360x80_t0.6') +data_desc_file = 'r120x80.json' net_config = 'fovea@snerffast4-rgb_e6_fc512x4_d2.00-50.00_s64_~p' -net_path = datadir + net_config + '/model-epoch_200.pth' +model_path = datadir / 'snerf_voxels/checkpoint_50.tar' fov = 40 res = (256, 256) pix_img_res = (256, 256) center = (0, 0) -def load_net(path): - print(path) - config = SphericalViewSynConfig() - config.from_id(net_config) - config.sa['perturb_sample'] = False - net = config.create_net().to(device.default()) - netio.load(path, net) - return net - - -def load_net_by_name(name): - for path in os.listdir(datadir): - if path.startswith(name + '@'): - return load_net(datadir + path) - return None - - def load_data_desc(data_desc_file) -> view.Trans: with open(datadir + data_desc_file, 'r', encoding='utf-8') as file: data_desc = json.loads(file.read()) @@ -85,7 +56,10 @@ cam = view.CameraParam({ 'cy': 0.5, 'normalized': True }, res, device=device.default()) -net = load_net(net_path) +model, _ = mdl.load(model_path, { + "perturb_sample": False +}) + # Global states x = y = None @@ -159,7 +133,7 @@ app.layout = html.Div([ def plot_alpha_and_density(ray_o, ray_d): # colors, densities, depths = net.sample_and_infer(ray_o, ray_d, sampler=sampler) - ret = net(ray_o, ray_d, ret_depth=True, debug=True) + ret = model(ray_o, ray_d, extra_outputs=['depth', 'layers']) colors = ret['layers'][..., : 3] densities = ret['sample_densities'] depths = ret['sample_depths'] @@ -202,7 +176,7 @@ def plot_pixel_image(ray_o, ray_d, r=1): ], dim=-1).to(device.default()) rays_d = pixel_point - rays_o rays_d /= rays_d.norm(dim=-1, keepdim=True) - image = net(rays_o.view(-1, 3), rays_d.view(-1, 3))['color'] \ + image = model(rays_o.view(-1, 3), rays_d.view(-1, 3))['color'] \ .view(1, *pix_img_res, -1).permute(0, 3, 1, 2) fig = px.imshow(img.torch2np(image)[0]) return fig @@ -230,10 +204,10 @@ def render_view(tx, ty, tz, rx, ry): torch.tensor(view.euler_to_matrix([ry, rx, 0]), device=device.default()).view(-1, 3, 3) ) rays_o, rays_d = cam.get_global_rays(test_view, True) - ret = net(rays_o.view(-1, 3), rays_d.view(-1, 3), debug=True) - image = ret['color'].view(1, res[0], res[1], 3).permute(0, 3, 1, 2) - layers = ret['layers'].view(res[0], res[1], -1, 4) - layer_weights = ret['weight'].view(res[0], res[1], -1) + ret = model(rays_o.view(-1, 3), rays_d.view(-1, 3), extra_outputs=['layers', 'weights']) + image = ret['color'].view(1, *res, 3).permute(0, 3, 1, 2) + layers = ret['layers'].view(*res, -1, 4) + layer_weights = ret['weight'].view(*res, -1) fig = px.imshow(img.torch2np(image)[0]) return fig @@ -241,17 +215,13 @@ def render_view(tx, ty, tz, rx, ry): def render_layer(layer): if layer is None: return None - layer_data = torch.sum(layers[..., range(*layer), :3] * layer_weights[..., range(*layer), None], - dim=-2) - #layer_data = layer_data[..., :3] * layer_data[..., 3:] + layer_data = torch.sum((layers * layer_weights)[..., range(*layer), :3], dim=-2) fig = px.imshow(img.torch2np(layer_data)) return fig def view_pixel(fig, x, y, samples): - sampler = Sampler(depth_range=(1, 50), n_samples=samples, - perturb_sample=False, spherical=True, - lindisp=True, inverse_r=True) + sampler = model.sampler if x is None or y is None: return None p = torch.tensor([x, y], device=device.default()) diff --git a/data/dataset_factory.py b/data/dataset_factory.py index 7a1f7c2..53793da 100644 --- a/data/dataset_factory.py +++ b/data/dataset_factory.py @@ -1,5 +1,8 @@ -import os import json +import os +from pathlib import Path +from typing import Union + import utils.device from .pano_dataset import PanoDataset from .view_dataset import ViewDataset @@ -8,16 +11,26 @@ from .view_dataset import ViewDataset class DatasetFactory(object): @staticmethod - def load(path, device=None, **kwargs): + def get_dataset_desc_path(path: Union[Path, str]): + if isinstance(path, str): + path = Path(path) + if path.suffix != ".json": + if os.path.exists(f"{path}.json"): + path = Path(f"{path}.json") + else: + path = path / "train.json" + return path + + @staticmethod + def load(path: Path, device=None, **kwargs): device = device or utils.device.default() - data_dir = os.path.dirname(path) + path = DatasetFactory.get_dataset_desc_path(path) with open(path, 'r', encoding='utf-8') as file: - data_desc = json.loads(file.read()) - cwd = os.getcwd() - os.chdir(data_dir) - if 'type' in data_desc and data_desc['type'] == 'pano': - dataset = PanoDataset(data_desc, device=device, **kwargs) + data_desc: dict = json.loads(file.read()) + if data_desc.get('type') == 'pano': + dataset_class = PanoDataset else: - dataset = ViewDataset(data_desc, device=device, **kwargs) - os.chdir(cwd) - return dataset \ No newline at end of file + dataset_class = ViewDataset + dataset = dataset_class(data_desc, root=path.absolute().parent, name=path.stem, + device=device, **kwargs) + return dataset diff --git a/data/loader.py b/data/loader.py index 49163cd..bcc474c 100644 --- a/data/loader.py +++ b/data/loader.py @@ -1,8 +1,8 @@ -from doctest import debug_script -from logging import * import threading import torch import math +from logging import * +from typing import Dict class Preloader(object): @@ -75,17 +75,18 @@ class DataLoader(object): self.chunk_idx += 1 self.current_chunk = self.chunks[self.chunk_idx] self.offset = 0 - self.indices = torch.randperm(len(self.current_chunk), device=self.device) \ + self.indices = torch.randperm(len(self.current_chunk)).to(device=self.device) \ if self.shuffle else None if self.preloader is not None: self.preloader.preload_chunk(self.chunks[(self.chunk_idx + 1) % len(self.chunks)]) def __init__(self, dataset, batch_size, *, - chunk_max_items=None, shuffle=False, enable_preload=True): + chunk_max_items=None, shuffle=False, enable_preload=True, **chunk_args): super().__init__() self.dataset = dataset self.batch_size = batch_size self.shuffle = shuffle + self.chunk_args = chunk_args self.preloader = Preloader(self.dataset.device) if enable_preload else None self._init_chunks(chunk_max_items) @@ -97,20 +98,18 @@ class DataLoader(object): return sum(math.ceil(len(chunk) / self.batch_size) for chunk in self.chunks) def _init_chunks(self, chunk_max_items): - data = self.dataset.get_data() + data: Dict[str, torch.Tensor] = self.dataset.get_data() if self.shuffle: - rand_seq = torch.randperm(self.dataset.n_views, device=self.dataset.device) - for key in data: - data[key] = data[key][rand_seq] + rand_seq = torch.randperm(self.dataset.n_views).to(device=self.dataset.device) + data = {key: val[rand_seq] for key, val in data.items()} self.chunks = [] n_chunks = 1 if chunk_max_items is None else \ math.ceil(self.dataset.n_pixels / chunk_max_items) views_per_chunk = math.ceil(self.dataset.n_views / n_chunks) for offset in range(0, self.dataset.n_views, views_per_chunk): sel = slice(offset, offset + views_per_chunk) - chunk_data = {} - for key in data: - chunk_data[key] = data[key][sel] - self.chunks.append(self.dataset.Chunk(len(self.chunks), self.dataset, **chunk_data)) + chunk_data = {key: val[sel] for key, val in data.items()} + self.chunks.append(self.dataset.Chunk(len(self.chunks), self.dataset, + chunk_data=chunk_data, **self.chunk_args)) if self.preloader is not None: self.preloader.preload_chunk(self.chunks[0]) diff --git a/data/pano_dataset.py b/data/pano_dataset.py index 9953c8f..918ba87 100644 --- a/data/pano_dataset.py +++ b/data/pano_dataset.py @@ -1,10 +1,12 @@ import os import torch import torch.nn.functional as nn_f -from typing import Tuple, Union +from typing import Dict, Tuple, Union +from operator import itemgetter +from pathlib import Path + from utils import img from utils import color -from utils import misc from utils import sphere from utils.mem_profiler import * from utils.constants import * @@ -27,8 +29,16 @@ class PanoDataset(object): class Chunk(object): - def __init__(self, id, dataset, *, - indices: torch.Tensor, centers: torch.Tensor): + @property + def n_views(self): + return self.indices.size(0) + + @property + def n_pixels_per_view(self): + return self.dataset.n_pixels_per_view + + def __init__(self, id: int, dataset, chunk_data: Dict[str, torch.Tensor], *, + color: int, **kwargs): """ [summary] @@ -38,10 +48,9 @@ class PanoDataset(object): """ self.id = id self.dataset = dataset - self.indices = indices - self.centers = centers - self.n_views = self.indices.size(0) - self.n_pixels_per_view = self.dataset.res[0] * self.dataset.res[1] + self.indices = chunk_data['indices'] + self.centers = chunk_data['centers'] + self.color = color self.colors_cpu = None self.colors = None self.loaded = False @@ -53,12 +62,12 @@ class PanoDataset(object): def load(self): if self.dataset.image_path is not None and self.colors_cpu is None: - images = color.cvt( - img.load(self.dataset.image_path % i for i in self.indices), - color.RGB, self.dataset.c) - if self.dataset.res != list(images.shape[-2:]): + images = color.cvt(img.load(self.dataset.image_path % i for i in self.indices), + color.RGB, self.color) + if self.dataset.res != tuple(images.shape[-2:]): images = nn_f.interpolate(images, self.dataset.res) - self.colors_cpu = images.permute(0, 2, 3, 1).flatten(0, 2) + self.colors_cpu = images.permute(0, 2, 3, 1) \ + [:, self.dataset.pixels[:, 0], self.dataset.pixels[:, 1]].flatten(0, 1) if self.colors_cpu is not None: self.colors = self.colors_cpu.to(self.dataset.device) self.loaded = True @@ -74,15 +83,27 @@ class PanoDataset(object): self.load() view_idx = idx // self.n_pixels_per_view pix_idx = idx % self.n_pixels_per_view + global_idx = self.indices[view_idx] * self.n_pixels_per_view + pix_idx extra_data = {} if self.colors is not None: - extra_data['colors'] = self.colors[idx] + extra_data['color'] = self.colors[idx] rays_o = self.centers[view_idx] - rays_d = self.dataset.pano_rays[pix_idx] - return idx, rays_o, rays_d, extra_data + rays_d = self.dataset.rays[pix_idx] + return global_idx, rays_o, rays_d, extra_data + + @property + def n_views(self): + return self.centers.size(0) + + @property + def n_pixels_per_view(self): + return self.pixels.size(0) + + @property + def n_pixels(self): + return self.n_views * self.n_pixels_per_view - def __init__(self, desc: dict, *, - c: int = color.RGB, + def __init__(self, desc: dict, root: Path, name: str, *, load_images: bool = True, res: Tuple[int, int] = None, views_to_load: Union[range, torch.Tensor] = None, @@ -104,7 +125,8 @@ class PanoDataset(object): :param c ```int```: color space to convert view images to :param calculate_rays ```bool```: whether calculate rays """ - self.c = c + self.root = root + self.name = name self.device = device self._load_desc(desc, res, views_to_load, load_images) @@ -119,26 +141,26 @@ class PanoDataset(object): views_to_load: Union[range, torch.Tensor], load_images: bool): if load_images and desc.get('view_file_pattern'): - self.image_path = os.path.join(os.getcwd(), desc['view_file_pattern']) + file_pattern = desc['view_file_pattern'] + if "/" not in file_pattern: + file_pattern = f"{self.name}/{file_pattern}" + self.image_path = str(self.root / file_pattern) else: self.image_path = None - self.res = res if res else misc.values(desc['view_res'], 'y', 'x') - self.depth_range = misc.values(desc['depth_range'], 'min', 'max') \ + self.res = res if res else itemgetter("y", "x")(desc['view_res']) + self.depth_range = itemgetter("min", "max")(desc['depth_range']) \ if 'depth_range' in desc else None - self.range = misc.values(desc['range'], 'min', 'max') if 'range' in desc else None + self.bbox = None self.samples = desc.get('samples') self.centers = torch.tensor(desc['view_centers'], device=self.device) # (N, 3) - self.indices = torch.tensor( - desc['views'] if 'views' in desc else list(range(self.centers.size(0))), - device=self.device) + self.indices = torch.tensor(desc.get('views') or [*range(self.centers.size(0))], + device=self.device) if views_to_load is not None: self.centers = self.centers[views_to_load] self.indices = self.indices[views_to_load] - self.n_views = self.centers.size(0) - self.n_pixels = self.n_views * self.res[0] * self.res[1] - self.pano_rays = self._get_pano_rays() # [H*W, 3] + self.pixels, self.rays = self._get_pano_rays() if desc.get('gl_coord'): print('Convert from OGL coordinate to DX coordinate (i. e. flip z axis)') @@ -148,12 +170,16 @@ class PanoDataset(object): """ Get unprojected rays of pixels on a panorama - :return `Tensor(H*W, 3)`: rays' directions with one unit length + :return `Tensor(N, 2)`: rays' pixel coordinates in pano image + :return `Tensor(N, 3)`: rays' directions with one unit length """ - spher_coords = torch.cat([ - torch.ones(*self.res, 1), - ((misc.meshgrid(*self.res, normalize=True)) * - torch.tensor([-2.0, 1.0]) + torch.tensor([1.5, 0.0])) * PI - ], dim=-1).to(device=self.device) - coords = sphere.spherical2cartesian(spher_coords) - return coords.flatten(0, 1) # [H*W, 3] + phi = (torch.arange(self.res[0], device=self.device) + 0.5) / self.res[0] * PI # (H) + length = (phi.sin() * self.res[1] * 0.5).ceil() * 2 + cols = torch.arange(self.res[1], device=self.device)[None, :].expand(*self.res) # (H, W) + mask = torch.logical_and(cols >= (self.res[1] - length[:, None]) / 2, + cols < (self.res[1] + length[:, None]) / 2) # (H, W) + pixs = mask.nonzero() # (N, 2) + pixs_phi = (0.5 - (pixs[:, 0] + 0.5) / self.res[0]) * PI + pixs_theta = (pixs[:, 1] * 2 + 1 - self.res[1]) / length[pixs[:, 0]] * PI + spher_coords = torch.stack([torch.ones_like(pixs_phi), pixs_theta, pixs_phi], dim=-1) + return pixs, sphere.spherical2cartesian(spher_coords) # (N, 3) diff --git a/data/view_dataset.py b/data/view_dataset.py index 477629b..34acf39 100644 --- a/data/view_dataset.py +++ b/data/view_dataset.py @@ -1,11 +1,13 @@ import os import torch import torch.nn.functional as nn_f -from typing import Tuple, Union +from typing import Dict, Tuple, Union +from operator import itemgetter +from pathlib import Path + from utils import img from utils import view from utils import color -from utils import misc class ViewDataset(object): @@ -25,20 +27,21 @@ class ViewDataset(object): class Chunk(object): - def __init__(self, id, dataset, *, - indices: torch.Tensor, centers: torch.Tensor, rots: torch.Tensor): + def __init__(self, id: int, dataset, chunk_data: Dict[str, torch.Tensor], *, + color: int, **kwargs): """ [summary] - :param dataset `PanoDataset`: dataset object + :param dataset `ViewDataset`: dataset object :param indices `Tensor(N)`: indices of views :param centers `Tensor(N, 3)`: centers of views """ self.id = id self.dataset = dataset - self.indices = indices - self.centers = centers - self.rots = rots + self.indices = chunk_data['indices'] + self.centers = chunk_data['centers'] + self.rots = chunk_data['rots'] + self.color = color self.n_views = self.indices.size(0) self.n_pixels_per_view = self.dataset.res[0] * self.dataset.res[1] self.colors = self.depths = self.bins = None @@ -50,35 +53,39 @@ class ViewDataset(object): self.loaded = False def load(self): - if self.dataset.image_path and self.colors_cpu is None: - images = color.cvt( - img.load(self.dataset.image_path % i for i in self.indices), - color.RGB, self.dataset.c) - if self.dataset.res != list(images.shape[-2:]): - images = nn_f.interpolate(images, self.dataset.res) - self.colors_cpu = images.permute(0, 2, 3, 1).flatten(0, 2) - if self.colors_cpu is not None: - self.colors = self.colors_cpu.to(self.dataset.device, non_blocking=True) - - if self.dataset.depth_path and self.depths_cpu is None: - depths = self.dataset._decode_depth_images( - img.load(self.depth_path % i for i in self.indices)) - if self.dataset.res != list(depths.shape[-2:]): - depths = nn_f.interpolate(depths, self.dataset.res) - self.depths_cpu = depths.flatten(0, 2) - if self.depths_cpu is not None: - self.depths = self.depths_cpu.to(self.dataset.device, non_blocking=True) - - if self.dataset.bins_path and self.bins_cpu is None: - bins = img.load([self.dataset.bins_path % i for i in self.indices]) - if self.dataset.res != list(bins.shape[-2:]): - bins = nn_f.interpolate(bins, self.dataset.res) - self.bins_cpu = bins.permute(0, 2, 3, 1).flatten(0, 2) - if self.bins_cpu is not None: - self.bins = self.bins_cpu.to(self.dataset.device, non_blocking=True) - - torch.cuda.current_stream(self.dataset.device).synchronize() - self.loaded = True + #print("chunk load") + try: + if self.dataset.image_path and self.colors_cpu is None: + images = color.cvt(img.load(self.dataset.image_path % i for i in self.indices), + color.RGB, self.color) + if self.dataset.res != list(images.shape[-2:]): + images = nn_f.interpolate(images, self.dataset.res) + self.colors_cpu = images.permute(0, 2, 3, 1).flatten(0, 2) + if self.colors_cpu is not None: + self.colors = self.colors_cpu.to(self.dataset.device, non_blocking=True) + + if self.dataset.depth_path and self.depths_cpu is None: + depths = self.dataset._decode_depth_images( + img.load(self.depth_path % i for i in self.indices)) + if self.dataset.res != list(depths.shape[-2:]): + depths = nn_f.interpolate(depths, self.dataset.res) + self.depths_cpu = depths.flatten(0, 2) + if self.depths_cpu is not None: + self.depths = self.depths_cpu.to(self.dataset.device, non_blocking=True) + + if self.dataset.bins_path and self.bins_cpu is None: + bins = img.load([self.dataset.bins_path % i for i in self.indices]) + if self.dataset.res != list(bins.shape[-2:]): + bins = nn_f.interpolate(bins, self.dataset.res) + self.bins_cpu = bins.permute(0, 2, 3, 1).flatten(0, 2) + if self.bins_cpu is not None: + self.bins = self.bins_cpu.to(self.dataset.device, non_blocking=True) + + torch.cuda.current_stream(self.dataset.device).synchronize() + self.loaded = True + except Exception as ex: + print(ex) + exit(-1) def __len__(self): return self.n_views * self.n_pixels_per_view @@ -88,21 +95,24 @@ class ViewDataset(object): self.load() view_idx = idx // self.n_pixels_per_view pix_idx = idx % self.n_pixels_per_view + global_idx = self.indices[view_idx] * self.n_pixels_per_view + pix_idx rays_o = self.centers[view_idx] - rays_d = self.dataset.cam_rays[pix_idx] # (N, 3) - r = self.rots[view_idx].movedim(-1, -2) # (N, 3, 3) - rays_d = torch.matmul(rays_d, r) - extra_data = {} + rays_d = self.dataset.cam_rays[pix_idx][:, None] # (N, 1, 3) + r = self.rots[view_idx].movedim(-1, -2) # (N, 3, 3) + rays_d = torch.matmul(rays_d, r)[:, 0] # (N, 3) + extra_data = { + 'view_idx': view_idx, + 'pix_idx': pix_idx + } # TBR if self.colors is not None: - extra_data['colors'] = self.colors[idx] + extra_data['color'] = self.colors[idx] if self.depths is not None: - extra_data['depths'] = self.depths[idx] + extra_data['depth'] = self.depths[idx] if self.bins is not None: - extra_data['bins'] = self.bins[idx] - return idx, rays_o, rays_d, extra_data + extra_data['bin'] = self.bins[idx] + return global_idx, rays_o, rays_d, extra_data - def __init__(self, desc: dict, *, - c: int = color.RGB, + def __init__(self, desc: dict, root: Path, name: str, *, load_images: bool = True, load_depths: bool = False, load_bins: bool = False, @@ -127,7 +137,8 @@ class ViewDataset(object): :param c ```int```: color space to convert view images to :param calculate_rays ```bool```: whether calculate rays """ - self.c = c + self.root = root + self.name = name self.device = device self._load_desc(desc, res, views_to_load, load_images, load_depths, load_bins) @@ -137,7 +148,7 @@ class ViewDataset(object): 'centers': self.centers, 'rots': self.rots } - + def _decode_depth_images(self, input): disp_range = (1 / self.depth_range[0], 1 / self.depth_range[1]) disp_val = (1 - input[..., 0, :, :]) * (disp_range[1] - disp_range[0]) + disp_range[0] @@ -150,22 +161,32 @@ class ViewDataset(object): load_depths: bool, load_bins: bool): if load_images and desc.get('view_file_pattern'): - self.image_path = os.path.join(self.data_dir, desc['view_file_pattern']) + file_pattern = desc['view_file_pattern'] + if "/" not in file_pattern: + file_pattern = f"{self.name}/{file_pattern}" + self.image_path = str(self.root / file_pattern) else: self.image_path = None if load_depths and desc.get('depth_file_pattern'): - self.depth_path = os.path.join(self.data_dir, desc['depth_file_pattern']) + file_pattern = desc['depth_file_pattern'] + if "/" not in file_pattern: + file_pattern = f"{self.name}/{file_pattern}" + self.depth_path = str(self.root / file_pattern) else: self.depth_path = None if load_bins and desc.get('bins_file_pattern'): - self.bins_path = os.path.join(self.data_dir, desc['bins_file_pattern']) + file_pattern = desc['bins_file_pattern'] + if "/" not in file_pattern: + file_pattern = f"{self.name}/{file_pattern}" + self.bins_path = str(self.root / file_pattern) else: self.bins_path = None - self.res = res if res else misc.values(desc['view_res'], 'y', 'x') + self.res = res or itemgetter("y", "x")(desc['view_res']) self.cam = view.CameraParam(desc['cam_params'], self.res, device=self.device) - self.depth_range = misc.values(desc['depth_range'], 'min', 'max') \ + self.depth_range = itemgetter("min", "max")(desc['depth_range']) \ if 'depth_range' in desc else None - self.range = misc.values(desc['range'], 'min', 'max') if 'range' in desc else None + self.range = itemgetter("min", "max")(desc['range']) if 'range' in desc else None + self.bbox = desc.get('bbox') self.samples = desc.get('samples') self.centers = torch.tensor(desc['view_centers'], device=self.device) # (N, 3) self.rots = torch.tensor( @@ -175,9 +196,8 @@ class ViewDataset(object): ] if len(desc['view_rots'][0]) == 2 else desc['view_rots'], device=self.device).view(-1, 3, 3) # (N, 3, 3) - self.indices = torch.tensor( - desc['views'] if 'views' in desc else list(range(self.centers.size(0))), - device=self.device) + self.indices = torch.tensor(desc.get('views') or [*range(self.centers.size(0))], + device=self.device) if views_to_load is not None: self.centers = self.centers[views_to_load] @@ -194,5 +214,5 @@ class ViewDataset(object): self.centers[:, 2] *= -1 self.rots[:, 2] *= -1 self.rots[..., 2] *= -1 - + self.cam_rays = self.cam.get_local_rays(flatten=True) diff --git a/debug/voxel_sampler_export3d.py b/debug/voxel_sampler_export3d.py new file mode 100644 index 0000000..bb24d54 --- /dev/null +++ b/debug/voxel_sampler_export3d.py @@ -0,0 +1,134 @@ +import os +import sys +import argparse +import torch + +sys.path.append(os.path.abspath(sys.path[0] + '/../')) + +parser = argparse.ArgumentParser() +parser.add_argument('-m', '--model', type=str, + help='The model file to load for testing') +parser.add_argument('-r', '--output-rays', type=int, default=100, + help='How many rays to output') +parser.add_argument('-p', '--prompt', action='store_true', + help='Interactive prompt mode') +parser.add_argument('dataset', type=str, + help='Dataset description file') +args = parser.parse_args() + + +import model as mdl +from utils import misc +from utils import color +from utils import interact +from utils import device +from data.dataset_factory import * +from data.loader import DataLoader +from modules import Samples, Voxels +from model.nsvf import NSVF + +model: NSVF +samples: Samples + +DATA_LOADER_CHUNK_SIZE = 1e8 + + +data_desc_path = args.dataset if args.dataset.endswith('.json') \ + else os.path.join(args.dataset, 'train.json') +data_desc_name = os.path.splitext(os.path.basename(data_desc_path))[0] +data_dir = os.path.dirname(data_desc_path) + '/' + + +def get_model_files(datadir): + model_files = [] + for root, _, files in os.walk(datadir): + model_files += [ + os.path.join(root, file).replace(datadir, '') + for file in files if file.endswith('.tar') or file.endswith('.pth') + ] + return model_files + + +if args.prompt: # Prompt test model, output resolution, output mode + model_files = get_model_files(data_dir) + args.model = interact.input_enum('Specify test model:', model_files, + err_msg='No such model file') + args.output_rays = interact.input_ex('Specify number of rays to output:', + interact.input_to_int(), default=10) + +model_path = os.path.join(data_dir, args.model) +model_name = os.path.splitext(os.path.basename(model_path))[0] +model, iters = mdl.load(model_path, {"perturb_sample": False}) +model.to(device.default()).eval() +model_class = model.__class__.__name__ +model_args = model.args +print(f"model: {model_name} ({model_class})") +print("args:", json.dumps(model.args0)) + +dataset = DatasetFactory.load(data_desc_path) +print("Dataset loaded: " + data_desc_path) + +run_dir = os.path.dirname(model_path) + '/' +output_dir = f"{run_dir}output_{int(model_name.split('_')[-1])}" + + +if __name__ == "__main__": + with torch.no_grad(): + # 1. Initialize data loader + data_loader = DataLoader(dataset, args.output_rays, chunk_max_items=DATA_LOADER_CHUNK_SIZE, + shuffle=True, enable_preload=True, + color=color.from_str(model.args['color'])) + sys.stdout.write("Export samples...\r") + for _, rays_o, rays_d, extra in data_loader: + samples, rays_mask = model.sampler(rays_o, rays_d, model.space) + invalid_rays_o = rays_o[torch.logical_not(rays_mask)] + invalid_rays_d = rays_d[torch.logical_not(rays_mask)] + rays_o = rays_o[rays_mask] + rays_d = rays_d[rays_mask] + break + print("Export samples...Done") + + os.makedirs(output_dir, exist_ok=True) + + export_data = {} + + if model.space.bbox is not None: + export_data['bbox'] = model.space.bbox.tolist() + if isinstance(model.space, Voxels): + export_data['voxel_size'] = model.space.voxel_size.tolist() + export_data['voxels'] = model.space.voxels.tolist() + + if False: + voxel_access_counts = torch.zeros(model.space.n_voxels, dtype=torch.long, + device=device.default()) + iters_in_epoch = 0 + data_loader.batch_size = 2 ** 20 + for _, rays_o1, rays_d1, _ in data_loader: + model(rays_o1, rays_d1, + raymarching_tolerance=0.5, + raymarching_chunk_size=0, + voxel_access_counts=voxel_access_counts) + iters_in_epoch += 1 + percent = iters_in_epoch / len(data_loader) * 100 + sys.stdout.write(f'Export voxel access counts...{percent:.1f}% \r') + export_data['voxel_access_counts'] = voxel_access_counts.tolist() + print("Export voxel access counts...Done ") + + export_data.update({ + 'rays_o': rays_o.tolist(), + 'rays_d': rays_d.tolist(), + 'invalid_rays_o': invalid_rays_o.tolist(), + 'invalid_rays_d': invalid_rays_d.tolist(), + 'samples': { + 'depths': samples.depths.tolist(), + 'dists': samples.dists.tolist(), + 'voxel_indices': samples.voxel_indices.tolist() + } + }) + with open(f'{output_dir}/debug_voxel_sampler_export3d.json', 'w') as fp: + json.dump(export_data, fp) + print("Write JSON file...Done") + + args.output_rays + print(f"Rays: total {args.output_rays}, valid {rays_o.size(0)}") + print(f"Samples: average {samples.voxel_indices.ne(-1).sum(-1).float().mean().item()} per ray") diff --git a/fntest.py b/fntest.py new file mode 100644 index 0000000..5f5a879 --- /dev/null +++ b/fntest.py @@ -0,0 +1,12 @@ +from math import ceil + +cdf = [2.2, 3.5, 3.6, 3.7, 4.0] +bins = [] +part = 1 +offset = 0 +for i in range(len(cdf)): + if cdf[i] >= part: + bins.append(i + 1 - offset) + offset = i + 1 + part = int(cdf[i]) + 1 +print(bins) \ No newline at end of file diff --git a/loss/__init__.py b/loss/__init__.py new file mode 100644 index 0000000..a4eecbc --- /dev/null +++ b/loss/__init__.py @@ -0,0 +1,5 @@ +from torch.nn import L1Loss, MSELoss +from torch.nn.functional import l1_loss, mse_loss +from .ssim import SSIM +from .perc_loss import VGGPerceptualLoss +from .cauchy import cauchy_loss, CauchyLoss \ No newline at end of file diff --git a/loss/cauchy.py b/loss/cauchy.py new file mode 100644 index 0000000..5dd213e --- /dev/null +++ b/loss/cauchy.py @@ -0,0 +1,16 @@ +import torch + + +def cauchy_loss(input: torch.Tensor, target: torch.Tensor = None, *, s = 1.0): + x = input - target if target is not None else input + return (s * x * x * 0.5 + 1).log().mean() + + +class CauchyLoss(torch.nn.Module): + + def __init__(self, s = 1.0): + super().__init__() + self.s = s + + def forward(self, input: torch.Tensor, target: torch.Tensor = None): + return cauchy_loss(input, target, s=self.s) diff --git a/loss/ssim.py b/loss/ssim.py index 93f390b..cd38987 100644 --- a/loss/ssim.py +++ b/loss/ssim.py @@ -1,7 +1,6 @@ import torch import torch.nn.functional as F from torch.autograd import Variable -import numpy as np from math import exp def gaussian(window_size, sigma): diff --git a/model/__init__.py b/model/__init__.py new file mode 100644 index 0000000..60f9f6b --- /dev/null +++ b/model/__init__.py @@ -0,0 +1,45 @@ +import importlib +import os +import torch +from typing import Tuple, Union +from . import base + + +# Automatically import any python files this directory +package_dir = os.path.dirname(__file__) +package = os.path.basename(package_dir) +for file in os.listdir(package_dir): + path = os.path.join(package_dir, file) + if file.startswith('_') or file.startswith('.'): + continue + if file.endswith('.py') or os.path.isdir(path): + model_name = file[:-3] if file.endswith('.py') else file + importlib.import_module(f'{package}.{model_name}') + + +def get_class(model_class_name: str) -> type: + return base.model_classes[model_class_name] + + +def create(model_class_name: str, args0: dict, **extra_args) -> base.BaseModel: + model_class = get_class(model_class_name) + return model_class(args0, extra_args) + + +def load(path: Union[str, os.PathLike], args0: dict = {}, **extra_args) -> Tuple[base.BaseModel, dict]: + states: dict = torch.load(path) + states['args'].update(args0) + model = create(states['model'], states['args'], **extra_args) + model.load_state_dict(states['states']) + return model, states + + +def save(path: Union[str, os.PathLike], model: base.BaseModel, **extra_states): + #print(f'Save model to {path}...') + dict = { + 'model': model.__class__.__name__, + 'args': model.args0, + 'states': model.state_dict(), + **extra_states + } + torch.save(dict, path) diff --git a/model/base.py b/model/base.py new file mode 100644 index 0000000..324ad93 --- /dev/null +++ b/model/base.py @@ -0,0 +1,34 @@ +import torch.nn as nn +from utils import color + + +model_classes = {} + + +class BaseModelMeta(type): + + def __new__(cls, name, bases, attrs): + new_cls = type.__new__(cls, name, bases, attrs) + if name != 'BaseModel': + model_classes[name] = new_cls + return new_cls + + +class BaseModel(nn.Module, metaclass=BaseModelMeta): + + trainer = "Train" + + @property + def args(self): + return {**self.args0, **self.args1} + + def __init__(self, args0: dict, args1: dict = {}): + super().__init__() + self.args0 = args0 + self.args1 = args1 + self._chns = { + "color": color.chns(color.from_str(self.args['color'])) + } + + def chns(self, name: str): + return self._chns.get(name, 1) \ No newline at end of file diff --git a/nets/bg_net.py b/model/bg_net.py similarity index 100% rename from nets/bg_net.py rename to model/bg_net.py diff --git a/model/nerf.py b/model/nerf.py new file mode 100644 index 0000000..35adb39 --- /dev/null +++ b/model/nerf.py @@ -0,0 +1,181 @@ +import torch + +import model +from .base import * +from modules import * +from utils.mem_profiler import MemProfiler +from utils.perf import perf +from utils.misc import masked_scatter + + +class NeRF(BaseModel): + + trainer = "TrainWithSpace" + SamplerClass = Sampler + RendererClass = VolumnRenderer + + def __init__(self, args0: dict, args1: dict = {}): + """ + Initialize a NeRF model + + :param args0 `dict`: basic arguments + :param args1 `dict`: extra arguments, defaults to {} + """ + if "sample_step_ratio" in args0: + args1["sample_step"] = args0["voxel_size"] * args0["sample_step_ratio"] + super().__init__(args0, args1) + + # Initialize components + self._init_space() + self._init_encoders() + self._init_core() + self.sampler = self.SamplerClass(**self.args) + self.rendering = self.RendererClass(**self.args) + + def _init_encoders(self): + self.pot_encoder = InputEncoder.Get(self.args['n_pot_encode'], + self.args.get('n_featdim') or 3) + if self.args.get('n_dir_encode'): + self.dir_chns = 3 + self.dir_encoder = InputEncoder.Get(self.args['n_dir_encode'], self.dir_chns) + else: + self.dir_chns = 0 + self.dir_encoder = None + + def _init_space(self): + if 'space' not in self.args: + self.space = Space(**self.args) + elif self.args['space'] == 'octree': + self.space = Octree(**self.args) + elif self.args['space'] == 'voxels': + self.space = Voxels(**self.args) + else: + self.space = model.load(self.args['space'])[0].space + if self.args.get('n_featdim'): + self.space.create_embedding(self.args['n_featdim']) + + def _new_core_unit(self): + return NerfCore(coord_chns=self.pot_encoder.out_dim, + density_chns=self.chns('density'), + color_chns=self.chns('color'), + core_nf=self.args['fc_params']['nf'], + core_layers=self.args['fc_params']['n_layers'], + dir_chns=self.dir_encoder.out_dim if self.dir_encoder else 0, + dir_nf=self.args['fc_params']['nf'] // 2, + act=self.args['fc_params']['activation'], + skips=self.args['fc_params']['skips']) + + def _create_core(self, n_nets=1): + return self._new_core_unit() if n_nets == 1 else nn.ModuleList([ + self._new_core_unit() for _ in range(n_nets) + ]) + + def _init_core(self): + if not self.args.get("net_bounds"): + self.core = self._create_core() + else: + self.register_buffer("net_bounds", torch.tensor(self.args["net_bounds"]), False) + self.cores = self._create_core(self.net_bounds.size(0)) + + def render(self, samples: Samples, *outputs: str, **kwargs) -> Dict[str, torch.Tensor]: + """ + Render colors, energies and other values (specified by `outputs`) of samples + (invalid items are filtered out) + + :param samples `Samples(N)`: samples + :param outputs `str...`: which types of inferred data should be returned + :return `Dict[str, Tensor(N, *)]`: outputs of cores + """ + x = self.encode_x(samples) + d = self.encode_d(samples) + return self.infer(x, d, *outputs, pts=samples.pts, **kwargs) + + def infer(self, x: torch.Tensor, d: torch.Tensor, *outputs, pts: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]: + """ + Infer colors, energies and other values (specified by `outputs`) of samples + (invalid items are filtered out) given their encoded positions and directions + + :param x `Tensor(N, Ex)`: encoded positions + :param d `Tensor(N, Ed)`: encoded directions + :param outputs `str...`: which types of inferred data should be returned + :param pts `Tensor(N, 3)`: raw sample positions + :return `Dict[str, Tensor(N, *)]`: outputs of cores + """ + if getattr(self, "core", None): + return self.core(x, d, outputs) + ret = {} + for i, core in enumerate(self.cores): + selector = (pts >= self.net_bounds[i, 0] and pts < self.net_bounds[i, 1]).all(-1) + partial_ret = core(x[selector], d[selector], outputs) + for key, value in partial_ret.items(): + if value is None: + ret[key] = None + continue + if key not in ret: + ret[key] = torch.zeros(*x.shape[:-1], value.shape[-1], device=x.device) + ret[key] = masked_scatter(selector, value, ret[key]) + return ret + + def embed(self, samples: Samples) -> torch.Tensor: + return self.space.extract_embedding(samples.pts, samples.voxel_indices) + + def encode_x(self, samples: Samples) -> torch.Tensor: + x = self.embed(samples) if self.args.get('n_featdim') else samples.pts + return self.pot_encoder(x) + + def encode_d(self, samples: Samples) -> torch.Tensor: + return self.dir_encoder(samples.dirs) if self.dir_encoder is not None else None + + @torch.no_grad() + def get_scores(self, sampled_points: torch.Tensor, sampled_voxel_indices: torch.Tensor) -> torch.Tensor: + densities = self.render(Samples(sampled_points, None, None, None, sampled_voxel_indices), + 'density') + return 1 - (-densities).exp() + + @torch.no_grad() + def pruning(self, threshold: float = 0.5, train_stats=False): + return self.space.pruning(self.get_scores, threshold, train_stats) + + @torch.no_grad() + def splitting(self): + ret = self.space.splitting() + if 'n_samples' in self.args0: + self.args0['n_samples'] *= 2 + if 'voxel_size' in self.args0: + self.args0['voxel_size'] /= 2 + if "sample_step_ratio" in self.args0: + self.args1["sample_step"] = self.args0["voxel_size"] \ + * self.args0["sample_step_ratio"] + if 'sample_step' in self.args0: + self.args0['sample_step'] /= 2 + self.sampler = self.SamplerClass(**self.args) + return ret + + @torch.no_grad() + def double_samples(self): + pass + + @perf + def forward(self, rays_o: torch.Tensor, rays_d: torch.Tensor, *, + extra_outputs: List[str] = [], **kwargs) -> torch.Tensor: + """ + Perform rendering for given rays. + + :param rays_o `Tensor(N, 3)`: rays' origin + :param rays_d `Tensor(N, 3)`: rays' direction + :param extra_outputs `list[str]`: extra items should be contained in the rendering result, + defaults to [] + :return `dict[str, Tensor]`: the rendering result, see corresponding Renderer implementation + """ + args = {**self.args, **kwargs} + with MemProfiler(f"{self.__class__}.forward: before sampling"): + samples, rays_mask = self.sampler(rays_o, rays_d, self.space, **args) + MemProfiler.print_memory_stats(f"{self.__class__}.forward: after sampling") + with MemProfiler(f"{self.__class__}.forward: rendering"): + if samples is None: + return None + return { + **self.rendering(self, samples, extra_outputs, **args), + 'samples': samples, + 'rays_mask': rays_mask + } diff --git a/model/nerf_advance.py b/model/nerf_advance.py new file mode 100644 index 0000000..b3d9716 --- /dev/null +++ b/model/nerf_advance.py @@ -0,0 +1,37 @@ +import torch +from modules import * +from .nerf import * + + +class NeRFAdvance(NeRF): + + RendererClass = DensityFirstVolumnRenderer + + def __init__(self, args0: dict, args1: dict = {}): + super().__init__(args0, args1) + + def _new_core_unit(self): + return NerfAdvCore( + x_chns=self.pot_encoder.out_dim, + d_chns=self.dir_encoder.out_dim, + density_chns=self.chns('density'), + color_chns=self.chns('color'), + density_net_params=self.args["density_net"], + color_net_params=self.args["color_net"], + specular_net_params=self.args.get("specular_net"), + appearance=self.args.get("appearance", "decomposite"), + density_color_connection=self.args.get("density_color_connection", False) + ) + + def infer(self, x: torch.Tensor, d: torch.Tensor, *outputs, extras={}, **kwargs) -> Dict[str, torch.Tensor]: + """ + Infer colors, energies and other values (specified by `outputs`) of samples + (invalid items are filtered out) given their encoded positions and directions + + :param x `Tensor(N, Ex)`: encoded positions + :param d `Tensor(N, Ed)`: encoded directions + :param outputs `str...`: which types of inferred data should be returned + :param extras `dict`: extra data needed by cores + :return `Dict[str, Tensor(N, *)]`: outputs of cores + """ + return self.core(x, d, outputs, **extras) diff --git a/nets/nerf_depth.py b/model/nerf_depth.py similarity index 96% rename from nets/nerf_depth.py rename to model/nerf_depth.py index fe5fafd..8826dfc 100644 --- a/nets/nerf_depth.py +++ b/model/nerf_depth.py @@ -27,7 +27,7 @@ class NerfDepth(nn.Module): color_chns=self.color_chns, core_nf=fc_params['nf'], core_layers=fc_params['n_layers'], - activation=fc_params['activation'], + act=fc_params['activation'], skips=fc_params['skips']) self.sampler = AdaptiveSampler(**sampler_params, n_bins=n_bins, include_neighbor_bins=include_neighbor_bins) diff --git a/model/nsvf.py b/model/nsvf.py new file mode 100644 index 0000000..08bfaeb --- /dev/null +++ b/model/nsvf.py @@ -0,0 +1,16 @@ +from .nerf import * +from utils.geometry import * + + +class NSVF(NeRF): + + SamplerClass = VoxelSampler + + def __init__(self, args0: dict, args1: dict = {}): + """ + Initialize a NSVF model + + :param args0 `dict`: basic arguments + :param args1 `dict`: extra arguments, defaults to {} + """ + super().__init__(args0, args1) diff --git a/nets/oracle.py b/model/oracle.py similarity index 96% rename from nets/oracle.py rename to model/oracle.py index 7f61890..1fa3e01 100644 --- a/nets/oracle.py +++ b/model/oracle.py @@ -27,7 +27,7 @@ class Oracle(nn.Module): self.net = nn.Sequential( FcNet(in_chns=self.pos_encoder.out_dim * self.n_samples, out_chns=0, nf=fc_params['nf'], n_layers=fc_params['n_layers'], - skips=[], activation=fc_params['activation']), + skips=[], act=fc_params['activation']), FcLayer(fc_params['nf'], self.n_samples, out_activation) ) diff --git a/model/snerf.py b/model/snerf.py new file mode 100644 index 0000000..534c515 --- /dev/null +++ b/model/snerf.py @@ -0,0 +1,26 @@ +import math +from .nerf import * + + +class SNeRF(NeRF): + SamplerClass = SphericalSampler + + def __init__(self, args0: dict, args1: dict = {}): + """ + Initialize a multi-sphere-layer net + + :param fc_params: parameters for full-connection network + :param sampler_params: parameters for sampler + :param normalize_coord: whether normalize the spherical coords to [0, 2pi] before encode + :param c: color mode + :param encode_to_dim: encode input to number of dimensions + """ + sample_range = [1 / args0['depth_range'][0], 1 / args0['depth_range'][1]] \ + if args0.get('depth_range') else [1, 0] + rot_range = [[-180, -90], [180, 90]] + args1['bbox'] = [ + [sample_range[0], math.radians(rot_range[0][0]), math.radians(rot_range[0][1])], + [sample_range[1], math.radians(rot_range[1][0]), math.radians(rot_range[1][1])] + ] + args1['sample_range'] = sample_range + super().__init__(args0, args1) \ No newline at end of file diff --git a/model/snerf_advance.py b/model/snerf_advance.py new file mode 100644 index 0000000..dd321c8 --- /dev/null +++ b/model/snerf_advance.py @@ -0,0 +1,33 @@ +import math +from .nerf_advance import * + + +class SNeRFAdvance(NeRFAdvance): + SamplerClass = SphericalSampler + + def __init__(self, args0: dict, args1: dict = {}): + """ + Initialize a multi-sphere-layer net + + :param fc_params: parameters for full-connection network + :param sampler_params: parameters for sampler + :param normalize_coord: whether normalize the spherical coords to [0, 2pi] before encode + :param c: color mode + :param encode_to_dim: encode input to number of dimensions + """ + sample_range = [1 / args0['depth_range'][0], 1 / args0['depth_range'][1]] \ + if args0.get('depth_range') else [1, 0] + rot_range = [[-180, -90], [180, 90]] + args1['bbox'] = [ + [sample_range[0], math.radians(rot_range[0][0]), math.radians(rot_range[0][1])], + [sample_range[1], math.radians(rot_range[1][0]), math.radians(rot_range[1][1])] + ] + args1['sample_range'] = sample_range + if args0.get('multi_nets'): + n = args0['multi_nets'] + step = (sample_range[1] - sample_range[0]) / n + args1['net_bounds'] = [[ + [sample_range[0] + step * (i + 1), *args1['bbox'][0][1:]], + [sample_range[0] + step * i, *args1['bbox'][1][1:]] + ] for i in range(n)] + super().__init__(args0, args1) \ No newline at end of file diff --git a/model/snerf_advance_x.py b/model/snerf_advance_x.py new file mode 100644 index 0000000..de2d71c --- /dev/null +++ b/model/snerf_advance_x.py @@ -0,0 +1,74 @@ +from utils.misc import print_and_log +from .snerf_advance import * + + +class SNeRFAdvanceX(SNeRFAdvance): + + RendererClass = DensityFirstVolumnRenderer + + def __init__(self, args0: dict, args1: dict = {}): + """ + Initialize a multi-sphere-layer net + + :param fc_params: parameters for full-connection network + :param sampler_params: parameters for sampler + :param normalize_coord: whether normalize the spherical coords to [0, 2pi] before encode + :param c: color mode + :param encode_to_dim: encode input to number of dimensions + """ + super().__init__(args0, args1) + + def _init_core(self): + if "net_samples" not in self.args: + n_nets = self.args.get("multi_nets", 1) + k = self.args["n_samples"] // self.space.steps[0].item() + self.args0["net_samples"] = [val * k for val in self.space.balance_cut(0, n_nets)] + self.cores = self._create_core(len(self.args0["net_samples"])) + + def infer(self, x: torch.Tensor, d: torch.Tensor, *outputs, chunk_id: int, extras={}, **kwargs) -> Dict[str, torch.Tensor]: + """ + Infer colors, energies and other values (specified by `outputs`) of samples + (invalid items are filtered out) given their encoded positions and directions + + :param x `Tensor(N, Ex)`: encoded positions + :param d `Tensor(N, Ed)`: encoded directions + :param outputs `str...`: which types of inferred data should be returned + :param chunk_id `int`: current index of sample chunk in renderer + :param extras `dict`: extra data needed by cores + :return `Dict[str, Tensor(N, *)]`: outputs of cores + """ + return self.cores[chunk_id](x, d, outputs, **extras) + + @torch.no_grad() + def get_scores(self, sampled_points: torch.Tensor, sampled_voxel_indices: torch.Tensor) -> torch.Tensor: + raise NotImplementedError() + + @torch.no_grad() + def pruning(self, threshold: float = 0.5, train_stats=False): + raise NotImplementedError() + + @torch.no_grad() + def splitting(self): + ret = super().splitting() + k = self.args["n_samples"] // self.space.steps[0].item() + net_samples = [val * k for val in self.space.balance_cut(0, len(self.cores))] + if len(net_samples) != len(self.cores): + print_and_log('Note: the result of balance cut has no enough bins. Keep origin cut.') + net_samples = [val * 2 for val in self.args0["net_samples"]] + self.args0['net_samples'] = net_samples + return ret + + @perf + def forward(self, rays_o: torch.Tensor, rays_d: torch.Tensor, *, + extra_outputs: List[str] = [], **kwargs) -> torch.Tensor: + """ + Perform rendering for given rays. + + :param rays_o `Tensor(N, 3)`: rays' origin + :param rays_d `Tensor(N, 3)`: rays' direction + :param extra_outputs `list[str]`: extra items should be contained in the rendering result, + defaults to [] + :return `dict[str, Tensor]`: the rendering result, see corresponding Renderer implementation + """ + return super().forward(rays_o, rays_d, extra_outputs=extra_outputs, **kwargs, + raymarching_chunk_size_or_sections=self.args["net_samples"]) diff --git a/nets/snerf_fast.py b/model/snerf_fast.py similarity index 98% rename from nets/snerf_fast.py rename to model/snerf_fast.py index d99165e..4d627e8 100644 --- a/nets/snerf_fast.py +++ b/model/snerf_fast.py @@ -48,7 +48,7 @@ class SnerfFast(nn.Module): core_layers=fc_params['n_layers'], dir_chns=self.dir_chns_per_part, dir_nf=fc_params['nf'] // 2, - activation=fc_params['activation']) + act=fc_params['activation']) for _ in range(self.n_parts) ] for i in range(self.n_parts): diff --git a/model/snerf_x.py b/model/snerf_x.py new file mode 100644 index 0000000..64bfcf9 --- /dev/null +++ b/model/snerf_x.py @@ -0,0 +1,79 @@ +from utils.misc import print_and_log +from .snerf import * + + +class SNeRFX(SNeRF): + + trainer = "TrainWithSpace" + SamplerClass = SphericalSampler + RendererClass = VolumnRenderer + + def __init__(self, args0: dict, args1: dict = {}): + """ + Initialize a multi-sphere-layer net + + :param fc_params: parameters for full-connection network + :param sampler_params: parameters for sampler + :param normalize_coord: whether normalize the spherical coords to [0, 2pi] before encode + :param c: color mode + :param encode_to_dim: encode input to number of dimensions + """ + super().__init__(args0, args1) + + def _init_core(self): + if "net_samples" not in self.args: + n_nets = self.args.get("multi_nets", 1) + k = self.args["n_samples"] // self.space.steps[0].item() + self.args0["net_samples"] = [val * k for val in self.space.balance_cut(0, n_nets)] + self.cores = self._create_core(len(self.args0["net_samples"])) + + def render(self, samples: Samples, *outputs: str, chunk_id: int, **kwargs) -> Dict[str, torch.Tensor]: + """ + Infer colors, energies and other values (specified by `outputs`) of samples + (invalid items are filtered out) + + :param samples `Samples(N)`: samples + :param outputs `str...`: which types of inferred data should be returned + :param chunk_id `int`: current index of sample chunk in renderer + :return `Dict[str, Tensor(N, *)]`: outputs of cores + """ + x = self.encode_x(samples) + d = self.encode_d(samples) + return self.cores[chunk_id](x, d, outputs) + + @torch.no_grad() + def get_scores(self, sampled_points: torch.Tensor, sampled_voxel_indices: torch.Tensor) -> torch.Tensor: + raise NotImplementedError() + + @torch.no_grad() + def pruning(self, threshold: float = 0.5, train_stats=False): + raise NotImplementedError() + + @torch.no_grad() + def splitting(self): + ret = super().splitting() + k = self.args["n_samples"] // self.space.steps[0].item() + net_samples = [ + val * k for val in self.space.balance_cut(0, len(self.cores)) + ] + if len(net_samples) != len(self.cores): + print_and_log('Note: the result of balance cut has no enough bins. Keep origin cut.') + net_samples = [val * 2 for val in self.args0["net_samples"]] + self.args0['net_samples'] = net_samples + self.sampler = self.SamplerClass(**self.args) + return ret + + @perf + def forward(self, rays_o: torch.Tensor, rays_d: torch.Tensor, *, + extra_outputs: List[str] = [], **kwargs) -> torch.Tensor: + """ + Perform rendering for given rays. + + :param rays_o `Tensor(N, 3)`: rays' origin + :param rays_d `Tensor(N, 3)`: rays' direction + :param extra_outputs `list[str]`: extra items should be contained in the rendering result, + defaults to [] + :return `dict[str, Tensor]`: the rendering result, see corresponding Renderer implementation + """ + return super().forward(rays_o, rays_d, extra_outputs=extra_outputs, **kwargs, + raymarching_chunk_size_or_sections=self.args["net_samples"]) diff --git a/modules/__init__.py b/modules/__init__.py index 2facaa3..c45b69e 100644 --- a/modules/__init__.py +++ b/modules/__init__.py @@ -1,43 +1,5 @@ -from typing import Tuple -import torch -import torch.nn as nn -from torch.nn.modules.linear import Identity -from utils.constants import * -from .generic import * from .sampler import * from .input_encoder import * from .renderer import * - - -class NerfCore(nn.Module): - - def __init__(self, *, coord_chns, density_chns, color_chns, core_nf, core_layers, - dir_chns=0, dir_nf=0, activation='relu', skips=[]): - super().__init__() - self.core = FcNet(in_chns=coord_chns, out_chns=0, nf=core_nf, n_layers=core_layers, - skips=skips, activation=activation) - self.density_out = FcLayer(core_nf, density_chns) if density_chns > 0 else None - if color_chns == 0: - self.feature_out = None - self.color_out = None - elif dir_chns > 0: - self.feature_out = FcLayer(core_nf, core_nf) - self.color_out = nn.Sequential( - FcLayer(core_nf + dir_chns, dir_nf, activation), - FcLayer(dir_nf, color_chns) - ) - else: - self.feature_out = Identity() - self.color_out = FcLayer(core_nf, color_chns) - - def forward(self, coord: torch.Tensor, dir: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]: - core_output = self.core(coord) - density = self.density_out(core_output) if self.density_out is not None else None - if self.color_out is None: - color = None - else: - feature = self.feature_out(core_output) - if dir is not None: - feature = torch.cat([feature, dir], dim=-1) - color = torch.sigmoid(self.color_out(feature)) - return color, density \ No newline at end of file +from .space import * +from .core import * \ No newline at end of file diff --git a/modules/core.py b/modules/core.py new file mode 100644 index 0000000..28219fd --- /dev/null +++ b/modules/core.py @@ -0,0 +1,175 @@ +from .generic import * +from typing import Dict + + +class NerfCore(nn.Module): + + def __init__(self, *, coord_chns, density_chns, color_chns, core_nf, core_layers, + dir_chns=0, dir_nf=0, act='relu', skips=[]): + super().__init__() + self.core = FcNet(in_chns=coord_chns, out_chns=None, nf=core_nf, n_layers=core_layers, + skips=skips, act=act) + self.density_out = FcLayer(core_nf, density_chns) if density_chns > 0 else None + if color_chns == 0: + self.feature_out = None + self.color_out = None + elif dir_chns > 0: + self.feature_out = FcLayer(core_nf, core_nf) + self.color_out = nn.Sequential( + FcLayer(core_nf + dir_chns, dir_nf, act), + FcLayer(dir_nf, color_chns) + ) + else: + self.feature_out = torch.nn.Identity() + self.color_out = FcLayer(core_nf, color_chns) + + def forward(self, x: torch.Tensor, d: torch.Tensor, outputs: List[str]) -> Dict[str, torch.Tensor]: + ret = {} + core_output = self.core(x) + if 'density' in outputs: + ret['density'] = torch.relu(self.density_out(core_output)) \ + if self.density_out is not None else None + if 'color' in outputs: + if self.color_out is None: + ret['color'] = None + else: + feature = self.feature_out(core_output) + if dir is not None: + feature = torch.cat([feature, d], dim=-1) + ret['color'] = self.color_out(feature).sigmoid() + for key in outputs: + if key == 'density' or key == 'color': + continue + ret[key] = None + return ret + + +class NerfAdvCore(nn.Module): + + def __init__(self, *, x_chns: int, d_chns: int, density_chns: int, color_chns: int, + density_net_params: dict, color_net_params: dict, + specular_net_params: dict = None, + appearance="decomposite", + density_color_connection=False): + """ + Create a NeRF-Adv Core Net. + Required parameters for the sub-mlps include: "nf", "n_layers", "skips" and "act". + Other parameters will be properly set automatically. + + :param x_chns `int`: the channels of input "position" + :param d_chns `int`: the channels of input "direction" + :param density_chns `int`: the channels of output "density" + :param color_chns `int`: the channels of output "color" + :param density_net_params `dict`: parameters for the density net + :param color_net_params `dict`: parameters for the color net + :param specular_net_params `dict`: (optional) parameters for the optional specular net, defaults to None + :param appearance `str`: (optional) options are [decomposite|combined], defaults to "decomposite" + :param density_color_connection `bool`: (optional) whether to add connections between + density net and color net, defaults to False + """ + super().__init__() + self.density_chns = density_chns + self.color_chns = color_chns + self.specular_feature_chns = color_net_params["nf"] if specular_net_params else 0 + self.color_feature_chns = density_net_params["nf"] if density_color_connection else 0 + self.appearance = appearance + self.density_color_connection = density_color_connection + self.density_net = FcNet(**density_net_params, + in_chns=x_chns, + out_chns=self.density_chns + self.color_feature_chns, + out_act='relu') + if self.appearance == "newtype": + self.specular_feature_chns = d_chns * 3 + self.color_net = FcNet(**color_net_params, + in_chns=x_chns + self.color_feature_chns, + out_chns=self.color_chns + self.specular_feature_chns) + self.specular_net = "Placeholder" + else: + if self.appearance == "decomposite": + self.color_net = FcNet(**color_net_params, + in_chns=x_chns + self.color_feature_chns, + out_chns=self.color_chns + self.specular_feature_chns) + else: + if specular_net_params: + self.color_net = FcNet(**color_net_params, + in_chns=x_chns + self.color_feature_chns, + out_chns=self.specular_feature_chns) + else: + self.color_net = FcNet(**color_net_params, + in_chns=x_chns + d_chns + self.color_feature_chns, + out_chns=self.color_chns) + self.specular_net = FcNet(**specular_net_params, + in_chns=d_chns + self.specular_feature_chns, + out_chns=self.color_chns) if specular_net_params else None + + def forward(self, x: torch.Tensor, d: torch.Tensor, outputs: List[str], *, + color_feats: torch.Tensor = None) -> Dict[str, torch.Tensor]: + input_shape = x.shape[:-1] + if len(input_shape) > 1: + x = x.flatten(0, -2) + d = d.flatten(0, -2) + n = x.shape[0] + c = self.color_chns + + ret: Dict[str, torch.Tensor] = {} + + if 'density' in outputs: + density_net_out: torch.Tensor = self.density_net(x) + ret['density'] = density_net_out[:, :self.density_chns] + color_feats = density_net_out[:, self.density_chns:] + if 'color_feat' in outputs: + ret['color_feat'] = color_feats + + if 'color' in outputs or 'specluar' in outputs: + if 'density' in ret: + valid_mask = ret['density'][:, 0].detach() >= 1e-4 + indices = valid_mask.nonzero()[:, 0] + x, d, color_feats = x[indices], d[indices], color_feats[indices] + else: + indices = None + + speculars = None + color_net_in = [x] + if not self.specular_net: + color_net_in.append(d) + if self.density_color_connection: + color_net_in.append(color_feats) + color_net_in = torch.cat(color_net_in, -1) + color_net_out: torch.Tensor = self.color_net(color_net_in) + diffuses = color_net_out[:, :c] + specular_features = color_net_out[:, -self.specular_feature_chns:] + + if self.appearance == "newtype": + speculars = torch.bmm(specular_features.reshape(n, 3, d.shape[-1]), + d[..., None])[..., 0] + # TODO relu or not? + diffuses = diffuses.relu() + speculars = speculars.relu() + colors = diffuses + speculars + else: + if not self.specular_net: + colors = diffuses + diffuses = None + else: + specular_net_in = torch.cat([d, specular_features], -1) + specular_net_out = self.specular_net(specular_net_in) + if self.appearance == "decomposite": + speculars = specular_net_out + colors = diffuses + speculars + else: + diffuses = None + colors = specular_net_out + colors = torch.sigmoid(colors) # TODO indent or not? + if 'color' in outputs: + ret['color'] = colors.new_zeros(n, c).index_copy(0, indices, colors) \ + if indices else colors + if 'diffuse' in outputs: + ret['diffuse'] = diffuses.new_zeros(n, c).index_copy(0, indices, diffuses) \ + if indices is not None and diffuses is not None else diffuses + if 'specular' in outputs: + ret['specular'] = speculars.new_zeros(n, c).index_copy(0, indices, speculars) \ + if indices is not None and speculars is not None else speculars + + if len(input_shape) > 1: + ret = {key: val.reshape(*input_shape, -1) for key, val in ret.items()} + return ret diff --git a/modules/generic.py b/modules/generic.py index fe9e234..c8b0987 100644 --- a/modules/generic.py +++ b/modules/generic.py @@ -34,7 +34,7 @@ class Sine(nn.Module): class FcLayer(nn.Module): - def __init__(self, in_chns: int, out_chns: int, activation: str = 'linear', skip_chns: int = 0): + def __init__(self, in_chns: int, out_chns: int, act: str = 'linear', skip_chns: int = 0): super().__init__() nls_and_inits = { 'sine': (Sine(), sine_init), @@ -48,7 +48,7 @@ class FcLayer(nn.Module): 'logsoftmax': (nn.LogSoftmax(dim=-1), softmax_init), 'linear': (None, None) } - nl, nl_weight_init = nls_and_inits[activation] + nl, nl_weight_init = nls_and_inits[act] self.net = nn.Sequential( nn.Linear(in_chns + skip_chns, out_chns), @@ -59,7 +59,7 @@ class FcLayer(nn.Module): if nl_weight_init is not None: nl_weight_init(self.net if isinstance(self.net, nn.Linear) else self.net[0]) else: - self.init_params(activation) + self.init_params(act) def forward(self, x: torch.Tensor, x0: torch.Tensor = None) -> torch.Tensor: return self.net(torch.cat([x0, x], dim=-1) if self.skip else x) @@ -68,9 +68,9 @@ class FcLayer(nn.Module): linear_net = self.net if isinstance(self.net, nn.Linear) else self.net[0] return linear_net.weight, linear_net.bias - def init_params(self, activation): + def init_params(self, act): weight, bias = self.get_params() - nn.init.xavier_normal_(weight, gain=nn.init.calculate_gain(activation)) + nn.init.xavier_normal_(weight, gain=nn.init.calculate_gain(act)) nn.init.zeros_(bias) def copy_to(self, layer): @@ -83,7 +83,7 @@ class FcLayer(nn.Module): class FcNet(nn.Module): def __init__(self, *, in_chns: int, out_chns: int, nf: int, n_layers: int, - skips: List[int] = [], activation: str = 'relu'): + skips: List[int] = [], act: str = 'relu', out_act = 'linear'): """ Initialize a full-connection net @@ -95,12 +95,12 @@ class FcNet(nn.Module): """ super().__init__() - self.layers = [FcLayer(in_chns, nf, activation)] + [ - FcLayer(nf, nf, activation, skip_chns=in_chns if i in skips else 0) + self.layers = [FcLayer(in_chns, nf, act)] + [ + FcLayer(nf, nf, act, skip_chns=in_chns if i in skips else 0) for i in range(n_layers - 1) ] - if out_chns > 0: - self.layers.append(FcLayer(nf, out_chns)) + if out_chns: + self.layers.append(FcLayer(nf, out_chns, out_act)) for i, layer in enumerate(self.layers): self.add_module(f"layer{i}", layer) diff --git a/modules/renderer.py b/modules/renderer.py index 4f8b467..84a86b4 100644 --- a/modules/renderer.py +++ b/modules/renderer.py @@ -1,8 +1,44 @@ +from itertools import cycle +from math import ceil +from typing import Dict, Tuple, Union import torch import torch.nn as nn -import torch.nn.functional as nn_f + from utils.constants import * +from utils.perf import perf from .generic import * +from .sampler import Samples + + +def density2energy(densities: torch.Tensor, dists: torch.Tensor, raw_noise_std: float = 0): + """ + Calculate energies from densities inferred by model. + + :param densities `Tensor(N..., 1)`: model's output densities + :param dists `Tensor(N...)`: integration times + :param raw_noise_std `float`: the noise std used to egularize network during training (prevents + floater artifacts), defaults to 0, means no noise is added + :return `Tensor(N..., 1)`: energies which block light rays + """ + if raw_noise_std > 0: + # Add noise to model's predictions for density. Can be used to + # regularize network during training (prevents floater artifacts). + densities = densities + torch.normal(0.0, raw_noise_std, densities.size()) + return densities * dists[..., None] + + +def density2alpha(densities: torch.Tensor, dists: torch.Tensor, raw_noise_std: float = 0): + """ + Calculate alphas from densities inferred by model. + + :param densities `Tensor(N..., 1)`: model's output densities + :param dists `Tensor(N...)`: integration times + :param raw_noise_std `float`: the noise std used to egularize network during training (prevents + floater artifacts), defaults to 0, means no noise is added + :return `Tensor(N..., 1)`: alphas + """ + energies = density2energy(densities, dists, raw_noise_std) + return 1.0 - torch.exp(-energies) class AlphaComposition(nn.Module): @@ -11,18 +47,26 @@ class AlphaComposition(nn.Module): super().__init__() def forward(self, colors, alphas, bg=None): + """ + [summary] + + :param colors `Tensor(N, P, C)`: [description] + :param alphas `Tensor(N, P, 1)`: [description] + :param bg `Tensor([N, ]C)`: [description], defaults to None + :return `Tensor(N, C)`: [description] + """ # Compute weight for RGB of each sample along each ray. A cumprod() is # used to express the idea of the ray not having reflected up to this # sample yet. - one_minus_alpha = torch.cumprod(1 - alphas[..., :-1] + TINY_FLOAT, dim=-1) + one_minus_alpha = torch.cumprod(1 - alphas[..., :-1, :] + TINY_FLOAT, dim=-2) one_minus_alpha = torch.cat([ - torch.ones_like(one_minus_alpha[..., 0:1]), + torch.ones_like(one_minus_alpha[..., :1, :]), one_minus_alpha - ], dim=-1) - weights = alphas * one_minus_alpha # (N_rays, N) + ], dim=-2) + weights = alphas * one_minus_alpha # (N, P, 1) - # (N_rays, 1|3), computed weighted color of each sample along each ray. - final_color = torch.sum(weights[..., None] * colors, dim=-2) + # (N, C), computed weighted color of each sample along each ray. + final_color = torch.sum(weights * colors, dim=-2) # To composite onto a white background, use the accumulated alpha map. if bg is not None: @@ -38,58 +82,290 @@ class AlphaComposition(nn.Module): class VolumnRenderer(nn.Module): - def __init__(self, *, raw_noise_std=0.0, sigma_as_density=True): - """ - Initialize a Rendering module - """ + class States: + kernel: nn.Module + samples: Samples + hit_mask: torch.Tensor + early_stop_tolerance: float + N: int + P: int + + colors: torch.Tensor + diffuses: torch.Tensor + speculars: torch.Tensor + energies: torch.Tensor + weights: torch.Tensor + cum_energies: torch.Tensor + exp_energies: torch.Tensor + tot_evaluations: Dict[str, int] + + chunk: Tuple[slice, slice] + cum_chunk: Tuple[slice, slice] + cum_last: Tuple[slice, slice] + chunk_id: int + + @property + def start(self) -> int: + return self.chunk[1].start + + @property + def end(self) -> int: + return self.chunk[1].stop + + def __init__(self, kernel: nn.Module, samples: Samples, early_stop_tolerance: float) -> None: + self.kernel = kernel + self.samples = samples + self.early_stop_tolerance = early_stop_tolerance + + N, P = samples.size + self.hit_mask = samples.voxel_indices != -1 # (N, P) + self.colors = torch.zeros(N, P, kernel.chns('color'), device=samples.device) + self.diffuses = torch.zeros(N, P, kernel.chns('color'), device=samples.device) + self.speculars = torch.zeros(N, P, kernel.chns('color'), device=samples.device) + self.energies = torch.zeros(N, P, 1, device=samples.device) + self.weights = torch.zeros(N, P, 1, device=samples.device) + self.cum_energies = torch.zeros(N, P + 1, 1, device=samples.device) + self.exp_energies = torch.ones(N, P + 1, 1, device=samples.device) + self.tot_evaluations = {} + self.N, self.P = N, P + self.chunk_id = -1 + + def n_hits(self, start: int = None, end: int = None) -> int: + if start is None: + return self.hit_mask.count_nonzero().item() + if end is None: + return self.hit_mask[:, start].count_nonzero().item() + return self.hit_mask[:, start:end].count_nonzero().item() + + def accumulate_tot_evaluations(self, key: str, n: int): + if key not in self.tot_evaluations: + self.tot_evaluations[key] = 0 + self.tot_evaluations[key] += n + + def next_chunk(self, *, length=None, end=None): + start = 0 if not hasattr(self, "chunk") else self.end + length = length or self.P + end = min(end or start + length, self.P) + self.chunk = slice(None), slice(start, end) + self.cum_chunk = slice(None), slice(start + 1, end + 1) + self.cum_last = slice(None), slice(start, start + 1) + self.chunk_id += 1 + return self + + def __init__(self, **kwargs): super().__init__() - self.alpha_composition = AlphaComposition() - self.sigma_as_density = sigma_as_density - self.raw_noise_std = raw_noise_std - - def forward(self, colors, sigmas, z_vals, bg_color=None, ret_depth=False, debug=False): - """Transforms model's predictions to semantically meaningful values. - - Args: - color: [num_rays, num_samples along ray, 1|3]. Predicted color from model. - density: [num_rays, num_samples along ray]. Predicted density from model. - z_vals: [num_rays, num_samples along ray]. Integration time. - - Returns: - rgb_map: [num_rays, 1|3]. Estimated RGB color of a ray. - disp_map: [num_rays]. Disparity map. Inverse of depth map. - acc_map: [num_rays]. Sum of weights along each ray. - weights: [num_rays, num_samples]. Weights assigned to each sampled color. - depth_map: [num_rays]. Estimated distance to object. + + @perf + def forward(self, kernel: nn.Module, samples: Samples, extra_outputs: List[str] = [], *, + raymarching_early_stop_tolerance: float = 0, + raymarching_chunk_size_or_sections: Union[int, List[int]] = None, + **kwargs): + """ + Perform volumn rendering. + + :param kernel: render kernel + :param samples `Samples(N, P)`: samples + :param extra_outputs `list[str]`: extra items should be contained in the result dict. + Optional values include 'depth', 'layers', 'states' and attribute names in class `States` (e.g. 'weights'). Defaults to [] + :param raymarching_early_stop_tolerance `float`: tolerance of raymarching early stop. + Should between 0 and 1 (0 means no early stop). Defaults to 0 + :param raymarching_chunk_size_or_sections `int|list[int]`: indicates how to split raymarching process. + Use a list of integers to specify samples of every chunk, or a positive integer to specify number of chunks. + Use a negative interger to split by number of hits in chunks, and the absolute value means maximum number of hits in a chunk. + 0 and `None` means not splitting the raymarching process. Defaults to `None` + :return `dict`: render result { 'color'[, 'depth', 'layers', 'states', ...] } """ - alphas = self.density2alpha(sigmas, z_vals) if self.sigma_as_density \ - else nn_f.sigmoid(sigmas) - ret = self.alpha_composition(colors, alphas, bg_color) - if ret_depth: - ret['depth'] = torch.sum(ret['weights'] * z_vals, dim=-1) - if debug: - ret['layers'] = torch.cat([colors, alphas[..., None]], dim=-1) + if samples.size[1] == 0: + print("VolumnRenderer.forward(): # of samples is zero") + return None + + s = VolumnRenderer.States(kernel, samples, raymarching_early_stop_tolerance) + + if not raymarching_chunk_size_or_sections: + raymarching_chunk_size_or_sections = [s.P] + elif isinstance(raymarching_chunk_size_or_sections, int) and \ + raymarching_chunk_size_or_sections > 0: + raymarching_chunk_size_or_sections = [ceil(s.P / raymarching_chunk_size_or_sections)] + + if isinstance(raymarching_chunk_size_or_sections, list): + chunk_sections = raymarching_chunk_size_or_sections + for chunk_samples in cycle(chunk_sections): + self._forward_chunk(s.next_chunk(length=chunk_samples)) + if s.end >= s.P: + break + else: + chunk_size = -raymarching_chunk_size_or_sections + chunk_hits = s.n_hits(0) + for i in range(1, s.P): + n_hits = s.n_hits(i) + if chunk_hits + n_hits > chunk_size: + self._forward_chunk(s.next_chunk(end=i)) + n_hits = s.n_hits(i) + chunk_hits = 0 + chunk_hits += n_hits + self._forward_chunk(s.next_chunk()) + + ret = { + 'color': torch.sum(s.colors * s.weights, 1), + 'tot_evaluations': s.tot_evaluations + } + for key in extra_outputs: + if key == 'depth': + ret['depth'] = torch.sum(s.samples.depths[..., None] * s.weights, 1) + elif key == 'diffuse': + ret['diffuse'] = torch.sum(s.diffuses * s.weights, 1) + elif key == 'specular': + ret['specular'] = torch.sum(s.speculars * s.weights, 1) + elif key == 'layers': + ret['layers'] = torch.cat([s.colors, 1 - torch.exp(-s.energies)], dim=-1) + elif key == 'states': + ret['states'] = s + else: + ret[key] = getattr(s, key) return ret - def density2alpha(self, densities: torch.Tensor, z_vals: torch.Tensor): + # if raymarching_chunk_size == 0: + # raymarching_chunk_samples = 1 + # if raymarching_chunk_samples != 0: + # if isinstance(raymarching_chunk_samples, int): + # raymarching_chunk_samples = repeat(raymarching_chunk_samples, + # ceil(s.P / raymarching_chunk_samples)) + # chunk_offset = 0 + # for chunk_samples in raymarching_chunk_samples: + # start, end = chunk_offset, chunk_offset + chunk_samples + # n_hits = self._forward_chunk(s, start, end) + # if n_hits > 0 and tolerance > 0: # Early stop + # s.hit_mask[s.cum_energies[:, end, 0] > tolerance] = 0 + # chunk_offset += chunk_samples + # elif raymarching_chunk_size > 0: + # chunk_offset, chunk_hits = 0, s.n_hits(0) + # for i in range(1, s.P): + # n_hits = s.n_hits(i) + # if chunk_hits + n_hits > raymarching_chunk_size: + # self._forward_chunk(s, chunk_offset, i, chunk_hits) + # if chunk_hits > 0 and tolerance > 0: # Early stop + # s.hit_mask[s.cum_energies[:, i, 0] > tolerance] = 0 + # n_hits = s.n_hits(i) + # chunk_hits, chunk_offset = 0, i + # chunk_hits += n_hits + # self._forward_chunk(s, chunk_offset, s.P, chunk_hits) + # else: + # self._forward_chunk(s, 0, s.P) + + # return self._composite(s, extra_outputs) + # original_depth = samples.get('original_point_depth', None) + # if original_depth is not None: + # results['z'] = (original_depth * probs).sum(-1) + # if getattr(input_fn, "track_max_probs", False) and (not self.training): + # input_fn.track_voxel_probs(samples['sampled_point_voxel_idx'].long(), results['probs']) + + def _calc_weights(self, s: States): + """ + Calculate weights of samples in composited outputs + + :param s `States`: states + :param start `int`: chunk's start + :param end `int`: chunk's end """ - Raw value inferred from model to color and alpha + s.cum_energies[s.cum_chunk] = torch.cumsum(s.energies[s.chunk], 1) \ + + s.cum_energies[s.cum_last] + s.exp_energies[s.cum_chunk] = (-s.cum_energies[s.cum_chunk]).exp() + s.weights[s.chunk] = s.exp_energies[s.chunk] - s.exp_energies[s.cum_chunk] - :param densities `Tensor(N.rays, N.samples)`: model's output density - :param z_vals `Tensor(N.rays, N.samples)`: integration time - :return `Tensor(N.rays, N.samples)`: alpha + def _apply_early_stop(self, s: States): """ + Stop rays whose accumulated opacity are larger than a threshold + + :param s `States`: s + :param end `int`: chunk's end + """ + if s.end < s.P and s.early_stop_tolerance > 0: + rays_to_stop = s.exp_energies[:, s.end, 0] < s.early_stop_tolerance + s.hit_mask[rays_to_stop, s.end:] = 0 + + def _forward_chunk(self, s: States) -> int: + fi_idxs: Tuple[torch.Tensor, ...] = s.hit_mask[s.chunk].nonzero(as_tuple=True) # (N') + fi_idxs[1].add_(s.start) + + if fi_idxs[0].size(0) == 0: + s.cum_energies[s.cum_chunk] = s.cum_energies[s.cum_last] + s.exp_energies[s.cum_chunk] = s.exp_energies[s.cum_last] + return 0 + + # fi_* means "filtered" by hit mask + fi_samples = s.samples[fi_idxs] # N -> N' + + # Infer densities and colors + fi_outputs = s.kernel.render(fi_samples, 'color', 'density', 'specular', 'diffuse', + chunk_id=s.chunk_id) + s.colors.index_put_(fi_idxs, fi_outputs['color']) + if fi_outputs['specular'] is not None: + s.speculars.index_put_(fi_idxs, fi_outputs['specular']) + if fi_outputs['diffuse'] is not None: + s.diffuses.index_put_(fi_idxs, fi_outputs['diffuse']) + s.energies.index_put_(fi_idxs, density2energy(fi_outputs['density'], fi_samples.dists)) + s.accumulate_tot_evaluations("color", fi_idxs[0].size(0)) + + self._calc_weights(s) + self._apply_early_stop(s) + + +class DensityFirstVolumnRenderer(VolumnRenderer): + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def _forward_chunk(self, s: VolumnRenderer.States) -> int: + fi_idxs: Tuple[torch.Tensor, ...] = s.hit_mask[s.chunk].nonzero(as_tuple=True) # (N') + fi_idxs[1].add_(s.start) + + if fi_idxs[0].size(0) == 0: + s.cum_energies[s.cum_chunk] = s.cum_energies[s.cum_last] + s.exp_energies[s.cum_chunk] = s.exp_energies[s.cum_last] + return 0 + + # fi_* means "filtered" by hit mask + fi_samples = s.samples[fi_idxs] # N -> N' + + # For all valid samples: encode X + fi_encoded_x = s.kernel.encode_x(fi_samples) # (N', Ex) + + # Infer densities (shape) + fi_outputs = s.kernel.infer(fi_encoded_x, None, 'density', 'color_feat', + chunk_id=s.chunk_id) + s.energies.index_put_(fi_idxs, density2energy(fi_outputs['density'], fi_samples.dists)) + s.accumulate_tot_evaluations("density", fi_idxs[0].size(0)) + + self._calc_weights(s) + self._apply_early_stop(s) + + # Remove samples whose weights are less than a threshold + s.hit_mask[s.chunk][s.weights[s.chunk][..., 0] < 0.01] = 0 + + # Update "filtered" tensors + fi_mask = s.hit_mask[fi_idxs] + fi_idxs = (fi_idxs[0][fi_mask], fi_idxs[1][fi_mask]) # N' -> N" + fi_encoded_x = fi_encoded_x[fi_mask] # (N", Ex) + fi_color_feats = fi_outputs['color_feat'][fi_mask] + + # For all valid samples: encode D + fi_encoded_d = s.kernel.encode_d(s.samples[fi_idxs]) # (N", Ed) - # Compute 'distance' (in time) between each integration time along a ray. - # The 'distance' from the last integration time is infinity. - # dists: (N_rays, N) - dists = z_vals[..., 1:] - z_vals[..., :-1] - last_dist = torch.zeros_like(z_vals[..., 0:1]) + TINY_FLOAT - dists = torch.cat([dists, last_dist], -1) - - if self.raw_noise_std > 0.: - # Add noise to model's predictions for density. Can be used to - # regularize network during training (prevents floater artifacts). - noise = torch.normal(0.0, self.raw_noise_std, densities.size()) - densities = densities + noise - return -torch.exp(-torch.relu(densities) * dists) + 1.0 + # Infer colors (appearance) + fi_outputs = s.kernel.infer(fi_encoded_x, fi_encoded_d, 'color', 'specular', 'diffuse', + chunk_id=s.chunk_id, + extras={"color_feats": fi_color_feats}) + # if s.chunk_id == 0: + # fi_colors[:] *= fi_colors.new_tensor([1, 0, 0]) + # elif s.chunk_id == 1: + # fi_colors[:] *= fi_colors.new_tensor([0, 1, 0]) + # elif s.chunk_id == 2: + # fi_colors[:] *= fi_colors.new_tensor([0, 0, 1]) + # else: + # fi_colors[:] *= fi_colors.new_tensor([1, 1, 0]) + s.colors.index_put_(fi_idxs, fi_outputs['color']) + if fi_outputs['specular'] is not None: + s.speculars.index_put_(fi_idxs, fi_outputs['specular']) + if fi_outputs['diffuse'] is not None: + s.diffuses.index_put_(fi_idxs, fi_outputs['diffuse']) + s.accumulate_tot_evaluations("color", fi_idxs[0].size(0)) diff --git a/modules/sampler.py b/modules/sampler.py index eacd072..1eae990 100644 --- a/modules/sampler.py +++ b/modules/sampler.py @@ -1,14 +1,26 @@ -from typing import Tuple +from .space import Space, Voxels import torch import torch.nn as nn +from typing import Tuple + from utils import device from utils import sphere from utils.constants import * +from utils.perf import perf, checkpoint from .generic import * +from clib import * class Bins(object): + @property + def up(self): + return self.bounds[1:] + + @property + def lo(self): + return self.bounds[:-1] + def __init__(self, vals: torch.Tensor): self.vals = vals self.bounds = torch.cat([ @@ -16,8 +28,6 @@ class Bins(object): 0.5 * (self.vals[1:] + self.vals[:-1]), self.vals[-1:] ]) - self.up = self.bounds[1:] - self.lo = self.bounds[:-1] @staticmethod def linspace(val_range: Tuple[float, float], N: int, device: torch.device = None): @@ -26,14 +36,60 @@ class Bins(object): def to(self, device: torch.device): self.vals = self.vals.to(device) self.bounds = self.bounds.to(device) - self.up = self.bounds[1:] - self.lo = self.bounds[:-1] + + +class Samples: + pts: torch.Tensor + """`Tensor(N[, P], 3)`""" + + dirs: torch.Tensor + """`Tensor(N[, P], 3)`""" + + depths: torch.Tensor + """`Tensor(N[, P])`""" + + dists: torch.Tensor + """`Tensor(N[, P])`""" + + voxel_indices: torch.Tensor + """`Tensor(N[, P])`""" + + @property + def size(self): + return self.pts.size()[:-1] + + @property + def device(self): + return self.pts.device + + def __init__(self, pts: torch.Tensor, dirs: torch.Tensor, depths: torch.Tensor, + dists: torch.Tensor, voxel_indices: torch.Tensor) -> None: + self.pts = pts + self.dirs = dirs + self.depths = depths + self.dists = dists + self.voxel_indices = voxel_indices + + def __getitem__(self, index): + return Samples( + pts=self.pts[index], + dirs=self.dirs[index], + depths=self.depths[index], + dists=self.dists[index], + voxel_indices=self.voxel_indices[index]) + + def reshape(self, *shape: int): + return Samples( + pts=self.pts.reshape(*shape, 3), + dirs=self.dirs.reshape(*shape, 3), + depths=self.depths.reshape(*shape), + dists=self.dists.reshape(*shape), + voxel_indices=self.voxel_indices.reshape(*shape)) class Sampler(nn.Module): - def __init__(self, *, sample_range: Tuple[float, float], n_samples: int, - perturb_sample: bool, spherical: bool, lindisp: bool): + def __init__(self, *, sample_range: Tuple[float, float], n_samples: int, lindisp: bool, **kwargs): """ Initialize a Sampler module @@ -44,37 +100,81 @@ class Sampler(nn.Module): """ super().__init__() self.lindisp = lindisp - self.spherical = spherical - self.perturb_sample = perturb_sample s_range = (1 / sample_range[0], 1 / sample_range[1]) if self.lindisp else sample_range + if s_range[1] > s_range[0]: + s_range[0] += 1e-4 + s_range[1] -= 1e-4 + else: + s_range[0] -= 1e-4 + s_range[1] += 1e-4 self.bins = Bins.linspace(s_range, n_samples, device=device.default()) - def forward(self, rays_o, rays_d): + @perf + def forward(self, rays_o: torch.Tensor, rays_d: torch.Tensor, space_module: Space, + perturb_sample: bool, **kwargs) -> Tuple[Samples, torch.Tensor]: """ Sample points along rays. return Spherical or Cartesian coordinates, specified by `self.shperical` - :param rays_o `Tensor(B, 3)`: rays' origin - :param rays_d `Tensor(B, 3)`: rays' direction - :return `Tensor(B, N, 3)`: sampled points - :return `Tensor(B, N)`: corresponding depths along rays + :param rays_o `Tensor(N, 3)`: rays' origin + :param rays_d `Tensor(N, 3)`: rays' direction + :return `Samples(N, P)`: samples """ s = self.bins.vals.expand(rays_o.size(0), -1) - if self.perturb_sample: + if perturb_sample: s = self.bins.lo + (self.bins.up - self.bins.lo) * torch.rand_like(s) + pts, depths = self._get_sample_points(rays_o, rays_d, s) + voxel_indices = space_module.get_voxel_indices(pts) + valid_rays_mask = voxel_indices.ne(-1).any(dim=-1) + return Samples( + pts=pts, + dirs=rays_d[:, None].expand(-1, depths.size(1), -1), + depths=depths, + dists=self._calc_dists(depths), + voxel_indices=voxel_indices + )[valid_rays_mask], valid_rays_mask + + def _get_sample_points(self, rays_o, rays_d, s): z = torch.reciprocal(s) if self.lindisp else s - if self.spherical: - pts, depths = sphere.ray_sphere_intersect(rays_o, rays_d, z) - sphers = sphere.cartesian2spherical(pts, inverse_r=self.lindisp) - return sphers, depths, s, pts - else: - return rays_o[..., None, :] + rays_d[..., None, :] * z[..., None], z, s, None + pts = rays_o[:, None] + rays_d[:, None] * z[..., None] + depths = z + return pts, depths + + def _calc_dists(self, vals): + # Compute 'distance' (in time) between each integration time along a ray. + # The 'distance' from the last integration time is infinity. + # dists: (N_rays, N) + dists = vals[..., 1:] - vals[..., :-1] + last_dist = torch.zeros_like(vals[..., :1]) + TINY_FLOAT + return torch.cat([dists, last_dist], -1) + + +class SphericalSampler(Sampler): + + def __init__(self, *, sample_range: Tuple[float, float], n_samples: int, + perturb_sample: bool, **kwargs): + """ + Initialize a Sampler module + + :param depth_range: depth range for sampler + :param n_samples: count to sample along ray + :param perturb_sample: perturb the sample depths + :param lindisp: If True, sample linearly in inverse depth rather than in depth + """ + super().__init__(sample_range=sample_range, n_samples=n_samples, + perturb_sample=perturb_sample, lindisp=False) + + def _get_sample_points(self, rays_o, rays_d, s): + r = torch.reciprocal(s) + pts, depths = sphere.ray_sphere_intersect(rays_o, rays_d, r) + pts = sphere.cartesian2spherical(pts, inverse_r=True) + return pts, depths class PdfSampler(nn.Module): def __init__(self, *, depth_range: Tuple[float, float], n_samples: int, perturb_sample: bool, - spherical: bool, lindisp: bool): + spherical: bool, lindisp: bool, **kwargs): """ Initialize a Sampler module @@ -90,7 +190,7 @@ class PdfSampler(nn.Module): self.n_samples = n_samples self.s_range = (1 / depth_range[0], 1 / depth_range[1]) if self.lindisp else depth_range - def forward(self, rays_o, rays_d, *, weights, s_vals=None, include_s_vals=False): + def forward(self, rays_o, rays_d, *, weights, s_vals=None, include_s_vals=False, **kwargs): """ Sample points along rays. return Spherical or Cartesian coordinates, specified by `self.shperical` @@ -166,22 +266,116 @@ class PdfSampler(nn.Module): class VoxelSampler(nn.Module): - def __init__(self, *, depth_range: Tuple[float, float], n_samples: int, perturb_sample: bool, - lindisp: bool, space): + def __init__(self, *, perturb_sample: bool, sample_step: float, **kwargs): """ - Initialize a Sampler module + Initialize a VoxelSampler module - :param depth_range: depth range for sampler - :param n_samples: count to sample along ray :param perturb_sample: perturb the sample depths - :param lindisp: If True, sample linearly in inverse depth rather than in depth + :param step_size: step size """ super().__init__() - self.lindisp = lindisp self.perturb_sample = perturb_sample - self.n_samples = n_samples - self.space = space - self.s_range = (1 / depth_range[0], 1 / depth_range[1]) if self.lindisp else depth_range + self.sample_step = sample_step + + def _forward(self, rays_o: torch.Tensor, rays_d: torch.Tensor, space_module: Space, + **kwargs) -> Tuple[Samples, torch.Tensor]: + """ + [summary] - def forward(self, rays_o, rays_d, *, weights, s_vals=None, include_s_vals=False): - \ No newline at end of file + :param rays_o `Tensor(N, 3)`: rays' origin positions + :param rays_d `Tensor(N, 3)`: rays' directions + :param step_size `float`: gap between samples along a ray + :return `Samples(N', P)`: samples along valid rays (which hit at least one voxel) + :return `Tensor(N)`: valid rays mask + """ + intersections = space_module.ray_intersect(rays_o, rays_d, 100) + valid_rays_mask = intersections.hits > 0 + rays_o = rays_o[valid_rays_mask] + rays_d = rays_d[valid_rays_mask] + intersections = intersections[valid_rays_mask] # (N) -> (N') + n_rays = rays_o.size(0) + ray_index_list = torch.arange(n_rays, device=rays_o.device, dtype=torch.long) # (N') + + hits = intersections.hits + min_depths = intersections.min_depths + max_depths = intersections.max_depths + voxel_indices = intersections.voxel_indices + + rays_near_depth = min_depths[:, :1] # (N', 1) + rays_far_depth = max_depths[ray_index_list, hits - 1][:, None] # (N', 1) + rays_length = rays_far_depth - rays_near_depth + rays_steps = (rays_length / self.sample_step).ceil().long() + rays_step_size = rays_length / rays_steps + max_steps = rays_steps.max().item() + rays_step = torch.arange(max_steps, device=rays_o.device, + dtype=torch.float)[None].repeat(n_rays, 1) # (N', P) + invalid_samples_mask = rays_step >= rays_steps + samples_min_depth = rays_near_depth + rays_step * rays_step_size + samples_depth = samples_min_depth + rays_step_size \ + * (torch.rand_like(samples_min_depth) if self.perturb_sample else 0.5) # (N', P) + samples_dist = rays_step_size.repeat(1, max_steps) # (N', 1) -> (N', P) + samples_voxel_index = voxel_indices[ + ray_index_list[:, None], + torch.searchsorted(max_depths, samples_depth) + ] # (N', P) + samples_depth[invalid_samples_mask] = HUGE_FLOAT + samples_dist[invalid_samples_mask] = 0 + samples_voxel_index[invalid_samples_mask] = -1 + + rays_o, rays_d = rays_o[:, None], rays_d[:, None] + return Samples( + pts=rays_o + rays_d * samples_depth[..., None], + dirs=rays_d.expand(-1, max_steps, -1), + depths=samples_depth, + dists=samples_dist, + voxel_indices=samples_voxel_index + ), valid_rays_mask + + @perf + def forward(self, rays_o: torch.Tensor, rays_d: torch.Tensor, space_module: Space, + **kwargs) -> Tuple[Samples, torch.Tensor]: + """ + [summary] + + :param rays_o `Tensor(N, 3)`: [description] + :param rays_d `Tensor(N, 3)`: [description] + :param step_size `float`: [description] + :return `Samples(N, P)`: [description] + """ + intersections = space_module.ray_intersect(rays_o, rays_d, 100) + valid_rays_mask = intersections.hits > 0 + rays_o = rays_o[valid_rays_mask] + rays_d = rays_d[valid_rays_mask] + intersections = intersections[valid_rays_mask] # (N) -> (N') + + checkpoint("Ray intersect") + + if intersections.size == 0: + return None, valid_rays_mask + else: + min_depth = intersections.min_depths + max_depth = intersections.max_depths + pts_idx = intersections.voxel_indices + dists = max_depth - min_depth + tot_dists = dists.sum(dim=-1, keepdim=True) # (N, 1) + probs = dists / tot_dists + steps = tot_dists[:, 0] / self.sample_step + + # sample points and use middle point approximation + sampled_indices, sampled_depths, sampled_dists = inverse_cdf_sampling( + pts_idx, min_depth, max_depth, probs, steps, -1, not self.perturb_sample) + sampled_indices = sampled_indices.long() + invalid_idx_mask = sampled_indices.eq(-1) + sampled_dists.clamp_min_(0).masked_fill_(invalid_idx_mask, 0) + sampled_depths.masked_fill_(invalid_idx_mask, HUGE_FLOAT) + + checkpoint("Inverse CDF sampling") + + rays_o, rays_d = rays_o[:, None], rays_d[:, None] + return Samples( + pts=rays_o + rays_d * sampled_depths[..., None], + dirs=rays_d.expand(-1, sampled_depths.size(1), -1), + depths=sampled_depths, + dists=sampled_dists, + voxel_indices=sampled_indices + ), valid_rays_mask diff --git a/modules/space.py b/modules/space.py new file mode 100644 index 0000000..26dac98 --- /dev/null +++ b/modules/space.py @@ -0,0 +1,351 @@ +from math import ceil +import torch +import numpy as np +from typing import List, NoReturn, Tuple, Union +from torch import nn +from plyfile import PlyData, PlyElement + +from utils.geometry import * +from utils.constants import * +from utils.voxels import * +from utils.perf import perf +from clib import * + + +class Intersections: + min_depths: torch.Tensor + """`Tensor(N, P)` Min ray depths of intersected voxels""" + + max_depths: torch.Tensor + """`Tensor(N, P)` Max ray depths of intersected voxels""" + + voxel_indices: torch.Tensor + """`Tensor(N, P)` Indices of intersected voxels""" + + hits: torch.Tensor + """`Tensor(N)` Number of hits""" + + @property + def size(self): + return self.hits.size(0) + + def __init__(self, min_depths: torch.Tensor, max_depths: torch.Tensor, + voxel_indices: torch.Tensor, hits: torch.Tensor) -> None: + self.min_depths = min_depths + self.max_depths = max_depths + self.voxel_indices = voxel_indices + self.hits = hits + + def __getitem__(self, index): + return Intersections( + min_depths=self.min_depths[index], + max_depths=self.max_depths[index], + voxel_indices=self.voxel_indices[index], + hits=self.hits[index]) + + +class Space(nn.Module): + bbox: Union[torch.Tensor, None] + """`Tensor(2, 3)` Bounding box""" + + def __init__(self, *, bbox: List[float] = None, **kwargs): + super().__init__() + if bbox is None: + self.bbox = None + else: + self.register_buffer('bbox', torch.Tensor(bbox).reshape(2, 3), persistent=False) + + def create_embedding(self, n_dims: int, name: str = 'default') -> torch.nn.Embedding: + raise NotImplementedError + + def extract_embedding(self, pts: torch.Tensor, voxel_indices: torch.Tensor, + name: str = 'default') -> torch.Tensor: + raise NotImplementedError + + def ray_intersect(self, rays_o: torch.Tensor, rays_d: torch.Tensor, n_max_hits: int) -> Intersections: + raise NotImplementedError + + def get_voxel_indices(self, pts: torch.Tensor) -> torch.Tensor: + voxel_indices = torch.zeros_like(pts[..., 0], dtype=torch.long) + if self.bbox is not None: + out_bbox = torch.logical_or(pts < self.bbox[0], pts >= self.bbox[1]).any(-1) # (N...) + voxel_indices[out_bbox] = -1 + return voxel_indices + + @torch.no_grad() + def pruning(self, score_fn, threshold: float = 0.5, train_stats=False): + raise NotImplementedError() + + @torch.no_grad() + def splitting(self): + raise NotImplementedError() + + +class Voxels(Space): + steps: torch.Tensor + """`Tensor(3)` Steps along each dimension""" + + corners: torch.Tensor + """`Tensor(C, 3)` Corner positions""" + + voxels: torch.Tensor + """`Tensor(M, 3)` Voxel centers""" + + corner_indices: torch.Tensor + """`Tensor(M, 8)` Voxel corner indices""" + + voxel_indices_in_grid: torch.Tensor + """`Tensor(G)` Indices in voxel list or -1 for pruned space""" + + @property + def dims(self) -> int: + """`int` Number of dimensions""" + return self.steps.size(0) + + @property + def n_voxels(self) -> int: + """`int` Number of voxels""" + return self.voxels.size(0) + + @property + def n_corner(self) -> int: + """`int` Number of corners""" + return self.corners.size(0) + + @property + def voxel_size(self) -> torch.Tensor: + """`Tensor(3)` Voxel size""" + return (self.bbox[1] - self.bbox[0]) / self.steps + + @property + def device(self) -> torch.device: + return self.voxels.device + + def __init__(self, *, voxel_size: float = None, + steps: Union[torch.Tensor, Tuple[int, int, int]] = None, **kwargs) -> None: + super().__init__(**kwargs) + if self.bbox is None: + raise ValueError("Missing argument 'bbox'") + if voxel_size is not None: + self.register_buffer('steps', get_grid_steps(self.bbox, voxel_size)) + else: + self.register_buffer('steps', torch.tensor(steps, dtype=torch.long)) + self.register_buffer('voxels', init_voxels(self.bbox, self.steps)) + corners, corner_indices = get_corners(self.voxels, self.bbox, self.steps) + self.register_buffer("corners", corners) + self.register_buffer("corner_indices", corner_indices) + self.register_buffer('voxel_indices_in_grid', torch.arange(self.n_voxels)) + self._register_load_state_dict_pre_hook(self._before_load_state_dict) + + def create_embedding(self, n_dims: int, name: str = 'default') -> torch.nn.Embedding: + """ + Create a embedding on voxel corners. + + :param name `str`: embedding name + :param n_dims `int`: embedding dimension + :return `Embedding(n_corners, n_dims)`: new embedding on voxel corners + """ + name = f'emb_{name}' + self.add_module(name, torch.nn.Embedding(self.n_corners.item(), n_dims)) + return self.__getattr__(name) + + def get_embedding(self, name: str = 'default') -> torch.nn.Embedding: + return getattr(self, f'emb_{name}') + + def extract_embedding(self, pts: torch.Tensor, voxel_indices: torch.Tensor, + name: str = 'default') -> torch.Tensor: + """ + Extract embedding values at given points using trilinear interpolation. + + :param pts `Tensor(N, 3)`: points to extract values + :param voxel_indices `Tensor(N)`: corresponding voxel indices + :param name `str`: embedding name, default to 'default' + :return `Tensor(N, X)`: extracted values + """ + emb = self.get_embedding(name) + if emb is None: + raise KeyError(f"Embedding '{name}' doesn't exist") + voxels = self.voxels[voxel_indices] # (N, 3) + corner_indices = self.corner_indices[voxel_indices] # (N, 8) + p = (pts - voxels) / self.voxel_size + 0.5 # (N, 3) normed-coords in voxel + features = emb(corner_indices).reshape(pts.size(0), 8, -1) # (N, 8, X) + return trilinear_interp(p, features) + + @perf + def ray_intersect(self, rays_o: torch.Tensor, rays_d: torch.Tensor, n_max_hits: int) -> Intersections: + """ + Calculate intersections of rays and voxels. + + :param rays_o `Tensor(N, 3)`: rays' origin + :param rays_d `Tensor(N, 3)`: rays' direction + :param n_max_hits `int`: maximum number of hits (for allocating enough space) + :return `Intersection`: intersections of rays and voxels + """ + # Prepend a dim to meet the requirement of external call + rays_o = rays_o[None].contiguous() + rays_d = rays_d[None].contiguous() + + voxel_indices, min_depths, max_depths = self._ray_intersect(rays_o, rays_d, n_max_hits) + invalid_voxel_mask = voxel_indices.eq(-1) + hits = n_max_hits - invalid_voxel_mask.sum(-1) + + # Sort intersections according to their depths + min_depths.masked_fill_(invalid_voxel_mask, HUGE_FLOAT) + max_depths.masked_fill_(invalid_voxel_mask, HUGE_FLOAT) + min_depths, sorted_idx = min_depths.sort(dim=-1) + max_depths = max_depths.gather(-1, sorted_idx) + voxel_indices = voxel_indices.gather(-1, sorted_idx) + + return Intersections( + min_depths=min_depths[0], + max_depths=max_depths[0], + voxel_indices=voxel_indices[0], + hits=hits[0] + ) + + @perf + def get_voxel_indices(self, pts: torch.Tensor) -> torch.Tensor: + """ + Get voxel indices of points. + + If a point is not in any valid voxels, its corresponding voxel index is -1. + + :param pts `Tensor(N..., 3)`: points + :return `Tensor(N...)`: corresponding voxel indices + """ + grid_indices, out_mask = to_grid_indices(pts, self.bbox, steps=self.steps) + grid_indices[out_mask] = 0 + voxel_indices = self.voxel_indices_in_grid[grid_indices] + voxel_indices[out_mask] = -1 + return voxel_indices + + @torch.no_grad() + def splitting(self) -> None: + """ + Split voxels into smaller voxels with half size. + """ + n_voxels_before = self.n_voxels + self.steps *= 2 + self.voxels = split_voxels(self.voxels, self.voxel_size, 2, align_border=False)\ + .reshape(-1, 3) + self._update_corners() + self._update_voxel_indices_in_grid() + return n_voxels_before, self.n_voxels + + @torch.no_grad() + def prune(self, keeps: torch.Tensor) -> Tuple[int, int]: + self.voxels = self.voxels[keeps] + self.corner_indices = self.corner_indices[keeps] + self._update_voxel_indices_in_grid() + return keeps.size(0), keeps.sum().item() + + @torch.no_grad() + def pruning(self, score_fn, threshold: float = 0.5) -> None: + scores = self._get_scores(score_fn, lambda x: torch.max(x, -1)[0]) # (M) + return self.prune(scores > threshold) + + def n_voxels_along_dim(self, dim: int) -> torch.Tensor: + sum_dims = [val for val in range(self.dims) if val != dim] + return self.voxel_indices_in_grid.reshape(*self.steps).ne(-1).sum(sum_dims) + + def balance_cut(self, dim: int, n_parts: int) -> List[int]: + n_voxels_list = self.n_voxels_along_dim(dim) + cdf = (n_voxels_list.cumsum(0) / self.n_voxels * n_parts).tolist() + bins = [] + part = 1 + offset = 0 + for i in range(len(cdf)): + if cdf[i] >= part: + bins.append(i + 1 - offset) + offset = i + 1 + part = int(cdf[i]) + 1 + return bins + + def sample(self, bits: int, perturb: bool = False) -> Tuple[torch.Tensor, torch.Tensor]: + sampled_xyz = split_voxels(self.voxels, self.voxel_size, bits) + sampled_idx = torch.arange(self.n_voxels, device=self.device)[:, None].expand( + *sampled_xyz.shape[:2]) + sampled_xyz, sampled_idx = sampled_xyz.reshape(-1, 3), sampled_idx.flatten() + + @torch.no_grad() + def _get_scores(self, score_fn, reduce_fn=None, bits=16) -> torch.Tensor: + def get_scores_once(pts, idxs): + scores = score_fn(pts, idxs).reshape(-1, bits ** 3) # (B, P) + if reduce_fn is not None: + scores = reduce_fn(scores) # (B[, ...]) + return scores + + sampled_xyz, sampled_idx = self.sample(bits) + chunk_size = 64 + return torch.cat([ + get_scores_once(sampled_xyz[i:i + chunk_size], sampled_idx[i:i + chunk_size]) + for i in range(0, self.voxels.size(0), chunk_size) + ], 0) # (M[, ...]) + + def _ray_intersect(self, rays_o: torch.Tensor, rays_d: torch.Tensor, n_max_hits: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + return aabb_ray_intersect(self.voxel_size, n_max_hits, self.voxels, rays_o, rays_d) + + + def _update_corners(self): + """ + Update voxel corners. + """ + corners, corner_indices = get_corners(self.voxels, self.bbox, self.steps) + self.register_buffer("corners", corners) + self.register_buffer("corner_indices", corner_indices) + + def _update_voxel_indices_in_grid(self): + """ + Update voxel indices in grid. + """ + grid_indices, _ = to_grid_indices(self.voxels, self.bbox, steps=self.steps) + self.voxel_indices_in_grid = grid_indices.new_full([self.steps.prod().item()], -1) + self.voxel_indices_in_grid[grid_indices] = torch.arange(self.n_voxels, device=self.device) + + @torch.no_grad() + def _before_load_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, + unexpected_keys, error_msgs): + # Handle buffers + for name, buffer in self.named_buffers(recurse=False): + if name in self._non_persistent_buffers_set: + continue + buffer.resize_as_(state_dict[prefix + name]) + + # Handle embeddings + for name, module in self.named_modules(): + if name.startswith('emb_'): + setattr(self, name, torch.nn.Embedding(self.n_corners.item(), module.embedding_dim)) + + +class Octree(Voxels): + + def __init__(self, **kwargs) -> None: + super().__init__(**kwargs) + self.nodes_cached = None + self.tree_cached = None + + def get(self) -> Tuple[torch.Tensor, torch.Tensor]: + if self.nodes_cached is None: + self.nodes_cached, self.tree_cached = build_easy_octree( + self.voxels, 0.5 * self.voxel_size) + return self.nodes_cached, self.tree_cached + + def clear(self): + self.nodes_cached = None + self.tree_cached = None + + def _ray_intersect(self, rays_o: torch.Tensor, rays_d: torch.Tensor, n_max_hits: int): + nodes, tree = self.get() + return octree_ray_intersect(self.voxel_size, n_max_hits, nodes, tree, rays_o, rays_d) + + @torch.no_grad() + def splitting(self): + ret = super().splitting() + self.clear() + return ret + + @torch.no_grad() + def prune(self, keeps: torch.Tensor) -> Tuple[int, int]: + ret = super().prune(keeps) + self.clear() + return ret diff --git a/nerf++ b/nerf++ deleted file mode 160000 index a30f1a5..0000000 --- a/nerf++ +++ /dev/null @@ -1 +0,0 @@ -Subproject commit a30f1a5ad116e43aad90c426a966b2a3fcedaf7e diff --git a/nets/nerf.py b/nets/nerf.py deleted file mode 100644 index 2d424d4..0000000 --- a/nets/nerf.py +++ /dev/null @@ -1,78 +0,0 @@ -import torch -import torch.nn as nn -from modules import * -from utils import color - - -class Nerf(nn.Module): - - def __init__(self, fc_params, sampler_params, *, - c: int = color.RGB, - n_pos_encode: int = 0, - n_dir_encode: int = None, - coarse_net=None, **kwargs): - """ - Initialize a NeRF unit - - :param fc_params `dict`: parameters for full-connection network - :param sampler_params `dict`: parameters for sampler - :param c `int`: color mode - :param n_pos_encode `int`: encode position to number of dimensions - :param n_dir_encode `int`: encode direction to number of dimensions, `None` means direction is ignored - :param coarse_net `NerfUnit`: optional coarse net - """ - super().__init__() - self.coarse_net = coarse_net - self.color = c - self.coord_chns = 3 - self.color_chns = color.chns(self.color) - - self.pos_encoder = InputEncoder.Get(n_pos_encode, self.coord_chns) - - if n_dir_encode is not None: - self.dir_chns = 3 - self.dir_encoder = InputEncoder.Get(n_dir_encode, self.dir_chns) - else: - self.dir_chns = 0 - self.dir_encoder = None - self.core = NerfCore(coord_chns=self.pos_encoder.out_dim, - density_chns=1, - color_chns=self.color_chns, - core_nf=fc_params['nf'], - core_layers=fc_params['n_layers'], - dir_chns=self.dir_encoder.out_dim if self.dir_encoder else 0, - dir_nf=fc_params['nf'] // 2, - activation=fc_params['activation'], - skips=fc_params['skips']) - sampler_params['spherical'] = False - self.sampler = PdfSampler(**sampler_params) if self.coarse_net is not None \ - else Sampler(**sampler_params) - self.rendering = VolumnRenderer() - - def forward(self, rays_o: torch.Tensor, rays_d: torch.Tensor, *, - ret_depth=False, debug=False) -> torch.Tensor: - """ - rays -> colors - - :param rays_o `Tensor(B, 3)`: rays' origin - :param rays_d `Tensor(B, 3)`: rays' direction - :param prev_ret `Mapping`: - :param ret_depth `bool`: - :return: `Tensor(B, C)``, inferred images/pixels - """ - if self.coarse_net is not None: - coarse_ret = self.coarse_net(rays_o, rays_d, ret_depth=ret_depth, debug=debug) - coords, depths, s_vals, _ = self.sampler(rays_o, rays_d, coarse_ret['sample'], - coarse_ret['weight']) - else: - coords, depths, s_vals, _ = self.sampler(rays_o, rays_d) - coords_encoded = self.pos_encoder(coords) - dirs_encoded = self.dir_encoder(rays_d)[:, None].expand(-1, s_vals.size(-1), -1) \ - if self.dir_encoder is not None else None - colors, densities = self.core(coords_encoded, dirs_encoded) - ret = self.rendering(colors, densities[..., 0], depths, ret_depth=ret_depth, debug=debug) - ret['sample'] = s_vals - if self.coarse_net is not None: - ret['coarse'] = coarse_ret - return ret - diff --git a/nets/nsvf.py b/nets/nsvf.py deleted file mode 100644 index cd1fb7b..0000000 --- a/nets/nsvf.py +++ /dev/null @@ -1,71 +0,0 @@ -import torch -import torch.nn as nn -from modules import * -from utils import color - - -class NSVF(nn.Module): - - def __init__(self, fc_params, sampler_params, *, - c: int = color.RGB, - n_featdim: int = 32, - n_pos_encode: int = 0, - n_dir_encode: int = None, - **kwargs): - """ - Initialize a NSVF model - - :param fc_params `dict`: parameters for full-connection network - :param sampler_params `dict`: parameters for sampler - :param c `int`: color mode - :param n_pos_encode `int`: encode position to number of dimensions - :param n_dir_encode `int`: encode direction to number of dimensions, `None` means direction is ignored - :param coarse_net `NerfUnit`: optional coarse net - """ - super().__init__() - self.color = c - self.coord_chns = n_featdim - self.color_chns = color.chns(self.color) - - self.pos_encoder = InputEncoder.Get(n_pos_encode, self.coord_chns) - if n_dir_encode is not None: - self.dir_chns = 3 - self.dir_encoder = InputEncoder.Get(n_dir_encode, self.dir_chns) - else: - self.dir_chns = 0 - self.dir_encoder = None - self.core = NerfCore(coord_chns=self.pos_encoder.out_dim, - density_chns=1, - color_chns=self.color_chns, - core_nf=fc_params['nf'], - core_layers=fc_params['n_layers'], - dir_chns=self.dir_encoder.out_dim if self.dir_encoder else 0, - dir_nf=fc_params['nf'] // 2, - activation=fc_params['activation'], - skips=fc_params['skips']) - - self.space = OctTreeSpace() - - sampler_params['space'] = self.space - self.sampler = VoxelSampler(**sampler_params) - self.rendering = VolumnRenderer() - - def forward(self, rays_o: torch.Tensor, rays_d: torch.Tensor, *, - ret_depth=False, debug=False) -> torch.Tensor: - """ - rays -> colors - - :param rays_o `Tensor(B, 3)`: rays' origin - :param rays_d `Tensor(B, 3)`: rays' direction - :param prev_ret `Mapping`: - :param ret_depth `bool`: - :return: `Tensor(B, C)``, inferred images/pixels - """ - feats, dirs, z_s, dz_s = self.sampler(rays_o, rays_d) - feats_encoded = self.pos_encoder(feats) - dirs_encoded = self.dir_encoder(rays_d)[:, None].expand(-1, z_s.size(-1), -1) \ - if self.dir_encoder is not None else None - colors, densities = self.core(feats_encoded, dirs_encoded) - ret = self.rendering(colors, densities[..., 0], z_s, dz_s, ret_depth=ret_depth, debug=debug) - return ret - diff --git a/nets/snerf.py b/nets/snerf.py deleted file mode 100644 index b2fdfb0..0000000 --- a/nets/snerf.py +++ /dev/null @@ -1,110 +0,0 @@ -import torch -import torch.nn as nn -from modules import * -from utils import sphere -from utils import color - - -class Snerf(nn.Module): - - def __init__(self, fc_params, sampler_params, *, - n_parts: int = 1, - c: int = color.RGB, - pos_encode: int = 10, - dir_encode: int = None, - spherical_dir: bool = False, **kwargs): - """ - Initialize a multi-sphere-layer net - - :param fc_params: parameters for full-connection network - :param sampler_params: parameters for sampler - :param normalize_coord: whether normalize the spherical coords to [0, 2pi] before encode - :param c: color mode - :param encode_to_dim: encode input to number of dimensions - """ - super().__init__() - self.color = c - self.spherical_dir = spherical_dir - self.n_samples = sampler_params['n_samples'] - self.n_parts = n_parts - self.samples_per_part = self.n_samples // self.n_parts - self.coord_chns = 3 - self.color_chns = color.chns(self.color) - self.pos_encoder = InputEncoder.Get(pos_encode, self.coord_chns) - - if dir_encode is not None: - self.dir_encoder = InputEncoder.Get(dir_encode, 2 if self.spherical_dir else 3) - self.dir_chns_encoded = self.dir_encoder.out_dim - else: - self.dir_encoder = None - self.dir_chns_encoded = 0 - - self.nets = nn.ModuleList( - NerfCore(coord_chns=self.pos_encoder.out_dim, - density_chns=1, - color_chns=self.color_chns, - core_nf=fc_params['nf'], - core_layers=fc_params['n_layers'], - dir_chns=self.dir_chns_encoded, - dir_nf=fc_params['nf'] // 2, - activation=fc_params['activation']) - for _ in range(self.n_parts)) - sampler_params['spherical'] = True - self.sampler = Sampler(**sampler_params) - self.rendering = VolumnRenderer() - - def forward(self, rays_o: torch.Tensor, rays_d: torch.Tensor, - ret_depth=False, debug=False) -> torch.Tensor: - """ - rays -> colors - - :param rays_o `Tensor(B, 3)`: rays' origin - :param rays_d `Tensor(B, 3)`: rays' direction - :return: `Tensor(B, C)``, inferred images/pixels - """ - n_rays = rays_o.size(0) - coords, depths, _, pts = self.sampler(rays_o, rays_d) - coords_encoded = self.pos_encoder(coords) - if self.dir_encoder is not None: - if self.spherical_dir: - dirs_encoded = self.dir_encoder(sphere.calc_local_dir(rays_d, coords, pts)) - else: - dirs_encoded = self.dir_encoder(rays_d)[:, None].expand(-1, self.n_samples, -1) - else: - dirs_encoded = None - - densities = torch.empty(n_rays, self.n_samples, device=device.default()) - colors = torch.empty(n_rays, self.n_samples, self.color_chns, device=device.default()) - for i, net in enumerate(self.nets): - s = slice(i * self.samples_per_part, (i + 1) * self.samples_per_part) - c, d = net(coords_encoded[:, s], - dirs_encoded[:, s] if dirs_encoded is not None else None) - colors[:, s] = c - densities[:, s] = d - ret = self.rendering(colors.view(-1, self.n_samples, self.color_chns), - densities, depths, ret_depth=ret_depth, debug=debug) - if debug: - ret['sample_densities'] = densities - ret['sample_depths'] = depths - return ret - - -class SnerfExport(nn.Module): - - def __init__(self, net: Snerf): - super().__init__() - self.net = net - - def forward(self, coords_encoded, z_vals): - colors = [] - densities = [] - for i in range(self.net.n_parts): - s = slice(i * self.net.samples_per_part, (i + 1) * self.net.samples_per_part) - mlp = self.net.nets[i] if self.net.nets is not None else self.net.net - c, d = mlp(coords_encoded[:, s].flatten(1, 2)) - colors.append(c.view(-1, self.net.samples_per_part, self.net.color_chns)) - densities.append(d) - colors = torch.cat(colors, 1) - densities = torch.cat(densities, 1) - alphas = self.net.rendering.density2alpha(densities, z_vals) - return torch.cat([colors, alphas[..., None]], -1) diff --git a/notebook/gen_crop.ipynb b/notebook/gen_crop.ipynb index 5c191cf..e7f6a9b 100644 --- a/notebook/gen_crop.ipynb +++ b/notebook/gen_crop.ipynb @@ -3,13 +3,11 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, - "outputs": [], "source": [ "import sys\n", "import os\n", "import torch\n", - "import torch.nn as nn\n", + "import torch.nn.functional as nn_f\n", "import matplotlib.pyplot as plt\n", "\n", "rootdir = os.path.abspath(sys.path[0] + '/../')\n", @@ -18,21 +16,16 @@ "print(\"Set CUDA:%d as current device.\" % torch.cuda.current_device())\n", "torch.autograd.set_grad_enabled(False)\n", "\n", - "from data.spherical_view_syn import *\n", - "from configs.spherical_view_syn import SphericalViewSynConfig\n", - "from utils import netio\n", "from utils import img\n", - "from utils import device\n", "from utils.view import *\n", - "from components.fnr import FoveatedNeuralRenderer\n", "\n", "datadir = f\"{rootdir}/data/__new/__demo/for_crop\"\n", "figs = ['our', 'gt', 'nerf', 'fgt']\n", "crops = {\n", - " 'classroom_0': [[720, 800, 128], [1097, 982, 256]],\n", - " 'lobby_1': [[570, 1000, 100], [1049, 1049, 256]],\n", - " 'stones_2': [[720, 800, 100], [680, 1317, 256]],\n", - " 'barbershop_3': [[745, 810, 100], [1135, 627, 256]]\n", + " 'classroom_0': [[720, 790, 100], [370, 1160, 200]],\n", + " 'lobby_1': [[570, 1000, 100], [1300, 1000, 200]],\n", + " 'stones_2': [[720, 800, 100], [680, 1317, 200]],\n", + " 'barbershop_3': [[745, 810, 100], [950, 900, 200]]\n", "}\n", "colors = torch.tensor([[0, 1, 0, 1], [1, 1, 0, 1]], dtype=torch.float)\n", "border = 10\n", @@ -78,16 +71,18 @@ " img.save(torch.cat([fovea_patches, periph_patches], dim=-1),\n", " [f\"{datadir}/patch/{scene}_{fig}.png\" for fig in figs])\n", " img.save(overlay, f\"{datadir}/overlay/{scene}.png\")\n" - ] + ], + "outputs": [], + "metadata": {} } ], "metadata": { "interpreter": { - "hash": "82066b63b621a9e3d15e3b7c11ca76da6238eff3834294910d715044bd0561e5" + "hash": "65406b00395a48e1d89cf658ae895e7869e05878f5469716b06a752a3915211c" }, "kernelspec": { - "display_name": "Python 3.8.5 64-bit ('base': conda)", - "name": "python3" + "name": "python3", + "display_name": "Python 3.8.5 64-bit ('base': conda)" }, "language_info": { "codemirror_mode": { diff --git a/notebook/gen_demo_mono.ipynb b/notebook/gen_demo_mono.ipynb index 0371863..5968a13 100644 --- a/notebook/gen_demo_mono.ipynb +++ b/notebook/gen_demo_mono.ipynb @@ -170,7 +170,7 @@ " images['overlaid'] = renderer.foveation.synthesis(images['layers_raw'], param[-2:], do_blend=False)\n", " if True:\n", " outputdir = '../__demo/mono/'\n", - " misc.create_dir(outputdir)\n", + " os.makedirs(outputdir, exist_ok=True)\n", " img.save(images['layers_img'][0], f'{outputdir}{scene}_{i}_fovea.png')\n", " img.save(images['layers_img'][1], f'{outputdir}{scene}_{i}_mid.png')\n", " img.save(images['layers_img'][2], f'{outputdir}{scene}_{i}_periph.png')\n", @@ -203,7 +203,7 @@ " center = (0, 0)\n", " images = renderer(views.get(view_idx), center, using_mask=True)\n", " outputdir = 'panorama'\n", - " misc.create_dir(outputdir)\n", + " os.makedirs(outputdir, exist_ok=True)\n", " img.save(images['blended'], f'{outputdir}/{view_idx:04d}.png')" ], "outputs": [ diff --git a/notebook/gen_demo_stereo.ipynb b/notebook/gen_demo_stereo.ipynb index f44e65e..138950d 100644 --- a/notebook/gen_demo_stereo.ipynb +++ b/notebook/gen_demo_stereo.ipynb @@ -216,7 +216,7 @@ " ret_raw=False)\n", " if True:\n", " outputdir = '../__demo/stereo_m%d' % mono_periph if mono_periph else '../__demo/stereo'\n", - " misc.create_dir(outputdir)\n", + " os.makedirs(outputdir, exist_ok=True)\n", " img.save(torch.cat([\n", " left_images['blended'],\n", " right_images['blended']\n", @@ -228,7 +228,7 @@ " right_images['blended'][:, 1:3]\n", " ], dim=1)\n", " img.save(stereo_overlap, '%s/%s_%d_stereo.png' % (outputdir, scene, i))\n", - " #misc.create_dir(outputdir + '/mid')\n", + " #os.makedirs(outputdir + '/mid', exist_ok=True)\n", " #img.save(left_images['layers_img'][1], '%s/mid/%s_%d_l.png' % (outputdir, scene, i))\n", " #img.save(right_images['layers_img'][1], '%s/mid/%s_%d_r.png' % (outputdir, scene, i))\n", " print(\"%s %d Saved\" % (scene, i))\n", diff --git a/notebook/gen_for_eval.ipynb b/notebook/gen_for_eval.ipynb index cada07b..a861f25 100644 --- a/notebook/gen_for_eval.ipynb +++ b/notebook/gen_for_eval.ipynb @@ -110,7 +110,7 @@ " #plot_figures(images, center)\n", "\n", " outputdir = '../__1_eval/output_mono_periph/ref_as_right_eye/%s/' % scene\n", - " misc.create_dir(outputdir)\n", + " os.makedirs(outputdir, exist_ok=True)\n", " #for key in images:\n", " key = 'blended'\n", " img.save(images[key], outputdir + 'view%04d_%s.png' % (view_idx, key))\n" @@ -131,7 +131,7 @@ " images = gen.gen(center, test_view, True)\n", " #plot_figures(images, center)\n", "\n", - " misc.create_dir('output/eval_gaze')\n", + " os.makedirs('output/eval_gaze', exist_ok=True)\n", " out_path = 'output/eval_gaze/gaze%03d_%d,%d.png' % (gaze_idx, x, y)\n", " img.save(images['blended'], out_path)\n", " print('Output ' + out_path)\n", diff --git a/notebook/gen_teaser.ipynb b/notebook/gen_teaser.ipynb index ea770f6..2c20f68 100644 --- a/notebook/gen_teaser.ipynb +++ b/notebook/gen_teaser.ipynb @@ -130,7 +130,7 @@ " images = gen.gen(center, test_view, True)\n", " #plot_figures(images, center)\n", "\n", - " misc.create_dir('output/teasers')\n", + " os.makedirs('output/teasers', exist_ok=True)\n", " for key in images:\n", " img.save(\n", " images[key], 'output/teasers/view%04d_%s.png' % (view_idx, key))\n" diff --git a/notebook/gen_test.ipynb b/notebook/gen_test.ipynb index eabf583..cd6197d 100644 --- a/notebook/gen_test.ipynb +++ b/notebook/gen_test.ipynb @@ -150,7 +150,7 @@ "print(\"Encoded:\", encoded)\n", "#plot_figures(images, center)\n", "\n", - "#misc.create_dir('output/teasers')\n", + "#os.makedirs('output/teasers', exist_ok=True)\n", "#for key in images:\n", "# img.save(\n", "# images[key], 'output/teasers/view%04d_%s.png' % (view_idx, key))\n" diff --git a/notebook/gen_user_study_images.ipynb b/notebook/gen_user_study_images.ipynb index 238185d..4518b11 100644 --- a/notebook/gen_user_study_images.ipynb +++ b/notebook/gen_user_study_images.ipynb @@ -188,7 +188,7 @@ "\n", "#plot_figures(left_images, right_images, centers[set_id][0], centers[set_id][1])\n", "\n", - "misc.create_dir('output')\n", + "os.makedirs('output', exist_ok=True)\n", "for key in left_images:\n", " img.save(\n", " left_images[key], 'output/set%d_%s_l.png' % (set_id, key))\n", diff --git a/notebook/gen_video.ipynb b/notebook/gen_video.ipynb index a5feb72..df80a3c 100644 --- a/notebook/gen_video.ipynb +++ b/notebook/gen_video.ipynb @@ -117,7 +117,7 @@ " left_images = gen.gen(left_center, left_view, mono_trans=mono_trans)\n", " right_images = gen.gen(right_center, right_view, mono_trans=mono_trans)\n", " \n", - " misc.create_dir('output/video_frames/hmd2')\n", + " os.makedirs('output/video_frames/hmd2', exist_ok=True)\n", " img.save(torch.cat([left_images['blended'], right_images['blended']], -1),\n", " 'output/video_frames/hmd2/view%04d.png' % view_idx)\n", " print('Frame %d saved' % view_idx)\n" diff --git a/notebook/net_insight.ipynb b/notebook/net_insight.ipynb index e71898e..00eb3e9 100644 --- a/notebook/net_insight.ipynb +++ b/notebook/net_insight.ipynb @@ -155,7 +155,7 @@ " images['overlaid'] = renderer.foveation.synthesis(images['layers_raw'], param[-2:], do_blend=False)\n", " if True:\n", " outputdir = '../__demo/mono/'\n", - " misc.create_dir(outputdir)\n", + " os.makedirs(outputdir, exist_ok=True)\n", " img.save(images['layers_img'][0], f'{outputdir}{scene}_{i}_fovea.png')\n", " img.save(images['layers_img'][1], f'{outputdir}{scene}_{i}_mid.png')\n", " img.save(images['layers_img'][2], f'{outputdir}{scene}_{i}_periph.png')\n", @@ -196,7 +196,7 @@ " center = (0, 0)\n", " images = renderer(views.get(view_idx), center, using_mask=True)\n", " outputdir = 'nerf_our'\n", - " misc.create_dir(outputdir)\n", + " os.makedirs(outputdir, exist_ok=True)\n", " img.save(images['blended'], f'{outputdir}/{view_idx:04d}.png')" ] } diff --git a/notebook/test_mono_gen.ipynb b/notebook/test_mono_gen.ipynb index e278546..545ee19 100644 --- a/notebook/test_mono_gen.ipynb +++ b/notebook/test_mono_gen.ipynb @@ -101,7 +101,7 @@ "gaze = [37.55656052, 20.7297554]\n", "images = renderer(view, gaze, using_mask=False, ret_raw=True)\n", "outputdir = '../__demo/mono_f60&m110/'\n", - "misc.create_dir(outputdir)\n", + "os.makedirs(outputdir, exist_ok=True)\n", "img.save(images['layers_img'][0], f'{outputdir}{scene}_fovea.png')\n", "img.save(images['blended'], f'{outputdir}{scene}.png')\n", "img.save(images['blended_raw'], f'{outputdir}{scene}_noCE.png')" diff --git a/notebook/test_mono_view.ipynb b/notebook/test_mono_view.ipynb index 72d05c5..323287e 100644 --- a/notebook/test_mono_view.ipynb +++ b/notebook/test_mono_view.ipynb @@ -249,7 +249,7 @@ "\n", "plot_figures(left_images, right_images, left_center, right_center)\n", "\n", - "misc.create_dir('output/mono_test')\n", + "os.makedirs('output/mono_test', exist_ok=True)\n", "for key in left_images:\n", " img.save(\n", " left_images[key], 'output/mono_test/set%d_%s_l.png' % (set_id, key))\n", diff --git a/run_lf_syn.py b/run_lf_syn.py index 8c1b5e9..23b3bf1 100644 --- a/run_lf_syn.py +++ b/run_lf_syn.py @@ -58,7 +58,7 @@ def train(): epoch = EPOCH_BEGIN iters = EPOCH_BEGIN * len(train_data_loader) * BATCH_SIZE - misc.create_dir(RUN_DIR) + os.makedirs(RUN_DIR, exist_ok=True) perf = Perf(enable=(MODE == "Perf"), start=True) writer = SummaryWriter(RUN_DIR) @@ -129,7 +129,7 @@ def test(net_file: str): # 3. Test on train dataset print("Begin test on train dataset...") - misc.create_dir(OUTPUT_DIR) + os.makedirs(OUTPUT_DIR, exist_ok=True) for view_idxs, view_images, _, view_positions in train_data_loader: out_view_images = model(view_positions) img.save(view_images, diff --git a/run_spherical_view_syn.py b/run_spherical_view_syn.py index b87daa9..d5ef230 100644 --- a/run_spherical_view_syn.py +++ b/run_spherical_view_syn.py @@ -316,8 +316,8 @@ def train(): if epochRange.start > 1: iters = netio.load(f'{run_dir}model-epoch_{epochRange.start - 1}.pth', model) else: - misc.create_dir(run_dir) - misc.create_dir(log_dir) + os.makedirs(run_dir, exist_ok=True) + os.makedirs(log_dir, exist_ok=True) iters = 0 # 3. Train @@ -400,7 +400,7 @@ def test(): # 4. Save results print('Saving results...') - misc.create_dir(output_dir) + os.makedirs(output_dir, exist_ok=True) for key in out: shape = [n] + list(dataset.res) + list(out[key].size()[1:]) @@ -446,7 +446,7 @@ def test(): img.save_video(out['color'], output_file, 30) else: output_subdir = f"{output_dir}/{output_dataset_id}_color" - misc.create_dir(output_subdir) + os.makedirs(output_subdir, exist_ok=True) img.save(out['color'], [f'{output_subdir}/{i:0>4d}.png' for i in dataset.indices]) if args.output_flags['depth']: @@ -457,13 +457,13 @@ def test(): img.save_video(colorized_depths, output_file, 30) else: output_subdir = f"{output_dir}/{output_dataset_id}_depth" - misc.create_dir(output_subdir) + os.makedirs(output_subdir, exist_ok=True) img.save(colorized_depths, [ f'{output_subdir}/{i:0>4d}.png' for i in dataset.indices ]) output_subdir = f"{output_dir}/{output_dataset_id}_bins" - misc.create_dir(output_subdir) + os.makedirs(output_subdir, exist_ok=True) img.save(out['bins'], [f'{output_subdir}/{i:0>4d}.png' for i in dataset.indices]) if args.output_flags['layers']: @@ -473,7 +473,7 @@ def test(): img.save_video(out['layers'][j], output_file, 30) else: output_subdir = f"{output_dir}/{output_dataset_id}_layers" - misc.create_dir(output_subdir) + os.makedirs(output_subdir, exist_ok=True) for j in range(config.sa['n_samples']): img.save(out['layers'][j], [ f'{output_subdir}/{i:0>4d}[{j:0>3d}].png' @@ -543,7 +543,7 @@ def test1(): # 4. Save results print('Saving results...') - misc.create_dir(output_dir) + os.makedirs(output_dir, exist_ok=True) for key in out: shape = [n] + list(dataset.res) + list(out[key].size()[1:]) @@ -587,7 +587,7 @@ def test1(): img.save_video(out['color'], output_file, 30) else: output_subdir = f"{output_dir}/{output_dataset_id}_color" - misc.create_dir(output_subdir) + os.makedirs(output_subdir, exist_ok=True) img.save(out['color'], [f'{output_subdir}/{i:0>4d}.png' for i in dataset.indices]) if args.output_flags['depth']: @@ -598,7 +598,7 @@ def test1(): img.save_video(colorized_depths, output_file, 30) else: output_subdir = f"{output_dir}/{output_dataset_id}_depth" - misc.create_dir(output_subdir) + os.makedirs(output_subdir, exist_ok=True) img.save(colorized_depths, [ f'{output_subdir}/{i:0>4d}.png' for i in dataset.indices @@ -611,7 +611,7 @@ def test1(): img.save_video(out['layers'][j], output_file, 30) else: output_subdir = f"{output_dir}/{output_dataset_id}_layers" - misc.create_dir(output_subdir) + os.makedirs(output_subdir, exist_ok=True) for j in range(config.sa['n_samples']): img.save(out['layers'][j], [ f'{output_subdir}/{i:0>4d}[{j:0>3d}].png' @@ -679,7 +679,7 @@ def test2(): # 4. Save results print('Saving results...') - misc.create_dir(output_dir) + os.makedirs(output_dir, exist_ok=True) for key in out: shape = [n] + list(dataset.res) + list(out[key].size()[1:]) @@ -723,7 +723,7 @@ def test2(): img.save_video(out['color'], output_file, 30) else: output_subdir = f"{output_dir}/{output_dataset_id}_color" - misc.create_dir(output_subdir) + os.makedirs(output_subdir, exist_ok=True) img.save(out['color'], [f'{output_subdir}/{i:0>4d}.png' for i in dataset.indices]) if args.output_flags['depth']: @@ -734,7 +734,7 @@ def test2(): img.save_video(colorized_depths, output_file, 30) else: output_subdir = f"{output_dir}/{output_dataset_id}_depth" - misc.create_dir(output_subdir) + os.makedirs(output_subdir, exist_ok=True) img.save(colorized_depths, [ f'{output_subdir}/{i:0>4d}.png' for i in dataset.indices @@ -747,7 +747,7 @@ def test2(): img.save_video(out['layers'][j], output_file, 30) else: output_subdir = f"{output_dir}/{output_dataset_id}_layers" - misc.create_dir(output_subdir) + os.makedirs(output_subdir, exist_ok=True) for j in range(config.sa['n_samples']): img.save(out['layers'][j], [ f'{output_subdir}/{i:0>4d}[{j:0>3d}].png' diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..26d489e --- /dev/null +++ b/setup.py @@ -0,0 +1,27 @@ +from setuptools import setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension +import glob +import os +import sys + +# build clib +src_root = "clib" +sources = glob.glob(f"{src_root}/src/*.cpp") + glob.glob(f"{src_root}/src/*.cu") +includes = f"{sys.path[0]}/{src_root}/include" + +setup( + name='dvs', + ext_modules=[ + CUDAExtension( + name='clib._ext', + sources=sources, + extra_compile_args={ + "cxx": ["-O2", f"-I{includes}"], + "nvcc": ["-O2", f"-I{includes}"], + }, + ) + ], + cmdclass={ + 'build_ext': BuildExtension + } +) \ No newline at end of file diff --git a/term_test.py b/term_test.py new file mode 100644 index 0000000..249bf9a --- /dev/null +++ b/term_test.py @@ -0,0 +1,15 @@ +import os +import shutil +from sys import stdout +from time import sleep +from utils.progress_bar import * + +i = 0 +while True: + rows = shutil.get_terminal_size().lines + cols = shutil.get_terminal_size().columns + os.system('cls' if os.name == 'nt' else 'clear') + stdout.write("\n" * (rows - 1)) + progress_bar(i, 10000, "Test", "XXX") + i += 1 + sleep(0.02) diff --git a/test.py b/test.py new file mode 100644 index 0000000..eab0340 --- /dev/null +++ b/test.py @@ -0,0 +1,226 @@ +import os +import argparse +import torch +import torch.nn.functional as nn_f +from math import nan, ceil, prod +from pathlib import Path + +parser = argparse.ArgumentParser() +parser.add_argument('-m', '--model', type=str, + help='The model file to load for testing') +parser.add_argument('-r', '--output-res', type=str, + help='Output resolution') +parser.add_argument('-o', '--output', nargs='+', type=str, default=['perf', 'color'], + help='Specify what to output (perf, color, depth, all)') +parser.add_argument('--output-type', type=str, default='image', + help='Specify the output type (image, video, debug)') +parser.add_argument('--views', type=str, + help='Specify the range of views to test') +parser.add_argument('-p', '--prompt', action='store_true', + help='Interactive prompt mode') +parser.add_argument('--time', action='store_true', + help='Enable time measurement') +parser.add_argument('dataset', type=str, + help='Dataset description file') +args = parser.parse_args() + + +import model as mdl +from loss.ssim import ssim +from utils import color +from utils import interact +from utils import device +from utils import img +from utils.perf import Perf, enable_perf, get_perf_result +from utils.progress_bar import progress_bar +from data.dataset_factory import * +from data.loader import DataLoader +from utils.constants import HUGE_FLOAT + + +RAYS_PER_BATCH = 2 ** 14 +DATA_LOADER_CHUNK_SIZE = 1e8 + + +data_desc_path = DatasetFactory.get_dataset_desc_path(args.dataset) +os.chdir(data_desc_path.parent) +nets_dir = Path("_nets") +data_desc_path = data_desc_path.name + + +def set_outputs(args, outputs_str: str): + args.output = [s.strip() for s in outputs_str.split(',')] + + +if args.prompt: # Prompt test model, output resolution, output mode + model_files = [str(path.relative_to(nets_dir)) for path in nets_dir.rglob("*.tar")] \ + + [str(path.relative_to(nets_dir)) for path in nets_dir.rglob("*.pth")] + args.model = interact.input_enum('Specify test model:', model_files, + err_msg='No such model file') + args.output_res = interact.input_ex('Specify output resolution:', + default='') + set_outputs(args, interact.input_ex('Specify the outputs | [perf,color,depth,layers,diffuse,specular]/all:', + default='perf,color')) + args.output_type = interact.input_enum('Specify the output type | image/video:', + ['image', 'video'], + err_msg='Wrong output type', + default='image') +args.output_res = tuple(int(s) for s in reversed(args.output_res.split('x'))) if args.output_res \ + else None +args.output_flags = { + item: item in args.output or 'all' in args.output + for item in ['perf', 'color', 'depth', 'layers', 'diffuse', 'specular'] +} +args.views = range(*[int(val) for val in args.views.split('-')]) if args.views else None + +if args.time: + enable_perf() + +dataset = DatasetFactory.load(data_desc_path, res=args.output_res, + load_images=args.output_flags['perf'], + views_to_load=args.views) +print(f"Dataset loaded: {dataset.root}/{dataset.name}") + + +model_path: Path = nets_dir / args.model +model_name = model_path.parent.name +model = mdl.load(model_path, { + "raymarching_early_stop_tolerance": 0.01, + # "raymarching_chunk_size_or_sections": [8], + "perturb_sample": False +})[0].to(device.default()).eval() +model_class = model.__class__.__name__ +model_args = model.args +print(f"model: {model_name} ({model_class})") +print("args:", json.dumps(model.args0)) + +run_dir = model_path.parent +output_dir = run_dir / f"output_{int(model_path.stem.split('_')[-1])}" +output_dataset_id = '%s%s' % ( + dataset.name, + f'_{args.output_res[1]}x{args.output_res[0]}' if args.output_res else '' +) + + +if __name__ == "__main__": + with torch.no_grad(): + # 1. Initialize data loader + data_loader = DataLoader(dataset, RAYS_PER_BATCH, chunk_max_items=DATA_LOADER_CHUNK_SIZE, + shuffle=False, enable_preload=True, + color=color.from_str(model.args['color'])) + + # 3. Test on dataset + print("Begin test, batch size is %d" % RAYS_PER_BATCH) + + i = 0 + offset = 0 + chns = model.chns('color') + n = dataset.n_views + total_pixels = prod([n, *dataset.res]) + + out = {} + if args.output_flags['perf'] or args.output_flags['color']: + out['color'] = torch.zeros(total_pixels, chns, device=device.default()) + if args.output_flags['diffuse']: + out['diffuse'] = torch.zeros(total_pixels, chns, device=device.default()) + if args.output_flags['specular']: + out['specular'] = torch.zeros(total_pixels, chns, device=device.default()) + if args.output_flags['depth']: + out['depth'] = torch.full([total_pixels, 1], HUGE_FLOAT, device=device.default()) + gt_images = torch.empty_like(out['color']) if dataset.image_path else None + + tot_time = 0 + tot_iters = len(data_loader) + progress_bar(i, tot_iters, 'Inferring...') + for _, rays_o, rays_d, extra in data_loader: + if args.output_flags['perf']: + test_perf = Perf.Node("Test") + n_rays = rays_o.size(0) + idx = slice(offset, offset + n_rays) + ret = model(rays_o, rays_d, extra_outputs=[key for key in out.keys() if key != 'color']) + if ret is not None: + for key in out: + out[key][idx][ret['rays_mask']] = ret[key] + if args.output_flags['perf']: + test_perf.close() + torch.cuda.synchronize() + tot_time += test_perf.duration() + if gt_images is not None: + gt_images[idx] = extra['color'] + i += 1 + progress_bar(i, tot_iters, 'Inferring...') + offset += n_rays + + # 4. Save results + print('Saving results...') + output_dir.mkdir(parents=True, exist_ok=True) + + for key in out: + out[key] = out[key].reshape([n, *dataset.res, *out[key].shape[1:]]) + if 'color' in out: + out['color'] = out['color'].permute(0, 3, 1, 2) + if 'diffuse' in out: + out['diffuse'] = out['diffuse'].permute(0, 3, 1, 2) + if 'specular' in out: + out['specular'] = out['specular'].permute(0, 3, 1, 2) + + if args.output_flags['perf']: + perf_errors = torch.full([n], nan) + perf_ssims = torch.full([n], nan) + if gt_images is not None: + gt_images = gt_images.reshape(n, *dataset.res, chns).permute(0, 3, 1, 2) + for i in range(n): + perf_errors[i] = nn_f.mse_loss(gt_images[i], out['color'][i]).item() + perf_ssims[i] = ssim(gt_images[i:i + 1], out['color'][i:i + 1]).item() * 100 + perf_mean_time = tot_time / n + perf_mean_error = torch.mean(perf_errors).item() + perf_name = f'perf_{output_dataset_id}_{perf_mean_time:.1f}ms_{perf_mean_error:.2e}.csv' + + # Remove old performance reports + for file in output_dir.glob(f'perf_{output_dataset_id}*'): + file.unlink() + + # Save new performance reports + with (output_dir / perf_name).open('w') as fp: + fp.write('View, PSNR, SSIM\n') + fp.writelines([ + f'{dataset.indices[i]}, ' + f'{img.mse2psnr(perf_errors[i].item()):.2f}, {perf_ssims[i].item():.2f}\n' + for i in range(n) + ]) + + for output_type in ['color', 'diffuse', 'specular']: + if not args.output_flags[output_type]: + continue + if args.output_type == 'video': + output_file = output_dir / f"{output_dataset_id}_{output_type}.mp4" + img.save_video(out[output_type], output_file, 30) + else: + output_subdir = output_dir / f"{output_dataset_id}_{output_type}" + output_subdir.mkdir(exist_ok=True) + img.save(out[output_type], + [f'{output_subdir}/{i:0>4d}.png' for i in dataset.indices]) + + if args.output_flags['depth']: + colored_depths = img.colorize_depthmap(out['depth'][..., 0], model_args['sample_range']) + if args.output_type == 'video': + output_file = output_dir / f"{output_dataset_id}_depth.mp4" + img.save_video(colored_depths, output_file, 30) + else: + output_subdir = output_dir / f"{output_dataset_id}_depth" + output_subdir.mkdir(exist_ok=True) + img.save(colored_depths, [f'{output_subdir}/{i:0>4d}.png' for i in dataset.indices]) + #output_subdir = output_dir / f"{output_dataset_id}_bins" + # output_dir.mkdir(exist_ok=True) + #img.save(out['bins'], [f'{output_subdir}/{i:0>4d}.png' for i in dataset.indices]) + + if args.time: + s = "Performance Report ==>\n" + res = get_perf_result() + if res is None: + s += "No available data.\n" + else: + for key, val in res.items(): + path_segs = key.split("/") + s += " " * (len(path_segs) - 1) + f"{path_segs[-1]}: {val:.1f}ms\n" + print(s) diff --git a/tools/clean_nets.py b/tools/clean_nets.py index 8140163..e2ac2e4 100644 --- a/tools/clean_nets.py +++ b/tools/clean_nets.py @@ -1,21 +1,26 @@ """ -Clean trained nets (*/model-epoch_#.pth) whose epoch is neither the largest nor a multiple of 50 +Clean trained nets (*/checkpoint_#.tar) whose epoch is neither the largest nor a multiple of 10 """ import sys import os -sys.path.append(os.path.abspath(sys.path[0] + '/../')) - +base_dir = os.path.abspath(sys.path[0] + '/../') +sys.path.append(base_dir) if __name__ == "__main__": - for dirpath, _, filenames in os.walk('../data'): - epoch_list = [int(filename[12:-4]) for filename in filenames - if filename.startswith("model-epoch_")] + root = sys.argv[1] if len(sys.argv) > 1 else f'{base_dir}/data' + print(f"Clean model files in {root}...") + + for dirpath, _, filenames in os.walk(root): + epoch_list = [int(filename[11:-4]) for filename in filenames + if filename.startswith("checkpoint_")] if len(epoch_list) <= 1: continue epoch_list.sort() for epoch in epoch_list[:-1]: - if epoch % 50 != 0: - file_to_del = f"{dirpath}/model-epoch_{epoch}.pth" + if epoch % 10 != 0: + file_to_del = f"{dirpath}/checkpoint_{epoch}.tar" print(f"Clean model file: {file_to_del}") - os.remove(file_to_del) \ No newline at end of file + os.remove(file_to_del) + + print("Finished.") \ No newline at end of file diff --git a/tools/depth_downsample.py b/tools/depth_downsample.py index fd0a5d4..684fd6d 100644 --- a/tools/depth_downsample.py +++ b/tools/depth_downsample.py @@ -18,6 +18,6 @@ os.chdir(in_set) depthmaps = img.load(img_names) depthmaps = torch.floor((depthmaps * 16)) / 16 -misc.create_dir(out_set) +os.makedirs(out_set, exist_ok=True) os.chdir(out_set) img.save(depthmaps, img_names) \ No newline at end of file diff --git a/tools/export_msl.py b/tools/export_msl.py index 79b4c1a..e006e56 100644 --- a/tools/export_msl.py +++ b/tools/export_msl.py @@ -74,7 +74,7 @@ if __name__ == "__main__": # Load model` net, name = load_net(model_file) - misc.create_dir(os.path.join(opt.outdir, config.to_id())) + os.makedirs(os.path.join(opt.outdir, config.to_id()), exist_ok=True) # Export Sampler export_net(ExportNet(net), 'msl', { diff --git a/tools/export_nmsl.py b/tools/export_nmsl.py index af5e6c7..1c72a77 100644 --- a/tools/export_nmsl.py +++ b/tools/export_nmsl.py @@ -74,7 +74,7 @@ if __name__ == "__main__": # Load model` net, name = load_net(model_file) - misc.create_dir(os.path.join(opt.outdir, config.to_id())) + os.makedirs(os.path.join(opt.outdir, config.to_id()), exist_ok=True) # Export Sampler export_net(Sampler(net), 'sampler', { diff --git a/tools/export_onnx.py b/tools/export_onnx.py index f2745b3..f8a247c 100644 --- a/tools/export_onnx.py +++ b/tools/export_onnx.py @@ -54,7 +54,7 @@ if __name__ == "__main__": rays_o = torch.empty(batch_size, 3, device=device.default()) rays_d = torch.empty(batch_size, 3, device=device.default()) - misc.create_dir(opt.outdir) + os.makedirs(opt.outdir, exist_ok=True) # Export the model outpath = os.path.join(opt.outdir, config.to_id() + ".onnx") diff --git a/tools/export_snerf_fast.py b/tools/export_snerf_fast.py index 68a955b..a2d036c 100644 --- a/tools/export_snerf_fast.py +++ b/tools/export_snerf_fast.py @@ -44,7 +44,7 @@ if not opt.output: else: outdir = f"{dir_path}/export" output = os.path.join(outdir, f"{model_file.split('@')[0]}@{batch_size_str}.onnx") - misc.create_dir(outdir) + os.makedirs(outdir, exist_ok=True) else: output = opt.output outname = os.path.splitext(os.path.split(output)[-1])[0] diff --git a/tools/gen_video.py b/tools/gen_video.py index f78c14c..2c2e73a 100644 --- a/tools/gen_video.py +++ b/tools/gen_video.py @@ -172,7 +172,7 @@ print('Dataset loaded. Views:', n_views) videodir = os.path.dirname(os.path.abspath(opt.view_file)) -tempdir = '/dev/shm/dvs_tmp/realvideo' +tempdir = '/dev/shm/dvs_tmp/video' videoname = f"{os.path.splitext(os.path.split(opt.view_file)[-1])[0]}_{'stereo' if opt.stereo else 'mono'}" gazeout = f"{videodir}/{videoname}_gaze.csv" if opt.noCE: @@ -220,8 +220,8 @@ def add_hint(image, center, right_center=None): exit() -misc.create_dir(os.path.dirname(inferout)) -misc.create_dir(os.path.dirname(hintout)) +os.makedirs(os.path.dirname(inferout), exist_ok=True) +os.makedirs(os.path.dirname(hintout), exist_ok=True) hint_offset = infer_offset = 0 if not opt.replace: diff --git a/tools/image_scale.py b/tools/image_scale.py index 64737ef..f3e8812 100644 --- a/tools/image_scale.py +++ b/tools/image_scale.py @@ -8,7 +8,7 @@ from utils import misc def batch_scale(src, target, size): - misc.create_dir(target) + os.makedirs(target, exist_ok=True) for file_name in os.listdir(src): postfix = os.path.splitext(file_name)[1] if postfix == '.jpg' or postfix == '.png': diff --git a/tools/merge_dataset.py b/tools/merge_dataset.py index 31a4f50..0b5c832 100644 --- a/tools/merge_dataset.py +++ b/tools/merge_dataset.py @@ -11,7 +11,7 @@ from utils import misc def copy_images(src_path, dst_path, n, offset=0): - misc.create_dir(os.path.dirname(dst_path)) + os.makedirs(os.path.dirname(dst_path), exist_ok=True) for i in range(n): copy(src_path % i, dst_path % (i + offset)) diff --git a/tools/pano_process.py b/tools/pano_process.py new file mode 100644 index 0000000..fca52c7 --- /dev/null +++ b/tools/pano_process.py @@ -0,0 +1,36 @@ +from pathlib import Path +import sys +import argparse +import math +import torch +import torchvision.transforms.functional as trans_F + +sys.path.append(str(Path(sys.path[0]).parent.absolute())) + +from utils import img + +parser = argparse.ArgumentParser() +parser.add_argument('-o', '--output', type=str) +parser.add_argument('dir', type=str) +args = parser.parse_args() + +data_dir = Path(args.dir) +output_dir = Path(args.output) +output_dir.mkdir(parents=True, exist_ok=True) + +files = [file for file in data_dir.glob('*') if file.suffix == '.png' or file.suffix == '.jpg'] +outfiles = [output_dir / file.name for file in data_dir.glob('*') + if file.suffix == '.png' or file.suffix == '.jpg'] +images = img.load(files) +print(f"{images.size(0)} images loaded.") +out_images = torch.zeros_like(images) +H, W = images.shape[-2:] +for row in range(H): + phi = math.pi / H * (row + 0.5) + length = math.ceil(math.sin(phi) * W * 0.5) * 2 + cols = slice((W - length) // 2, (W + length) // 2) + out_images[..., row:row + 1, cols] = trans_F.resize(images[..., row:row + 1, :], [1, length]) + sys.stdout.write(f'{row + 1} / {H} processed. \r') +print('') +img.save(out_images, outfiles) +print(f"{images.size(0)} images saved.") \ No newline at end of file diff --git a/tools/split_dataset.py b/tools/split_dataset.py index 1483b1f..3ed4fa2 100644 --- a/tools/split_dataset.py +++ b/tools/split_dataset.py @@ -4,26 +4,33 @@ import os import argparse import numpy as np import torch +from itertools import product, repeat +from pathlib import Path sys.path.append(os.path.abspath(sys.path[0] + '/../')) -from utils import misc - parser = argparse.ArgumentParser() parser.add_argument('-o', '--output', type=str, default='train1') +parser.add_argument("-t", "--trans", type=float) +parser.add_argument("-v", "--views", type=int) +parser.add_argument('-g', '--grids', nargs='+', type=int) parser.add_argument('dataset', type=str) args = parser.parse_args() +if not args.dataset.endswith(".json"): + args.dataset = args.dataset.rstrip("/") + ".json" +if not args.output.endswith(".json"): + args.output = args.output.rstrip("/") + ".json" -data_desc_path = args.dataset -data_desc_name = os.path.splitext(os.path.basename(data_desc_path))[0] -data_dir = os.path.dirname(data_desc_path) + '/' +in_desc_path = Path(args.dataset) +in_name = in_desc_path.stem +root_dir = in_desc_path.parent +out_desc_path: Path = root_dir / args.output +out_dir = out_desc_path.with_suffix("") -with open(data_desc_path, 'r') as fp: +with open(in_desc_path, 'r') as fp: dataset_desc = json.load(fp) -indices = torch.arange(len(dataset_desc['view_centers'])).view(dataset_desc['samples']) - idx = 0 ''' for i in range(3): @@ -40,7 +47,7 @@ for i in range(3): out_desc['view_rots'] = np.array(dataset_desc['view_rots'])[views].tolist() with open(os.path.join(data_dir, f'{out_desc_name}.json'), 'w') as fp: json.dump(out_desc, fp, indent=4) - misc.create_dir(os.path.join(data_dir, out_desc_name)) + os.makedirs(os.path.join(data_dir, out_desc_name), exist_ok=True) for k in range(len(views)): os.symlink(os.path.join('..', dataset_desc['view_file_pattern'] % views[k]), os.path.join(data_dir, out_desc['view_file_pattern'] % views[k])) @@ -61,26 +68,62 @@ for xi in range(0, 4, 2): out_desc['view_rots'] = np.array(dataset_desc['view_rots'])[views].tolist() with open(os.path.join(data_dir, f'{out_desc_name}.json'), 'w') as fp: json.dump(out_desc, fp, indent=4) - misc.create_dir(os.path.join(data_dir, out_desc_name)) + os.makedirs(os.path.join(data_dir, out_desc_name), exist_ok=True) for k in range(len(views)): os.symlink(os.path.join('..', dataset_desc['view_file_pattern'] % views[k]), os.path.join(data_dir, out_desc['view_file_pattern'] % views[k])) idx += 1 ''' -from itertools import product -out_desc_name = args.output + + +def extract_by_grid(*grid_indices): + indices = torch.arange(len(dataset_desc['view_centers'])).view(dataset_desc['samples']) + views = [] + for idx in product(*grid_indices): + views += indices[idx].flatten().tolist() + return views + + +def extract_by_trans(max_trans, max_views): + if max_trans is not None: + centers = np.array(dataset_desc['view_centers']) + trans = np.linalg.norm(centers, axis=-1) + indices = np.nonzero(trans <= max_trans)[0] + else: + indices = np.arange(len(dataset_desc['view_centers'])) + if max_views is not None: + indices = np.sort(indices[np.random.permutation(indices.shape[0])[:max_views]]) + return indices.tolist() + + +if args.grids: + views = extract_by_grid(*repeat(args.grids, 3)) # , [0, 2, 3, 5], [1, 2, 3, 4]) +else: + views = extract_by_trans(args.trans, args.views) + +image_path = dataset_desc['view_file_pattern'] +if "/" not in image_path: + image_path = in_name + "/" + image_path + +# Save new dataset out_desc = dataset_desc.copy() -out_desc['view_file_pattern'] = f"{out_desc_name}/{dataset_desc['view_file_pattern'].split('/')[-1]}" -views = [] -for idx in product([1,2,3,4], [1,2,3,4], [1,2,3,4]):#, [0, 2, 3, 5], [1, 2, 3, 4]): - views += indices[idx].flatten().tolist() +out_desc['view_file_pattern'] = image_path.split('/')[-1] out_desc['samples'] = [len(views)] out_desc['views'] = views out_desc['view_centers'] = np.array(dataset_desc['view_centers'])[views].tolist() -out_desc['view_rots'] = np.array(dataset_desc['view_rots'])[views].tolist() -with open(os.path.join(data_dir, f'{out_desc_name}.json'), 'w') as fp: +if 'view_rots' in dataset_desc: + out_desc['view_rots'] = np.array(dataset_desc['view_rots'])[views].tolist() + +# Write new data desc +with open(out_desc_path, 'w') as fp: json.dump(out_desc, fp, indent=4) -misc.create_dir(os.path.join(data_dir, out_desc_name)) + +# Create symbol links of images +out_dir.mkdir() for k in range(len(views)): - os.symlink(os.path.join('..', dataset_desc['view_file_pattern'] % views[k]), - os.path.join(data_dir, out_desc['view_file_pattern'] % views[k])) + if out_dir.parent.absolute() == root_dir.absolute(): + os.symlink(Path("..") / (image_path % views[k]), + out_dir / (out_desc['view_file_pattern'] % views[k])) + else: + os.symlink(root_dir.absolute() / (image_path % views[k]), + out_dir / (out_desc['view_file_pattern'] % views[k])) diff --git a/train.py b/train.py new file mode 100644 index 0000000..6792a3e --- /dev/null +++ b/train.py @@ -0,0 +1,103 @@ +import argparse +import logging +import os +from pathlib import Path +import sys + +import model as mdl +import train +from utils import color +from utils import device +from data.dataset_factory import * +from data.loader import DataLoader +from utils.misc import list_epochs, print_and_log + + +RAYS_PER_BATCH = 2 ** 16 +DATA_LOADER_CHUNK_SIZE = 1e8 + + +parser = argparse.ArgumentParser() +parser.add_argument('-c', '--config', type=str, + help='Net config files') +parser.add_argument('-e', '--epochs', type=int, default=50, + help='Max epochs for train') +parser.add_argument('--perf', type=int, default=0, + help='Performance measurement frames (0 for disabling performance measurement)') +parser.add_argument('--prune', type=int, default=5, + help='Prune voxels on every # epochs') +parser.add_argument('--split', type=int, default=10, + help='Split voxels on every # epochs') +parser.add_argument('--views', type=str, + help='Specify the range of views to train') +parser.add_argument('path', type=str, + help='Dataset description file') +args = parser.parse_args() + +argpath = Path(args.path) +# argpath: May be model path or data path +# 1) model path: continue training on the specified model +# 2) data path: train a new model using specified dataset + +if argpath.suffix == ".tar": + args.mdl_path = argpath +else: + existed_epochs = list_epochs(argpath, "checkpoint_*.tar") + args.mdl_path = argpath / f"checkpoint_{existed_epochs[-1]}.tar" if existed_epochs else None + +if args.mdl_path: + # Infer dataset path from model path + # The model path follows such rule: <dataset_dir>/_nets/<dataset_name>/<model_name>/checkpoint_*.tar + dataset_name = args.mdl_path.parent.parent.name + dataset_dir = args.mdl_path.parent.parent.parent.parent + args.data_path = dataset_dir / dataset_name + args.mdl_path = args.mdl_path.relative_to(dataset_dir) +else: + args.data_path = argpath +args.views = range(*[int(val) for val in args.views.split('-')]) if args.views else None + +dataset = DatasetFactory.load(args.data_path, views_to_load=args.views) +print(f"Dataset loaded: {dataset.root}/{dataset.name}") +os.chdir(dataset.root) + +if args.mdl_path: + # Load model to continue training + model, states = mdl.load(args.mdl_path) + model_name = args.mdl_path.parent.name + model_class = model.__class__.__name__ + model_args = model.args +else: + # Create model from specified configuration + with Path(f'{sys.path[0]}/configs/{args.config}.json').open() as fp: + config = json.load(fp) + model_name = args.config + model_class = config['model'] + model_args = config['args'] + model_args['bbox'] = dataset.bbox + model_args['depth_range'] = dataset.depth_range + model, states = mdl.create(model_class, model_args), None +model.to(device.default()).train() + +run_dir = Path(f"_nets/{dataset.name}/{model_name}") +run_dir.mkdir(parents=True, exist_ok=True) + +log_file = run_dir / "train.log" +logging.basicConfig(format='%(asctime)s[%(levelname)s] %(message)s', level=logging.INFO, + filename=log_file, filemode='a' if log_file.exists() else 'w') + +print_and_log(f"model: {model_name} ({model_class})") +print_and_log(f"args: {json.dumps(model.args0)}") + + +if __name__ == "__main__": + # 1. Initialize data loader + data_loader = DataLoader(dataset, RAYS_PER_BATCH, chunk_max_items=DATA_LOADER_CHUNK_SIZE, + shuffle=True, enable_preload=True, + color=color.from_str(model.args['color'])) + + # 2. Initialize model and trainer + trainer = train.get_trainer(model, run_dir=run_dir, states=states, perf_frames=args.perf, + pruning_loop=args.prune, splitting_loop=args.split) + + # 3. Train + trainer.train(data_loader, args.epochs) \ No newline at end of file diff --git a/train/__init__.py b/train/__init__.py new file mode 100644 index 0000000..bb0a4bc --- /dev/null +++ b/train/__init__.py @@ -0,0 +1,26 @@ +import importlib +import os + +from model.base import BaseModel +from . import base + + +# Automatically import any python files this directory +package_dir = os.path.dirname(__file__) +package = os.path.basename(package_dir) +for file in os.listdir(package_dir): + path = os.path.join(package_dir, file) + if file.startswith('_') or file.startswith('.'): + continue + if file.endswith('.py') or os.path.isdir(path): + model_name = file[:-3] if file.endswith('.py') else file + importlib.import_module(f'{package}.{model_name}') + + +def get_class(class_name: str) -> type: + return base.train_classes[class_name] + + +def get_trainer(model: BaseModel, **kwargs) -> base.Train: + train_class = get_class(model.trainer) + return train_class(model, **kwargs) diff --git a/train/base.py b/train/base.py new file mode 100644 index 0000000..78dd048 --- /dev/null +++ b/train/base.py @@ -0,0 +1,225 @@ +import csv +import logging +import sys +import time +import torch +import torch.nn.functional as nn_f +from typing import Dict +from pathlib import Path + +import loss +from utils.constants import HUGE_FLOAT +from utils.misc import format_time +from utils.progress_bar import progress_bar +from utils.perf import Perf, checkpoint, enable_perf, perf, get_perf_result +from data.loader import DataLoader +from model.base import BaseModel +from model import save + + +train_classes = {} + + +class BaseTrainMeta(type): + + def __new__(cls, name, bases, attrs): + new_cls = type.__new__(cls, name, bases, attrs) + train_classes[name] = new_cls + return new_cls + + +class Train(object, metaclass=BaseTrainMeta): + + @property + def perf_mode(self): + return self.perf_frames > 0 + + def __init__(self, model: BaseModel, *, + run_dir: Path, states: dict = None, perf_frames: int = 0) -> None: + super().__init__() + self.model = model + self.epoch = 0 + self.iters = 0 + self.run_dir = run_dir + + self.model.train() + self.optimizer = torch.optim.Adam(self.model.parameters(), lr=5e-4) + + if states: + if 'epoch' in states: + self.epoch = states['epoch'] + if 'iters' in states: + self.iters = states['iters'] + if 'opti' in states: + self.optimizer.load_state_dict(states['opti']) + + # For performance measurement + self.perf_frames = perf_frames + if self.perf_mode: + enable_perf() + + def train(self, data_loader: DataLoader, max_epochs: int): + self.data_loader = data_loader + self.iters_per_epoch = self.perf_frames or len(data_loader) + + print("Begin training...") + while self.epoch < max_epochs: + self.epoch += 1 + self._train_epoch() + self._save_checkpoint() + print("Train finished") + + def _save_checkpoint(self): + save(self.run_dir / f'checkpoint_{self.epoch}.tar', self.model, epoch=self.epoch, + iters=self.iters, opti=self.optimizer.state_dict()) + for i in range(1, self.epoch): + if i % 10 != 0: + (self.run_dir / f'checkpoint_{i}.tar').unlink(missing_ok=True) + + def _show_progress(self, iters_in_epoch: int, loss: Dict[str, float] = {}): + loss_val = loss.get('val', 0) + loss_min = loss.get('min', 0) + loss_max = loss.get('max', 0) + loss_avg = loss.get('avg', 0) + iters_per_epoch = self.perf_frames or len(self.data_loader) + progress_bar(iters_in_epoch, iters_per_epoch, + f"Loss: {loss_val:.2e} ({loss_min:.2e}/{loss_avg:.2e}/{loss_max:.2e})", + f"Epoch {self.epoch:<3d}", + f" {self.run_dir}") + + def _show_perf(self): + s = "Performance Report ==>\n" + res = get_perf_result() + if res is None: + s += "No available data.\n" + else: + for key, val in res.items(): + path_segs = key.split("/") + s += " " * (len(path_segs) - 1) + f"{path_segs[-1]}: {val:.1f}ms\n" + print(s) + + @perf + def _train_iter(self, rays_o: torch.Tensor, rays_d: torch.Tensor, + extra: Dict[str, torch.Tensor]) -> float: + out = self.model(rays_o, rays_d, extra_outputs=['energies', 'speculars']) + if 'rays_mask' in out: + extra = {key: value[out['rays_mask']] for key, value in extra.items()} + checkpoint("Forward") + + self.optimizer.zero_grad() + loss_val = loss.mse_loss(out['color'], extra['color']) + if self.model.args.get('density_regularization_weight'): + loss_val += loss.cauchy_loss(out['energies'], + s=self.model.args['density_regularization_scale']) \ + * self.model.args['density_regularization_weight'] + if self.model.args.get('specular_regularization_weight'): + loss_val += loss.cauchy_loss(out['speculars'], + s=self.model.args['specular_regularization_scale']) \ + * self.model.args['specular_regularization_weight'] + checkpoint("Compute loss") + + loss_val.backward() + checkpoint("Backward") + + self.optimizer.step() + checkpoint("Update") + + return loss_val.item() + + def _train_epoch(self): + iters_in_epoch = 0 + loss_min = HUGE_FLOAT + loss_max = 0 + loss_avg = 0 + train_epoch_node = Perf.Node("Train Epoch") + + self._show_progress(iters_in_epoch, loss={'val': 0, 'min': 0, 'max': 0, 'avg': 0}) + for idx, rays_o, rays_d, extra in self.data_loader: + loss_val = self._train_iter(rays_o, rays_d, extra) + + loss_min = min(loss_min, loss_val) + loss_max = max(loss_max, loss_val) + loss_avg = (loss_avg * iters_in_epoch + loss_val) / (iters_in_epoch + 1) + + self.iters += 1 + iters_in_epoch += 1 + self._show_progress(iters_in_epoch, loss={ + 'val': loss_val, + 'min': loss_min, + 'max': loss_max, + 'avg': loss_avg + }) + + if self.perf_mode and iters_in_epoch >= self.perf_frames: + self._show_perf() + exit() + train_epoch_node.close() + torch.cuda.synchronize() + epoch_dur = train_epoch_node.duration() / 1000 + logging.info(f"Epoch {self.epoch} spent {format_time(epoch_dur)} " + f"(Avg. {format_time(epoch_dur / self.iters_per_epoch)}/iter). " + f"Loss is {loss_min:.2e}/{loss_avg:.2e}/{loss_max:.2e}") + + def _train_epoch_debug(self): # TBR + iters_in_epoch = 0 + loss_min = HUGE_FLOAT + loss_max = 0 + loss_avg = 0 + + self._show_progress(iters_in_epoch, loss={'val': 0, 'min': 0, 'max': 0, 'avg': 0}) + indices = [] + debug_data = [] + for idx, rays_o, rays_d, extra in self.data_loader: + out = self.model(rays_o, rays_d, extra_outputs=['layers', 'weights']) + loss_val = nn_f.mse_loss(out['color'], extra['color']).item() + + loss_min = min(loss_min, loss_val) + loss_max = max(loss_max, loss_val) + loss_avg = (loss_avg * iters_in_epoch + loss_val) / (iters_in_epoch + 1) + + self.iters += 1 + iters_in_epoch += 1 + self._show_progress(iters_in_epoch, loss={ + 'val': loss_val, + 'min': loss_min, + 'max': loss_max, + 'avg': loss_avg + }) + + indices.append(idx) + debug_data.append(torch.cat([ + extra['view_idx'][..., None], + extra['pix_idx'][..., None], + rays_d, + #out['samples'].pts[:, 215:225].reshape(idx.size(0), -1), + #out['samples'].dirs[:, :3].reshape(idx.size(0), -1), + #out['samples'].voxel_indices[:, 215:225], + out['states'].densities[:, 210:230].detach().reshape(idx.size(0), -1), + out['states'].energies[:, 210:230].detach().reshape(idx.size(0), -1) + # out['color'].detach() + ], dim=-1)) + # states: VolumnRenderer.States = out['states'] # TBR + + indices = torch.cat(indices, dim=0) + debug_data = torch.cat(debug_data, dim=0) + indices, sort = indices.sort() + debug_data = debug_data[sort] + name = "rand.csv" if self.data_loader.shuffle else "seq.csv" + with (self.run_dir / name).open("w") as fp: + csv_writer = csv.writer(fp) + csv_writer.writerows(torch.cat([indices[:20, None], debug_data[:20]], dim=-1).tolist()) + return + with (self.run_dir / 'states.csv').open("w") as fp: + csv_writer = csv.writer(fp) + for chunk_info in states.chunk_infos: + csv_writer.writerow( + [*chunk_info['range'], chunk_info['hits'], chunk_info['core_i']]) + if chunk_info['hits'] > 0: + csv_writer.writerows(torch.cat([ + chunk_info['samples'].pts, + chunk_info['samples'].dirs, + chunk_info['samples'].voxel_indices[:, None], + chunk_info['colors'], + chunk_info['energies'] + ], dim=-1).tolist()) + csv_writer.writerow([]) diff --git a/train/train_with_space.py b/train/train_with_space.py new file mode 100644 index 0000000..4d236da --- /dev/null +++ b/train/train_with_space.py @@ -0,0 +1,127 @@ +from modules.sampler import Samples +from modules.space import Octree, Voxels +from utils.mem_profiler import MemProfiler +from utils.misc import print_and_log +from .base import * + + +class TrainWithSpace(Train): + + def __init__(self, model: BaseModel, pruning_loop: int = 10000, splitting_loop: int = 10000, + **kwargs) -> None: + super().__init__(model, **kwargs) + self.pruning_loop = pruning_loop + self.splitting_loop = splitting_loop + #MemProfiler.enable = True + + def _train_epoch(self): + if not self.perf_mode: + if self.epoch != 1: + if self.splitting_loop == 1 or self.epoch % self.splitting_loop == 1: + try: + with torch.no_grad(): + before, after = self.model.splitting() + print_and_log( + f"Splitting done. # of voxels before: {before}, after: {after}") + except NotImplementedError: + print_and_log( + "Note: The space does not support splitting operation. Just skip it.") + if self.pruning_loop == 1 or self.epoch % self.pruning_loop == 1: + try: + with torch.no_grad(): + #before, after = self.model.pruning() + # print(f"Pruning by voxel densities done. # of voxels before: {before}, after: {after}") + # self._prune_inner_voxels() + self._prune_voxels_by_weights() + except NotImplementedError: + print_and_log( + "Note: The space does not support pruning operation. Just skip it.") + + super()._train_epoch() + + def _prune_inner_voxels(self): + space: Voxels = self.model.space + voxel_access_counts = torch.zeros(space.n_voxels, dtype=torch.long, + device=space.voxels.device) + iters_in_epoch = 0 + batch_size = self.data_loader.batch_size + self.data_loader.batch_size = 2 ** 14 + for _, rays_o, rays_d, _ in self.data_loader: + self.model(rays_o, rays_d, + raymarching_early_stop_tolerance=0.01, + raymarching_chunk_size_or_sections=[1], + perturb_sample=False, + voxel_access_counts=voxel_access_counts, + voxel_access_tolerance=0) + iters_in_epoch += 1 + percent = iters_in_epoch / len(self.data_loader) * 100 + sys.stdout.write(f'Pruning inner voxels...{percent:.1f}% \r') + self.data_loader.batch_size = batch_size + before, after = space.prune(voxel_access_counts > 0) + print(f"Prune inner voxels: {before} -> {after}") + + def _prune_voxels_by_weights(self): + space: Voxels = self.model.space + voxel_access_counts = torch.zeros(space.n_voxels, dtype=torch.long, + device=space.voxels.device) + iters_in_epoch = 0 + batch_size = self.data_loader.batch_size + self.data_loader.batch_size = 2 ** 14 + for _, rays_o, rays_d, _ in self.data_loader: + ret = self.model(rays_o, rays_d, + raymarching_early_stop_tolerance=0, + raymarching_chunk_size_or_sections=None, + perturb_sample=False, + extra_outputs=['weights']) + valid_mask = ret['weights'][..., 0] > 0.01 + accessed_voxels = ret['samples'].voxel_indices[valid_mask] + voxel_access_counts.index_add_(0, accessed_voxels, torch.ones_like(accessed_voxels)) + iters_in_epoch += 1 + percent = iters_in_epoch / len(self.data_loader) * 100 + sys.stdout.write(f'Pruning by weights...{percent:.1f}% \r') + self.data_loader.batch_size = batch_size + before, after = space.prune(voxel_access_counts > 0) + print_and_log(f"Prune by weights: {before} -> {after}") + + def _prune_voxels_by_voxel_weights(self): + space: Voxels = self.model.space + voxel_access_counts = torch.zeros(space.n_voxels, dtype=torch.long, + device=space.voxels.device) + with torch.no_grad(): + batch_size = self.data_loader.batch_size + self.data_loader.batch_size = 2 ** 14 + iters_in_epoch = 0 + for _, rays_o, rays_d, _ in self.data_loader: + ret = self.model(rays_o, rays_d, + raymarching_early_stop_tolerance=0, + raymarching_chunk_size_or_sections=None, + perturb_sample=False, + extra_outputs=['weights']) + self._accumulate_access_count_by_weight(ret['samples'], ret['weights'][..., 0], + voxel_access_counts) + iters_in_epoch += 1 + percent = iters_in_epoch / len(self.data_loader) * 100 + sys.stdout.write(f'Pruning by voxel weights...{percent:.1f}% \r') + self.data_loader.batch_size = batch_size + before, after = space.prune(voxel_access_counts > 0) + print_and_log(f"Prune by voxel weights: {before} -> {after}") + + def _accumulate_access_count_by_weight(self, samples: Samples, weights: torch.Tensor, + voxel_access_counts: torch.Tensor): + uni_vidxs = -torch.ones_like(samples.voxel_indices) + vidx_accu = torch.zeros_like(samples.voxel_indices, dtype=torch.float) + uni_vidxs_row = torch.arange(samples.size[0], dtype=torch.long, device=samples.device) + uni_vidxs_head = torch.zeros_like(samples.voxel_indices[:, 0]) + uni_vidxs[:, 0] = samples.voxel_indices[:, 0] + vidx_accu[:, 0].add_(weights[:, 0]) + for i in range(samples.size[1]): + # For those rows that voxels are changed, move the head one step forward + next_voxel = uni_vidxs[uni_vidxs_row, uni_vidxs_head].ne(samples.voxel_indices[:, i]) + uni_vidxs_head[next_voxel].add_(1) + # Set voxel indices and accumulate weights + uni_vidxs[uni_vidxs_row, uni_vidxs_head] = samples.voxel_indices[:, i] + vidx_accu[uni_vidxs_row, uni_vidxs_head].add_(weights[:, i]) + max_accu = vidx_accu.max(dim=1, keepdim=True)[0] + uni_vidxs[vidx_accu < max_accu * 0.1] = -1 + access_voxels, access_count = uni_vidxs.unique(return_counts=True) + voxel_access_counts[access_voxels[1:]].add_(access_count[1:]) diff --git a/train_oracle.py b/train_oracle.py index 82e5afd..219d81a 100644 --- a/train_oracle.py +++ b/train_oracle.py @@ -260,8 +260,8 @@ def train(): if epochRange.start > 1: iters = netio.load(f'{run_dir}model-epoch_{epochRange.start - 1}.pth', model) else: - misc.create_dir(run_dir) - misc.create_dir(log_dir) + os.makedirs(run_dir, exist_ok=True) + os.makedirs(log_dir, exist_ok=True) iters = 0 # 3. Train @@ -333,7 +333,7 @@ def test(): # 4. Save results print('Saving results...') - misc.create_dir(output_dir) + os.makedirs(output_dir, exist_ok=True) for key in out: shape = [n] + list(dataset.view_res) + list(out[key].size()[1:]) @@ -367,7 +367,7 @@ def test(): for i in range(n) ]) output_subdir = f"{output_dir}/{output_dataset_id}_bins" - misc.create_dir(output_subdir) + os.makedirs(output_subdir, exist_ok=True) img.save(out['bins'], [f'{output_subdir}/{i:0>4d}.png' for i in dataset.view_idxs]) diff --git a/upsampling/run_upsampling.py b/upsampling/run_upsampling.py index 79ee6ea..8b90e2f 100644 --- a/upsampling/run_upsampling.py +++ b/upsampling/run_upsampling.py @@ -60,7 +60,7 @@ args.color = color.from_str(args.color) def train(): - misc.create_dir(run_dir) + os.makedirs(run_dir, exist_ok=True) train_set = UpsamplingDataset('.', 'input/out_view_%04d.png', 'gt/view_%04d.png', color=args.color) training_data_loader = FastDataLoader(dataset=train_set, @@ -80,7 +80,7 @@ def train(): def test(): - misc.create_dir(os.path.dirname(args.testOutPatt)) + os.makedirs(os.path.dirname(args.testOutPatt), exist_ok=True) train_set = UpsamplingDataset( '.', 'input/out_view_%04d.png', None, color=args.color) training_data_loader = FastDataLoader(dataset=train_set, diff --git a/utils/constants.py b/utils/constants.py index 8d2ad1b..42601d4 100644 --- a/utils/constants.py +++ b/utils/constants.py @@ -2,4 +2,6 @@ import math HUGE_FLOAT = 1e10 TINY_FLOAT = 1e-6 -PI = math.pi \ No newline at end of file +PI = math.pi +NAN = math.nan +E = math.e \ No newline at end of file diff --git a/utils/geometry.py b/utils/geometry.py new file mode 100644 index 0000000..527ac4a --- /dev/null +++ b/utils/geometry.py @@ -0,0 +1,284 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Union +import numpy as np +import torch +import torch.nn.functional as F + +INF = 1000.0 + + +def ones_like(x): + T = torch if isinstance(x, torch.Tensor) else np + return T.ones_like(x) + + +def stack(x): + T = torch if isinstance(x[0], torch.Tensor) else np + return T.stack(x) + + +def matmul(x, y): + T = torch if isinstance(x, torch.Tensor) else np + return T.matmul(x, y) + + +def cross(x, y, axis=0): + T = torch if isinstance(x, torch.Tensor) else np + return T.cross(x, y, axis) + + +def cat(x, axis=1): + if isinstance(x[0], torch.Tensor): + return torch.cat(x, dim=axis) + return np.concatenate(x, axis=axis) + + +def normalize(x, axis=-1, order=2): + if isinstance(x, torch.Tensor): + l2 = x.norm(p=order, dim=axis, keepdim=True) + return x / (l2 + 1e-8), l2 + + else: + l2 = np.linalg.norm(x, order, axis) + l2 = np.expand_dims(l2, axis) + l2[l2 == 0] = 1 + return x / l2, l2 + + +def parse_extrinsics(extrinsics, world2camera=True): + """ this function is only for numpy for now""" + if extrinsics.shape[0] == 3 and extrinsics.shape[1] == 4: + extrinsics = np.vstack([extrinsics, np.array([[0, 0, 0, 1.0]])]) + if extrinsics.shape[0] == 1 and extrinsics.shape[1] == 16: + extrinsics = extrinsics.reshape(4, 4) + if world2camera: + extrinsics = np.linalg.inv(extrinsics).astype(np.float32) + return extrinsics + + +def parse_intrinsics(intrinsics): + fx = intrinsics[0, 0] + fy = intrinsics[1, 1] + cx = intrinsics[0, 2] + cy = intrinsics[1, 2] + return fx, fy, cx, cy + + +def uv2cam(uv, z, intrinsics, homogeneous=False): + fx, fy, cx, cy = parse_intrinsics(intrinsics) + x_lift = (uv[0] - cx) / fx * z + y_lift = (uv[1] - cy) / fy * z + z_lift = ones_like(x_lift) * z + + if homogeneous: + return stack([x_lift, y_lift, z_lift, ones_like(z_lift)]) + else: + return stack([x_lift, y_lift, z_lift]) + + +def cam2world(xyz_cam, inv_RT): + return matmul(inv_RT, xyz_cam)[:3] + + +def r6d2mat(d6: torch.Tensor) -> torch.Tensor: + """ + Converts 6D rotation representation by Zhou et al. [1] to rotation matrix + using Gram--Schmidt orthogonalisation per Section B of [1]. + Args: + d6: 6D rotation representation, of size (*, 6) + + Returns: + batch of rotation matrices of size (*, 3, 3) + + [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H. + On the Continuity of Rotation Representations in Neural Networks. + IEEE Conference on Computer Vision and Pattern Recognition, 2019. + Retrieved from http://arxiv.org/abs/1812.07035 + """ + + a1, a2 = d6[..., :3], d6[..., 3:] + b1 = F.normalize(a1, dim=-1) + b2 = a2 - (b1 * a2).sum(-1, keepdim=True) * b1 + b2 = F.normalize(b2, dim=-1) + b3 = torch.cross(b1, b2, dim=-1) + return torch.stack((b1, b2, b3), dim=-2) + + +def get_ray_direction(ray_start, uv, intrinsics, inv_RT, depths=None): + if depths is None: + depths = 1 + rt_cam = uv2cam(uv, depths, intrinsics, True) + rt = cam2world(rt_cam, inv_RT) + ray_dir, _ = normalize(rt - ray_start[:, None], axis=0) + return ray_dir + + +def look_at_rotation(camera_position, at=None, up=None, inverse=False, cv=False): + """ + This function takes a vector 'camera_position' which specifies the location + of the camera in world coordinates and two vectors `at` and `up` which + indicate the position of the object and the up directions of the world + coordinate system respectively. The object is assumed to be centered at + the origin. + + The output is a rotation matrix representing the transformation + from world coordinates -> view coordinates. + + Input: + camera_position: 3 + at: 1 x 3 or N x 3 (0, 0, 0) in default + up: 1 x 3 or N x 3 (0, 1, 0) in default + """ + + if at is None: + at = torch.zeros_like(camera_position) + else: + at = torch.tensor(at).type_as(camera_position) + if up is None: + up = torch.zeros_like(camera_position) + up[2] = -1 + else: + up = torch.tensor(up).type_as(camera_position) + + z_axis = normalize(at - camera_position)[0] + x_axis = normalize(cross(up, z_axis))[0] + y_axis = normalize(cross(z_axis, x_axis))[0] + + R = cat([x_axis[:, None], y_axis[:, None], z_axis[:, None]], axis=1) + return R + + +def ray(ray_start, ray_dir, depths): + return ray_start + ray_dir * depths + + +def compute_normal_map(ray_start, ray_dir, depths, RT, width=512, proj=False): + raise NotImplementedError("This function needs fairnr.data.data_utils to work. " + "Will remove this dependency later.") + # TODO: + # this function is pytorch-only (for not) + wld_coords = ray(ray_start, ray_dir, depths.unsqueeze(-1)).transpose(0, 1) + cam_coords = matmul(RT[:3, :3], wld_coords) + RT[:3, 3].unsqueeze(-1) + cam_coords = D.unflatten_img(cam_coords, width) + + # estimate local normal + shift_l = cam_coords[:, 2:, :] + shift_r = cam_coords[:, :-2, :] + shift_u = cam_coords[:, :, 2:] + shift_d = cam_coords[:, :, :-2] + diff_hor = normalize(shift_r - shift_l, axis=0)[0][:, :, 1:-1] + diff_ver = normalize(shift_u - shift_d, axis=0)[0][:, 1:-1, :] + normal = cross(diff_hor, diff_ver) + _normal = normal.new_zeros(*cam_coords.size()) + _normal[:, 1:-1, 1:-1] = normal + _normal = _normal.reshape(3, -1).transpose(0, 1) + + # compute the projected color + if proj: + _normal = normalize(_normal, axis=1)[0] + wld_coords0 = ray(ray_start, ray_dir, 0).transpose(0, 1) + cam_coords0 = matmul(RT[:3, :3], wld_coords0) + RT[:3, 3].unsqueeze(-1) + cam_coords0 = D.unflatten_img(cam_coords0, width) + cam_raydir = normalize(cam_coords - cam_coords0, 0)[0].reshape(3, -1).transpose(0, 1) + proj_factor = (_normal * cam_raydir).sum(-1).abs() * 0.8 + 0.2 + return proj_factor + return _normal + + +# helper functions for encoder + +def padding_points(xs, pad): + if len(xs) == 1: + return xs[0].unsqueeze(0) + + maxlen = max([x.size(0) for x in xs]) + xt = xs[0].new_ones(len(xs), maxlen, xs[0].size(1)).fill_(pad) + for i in range(len(xs)): + xt[i, :xs[i].size(0)] = xs[i] + return xt + + +def pruning_points(feats, points, scores, depth=0, th=0.5): + if depth > 0: + g = int(8 ** depth) + scores = scores.reshape(scores.size(0), -1, g).sum(-1, keepdim=True) + scores = scores.expand(*scores.size()[:2], g).reshape(scores.size(0), -1) + alpha = (1 - torch.exp(-scores)) > th + feats = [feats[i][alpha[i]] for i in range(alpha.size(0))] + points = [points[i][alpha[i]] for i in range(alpha.size(0))] + points = padding_points(points, INF) + feats = padding_points(feats, 0) + return feats, points + + +def offset_points(point_xyz: torch.Tensor, half_voxel: Union[torch.Tensor, int, float] = 1, + offset_only: bool = False, bits: int = 2) -> torch.Tensor: + """ + [summary] + + :param point_xyz `Tensor(N, 3)`: [description] + :param half_voxel `Tensor(1) | int | float`: [description], defaults to 1 + :param offset_only `bool`: [description], defaults to False + :param bits `int`: [description], defaults to 2 + :return `Tensor(N, X, 3)|Tensor(X, 3)`: [description] + """ + c = torch.arange(1 - bits, bits, 2, dtype=point_xyz.dtype, device=point_xyz.device) + offset = (torch.stack(torch.meshgrid(c, c, c), dim=-1).reshape(-1, 3)) / (bits - 1) * half_voxel + return offset if offset_only else point_xyz[:, None] + offset + + +def discretize_points(voxel_points, voxel_size): + # this function turns voxel centers/corners into integer indeices + # we assume all points are alreay put as voxels (real numbers) + minimal_voxel_point = voxel_points.min(dim=0, keepdim=True)[0] + voxel_indices = ((voxel_points - minimal_voxel_point) / voxel_size).round_().long() # float + residual = (voxel_points - voxel_indices.type_as(voxel_points) + * voxel_size).mean(0, keepdim=True) + return voxel_indices, residual + + +def expand_points(voxel_points, voxel_size): + _voxel_size = min([ + torch.sqrt(((voxel_points[j:j + 1] - voxel_points[j + 1:]) ** 2).sum(-1).min()) + for j in range(100)]) + depth = int(np.round(torch.log2(_voxel_size / voxel_size))) + if depth > 0: + half_voxel = _voxel_size / 2.0 + for _ in range(depth): + voxel_points = offset_points(voxel_points, half_voxel / 2.0).reshape(-1, 3) + half_voxel = half_voxel / 2.0 + + return voxel_points, depth + + +def get_edge(depth_pts, voxel_pts, voxel_size, th=0.05): + voxel_pts = offset_points(voxel_pts, voxel_size / 2.0) + diff_pts = (voxel_pts - depth_pts[:, None, :]).norm(dim=2) + ab = diff_pts.sort(dim=1)[0][:, :2] + a, b = ab[:, 0], ab[:, 1] + c = voxel_size + p = (ab.sum(-1) + c) / 2.0 + h = (p * (p - a) * (p - b) * (p - c)) ** 0.5 / c + return h < (th * voxel_size) + + +# fill-in image +def fill_in(shape, hits, input, initial=1.0): + input_sizes = [k for k in input.size()] + if (len(input_sizes) == len(shape)) and \ + all([shape[i] == input_sizes[i] for i in range(len(shape))]): + return input # shape is the same no need to fill + + if isinstance(initial, torch.Tensor): + output = initial.expand(*shape) + else: + output = input.new_ones(*shape) * initial + if input is not None: + if len(shape) == 1: + return output.masked_scatter(hits, input) + return output.masked_scatter(hits.unsqueeze(-1).expand(*shape), input) + return output diff --git a/utils/img.py b/utils/img.py index a39d308..8920922 100644 --- a/utils/img.py +++ b/utils/img.py @@ -1,10 +1,11 @@ import os +from pathlib import Path import shutil import torch import matplotlib.pyplot as plt import numpy as np import torch.nn.functional as nn_f -from typing import Tuple +from typing import List, Tuple, Union from . import misc from .constants import * @@ -65,7 +66,7 @@ def load(*paths: str, permute=True, with_alpha=False) -> torch.Tensor: chns = 4 if with_alpha else 3 new_paths = [] for path in paths: - new_paths += [path] if isinstance(path, str) else list(path) + new_paths += [path] if isinstance(path, (str, Path)) else list(path) imgs = np.stack([plt.imread(path)[..., :chns] for path in new_paths]) if imgs.dtype == 'uint8': imgs = imgs.astype(np.float32) / 255 @@ -76,7 +77,7 @@ def load_seq(path: str, n: int, permute=True, with_alpha=False) -> torch.Tensor: return load([path % i for i in range(n)], permute=permute, with_alpha=with_alpha) -def save(input: torch.Tensor, *paths: str): +def save(input: torch.Tensor, *paths: Union[str, Path, List[Union[str, Path]]]): """ Save one or multiple torch-image(s) to `paths` @@ -86,7 +87,7 @@ def save(input: torch.Tensor, *paths: str): """ new_paths = [] for path in paths: - new_paths += [path] if isinstance(path, str) else list(path) + new_paths += [path] if isinstance(path, (str, Path)) else list(path) if len(input.size()) < 4: input = input[None] if input.size(0) != len(new_paths): @@ -100,9 +101,9 @@ def save(input: torch.Tensor, *paths: str): plt.imsave(path, np_img[i]) -def save_seq(input: torch.Tensor, path: str): +def save_seq(input: torch.Tensor, path: Union[str, Path]): n = 1 if len(input.size()) <= 3 else input.size(0) - return save(input, [path % i for i in range(n)]) + return save(input, [str(path) % i for i in range(n)]) def plot(input: torch.Tensor, *, ax: plt.Axes = None): @@ -118,7 +119,7 @@ def plot(input: torch.Tensor, *, ax: plt.Axes = None): return plt.imshow(im) if ax is None else ax.imshow(im) -def save_video(frames: torch.Tensor, path: str, fps: int, +def save_video(frames: torch.Tensor, path: Union[str, Path], fps: int, repeat: int = 1, pingpong: bool = False): """ Encode and save a sequence of frames as video file @@ -134,19 +135,16 @@ def save_video(frames: torch.Tensor, path: str, fps: int, frames = torch.cat([frames, frames.flip(0)], 0) if repeat > 1: frames = frames.expand(repeat, -1, -1, -1, -1).flatten(0, 1) - dir, file_name = os.path.split(path) - if not dir: - dir = './' - misc.create_dir(dir) - cwd = os.getcwd() - os.chdir(dir) - temp_out_dir = os.path.splitext(file_name)[0] + '_tempout' - misc.create_dir(temp_out_dir) - os.chdir(temp_out_dir) - save_seq(frames, 'out_%04d.png') - os.system(f'ffmpeg -y -r {fps:d} -i out_%04d.png -c:v libx264 ../{file_name}') - os.chdir(cwd) - shutil.rmtree(os.path.join(dir, temp_out_dir)) + + path = Path(path) + tempdir = Path('/dev/shm/dvs_tmp/video') + inferout = tempdir / path.stem / f"%04d.bmp" + os.makedirs(inferout.parent, exist_ok=True) + os.makedirs(path.parent, exist_ok=True) + + save_seq(frames, inferout) + os.system(f'ffmpeg -y -r {fps:d} -i {inferout} -c:v libx264 {path}') + shutil.rmtree(inferout.parent) def horizontal_shift(input: torch.Tensor, offset: int, dim=-1) -> torch.Tensor: diff --git a/utils/mem_profiler.py b/utils/mem_profiler.py index d034bc0..848d4e7 100644 --- a/utils/mem_profiler.py +++ b/utils/mem_profiler.py @@ -2,13 +2,14 @@ from cgitb import enable import torch from .device import * + class MemProfiler: enable = False @staticmethod - def print_memory_stats(prefix, last_allocated=None, device=None): - if not MemProfiler.enable: + def print_memory_stats(prefix, last_allocated=None, device=None, enable_once=False): + if not enable_once and not MemProfiler.enable: return if device is None: device = default() diff --git a/utils/misc.py b/utils/misc.py index 2bf7250..e6f49fe 100644 --- a/utils/misc.py +++ b/utils/misc.py @@ -1,9 +1,13 @@ -import os +from itertools import repeat +import logging +from pathlib import Path +import re +import shutil import torch import glm import csv import numpy as np -from typing import List, Tuple, Union +from typing import List, Union from torch.types import Number from .constants import * from .device import * @@ -59,31 +63,11 @@ def meshgrid(*size: int, normalize: bool = False, swap_dim: bool = False) -> tor return torch.stack([x / (size[1] - 1.), y / (size[0] - 1.)], 2) if normalize else torch.stack([x, y], 2) -def create_dir(path): - if not os.path.exists(path): - os.makedirs(path) - - def get_angle(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: - angle = -torch.atan(x / y) + (y < 0) * PI + 0.5 * PI + angle = -torch.atan(x / y) - (y < 0) * PI + 0.5 * PI return angle -def depth_sample(depth_range: Tuple[float, float], n: int, lindisp: bool) -> torch.Tensor: - """ - Get [n_layers] foreground layers whose diopters are distributed uniformly - in [depth_range] plus a background layer - - :param depth_range: depth range of foreground layers - :param n_layers: number of foreground layers - :return: list of [n_layers+1] depths - """ - if lindisp: - depth_range = (1 / depth_range[0], 1 / depth_range[1]) - samples = torch.linspace(depth_range[0], depth_range[1], n) - return samples - - def broadcast_cat(input: torch.Tensor, s: Union[Number, List[Number], torch.Tensor], dim=-1, @@ -130,4 +114,73 @@ def view_like(input: torch.Tensor, ref: torch.Tensor) -> torch.Tensor: return input.view(out_shape) -def values(map, *keys): return list(map[key] for key in keys) +def format_time(seconds): + days = int(seconds / 3600 / 24) + seconds = seconds - days * 3600 * 24 + hours = int(seconds / 3600) + seconds = seconds - hours * 3600 + minutes = int(seconds / 60) + seconds = seconds - minutes * 60 + seconds_final = int(seconds) + seconds = seconds - seconds_final + millis = int(seconds * 1000) + + if days > 0: + output = f"{days}D{hours:0>2d}h{minutes:0>2d}m" + elif hours > 0: + output = f"{hours:0>2d}h{minutes:0>2d}m{seconds_final:0>2d}s" + elif minutes > 0: + output = f"{minutes:0>2d}m{seconds_final:0>2d}s" + elif seconds_final > 0: + output = f"{seconds_final:0>2d}s{millis:0>3d}ms" + elif millis > 0: + output = f"{millis:0>3d}ms" + else: + output = '0ms' + return output + + +def print_and_log(s): + print(s) + logging.info(s) + + +def masked_scatter(mask: torch.Tensor, value: torch.Tensor, initial: Union[torch.Tensor, Number] = 0): + """ + Extend PyTorch's built-in `masked_scatter` function + + :param mask `Tensor(M...)`: the boolean mask + :param value `Tensor(N, D...)`: the value to fill in with, should have at least as many elements + as the number of ones in `mask` + :param destination `Tensor(M..., D...)`: (optional) the destination tensor to fill, + if not specified, a new tensor filled with + `empty_value` will be created and used as destination + :param empty_value `Number`: the initial elements in the newly created destination tensor, + defaults to 0 + :return `Tensor(M..., D...)`: the destination tensor after filled + """ + M_ = mask.size() + D_ = value.size()[1:] + if not isinstance(initial, torch.Tensor): + initial = value.new_full([*M_, *D_], initial) + return initial.masked_scatter(mask.reshape(*M_, *repeat(1, len(D_))), value) + + +def list_epochs(dir: Path, pattern: str) -> List[int]: + prefix = pattern.split("*")[0] + epoch_list = [int(str(path.stem)[len(prefix):]) for path in dir.glob(pattern)] + epoch_list.sort() + return epoch_list + + +def rename_seqs_with_offset(dir: Path, file_pattern: str, offset: int): + start, end = re.search(r'%0\dd', file_pattern).span() + prefix, suffix = start, len(file_pattern) - end + + seqs = [ + int(path.name[prefix:-suffix]) + for path in dir.glob(re.sub(r'%0\dd', "*", file_pattern)) + ] + seqs.sort(reverse=offset > 0) + for i in seqs: + (dir / (file_pattern % i)).rename(dir / (file_pattern % (i + offset))) diff --git a/utils/perf.py b/utils/perf.py index 2b9f278..5a3dd58 100644 --- a/utils/perf.py +++ b/utils/perf.py @@ -1,32 +1,137 @@ +from numpy import average +import torch import torch.cuda +from typing import Dict, List, OrderedDict class Perf(object): + frames: List[Dict[str, float]] - def __init__(self, enable, start=False) -> None: + class Node: + def __init__(self, name, parent=None) -> None: + self.name = name + self.parent = parent + self.events = [] + self.event_names = [] + self.child_nodes = [] + self.child_nodes_event_idx = [] + self.add_checkpoint("Start") + + def add_checkpoint(self, name): + event = torch.cuda.Event(enable_timing=True) + event.record() + self.events.append(event) + self.event_names.append(name) + + def add_child(self, name): + child = Perf.Node(name, self) + self.child_nodes.append(child) + self.child_nodes_event_idx.append(len(self.events)) + return child + + def close(self): + self.add_checkpoint("End") + return self.parent + + def duration(self, i0=0, i1=-1) -> float: + return self.events[i0].elapsed_time(self.events[i1]) + + def result(self, prefix: str = '') -> OrderedDict[str, float]: + path = f"{prefix}{self.name}" + res = {path: self.duration()} + j = 0 + for i in range(1, len(self.events) - 1): + event_path = f"{path}/{self.event_names[i]}" + res[event_path] = self.duration(i - 1, i) + while j < len(self.child_nodes): + if self.child_nodes_event_idx[j] > i: + break + res.update(self.child_nodes[j].result(f"{event_path}/")) + j += 1 + while j < len(self.child_nodes): + res.update(self.child_nodes[j].result(f"{path}/")) + j += 1 + return res + + def __init__(self) -> None: super().__init__() - self.enable = enable - self.start_event = None - if start: - self.start() - - def start(self): - if not self.enable: - return - if self.start_event == None: - self.start_event = torch.cuda.Event(enable_timing=True) - self.end_event = torch.cuda.Event(enable_timing=True) - torch.cuda.synchronize() - self.start_event.record() - - def checkpoint(self, name: str = None, end: bool = False): - if not self.enable: - return 0 - self.end_event.record() - torch.cuda.synchronize() - duration = self.start_event.elapsed_time(self.end_event) - if name: - print('%s: %.1fms' % (name, duration)) - if not end: - self.start_event.record() - return duration + self.root_node = None + self.current_node = None + self.frames = [] + + def start_node(self, name): + if self.current_node is None: + self.root_node = self.current_node = Perf.Node(name) + else: + self.current_node = self.current_node.add_child(name) + + def checkpoint(self, name): + self.current_node.add_checkpoint(name) + + def end_node(self): + self.current_node = self.current_node.close() + if self.current_node is None: + torch.cuda.synchronize() + self.frames.append(self.root_node.result()) + + def get_result(self, i=None): + if i is not None: + return self.frames[i] + if len(self.frames) == 0: + return {} + res = {key: [val] for key, val in self.frames[0].items()} + for i in range(1, len(self.frames)): + for key, val in self.frames[i].items(): + res[key].append(val) + return {key: average(val) for key, val in res.items()} + + +default_perf_object = None + + +def enable_perf(): + global default_perf_object + default_perf_object = Perf() + + +def perf(fn_or_name): + if isinstance(fn_or_name, str): + name = fn_or_name + + def perf_with_name(fn): + def wrap_perf(*args, **kwargs): + start_node(name) + ret = fn(*args, **kwargs) + end_node() + return ret + return wrap_perf + return perf_with_name + fn = fn_or_name + + def wrap_perf(*args, **kwargs): + start_node(fn.__qualname__) + ret = fn(*args, **kwargs) + end_node() + return ret + return wrap_perf + + +def start_node(name): + if default_perf_object is not None: + default_perf_object.start_node(name) + + +def end_node(): + if default_perf_object is not None: + default_perf_object.end_node() + + +def checkpoint(name): + if default_perf_object is not None: + default_perf_object.checkpoint(name) + + +def get_perf_result(i=None): + if default_perf_object is not None: + return default_perf_object.get_result(i) + return None diff --git a/utils/progress_bar.py b/utils/progress_bar.py index 144586d..fd84774 100644 --- a/utils/progress_bar.py +++ b/utils/progress_bar.py @@ -1,78 +1,50 @@ +import shutil import sys import time -import os +from .misc import format_time +from .constants import NAN -bar_length = 50 -LAST_T = time.time() -BEGIN_T = LAST_T +last_time = time.time() +begin_time = last_time -def get_terminal_columns(): - return os.get_terminal_size().columns - -def progress_bar(current, total, msg=None, premsg=None): - global LAST_T, BEGIN_T +def progress_bar(current, total, msg=None, premsg=None, barmsg=None): + global last_time, begin_time if current == 0: - BEGIN_T = time.time() # Reset for new bar. + begin_time = time.time() # Reset for new bar. current_time = time.time() - step_time = current_time - LAST_T - LAST_T = current_time - total_time = current_time - BEGIN_T + step_time = current_time - last_time + total_time = current_time - begin_time + last_time = current_time + estimated_time = 0 if current == 0 else total_time / current * (total - current) + + show_opt = int(current_time) % 6 >= 3 and current < total + show_barmsg = barmsg is not None and show_opt str0 = f"{premsg} [" if premsg else '[' - str1 = f"] {current + 1:d}/{total:d} | Step: {format_time(step_time)} | Tot: {format_time(total_time)}" + str1 = f"] {current:d}/{total:d} | Step: {format_time(step_time)} | " + ( + f"Eta: {format_time(estimated_time)}" if show_opt else f"Tot: {format_time(total_time)}" + ) if msg: str1 += f" | {msg}" - tot_cols = get_terminal_columns() + tot_cols = shutil.get_terminal_size().columns - 10 bar_length = tot_cols - len(str0) - len(str1) - current_len = int(bar_length * (current + 1) / total) - rest_len = int(bar_length - current_len) - - if current_len == 0: - str_bar = '.' * rest_len + if show_barmsg and bar_length < len(barmsg): + sys.stdout.write(str0[:-1] + barmsg) + elif bar_length <= 0: + sys.stdout.write(str0[:-1] + str1[2:]) else: - str_bar = '=' * (current_len - 1) + '>' + '.' * rest_len - - sys.stdout.write(str0 + str_bar + str1) - - if current < total - 1: - sys.stdout.write('\r') - else: - sys.stdout.write('\n') + current_len = int(bar_length * current / total) + rest_len = int(bar_length - current_len) + str_bar = '' + if current_len > 0: + str_bar += '=' * (current_len - 1) + '>' + str_bar += '.' * rest_len + if show_barmsg: + str_bar = barmsg + str_bar[len(barmsg):] + sys.stdout.write(str0 + str_bar + str1) + + sys.stdout.write('\r' if current < total else '\n') sys.stdout.flush() - - -# return the formatted time -def format_time(seconds): - days = int(seconds / 3600 / 24) - seconds = seconds - days * 3600 * 24 - hours = int(seconds / 3600) - seconds = seconds - hours * 3600 - minutes = int(seconds / 60) - seconds = seconds - minutes * 60 - seconds_final = int(seconds) - seconds = seconds - seconds_final - millis = int(seconds * 1000) - - output = '' - time_index = 1 - if days > 0: - output += str(days) + 'D' - time_index += 1 - if hours > 0 and time_index <= 2: - output += str(hours) + 'h' - time_index += 1 - if minutes > 0 and time_index <= 2: - output += str(minutes) + 'm' - time_index += 1 - if seconds_final > 0 and time_index <= 2: - output += '%02ds' % seconds_final - time_index += 1 - if millis > 0 and time_index <= 2: - output += '%03dms' % millis - time_index += 1 - if output == '': - output = '0ms' - return output diff --git a/utils/sphere.py b/utils/sphere.py index 8feb6d5..24d6309 100644 --- a/utils/sphere.py +++ b/utils/sphere.py @@ -1,4 +1,4 @@ -from typing import List, Union +from typing import Union import torch import math from . import misc @@ -13,12 +13,12 @@ def cartesian2spherical(cart: torch.Tensor, inverse_r: bool = False) -> torch.Te :return `Tensor(..., 3)`: coordinates in Spherical (r, theta, phi) """ rho = torch.sqrt(torch.sum(cart * cart, dim=-1)) - theta = misc.get_angle(cart[..., 0], cart[..., 2]) + theta = misc.get_angle(cart[..., 2], cart[..., 0]) if inverse_r: rho = rho.reciprocal() - phi = torch.acos(cart[..., 1] * rho) + phi = torch.asin(cart[..., 1] * rho) else: - phi = torch.acos(cart[..., 1] / rho) + phi = torch.asin(cart[..., 1] / rho) return torch.stack([rho, theta, phi], dim=-1) @@ -34,9 +34,9 @@ def spherical2cartesian(spher: torch.Tensor, inverse_r: bool = False) -> torch.T rho = rho.reciprocal() sin_theta_phi = torch.sin(spher[..., 1:3]) cos_theta_phi = torch.cos(spher[..., 1:3]) - x = rho * cos_theta_phi[..., 0] * sin_theta_phi[..., 1] - y = rho * cos_theta_phi[..., 1] - z = rho * sin_theta_phi[..., 0] * sin_theta_phi[..., 1] + x = rho * sin_theta_phi[..., 0] * cos_theta_phi[..., 1] + y = rho * sin_theta_phi[..., 1] + z = rho * cos_theta_phi[..., 0] * cos_theta_phi[..., 1] return torch.stack([x, y, z], dim=-1) diff --git a/utils/voxels.py b/utils/voxels.py new file mode 100644 index 0000000..824f2de --- /dev/null +++ b/utils/voxels.py @@ -0,0 +1,174 @@ +import torch +from typing import Tuple, Union + + +def get_grid_steps(bbox: torch.Tensor, step_size: Union[torch.Tensor, float]) -> torch.Tensor: + """ + Get grid steps alone every dim. + + :param bbox `Tensor(2, D)`: bounding box + :param step_size `Tensor(1|D) | float`: step size + :return `Tensor(D)`: grid steps alone every dim + """ + return ((bbox[1] - bbox[0]) / step_size).ceil().long() + + +def to_grid_coords(pts: torch.Tensor, bbox: torch.Tensor, *, + step_size: Union[torch.Tensor, float] = None, + steps: torch.Tensor = None) -> torch.Tensor: + """ + Get discretized (integer) grid coordinates of points. + + At least one of the parameters `step_size` and `steps` should be specified. If `step_size` is + specified, then the grid coordinates will be calculated according to the step size, ignoring + the value of `steps`. + + :param pts `Tensor(N..., D)`: points + :param bbox `Tensor(2, D)`: bounding box + :param step_size `Tensor(1|D) | float`: (optional) step size + :param steps `Tensor(1|D)`: (optional) steps alone every dim + :return `Tensor(N..., D)`: discretized grid coordinates + """ + if step_size is not None: + return ((pts - bbox[0]) / step_size).floor().long() + return ((pts - bbox[0]) / (bbox[1] - bbox[0]) * steps).floor().long() + + +def to_grid_indices(pts: torch.Tensor, bbox: torch.Tensor, *, + step_size: Union[torch.Tensor, float] = None, + steps: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Get flattened grid indices of points. + + At least one of the parameters `step_size` and `steps` should be specified. If `step_size` is + specified, then the grid indices will be calculated according to the step size, ignoring + the value of `steps`. + + :param pts `Tensor(N..., D)`: points + :param bbox `Tensor(2, D)`: bounding box + :param step_size `Tensor(1|D) | float`: (optional) step size + :param steps `Tensor(1|D)`: (optional) steps alone every dim + :return `Tensor(N...)`: grid indices + :return `Tensor(N...)`: a mask tensor indicating the returned indices are outside or not + """ + if step_size is not None: + steps = get_grid_steps(bbox, step_size) # (D) + grid_coords = to_grid_coords(pts, bbox, step_size=step_size, steps=steps) # (N..., D) + outside_mask = torch.logical_or(grid_coords < 0, grid_coords >= steps).any(-1) # (N...) + if pts.size(-1) == 1: + grid_indices = grid_coords[..., 0] + elif pts.size(-1) == 2: + grid_indices = grid_coords[..., 0] * steps[1] + grid_coords[..., 1] + elif pts.size(-1) == 3: + grid_indices = grid_coords[..., 0] * steps[1] * steps[2] \ + + grid_coords[..., 1] * steps[2] + grid_coords[..., 2] + elif pts.size(-1) == 4: + grid_indices = grid_coords[..., 0] * steps[1] * steps[2] * steps[3] \ + + grid_coords[..., 1] * steps[2] * steps[3] \ + + grid_coords[..., 2] * steps[3] \ + + grid_coords[..., 3] + else: + raise NotImplementedError("The function does not support D>4") + return grid_indices, outside_mask + + +def init_voxels(bbox: torch.Tensor, steps: torch.Tensor): + """ + Initialize voxels. + """ + x, y, z = torch.meshgrid(*[torch.arange(steps[i]) for i in range(3)]) + return to_voxel_centers(torch.stack([x, y, z], -1).reshape(-1, 3), bbox, steps=steps) + + +def to_voxel_centers(grid_coords: torch.Tensor, bbox: torch.Tensor, *, + step_size: Union[torch.Tensor, float] = None, + steps: torch.Tensor = None) -> torch.Tensor: + """ + Get discretized (integer) grid coordinates of points. + + At least one of the parameters `step_size` and `steps` should be specified. If `step_size` is + specified, then the grid coordinates will be calculated according to the step size, ignoring + the value of `steps`. + + :param pts `Tensor(N..., D)`: points + :param bbox `Tensor(2, D)`: bounding box + :param step_size `Tensor(1|D) | float`: (optional) step size + :param steps `Tensor(1|D)`: (optional) steps alone every dim + :return `Tensor(N..., D)`: discretized grid coordinates + """ + grid_coords = grid_coords.float() + 0.5 + if step_size is not None: + return grid_coords * step_size + bbox[0] + return grid_coords / steps * (bbox[1] - bbox[0]) + bbox[0] + + +def split_voxels_local(voxel_size: Union[torch.Tensor, float], n: int, align_border: bool = True, + dims=3, *, dtype: torch.dtype = None, device: torch.device = None, + like: torch.Tensor = None): + """ + [summary] + + :param voxel_size `Tensor(D)|float`: [description] + :param n `int`: [description] + :param align_border `bool`: [description], defaults to False + :param dims `int`: [description], defaults to 3 + :param dtype `dtype`: [description], defaults to None + :param device `device`: [description], defaults to None + :param like `Tensor(*)`: + :return `Tensor(X, D)`: [description] + """ + if like is not None: + dtype = like.dtype + device = like.device + c = torch.arange(1 - n, n, 2, dtype=dtype, device=device) + offset = torch.stack(torch.meshgrid([c] * dims), -1).flatten(0, -2) * voxel_size / 2 /\ + (n - 1 if align_border else n) + return offset + + +def split_voxels(voxel_centers: torch.Tensor, voxel_size: Union[torch.Tensor, float], n: int, + align_border: bool = True): + """ + [summary] + + :param voxel_centers `Tensor(N, D)`: [description] + :param voxel_size `Tensor(D)|float`: [description] + :param n `int`: [description] + :param align_border `bool`: [description], defaults to False + :param return_local `bool`: [description], defaults to False + :return `Tensor(N, X, D)`: [description] + """ + return voxel_centers[:, None] + split_voxels_local( + voxel_size, n, align_border, voxel_centers.shape[-1], like=voxel_centers) + + +def get_corners(voxel_centers: torch.Tensor, bbox: torch.Tensor, steps: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + half_voxel_size = (bbox[1] - bbox[0]) / steps * 0.5 + expand_bbox = bbox + expand_bbox[0] -= 0.5 * half_voxel_size + expand_bbox[1] += 0.5 * half_voxel_size + double_grid_coords = to_grid_coords(voxel_centers, expand_bbox, step_size=half_voxel_size) + # (M, 3) -> [1, 3, 5, ...] + + corner_coords = split_voxels(double_grid_coords, 2, 2).reshape(-1, 3) + # (8M, 3) -> [0, 2, 4, ...] + + corner_coords, corner_indices = corner_coords.unique(dim=0, sorted=True, return_inverse=True) + corners = to_voxel_centers(corner_coords, expand_bbox, step_size=half_voxel_size) + + return corners, corner_indices.reshape(-1, 8) + + +def trilinear_interp(pts: torch.Tensor, corner_values: torch.Tensor) -> torch.Tensor: + """ + Perform trilinear interpolation in unit voxel ([0,0,0] ~ [1,1,1]). + + :param pts `Tensor(N, 3)`: uniform coordinates in voxels + :param corner_values `Tensor(N, 8X)|Tensor(N, 8, X)`: values at corners of voxels + :return `Tensor(N, X)`: interpolated values + """ + pts = pts[:, None] # (N, 1, 3) + corners = split_voxels_local(1, 2, like=pts) + 0.5 # (8, 3) + weights = (pts * corners * 2 - pts - corners + 1).prod(-1, keepdim=True) # (N, 8, 1) + corner_values = corner_values.reshape(corner_values.size(0), 8, -1) # (N, 8, X) + return (weights * corner_values).sum(1) -- GitLab