From 5699ccbf8536ceb9e79194f2a1b784da39b511eb Mon Sep 17 00:00:00 2001
From: Nianchen Deng <dengnianchen@sjtu.edu.cn>
Date: Fri, 3 Dec 2021 09:36:27 +0800
Subject: [PATCH] sync

---
 .vscode/launch.json                           |  57 ++
 README.md                                     |   6 +
 blender/gen_pano.py                           |  15 +
 blender/gen_utils.py                          | 184 ++++
 clib/__init__.py                              | 479 +++++++++++
 clib/include/cuda_utils.h                     |  46 +
 clib/include/cutil_math.h                     | 793 ++++++++++++++++++
 clib/include/intersect.h                      |  17 +
 clib/include/octree.h                         |  10 +
 clib/include/sample.h                         |  16 +
 clib/include/utils.h                          |  30 +
 clib/src/binding.cpp                          |  21 +
 clib/src/intersect.cpp                        | 146 ++++
 clib/src/intersect_gpu.cu                     | 375 +++++++++
 clib/src/octree.cpp                           | 136 +++
 clib/src/sample.cpp                           |  96 +++
 clib/src/sample_gpu.cu                        | 231 +++++
 configs/nerf_default.json                     |  22 +
 configs/nerf_voxels.json                      |  24 +
 configs/nsvf_coarse.json                      |  21 +
 configs/nsvf_default.json                     |  21 +
 configs/nsvf_voxels.json                      |  21 +
 configs/{ => old}/bgnet.py                    |   0
 configs/{ => old}/cnerf.py                    |   0
 configs/{ => old}/dnerfabins.py               |   0
 configs/{ => old}/fovea.py                    |   0
 configs/{ => old}/fovea_small_rot1.py         |   0
 configs/{ => old}/fovea_small_trans.py        |   0
 configs/{ => old}/msl2fast.py                 |   0
 configs/{ => old}/msl_fovea.py                |   0
 configs/{ => old}/mslfast.py                  |   0
 configs/{ => old}/mslray.py                   |   0
 configs/{ => old}/nerf.py                     |   0
 configs/{ => old}/nerf_horns.py               |   0
 configs/{ => old}/nerf_horns_4.py             |   0
 configs/{ => old}/nerf_horns_8.py             |   0
 configs/{ => old}/nerf_periph.py              |   0
 configs/{ => old}/nerf_trex.py                |   0
 configs/{ => old}/nerf_trex_4.py              |   0
 configs/{ => old}/nerf_trex_8.py              |   0
 configs/{ => old}/nerfsimple.py               |   0
 configs/{ => old}/nmsl_fovea.py               |   0
 configs/{ => old}/nnerf.py                    |   0
 configs/{ => old}/oracle.py                   |   0
 configs/{ => old}/periph.py                   |   0
 configs/{ => old}/periph_new.py               |   0
 configs/{ => old}/periph_small_trans.py       |   0
 configs/{ => old}/snerffast_periph.py         |   0
 configs/{ => old}/snerffastx.py               |   0
 configs/snerf_fine_voxels.json                |  21 +
 configs/snerf_voxels+ls-d.json                |  20 +
 configs/snerf_voxels+ls.json                  |  21 +
 configs/snerf_voxels.json                     |  19 +
 configs/snerf_voxels_128x8_x2.json            |  22 +
 configs/snerf_voxels_128x8_x4.json            |  22 +
 configs/snerf_voxels_feat.json                |  21 +
 configs/snerf_voxels_fine.json                |  21 +
 configs/snerfadv_finevoxels+ls.json           |  34 +
 ...erfadv_finevoxels+ls_256x4_256x6_16x2.json |  34 +
 ...dv_finevoxels+ls_256x4_256x6_combined.json |  30 +
 configs/snerfadv_finevoxels_ls2.json          |  34 +
 configs/snerfadv_voxels+ls+ns.json            |  36 +
 configs/snerfadv_voxels+ls.json               |  33 +
 configs/snerfadv_voxels+ls1.json              |  34 +
 configs/snerfadv_voxels+ls2.json              |  34 +
 configs/snerfadv_voxels+ls3.json              |  34 +
 configs/snerfadv_voxels+ls4.json              |  34 +
 configs/snerfadv_voxels+ls5.json              |  34 +
 configs/snerfadv_voxels+ls6.json              |  34 +
 configs/snerfadvx_voxels_x16.json             |  34 +
 configs/snerfadvx_voxels_x4.json              |  34 +
 configs/snerfadvx_voxels_x8.json              |  34 +
 configs/snerfx_voxels_128x4_x4.json           |  21 +
 configs/snerfx_voxels_128x4_x8.json           |  21 +
 configs/snerfx_voxels_128x8_x4.json           |  21 +
 configs/snerfx_voxels_256x4_x4.json           |  21 +
 configs/snerfx_voxels_256x4_x4_balance.json   |  22 +
 dash_test.py                                  |  66 +-
 data/dataset_factory.py                       |  35 +-
 data/loader.py                                |  23 +-
 data/pano_dataset.py                          | 100 ++-
 data/view_dataset.py                          | 138 +--
 debug/voxel_sampler_export3d.py               | 134 +++
 fntest.py                                     |  12 +
 loss/__init__.py                              |   5 +
 loss/cauchy.py                                |  16 +
 loss/ssim.py                                  |   1 -
 model/__init__.py                             |  45 +
 model/base.py                                 |  34 +
 {nets => model}/bg_net.py                     |   0
 model/nerf.py                                 | 181 ++++
 model/nerf_advance.py                         |  37 +
 {nets => model}/nerf_depth.py                 |   2 +-
 model/nsvf.py                                 |  16 +
 {nets => model}/oracle.py                     |   2 +-
 model/snerf.py                                |  26 +
 model/snerf_advance.py                        |  33 +
 model/snerf_advance_x.py                      |  74 ++
 {nets => model}/snerf_fast.py                 |   2 +-
 model/snerf_x.py                              |  79 ++
 modules/__init__.py                           |  42 +-
 modules/core.py                               | 175 ++++
 modules/generic.py                            |  20 +-
 modules/renderer.py                           | 384 +++++++--
 modules/sampler.py                            | 264 +++++-
 modules/space.py                              | 351 ++++++++
 nerf++                                        |   1 -
 nets/nerf.py                                  |  78 --
 nets/nsvf.py                                  |  71 --
 nets/snerf.py                                 | 110 ---
 notebook/gen_crop.ipynb                       |  27 +-
 notebook/gen_demo_mono.ipynb                  |   4 +-
 notebook/gen_demo_stereo.ipynb                |   4 +-
 notebook/gen_for_eval.ipynb                   |   4 +-
 notebook/gen_teaser.ipynb                     |   2 +-
 notebook/gen_test.ipynb                       |   2 +-
 notebook/gen_user_study_images.ipynb          |   2 +-
 notebook/gen_video.ipynb                      |   2 +-
 notebook/net_insight.ipynb                    |   4 +-
 notebook/test_mono_gen.ipynb                  |   2 +-
 notebook/test_mono_view.ipynb                 |   2 +-
 run_lf_syn.py                                 |   4 +-
 run_spherical_view_syn.py                     |  30 +-
 setup.py                                      |  27 +
 term_test.py                                  |  15 +
 test.py                                       | 226 +++++
 tools/clean_nets.py                           |  23 +-
 tools/depth_downsample.py                     |   2 +-
 tools/export_msl.py                           |   2 +-
 tools/export_nmsl.py                          |   2 +-
 tools/export_onnx.py                          |   2 +-
 tools/export_snerf_fast.py                    |   2 +-
 tools/gen_video.py                            |   6 +-
 tools/image_scale.py                          |   2 +-
 tools/merge_dataset.py                        |   2 +-
 tools/pano_process.py                         |  36 +
 tools/split_dataset.py                        |  85 +-
 train.py                                      | 103 +++
 train/__init__.py                             |  26 +
 train/base.py                                 | 225 +++++
 train/train_with_space.py                     | 127 +++
 train_oracle.py                               |   8 +-
 upsampling/run_upsampling.py                  |   4 +-
 utils/constants.py                            |   4 +-
 utils/geometry.py                             | 284 +++++++
 utils/img.py                                  |  38 +-
 utils/mem_profiler.py                         |   5 +-
 utils/misc.py                                 | 101 ++-
 utils/perf.py                                 | 157 +++-
 utils/progress_bar.py                         |  96 +--
 utils/sphere.py                               |  14 +-
 utils/voxels.py                               | 174 ++++
 152 files changed, 7186 insertions(+), 805 deletions(-)
 create mode 100644 .vscode/launch.json
 create mode 100644 blender/gen_pano.py
 create mode 100644 blender/gen_utils.py
 create mode 100644 clib/__init__.py
 create mode 100644 clib/include/cuda_utils.h
 create mode 100644 clib/include/cutil_math.h
 create mode 100644 clib/include/intersect.h
 create mode 100644 clib/include/octree.h
 create mode 100644 clib/include/sample.h
 create mode 100644 clib/include/utils.h
 create mode 100644 clib/src/binding.cpp
 create mode 100644 clib/src/intersect.cpp
 create mode 100644 clib/src/intersect_gpu.cu
 create mode 100644 clib/src/octree.cpp
 create mode 100644 clib/src/sample.cpp
 create mode 100644 clib/src/sample_gpu.cu
 create mode 100644 configs/nerf_default.json
 create mode 100644 configs/nerf_voxels.json
 create mode 100644 configs/nsvf_coarse.json
 create mode 100644 configs/nsvf_default.json
 create mode 100644 configs/nsvf_voxels.json
 rename configs/{ => old}/bgnet.py (100%)
 rename configs/{ => old}/cnerf.py (100%)
 rename configs/{ => old}/dnerfabins.py (100%)
 rename configs/{ => old}/fovea.py (100%)
 rename configs/{ => old}/fovea_small_rot1.py (100%)
 rename configs/{ => old}/fovea_small_trans.py (100%)
 rename configs/{ => old}/msl2fast.py (100%)
 rename configs/{ => old}/msl_fovea.py (100%)
 rename configs/{ => old}/mslfast.py (100%)
 rename configs/{ => old}/mslray.py (100%)
 rename configs/{ => old}/nerf.py (100%)
 rename configs/{ => old}/nerf_horns.py (100%)
 rename configs/{ => old}/nerf_horns_4.py (100%)
 rename configs/{ => old}/nerf_horns_8.py (100%)
 rename configs/{ => old}/nerf_periph.py (100%)
 rename configs/{ => old}/nerf_trex.py (100%)
 rename configs/{ => old}/nerf_trex_4.py (100%)
 rename configs/{ => old}/nerf_trex_8.py (100%)
 rename configs/{ => old}/nerfsimple.py (100%)
 rename configs/{ => old}/nmsl_fovea.py (100%)
 rename configs/{ => old}/nnerf.py (100%)
 rename configs/{ => old}/oracle.py (100%)
 rename configs/{ => old}/periph.py (100%)
 rename configs/{ => old}/periph_new.py (100%)
 rename configs/{ => old}/periph_small_trans.py (100%)
 rename configs/{ => old}/snerffast_periph.py (100%)
 rename configs/{ => old}/snerffastx.py (100%)
 create mode 100644 configs/snerf_fine_voxels.json
 create mode 100644 configs/snerf_voxels+ls-d.json
 create mode 100644 configs/snerf_voxels+ls.json
 create mode 100644 configs/snerf_voxels.json
 create mode 100644 configs/snerf_voxels_128x8_x2.json
 create mode 100644 configs/snerf_voxels_128x8_x4.json
 create mode 100644 configs/snerf_voxels_feat.json
 create mode 100644 configs/snerf_voxels_fine.json
 create mode 100644 configs/snerfadv_finevoxels+ls.json
 create mode 100644 configs/snerfadv_finevoxels+ls_256x4_256x6_16x2.json
 create mode 100644 configs/snerfadv_finevoxels+ls_256x4_256x6_combined.json
 create mode 100644 configs/snerfadv_finevoxels_ls2.json
 create mode 100644 configs/snerfadv_voxels+ls+ns.json
 create mode 100644 configs/snerfadv_voxels+ls.json
 create mode 100644 configs/snerfadv_voxels+ls1.json
 create mode 100644 configs/snerfadv_voxels+ls2.json
 create mode 100644 configs/snerfadv_voxels+ls3.json
 create mode 100644 configs/snerfadv_voxels+ls4.json
 create mode 100644 configs/snerfadv_voxels+ls5.json
 create mode 100644 configs/snerfadv_voxels+ls6.json
 create mode 100644 configs/snerfadvx_voxels_x16.json
 create mode 100644 configs/snerfadvx_voxels_x4.json
 create mode 100644 configs/snerfadvx_voxels_x8.json
 create mode 100644 configs/snerfx_voxels_128x4_x4.json
 create mode 100644 configs/snerfx_voxels_128x4_x8.json
 create mode 100644 configs/snerfx_voxels_128x8_x4.json
 create mode 100644 configs/snerfx_voxels_256x4_x4.json
 create mode 100644 configs/snerfx_voxels_256x4_x4_balance.json
 create mode 100644 debug/voxel_sampler_export3d.py
 create mode 100644 fntest.py
 create mode 100644 loss/__init__.py
 create mode 100644 loss/cauchy.py
 create mode 100644 model/__init__.py
 create mode 100644 model/base.py
 rename {nets => model}/bg_net.py (100%)
 create mode 100644 model/nerf.py
 create mode 100644 model/nerf_advance.py
 rename {nets => model}/nerf_depth.py (96%)
 create mode 100644 model/nsvf.py
 rename {nets => model}/oracle.py (96%)
 create mode 100644 model/snerf.py
 create mode 100644 model/snerf_advance.py
 create mode 100644 model/snerf_advance_x.py
 rename {nets => model}/snerf_fast.py (98%)
 create mode 100644 model/snerf_x.py
 create mode 100644 modules/core.py
 create mode 100644 modules/space.py
 delete mode 160000 nerf++
 delete mode 100644 nets/nerf.py
 delete mode 100644 nets/nsvf.py
 delete mode 100644 nets/snerf.py
 create mode 100644 setup.py
 create mode 100644 term_test.py
 create mode 100644 test.py
 create mode 100644 tools/pano_process.py
 create mode 100644 train.py
 create mode 100644 train/__init__.py
 create mode 100644 train/base.py
 create mode 100644 train/train_with_space.py
 create mode 100644 utils/geometry.py
 create mode 100644 utils/voxels.py

diff --git a/.vscode/launch.json b/.vscode/launch.json
new file mode 100644
index 0000000..70285a7
--- /dev/null
+++ b/.vscode/launch.json
@@ -0,0 +1,57 @@
+{
+    // 浣跨敤 IntelliSense 浜嗚В鐩稿叧灞炴€с€� 
+    // 鎮仠浠ユ煡鐪嬬幇鏈夊睘鎬х殑鎻忚堪銆�
+    // 娆蹭簡瑙f洿澶氫俊鎭紝璇疯闂�: https://go.microsoft.com/fwlink/?linkid=830387
+    "version": "0.2.0",
+    "configurations": [
+
+
+        {
+            "name": "Debug/Voxel Sampler Export 3D",
+            "type": "python",
+            "request": "launch",
+            "program": "debug/voxel_sampler_export3d.py",
+            "args": [
+                "-p",
+                "data/__new/barbershop_fovea_r360x80_t0.6/train_t0.3.json"
+            ],
+            "console": "integratedTerminal"
+        },
+        {
+            "name": "Train",
+            "type": "python",
+            "request": "launch",
+            "program": "train.py",
+            "args": [
+                //"-c",
+                //"snerf_voxels",
+                "/home/dengnc/dvs/data/__new/barbershop_fovea_r360x80_t0.6/_nets/train_t0.3/snerfadvx_voxels_x4/checkpoint_10.tar",
+                "--prune",
+                "100",
+                "--split",
+                "100"
+                //"data/__new/barbershop_fovea_r360x80_t0.6/train_t0.3.json"
+            ],
+            "console": "integratedTerminal"
+        },
+        {
+            "name": "Test",
+            "type": "python",
+            "request": "launch",
+            "program": "test.py",
+            "args": [
+                "-m",
+                "/home/dengnc/dvs/data/__new/barbershop_fovea_r360x80_t0.6/_nets/train_t0.3/snerfadv_voxels+ls2/checkpoint_50.tar",
+                "-o",
+                "perf",
+                "color",
+                "--output-type",
+                "image",
+                "/home/dengnc/dvs/data/__new/barbershop_fovea_r360x80_t0.6/test_t0.3.json",
+                "--views",
+                "1"
+            ],
+            "console": "integratedTerminal"
+        }
+    ]
+}
\ No newline at end of file
diff --git a/README.md b/README.md
index 9005cdb..fee4397 100644
--- a/README.md
+++ b/README.md
@@ -13,6 +13,12 @@ Or ref to https://pytorch.org/get-started/locally/ for install guide
 
 * tensorboard
 
+* plyfile
+
+```
+$ conda install -c conda-forge plyfile
+```
+
 * (Optional) dash
 
 ```
diff --git a/blender/gen_pano.py b/blender/gen_pano.py
new file mode 100644
index 0000000..ee0e9ed
--- /dev/null
+++ b/blender/gen_pano.py
@@ -0,0 +1,15 @@
+import sys
+import os
+import argparse
+
+sys.path.append(os.path.dirname(os.path.abspath(__file__)))
+
+from gen_utils import GenPano
+
+parser = argparse.ArgumentParser()
+parser.add_argument('-r', '--radius', type=float, required=True)
+parser.add_argument("-n", "--samples", type=int, required=True)
+parser.add_argument("--cycles-device", type=str)
+args = parser.parse_args(sys.argv[sys.argv.index("--") + 1:])
+
+GenPano('output/pano', f'hr_r{args.radius:.1f}', samples=[args.samples], depth_range=[args.radius, 50])()
diff --git a/blender/gen_utils.py b/blender/gen_utils.py
new file mode 100644
index 0000000..8f100d3
--- /dev/null
+++ b/blender/gen_utils.py
@@ -0,0 +1,184 @@
+import bpy
+import math
+import json
+import os
+import math
+import numpy as np
+from typing import List, Tuple
+from itertools import product
+
+
+class Gen:
+    def __init__(self, root_dir: str, dataset_name: str, *,
+                 res: Tuple[int, int],
+                 fov: float,
+                 samples: List[int]) -> None:
+        self.res = res
+        self.fov = fov
+        self.samples = samples
+
+        self.scene = bpy.context.scene
+        self.cam_obj = self.scene.camera
+        self.cam = self.cam_obj.data
+        self.scene.render.resolution_x = self.res[0]
+        self.scene.render.resolution_y = self.res[1]
+        self.init_camera()
+
+        self.root_dir = root_dir
+        self.data_dir = f"{root_dir}/{dataset_name}/"
+        self.data_name = dataset_name
+        self.data_desc_file = f'{root_dir}/{dataset_name}.json'
+
+    def init_camera(self):
+        if self.fov < 0:
+            self.cam.type = 'PANO'
+            self.cam.cycles.panorama_type = 'EQUIRECTANGULAR'
+        else:
+            self.cam.type = 'PERSP'
+            self.cam.lens_unit = 'FOV'
+            self.cam.angle = math.radians(self.fov)
+        self.cam.dof.use_dof = False
+        self.cam.clip_start = 0.1
+        self.cam.clip_end = 1000
+
+    def init_desc(self):
+        return None
+
+    def save_desc(self):
+        with open(self.data_desc_file, 'w') as fp:
+            json.dump(self.desc, fp, indent=4)
+
+    def add_sample(self, i, x: List[float], render_only=False):
+        self.cam_obj.location = x[:3]
+        if len(x) > 3:
+            self.cam_obj.rotation_euler = [math.radians(x[4]), math.radians(x[3]), 0]
+        self.scene.render.filepath = self.data_dir + self.desc['view_file_pattern'] % i
+        bpy.ops.render.render(write_still=True)
+        if not render_only:
+            self.desc['view_centers'].append(x[:3])
+            if len(x) > 3:
+                self.desc['view_rots'].append(x[3:])
+            self.save_desc()
+
+    def gen_grid(self):
+        start_view = len(self.desc['view_centers'])
+        ranges = [
+            np.linspace(self.desc['range']['min'][i],
+                        self.desc['range']['max'][i],
+                        self.desc['samples'][i])
+            for i in range(len(self.desc['samples']))
+        ]
+        for i, x in enumerate(product(*ranges)):
+            if i >= start_view:
+                self.add_sample(i, list(x))
+
+    def gen_rand(self):
+        pass
+
+    def __call__(self):
+        os.makedirs(self.data_dir, exist_ok=True)
+        if os.path.exists(self.data_desc_file):
+            with open(self.data_desc_file, 'r') as fp:
+                self.desc = json.load(fp)
+        else:
+            self.desc = self.init_desc()
+
+        # Render missing views in data desc
+        for i in range(len(self.desc['view_centers'])):
+            if not os.path.exists(self.data_dir + self.desc['view_file_pattern'] % i):
+                x: List[float] = self.desc['view_centers'][i]
+                if 'view_rots' in self.desc:
+                    x += self.desc['view_rots'][i]
+                self.add_sample(i, x, render_only=True)
+
+        if len(self.desc['samples']) == 1:
+            self.gen_rand()
+        else:
+            self.gen_grid()
+
+
+class GenView(Gen):
+
+    def __init__(self, root_dir: str, dataset_name: str, *,
+                 res: Tuple[int, int], fov: float, samples: List[int],
+                 tbox: Tuple[float, float, float], rbox: Tuple[float, float]) -> None:
+        super().__init__(root_dir, dataset_name, res=res, fov=fov, samples=samples)
+        self.tbox = tbox
+        self.rbox = rbox
+
+    def init_desc(self):
+        return {
+            'view_file_pattern': 'view_%04d.png',
+            "gl_coord": True,
+            'view_res': {
+                'x': self.res[0],
+                'y': self.res[1]
+            },
+            'cam_params': {
+                'fov': self.fov,
+                'cx': 0.5,
+                'cy': 0.5,
+                'normalized': True
+            },
+            'range': {
+                'min': [-self.tbox[0] / 2, -self.tbox[1] / 2, -self.tbox[2] / 2,
+                        -self.rbox[0] / 2, -self.rbox[1] / 2],
+                'max': [self.tbox[0] / 2, self.tbox[1] / 2, self.tbox[2] / 2,
+                        self.rbox[0] / 2, self.rbox[1] / 2]
+            },
+            'samples': self.samples,
+            'view_centers': [],
+            'view_rots': []
+        }
+
+    def gen_rand(self):
+        start_view = len(self.desc['view_centers'])
+        n = self.desc['samples'][0] - start_view
+        range_min = np.array(self.desc['range']['min'])
+        range_max = np.array(self.desc['range']['max'])
+        samples = (range_max - range_min) * np.random.rand(n, 5) + range_min
+        for i in range(n):
+            self.add_sample(i + start_view, list(samples[i]))
+
+
+class GenPano(Gen):
+
+    def __init__(self, root_dir: str, dataset_name: str, *,
+                 samples: List[int], depth_range: Tuple[float, float],
+                 tbox: Tuple[float, float, float] = None) -> None:
+        self.depth_range = depth_range
+        self.tbox = tbox
+        super().__init__(root_dir, dataset_name, res=[4096, 2048], fov=-1, samples=samples)
+
+    def init_desc(self):
+        range = {
+            'range': {
+                'min': [-self.tbox[0] / 2, -self.tbox[1] / 2, -self.tbox[2] / 2],
+                'max': [self.tbox[0] / 2, self.tbox[1] / 2, self.tbox[2] / 2]
+            }
+        } if self.tbox else {}
+        return {
+            "type": "pano",
+            'view_file_pattern': 'view_%04d.png',
+            "gl_coord": True,
+            'view_res': {
+                'x': self.res[0],
+                'y': self.res[1]
+            },
+            **range,
+            "depth_range": {
+                "min": self.depth_range[0],
+                "max": self.depth_range[1]
+            },
+            'samples': self.samples,
+            'view_centers': []
+        }
+
+    def gen_rand(self):
+        start_view = len(self.desc['view_centers'])
+        n = self.desc['samples'][0] - start_view
+        r_max = self.desc['depth_range']['min']
+        pts = (np.random.rand(n * 5, 3) - 0.5) * 2 * r_max
+        samples = pts[np.linalg.norm(pts, axis=1) < r_max][:n]
+        for i in range(n):
+            self.add_sample(i + start_view, list(samples[i]))
diff --git a/clib/__init__.py b/clib/__init__.py
new file mode 100644
index 0000000..c740aff
--- /dev/null
+++ b/clib/__init__.py
@@ -0,0 +1,479 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+''' Modified based on: https://github.com/erikwijmans/Pointnet2_PyTorch '''
+from __future__ import (
+    division,
+    absolute_import,
+    with_statement,
+    print_function,
+    unicode_literals,
+)
+import os
+import sys
+from typing import Tuple
+import torch
+import torch.nn.functional as F
+from torch.autograd import Function
+import torch.nn as nn
+import sys
+import numpy as np
+from utils.geometry import discretize_points
+from utils.constants import HUGE_FLOAT
+
+try:
+    import builtins
+except:
+    import __builtin__ as builtins
+
+try:
+    import clib._ext as _ext
+except ImportError:
+    raise ImportError(
+        "Could not import _ext module.\n"
+        "Please see the setup instructions in the README"
+    )
+
+
+
+class BallRayIntersect(Function):
+    @staticmethod
+    def forward(ctx, radius, n_max, points, ray_start, ray_dir):
+        inds, min_depth, max_depth = _ext.ball_intersect(
+            ray_start.float(), ray_dir.float(), points.float(), radius, n_max)
+        min_depth = min_depth.type_as(ray_start)
+        max_depth = max_depth.type_as(ray_start)
+
+        ctx.mark_non_differentiable(inds)
+        ctx.mark_non_differentiable(min_depth)
+        ctx.mark_non_differentiable(max_depth)
+        return inds, min_depth, max_depth
+
+    @staticmethod
+    def backward(ctx, a, b, c):
+        return None, None, None, None, None
+
+
+ball_ray_intersect = BallRayIntersect.apply
+
+
+class AABBRayIntersect(Function):
+    @staticmethod
+    def forward(ctx, voxelsize, n_max, points, ray_start, ray_dir):
+        # HACK: speed-up ray-voxel intersection by batching...
+        G = min(2048, int(2 * 10 ** 9 / points.numel()))   # HACK: avoid out-of-memory
+        S, N = ray_start.shape[:2]
+        K = int(np.ceil(N / G))
+        G, K = 1, N # HACK
+        H = K * G
+        if H > N:
+            ray_start = torch.cat([ray_start, ray_start[:, :H - N]], 1)
+            ray_dir = torch.cat([ray_dir, ray_dir[:, :H - N]], 1)
+        ray_start = ray_start.reshape(S * G, K, 3)
+        ray_dir = ray_dir.reshape(S * G, K, 3)
+        points = points[None].expand(S * G, *points.size()).contiguous()
+
+        inds, min_depth, max_depth = _ext.aabb_intersect(
+            ray_start.float(), ray_dir.float(), points.float(), voxelsize, n_max)
+        min_depth = min_depth.type_as(ray_start)
+        max_depth = max_depth.type_as(ray_start)
+
+        inds = inds.reshape(S, H, -1)
+        min_depth = min_depth.reshape(S, H, -1)
+        max_depth = max_depth.reshape(S, H, -1)
+        if H > N:
+            inds = inds[:, :N]
+            min_depth = min_depth[:, :N]
+            max_depth = max_depth[:, :N]
+
+        ctx.mark_non_differentiable(inds)
+        ctx.mark_non_differentiable(min_depth)
+        ctx.mark_non_differentiable(max_depth)
+        return inds, min_depth, max_depth
+
+    @staticmethod
+    def backward(ctx, a, b, c):
+        return None, None, None, None, None
+
+
+def aabb_ray_intersect(voxelsize: float, n_max: int, points: torch.Tensor, ray_start: torch.Tensor,
+                       ray_dir: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+    """
+    AABB-Ray intersect test
+
+    :param voxelsize `float`: size of a voxel
+    :param n_max `int`: maximum number of hits
+    :param points `Tensor(M, 3)`: voxels' centers
+    :param ray_start `Tensor(S, N, 3)`: rays' start positions
+    :param ray_dir `Tensor(S, N, 3)`: rays' directions
+    :return `Tensor(S, N, n_max)`: indices of intersected voxels or -1
+    :return `Tensor(S, N, n_max)`: min depths of every intersected voxels
+    :return `Tensor(S, N, n_max)`: max depths of every intersected voxels
+    """
+    return AABBRayIntersect.apply(voxelsize, n_max, points, ray_start, ray_dir)
+
+
+class SparseVoxelOctreeRayIntersect(Function):
+    @staticmethod
+    def forward(ctx, voxelsize, n_max, points, children, ray_start, ray_dir):
+        # HACK: avoid out-of-memory
+        G = min(2048, int(2 * 10 ** 9 / (points.numel() + children.numel())))
+        S, N = ray_start.shape[:2]
+        K = int(np.ceil(N / G))
+        G, K = 1, N # HACK
+        H = K * G
+        if H > N:
+            ray_start = torch.cat([ray_start, ray_start[:, :H - N]], 1)
+            ray_dir = torch.cat([ray_dir, ray_dir[:, :H - N]], 1)
+        ray_start = ray_start.reshape(S * G, K, 3)
+        ray_dir = ray_dir.reshape(S * G, K, 3)
+        points = points[None].expand(S * G, *points.size()).contiguous()
+        children = children[None].expand(S * G, *children.size()).contiguous()
+        inds, min_depth, max_depth = _ext.svo_intersect(
+            ray_start.float(), ray_dir.float(), points.float(), children.int(), voxelsize, n_max)
+
+        min_depth = min_depth.type_as(ray_start)
+        max_depth = max_depth.type_as(ray_start)
+
+        inds = inds.reshape(S, H, -1)
+        min_depth = min_depth.reshape(S, H, -1)
+        max_depth = max_depth.reshape(S, H, -1)
+        if H > N:
+            inds = inds[:, :N]
+            min_depth = min_depth[:, :N]
+            max_depth = max_depth[:, :N]
+
+        ctx.mark_non_differentiable(inds)
+        ctx.mark_non_differentiable(min_depth)
+        ctx.mark_non_differentiable(max_depth)
+        return inds, min_depth, max_depth
+
+    @staticmethod
+    def backward(ctx, a, b, c):
+        return None, None, None, None, None
+
+
+def octree_ray_intersect(voxelsize: float, n_max: int, points: torch.Tensor, children: torch.Tensor,
+                         ray_start: torch.Tensor, ray_dir: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+    """
+    Octree-Ray intersect test
+
+    :param voxelsize `float`: size of a voxel
+    :param n_max `int`: maximum number of hits
+    :param points `Tensor(M, 3)`: voxels' centers
+    :param children `Tensor(M, 9)`: flattened octree structure
+    :param ray_start `Tensor(S, N, 3)`: rays' start positions
+    :param ray_dir `Tensor(S, N, 3)`: rays' directions
+    :return `Tensor(S, N, n_max)`: indices of intersected voxels or -1
+    :return `Tensor(S, N, n_max)`: min depths of every intersected voxels
+    :return `Tensor(S, N, n_max)`: max depths of every intersected voxels
+    """
+    return SparseVoxelOctreeRayIntersect.apply(voxelsize, n_max, points, children, ray_start,
+                                               ray_dir)
+
+
+class TriangleRayIntersect(Function):
+    @staticmethod
+    def forward(ctx, cagesize, blur_ratio, n_max, points, faces, ray_start, ray_dir):
+        # HACK: speed-up ray-voxel intersection by batching...
+        G = min(2048, int(2 * 10 ** 9 / (3 * faces.numel())))   # HACK: avoid out-of-memory
+        S, N = ray_start.shape[:2]
+        K = int(np.ceil(N / G))
+        H = K * G
+        if H > N:
+            ray_start = torch.cat([ray_start, ray_start[:, :H - N]], 1)
+            ray_dir = torch.cat([ray_dir, ray_dir[:, :H - N]], 1)
+        ray_start = ray_start.reshape(S * G, K, 3)
+        ray_dir = ray_dir.reshape(S * G, K, 3)
+        face_points = F.embedding(faces.reshape(-1, 3), points.reshape(-1, 3))
+        face_points = face_points.unsqueeze(0).expand(S * G, *face_points.size()).contiguous()
+        inds, depth, uv = _ext.triangle_intersect(
+            ray_start.float(), ray_dir.float(), face_points.float(), cagesize, blur_ratio, n_max)
+        depth = depth.type_as(ray_start)
+        uv = uv.type_as(ray_start)
+
+        inds = inds.reshape(S, H, -1)
+        depth = depth.reshape(S, H, -1, 3)
+        uv = uv.reshape(S, H, -1)
+        if H > N:
+            inds = inds[:, :N]
+            depth = depth[:, :N]
+            uv = uv[:, :N]
+
+        ctx.mark_non_differentiable(inds)
+        ctx.mark_non_differentiable(depth)
+        ctx.mark_non_differentiable(uv)
+        return inds, depth, uv
+
+    @staticmethod
+    def backward(ctx, a, b, c):
+        return None, None, None, None, None, None
+
+
+triangle_ray_intersect = TriangleRayIntersect.apply
+
+
+class UniformRaySampling(Function):
+    @staticmethod
+    def forward(ctx, pts_idx, min_depth, max_depth, step_size, max_ray_length, deterministic=False):
+        G, N, P = 256, pts_idx.size(0), pts_idx.size(1)
+        H = int(np.ceil(N / G)) * G
+        if H > N:
+            pts_idx = torch.cat([pts_idx, pts_idx[:H - N]], 0)
+            min_depth = torch.cat([min_depth, min_depth[:H - N]], 0)
+            max_depth = torch.cat([max_depth, max_depth[:H - N]], 0)
+        pts_idx = pts_idx.reshape(G, -1, P)
+        min_depth = min_depth.reshape(G, -1, P)
+        max_depth = max_depth.reshape(G, -1, P)
+
+        # pre-generate noise
+        max_steps = int(max_ray_length / step_size)
+        max_steps = max_steps + min_depth.size(-1) * 2
+        noise = min_depth.new_zeros(*min_depth.size()[:-1], max_steps)
+        if deterministic:
+            noise += 0.5
+        else:
+            noise = noise.uniform_()
+
+        # call cuda function
+        sampled_idx, sampled_depth, sampled_dists = _ext.uniform_ray_sampling(
+            pts_idx, min_depth.float(), max_depth.float(), noise.float(), step_size, max_steps)
+        sampled_depth = sampled_depth.type_as(min_depth)
+        sampled_dists = sampled_dists.type_as(min_depth)
+
+        sampled_idx = sampled_idx.reshape(H, -1)
+        sampled_depth = sampled_depth.reshape(H, -1)
+        sampled_dists = sampled_dists.reshape(H, -1)
+        if H > N:
+            sampled_idx = sampled_idx[: N]
+            sampled_depth = sampled_depth[: N]
+            sampled_dists = sampled_dists[: N]
+
+        max_len = sampled_idx.ne(-1).sum(-1).max()
+        sampled_idx = sampled_idx[:, :max_len]
+        sampled_depth = sampled_depth[:, :max_len]
+        sampled_dists = sampled_dists[:, :max_len]
+
+        ctx.mark_non_differentiable(sampled_idx)
+        ctx.mark_non_differentiable(sampled_depth)
+        ctx.mark_non_differentiable(sampled_dists)
+        return sampled_idx, sampled_depth, sampled_dists
+
+    @staticmethod
+    def backward(ctx, a, b, c):
+        return None, None, None, None, None, None
+
+
+def uniform_ray_sampling(pts_idx: torch.Tensor, min_depth: torch.Tensor, max_depth: torch.Tensor,
+                         step_size: float, max_ray_length: float, deterministic: bool = False) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+    """
+    Sample along rays uniformly
+
+    :param pts_idx `Tensor(N, P)`: indices of voxels intersected with rays
+    :param min_depth `Tensor(N, P)`: min depth of intersections of rays and voxels
+    :param max_depth `Tensor(N, P)`: max depth of intersections of rays and voxels
+    :param step_size `float`: size of sampling step
+    :param max_ray_length `float`: maximum sampling depth along rays
+    :param deterministic `bool`: (optional) sample deterministically (or randomly), defaults to False
+    :return `Tensor(N, P')`: voxel indices of sampled points
+    :return `Tensor(N, P')`: depth of sampled points
+    :return `Tensor(N, P')`: length of sampled points
+    """
+    return UniformRaySampling.apply(pts_idx, min_depth, max_depth, step_size, max_ray_length,
+                                    deterministic)
+
+
+class InverseCDFRaySampling(Function):
+    @staticmethod
+    def forward(ctx, pts_idx, min_depth, max_depth, probs, steps, fixed_step_size=-1, deterministic=False):
+        G, N, P = 200, pts_idx.size(0), pts_idx.size(1)
+        H = int(np.ceil(N / G)) * G
+
+        if H > N:
+            pts_idx = torch.cat([pts_idx, pts_idx[:1].expand(H - N, P)], 0)
+            min_depth = torch.cat([min_depth, min_depth[:1].expand(H - N, P)], 0)
+            max_depth = torch.cat([max_depth, max_depth[:1].expand(H - N, P)], 0)
+            probs = torch.cat([probs, probs[:1].expand(H - N, P)], 0)
+            steps = torch.cat([steps, steps[:1].expand(H - N)], 0)
+        # print(G, P, np.ceil(N / G), N, H, pts_idx.shape, min_depth.device)
+        pts_idx = pts_idx.reshape(G, -1, P)
+        min_depth = min_depth.reshape(G, -1, P)
+        max_depth = max_depth.reshape(G, -1, P)
+        probs = probs.reshape(G, -1, P)
+        steps = steps.reshape(G, -1)
+
+        # pre-generate noise
+        max_steps = steps.ceil().long().max() + P
+        noise = min_depth.new_zeros(*min_depth.size()[:-1], max_steps)
+        if deterministic:
+            noise += 0.5
+        else:
+            noise = noise.uniform_().clamp(min=0.001, max=0.999)  # in case
+
+        # call cuda function
+        chunk_size = 4 * G  # to avoid oom?
+        results = [
+            _ext.inverse_cdf_sampling(
+                pts_idx[:, i:i + chunk_size].contiguous(),
+                min_depth.float()[:, i:i + chunk_size].contiguous(),
+                max_depth.float()[:, i:i + chunk_size].contiguous(),
+                noise.float()[:, i:i + chunk_size].contiguous(),
+                probs.float()[:, i:i + chunk_size].contiguous(),
+                steps.float()[:, i:i + chunk_size].contiguous(),
+                fixed_step_size)
+            for i in range(0, min_depth.size(1), chunk_size)
+        ]
+        sampled_idx, sampled_depth, sampled_dists = [
+            torch.cat([r[i] for r in results], 1)
+            for i in range(3)
+        ]
+        sampled_depth = sampled_depth.type_as(min_depth)
+        sampled_dists = sampled_dists.type_as(min_depth)
+
+        sampled_idx = sampled_idx.reshape(H, -1)
+        sampled_depth = sampled_depth.reshape(H, -1)
+        sampled_dists = sampled_dists.reshape(H, -1)
+        if H > N:
+            sampled_idx = sampled_idx[: N]
+            sampled_depth = sampled_depth[: N]
+            sampled_dists = sampled_dists[: N]
+
+        max_len = sampled_idx.ne(-1).sum(-1).max()
+        sampled_idx = sampled_idx[:, :max_len]
+        sampled_depth = sampled_depth[:, :max_len]
+        sampled_dists = sampled_dists[:, :max_len]
+
+        ctx.mark_non_differentiable(sampled_idx)
+        ctx.mark_non_differentiable(sampled_depth)
+        ctx.mark_non_differentiable(sampled_dists)
+        return sampled_idx, sampled_depth, sampled_dists
+
+    @staticmethod
+    def backward(ctx, a, b, c):
+        return None, None, None, None, None, None, None
+
+
+def inverse_cdf_sampling(pts_idx: torch.Tensor, min_depth: torch.Tensor, max_depth: torch.Tensor,
+                         probs: torch.Tensor, steps: torch.Tensor, fixed_step_size: float = -1,
+                         deterministic: bool = False) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+    """
+    Sample along rays by inverse CDF
+
+    :param pts_idx `Tensor(N, P)`: indices of voxels intersected with rays
+    :param min_depth `Tensor(N, P)`: min depth of intersections of rays and voxels
+    :param max_depth `Tensor(N, P)`: max depth of intersections of rays and voxels
+    :param probs `Tensor(N, P)`:
+    :param steps `Tensor(N)`: 
+    :param fixed_step_size `float`:
+    :param deterministic `bool`: (optional) sample deterministically (or randomly), defaults to False
+    :return `Tensor(N, P')`: voxel indices of sampled points
+    :return `Tensor(N, P')`: depth of sampled points
+    :return `Tensor(N, P')`: length of sampled points
+    """
+    return InverseCDFRaySampling.apply(pts_idx, min_depth, max_depth, probs, steps, fixed_step_size,
+                                       deterministic)
+
+
+# back-up for ray point sampling
+@torch.no_grad()
+def _parallel_ray_sampling(MARCH_SIZE, pts_idx, min_depth, max_depth, deterministic=False):
+    # uniform sampling
+    _min_depth = min_depth.min(1)[0]
+    _max_depth = max_depth.masked_fill(max_depth.eq(HUGE_FLOAT), 0).max(1)[0]
+    max_ray_length = (_max_depth - _min_depth).max()
+
+    delta = torch.arange(int(max_ray_length / MARCH_SIZE),
+                         device=min_depth.device, dtype=min_depth.dtype)
+    delta = delta[None, :].expand(min_depth.size(0), delta.size(-1))
+    if deterministic:
+        delta = delta + 0.5
+    else:
+        delta = delta + delta.clone().uniform_().clamp(min=0.01, max=0.99)
+    delta = delta * MARCH_SIZE
+    sampled_depth = min_depth[:, :1] + delta
+    sampled_idx = (sampled_depth[:, :, None] >= min_depth[:, None, :]).sum(-1) - 1
+    sampled_idx = pts_idx.gather(1, sampled_idx)
+
+    # include all boundary points
+    sampled_depth = torch.cat([min_depth, max_depth, sampled_depth], -1)
+    sampled_idx = torch.cat([pts_idx, pts_idx, sampled_idx], -1)
+
+    # reorder
+    sampled_depth, ordered_index = sampled_depth.sort(-1)
+    sampled_idx = sampled_idx.gather(1, ordered_index)
+    sampled_dists = sampled_depth[:, 1:] - sampled_depth[:, :-1]          # distances
+    sampled_depth = .5 * (sampled_depth[:, 1:] + sampled_depth[:, :-1])   # mid-points
+
+    # remove all invalid depths
+    min_ids = (sampled_depth[:, :, None] >= min_depth[:, None, :]).sum(-1) - 1
+    max_ids = (sampled_depth[:, :, None] >= max_depth[:, None, :]).sum(-1)
+
+    sampled_depth.masked_fill_(
+        (max_ids.ne(min_ids)) |
+        (sampled_depth > _max_depth[:, None]) |
+        (sampled_dists == 0.0), HUGE_FLOAT)
+    sampled_depth, ordered_index = sampled_depth.sort(-1)  # sort again
+    sampled_masks = sampled_depth.eq(HUGE_FLOAT)
+    num_max_steps = (~sampled_masks).sum(-1).max()
+
+    sampled_depth = sampled_depth[:, :num_max_steps]
+    sampled_dists = sampled_dists.gather(1, ordered_index).masked_fill_(
+        sampled_masks, 0.0)[:, :num_max_steps]
+    sampled_idx = sampled_idx.gather(1, ordered_index).masked_fill_(
+        sampled_masks, -1)[:, :num_max_steps]
+
+    return sampled_idx, sampled_depth, sampled_dists
+
+
+@torch.no_grad()
+def parallel_ray_sampling(MARCH_SIZE, pts_idx, min_depth, max_depth, deterministic=False):
+    chunk_size = 4096
+    full_size = min_depth.shape[0]
+    if full_size <= chunk_size:
+        return _parallel_ray_sampling(MARCH_SIZE, pts_idx, min_depth, max_depth, deterministic=deterministic)
+
+    outputs = zip(*[
+        _parallel_ray_sampling(
+            MARCH_SIZE,
+            pts_idx[i:i + chunk_size], min_depth[i:i + chunk_size], max_depth[i:i + chunk_size],
+            deterministic=deterministic)
+        for i in range(0, full_size, chunk_size)])
+    sampled_idx, sampled_depth, sampled_dists = outputs
+
+    def padding_points(xs, pad):
+        if len(xs) == 1:
+            return xs[0]
+
+        maxlen = max([x.size(1) for x in xs])
+        full_size = sum([x.size(0) for x in xs])
+        xt = xs[0].new_ones(full_size, maxlen).fill_(pad)
+        st = 0
+        for i in range(len(xs)):
+            xt[st: st + xs[i].size(0), :xs[i].size(1)] = xs[i]
+            st += xs[i].size(0)
+        return xt
+
+    sampled_idx = padding_points(sampled_idx, -1)
+    sampled_depth = padding_points(sampled_depth, HUGE_FLOAT)
+    sampled_dists = padding_points(sampled_dists, 0.0)
+    return sampled_idx, sampled_depth, sampled_dists
+
+
+def build_easy_octree(points: torch.Tensor, half_voxel: float) -> Tuple[torch.Tensor, torch.Tensor]:
+    """
+    Build an octree.
+
+    :param points `Tensor(M, 3)`: centers of leaf voxels
+    :param half_voxel `float`: half size of voxel
+    :return `Tensor(M', 3)`: centers of all nodes in octree
+    :return `Tensor(M', 9)`: flattened octree structure
+    """
+    coords, residual = discretize_points(points, half_voxel)
+    ranges = coords.max(0)[0] - coords.min(0)[0]
+    depths = torch.log2(ranges.max().float()).ceil_().long() - 1
+    center = (coords.max(0)[0] + coords.min(0)[0]) / 2
+    centers, children = _ext.build_octree(center, coords, int(depths))
+    centers = centers.float() * half_voxel + residual   # transform back to float
+    return centers, children
\ No newline at end of file
diff --git a/clib/include/cuda_utils.h b/clib/include/cuda_utils.h
new file mode 100644
index 0000000..d4c4bb4
--- /dev/null
+++ b/clib/include/cuda_utils.h
@@ -0,0 +1,46 @@
+// Copyright (c) Facebook, Inc. and its affiliates.
+// 
+// This source code is licensed under the MIT license found in the
+// LICENSE file in the root directory of this source tree.
+
+#ifndef _CUDA_UTILS_H
+#define _CUDA_UTILS_H
+
+#include <ATen/ATen.h>
+#include <ATen/cuda/CUDAContext.h>
+#include <cmath>
+
+#include <cuda.h>
+#include <cuda_runtime.h>
+
+#include <vector>
+
+#define TOTAL_THREADS 512
+
+inline int opt_n_threads(int work_size) {
+  const int pow_2 = std::log(static_cast<double>(work_size)) / std::log(2.0);
+
+  return max(min(1 << pow_2, TOTAL_THREADS), 1);
+}
+
+inline dim3 opt_block_config(int x, int y) {
+  const int x_threads = opt_n_threads(x);
+  const int y_threads =
+      max(min(opt_n_threads(y), TOTAL_THREADS / x_threads), 1);
+  dim3 block_config(x_threads, y_threads, 1);
+
+  return block_config;
+}
+
+#define CUDA_CHECK_ERRORS()                                           \
+  do {                                                                \
+    cudaError_t err = cudaGetLastError();                             \
+    if (cudaSuccess != err) {                                         \
+      fprintf(stderr, "CUDA kernel failed : %s\n%s at L:%d in %s\n",  \
+              cudaGetErrorString(err), __PRETTY_FUNCTION__, __LINE__, \
+              __FILE__);                                              \
+      exit(-1);                                                       \
+    }                                                                 \
+  } while (0)
+
+#endif
diff --git a/clib/include/cutil_math.h b/clib/include/cutil_math.h
new file mode 100644
index 0000000..d8748b9
--- /dev/null
+++ b/clib/include/cutil_math.h
@@ -0,0 +1,793 @@
+// Copyright (c) Facebook, Inc. and its affiliates.
+// 
+// This source code is licensed under the MIT license found in the
+// LICENSE file in the root directory of this source tree.
+
+/*
+ * Copyright 1993-2009 NVIDIA Corporation.  All rights reserved.
+ *
+ * NVIDIA Corporation and its licensors retain all intellectual property and 
+ * proprietary rights in and to this software and related documentation and 
+ * any modifications thereto.  Any use, reproduction, disclosure, or distribution 
+ * of this software and related documentation without an express license 
+ * agreement from NVIDIA Corporation is strictly prohibited.
+ * 
+ */
+
+/*
+    This file implements common mathematical operations on vector types
+    (float3, float4 etc.) since these are not provided as standard by CUDA.
+
+    The syntax is modelled on the Cg standard library.
+*/
+
+#ifndef CUTIL_MATH_H
+#define CUTIL_MATH_H
+
+#include "cuda_runtime.h"
+
+////////////////////////////////////////////////////////////////////////////////
+typedef unsigned int uint;
+typedef unsigned short ushort;
+
+#ifndef __CUDACC__
+#include <math.h>
+
+inline float fminf(float a, float b)
+{
+  return a < b ? a : b;
+}
+
+inline float fmaxf(float a, float b)
+{
+  return a > b ? a : b;
+}
+
+inline int max(int a, int b)
+{
+  return a > b ? a : b;
+}
+
+inline int min(int a, int b)
+{
+  return a < b ? a : b;
+}
+
+inline float rsqrtf(float x)
+{
+    return 1.0f / sqrtf(x);
+}
+
+#endif
+
+// float functions
+////////////////////////////////////////////////////////////////////////////////
+
+// lerp
+inline __device__ __host__ float lerp(float a, float b, float t)
+{
+    return a + t*(b-a);
+}
+
+// clamp
+inline __device__ __host__ float clamp(float f, float a, float b)
+{
+    return fmaxf(a, fminf(f, b));
+}
+
+inline __device__ __host__ void swap(float &a, float &b)
+{
+    float c = a;
+    a = b;
+    b = c;
+}
+
+inline __device__ __host__ void swap(int &a, int &b)
+{
+    float c = a;
+    a = b;
+    b = c;
+}
+
+
+// int2 functions
+////////////////////////////////////////////////////////////////////////////////
+
+// negate
+inline __host__ __device__ int2 operator-(int2 &a)
+{
+    return make_int2(-a.x, -a.y);
+}
+
+// addition
+inline __host__ __device__ int2 operator+(int2 a, int2 b)
+{
+    return make_int2(a.x + b.x, a.y + b.y);
+}
+inline __host__ __device__ void operator+=(int2 &a, int2 b)
+{
+    a.x += b.x; a.y += b.y;
+}
+
+// subtract
+inline __host__ __device__ int2 operator-(int2 a, int2 b)
+{
+    return make_int2(a.x - b.x, a.y - b.y);
+}
+inline __host__ __device__ void operator-=(int2 &a, int2 b)
+{
+    a.x -= b.x; a.y -= b.y;
+}
+
+// multiply
+inline __host__ __device__ int2 operator*(int2 a, int2 b)
+{
+    return make_int2(a.x * b.x, a.y * b.y);
+}
+inline __host__ __device__ int2 operator*(int2 a, int s)
+{
+    return make_int2(a.x * s, a.y * s);
+}
+inline __host__ __device__ int2 operator*(int s, int2 a)
+{
+    return make_int2(a.x * s, a.y * s);
+}
+inline __host__ __device__ void operator*=(int2 &a, int s)
+{
+    a.x *= s; a.y *= s;
+}
+
+// float2 functions
+////////////////////////////////////////////////////////////////////////////////
+
+// additional constructors
+inline __host__ __device__ float2 make_float2(float s)
+{
+    return make_float2(s, s);
+}
+inline __host__ __device__ float2 make_float2(int2 a)
+{
+    return make_float2(float(a.x), float(a.y));
+}
+
+// negate
+inline __host__ __device__ float2 operator-(float2 &a)
+{
+    return make_float2(-a.x, -a.y);
+}
+
+// addition
+inline __host__ __device__ float2 operator+(float2 a, float2 b)
+{
+    return make_float2(a.x + b.x, a.y + b.y);
+}
+inline __host__ __device__ void operator+=(float2 &a, float2 b)
+{
+    a.x += b.x; a.y += b.y;
+}
+
+// subtract
+inline __host__ __device__ float2 operator-(float2 a, float2 b)
+{
+    return make_float2(a.x - b.x, a.y - b.y);
+}
+inline __host__ __device__ void operator-=(float2 &a, float2 b)
+{
+    a.x -= b.x; a.y -= b.y;
+}
+
+// multiply
+inline __host__ __device__ float2 operator*(float2 a, float2 b)
+{
+    return make_float2(a.x * b.x, a.y * b.y);
+}
+inline __host__ __device__ float2 operator*(float2 a, float s)
+{
+    return make_float2(a.x * s, a.y * s);
+}
+inline __host__ __device__ float2 operator*(float s, float2 a)
+{
+    return make_float2(a.x * s, a.y * s);
+}
+inline __host__ __device__ void operator*=(float2 &a, float s)
+{
+    a.x *= s; a.y *= s;
+}
+
+// divide
+inline __host__ __device__ float2 operator/(float2 a, float2 b)
+{
+    return make_float2(a.x / b.x, a.y / b.y);
+}
+inline __host__ __device__ float2 operator/(float2 a, float s)
+{
+    float inv = 1.0f / s;
+    return a * inv;
+}
+inline __host__ __device__ float2 operator/(float s, float2 a)
+{
+    float inv = 1.0f / s;
+    return a * inv;
+}
+inline __host__ __device__ void operator/=(float2 &a, float s)
+{
+    float inv = 1.0f / s;
+    a *= inv;
+}
+
+// lerp
+inline __device__ __host__ float2 lerp(float2 a, float2 b, float t)
+{
+    return a + t*(b-a);
+}
+
+// clamp
+inline __device__ __host__ float2 clamp(float2 v, float a, float b)
+{
+    return make_float2(clamp(v.x, a, b), clamp(v.y, a, b));
+}
+
+inline __device__ __host__ float2 clamp(float2 v, float2 a, float2 b)
+{
+    return make_float2(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y));
+}
+
+// dot product
+inline __host__ __device__ float dot(float2 a, float2 b)
+{ 
+    return a.x * b.x + a.y * b.y;
+}
+
+// length
+inline __host__ __device__ float length(float2 v)
+{
+    return sqrtf(dot(v, v));
+}
+
+// normalize
+inline __host__ __device__ float2 normalize(float2 v)
+{
+    float invLen = rsqrtf(dot(v, v));
+    return v * invLen;
+}
+
+// floor
+inline __host__ __device__ float2 floor(const float2 v)
+{
+    return make_float2(floor(v.x), floor(v.y));
+}
+
+// reflect
+inline __host__ __device__ float2 reflect(float2 i, float2 n)
+{
+	return i - 2.0f * n * dot(n,i);
+}
+
+// absolute value
+inline __host__ __device__ float2 fabs(float2 v)
+{
+	return make_float2(fabs(v.x), fabs(v.y));
+}
+
+// float3 functions
+////////////////////////////////////////////////////////////////////////////////
+
+// additional constructors
+inline __host__ __device__ float3 make_float3(float s)
+{
+    return make_float3(s, s, s);
+}
+inline __host__ __device__ float3 make_float3(float2 a)
+{
+    return make_float3(a.x, a.y, 0.0f);
+}
+inline __host__ __device__ float3 make_float3(float2 a, float s)
+{
+    return make_float3(a.x, a.y, s);
+}
+inline __host__ __device__ float3 make_float3(float4 a)
+{
+    return make_float3(a.x, a.y, a.z);  // discards w
+}
+inline __host__ __device__ float3 make_float3(int3 a)
+{
+    return make_float3(float(a.x), float(a.y), float(a.z));
+}
+
+// negate
+inline __host__ __device__ float3 operator-(float3 &a)
+{
+    return make_float3(-a.x, -a.y, -a.z);
+}
+
+// min
+static __inline__ __host__ __device__ float3 fminf(float3 a, float3 b)
+{
+	return make_float3(fminf(a.x,b.x), fminf(a.y,b.y), fminf(a.z,b.z));
+}
+
+// max
+static __inline__ __host__ __device__ float3 fmaxf(float3 a, float3 b)
+{
+	return make_float3(fmaxf(a.x,b.x), fmaxf(a.y,b.y), fmaxf(a.z,b.z));
+}
+
+// addition
+inline __host__ __device__ float3 operator+(float3 a, float3 b)
+{
+    return make_float3(a.x + b.x, a.y + b.y, a.z + b.z);
+}
+inline __host__ __device__ float3 operator+(float3 a, float b)
+{
+    return make_float3(a.x + b, a.y + b, a.z + b);
+}
+inline __host__ __device__ void operator+=(float3 &a, float3 b)
+{
+    a.x += b.x; a.y += b.y; a.z += b.z;
+}
+
+// subtract
+inline __host__ __device__ float3 operator-(float3 a, float3 b)
+{
+    return make_float3(a.x - b.x, a.y - b.y, a.z - b.z);
+}
+inline __host__ __device__ float3 operator-(float3 a, float b)
+{
+    return make_float3(a.x - b, a.y - b, a.z - b);
+}
+inline __host__ __device__ void operator-=(float3 &a, float3 b)
+{
+    a.x -= b.x; a.y -= b.y; a.z -= b.z;
+}
+
+// multiply
+inline __host__ __device__ float3 operator*(float3 a, float3 b)
+{
+    return make_float3(a.x * b.x, a.y * b.y, a.z * b.z);
+}
+inline __host__ __device__ float3 operator*(float3 a, float s)
+{
+    return make_float3(a.x * s, a.y * s, a.z * s);
+}
+inline __host__ __device__ float3 operator*(float s, float3 a)
+{
+    return make_float3(a.x * s, a.y * s, a.z * s);
+}
+inline __host__ __device__ void operator*=(float3 &a, float s)
+{
+    a.x *= s; a.y *= s; a.z *= s;
+}
+inline __host__ __device__ void operator*=(float3 &a, float3 b)
+{
+	a.x *= b.x; a.y *= b.y; a.z *= b.z;;
+}
+
+// divide
+inline __host__ __device__ float3 operator/(float3 a, float3 b)
+{
+    return make_float3(a.x / b.x, a.y / b.y, a.z / b.z);
+}
+inline __host__ __device__ float3 operator/(float3 a, float s)
+{
+    float inv = 1.0f / s;
+    return a * inv;
+}
+inline __host__ __device__ float3 operator/(float s, float3 a)
+{
+    float inv = 1.0f / s;
+    return a * inv;
+}
+inline __host__ __device__ void operator/=(float3 &a, float s)
+{
+    float inv = 1.0f / s;
+    a *= inv;
+}
+
+// lerp
+inline __device__ __host__ float3 lerp(float3 a, float3 b, float t)
+{
+    return a + t*(b-a);
+}
+
+// clamp
+inline __device__ __host__ float3 clamp(float3 v, float a, float b)
+{
+    return make_float3(clamp(v.x, a, b), clamp(v.y, a, b), clamp(v.z, a, b));
+}
+
+inline __device__ __host__ float3 clamp(float3 v, float3 a, float3 b)
+{
+    return make_float3(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y), clamp(v.z, a.z, b.z));
+}
+
+// dot product
+inline __host__ __device__ float dot(float3 a, float3 b)
+{ 
+    return a.x * b.x + a.y * b.y + a.z * b.z;
+}
+
+// cross product
+inline __host__ __device__ float3 cross(float3 a, float3 b)
+{ 
+    return make_float3(a.y*b.z - a.z*b.y, a.z*b.x - a.x*b.z, a.x*b.y - a.y*b.x); 
+}
+
+// length
+inline __host__ __device__ float length(float3 v)
+{
+    return sqrtf(dot(v, v));
+}
+
+// normalize
+inline __host__ __device__ float3 normalize(float3 v)
+{
+    float invLen = rsqrtf(dot(v, v));
+    return v * invLen;
+}
+
+// floor
+inline __host__ __device__ float3 floor(const float3 v)
+{
+    return make_float3(floor(v.x), floor(v.y), floor(v.z));
+}
+
+// reflect
+inline __host__ __device__ float3 reflect(float3 i, float3 n)
+{
+	return i - 2.0f * n * dot(n,i);
+}
+
+// absolute value
+inline __host__ __device__ float3 fabs(float3 v)
+{
+	return make_float3(fabs(v.x), fabs(v.y), fabs(v.z));
+}
+
+// float4 functions
+////////////////////////////////////////////////////////////////////////////////
+
+// additional constructors
+inline __host__ __device__ float4 make_float4(float s)
+{
+    return make_float4(s, s, s, s);
+}
+inline __host__ __device__ float4 make_float4(float3 a)
+{
+    return make_float4(a.x, a.y, a.z, 0.0f);
+}
+inline __host__ __device__ float4 make_float4(float3 a, float w)
+{
+    return make_float4(a.x, a.y, a.z, w);
+}
+inline __host__ __device__ float4 make_float4(int4 a)
+{
+    return make_float4(float(a.x), float(a.y), float(a.z), float(a.w));
+}
+
+// negate
+inline __host__ __device__ float4 operator-(float4 &a)
+{
+    return make_float4(-a.x, -a.y, -a.z, -a.w);
+}
+
+// min
+static __inline__ __host__ __device__ float4 fminf(float4 a, float4 b)
+{
+	return make_float4(fminf(a.x,b.x), fminf(a.y,b.y), fminf(a.z,b.z), fminf(a.w,b.w));
+}
+
+// max
+static __inline__ __host__ __device__ float4 fmaxf(float4 a, float4 b)
+{
+	return make_float4(fmaxf(a.x,b.x), fmaxf(a.y,b.y), fmaxf(a.z,b.z), fmaxf(a.w,b.w));
+}
+
+// addition
+inline __host__ __device__ float4 operator+(float4 a, float4 b)
+{
+    return make_float4(a.x + b.x, a.y + b.y, a.z + b.z,  a.w + b.w);
+}
+inline __host__ __device__ void operator+=(float4 &a, float4 b)
+{
+    a.x += b.x; a.y += b.y; a.z += b.z; a.w += b.w;
+}
+
+// subtract
+inline __host__ __device__ float4 operator-(float4 a, float4 b)
+{
+    return make_float4(a.x - b.x, a.y - b.y, a.z - b.z,  a.w - b.w);
+}
+inline __host__ __device__ void operator-=(float4 &a, float4 b)
+{
+    a.x -= b.x; a.y -= b.y; a.z -= b.z; a.w -= b.w;
+}
+
+// multiply
+inline __host__ __device__ float4 operator*(float4 a, float s)
+{
+    return make_float4(a.x * s, a.y * s, a.z * s, a.w * s);
+}
+inline __host__ __device__ float4 operator*(float s, float4 a)
+{
+    return make_float4(a.x * s, a.y * s, a.z * s, a.w * s);
+}
+inline __host__ __device__ void operator*=(float4 &a, float s)
+{
+    a.x *= s; a.y *= s; a.z *= s; a.w *= s;
+}
+
+// divide
+inline __host__ __device__ float4 operator/(float4 a, float4 b)
+{
+    return make_float4(a.x / b.x, a.y / b.y, a.z / b.z, a.w / b.w);
+}
+inline __host__ __device__ float4 operator/(float4 a, float s)
+{
+    float inv = 1.0f / s;
+    return a * inv;
+}
+inline __host__ __device__ float4 operator/(float s, float4 a)
+{
+    float inv = 1.0f / s;
+    return a * inv;
+}
+inline __host__ __device__ void operator/=(float4 &a, float s)
+{
+    float inv = 1.0f / s;
+    a *= inv;
+}
+
+// lerp
+inline __device__ __host__ float4 lerp(float4 a, float4 b, float t)
+{
+    return a + t*(b-a);
+}
+
+// clamp
+inline __device__ __host__ float4 clamp(float4 v, float a, float b)
+{
+    return make_float4(clamp(v.x, a, b), clamp(v.y, a, b), clamp(v.z, a, b), clamp(v.w, a, b));
+}
+
+inline __device__ __host__ float4 clamp(float4 v, float4 a, float4 b)
+{
+    return make_float4(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y), clamp(v.z, a.z, b.z), clamp(v.w, a.w, b.w));
+}
+
+// dot product
+inline __host__ __device__ float dot(float4 a, float4 b)
+{ 
+    return a.x * b.x + a.y * b.y + a.z * b.z + a.w * b.w;
+}
+
+// length
+inline __host__ __device__ float length(float4 r)
+{
+    return sqrtf(dot(r, r));
+}
+
+// normalize
+inline __host__ __device__ float4 normalize(float4 v)
+{
+    float invLen = rsqrtf(dot(v, v));
+    return v * invLen;
+}
+
+// floor
+inline __host__ __device__ float4 floor(const float4 v)
+{
+    return make_float4(floor(v.x), floor(v.y), floor(v.z), floor(v.w));
+}
+
+// absolute value
+inline __host__ __device__ float4 fabs(float4 v)
+{
+	return make_float4(fabs(v.x), fabs(v.y), fabs(v.z), fabs(v.w));
+}
+
+// int3 functions
+////////////////////////////////////////////////////////////////////////////////
+
+// additional constructors
+inline __host__ __device__ int3 make_int3(int s)
+{
+    return make_int3(s, s, s);
+}
+inline __host__ __device__ int3 make_int3(float3 a)
+{
+    return make_int3(int(a.x), int(a.y), int(a.z));
+}
+
+// negate
+inline __host__ __device__ int3 operator-(int3 &a)
+{
+    return make_int3(-a.x, -a.y, -a.z);
+}
+
+// min
+inline __host__ __device__ int3 min(int3 a, int3 b)
+{
+    return make_int3(min(a.x,b.x), min(a.y,b.y), min(a.z,b.z));
+}
+
+// max
+inline __host__ __device__ int3 max(int3 a, int3 b)
+{
+    return make_int3(max(a.x,b.x), max(a.y,b.y), max(a.z,b.z));
+}
+
+// addition
+inline __host__ __device__ int3 operator+(int3 a, int3 b)
+{
+    return make_int3(a.x + b.x, a.y + b.y, a.z + b.z);
+}
+inline __host__ __device__ void operator+=(int3 &a, int3 b)
+{
+    a.x += b.x; a.y += b.y; a.z += b.z;
+}
+
+// subtract
+inline __host__ __device__ int3 operator-(int3 a, int3 b)
+{
+    return make_int3(a.x - b.x, a.y - b.y, a.z - b.z);
+}
+
+inline __host__ __device__ void operator-=(int3 &a, int3 b)
+{
+    a.x -= b.x; a.y -= b.y; a.z -= b.z;
+}
+
+// multiply
+inline __host__ __device__ int3 operator*(int3 a, int3 b)
+{
+    return make_int3(a.x * b.x, a.y * b.y, a.z * b.z);
+}
+inline __host__ __device__ int3 operator*(int3 a, int s)
+{
+    return make_int3(a.x * s, a.y * s, a.z * s);
+}
+inline __host__ __device__ int3 operator*(int s, int3 a)
+{
+    return make_int3(a.x * s, a.y * s, a.z * s);
+}
+inline __host__ __device__ void operator*=(int3 &a, int s)
+{
+    a.x *= s; a.y *= s; a.z *= s;
+}
+
+// divide
+inline __host__ __device__ int3 operator/(int3 a, int3 b)
+{
+    return make_int3(a.x / b.x, a.y / b.y, a.z / b.z);
+}
+inline __host__ __device__ int3 operator/(int3 a, int s)
+{
+    return make_int3(a.x / s, a.y / s, a.z / s);
+}
+inline __host__ __device__ int3 operator/(int s, int3 a)
+{
+    return make_int3(a.x / s, a.y / s, a.z / s);
+}
+inline __host__ __device__ void operator/=(int3 &a, int s)
+{
+    a.x /= s; a.y /= s; a.z /= s;
+}
+
+// clamp
+inline __device__ __host__ int clamp(int f, int a, int b)
+{
+    return max(a, min(f, b));
+}
+
+inline __device__ __host__ int3 clamp(int3 v, int a, int b)
+{
+    return make_int3(clamp(v.x, a, b), clamp(v.y, a, b), clamp(v.z, a, b));
+}
+
+inline __device__ __host__ int3 clamp(int3 v, int3 a, int3 b)
+{
+    return make_int3(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y), clamp(v.z, a.z, b.z));
+}
+
+
+// uint3 functions
+////////////////////////////////////////////////////////////////////////////////
+
+// additional constructors
+inline __host__ __device__ uint3 make_uint3(uint s)
+{
+    return make_uint3(s, s, s);
+}
+inline __host__ __device__ uint3 make_uint3(float3 a)
+{
+    return make_uint3(uint(a.x), uint(a.y), uint(a.z));
+}
+
+// min
+inline __host__ __device__ uint3 min(uint3 a, uint3 b)
+{
+    return make_uint3(min(a.x,b.x), min(a.y,b.y), min(a.z,b.z));
+}
+
+// max
+inline __host__ __device__ uint3 max(uint3 a, uint3 b)
+{
+    return make_uint3(max(a.x,b.x), max(a.y,b.y), max(a.z,b.z));
+}
+
+// addition
+inline __host__ __device__ uint3 operator+(uint3 a, uint3 b)
+{
+    return make_uint3(a.x + b.x, a.y + b.y, a.z + b.z);
+}
+inline __host__ __device__ void operator+=(uint3 &a, uint3 b)
+{
+    a.x += b.x; a.y += b.y; a.z += b.z;
+}
+
+// subtract
+inline __host__ __device__ uint3 operator-(uint3 a, uint3 b)
+{
+    return make_uint3(a.x - b.x, a.y - b.y, a.z - b.z);
+}
+
+inline __host__ __device__ void operator-=(uint3 &a, uint3 b)
+{
+    a.x -= b.x; a.y -= b.y; a.z -= b.z;
+}
+
+// multiply
+inline __host__ __device__ uint3 operator*(uint3 a, uint3 b)
+{
+    return make_uint3(a.x * b.x, a.y * b.y, a.z * b.z);
+}
+inline __host__ __device__ uint3 operator*(uint3 a, uint s)
+{
+    return make_uint3(a.x * s, a.y * s, a.z * s);
+}
+inline __host__ __device__ uint3 operator*(uint s, uint3 a)
+{
+    return make_uint3(a.x * s, a.y * s, a.z * s);
+}
+inline __host__ __device__ void operator*=(uint3 &a, uint s)
+{
+    a.x *= s; a.y *= s; a.z *= s;
+}
+
+// divide
+inline __host__ __device__ uint3 operator/(uint3 a, uint3 b)
+{
+    return make_uint3(a.x / b.x, a.y / b.y, a.z / b.z);
+}
+inline __host__ __device__ uint3 operator/(uint3 a, uint s)
+{
+    return make_uint3(a.x / s, a.y / s, a.z / s);
+}
+inline __host__ __device__ uint3 operator/(uint s, uint3 a)
+{
+    return make_uint3(a.x / s, a.y / s, a.z / s);
+}
+inline __host__ __device__ void operator/=(uint3 &a, uint s)
+{
+    a.x /= s; a.y /= s; a.z /= s;
+}
+
+// clamp
+inline __device__ __host__ uint clamp(uint f, uint a, uint b)
+{
+    return max(a, min(f, b));
+}
+
+inline __device__ __host__ uint3 clamp(uint3 v, uint a, uint b)
+{
+    return make_uint3(clamp(v.x, a, b), clamp(v.y, a, b), clamp(v.z, a, b));
+}
+
+inline __device__ __host__ uint3 clamp(uint3 v, uint3 a, uint3 b)
+{
+    return make_uint3(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y), clamp(v.z, a.z, b.z));
+}
+
+
+
+#endif
\ No newline at end of file
diff --git a/clib/include/intersect.h b/clib/include/intersect.h
new file mode 100644
index 0000000..757b137
--- /dev/null
+++ b/clib/include/intersect.h
@@ -0,0 +1,17 @@
+// Copyright (c) Facebook, Inc. and its affiliates.
+// 
+// This source code is licensed under the MIT license found in the
+// LICENSE file in the root directory of this source tree.
+
+#pragma once
+#include <torch/extension.h>
+#include <utility>
+
+std::tuple<at::Tensor, at::Tensor, at::Tensor> ball_intersect(at::Tensor ray_start, at::Tensor ray_dir, at::Tensor points, 
+               const float radius, const int n_max);
+std::tuple<at::Tensor, at::Tensor, at::Tensor> aabb_intersect(at::Tensor ray_start, at::Tensor ray_dir, at::Tensor points, 
+               const float voxelsize, const int n_max);
+std::tuple<at::Tensor, at::Tensor, at::Tensor> svo_intersect(at::Tensor ray_start, at::Tensor ray_dir, at::Tensor points, at::Tensor children,
+               const float voxelsize, const int n_max);
+std::tuple< at::Tensor, at::Tensor, at::Tensor > triangle_intersect(at::Tensor ray_start, at::Tensor ray_dir, at::Tensor face_points, 
+               const float cagesize, const float blur, const int n_max);              
diff --git a/clib/include/octree.h b/clib/include/octree.h
new file mode 100644
index 0000000..429053e
--- /dev/null
+++ b/clib/include/octree.h
@@ -0,0 +1,10 @@
+// Copyright (c) Facebook, Inc. and its affiliates.
+// 
+// This source code is licensed under the MIT license found in the
+// LICENSE file in the root directory of this source tree.
+
+#pragma once
+#include <torch/extension.h>
+#include <utility>
+
+std::tuple<at::Tensor, at::Tensor> build_octree(at::Tensor center, at::Tensor points, int depth);
\ No newline at end of file
diff --git a/clib/include/sample.h b/clib/include/sample.h
new file mode 100644
index 0000000..7547710
--- /dev/null
+++ b/clib/include/sample.h
@@ -0,0 +1,16 @@
+// Copyright (c) Facebook, Inc. and its affiliates.
+// 
+// This source code is licensed under the MIT license found in the
+// LICENSE file in the root directory of this source tree.
+
+#pragma once
+#include <torch/extension.h>
+#include <utility>
+
+            
+std::tuple<at::Tensor, at::Tensor, at::Tensor> uniform_ray_sampling(
+    at::Tensor pts_idx, at::Tensor min_depth, at::Tensor max_depth, at::Tensor uniform_noise,
+    const float step_size, const int max_steps);
+std::tuple<at::Tensor, at::Tensor, at::Tensor> inverse_cdf_sampling(
+    at::Tensor pts_idx, at::Tensor min_depth, at::Tensor max_depth, at::Tensor uniform_noise,
+    at::Tensor probs, at::Tensor steps, float fixed_step_size);
\ No newline at end of file
diff --git a/clib/include/utils.h b/clib/include/utils.h
new file mode 100644
index 0000000..925f769
--- /dev/null
+++ b/clib/include/utils.h
@@ -0,0 +1,30 @@
+// Copyright (c) Facebook, Inc. and its affiliates.
+// 
+// This source code is licensed under the MIT license found in the
+// LICENSE file in the root directory of this source tree.
+
+#pragma once
+#include <ATen/cuda/CUDAContext.h>
+#include <torch/extension.h>
+
+#define CHECK_CUDA(x)                                          \
+  do {                                                         \
+    TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor"); \
+  } while (0)
+
+#define CHECK_CONTIGUOUS(x)                                         \
+  do {                                                              \
+    TORCH_CHECK(x.is_contiguous(), #x " must be a contiguous tensor"); \
+  } while (0)
+
+#define CHECK_IS_INT(x)                              \
+  do {                                               \
+    TORCH_CHECK(x.scalar_type() == at::ScalarType::Int, \
+             #x " must be an int tensor");           \
+  } while (0)
+
+#define CHECK_IS_FLOAT(x)                              \
+  do {                                                 \
+    TORCH_CHECK(x.scalar_type() == at::ScalarType::Float, \
+             #x " must be a float tensor");            \
+  } while (0)
diff --git a/clib/src/binding.cpp b/clib/src/binding.cpp
new file mode 100644
index 0000000..a7274d0
--- /dev/null
+++ b/clib/src/binding.cpp
@@ -0,0 +1,21 @@
+// Copyright (c) Facebook, Inc. and its affiliates.
+// 
+// This source code is licensed under the MIT license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include "intersect.h"
+#include "octree.h"
+#include "sample.h"
+
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+  m.def("ball_intersect", &ball_intersect);
+  m.def("aabb_intersect", &aabb_intersect);
+  m.def("svo_intersect", &svo_intersect);
+  m.def("triangle_intersect", &triangle_intersect);
+
+  m.def("uniform_ray_sampling", &uniform_ray_sampling);
+  m.def("inverse_cdf_sampling", &inverse_cdf_sampling);
+
+  m.def("build_octree", &build_octree);
+}
\ No newline at end of file
diff --git a/clib/src/intersect.cpp b/clib/src/intersect.cpp
new file mode 100644
index 0000000..5e5bab4
--- /dev/null
+++ b/clib/src/intersect.cpp
@@ -0,0 +1,146 @@
+// Copyright (c) Facebook, Inc. and its affiliates.
+// 
+// This source code is licensed under the MIT license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include "intersect.h"
+#include "utils.h"
+#include <utility> 
+
+void ball_intersect_point_kernel_wrapper(
+  int b, int n, int m, float radius, int n_max,
+  const float *ray_start, const float *ray_dir, const float *points,
+  int *idx, float *min_depth, float *max_depth);
+
+std::tuple< at::Tensor, at::Tensor, at::Tensor > ball_intersect(at::Tensor ray_start, at::Tensor ray_dir, at::Tensor points, 
+               const float radius, const int n_max){
+  CHECK_CONTIGUOUS(ray_start);
+  CHECK_CONTIGUOUS(ray_dir);
+  CHECK_CONTIGUOUS(points);
+  CHECK_IS_FLOAT(ray_start);
+  CHECK_IS_FLOAT(ray_dir);
+  CHECK_IS_FLOAT(points);
+  CHECK_CUDA(ray_start);
+  CHECK_CUDA(ray_dir);
+  CHECK_CUDA(points);
+
+  at::Tensor idx =
+      torch::zeros({ray_start.size(0), ray_start.size(1), n_max},
+                    at::device(ray_start.device()).dtype(at::ScalarType::Int));
+  at::Tensor min_depth =
+      torch::zeros({ray_start.size(0), ray_start.size(1), n_max},
+                    at::device(ray_start.device()).dtype(at::ScalarType::Float));
+  at::Tensor max_depth =
+      torch::zeros({ray_start.size(0), ray_start.size(1), n_max},
+                    at::device(ray_start.device()).dtype(at::ScalarType::Float));
+  ball_intersect_point_kernel_wrapper(points.size(0), points.size(1), ray_start.size(1),
+                                      radius, n_max,
+                                      ray_start.data_ptr <float>(), ray_dir.data_ptr <float>(), points.data_ptr <float>(),
+                                      idx.data_ptr <int>(), min_depth.data_ptr <float>(), max_depth.data_ptr <float>());
+  return std::make_tuple(idx, min_depth, max_depth);
+}
+
+
+void aabb_intersect_point_kernel_wrapper(
+  int b, int n, int m, float voxelsize, int n_max,
+  const float *ray_start, const float *ray_dir, const float *points,
+  int *idx, float *min_depth, float *max_depth);
+
+std::tuple< at::Tensor, at::Tensor, at::Tensor > aabb_intersect(at::Tensor ray_start, at::Tensor ray_dir, at::Tensor points, 
+               const float voxelsize, const int n_max){
+  CHECK_CONTIGUOUS(ray_start);
+  CHECK_CONTIGUOUS(ray_dir);
+  CHECK_CONTIGUOUS(points);
+  CHECK_IS_FLOAT(ray_start);
+  CHECK_IS_FLOAT(ray_dir);
+  CHECK_IS_FLOAT(points);
+  CHECK_CUDA(ray_start);
+  CHECK_CUDA(ray_dir);
+  CHECK_CUDA(points);
+
+  at::Tensor idx =
+      torch::zeros({ray_start.size(0), ray_start.size(1), n_max},
+                    at::device(ray_start.device()).dtype(at::ScalarType::Int));
+  at::Tensor min_depth =
+      torch::zeros({ray_start.size(0), ray_start.size(1), n_max},
+                    at::device(ray_start.device()).dtype(at::ScalarType::Float));
+  at::Tensor max_depth =
+      torch::zeros({ray_start.size(0), ray_start.size(1), n_max},
+                    at::device(ray_start.device()).dtype(at::ScalarType::Float));
+  aabb_intersect_point_kernel_wrapper(points.size(0), points.size(1), ray_start.size(1),
+                                      voxelsize, n_max,
+                                      ray_start.data_ptr <float>(), ray_dir.data_ptr <float>(), points.data_ptr <float>(),
+                                      idx.data_ptr <int>(), min_depth.data_ptr <float>(), max_depth.data_ptr <float>());
+  return std::make_tuple(idx, min_depth, max_depth);
+}
+
+
+void svo_intersect_point_kernel_wrapper(
+  int b, int n, int m, float voxelsize, int n_max,
+  const float *ray_start, const float *ray_dir, const float *points, const int *children,
+  int *idx, float *min_depth, float *max_depth);
+
+
+std::tuple< at::Tensor, at::Tensor, at::Tensor > svo_intersect(at::Tensor ray_start, at::Tensor ray_dir, at::Tensor points, 
+               at::Tensor children, const float voxelsize, const int n_max){
+  CHECK_CONTIGUOUS(ray_start);
+  CHECK_CONTIGUOUS(ray_dir);
+  CHECK_CONTIGUOUS(points);
+  CHECK_CONTIGUOUS(children);
+  CHECK_IS_FLOAT(ray_start);
+  CHECK_IS_FLOAT(ray_dir);
+  CHECK_IS_FLOAT(points);
+  CHECK_CUDA(ray_start);
+  CHECK_CUDA(ray_dir);
+  CHECK_CUDA(points);
+  CHECK_CUDA(children);
+
+  at::Tensor idx =
+      torch::zeros({ray_start.size(0), ray_start.size(1), n_max},
+                    at::device(ray_start.device()).dtype(at::ScalarType::Int));
+  at::Tensor min_depth =
+      torch::zeros({ray_start.size(0), ray_start.size(1), n_max},
+                    at::device(ray_start.device()).dtype(at::ScalarType::Float));
+  at::Tensor max_depth =
+      torch::zeros({ray_start.size(0), ray_start.size(1), n_max},
+                    at::device(ray_start.device()).dtype(at::ScalarType::Float));
+  svo_intersect_point_kernel_wrapper(points.size(0), points.size(1), ray_start.size(1),
+                                      voxelsize, n_max,
+                                      ray_start.data_ptr <float>(), ray_dir.data_ptr <float>(), points.data_ptr <float>(),
+                                      children.data_ptr <int>(), idx.data_ptr <int>(), min_depth.data_ptr <float>(), max_depth.data_ptr <float>());
+  return std::make_tuple(idx, min_depth, max_depth);
+}
+
+
+void triangle_intersect_point_kernel_wrapper(
+  int b, int n, int m, float cagesize, float blur, int n_max,
+  const float *ray_start, const float *ray_dir, const float *face_points,
+  int *idx, float *depth, float *uv);
+
+std::tuple< at::Tensor, at::Tensor, at::Tensor > triangle_intersect(at::Tensor ray_start, at::Tensor ray_dir, at::Tensor face_points, 
+               const float cagesize, const float blur, const int n_max){
+  CHECK_CONTIGUOUS(ray_start);
+  CHECK_CONTIGUOUS(ray_dir);
+  CHECK_CONTIGUOUS(face_points);
+  CHECK_IS_FLOAT(ray_start);
+  CHECK_IS_FLOAT(ray_dir);
+  CHECK_IS_FLOAT(face_points);
+  CHECK_CUDA(ray_start);
+  CHECK_CUDA(ray_dir);
+  CHECK_CUDA(face_points);
+
+  at::Tensor idx =
+      torch::zeros({ray_start.size(0), ray_start.size(1), n_max},
+                    at::device(ray_start.device()).dtype(at::ScalarType::Int));
+  at::Tensor depth =
+      torch::zeros({ray_start.size(0), ray_start.size(1), n_max * 3},
+                    at::device(ray_start.device()).dtype(at::ScalarType::Float));
+  at::Tensor uv =
+      torch::zeros({ray_start.size(0), ray_start.size(1), n_max * 2},
+                    at::device(ray_start.device()).dtype(at::ScalarType::Float));
+  triangle_intersect_point_kernel_wrapper(face_points.size(0), face_points.size(1), ray_start.size(1),
+                                          cagesize, blur, n_max,
+                                          ray_start.data_ptr <float>(), ray_dir.data_ptr <float>(), face_points.data_ptr <float>(),
+                                          idx.data_ptr <int>(), depth.data_ptr <float>(), uv.data_ptr <float>());
+  return std::make_tuple(idx, depth, uv);
+}
diff --git a/clib/src/intersect_gpu.cu b/clib/src/intersect_gpu.cu
new file mode 100644
index 0000000..fa25cda
--- /dev/null
+++ b/clib/src/intersect_gpu.cu
@@ -0,0 +1,375 @@
+// Copyright (c) Facebook, Inc. and its affiliates.
+//
+// This source code is licensed under the MIT license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <math.h>
+#include <stdio.h>
+#include <stdlib.h>
+
+#include "cuda_utils.h"
+#include "cutil_math.h" // required for float3 vector math
+
+__global__ void ball_intersect_point_kernel(int b, int n, int m, float radius, int n_max,
+                                            const float *__restrict__ ray_start,
+                                            const float *__restrict__ ray_dir,
+                                            const float *__restrict__ points, int *__restrict__ idx,
+                                            float *__restrict__ min_depth,
+                                            float *__restrict__ max_depth) {
+
+    int batch_index = blockIdx.x;
+    points += batch_index * n * 3;
+    ray_start += batch_index * m * 3;
+    ray_dir += batch_index * m * 3;
+    idx += batch_index * m * n_max;
+    min_depth += batch_index * m * n_max;
+    max_depth += batch_index * m * n_max;
+
+    int index = threadIdx.x;
+    int stride = blockDim.x;
+    float radius2 = radius * radius;
+
+    for (int j = index; j < m; j += stride) {
+
+        float x0 = ray_start[j * 3 + 0];
+        float y0 = ray_start[j * 3 + 1];
+        float z0 = ray_start[j * 3 + 2];
+        float xw = ray_dir[j * 3 + 0];
+        float yw = ray_dir[j * 3 + 1];
+        float zw = ray_dir[j * 3 + 2];
+
+        for (int l = 0; l < n_max; ++l) {
+            idx[j * n_max + l] = -1;
+        }
+
+        for (int k = 0, cnt = 0; k < n && cnt < n_max; ++k) {
+            float x = points[k * 3 + 0] - x0;
+            float y = points[k * 3 + 1] - y0;
+            float z = points[k * 3 + 2] - z0;
+            float d2 = x * x + y * y + z * z;
+            float d2_proj = pow(x * xw + y * yw + z * zw, 2);
+            float r2 = d2 - d2_proj;
+
+            if (r2 < radius2) {
+                idx[j * n_max + cnt] = k;
+
+                float depth = sqrt(d2_proj);
+                float depth_blur = sqrt(radius2 - r2);
+
+                min_depth[j * n_max + cnt] = depth - depth_blur;
+                max_depth[j * n_max + cnt] = depth + depth_blur;
+                ++cnt;
+            }
+        }
+    }
+}
+
+__device__ float2 RayAABBIntersection(const float3 &ori, const float3 &dir, const float3 &center,
+                                      float half_voxel) {
+
+    float f_low = 0;
+    float f_high = 100000.;
+    float f_dim_low, f_dim_high, temp, inv_ray_dir, start, aabb;
+
+    for (int d = 0; d < 3; ++d) {
+        switch (d) {
+        case 0:
+            inv_ray_dir = __fdividef(1.0f, dir.x);
+            start = ori.x;
+            aabb = center.x;
+            break;
+        case 1:
+            inv_ray_dir = __fdividef(1.0f, dir.y);
+            start = ori.y;
+            aabb = center.y;
+            break;
+        case 2:
+            inv_ray_dir = __fdividef(1.0f, dir.z);
+            start = ori.z;
+            aabb = center.z;
+            break;
+        }
+
+        f_dim_low = (aabb - half_voxel - start) * inv_ray_dir;
+        f_dim_high = (aabb + half_voxel - start) * inv_ray_dir;
+
+        // Make sure low is less than high
+        if (f_dim_high < f_dim_low) {
+            temp = f_dim_low;
+            f_dim_low = f_dim_high;
+            f_dim_high = temp;
+        }
+
+        // If this dimension's high is less than the low we got then we definitely missed.
+        // Likewise if the low is less than the high.
+        if (f_dim_high < f_low || f_dim_low > f_high)
+            return make_float2(-1.0f, -1.0f);
+
+        // Add the clip from this dimension to the previous results
+        f_low = max(f_dim_low, f_low);
+        f_high = min(f_dim_high, f_high);
+        if (f_low >= f_high - 1e-5f)
+            return make_float2(-1.0f, -1.0f);
+    }
+    return make_float2(f_low, f_high);
+}
+
+__global__ void aabb_intersect_point_kernel(int b, int n, int m, float voxelsize, int n_max,
+                                            const float *__restrict__ ray_start,
+                                            const float *__restrict__ ray_dir,
+                                            const float *__restrict__ points, int *__restrict__ idx,
+                                            float *__restrict__ min_depth,
+                                            float *__restrict__ max_depth) {
+
+    int batch_index = blockIdx.x;
+    points += batch_index * n * 3;
+    ray_start += batch_index * m * 3;
+    ray_dir += batch_index * m * 3;
+    idx += batch_index * m * n_max;
+    min_depth += batch_index * m * n_max;
+    max_depth += batch_index * m * n_max;
+
+    int index = threadIdx.x;
+    int stride = blockDim.x;
+    float half_voxel = voxelsize * 0.5;
+
+    for (int j = index; j < m; j += stride) {
+        for (int l = 0; l < n_max; ++l) {
+            idx[j * n_max + l] = -1;
+        }
+
+        for (int k = 0, cnt = 0; k < n && cnt < n_max; ++k) {
+            float2 depths = RayAABBIntersection(
+                make_float3(ray_start[j * 3 + 0], ray_start[j * 3 + 1], ray_start[j * 3 + 2]),
+                make_float3(ray_dir[j * 3 + 0], ray_dir[j * 3 + 1], ray_dir[j * 3 + 2]),
+                make_float3(points[k * 3 + 0], points[k * 3 + 1], points[k * 3 + 2]), half_voxel);
+
+            if (depths.x > -1.0f) {
+                idx[j * n_max + cnt] = k;
+                min_depth[j * n_max + cnt] = depths.x;
+                max_depth[j * n_max + cnt] = depths.y;
+                ++cnt;
+            }
+        }
+    }
+}
+
+__global__ void svo_intersect_point_kernel(int b, int n, int m, float voxelsize, int n_max,
+                                           const float *__restrict__ ray_start,
+                                           const float *__restrict__ ray_dir,
+                                           const float *__restrict__ points,
+                                           const int *__restrict__ children, int *__restrict__ idx,
+                                           float *__restrict__ min_depth,
+                                           float *__restrict__ max_depth) {
+    /*
+    TODO: this is an inefficient implementation of the
+          navie Ray -- Sparse Voxel Octree Intersection.
+          It can be further improved using:
+
+          Revelles, Jorge, Carlos Urena, and Miguel Lastra.
+          "An efficient parametric algorithm for octree traversal." (2000).
+    */
+    int batch_index = blockIdx.x;
+    points += batch_index * n * 3;
+    children += batch_index * n * 9;
+    ray_start += batch_index * m * 3;
+    ray_dir += batch_index * m * 3;
+    idx += batch_index * m * n_max;
+    min_depth += batch_index * m * n_max;
+    max_depth += batch_index * m * n_max;
+
+    int index = threadIdx.x;
+    int stride = blockDim.x;
+    float half_voxel = voxelsize * 0.5;
+
+    for (int j = index; j < m; j += stride) {
+        for (int l = 0; l < n_max; ++l) {
+            idx[j * n_max + l] = -1;
+        }
+        int stack[256] = {-1}; // DFS, initialize the stack
+        int ptr = 0, cnt = 0, k = -1;
+        stack[ptr] = n - 1; // ROOT node is always the last
+        while (ptr > -1 && cnt < n_max) {
+            assert((ptr < 256));
+
+            // evaluate the current node
+            k = stack[ptr];
+            float2 depths = RayAABBIntersection(
+                make_float3(ray_start[j * 3 + 0], ray_start[j * 3 + 1], ray_start[j * 3 + 2]),
+                make_float3(ray_dir[j * 3 + 0], ray_dir[j * 3 + 1], ray_dir[j * 3 + 2]),
+                make_float3(points[k * 3 + 0], points[k * 3 + 1], points[k * 3 + 2]),
+                half_voxel * float(children[k * 9 + 8]));
+            stack[ptr] = -1;
+            ptr--;
+
+            if (depths.x > -1.0f) { // ray did not miss the voxel
+                // TODO: here it should be able to know which children is ok, further optimize the
+                // code
+                if (children[k * 9 + 8] == 1) { // this is a terminal node
+                    idx[j * n_max + cnt] = k;
+                    min_depth[j * n_max + cnt] = depths.x;
+                    max_depth[j * n_max + cnt] = depths.y;
+                    ++cnt;
+                    continue;
+                }
+
+                for (int u = 0; u < 8; u++) {
+                    if (children[k * 9 + u] > -1) {
+                        ptr++;
+                        stack[ptr] = children[k * 9 + u]; // push child to the stack
+                    }
+                }
+            }
+        }
+    }
+}
+
+__device__ float3 RayTriangleIntersection(const float3 &ori, const float3 &dir, const float3 &v0,
+                                          const float3 &v1, const float3 &v2, float blur) {
+
+    float3 v0v1 = v1 - v0;
+    float3 v0v2 = v2 - v0;
+    float3 v0O = ori - v0;
+    float3 dir_crs_v0v2 = cross(dir, v0v2);
+
+    float det = dot(v0v1, dir_crs_v0v2);
+    det = __fdividef(1.0f, det); // CUDA intrinsic function
+
+    float u = dot(v0O, dir_crs_v0v2) * det;
+    if ((u < 0.0f - blur) || (u > 1.0f + blur))
+        return make_float3(-1.0f, 0.0f, 0.0f);
+
+    float3 v0O_crs_v0v1 = cross(v0O, v0v1);
+    float v = dot(dir, v0O_crs_v0v1) * det;
+    if ((v < 0.0f - blur) || (v > 1.0f + blur))
+        return make_float3(-1.0f, 0.0f, 0.0f);
+
+    if (((u + v) < 0.0f - blur) || ((u + v) > 1.0f + blur))
+        return make_float3(-1.0f, 0.0f, 0.0f);
+
+    float t = dot(v0v2, v0O_crs_v0v1) * det;
+    return make_float3(t, u, v);
+}
+
+__global__ void triangle_intersect_point_kernel(int b, int n, int m, float cagesize, float blur,
+                                                int n_max, const float *__restrict__ ray_start,
+                                                const float *__restrict__ ray_dir,
+                                                const float *__restrict__ face_points,
+                                                int *__restrict__ idx, float *__restrict__ depth,
+                                                float *__restrict__ uv) {
+
+    int batch_index = blockIdx.x;
+    face_points += batch_index * n * 9;
+    ray_start += batch_index * m * 3;
+    ray_dir += batch_index * m * 3;
+    idx += batch_index * m * n_max;
+    depth += batch_index * m * n_max * 3;
+    uv += batch_index * m * n_max * 2;
+
+    int index = threadIdx.x;
+    int stride = blockDim.x;
+    for (int j = index; j < m; j += stride) {
+        // go over rays
+        for (int l = 0; l < n_max; ++l) {
+            idx[j * n_max + l] = -1;
+        }
+
+        int cnt = 0;
+        for (int k = 0; k < n && cnt < n_max; ++k) {
+            // go over triangles
+            float3 tuv = RayTriangleIntersection(
+                make_float3(ray_start[j * 3 + 0], ray_start[j * 3 + 1], ray_start[j * 3 + 2]),
+                make_float3(ray_dir[j * 3 + 0], ray_dir[j * 3 + 1], ray_dir[j * 3 + 2]),
+                make_float3(face_points[k * 9 + 0], face_points[k * 9 + 1], face_points[k * 9 + 2]),
+                make_float3(face_points[k * 9 + 3], face_points[k * 9 + 4], face_points[k * 9 + 5]),
+                make_float3(face_points[k * 9 + 6], face_points[k * 9 + 7], face_points[k * 9 + 8]),
+                blur);
+
+            if (tuv.x > 0) {
+                int ki = k;
+                float d = tuv.x, u = tuv.y, v = tuv.z;
+
+                // sort
+                for (int l = 0; l < cnt; l++) {
+                    if (d < depth[j * n_max * 3 + l * 3]) {
+                        swap(ki, idx[j * n_max + l]);
+                        swap(d, depth[j * n_max * 3 + l * 3]);
+                        swap(u, uv[j * n_max * 2 + l * 2]);
+                        swap(v, uv[j * n_max * 2 + l * 2 + 1]);
+                    }
+                }
+                idx[j * n_max + cnt] = ki;
+                depth[j * n_max * 3 + cnt * 3] = d;
+                uv[j * n_max * 2 + cnt * 2] = u;
+                uv[j * n_max * 2 + cnt * 2 + 1] = v;
+                cnt++;
+            }
+        }
+
+        for (int l = 0; l < cnt; l++) {
+            // compute min_depth
+            if (l == 0)
+                depth[j * n_max * 3 + l * 3 + 1] = -cagesize;
+            else
+                depth[j * n_max * 3 + l * 3 + 1] =
+                    -fminf(cagesize,
+                           .5 * (depth[j * n_max * 3 + l * 3] - depth[j * n_max * 3 + l * 3 - 3]));
+
+            // compute max_depth
+            if (l == cnt - 1)
+                depth[j * n_max * 3 + l * 3 + 2] = cagesize;
+            else
+                depth[j * n_max * 3 + l * 3 + 2] =
+                    fminf(cagesize,
+                          .5 * (depth[j * n_max * 3 + l * 3 + 3] - depth[j * n_max * 3 + l * 3]));
+        }
+    }
+}
+
+void ball_intersect_point_kernel_wrapper(int b, int n, int m, float radius, int n_max,
+                                         const float *ray_start, const float *ray_dir,
+                                         const float *points, int *idx, float *min_depth,
+                                         float *max_depth) {
+
+    cudaStream_t stream = at::cuda::getCurrentCUDAStream();
+    ball_intersect_point_kernel<<<b, opt_n_threads(m), 0, stream>>>(
+        b, n, m, radius, n_max, ray_start, ray_dir, points, idx, min_depth, max_depth);
+
+    CUDA_CHECK_ERRORS();
+}
+
+void aabb_intersect_point_kernel_wrapper(int b, int n, int m, float voxelsize, int n_max,
+                                         const float *ray_start, const float *ray_dir,
+                                         const float *points, int *idx, float *min_depth,
+                                         float *max_depth) {
+
+    cudaStream_t stream = at::cuda::getCurrentCUDAStream();
+    aabb_intersect_point_kernel<<<b, opt_n_threads(m), 0, stream>>>(
+        b, n, m, voxelsize, n_max, ray_start, ray_dir, points, idx, min_depth, max_depth);
+
+    CUDA_CHECK_ERRORS();
+}
+
+void svo_intersect_point_kernel_wrapper(int b, int n, int m, float voxelsize, int n_max,
+                                        const float *ray_start, const float *ray_dir,
+                                        const float *points, const int *children, int *idx,
+                                        float *min_depth, float *max_depth) {
+
+    cudaStream_t stream = at::cuda::getCurrentCUDAStream();
+    svo_intersect_point_kernel<<<b, opt_n_threads(m), 0, stream>>>(
+        b, n, m, voxelsize, n_max, ray_start, ray_dir, points, children, idx, min_depth, max_depth);
+
+    CUDA_CHECK_ERRORS();
+}
+
+void triangle_intersect_point_kernel_wrapper(int b, int n, int m, float cagesize, float blur,
+                                             int n_max, const float *ray_start,
+                                             const float *ray_dir, const float *face_points,
+                                             int *idx, float *depth, float *uv) {
+
+    cudaStream_t stream = at::cuda::getCurrentCUDAStream();
+    triangle_intersect_point_kernel<<<b, opt_n_threads(m), 0, stream>>>(
+        b, n, m, cagesize, blur, n_max, ray_start, ray_dir, face_points, idx, depth, uv);
+
+    CUDA_CHECK_ERRORS();
+}
diff --git a/clib/src/octree.cpp b/clib/src/octree.cpp
new file mode 100644
index 0000000..e1c8ab0
--- /dev/null
+++ b/clib/src/octree.cpp
@@ -0,0 +1,136 @@
+// Copyright (c) Facebook, Inc. and its affiliates.
+// 
+// This source code is licensed under the MIT license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include "octree.h"
+#include "utils.h"
+#include <utility> 
+#include <chrono>
+using namespace std::chrono; 
+
+
+typedef struct OcTree
+{
+    int depth;
+    int index;
+    at::Tensor center;
+    struct OcTree *children[8];
+    void init(at::Tensor center, int d, int i) {
+        this->center = center;
+        this->depth = d;
+        this->index = i;
+        for (int i=0; i<8; i++) this->children[i] = nullptr;
+    }
+}OcTree;
+
+class EasyOctree {
+    public:
+        OcTree *root;
+        int total;
+        int terminal;
+
+        at::Tensor all_centers;
+        at::Tensor all_children;
+
+        EasyOctree(at::Tensor center, int depth) {
+            root = new OcTree;
+            root->init(center, depth, -1);
+            total = -1;
+            terminal = -1;
+        }
+        ~EasyOctree() {
+            OcTree *p = root;
+            destory(p);
+        }
+        void destory(OcTree * &p);
+        void insert(OcTree * &p, at::Tensor point, int index);
+        void finalize();
+        std::pair<int, int> count(OcTree * &p);
+};
+
+void EasyOctree::destory(OcTree * &p){
+    if (p != nullptr) {
+        for (int i=0; i<8; i++) {
+            if (p->children[i] != nullptr) destory(p->children[i]);
+        }
+        delete p;
+        p = nullptr;
+    }
+}
+
+void EasyOctree::insert(OcTree * &p, at::Tensor point, int index) {
+    at::Tensor diff = (point > p->center).to(at::kInt);
+    int idx = diff[0].item<int>() + 2 * diff[1].item<int>() + 4 * diff[2].item<int>();
+    if (p->depth == 0) {
+        p->children[idx] = new OcTree;
+        p->children[idx]->init(point, -1, index);
+    } else {
+        if (p->children[idx] == nullptr) {
+            int length = 1 << (p->depth - 1);
+            at::Tensor new_center = p->center + (2 * diff - 1) * length;
+            p->children[idx] = new OcTree;
+            p->children[idx]->init(new_center, p->depth-1, -1);
+        }
+        insert(p->children[idx], point, index);
+    }
+}
+
+std::pair<int, int> EasyOctree::count(OcTree * &p) {
+    int total = 0, terminal = 0;
+    for (int i=0; i<8; i++) {
+        if (p->children[i] != nullptr) {
+            std::pair<int, int> sub = count(p->children[i]);
+            total += sub.first;
+            terminal += sub.second;
+        }
+    }
+    total += 1;
+    if (p->depth == -1) terminal += 1;
+    return std::make_pair(total, terminal);
+}
+
+void EasyOctree::finalize() {
+    std::pair<int, int> outs = count(root);
+    total = outs.first; terminal = outs.second;
+    
+    all_centers =
+      torch::zeros({outs.first, 3}, at::device(root->center.device()).dtype(at::ScalarType::Int));
+    all_children =
+      -torch::ones({outs.first, 9}, at::device(root->center.device()).dtype(at::ScalarType::Int));
+
+    int node_idx = outs.first - 1;
+    root->index = node_idx;
+
+    std::queue<OcTree*> all_leaves; all_leaves.push(root);    
+    while (!all_leaves.empty()) {
+        OcTree* node_ptr = all_leaves.front();
+        all_leaves.pop();
+        for (int i=0; i<8; i++) {
+            if (node_ptr->children[i] != nullptr) {
+                if (node_ptr->children[i]->depth > -1) {
+                    node_idx--; 
+                    node_ptr->children[i]->index = node_idx;
+                }
+                all_leaves.push(node_ptr->children[i]);
+                all_children[node_ptr->index][i] = node_ptr->children[i]->index;
+            }
+        }
+        all_children[node_ptr->index][8] = 1 << (node_ptr->depth + 1);
+        all_centers[node_ptr->index] = node_ptr->center;
+    }
+    assert (node_idx == outs.second);
+};
+
+std::tuple<at::Tensor, at::Tensor> build_octree(at::Tensor center, at::Tensor points, int depth) {
+    auto start = high_resolution_clock::now();
+    EasyOctree tree(center, depth);
+    for (int k=0; k<points.size(0); k++) 
+        tree.insert(tree.root, points[k], k);
+    tree.finalize();
+    auto stop = high_resolution_clock::now(); 
+    auto duration = duration_cast<microseconds>(stop - start);
+    printf("Building EasyOctree done. total #nodes = %d, terminal #nodes = %d (time taken %f s)\n", 
+        tree.total, tree.terminal, float(duration.count()) / 1000000.);
+    return std::make_tuple(tree.all_centers, tree.all_children);
+}
\ No newline at end of file
diff --git a/clib/src/sample.cpp b/clib/src/sample.cpp
new file mode 100644
index 0000000..a67c2f7
--- /dev/null
+++ b/clib/src/sample.cpp
@@ -0,0 +1,96 @@
+// Copyright (c) Facebook, Inc. and its affiliates.
+// 
+// This source code is licensed under the MIT license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include "sample.h"
+#include "utils.h"
+#include <utility> 
+
+
+void uniform_ray_sampling_kernel_wrapper(
+  int b, int num_rays, int max_hits, int max_steps, float step_size,
+  const int *pts_idx, const float *min_depth, const float *max_depth, const float *uniform_noise,
+  int *sampled_idx, float *sampled_depth, float *sampled_dists);
+
+void inverse_cdf_sampling_kernel_wrapper(
+  int b, int num_rays, int max_hits, int max_steps, float fixed_step_size,
+  const int *pts_idx, const float *min_depth, const float *max_depth,
+  const float *uniform_noise, const float *probs, const float *steps,
+  int *sampled_idx, float *sampled_depth, float *sampled_dists);
+
+                  
+std::tuple< at::Tensor, at::Tensor, at::Tensor> uniform_ray_sampling(
+  at::Tensor pts_idx, at::Tensor min_depth, at::Tensor max_depth, at::Tensor uniform_noise,
+  const float step_size, const int max_steps){
+
+  CHECK_CONTIGUOUS(pts_idx);
+  CHECK_CONTIGUOUS(min_depth);
+  CHECK_CONTIGUOUS(max_depth);
+  CHECK_CONTIGUOUS(uniform_noise);
+  CHECK_IS_FLOAT(min_depth);
+  CHECK_IS_FLOAT(max_depth);
+  CHECK_IS_FLOAT(uniform_noise);
+  CHECK_IS_INT(pts_idx);
+  CHECK_CUDA(pts_idx);
+  CHECK_CUDA(min_depth);
+  CHECK_CUDA(max_depth);
+  CHECK_CUDA(uniform_noise);
+
+  at::Tensor sampled_idx =
+      -torch::ones({pts_idx.size(0), pts_idx.size(1), max_steps},
+                    at::device(pts_idx.device()).dtype(at::ScalarType::Int));
+  at::Tensor sampled_depth =
+      torch::zeros({min_depth.size(0), min_depth.size(1), max_steps},
+                    at::device(min_depth.device()).dtype(at::ScalarType::Float));
+  at::Tensor sampled_dists =
+      torch::zeros({min_depth.size(0), min_depth.size(1), max_steps},
+                    at::device(min_depth.device()).dtype(at::ScalarType::Float));
+  uniform_ray_sampling_kernel_wrapper(min_depth.size(0), min_depth.size(1), min_depth.size(2), sampled_depth.size(2),
+                                      step_size,
+                                      pts_idx.data_ptr <int>(), min_depth.data_ptr <float>(), max_depth.data_ptr <float>(),
+                                      uniform_noise.data_ptr <float>(), sampled_idx.data_ptr <int>(), 
+                                      sampled_depth.data_ptr <float>(), sampled_dists.data_ptr <float>());
+  return std::make_tuple(sampled_idx, sampled_depth, sampled_dists);
+}
+
+
+std::tuple<at::Tensor, at::Tensor, at::Tensor> inverse_cdf_sampling(
+    at::Tensor pts_idx, at::Tensor min_depth, at::Tensor max_depth, at::Tensor uniform_noise,
+    at::Tensor probs, at::Tensor steps, float fixed_step_size) {
+  
+  CHECK_CONTIGUOUS(pts_idx);
+  CHECK_CONTIGUOUS(min_depth);
+  CHECK_CONTIGUOUS(max_depth);
+  CHECK_CONTIGUOUS(probs);
+  CHECK_CONTIGUOUS(steps);
+  CHECK_CONTIGUOUS(uniform_noise);
+  CHECK_IS_FLOAT(min_depth);
+  CHECK_IS_FLOAT(max_depth);
+  CHECK_IS_FLOAT(uniform_noise);
+  CHECK_IS_FLOAT(probs);
+  CHECK_IS_FLOAT(steps);
+  CHECK_IS_INT(pts_idx);
+  CHECK_CUDA(pts_idx);
+  CHECK_CUDA(min_depth);
+  CHECK_CUDA(max_depth);
+  CHECK_CUDA(uniform_noise);
+  CHECK_CUDA(probs);
+  CHECK_CUDA(steps);
+
+  int max_steps = uniform_noise.size(-1);
+  at::Tensor sampled_idx =
+      -torch::ones({pts_idx.size(0), pts_idx.size(1), max_steps},
+                    at::device(pts_idx.device()).dtype(at::ScalarType::Int));
+  at::Tensor sampled_depth =
+      torch::zeros({min_depth.size(0), min_depth.size(1), max_steps},
+                    at::device(min_depth.device()).dtype(at::ScalarType::Float));
+  at::Tensor sampled_dists =
+      torch::zeros({min_depth.size(0), min_depth.size(1), max_steps},
+                    at::device(min_depth.device()).dtype(at::ScalarType::Float));
+  inverse_cdf_sampling_kernel_wrapper(min_depth.size(0), min_depth.size(1), min_depth.size(2), sampled_depth.size(2), fixed_step_size,
+                                      pts_idx.data_ptr <int>(), min_depth.data_ptr <float>(), max_depth.data_ptr <float>(),
+                                      uniform_noise.data_ptr <float>(), probs.data_ptr <float>(), steps.data_ptr <float>(),
+                                      sampled_idx.data_ptr <int>(), sampled_depth.data_ptr <float>(), sampled_dists.data_ptr <float>());
+  return std::make_tuple(sampled_idx, sampled_depth, sampled_dists);
+}
\ No newline at end of file
diff --git a/clib/src/sample_gpu.cu b/clib/src/sample_gpu.cu
new file mode 100644
index 0000000..7e4e212
--- /dev/null
+++ b/clib/src/sample_gpu.cu
@@ -0,0 +1,231 @@
+// Copyright (c) Facebook, Inc. and its affiliates.
+// 
+// This source code is licensed under the MIT license found in the
+// LICENSE file in the root directory of this source tree.
+
+
+#include <math.h>
+#include <stdio.h>
+#include <stdlib.h>
+
+#include "cuda_utils.h"
+#include "cutil_math.h"  // required for float3 vector math
+
+
+__global__ void uniform_ray_sampling_kernel(
+            int b, int num_rays, 
+            int max_hits,
+            int max_steps,
+            float step_size,
+            const int *__restrict__ pts_idx,
+            const float *__restrict__ min_depth,
+            const float *__restrict__ max_depth,
+            const float *__restrict__ uniform_noise,
+            int *__restrict__ sampled_idx,
+            float *__restrict__ sampled_depth,
+            float *__restrict__ sampled_dists) {
+  
+  int batch_index = blockIdx.x;
+  int index = threadIdx.x;
+  int stride = blockDim.x;
+
+  pts_idx += batch_index * num_rays * max_hits;
+  min_depth += batch_index * num_rays * max_hits;
+  max_depth += batch_index * num_rays * max_hits;
+
+  uniform_noise += batch_index * num_rays * max_steps;
+  sampled_idx += batch_index * num_rays * max_steps;
+  sampled_depth += batch_index * num_rays * max_steps;
+  sampled_dists += batch_index * num_rays * max_steps;
+
+  // loop over all rays
+  for (int j = index; j < num_rays; j += stride) {
+    int H = j * max_hits, K = j * max_steps;
+    int s = 0, ucur = 0, umin = 0, umax = 0;
+    float last_min_depth, last_max_depth, curr_depth;
+    
+    // sort all depths
+    while (true) {
+      if ((umax == max_hits) || (ucur == max_steps) || (pts_idx[H + umax] == -1)) {
+        break;  // reach the maximum
+      }
+      if (umin < max_hits) {
+        last_min_depth = min_depth[H + umin];
+      } else {
+        last_min_depth = 10000.0;
+      }
+      if (umax < max_hits) {
+        last_max_depth = max_depth[H + umax];
+      } else {
+        last_max_depth = 10000.0;
+      }
+      if (ucur < max_steps) {
+        curr_depth = min_depth[H] + (float(ucur) + uniform_noise[K + ucur]) * step_size;
+      }
+      
+      if ((last_max_depth <= curr_depth) && (last_max_depth <= last_min_depth)) {
+        sampled_depth[K + s] = last_max_depth;
+        sampled_idx[K + s] = pts_idx[H + umax];
+        umax++; s++; continue;
+      }
+      if ((curr_depth <= last_min_depth) && (curr_depth <= last_max_depth)) {
+        sampled_depth[K + s] = curr_depth;
+        sampled_idx[K + s] = pts_idx[H + umin - 1];
+        ucur++; s++; continue;
+      }
+      if ((last_min_depth <= curr_depth) && (last_min_depth <= last_max_depth)) {
+        sampled_depth[K + s] = last_min_depth;
+        sampled_idx[K + s] = pts_idx[H + umin];
+        umin++; s++; continue;
+      }
+    }
+
+    float l_depth, r_depth;
+    int step = 0;
+    for (ucur = 0, umin = 0, umax = 0; ucur < max_steps - 1; ucur++) {
+      if (sampled_idx[K + ucur + 1] == -1) break;
+      l_depth = sampled_depth[K + ucur];
+      r_depth = sampled_depth[K + ucur + 1];  
+      sampled_depth[K + ucur] = (l_depth + r_depth) * .5;
+      sampled_dists[K + ucur] = (r_depth - l_depth);
+      if ((umin < max_hits) && (sampled_depth[K + ucur] >= min_depth[H + umin]) && (pts_idx[H + umin] > -1)) umin++;
+      if ((umax < max_hits) && (sampled_depth[K + ucur] >= max_depth[H + umax]) && (pts_idx[H + umax] > -1)) umax++;
+      if ((umax == max_hits) || (pts_idx[H + umax] == -1)) break;
+      if ((umin - 1 == umax) && (sampled_dists[K + ucur] > 0)) {
+        sampled_depth[K + step] = sampled_depth[K + ucur];
+        sampled_dists[K + step] = sampled_dists[K + ucur];
+        sampled_idx[K + step] = sampled_idx[K + ucur];
+        step++;
+      }
+    }
+    
+    for (int s = step; s < max_steps; s++) {
+      sampled_idx[K + s] = -1;
+    }
+  }
+}
+
+__global__ void inverse_cdf_sampling_kernel(
+    int b, int num_rays, 
+    int max_hits,
+    int max_steps,
+    float fixed_step_size,
+    const int *__restrict__ pts_idx,
+    const float *__restrict__ min_depth,
+    const float *__restrict__ max_depth,
+    const float *__restrict__ uniform_noise,
+    const float *__restrict__ probs,
+    const float *__restrict__ steps,
+    int *__restrict__ sampled_idx,
+    float *__restrict__ sampled_depth,
+    float *__restrict__ sampled_dists) {
+
+    int batch_index = blockIdx.x;
+    int index = threadIdx.x;
+    int stride = blockDim.x;
+
+    pts_idx += batch_index * num_rays * max_hits;
+    min_depth += batch_index * num_rays * max_hits;
+    max_depth += batch_index * num_rays * max_hits;
+    probs += batch_index * num_rays * max_hits;
+    steps += batch_index * num_rays;
+
+    uniform_noise += batch_index * num_rays * max_steps;
+    sampled_idx += batch_index * num_rays * max_steps;
+    sampled_depth += batch_index * num_rays * max_steps;
+    sampled_dists += batch_index * num_rays * max_steps;
+
+    // loop over all rays
+    for (int j = index; j < num_rays; j += stride) {
+        int H = j * max_hits, K = j * max_steps;
+        int curr_bin = 0, s = 0;  // current index (bin)
+
+        float curr_min_depth = min_depth[H];  // lower depth
+        float curr_max_depth = max_depth[H];  // upper depth
+        float curr_min_cdf = 0;
+        float curr_max_cdf = probs[H];
+        float step_size = 1.0 / steps[j];
+        float z_low = curr_min_depth;        
+        int total_steps = int(ceil(steps[j]));
+        bool done = false;
+
+        // optional use a fixed step size
+        if (fixed_step_size > 0.0) step_size = fixed_step_size;
+
+        // sample points 
+        for (int curr_step = 0; curr_step < total_steps; curr_step++) {
+            float curr_cdf = (float(curr_step) + uniform_noise[K + curr_step]) * step_size;
+            while (curr_cdf > curr_max_cdf) {
+                // first include max cdf
+                sampled_idx[K + s] = pts_idx[H + curr_bin];
+                sampled_dists[K + s] = (curr_max_depth - z_low);
+                sampled_depth[K + s] = (curr_max_depth + z_low) * .5;
+
+                // move to next cdf
+                curr_bin++; 
+                s++;
+                if ((curr_bin >= max_hits) || (pts_idx[H + curr_bin] == -1)) {
+                    done = true; break;
+                }
+                curr_min_depth = min_depth[H + curr_bin];
+                curr_max_depth = max_depth[H + curr_bin];
+                curr_min_cdf = curr_max_cdf;
+                curr_max_cdf = curr_max_cdf + probs[H + curr_bin];
+                z_low = curr_min_depth;
+            }
+            if (done) break;
+            
+            // if the sampled cdf is inside bin
+            float u = (curr_cdf - curr_min_cdf) / (curr_max_cdf - curr_min_cdf);
+            float z = curr_min_depth + u * (curr_max_depth - curr_min_depth);
+            sampled_idx[K + s] = pts_idx[H + curr_bin];
+            sampled_dists[K + s] = (z - z_low);
+            sampled_depth[K + s] = (z + z_low) * .5;
+            z_low = z; s++;
+        }
+        
+        // if there are bins still remained
+        while ((z_low < curr_max_depth) && (~done)) {
+            sampled_idx[K + s] = pts_idx[H + curr_bin];
+            sampled_dists[K + s] = (curr_max_depth - z_low);
+            sampled_depth[K + s] = (curr_max_depth + z_low) * .5;
+            curr_bin++; 
+            s++;
+            if ((curr_bin >= max_hits) || (pts_idx[curr_bin] == -1)) 
+                break;
+            
+            curr_min_depth = min_depth[H + curr_bin];
+            curr_max_depth = max_depth[H + curr_bin];
+            z_low = curr_min_depth;
+        }
+    }
+}
+
+void uniform_ray_sampling_kernel_wrapper(
+  int b, int num_rays, int max_hits, int max_steps, float step_size,
+  const int *pts_idx, const float *min_depth, const float *max_depth, const float *uniform_noise,
+  int *sampled_idx, float *sampled_depth, float *sampled_dists) {
+  
+  cudaStream_t stream = at::cuda::getCurrentCUDAStream();
+  uniform_ray_sampling_kernel<<<b, opt_n_threads(num_rays), 0, stream>>>(
+      b, num_rays, max_hits, max_steps, step_size, pts_idx, 
+      min_depth, max_depth, uniform_noise, sampled_idx, sampled_depth, sampled_dists);
+  
+  CUDA_CHECK_ERRORS();
+}
+
+void inverse_cdf_sampling_kernel_wrapper(
+    int b, int num_rays, int max_hits, int max_steps, float fixed_step_size,
+    const int *pts_idx, const float *min_depth, const float *max_depth,
+    const float *uniform_noise, const float *probs, const float *steps,
+    int *sampled_idx, float *sampled_depth, float *sampled_dists) {
+    
+    cudaStream_t stream = at::cuda::getCurrentCUDAStream();
+    inverse_cdf_sampling_kernel<<<b, opt_n_threads(num_rays), 0, stream>>>(
+        b, num_rays, max_hits, max_steps, fixed_step_size,
+        pts_idx, min_depth, max_depth, uniform_noise, probs, steps, 
+        sampled_idx, sampled_depth, sampled_dists);
+    
+    CUDA_CHECK_ERRORS();
+}
+  
\ No newline at end of file
diff --git a/configs/nerf_default.json b/configs/nerf_default.json
new file mode 100644
index 0000000..3f9165b
--- /dev/null
+++ b/configs/nerf_default.json
@@ -0,0 +1,22 @@
+{
+    "model": "NeRF",
+    "args": {
+        "color": "rgb",
+        "n_pot_encode": 10,
+        "n_dir_encode": 4,
+        "fc_params": {
+            "nf": 256,
+            "n_layers": 8,
+            "activation": "relu",
+            "skips": [ 4 ]
+        },
+        "n_featdim": 0,
+        "sample_range": [0, 10],
+        "n_samples": 256,
+        "perturb_sample": true,
+        "spherical": false,
+        "lindisp": false,
+        "raymarching_tolerance": 0,
+        "raymarching_chunk_size": -1
+    }
+}
\ No newline at end of file
diff --git a/configs/nerf_voxels.json b/configs/nerf_voxels.json
new file mode 100644
index 0000000..411ab9d
--- /dev/null
+++ b/configs/nerf_voxels.json
@@ -0,0 +1,24 @@
+{
+    "model": "NeRF",
+    "args": {
+        "color": "rgb",
+        "n_pot_encode": 10,
+        "n_dir_encode": 4,
+        "fc_params": {
+            "nf": 256,
+            "n_layers": 8,
+            "activation": "relu",
+            "skips": [ 4 ]
+        },
+        "n_featdim": 0,
+        "space": "voxels",
+        "voxel_size": 0.5,
+        "sample_range": [0, 10],
+        "n_samples": 50,
+        "perturb_sample": true,
+        "spherical": false,
+        "lindisp": false,
+        "raymarching_tolerance": 0,
+        "raymarching_chunk_size": -1
+    }
+}
\ No newline at end of file
diff --git a/configs/nsvf_coarse.json b/configs/nsvf_coarse.json
new file mode 100644
index 0000000..f9b341c
--- /dev/null
+++ b/configs/nsvf_coarse.json
@@ -0,0 +1,21 @@
+{
+    "model": "NSVF",
+    "args": {
+        "color": "rgb",
+        "n_pot_encode": 10,
+        "n_dir_encode": 4,
+        "fc_params": {
+            "nf": 128,
+            "n_layers": 4,
+            "activation": "relu",
+            "skips": [ 4 ]
+        },
+        "n_featdim": 0,
+        "space": "octree",
+        "voxel_size": 0.5,
+        "sample_step_ratio": 0.2,
+        "perturb_sample": true,
+        "raymarching_tolerance": 0,
+        "raymarching_chunk_size": -1
+    }
+}
\ No newline at end of file
diff --git a/configs/nsvf_default.json b/configs/nsvf_default.json
new file mode 100644
index 0000000..ad6faaf
--- /dev/null
+++ b/configs/nsvf_default.json
@@ -0,0 +1,21 @@
+{
+    "model": "NSVF",
+    "args": {
+        "color": "rgb",
+        "n_pot_encode": 10,
+        "n_dir_encode": 4,
+        "fc_params": {
+            "nf": 256,
+            "n_layers": 8,
+            "activation": "relu",
+            "skips": [ 4 ]
+        },
+        "n_featdim": 0,
+        "space": "octree",
+        "voxel_size": 0.5,
+        "sample_step_ratio": 0.2,
+        "perturb_sample": true,
+        "raymarching_tolerance": 0,
+        "raymarching_chunk_size": -1
+    }
+}
\ No newline at end of file
diff --git a/configs/nsvf_voxels.json b/configs/nsvf_voxels.json
new file mode 100644
index 0000000..b60ae89
--- /dev/null
+++ b/configs/nsvf_voxels.json
@@ -0,0 +1,21 @@
+{
+    "model": "NSVF",
+    "args": {
+        "color": "rgb",
+        "n_pot_encode": 10,
+        "n_dir_encode": 4,
+        "fc_params": {
+            "nf": 256,
+            "n_layers": 8,
+            "activation": "relu",
+            "skips": [ 4 ]
+        },
+        "n_featdim": 0,
+        "space": "voxels",
+        "voxel_size": 0.5,
+        "sample_step_ratio": 0.2,
+        "perturb_sample": true,
+        "raymarching_tolerance": 0,
+        "raymarching_chunk_size": -1
+    }
+}
\ No newline at end of file
diff --git a/configs/bgnet.py b/configs/old/bgnet.py
similarity index 100%
rename from configs/bgnet.py
rename to configs/old/bgnet.py
diff --git a/configs/cnerf.py b/configs/old/cnerf.py
similarity index 100%
rename from configs/cnerf.py
rename to configs/old/cnerf.py
diff --git a/configs/dnerfabins.py b/configs/old/dnerfabins.py
similarity index 100%
rename from configs/dnerfabins.py
rename to configs/old/dnerfabins.py
diff --git a/configs/fovea.py b/configs/old/fovea.py
similarity index 100%
rename from configs/fovea.py
rename to configs/old/fovea.py
diff --git a/configs/fovea_small_rot1.py b/configs/old/fovea_small_rot1.py
similarity index 100%
rename from configs/fovea_small_rot1.py
rename to configs/old/fovea_small_rot1.py
diff --git a/configs/fovea_small_trans.py b/configs/old/fovea_small_trans.py
similarity index 100%
rename from configs/fovea_small_trans.py
rename to configs/old/fovea_small_trans.py
diff --git a/configs/msl2fast.py b/configs/old/msl2fast.py
similarity index 100%
rename from configs/msl2fast.py
rename to configs/old/msl2fast.py
diff --git a/configs/msl_fovea.py b/configs/old/msl_fovea.py
similarity index 100%
rename from configs/msl_fovea.py
rename to configs/old/msl_fovea.py
diff --git a/configs/mslfast.py b/configs/old/mslfast.py
similarity index 100%
rename from configs/mslfast.py
rename to configs/old/mslfast.py
diff --git a/configs/mslray.py b/configs/old/mslray.py
similarity index 100%
rename from configs/mslray.py
rename to configs/old/mslray.py
diff --git a/configs/nerf.py b/configs/old/nerf.py
similarity index 100%
rename from configs/nerf.py
rename to configs/old/nerf.py
diff --git a/configs/nerf_horns.py b/configs/old/nerf_horns.py
similarity index 100%
rename from configs/nerf_horns.py
rename to configs/old/nerf_horns.py
diff --git a/configs/nerf_horns_4.py b/configs/old/nerf_horns_4.py
similarity index 100%
rename from configs/nerf_horns_4.py
rename to configs/old/nerf_horns_4.py
diff --git a/configs/nerf_horns_8.py b/configs/old/nerf_horns_8.py
similarity index 100%
rename from configs/nerf_horns_8.py
rename to configs/old/nerf_horns_8.py
diff --git a/configs/nerf_periph.py b/configs/old/nerf_periph.py
similarity index 100%
rename from configs/nerf_periph.py
rename to configs/old/nerf_periph.py
diff --git a/configs/nerf_trex.py b/configs/old/nerf_trex.py
similarity index 100%
rename from configs/nerf_trex.py
rename to configs/old/nerf_trex.py
diff --git a/configs/nerf_trex_4.py b/configs/old/nerf_trex_4.py
similarity index 100%
rename from configs/nerf_trex_4.py
rename to configs/old/nerf_trex_4.py
diff --git a/configs/nerf_trex_8.py b/configs/old/nerf_trex_8.py
similarity index 100%
rename from configs/nerf_trex_8.py
rename to configs/old/nerf_trex_8.py
diff --git a/configs/nerfsimple.py b/configs/old/nerfsimple.py
similarity index 100%
rename from configs/nerfsimple.py
rename to configs/old/nerfsimple.py
diff --git a/configs/nmsl_fovea.py b/configs/old/nmsl_fovea.py
similarity index 100%
rename from configs/nmsl_fovea.py
rename to configs/old/nmsl_fovea.py
diff --git a/configs/nnerf.py b/configs/old/nnerf.py
similarity index 100%
rename from configs/nnerf.py
rename to configs/old/nnerf.py
diff --git a/configs/oracle.py b/configs/old/oracle.py
similarity index 100%
rename from configs/oracle.py
rename to configs/old/oracle.py
diff --git a/configs/periph.py b/configs/old/periph.py
similarity index 100%
rename from configs/periph.py
rename to configs/old/periph.py
diff --git a/configs/periph_new.py b/configs/old/periph_new.py
similarity index 100%
rename from configs/periph_new.py
rename to configs/old/periph_new.py
diff --git a/configs/periph_small_trans.py b/configs/old/periph_small_trans.py
similarity index 100%
rename from configs/periph_small_trans.py
rename to configs/old/periph_small_trans.py
diff --git a/configs/snerffast_periph.py b/configs/old/snerffast_periph.py
similarity index 100%
rename from configs/snerffast_periph.py
rename to configs/old/snerffast_periph.py
diff --git a/configs/snerffastx.py b/configs/old/snerffastx.py
similarity index 100%
rename from configs/snerffastx.py
rename to configs/old/snerffastx.py
diff --git a/configs/snerf_fine_voxels.json b/configs/snerf_fine_voxels.json
new file mode 100644
index 0000000..8079e6f
--- /dev/null
+++ b/configs/snerf_fine_voxels.json
@@ -0,0 +1,21 @@
+{
+    "model": "SNeRF",
+    "args": {
+        "color": "rgb",
+        "n_pot_encode": 10,
+        "n_dir_encode": 4,
+        "fc_params": {
+            "nf": 256,
+            "n_layers": 8,
+            "activation": "relu",
+            "skips": [ 4 ]
+        },
+        "n_featdim": 0,
+        "space": "voxels",
+        "steps": [8, 32, 16],
+        "n_samples": 16,
+        "perturb_sample": true,
+        "raymarching_tolerance": 0,
+        "raymarching_chunk_size": -1
+    }
+}
\ No newline at end of file
diff --git a/configs/snerf_voxels+ls-d.json b/configs/snerf_voxels+ls-d.json
new file mode 100644
index 0000000..2eebf9b
--- /dev/null
+++ b/configs/snerf_voxels+ls-d.json
@@ -0,0 +1,20 @@
+{
+    "model": "SNeRF",
+    "args": {
+        "color": "rgb",
+        "n_pot_encode": 10,
+        "fc_params": {
+            "nf": 256,
+            "n_layers": 8,
+            "activation": "relu",
+            "skips": [ 4 ]
+        },
+        "n_featdim": 0,
+        "space": "voxels",
+        "steps": [4, 16, 8],
+        "n_samples": 16,
+        "perturb_sample": true,
+        "density_regularization_weight": 1e-4,
+        "density_regularization_scale": 1e4
+    }
+}
\ No newline at end of file
diff --git a/configs/snerf_voxels+ls.json b/configs/snerf_voxels+ls.json
new file mode 100644
index 0000000..a7cde45
--- /dev/null
+++ b/configs/snerf_voxels+ls.json
@@ -0,0 +1,21 @@
+{
+    "model": "SNeRF",
+    "args": {
+        "color": "rgb",
+        "n_pot_encode": 10,
+        "n_dir_encode": 4,
+        "fc_params": {
+            "nf": 256,
+            "n_layers": 8,
+            "activation": "relu",
+            "skips": [ 4 ]
+        },
+        "n_featdim": 0,
+        "space": "voxels",
+        "steps": [4, 16, 8],
+        "n_samples": 16,
+        "perturb_sample": true,
+        "density_regularization_weight": 1e-4,
+        "density_regularization_scale": 1e4
+    }
+}
\ No newline at end of file
diff --git a/configs/snerf_voxels.json b/configs/snerf_voxels.json
new file mode 100644
index 0000000..7e68cb4
--- /dev/null
+++ b/configs/snerf_voxels.json
@@ -0,0 +1,19 @@
+{
+    "model": "SNeRF",
+    "args": {
+        "color": "rgb",
+        "n_pot_encode": 10,
+        "n_dir_encode": 4,
+        "fc_params": {
+            "nf": 256,
+            "n_layers": 8,
+            "activation": "relu",
+            "skips": [ 4 ]
+        },
+        "n_featdim": 0,
+        "space": "voxels",
+        "steps": [4, 16, 8],
+        "n_samples": 16,
+        "perturb_sample": true
+    }
+}
\ No newline at end of file
diff --git a/configs/snerf_voxels_128x8_x2.json b/configs/snerf_voxels_128x8_x2.json
new file mode 100644
index 0000000..0052515
--- /dev/null
+++ b/configs/snerf_voxels_128x8_x2.json
@@ -0,0 +1,22 @@
+{
+    "model": "SNeRF",
+    "args": {
+        "color": "rgb",
+        "n_pot_encode": 10,
+        "n_dir_encode": 4,
+        "fc_params": {
+            "nf": 128,
+            "n_layers": 8,
+            "activation": "relu",
+            "skips": [ 4 ]
+        },
+        "n_featdim": 0,
+        "space": "voxels",
+        "steps": [4, 16, 8],
+        "n_samples": 16,
+        "perturb_sample": true,
+        "raymarching_tolerance": 0,
+        "raymarching_chunk_size": -1,
+        "multi_nets": 2
+    }
+}
\ No newline at end of file
diff --git a/configs/snerf_voxels_128x8_x4.json b/configs/snerf_voxels_128x8_x4.json
new file mode 100644
index 0000000..268498c
--- /dev/null
+++ b/configs/snerf_voxels_128x8_x4.json
@@ -0,0 +1,22 @@
+{
+    "model": "SNeRF",
+    "args": {
+        "color": "rgb",
+        "n_pot_encode": 10,
+        "n_dir_encode": 4,
+        "fc_params": {
+            "nf": 128,
+            "n_layers": 8,
+            "activation": "relu",
+            "skips": [ 4 ]
+        },
+        "n_featdim": 0,
+        "space": "voxels",
+        "steps": [4, 16, 8],
+        "n_samples": 16,
+        "perturb_sample": true,
+        "raymarching_tolerance": 0,
+        "raymarching_chunk_size": -1,
+        "multi_nets": 4
+    }
+}
\ No newline at end of file
diff --git a/configs/snerf_voxels_feat.json b/configs/snerf_voxels_feat.json
new file mode 100644
index 0000000..fcd8dce
--- /dev/null
+++ b/configs/snerf_voxels_feat.json
@@ -0,0 +1,21 @@
+{
+    "model": "SNeRF",
+    "args": {
+        "color": "rgb",
+        "n_pot_encode": 10,
+        "n_dir_encode": 4,
+        "fc_params": {
+            "nf": 256,
+            "n_layers": 8,
+            "activation": "relu",
+            "skips": [ 4 ]
+        },
+        "n_featdim": 32,
+        "space": "voxels",
+        "steps": [4, 16, 8],
+        "n_samples": 16,
+        "perturb_sample": true,
+        "raymarching_tolerance": 0,
+        "raymarching_chunk_size": -1
+    }
+}
\ No newline at end of file
diff --git a/configs/snerf_voxels_fine.json b/configs/snerf_voxels_fine.json
new file mode 100644
index 0000000..2a356ed
--- /dev/null
+++ b/configs/snerf_voxels_fine.json
@@ -0,0 +1,21 @@
+{
+    "model": "SNeRF",
+    "args": {
+        "color": "rgb",
+        "n_pot_encode": 10,
+        "n_dir_encode": 4,
+        "fc_params": {
+            "nf": 512,
+            "n_layers": 8,
+            "activation": "relu",
+            "skips": [4]
+        },
+        "n_featdim": 0,
+        "space": "voxels",
+        "steps": [32, 128, 64],
+        "n_samples": 128,
+        "perturb_sample": true,
+        "raymarching_tolerance": 0,
+        "raymarching_chunk_size": -1
+    }
+}
\ No newline at end of file
diff --git a/configs/snerfadv_finevoxels+ls.json b/configs/snerfadv_finevoxels+ls.json
new file mode 100644
index 0000000..5511733
--- /dev/null
+++ b/configs/snerfadv_finevoxels+ls.json
@@ -0,0 +1,34 @@
+{
+    "model": "SNeRFAdvance",
+    "args": {
+        "color": "rgb",
+        "n_pot_encode": 10,
+        "n_dir_encode": 4,
+        "density_net": {
+            "nf": 256,
+            "n_layers": 4,
+            "act": "relu",
+            "skips": []
+        },
+        "color_net": {
+            "nf": 256,
+            "n_layers": 3,
+            "act": "relu",
+            "skips": []
+        },
+        "specular_net": {
+            "nf": 128,
+            "n_layers": 1,
+            "act": "relu"
+        },
+        "n_featdim": 0,
+        "space": "voxels",
+        "steps": [8, 32, 16],
+        "n_samples": 64,
+        "perturb_sample": true,
+        "appearance": "combined",
+        "density_color_connection": true,
+        "density_regularization_weight": 1e-4,
+        "density_regularization_scale": 1e4
+    }
+}
\ No newline at end of file
diff --git a/configs/snerfadv_finevoxels+ls_256x4_256x6_16x2.json b/configs/snerfadv_finevoxels+ls_256x4_256x6_16x2.json
new file mode 100644
index 0000000..8ee790a
--- /dev/null
+++ b/configs/snerfadv_finevoxels+ls_256x4_256x6_16x2.json
@@ -0,0 +1,34 @@
+{
+    "model": "SNeRFAdvance",
+    "args": {
+        "color": "rgb",
+        "n_pot_encode": 10,
+        "n_dir_encode": 4,
+        "density_net": {
+            "nf": 256,
+            "n_layers": 4,
+            "act": "relu",
+            "skips": []
+        },
+        "color_net": {
+            "nf": 256,
+            "n_layers": 6,
+            "act": "relu",
+            "skips": []
+        },
+        "specular_net": {
+            "nf": 16,
+            "n_layers": 2,
+            "act": "relu"
+        },
+        "n_featdim": 0,
+        "space": "voxels",
+        "steps": [16, 64, 32],
+        "n_samples": 64,
+        "perturb_sample": true,
+        "raymarching_tolerance": 0,
+        "raymarching_chunk_size": -1,
+        "density_regularization_weight": 1e-4,
+        "density_regularization_scale": 1e4
+    }
+}
\ No newline at end of file
diff --git a/configs/snerfadv_finevoxels+ls_256x4_256x6_combined.json b/configs/snerfadv_finevoxels+ls_256x4_256x6_combined.json
new file mode 100644
index 0000000..ce9dd1b
--- /dev/null
+++ b/configs/snerfadv_finevoxels+ls_256x4_256x6_combined.json
@@ -0,0 +1,30 @@
+{
+    "model": "SNeRFAdvance",
+    "args": {
+        "color": "rgb",
+        "n_pot_encode": 10,
+        "n_dir_encode": 4,
+        "density_net": {
+            "nf": 256,
+            "n_layers": 4,
+            "act": "relu",
+            "skips": []
+        },
+        "color_net": {
+            "nf": 256,
+            "n_layers": 6,
+            "act": "relu",
+            "skips": []
+        },
+        "n_featdim": 0,
+        "space": "voxels",
+        "steps": [16, 64, 32],
+        "n_samples": 64,
+        "perturb_sample": true,
+        "raymarching_tolerance": 0,
+        "raymarching_chunk_size": -1,
+        "density_regularization_weight": 1e-4,
+        "density_regularization_scale": 1e4,
+        "appearance": "combined"
+    }
+}
\ No newline at end of file
diff --git a/configs/snerfadv_finevoxels_ls2.json b/configs/snerfadv_finevoxels_ls2.json
new file mode 100644
index 0000000..94d4fd6
--- /dev/null
+++ b/configs/snerfadv_finevoxels_ls2.json
@@ -0,0 +1,34 @@
+{
+    "model": "SNeRFAdvance",
+    "args": {
+        "color": "rgb",
+        "n_pot_encode": 10,
+        "n_dir_encode": 4,
+        "density_net": {
+            "nf": 256,
+            "n_layers": 2,
+            "act": "relu",
+            "skips": []
+        },
+        "color_net": {
+            "nf": 256,
+            "n_layers": 3,
+            "act": "relu",
+            "skips": []
+        },
+        "specular_net": {
+            "nf": 128,
+            "n_layers": 1,
+            "act": "relu"
+        },
+        "n_featdim": 0,
+        "space": "voxels",
+        "steps": [16, 64, 32],
+        "n_samples": 64,
+        "perturb_sample": true,
+        "raymarching_tolerance": 0,
+        "raymarching_chunk_size": -1,
+        "density_regularization_weight": 1e-4,
+        "density_regularization_scale": 1e4
+    }
+}
\ No newline at end of file
diff --git a/configs/snerfadv_voxels+ls+ns.json b/configs/snerfadv_voxels+ls+ns.json
new file mode 100644
index 0000000..9e69a9d
--- /dev/null
+++ b/configs/snerfadv_voxels+ls+ns.json
@@ -0,0 +1,36 @@
+{
+    "model": "SNeRFAdvance",
+    "args": {
+        "color": "rgb",
+        "n_pot_encode": 10,
+        "n_dir_encode": 4,
+        "density_net": {
+            "nf": 256,
+            "n_layers": 4,
+            "act": "relu",
+            "skips": []
+        },
+        "color_net": {
+            "nf": 256,
+            "n_layers": 3,
+            "act": "relu",
+            "skips": []
+        },
+        "specular_net": {
+            "nf": 128,
+            "n_layers": 1,
+            "act": "relu"
+        },
+        "n_featdim": 0,
+        "space": "voxels",
+        "steps": [4, 16, 8],
+        "n_samples": 16,
+        "perturb_sample": true,
+        "appearance": "newtype",
+        "density_color_connection": true,
+        "density_regularization_weight": 1e-4,
+        "density_regularization_scale": 1e4,
+        "specular_regularization_weight": 1e-1,
+        "specular_regularization_scale": 1e4
+    }
+}
\ No newline at end of file
diff --git a/configs/snerfadv_voxels+ls.json b/configs/snerfadv_voxels+ls.json
new file mode 100644
index 0000000..e533e59
--- /dev/null
+++ b/configs/snerfadv_voxels+ls.json
@@ -0,0 +1,33 @@
+{
+    "model": "SNeRFAdvance",
+    "args": {
+        "color": "rgb",
+        "n_pot_encode": 10,
+        "n_dir_encode": 4,
+        "density_net": {
+            "nf": 256,
+            "n_layers": 4,
+            "act": "relu",
+            "skips": []
+        },
+        "color_net": {
+            "nf": 256,
+            "n_layers": 3,
+            "act": "relu",
+            "skips": []
+        },
+        "specular_net": {
+            "nf": 128,
+            "n_layers": 1,
+            "act": "relu"
+        },
+        "n_featdim": 0,
+        "space": "voxels",
+        "steps": [4, 16, 8],
+        "n_samples": 16,
+        "perturb_sample": true,
+        "appearance": "combined",
+        "density_regularization_weight": 1e-4,
+        "density_regularization_scale": 1e4
+    }
+}
\ No newline at end of file
diff --git a/configs/snerfadv_voxels+ls1.json b/configs/snerfadv_voxels+ls1.json
new file mode 100644
index 0000000..d634054
--- /dev/null
+++ b/configs/snerfadv_voxels+ls1.json
@@ -0,0 +1,34 @@
+{
+    "model": "SNeRFAdvance",
+    "args": {
+        "color": "rgb",
+        "n_pot_encode": 10,
+        "n_dir_encode": 4,
+        "density_net": {
+            "nf": 256,
+            "n_layers": 5,
+            "act": "relu",
+            "skips": []
+        },
+        "color_net": {
+            "nf": 256,
+            "n_layers": 2,
+            "act": "relu",
+            "skips": []
+        },
+        "specular_net": {
+            "nf": 128,
+            "n_layers": 1,
+            "act": "relu"
+        },
+        "n_featdim": 0,
+        "space": "voxels",
+        "steps": [4, 16, 8],
+        "n_samples": 16,
+        "perturb_sample": true,
+        "appearance": "combined",
+        "density_color_connection": true,
+        "density_regularization_weight": 1e-4,
+        "density_regularization_scale": 1e4
+    }
+}
\ No newline at end of file
diff --git a/configs/snerfadv_voxels+ls2.json b/configs/snerfadv_voxels+ls2.json
new file mode 100644
index 0000000..bbe5723
--- /dev/null
+++ b/configs/snerfadv_voxels+ls2.json
@@ -0,0 +1,34 @@
+{
+    "model": "SNeRFAdvance",
+    "args": {
+        "color": "rgb",
+        "n_pot_encode": 10,
+        "n_dir_encode": 4,
+        "density_net": {
+            "nf": 256,
+            "n_layers": 4,
+            "act": "relu",
+            "skips": []
+        },
+        "color_net": {
+            "nf": 256,
+            "n_layers": 3,
+            "act": "relu",
+            "skips": []
+        },
+        "specular_net": {
+            "nf": 128,
+            "n_layers": 1,
+            "act": "relu"
+        },
+        "n_featdim": 0,
+        "space": "voxels",
+        "steps": [4, 16, 8],
+        "n_samples": 16,
+        "perturb_sample": true,
+        "appearance": "combined",
+        "density_color_connection": true,
+        "density_regularization_weight": 1e-4,
+        "density_regularization_scale": 1e4
+    }
+}
\ No newline at end of file
diff --git a/configs/snerfadv_voxels+ls3.json b/configs/snerfadv_voxels+ls3.json
new file mode 100644
index 0000000..51aa71f
--- /dev/null
+++ b/configs/snerfadv_voxels+ls3.json
@@ -0,0 +1,34 @@
+{
+    "model": "SNeRFAdvance",
+    "args": {
+        "color": "rgb",
+        "n_pot_encode": 10,
+        "n_dir_encode": 4,
+        "density_net": {
+            "nf": 256,
+            "n_layers": 3,
+            "act": "relu",
+            "skips": []
+        },
+        "color_net": {
+            "nf": 256,
+            "n_layers": 4,
+            "act": "relu",
+            "skips": []
+        },
+        "specular_net": {
+            "nf": 128,
+            "n_layers": 1,
+            "act": "relu"
+        },
+        "n_featdim": 0,
+        "space": "voxels",
+        "steps": [4, 16, 8],
+        "n_samples": 16,
+        "perturb_sample": true,
+        "appearance": "combined",
+        "density_color_connection": true,
+        "density_regularization_weight": 1e-4,
+        "density_regularization_scale": 1e4
+    }
+}
\ No newline at end of file
diff --git a/configs/snerfadv_voxels+ls4.json b/configs/snerfadv_voxels+ls4.json
new file mode 100644
index 0000000..99daf4d
--- /dev/null
+++ b/configs/snerfadv_voxels+ls4.json
@@ -0,0 +1,34 @@
+{
+    "model": "SNeRFAdvance",
+    "args": {
+        "color": "rgb",
+        "n_pot_encode": 10,
+        "n_dir_encode": 4,
+        "density_net": {
+            "nf": 256,
+            "n_layers": 2,
+            "act": "relu",
+            "skips": []
+        },
+        "color_net": {
+            "nf": 256,
+            "n_layers": 5,
+            "act": "relu",
+            "skips": []
+        },
+        "specular_net": {
+            "nf": 128,
+            "n_layers": 1,
+            "act": "relu"
+        },
+        "n_featdim": 0,
+        "space": "voxels",
+        "steps": [4, 16, 8],
+        "n_samples": 16,
+        "perturb_sample": true,
+        "appearance": "combined",
+        "density_color_connection": true,
+        "density_regularization_weight": 1e-4,
+        "density_regularization_scale": 1e4
+    }
+}
\ No newline at end of file
diff --git a/configs/snerfadv_voxels+ls5.json b/configs/snerfadv_voxels+ls5.json
new file mode 100644
index 0000000..897dcb9
--- /dev/null
+++ b/configs/snerfadv_voxels+ls5.json
@@ -0,0 +1,34 @@
+{
+    "model": "SNeRFAdvance",
+    "args": {
+        "color": "rgb",
+        "n_pot_encode": 10,
+        "n_dir_encode": 4,
+        "density_net": {
+            "nf": 256,
+            "n_layers": 8,
+            "act": "relu",
+            "skips": []
+        },
+        "color_net": {
+            "nf": 256,
+            "n_layers": 6,
+            "act": "relu",
+            "skips": []
+        },
+        "specular_net": {
+            "nf": 128,
+            "n_layers": 2,
+            "act": "relu"
+        },
+        "n_featdim": 0,
+        "space": "voxels",
+        "steps": [4, 16, 8],
+        "n_samples": 16,
+        "perturb_sample": true,
+        "appearance": "combined",
+        "density_color_connection": true,
+        "density_regularization_weight": 1e-4,
+        "density_regularization_scale": 1e4
+    }
+}
\ No newline at end of file
diff --git a/configs/snerfadv_voxels+ls6.json b/configs/snerfadv_voxels+ls6.json
new file mode 100644
index 0000000..3971cae
--- /dev/null
+++ b/configs/snerfadv_voxels+ls6.json
@@ -0,0 +1,34 @@
+{
+    "model": "SNeRFAdvance",
+    "args": {
+        "color": "rgb",
+        "n_pot_encode": 10,
+        "n_dir_encode": 4,
+        "density_net": {
+            "nf": 512,
+            "n_layers": 4,
+            "act": "relu",
+            "skips": []
+        },
+        "color_net": {
+            "nf": 512,
+            "n_layers": 3,
+            "act": "relu",
+            "skips": []
+        },
+        "specular_net": {
+            "nf": 256,
+            "n_layers": 1,
+            "act": "relu"
+        },
+        "n_featdim": 0,
+        "space": "voxels",
+        "steps": [4, 16, 8],
+        "n_samples": 16,
+        "perturb_sample": true,
+        "appearance": "combined",
+        "density_color_connection": true,
+        "density_regularization_weight": 1e-4,
+        "density_regularization_scale": 1e4
+    }
+}
\ No newline at end of file
diff --git a/configs/snerfadvx_voxels_x16.json b/configs/snerfadvx_voxels_x16.json
new file mode 100644
index 0000000..7e55bea
--- /dev/null
+++ b/configs/snerfadvx_voxels_x16.json
@@ -0,0 +1,34 @@
+{
+    "model": "SNeRFAdvanceX",
+    "args": {
+        "color": "rgb",
+        "n_pot_encode": 10,
+        "n_dir_encode": 4,
+        "density_net": {
+            "nf": 128,
+            "n_layers": 4,
+            "act": "relu",
+            "skips": []
+        },
+        "color_net": {
+            "nf": 128,
+            "n_layers": 3,
+            "act": "relu",
+            "skips": []
+        },
+        "specular_net": {
+            "nf": 128,
+            "n_layers": 1,
+            "act": "relu"
+        },
+        "n_featdim": 0,
+        "space": "_nets/hr_r0.8s/snerfadv_voxels+ls6/checkpoint_50.tar",
+        "n_samples": 256,
+        "perturb_sample": true,
+        "appearance": "combined",
+        "density_color_connection": true,
+        "density_regularization_weight": 1e-4,
+        "density_regularization_scale": 1e4,
+        "multi_nets": 16
+    }
+}
\ No newline at end of file
diff --git a/configs/snerfadvx_voxels_x4.json b/configs/snerfadvx_voxels_x4.json
new file mode 100644
index 0000000..1e55fff
--- /dev/null
+++ b/configs/snerfadvx_voxels_x4.json
@@ -0,0 +1,34 @@
+{
+    "model": "SNeRFAdvanceX",
+    "args": {
+        "color": "rgb",
+        "n_pot_encode": 10,
+        "n_dir_encode": 4,
+        "density_net": {
+            "nf": 128,
+            "n_layers": 4,
+            "act": "relu",
+            "skips": []
+        },
+        "color_net": {
+            "nf": 128,
+            "n_layers": 3,
+            "act": "relu",
+            "skips": []
+        },
+        "specular_net": {
+            "nf": 128,
+            "n_layers": 1,
+            "act": "relu"
+        },
+        "n_featdim": 0,
+        "space": "_nets/train_t0.3/snerfadv_voxels+ls2/checkpoint_50.tar",
+        "n_samples": 256,
+        "perturb_sample": true,
+        "appearance": "combined",
+        "density_color_connection": true,
+        "density_regularization_weight": 1e-4,
+        "density_regularization_scale": 1e4,
+        "multi_nets": 4
+    }
+}
\ No newline at end of file
diff --git a/configs/snerfadvx_voxels_x8.json b/configs/snerfadvx_voxels_x8.json
new file mode 100644
index 0000000..1285ead
--- /dev/null
+++ b/configs/snerfadvx_voxels_x8.json
@@ -0,0 +1,34 @@
+{
+    "model": "SNeRFAdvanceX",
+    "args": {
+        "color": "rgb",
+        "n_pot_encode": 10,
+        "n_dir_encode": 4,
+        "density_net": {
+            "nf": 128,
+            "n_layers": 4,
+            "act": "relu",
+            "skips": []
+        },
+        "color_net": {
+            "nf": 128,
+            "n_layers": 3,
+            "act": "relu",
+            "skips": []
+        },
+        "specular_net": {
+            "nf": 128,
+            "n_layers": 1,
+            "act": "relu"
+        },
+        "n_featdim": 0,
+        "space": "_nets/hr_t1.0s/snerfadv_voxels+ls2/checkpoint_50.tar",
+        "n_samples": 256,
+        "perturb_sample": true,
+        "appearance": "combined",
+        "density_color_connection": true,
+        "density_regularization_weight": 1e-4,
+        "density_regularization_scale": 1e4,
+        "multi_nets": 8
+    }
+}
\ No newline at end of file
diff --git a/configs/snerfx_voxels_128x4_x4.json b/configs/snerfx_voxels_128x4_x4.json
new file mode 100644
index 0000000..3d3af48
--- /dev/null
+++ b/configs/snerfx_voxels_128x4_x4.json
@@ -0,0 +1,21 @@
+{
+    "model": "SNeRFX",
+    "args": {
+        "color": "rgb",
+        "n_pot_encode": 10,
+        "n_dir_encode": 4,
+        "fc_params": {
+            "nf": 128,
+            "n_layers": 4,
+            "activation": "relu",
+            "skips": []
+        },
+        "n_featdim": 0,
+        "space": "nets/train1/snerf_voxels/checkpoint_50.tar",
+        "n_samples": 256,
+        "perturb_sample": true,
+        "raymarching_tolerance": 0,
+        "raymarching_chunk_size": -1,
+        "multi_nets": 4
+    }
+}
\ No newline at end of file
diff --git a/configs/snerfx_voxels_128x4_x8.json b/configs/snerfx_voxels_128x4_x8.json
new file mode 100644
index 0000000..374d8fc
--- /dev/null
+++ b/configs/snerfx_voxels_128x4_x8.json
@@ -0,0 +1,21 @@
+{
+    "model": "SNeRFX",
+    "args": {
+        "color": "rgb",
+        "n_pot_encode": 10,
+        "n_dir_encode": 4,
+        "fc_params": {
+            "nf": 128,
+            "n_layers": 4,
+            "activation": "relu",
+            "skips": []
+        },
+        "n_featdim": 0,
+        "space": "nets/train_t0.3/snerf_voxels/checkpoint_50.tar",
+        "n_samples": 256,
+        "perturb_sample": true,
+        "raymarching_tolerance": 0,
+        "raymarching_chunk_size": -1,
+        "multi_nets": 8
+    }
+}
\ No newline at end of file
diff --git a/configs/snerfx_voxels_128x8_x4.json b/configs/snerfx_voxels_128x8_x4.json
new file mode 100644
index 0000000..6a2fe8a
--- /dev/null
+++ b/configs/snerfx_voxels_128x8_x4.json
@@ -0,0 +1,21 @@
+{
+    "model": "SNeRFX",
+    "args": {
+        "color": "rgb",
+        "n_pot_encode": 10,
+        "n_dir_encode": 4,
+        "fc_params": {
+            "nf": 128,
+            "n_layers": 8,
+            "activation": "relu",
+            "skips": [4]
+        },
+        "n_featdim": 0,
+        "space": "nets/train_t0.3/snerf_voxels/checkpoint_50.tar",
+        "n_samples": 256,
+        "perturb_sample": true,
+        "raymarching_tolerance": 0,
+        "raymarching_chunk_size": -1,
+        "multi_nets": 4
+    }
+}
\ No newline at end of file
diff --git a/configs/snerfx_voxels_256x4_x4.json b/configs/snerfx_voxels_256x4_x4.json
new file mode 100644
index 0000000..706d22f
--- /dev/null
+++ b/configs/snerfx_voxels_256x4_x4.json
@@ -0,0 +1,21 @@
+{
+    "model": "SNeRFX",
+    "args": {
+        "color": "rgb",
+        "n_pot_encode": 10,
+        "n_dir_encode": 4,
+        "fc_params": {
+            "nf": 256,
+            "n_layers": 4,
+            "activation": "relu",
+            "skips": []
+        },
+        "n_featdim": 0,
+        "space": "nets/train1/snerf_voxels/checkpoint_50.tar",
+        "n_samples": 256,
+        "perturb_sample": true,
+        "raymarching_tolerance": 0,
+        "raymarching_chunk_size": -1,
+        "multi_nets": 4
+    }
+}
\ No newline at end of file
diff --git a/configs/snerfx_voxels_256x4_x4_balance.json b/configs/snerfx_voxels_256x4_x4_balance.json
new file mode 100644
index 0000000..11bc5e8
--- /dev/null
+++ b/configs/snerfx_voxels_256x4_x4_balance.json
@@ -0,0 +1,22 @@
+{
+    "model": "SNeRFX",
+    "args": {
+        "color": "rgb",
+        "n_pot_encode": 10,
+        "n_dir_encode": 4,
+        "fc_params": {
+            "nf": 256,
+            "n_layers": 4,
+            "activation": "relu",
+            "skips": []
+        },
+        "n_featdim": 0,
+        "space": "voxels",
+        "steps": [4, 16, 8],
+        "n_samples": 16,
+        "perturb_sample": true,
+        "raymarching_tolerance": 0,
+        "raymarching_chunk_size": -1,
+        "multi_nets": 4
+    }
+}
\ No newline at end of file
diff --git a/dash_test.py b/dash_test.py
index 46f1434..9761996 100644
--- a/dash_test.py
+++ b/dash_test.py
@@ -1,67 +1,38 @@
 import os
-import argparse
 import torch
 import json
 import dash
 import dash_core_components as dcc
 import dash_html_components as html
 import plotly.express as px
-import pandas as pd
 import numpy as np
 # from skimage import data
+from pathlib import Path
 from dash.dependencies import Input, Output
 from dash.exceptions import PreventUpdate
 
 
-if __name__ == '__main__':
-    parser = argparse.ArgumentParser()
-    parser.add_argument('--device', type=int, default=0,
-                        help='Which CUDA device to use.')
-    opt = parser.parse_args()
-
-    # Select device
-    torch.cuda.set_device(opt.device)
-    print("Set CUDA:%d as current device." % torch.cuda.current_device())
 torch.autograd.set_grad_enabled(False)
 
 
-from data.spherical_view_syn import *
-from configs.spherical_view_syn import SphericalViewSynConfig
-from utils import netio
 from utils import device
 from utils import view
 from utils import img
 from utils import misc
-from nets.modules import AlphaComposition, Sampler
+import model as mdl
+from modules import AlphaComposition, Sampler
 
 
-datadir = 'data/__new/lobby_fovea_r360x80_t1.0/'
-data_desc_file = 'train1.json'
+datadir = Path('data/__new/classroom_fovea_r360x80_t0.6')
+data_desc_file = 'r120x80.json'
 net_config = 'fovea@snerffast4-rgb_e6_fc512x4_d2.00-50.00_s64_~p'
-net_path = datadir + net_config + '/model-epoch_200.pth'
+model_path = datadir / 'snerf_voxels/checkpoint_50.tar'
 fov = 40
 res = (256, 256)
 pix_img_res = (256, 256)
 center = (0, 0)
 
 
-def load_net(path):
-    print(path)
-    config = SphericalViewSynConfig()
-    config.from_id(net_config)
-    config.sa['perturb_sample'] = False
-    net = config.create_net().to(device.default())
-    netio.load(path, net)
-    return net
-
-
-def load_net_by_name(name):
-    for path in os.listdir(datadir):
-        if path.startswith(name + '@'):
-            return load_net(datadir + path)
-    return None
-
-
 def load_data_desc(data_desc_file) -> view.Trans:
     with open(datadir + data_desc_file, 'r', encoding='utf-8') as file:
         data_desc = json.loads(file.read())
@@ -85,7 +56,10 @@ cam = view.CameraParam({
     'cy': 0.5,
     'normalized': True
 }, res, device=device.default())
-net = load_net(net_path)
+model, _ = mdl.load(model_path, {
+    "perturb_sample": False
+})
+
 
 # Global states
 x = y = None
@@ -159,7 +133,7 @@ app.layout = html.Div([
 
 def plot_alpha_and_density(ray_o, ray_d):
     # colors, densities, depths = net.sample_and_infer(ray_o, ray_d, sampler=sampler)
-    ret = net(ray_o, ray_d, ret_depth=True, debug=True)
+    ret = model(ray_o, ray_d, extra_outputs=['depth', 'layers'])
     colors = ret['layers'][..., : 3]
     densities = ret['sample_densities']
     depths = ret['sample_depths']
@@ -202,7 +176,7 @@ def plot_pixel_image(ray_o, ray_d, r=1):
         ], dim=-1).to(device.default())
         rays_d = pixel_point - rays_o
         rays_d /= rays_d.norm(dim=-1, keepdim=True)
-        image = net(rays_o.view(-1, 3), rays_d.view(-1, 3))['color'] \
+        image = model(rays_o.view(-1, 3), rays_d.view(-1, 3))['color'] \
             .view(1, *pix_img_res, -1).permute(0, 3, 1, 2)
         fig = px.imshow(img.torch2np(image)[0])
     return fig
@@ -230,10 +204,10 @@ def render_view(tx, ty, tz, rx, ry):
             torch.tensor(view.euler_to_matrix([ry, rx, 0]), device=device.default()).view(-1, 3, 3)
         )
         rays_o, rays_d = cam.get_global_rays(test_view, True)
-        ret = net(rays_o.view(-1, 3), rays_d.view(-1, 3), debug=True)
-        image = ret['color'].view(1, res[0], res[1], 3).permute(0, 3, 1, 2)
-        layers = ret['layers'].view(res[0], res[1], -1, 4)
-        layer_weights = ret['weight'].view(res[0], res[1], -1)
+        ret = model(rays_o.view(-1, 3), rays_d.view(-1, 3), extra_outputs=['layers', 'weights'])
+        image = ret['color'].view(1, *res, 3).permute(0, 3, 1, 2)
+        layers = ret['layers'].view(*res, -1, 4)
+        layer_weights = ret['weight'].view(*res, -1)
         fig = px.imshow(img.torch2np(image)[0])
     return fig
 
@@ -241,17 +215,13 @@ def render_view(tx, ty, tz, rx, ry):
 def render_layer(layer):
     if layer is None:
         return None
-    layer_data = torch.sum(layers[..., range(*layer), :3] * layer_weights[..., range(*layer), None],
-                           dim=-2)
-    #layer_data = layer_data[..., :3] * layer_data[..., 3:]
+    layer_data = torch.sum((layers * layer_weights)[..., range(*layer), :3], dim=-2)
     fig = px.imshow(img.torch2np(layer_data))
     return fig
 
 
 def view_pixel(fig, x, y, samples):
-    sampler = Sampler(depth_range=(1, 50), n_samples=samples,
-                      perturb_sample=False, spherical=True,
-                      lindisp=True, inverse_r=True)
+    sampler = model.sampler
     if x is None or y is None:
         return None
     p = torch.tensor([x, y], device=device.default())
diff --git a/data/dataset_factory.py b/data/dataset_factory.py
index 7a1f7c2..53793da 100644
--- a/data/dataset_factory.py
+++ b/data/dataset_factory.py
@@ -1,5 +1,8 @@
-import os
 import json
+import os
+from pathlib import Path
+from typing import Union
+
 import utils.device
 from .pano_dataset import PanoDataset
 from .view_dataset import ViewDataset
@@ -8,16 +11,26 @@ from .view_dataset import ViewDataset
 class DatasetFactory(object):
 
     @staticmethod
-    def load(path, device=None, **kwargs):
+    def get_dataset_desc_path(path: Union[Path, str]):
+        if isinstance(path, str):
+            path = Path(path)
+        if path.suffix != ".json":
+            if os.path.exists(f"{path}.json"):
+                path = Path(f"{path}.json")
+            else:
+                path = path / "train.json"
+        return path
+
+    @staticmethod
+    def load(path: Path, device=None, **kwargs):
         device = device or utils.device.default()
-        data_dir = os.path.dirname(path)
+        path = DatasetFactory.get_dataset_desc_path(path)
         with open(path, 'r', encoding='utf-8') as file:
-            data_desc = json.loads(file.read())
-        cwd = os.getcwd()
-        os.chdir(data_dir)
-        if 'type' in data_desc and data_desc['type'] == 'pano':
-            dataset = PanoDataset(data_desc, device=device, **kwargs)
+            data_desc: dict = json.loads(file.read())
+        if data_desc.get('type') == 'pano':
+            dataset_class = PanoDataset
         else:
-            dataset = ViewDataset(data_desc, device=device, **kwargs)
-        os.chdir(cwd)
-        return dataset
\ No newline at end of file
+            dataset_class = ViewDataset
+        dataset = dataset_class(data_desc, root=path.absolute().parent, name=path.stem,
+                                device=device, **kwargs)
+        return dataset
diff --git a/data/loader.py b/data/loader.py
index 49163cd..bcc474c 100644
--- a/data/loader.py
+++ b/data/loader.py
@@ -1,8 +1,8 @@
-from doctest import debug_script
-from logging import *
 import threading
 import torch
 import math
+from logging import *
+from typing import Dict
 
 
 class Preloader(object):
@@ -75,17 +75,18 @@ class DataLoader(object):
             self.chunk_idx += 1
             self.current_chunk = self.chunks[self.chunk_idx]
             self.offset = 0
-            self.indices = torch.randperm(len(self.current_chunk), device=self.device) \
+            self.indices = torch.randperm(len(self.current_chunk)).to(device=self.device) \
                 if self.shuffle else None
             if self.preloader is not None:
                 self.preloader.preload_chunk(self.chunks[(self.chunk_idx + 1) % len(self.chunks)])
 
     def __init__(self, dataset, batch_size, *,
-                 chunk_max_items=None, shuffle=False, enable_preload=True):
+                 chunk_max_items=None, shuffle=False, enable_preload=True, **chunk_args):
         super().__init__()
         self.dataset = dataset
         self.batch_size = batch_size
         self.shuffle = shuffle
+        self.chunk_args = chunk_args
         self.preloader = Preloader(self.dataset.device) if enable_preload else None
         self._init_chunks(chunk_max_items)
 
@@ -97,20 +98,18 @@ class DataLoader(object):
         return sum(math.ceil(len(chunk) / self.batch_size) for chunk in self.chunks)
 
     def _init_chunks(self, chunk_max_items):
-        data = self.dataset.get_data()
+        data: Dict[str, torch.Tensor] = self.dataset.get_data()
         if self.shuffle:
-            rand_seq = torch.randperm(self.dataset.n_views, device=self.dataset.device)
-            for key in data:
-                data[key] = data[key][rand_seq]
+            rand_seq = torch.randperm(self.dataset.n_views).to(device=self.dataset.device)
+            data = {key: val[rand_seq] for key, val in data.items()}
         self.chunks = []
         n_chunks = 1 if chunk_max_items is None else \
             math.ceil(self.dataset.n_pixels / chunk_max_items)
         views_per_chunk = math.ceil(self.dataset.n_views / n_chunks)
         for offset in range(0, self.dataset.n_views, views_per_chunk):
             sel = slice(offset, offset + views_per_chunk)
-            chunk_data = {}
-            for key in data:
-                chunk_data[key] = data[key][sel]
-            self.chunks.append(self.dataset.Chunk(len(self.chunks), self.dataset, **chunk_data))
+            chunk_data = {key: val[sel] for key, val in data.items()}
+            self.chunks.append(self.dataset.Chunk(len(self.chunks), self.dataset,
+                                                  chunk_data=chunk_data, **self.chunk_args))
         if self.preloader is not None:
             self.preloader.preload_chunk(self.chunks[0])
diff --git a/data/pano_dataset.py b/data/pano_dataset.py
index 9953c8f..918ba87 100644
--- a/data/pano_dataset.py
+++ b/data/pano_dataset.py
@@ -1,10 +1,12 @@
 import os
 import torch
 import torch.nn.functional as nn_f
-from typing import Tuple, Union
+from typing import Dict, Tuple, Union
+from operator import itemgetter
+from pathlib import Path
+
 from utils import img
 from utils import color
-from utils import misc
 from utils import sphere
 from utils.mem_profiler import *
 from utils.constants import *
@@ -27,8 +29,16 @@ class PanoDataset(object):
 
     class Chunk(object):
 
-        def __init__(self, id, dataset, *,
-                     indices: torch.Tensor, centers: torch.Tensor):
+        @property
+        def n_views(self):
+            return self.indices.size(0)
+
+        @property
+        def n_pixels_per_view(self):
+            return self.dataset.n_pixels_per_view
+
+        def __init__(self, id: int, dataset, chunk_data: Dict[str, torch.Tensor], *,
+                     color: int, **kwargs):
             """
             [summary]
 
@@ -38,10 +48,9 @@ class PanoDataset(object):
             """
             self.id = id
             self.dataset = dataset
-            self.indices = indices
-            self.centers = centers
-            self.n_views = self.indices.size(0)
-            self.n_pixels_per_view = self.dataset.res[0] * self.dataset.res[1]
+            self.indices = chunk_data['indices']
+            self.centers = chunk_data['centers']
+            self.color = color
             self.colors_cpu = None
             self.colors = None
             self.loaded = False
@@ -53,12 +62,12 @@ class PanoDataset(object):
 
         def load(self):
             if self.dataset.image_path is not None and self.colors_cpu is None:
-                images = color.cvt(
-                    img.load(self.dataset.image_path % i for i in self.indices),
-                    color.RGB, self.dataset.c)
-                if self.dataset.res != list(images.shape[-2:]):
+                images = color.cvt(img.load(self.dataset.image_path % i for i in self.indices),
+                                   color.RGB, self.color)
+                if self.dataset.res != tuple(images.shape[-2:]):
                     images = nn_f.interpolate(images, self.dataset.res)
-                self.colors_cpu = images.permute(0, 2, 3, 1).flatten(0, 2)
+                self.colors_cpu = images.permute(0, 2, 3, 1) \
+                    [:, self.dataset.pixels[:, 0], self.dataset.pixels[:, 1]].flatten(0, 1)
             if self.colors_cpu is not None:
                 self.colors = self.colors_cpu.to(self.dataset.device)
             self.loaded = True
@@ -74,15 +83,27 @@ class PanoDataset(object):
                 self.load()
             view_idx = idx // self.n_pixels_per_view
             pix_idx = idx % self.n_pixels_per_view
+            global_idx = self.indices[view_idx] * self.n_pixels_per_view + pix_idx
             extra_data = {}
             if self.colors is not None:
-                extra_data['colors'] = self.colors[idx]
+                extra_data['color'] = self.colors[idx]
             rays_o = self.centers[view_idx]
-            rays_d = self.dataset.pano_rays[pix_idx]
-            return idx, rays_o, rays_d, extra_data
+            rays_d = self.dataset.rays[pix_idx]
+            return global_idx, rays_o, rays_d, extra_data
+
+    @property
+    def n_views(self):
+        return self.centers.size(0)
+
+    @property
+    def n_pixels_per_view(self):
+        return self.pixels.size(0)
+
+    @property
+    def n_pixels(self):
+        return self.n_views * self.n_pixels_per_view
 
-    def __init__(self, desc: dict, *,
-                 c: int = color.RGB,
+    def __init__(self, desc: dict, root: Path, name: str, *,
                  load_images: bool = True,
                  res: Tuple[int, int] = None,
                  views_to_load: Union[range, torch.Tensor] = None,
@@ -104,7 +125,8 @@ class PanoDataset(object):
         :param c ```int```: color space to convert view images to
         :param calculate_rays ```bool```: whether calculate rays
         """
-        self.c = c
+        self.root = root
+        self.name = name
         self.device = device
         self._load_desc(desc, res, views_to_load, load_images)
 
@@ -119,26 +141,26 @@ class PanoDataset(object):
                    views_to_load: Union[range, torch.Tensor],
                    load_images: bool):
         if load_images and desc.get('view_file_pattern'):
-            self.image_path = os.path.join(os.getcwd(), desc['view_file_pattern'])
+            file_pattern = desc['view_file_pattern']
+            if "/" not in file_pattern:
+                file_pattern = f"{self.name}/{file_pattern}"
+            self.image_path = str(self.root / file_pattern)
         else:
             self.image_path = None
-        self.res = res if res else misc.values(desc['view_res'], 'y', 'x')
-        self.depth_range = misc.values(desc['depth_range'], 'min', 'max') \
+        self.res = res if res else itemgetter("y", "x")(desc['view_res'])
+        self.depth_range = itemgetter("min", "max")(desc['depth_range']) \
             if 'depth_range' in desc else None
-        self.range = misc.values(desc['range'], 'min', 'max') if 'range' in desc else None
+        self.bbox = None
         self.samples = desc.get('samples')
         self.centers = torch.tensor(desc['view_centers'], device=self.device)  # (N, 3)
-        self.indices = torch.tensor(
-            desc['views'] if 'views' in desc else list(range(self.centers.size(0))),
-            device=self.device)
+        self.indices = torch.tensor(desc.get('views') or [*range(self.centers.size(0))],
+                                    device=self.device)
 
         if views_to_load is not None:
             self.centers = self.centers[views_to_load]
             self.indices = self.indices[views_to_load]
 
-        self.n_views = self.centers.size(0)
-        self.n_pixels = self.n_views * self.res[0] * self.res[1]
-        self.pano_rays = self._get_pano_rays()  # [H*W, 3]
+        self.pixels, self.rays = self._get_pano_rays()
 
         if desc.get('gl_coord'):
             print('Convert from OGL coordinate to DX coordinate (i. e. flip z axis)')
@@ -148,12 +170,16 @@ class PanoDataset(object):
         """
         Get unprojected rays of pixels on a panorama
 
-        :return `Tensor(H*W, 3)`: rays' directions with one unit length
+        :return `Tensor(N, 2)`: rays' pixel coordinates in pano image
+        :return `Tensor(N, 3)`: rays' directions with one unit length
         """
-        spher_coords = torch.cat([
-            torch.ones(*self.res, 1),
-            ((misc.meshgrid(*self.res, normalize=True)) *
-             torch.tensor([-2.0, 1.0]) + torch.tensor([1.5, 0.0])) * PI
-        ], dim=-1).to(device=self.device)
-        coords = sphere.spherical2cartesian(spher_coords)
-        return coords.flatten(0, 1)  # [H*W, 3]
+        phi = (torch.arange(self.res[0], device=self.device) + 0.5) / self.res[0] * PI  # (H)
+        length = (phi.sin() * self.res[1] * 0.5).ceil() * 2
+        cols = torch.arange(self.res[1], device=self.device)[None, :].expand(*self.res)  # (H, W)
+        mask = torch.logical_and(cols >= (self.res[1] - length[:, None]) / 2,
+                                 cols < (self.res[1] + length[:, None]) / 2)  # (H, W)
+        pixs = mask.nonzero()  # (N, 2)
+        pixs_phi = (0.5 - (pixs[:, 0] + 0.5) / self.res[0]) * PI
+        pixs_theta = (pixs[:, 1] * 2 + 1 - self.res[1]) / length[pixs[:, 0]] * PI
+        spher_coords = torch.stack([torch.ones_like(pixs_phi), pixs_theta, pixs_phi], dim=-1)
+        return pixs, sphere.spherical2cartesian(spher_coords)  # (N, 3)
diff --git a/data/view_dataset.py b/data/view_dataset.py
index 477629b..34acf39 100644
--- a/data/view_dataset.py
+++ b/data/view_dataset.py
@@ -1,11 +1,13 @@
 import os
 import torch
 import torch.nn.functional as nn_f
-from typing import Tuple, Union
+from typing import Dict, Tuple, Union
+from operator import itemgetter
+from pathlib import Path
+
 from utils import img
 from utils import view
 from utils import color
-from utils import misc
 
 
 class ViewDataset(object):
@@ -25,20 +27,21 @@ class ViewDataset(object):
 
     class Chunk(object):
 
-        def __init__(self, id, dataset, *,
-                     indices: torch.Tensor, centers: torch.Tensor, rots: torch.Tensor):
+        def __init__(self, id: int, dataset, chunk_data: Dict[str, torch.Tensor], *,
+                     color: int, **kwargs):
             """
             [summary]
 
-            :param dataset `PanoDataset`: dataset object
+            :param dataset `ViewDataset`: dataset object
             :param indices `Tensor(N)`: indices of views
             :param centers `Tensor(N, 3)`: centers of views
             """
             self.id = id
             self.dataset = dataset
-            self.indices = indices
-            self.centers = centers
-            self.rots = rots
+            self.indices = chunk_data['indices']
+            self.centers = chunk_data['centers']
+            self.rots = chunk_data['rots']
+            self.color = color
             self.n_views = self.indices.size(0)
             self.n_pixels_per_view = self.dataset.res[0] * self.dataset.res[1]
             self.colors = self.depths = self.bins = None
@@ -50,35 +53,39 @@ class ViewDataset(object):
             self.loaded = False
 
         def load(self):
-            if self.dataset.image_path and self.colors_cpu is None:
-                images = color.cvt(
-                    img.load(self.dataset.image_path % i for i in self.indices),
-                    color.RGB, self.dataset.c)
-                if self.dataset.res != list(images.shape[-2:]):
-                    images = nn_f.interpolate(images, self.dataset.res)
-                self.colors_cpu = images.permute(0, 2, 3, 1).flatten(0, 2)
-            if self.colors_cpu is not None:
-                self.colors = self.colors_cpu.to(self.dataset.device, non_blocking=True)
-
-            if self.dataset.depth_path and self.depths_cpu is None:
-                depths = self.dataset._decode_depth_images(
-                    img.load(self.depth_path % i for i in self.indices))
-                if self.dataset.res != list(depths.shape[-2:]):
-                    depths = nn_f.interpolate(depths, self.dataset.res)
-                self.depths_cpu = depths.flatten(0, 2)
-            if self.depths_cpu is not None:
-                self.depths = self.depths_cpu.to(self.dataset.device, non_blocking=True)
-
-            if self.dataset.bins_path and self.bins_cpu is None:
-                bins = img.load([self.dataset.bins_path % i for i in self.indices])
-                if self.dataset.res != list(bins.shape[-2:]):
-                    bins = nn_f.interpolate(bins, self.dataset.res)
-                self.bins_cpu = bins.permute(0, 2, 3, 1).flatten(0, 2)
-            if self.bins_cpu is not None:
-                self.bins = self.bins_cpu.to(self.dataset.device, non_blocking=True)
-
-            torch.cuda.current_stream(self.dataset.device).synchronize()
-            self.loaded = True
+            #print("chunk load")
+            try:
+                if self.dataset.image_path and self.colors_cpu is None:
+                    images = color.cvt(img.load(self.dataset.image_path % i for i in self.indices),
+                                       color.RGB, self.color)
+                    if self.dataset.res != list(images.shape[-2:]):
+                        images = nn_f.interpolate(images, self.dataset.res)
+                    self.colors_cpu = images.permute(0, 2, 3, 1).flatten(0, 2)
+                if self.colors_cpu is not None:
+                    self.colors = self.colors_cpu.to(self.dataset.device, non_blocking=True)
+
+                if self.dataset.depth_path and self.depths_cpu is None:
+                    depths = self.dataset._decode_depth_images(
+                        img.load(self.depth_path % i for i in self.indices))
+                    if self.dataset.res != list(depths.shape[-2:]):
+                        depths = nn_f.interpolate(depths, self.dataset.res)
+                    self.depths_cpu = depths.flatten(0, 2)
+                if self.depths_cpu is not None:
+                    self.depths = self.depths_cpu.to(self.dataset.device, non_blocking=True)
+
+                if self.dataset.bins_path and self.bins_cpu is None:
+                    bins = img.load([self.dataset.bins_path % i for i in self.indices])
+                    if self.dataset.res != list(bins.shape[-2:]):
+                        bins = nn_f.interpolate(bins, self.dataset.res)
+                    self.bins_cpu = bins.permute(0, 2, 3, 1).flatten(0, 2)
+                if self.bins_cpu is not None:
+                    self.bins = self.bins_cpu.to(self.dataset.device, non_blocking=True)
+
+                torch.cuda.current_stream(self.dataset.device).synchronize()
+                self.loaded = True
+            except Exception as ex:
+                print(ex)
+                exit(-1)
 
         def __len__(self):
             return self.n_views * self.n_pixels_per_view
@@ -88,21 +95,24 @@ class ViewDataset(object):
                 self.load()
             view_idx = idx // self.n_pixels_per_view
             pix_idx = idx % self.n_pixels_per_view
+            global_idx = self.indices[view_idx] * self.n_pixels_per_view + pix_idx
             rays_o = self.centers[view_idx]
-            rays_d = self.dataset.cam_rays[pix_idx] # (N, 3)
-            r = self.rots[view_idx].movedim(-1, -2) # (N, 3, 3)
-            rays_d = torch.matmul(rays_d, r)
-            extra_data = {}
+            rays_d = self.dataset.cam_rays[pix_idx][:, None]  # (N, 1, 3)
+            r = self.rots[view_idx].movedim(-1, -2)  # (N, 3, 3)
+            rays_d = torch.matmul(rays_d, r)[:, 0]  # (N, 3)
+            extra_data = {
+                'view_idx': view_idx,
+                'pix_idx': pix_idx
+            } # TBR
             if self.colors is not None:
-                extra_data['colors'] = self.colors[idx]
+                extra_data['color'] = self.colors[idx]
             if self.depths is not None:
-                extra_data['depths'] = self.depths[idx]
+                extra_data['depth'] = self.depths[idx]
             if self.bins is not None:
-                extra_data['bins'] = self.bins[idx]
-            return idx, rays_o, rays_d, extra_data
+                extra_data['bin'] = self.bins[idx]
+            return global_idx, rays_o, rays_d, extra_data
 
-    def __init__(self, desc: dict, *,
-                 c: int = color.RGB,
+    def __init__(self, desc: dict, root: Path, name: str, *,
                  load_images: bool = True,
                  load_depths: bool = False,
                  load_bins: bool = False,
@@ -127,7 +137,8 @@ class ViewDataset(object):
         :param c ```int```: color space to convert view images to
         :param calculate_rays ```bool```: whether calculate rays
         """
-        self.c = c
+        self.root = root
+        self.name = name
         self.device = device
         self._load_desc(desc, res, views_to_load, load_images, load_depths, load_bins)
 
@@ -137,7 +148,7 @@ class ViewDataset(object):
             'centers': self.centers,
             'rots': self.rots
         }
-    
+
     def _decode_depth_images(self, input):
         disp_range = (1 / self.depth_range[0], 1 / self.depth_range[1])
         disp_val = (1 - input[..., 0, :, :]) * (disp_range[1] - disp_range[0]) + disp_range[0]
@@ -150,22 +161,32 @@ class ViewDataset(object):
                    load_depths: bool,
                    load_bins: bool):
         if load_images and desc.get('view_file_pattern'):
-            self.image_path = os.path.join(self.data_dir, desc['view_file_pattern'])
+            file_pattern = desc['view_file_pattern']
+            if "/" not in file_pattern:
+                file_pattern = f"{self.name}/{file_pattern}"
+            self.image_path = str(self.root / file_pattern)
         else:
             self.image_path = None
         if load_depths and desc.get('depth_file_pattern'):
-            self.depth_path = os.path.join(self.data_dir, desc['depth_file_pattern'])
+            file_pattern = desc['depth_file_pattern']
+            if "/" not in file_pattern:
+                file_pattern = f"{self.name}/{file_pattern}"
+            self.depth_path = str(self.root / file_pattern)
         else:
             self.depth_path = None
         if load_bins and desc.get('bins_file_pattern'):
-            self.bins_path = os.path.join(self.data_dir, desc['bins_file_pattern'])
+            file_pattern = desc['bins_file_pattern']
+            if "/" not in file_pattern:
+                file_pattern = f"{self.name}/{file_pattern}"
+            self.bins_path = str(self.root / file_pattern)
         else:
             self.bins_path = None
-        self.res = res if res else misc.values(desc['view_res'], 'y', 'x')
+        self.res = res or itemgetter("y", "x")(desc['view_res'])
         self.cam = view.CameraParam(desc['cam_params'], self.res, device=self.device)
-        self.depth_range = misc.values(desc['depth_range'], 'min', 'max') \
+        self.depth_range = itemgetter("min", "max")(desc['depth_range']) \
             if 'depth_range' in desc else None
-        self.range = misc.values(desc['range'], 'min', 'max') if 'range' in desc else None
+        self.range = itemgetter("min", "max")(desc['range']) if 'range' in desc else None
+        self.bbox = desc.get('bbox')
         self.samples = desc.get('samples')
         self.centers = torch.tensor(desc['view_centers'], device=self.device)  # (N, 3)
         self.rots = torch.tensor(
@@ -175,9 +196,8 @@ class ViewDataset(object):
             ]
             if len(desc['view_rots'][0]) == 2 else desc['view_rots'],
             device=self.device).view(-1, 3, 3)  # (N, 3, 3)
-        self.indices = torch.tensor(
-            desc['views'] if 'views' in desc else list(range(self.centers.size(0))),
-            device=self.device)
+        self.indices = torch.tensor(desc.get('views') or [*range(self.centers.size(0))],
+                                    device=self.device)
 
         if views_to_load is not None:
             self.centers = self.centers[views_to_load]
@@ -194,5 +214,5 @@ class ViewDataset(object):
             self.centers[:, 2] *= -1
             self.rots[:, 2] *= -1
             self.rots[..., 2] *= -1
-        
+
         self.cam_rays = self.cam.get_local_rays(flatten=True)
diff --git a/debug/voxel_sampler_export3d.py b/debug/voxel_sampler_export3d.py
new file mode 100644
index 0000000..bb24d54
--- /dev/null
+++ b/debug/voxel_sampler_export3d.py
@@ -0,0 +1,134 @@
+import os
+import sys
+import argparse
+import torch
+
+sys.path.append(os.path.abspath(sys.path[0] + '/../'))
+
+parser = argparse.ArgumentParser()
+parser.add_argument('-m', '--model', type=str,
+                    help='The model file to load for testing')
+parser.add_argument('-r', '--output-rays', type=int, default=100,
+                    help='How many rays to output')
+parser.add_argument('-p', '--prompt', action='store_true',
+                    help='Interactive prompt mode')
+parser.add_argument('dataset', type=str,
+                    help='Dataset description file')
+args = parser.parse_args()
+
+
+import model as mdl
+from utils import misc
+from utils import color
+from utils import interact
+from utils import device
+from data.dataset_factory import *
+from data.loader import DataLoader
+from modules import Samples, Voxels
+from model.nsvf import NSVF
+
+model: NSVF
+samples: Samples
+
+DATA_LOADER_CHUNK_SIZE = 1e8
+
+
+data_desc_path = args.dataset if args.dataset.endswith('.json') \
+    else os.path.join(args.dataset, 'train.json')
+data_desc_name = os.path.splitext(os.path.basename(data_desc_path))[0]
+data_dir = os.path.dirname(data_desc_path) + '/'
+
+
+def get_model_files(datadir):
+    model_files = []
+    for root, _, files in os.walk(datadir):
+        model_files += [
+            os.path.join(root, file).replace(datadir, '')
+            for file in files if file.endswith('.tar') or file.endswith('.pth')
+        ]
+    return model_files
+
+
+if args.prompt:  # Prompt test model, output resolution, output mode
+    model_files = get_model_files(data_dir)
+    args.model = interact.input_enum('Specify test model:', model_files,
+                                     err_msg='No such model file')
+    args.output_rays = interact.input_ex('Specify number of rays to output:',
+                                         interact.input_to_int(), default=10)
+
+model_path = os.path.join(data_dir, args.model)
+model_name = os.path.splitext(os.path.basename(model_path))[0]
+model, iters = mdl.load(model_path, {"perturb_sample": False})
+model.to(device.default()).eval()
+model_class = model.__class__.__name__
+model_args = model.args
+print(f"model: {model_name} ({model_class})")
+print("args:", json.dumps(model.args0))
+
+dataset = DatasetFactory.load(data_desc_path)
+print("Dataset loaded: " + data_desc_path)
+
+run_dir = os.path.dirname(model_path) + '/'
+output_dir = f"{run_dir}output_{int(model_name.split('_')[-1])}"
+
+
+if __name__ == "__main__":
+    with torch.no_grad():
+        # 1. Initialize data loader
+        data_loader = DataLoader(dataset, args.output_rays, chunk_max_items=DATA_LOADER_CHUNK_SIZE,
+                                 shuffle=True, enable_preload=True,
+                                 color=color.from_str(model.args['color']))
+        sys.stdout.write("Export samples...\r")
+        for _, rays_o, rays_d, extra in data_loader:
+            samples, rays_mask = model.sampler(rays_o, rays_d, model.space)
+            invalid_rays_o = rays_o[torch.logical_not(rays_mask)]
+            invalid_rays_d = rays_d[torch.logical_not(rays_mask)]
+            rays_o = rays_o[rays_mask]
+            rays_d = rays_d[rays_mask]
+            break
+        print("Export samples...Done")
+        
+        os.makedirs(output_dir, exist_ok=True)
+
+        export_data = {}
+
+        if model.space.bbox is not None:
+            export_data['bbox'] = model.space.bbox.tolist()
+        if isinstance(model.space, Voxels):
+            export_data['voxel_size'] = model.space.voxel_size.tolist()
+            export_data['voxels'] = model.space.voxels.tolist()
+
+            if False:
+                voxel_access_counts = torch.zeros(model.space.n_voxels, dtype=torch.long,
+                                                device=device.default())
+                iters_in_epoch = 0
+                data_loader.batch_size = 2 ** 20
+                for _, rays_o1, rays_d1, _ in data_loader:
+                    model(rays_o1, rays_d1,
+                        raymarching_tolerance=0.5,
+                        raymarching_chunk_size=0,
+                        voxel_access_counts=voxel_access_counts)
+                    iters_in_epoch += 1
+                    percent = iters_in_epoch / len(data_loader) * 100
+                    sys.stdout.write(f'Export voxel access counts...{percent:.1f}%   \r')
+                export_data['voxel_access_counts'] = voxel_access_counts.tolist()
+                print("Export voxel access counts...Done  ")
+
+        export_data.update({
+            'rays_o': rays_o.tolist(),
+            'rays_d': rays_d.tolist(),
+            'invalid_rays_o': invalid_rays_o.tolist(),
+            'invalid_rays_d': invalid_rays_d.tolist(),
+            'samples': {
+                'depths': samples.depths.tolist(),
+                'dists': samples.dists.tolist(),
+                'voxel_indices': samples.voxel_indices.tolist()
+            }
+        })
+        with open(f'{output_dir}/debug_voxel_sampler_export3d.json', 'w') as fp:
+            json.dump(export_data, fp)
+        print("Write JSON file...Done")
+
+        args.output_rays
+        print(f"Rays: total {args.output_rays}, valid {rays_o.size(0)}")
+        print(f"Samples: average {samples.voxel_indices.ne(-1).sum(-1).float().mean().item()} per ray")
diff --git a/fntest.py b/fntest.py
new file mode 100644
index 0000000..5f5a879
--- /dev/null
+++ b/fntest.py
@@ -0,0 +1,12 @@
+from math import ceil
+
+cdf = [2.2, 3.5, 3.6, 3.7, 4.0]
+bins = []
+part = 1
+offset = 0
+for i in range(len(cdf)):
+    if cdf[i] >= part:
+        bins.append(i + 1 - offset)
+        offset = i + 1
+        part = int(cdf[i]) + 1
+print(bins)
\ No newline at end of file
diff --git a/loss/__init__.py b/loss/__init__.py
new file mode 100644
index 0000000..a4eecbc
--- /dev/null
+++ b/loss/__init__.py
@@ -0,0 +1,5 @@
+from torch.nn import L1Loss, MSELoss
+from torch.nn.functional import l1_loss, mse_loss
+from .ssim import SSIM
+from .perc_loss import VGGPerceptualLoss
+from .cauchy import cauchy_loss, CauchyLoss
\ No newline at end of file
diff --git a/loss/cauchy.py b/loss/cauchy.py
new file mode 100644
index 0000000..5dd213e
--- /dev/null
+++ b/loss/cauchy.py
@@ -0,0 +1,16 @@
+import torch
+
+
+def cauchy_loss(input: torch.Tensor, target: torch.Tensor = None, *, s = 1.0):
+    x = input - target if target is not None else input
+    return (s * x * x * 0.5 + 1).log().mean()
+
+
+class CauchyLoss(torch.nn.Module):
+
+    def __init__(self, s = 1.0):
+        super().__init__()
+        self.s = s
+
+    def forward(self, input: torch.Tensor, target: torch.Tensor = None):
+        return cauchy_loss(input, target, s=self.s)
diff --git a/loss/ssim.py b/loss/ssim.py
index 93f390b..cd38987 100644
--- a/loss/ssim.py
+++ b/loss/ssim.py
@@ -1,7 +1,6 @@
 import torch
 import torch.nn.functional as F
 from torch.autograd import Variable
-import numpy as np
 from math import exp
 
 def gaussian(window_size, sigma):
diff --git a/model/__init__.py b/model/__init__.py
new file mode 100644
index 0000000..60f9f6b
--- /dev/null
+++ b/model/__init__.py
@@ -0,0 +1,45 @@
+import importlib
+import os
+import torch
+from typing import Tuple, Union
+from . import base
+
+
+# Automatically import any python files this directory
+package_dir = os.path.dirname(__file__)
+package = os.path.basename(package_dir)
+for file in os.listdir(package_dir):
+    path = os.path.join(package_dir, file)
+    if file.startswith('_') or file.startswith('.'):
+        continue
+    if file.endswith('.py') or os.path.isdir(path):
+        model_name = file[:-3] if file.endswith('.py') else file
+        importlib.import_module(f'{package}.{model_name}')
+
+
+def get_class(model_class_name: str) -> type:
+    return base.model_classes[model_class_name]
+
+
+def create(model_class_name: str, args0: dict, **extra_args) -> base.BaseModel:
+    model_class = get_class(model_class_name)
+    return model_class(args0, extra_args)
+
+
+def load(path: Union[str, os.PathLike], args0: dict = {}, **extra_args) -> Tuple[base.BaseModel, dict]:
+    states: dict = torch.load(path)
+    states['args'].update(args0)
+    model = create(states['model'], states['args'], **extra_args)
+    model.load_state_dict(states['states'])
+    return model, states
+
+
+def save(path: Union[str, os.PathLike], model: base.BaseModel, **extra_states):
+    #print(f'Save model to {path}...')
+    dict = {
+        'model': model.__class__.__name__,
+        'args': model.args0,
+        'states': model.state_dict(),
+        **extra_states
+    }
+    torch.save(dict, path)
diff --git a/model/base.py b/model/base.py
new file mode 100644
index 0000000..324ad93
--- /dev/null
+++ b/model/base.py
@@ -0,0 +1,34 @@
+import torch.nn as nn
+from utils import color
+
+
+model_classes = {}
+
+
+class BaseModelMeta(type):
+
+    def __new__(cls, name, bases, attrs):
+        new_cls = type.__new__(cls, name, bases, attrs)
+        if name != 'BaseModel':
+            model_classes[name] = new_cls
+        return new_cls
+
+
+class BaseModel(nn.Module, metaclass=BaseModelMeta):
+
+    trainer = "Train"
+
+    @property
+    def args(self):
+        return {**self.args0, **self.args1}
+
+    def __init__(self, args0: dict, args1: dict = {}):
+        super().__init__()
+        self.args0 = args0
+        self.args1 = args1
+        self._chns = {
+            "color": color.chns(color.from_str(self.args['color']))
+        }
+
+    def chns(self, name: str):
+        return self._chns.get(name, 1)
\ No newline at end of file
diff --git a/nets/bg_net.py b/model/bg_net.py
similarity index 100%
rename from nets/bg_net.py
rename to model/bg_net.py
diff --git a/model/nerf.py b/model/nerf.py
new file mode 100644
index 0000000..35adb39
--- /dev/null
+++ b/model/nerf.py
@@ -0,0 +1,181 @@
+import torch
+
+import model
+from .base import *
+from modules import *
+from utils.mem_profiler import MemProfiler
+from utils.perf import perf
+from utils.misc import masked_scatter
+
+
+class NeRF(BaseModel):
+
+    trainer = "TrainWithSpace"
+    SamplerClass = Sampler
+    RendererClass = VolumnRenderer
+
+    def __init__(self, args0: dict, args1: dict = {}):
+        """
+        Initialize a NeRF model
+
+        :param args0 `dict`: basic arguments
+        :param args1 `dict`: extra arguments, defaults to {}
+        """
+        if "sample_step_ratio" in args0:
+            args1["sample_step"] = args0["voxel_size"] * args0["sample_step_ratio"]
+        super().__init__(args0, args1)
+
+        # Initialize components
+        self._init_space()
+        self._init_encoders()
+        self._init_core()
+        self.sampler = self.SamplerClass(**self.args)
+        self.rendering = self.RendererClass(**self.args)
+
+    def _init_encoders(self):
+        self.pot_encoder = InputEncoder.Get(self.args['n_pot_encode'],
+                                            self.args.get('n_featdim') or 3)
+        if self.args.get('n_dir_encode'):
+            self.dir_chns = 3
+            self.dir_encoder = InputEncoder.Get(self.args['n_dir_encode'], self.dir_chns)
+        else:
+            self.dir_chns = 0
+            self.dir_encoder = None
+
+    def _init_space(self):
+        if 'space' not in self.args:
+            self.space = Space(**self.args)
+        elif self.args['space'] == 'octree':
+            self.space = Octree(**self.args)
+        elif self.args['space'] == 'voxels':
+            self.space = Voxels(**self.args)
+        else:
+            self.space = model.load(self.args['space'])[0].space
+        if self.args.get('n_featdim'):
+            self.space.create_embedding(self.args['n_featdim'])
+
+    def _new_core_unit(self):
+        return NerfCore(coord_chns=self.pot_encoder.out_dim,
+                        density_chns=self.chns('density'),
+                        color_chns=self.chns('color'),
+                        core_nf=self.args['fc_params']['nf'],
+                        core_layers=self.args['fc_params']['n_layers'],
+                        dir_chns=self.dir_encoder.out_dim if self.dir_encoder else 0,
+                        dir_nf=self.args['fc_params']['nf'] // 2,
+                        act=self.args['fc_params']['activation'],
+                        skips=self.args['fc_params']['skips'])
+
+    def _create_core(self, n_nets=1):
+        return self._new_core_unit() if n_nets == 1 else nn.ModuleList([
+            self._new_core_unit() for _ in range(n_nets)
+        ])
+
+    def _init_core(self):
+        if not self.args.get("net_bounds"):
+            self.core = self._create_core()
+        else:
+            self.register_buffer("net_bounds", torch.tensor(self.args["net_bounds"]), False)
+            self.cores = self._create_core(self.net_bounds.size(0))
+
+    def render(self, samples: Samples, *outputs: str, **kwargs) -> Dict[str, torch.Tensor]:
+        """
+        Render colors, energies and other values (specified by `outputs`) of samples 
+        (invalid items are filtered out)
+
+        :param samples `Samples(N)`: samples
+        :param outputs `str...`: which types of inferred data should be returned
+        :return `Dict[str, Tensor(N, *)]`: outputs of cores
+        """
+        x = self.encode_x(samples)
+        d = self.encode_d(samples)
+        return self.infer(x, d, *outputs, pts=samples.pts, **kwargs)
+    
+    def infer(self, x: torch.Tensor, d: torch.Tensor, *outputs, pts: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:
+        """
+        Infer colors, energies and other values (specified by `outputs`) of samples 
+        (invalid items are filtered out) given their encoded positions and directions
+
+        :param x `Tensor(N, Ex)`: encoded positions
+        :param d `Tensor(N, Ed)`: encoded directions 
+        :param outputs `str...`: which types of inferred data should be returned
+        :param pts `Tensor(N, 3)`: raw sample positions
+        :return `Dict[str, Tensor(N, *)]`: outputs of cores
+        """
+        if getattr(self, "core", None):
+            return self.core(x, d, outputs)
+        ret = {}
+        for i, core in enumerate(self.cores):
+            selector = (pts >= self.net_bounds[i, 0] and pts < self.net_bounds[i, 1]).all(-1)
+            partial_ret = core(x[selector], d[selector], outputs)
+            for key, value in partial_ret.items():
+                if value is None:
+                    ret[key] = None
+                    continue
+                if key not in ret:
+                    ret[key] = torch.zeros(*x.shape[:-1], value.shape[-1], device=x.device)
+                ret[key] = masked_scatter(selector, value, ret[key])
+        return ret
+
+    def embed(self, samples: Samples) -> torch.Tensor:
+        return self.space.extract_embedding(samples.pts, samples.voxel_indices)
+
+    def encode_x(self, samples: Samples) -> torch.Tensor:
+        x = self.embed(samples) if self.args.get('n_featdim') else samples.pts
+        return self.pot_encoder(x)
+
+    def encode_d(self, samples: Samples) -> torch.Tensor:
+        return self.dir_encoder(samples.dirs) if self.dir_encoder is not None else None
+
+    @torch.no_grad()
+    def get_scores(self, sampled_points: torch.Tensor, sampled_voxel_indices: torch.Tensor) -> torch.Tensor:
+        densities = self.render(Samples(sampled_points, None, None, None, sampled_voxel_indices),
+                                'density')
+        return 1 - (-densities).exp()
+
+    @torch.no_grad()
+    def pruning(self, threshold: float = 0.5, train_stats=False):
+        return self.space.pruning(self.get_scores, threshold, train_stats)
+
+    @torch.no_grad()
+    def splitting(self):
+        ret = self.space.splitting()
+        if 'n_samples' in self.args0:
+            self.args0['n_samples'] *= 2
+        if 'voxel_size' in self.args0:
+            self.args0['voxel_size'] /= 2
+            if "sample_step_ratio" in self.args0:
+                self.args1["sample_step"] = self.args0["voxel_size"] \
+                    * self.args0["sample_step_ratio"]
+        if 'sample_step' in self.args0:
+            self.args0['sample_step'] /= 2
+        self.sampler = self.SamplerClass(**self.args)
+        return ret
+
+    @torch.no_grad()
+    def double_samples(self):
+        pass
+
+    @perf
+    def forward(self, rays_o: torch.Tensor, rays_d: torch.Tensor, *,
+                extra_outputs: List[str] = [], **kwargs) -> torch.Tensor:
+        """
+        Perform rendering for given rays.
+
+        :param rays_o `Tensor(N, 3)`: rays' origin
+        :param rays_d `Tensor(N, 3)`: rays' direction
+        :param extra_outputs `list[str]`: extra items should be contained in the rendering result,
+                                          defaults to []
+        :return `dict[str, Tensor]`: the rendering result, see corresponding Renderer implementation
+        """
+        args = {**self.args, **kwargs}
+        with MemProfiler(f"{self.__class__}.forward: before sampling"):
+            samples, rays_mask = self.sampler(rays_o, rays_d, self.space, **args)
+        MemProfiler.print_memory_stats(f"{self.__class__}.forward: after sampling")
+        with MemProfiler(f"{self.__class__}.forward: rendering"):
+            if samples is None:
+                return None
+            return {
+                **self.rendering(self, samples, extra_outputs, **args),
+                'samples': samples,
+                'rays_mask': rays_mask
+            }
diff --git a/model/nerf_advance.py b/model/nerf_advance.py
new file mode 100644
index 0000000..b3d9716
--- /dev/null
+++ b/model/nerf_advance.py
@@ -0,0 +1,37 @@
+import torch
+from modules import *
+from .nerf import *
+
+
+class NeRFAdvance(NeRF):
+
+    RendererClass = DensityFirstVolumnRenderer
+
+    def __init__(self, args0: dict, args1: dict = {}):
+        super().__init__(args0, args1)
+
+    def _new_core_unit(self):
+        return NerfAdvCore(
+            x_chns=self.pot_encoder.out_dim,
+            d_chns=self.dir_encoder.out_dim,
+            density_chns=self.chns('density'),
+            color_chns=self.chns('color'),
+            density_net_params=self.args["density_net"],
+            color_net_params=self.args["color_net"],
+            specular_net_params=self.args.get("specular_net"),
+            appearance=self.args.get("appearance", "decomposite"),
+            density_color_connection=self.args.get("density_color_connection", False)
+        )
+
+    def infer(self, x: torch.Tensor, d: torch.Tensor, *outputs, extras={}, **kwargs) -> Dict[str, torch.Tensor]:
+        """
+        Infer colors, energies and other values (specified by `outputs`) of samples 
+        (invalid items are filtered out) given their encoded positions and directions
+
+        :param x `Tensor(N, Ex)`: encoded positions
+        :param d `Tensor(N, Ed)`: encoded directions 
+        :param outputs `str...`: which types of inferred data should be returned
+        :param extras `dict`: extra data needed by cores
+        :return `Dict[str, Tensor(N, *)]`: outputs of cores
+        """
+        return self.core(x, d, outputs, **extras)
diff --git a/nets/nerf_depth.py b/model/nerf_depth.py
similarity index 96%
rename from nets/nerf_depth.py
rename to model/nerf_depth.py
index fe5fafd..8826dfc 100644
--- a/nets/nerf_depth.py
+++ b/model/nerf_depth.py
@@ -27,7 +27,7 @@ class NerfDepth(nn.Module):
                           color_chns=self.color_chns,
                           core_nf=fc_params['nf'],
                           core_layers=fc_params['n_layers'],
-                          activation=fc_params['activation'],
+                          act=fc_params['activation'],
                           skips=fc_params['skips'])
         self.sampler = AdaptiveSampler(**sampler_params, n_bins=n_bins,
                                        include_neighbor_bins=include_neighbor_bins)
diff --git a/model/nsvf.py b/model/nsvf.py
new file mode 100644
index 0000000..08bfaeb
--- /dev/null
+++ b/model/nsvf.py
@@ -0,0 +1,16 @@
+from .nerf import *
+from utils.geometry import *
+
+
+class NSVF(NeRF):
+
+    SamplerClass = VoxelSampler
+
+    def __init__(self, args0: dict, args1: dict = {}):
+        """
+        Initialize a NSVF model
+
+        :param args0 `dict`: basic arguments
+        :param args1 `dict`: extra arguments, defaults to {}
+        """
+        super().__init__(args0, args1)
diff --git a/nets/oracle.py b/model/oracle.py
similarity index 96%
rename from nets/oracle.py
rename to model/oracle.py
index 7f61890..1fa3e01 100644
--- a/nets/oracle.py
+++ b/model/oracle.py
@@ -27,7 +27,7 @@ class Oracle(nn.Module):
         self.net = nn.Sequential(
             FcNet(in_chns=self.pos_encoder.out_dim * self.n_samples,
                     out_chns=0, nf=fc_params['nf'], n_layers=fc_params['n_layers'],
-                    skips=[], activation=fc_params['activation']),
+                    skips=[], act=fc_params['activation']),
             FcLayer(fc_params['nf'], self.n_samples, out_activation)
         )
 
diff --git a/model/snerf.py b/model/snerf.py
new file mode 100644
index 0000000..534c515
--- /dev/null
+++ b/model/snerf.py
@@ -0,0 +1,26 @@
+import math
+from .nerf import *
+
+
+class SNeRF(NeRF):
+    SamplerClass = SphericalSampler
+
+    def __init__(self, args0: dict, args1: dict = {}):
+        """
+        Initialize a multi-sphere-layer net
+
+        :param fc_params: parameters for full-connection network
+        :param sampler_params: parameters for sampler
+        :param normalize_coord: whether normalize the spherical coords to [0, 2pi] before encode
+        :param c: color mode
+        :param encode_to_dim: encode input to number of dimensions
+        """
+        sample_range = [1 / args0['depth_range'][0], 1 / args0['depth_range'][1]] \
+            if args0.get('depth_range') else [1, 0]
+        rot_range = [[-180, -90], [180, 90]]
+        args1['bbox'] = [
+            [sample_range[0], math.radians(rot_range[0][0]), math.radians(rot_range[0][1])],
+            [sample_range[1], math.radians(rot_range[1][0]), math.radians(rot_range[1][1])]
+        ]
+        args1['sample_range'] = sample_range
+        super().__init__(args0, args1)
\ No newline at end of file
diff --git a/model/snerf_advance.py b/model/snerf_advance.py
new file mode 100644
index 0000000..dd321c8
--- /dev/null
+++ b/model/snerf_advance.py
@@ -0,0 +1,33 @@
+import math
+from .nerf_advance import *
+
+
+class SNeRFAdvance(NeRFAdvance):
+    SamplerClass = SphericalSampler
+
+    def __init__(self, args0: dict, args1: dict = {}):
+        """
+        Initialize a multi-sphere-layer net
+
+        :param fc_params: parameters for full-connection network
+        :param sampler_params: parameters for sampler
+        :param normalize_coord: whether normalize the spherical coords to [0, 2pi] before encode
+        :param c: color mode
+        :param encode_to_dim: encode input to number of dimensions
+        """
+        sample_range = [1 / args0['depth_range'][0], 1 / args0['depth_range'][1]] \
+            if args0.get('depth_range') else [1, 0]
+        rot_range = [[-180, -90], [180, 90]]
+        args1['bbox'] = [
+            [sample_range[0], math.radians(rot_range[0][0]), math.radians(rot_range[0][1])],
+            [sample_range[1], math.radians(rot_range[1][0]), math.radians(rot_range[1][1])]
+        ]
+        args1['sample_range'] = sample_range
+        if args0.get('multi_nets'):
+            n = args0['multi_nets']
+            step = (sample_range[1] - sample_range[0]) / n
+            args1['net_bounds'] = [[
+                [sample_range[0] + step * (i + 1), *args1['bbox'][0][1:]],
+                [sample_range[0] + step * i, *args1['bbox'][1][1:]]
+            ] for i in range(n)]
+        super().__init__(args0, args1)
\ No newline at end of file
diff --git a/model/snerf_advance_x.py b/model/snerf_advance_x.py
new file mode 100644
index 0000000..de2d71c
--- /dev/null
+++ b/model/snerf_advance_x.py
@@ -0,0 +1,74 @@
+from utils.misc import print_and_log
+from .snerf_advance import *
+
+
+class SNeRFAdvanceX(SNeRFAdvance):
+
+    RendererClass = DensityFirstVolumnRenderer
+
+    def __init__(self, args0: dict, args1: dict = {}):
+        """
+        Initialize a multi-sphere-layer net
+
+        :param fc_params: parameters for full-connection network
+        :param sampler_params: parameters for sampler
+        :param normalize_coord: whether normalize the spherical coords to [0, 2pi] before encode
+        :param c: color mode
+        :param encode_to_dim: encode input to number of dimensions
+        """
+        super().__init__(args0, args1)
+
+    def _init_core(self):
+        if "net_samples" not in self.args:
+            n_nets = self.args.get("multi_nets", 1)
+            k = self.args["n_samples"] // self.space.steps[0].item()
+            self.args0["net_samples"] = [val * k for val in self.space.balance_cut(0, n_nets)]
+        self.cores = self._create_core(len(self.args0["net_samples"]))
+
+    def infer(self, x: torch.Tensor, d: torch.Tensor, *outputs, chunk_id: int, extras={}, **kwargs) -> Dict[str, torch.Tensor]:
+        """
+        Infer colors, energies and other values (specified by `outputs`) of samples 
+        (invalid items are filtered out) given their encoded positions and directions
+
+        :param x `Tensor(N, Ex)`: encoded positions
+        :param d `Tensor(N, Ed)`: encoded directions
+        :param outputs `str...`: which types of inferred data should be returned
+        :param chunk_id `int`: current index of sample chunk in renderer
+        :param extras `dict`: extra data needed by cores
+        :return `Dict[str, Tensor(N, *)]`: outputs of cores
+        """
+        return self.cores[chunk_id](x, d, outputs, **extras)
+
+    @torch.no_grad()
+    def get_scores(self, sampled_points: torch.Tensor, sampled_voxel_indices: torch.Tensor) -> torch.Tensor:
+        raise NotImplementedError()
+
+    @torch.no_grad()
+    def pruning(self, threshold: float = 0.5, train_stats=False):
+        raise NotImplementedError()
+
+    @torch.no_grad()
+    def splitting(self):
+        ret = super().splitting()
+        k = self.args["n_samples"] // self.space.steps[0].item()
+        net_samples = [val * k for val in self.space.balance_cut(0, len(self.cores))]
+        if len(net_samples) != len(self.cores):
+            print_and_log('Note: the result of balance cut has no enough bins. Keep origin cut.')
+            net_samples = [val * 2 for val in self.args0["net_samples"]]
+        self.args0['net_samples'] = net_samples
+        return ret
+
+    @perf
+    def forward(self, rays_o: torch.Tensor, rays_d: torch.Tensor, *,
+                extra_outputs: List[str] = [], **kwargs) -> torch.Tensor:
+        """
+        Perform rendering for given rays.
+
+        :param rays_o `Tensor(N, 3)`: rays' origin
+        :param rays_d `Tensor(N, 3)`: rays' direction
+        :param extra_outputs `list[str]`: extra items should be contained in the rendering result,
+                                          defaults to []
+        :return `dict[str, Tensor]`: the rendering result, see corresponding Renderer implementation
+        """
+        return super().forward(rays_o, rays_d, extra_outputs=extra_outputs, **kwargs,
+                               raymarching_chunk_size_or_sections=self.args["net_samples"])
diff --git a/nets/snerf_fast.py b/model/snerf_fast.py
similarity index 98%
rename from nets/snerf_fast.py
rename to model/snerf_fast.py
index d99165e..4d627e8 100644
--- a/nets/snerf_fast.py
+++ b/model/snerf_fast.py
@@ -48,7 +48,7 @@ class SnerfFast(nn.Module):
                      core_layers=fc_params['n_layers'],
                      dir_chns=self.dir_chns_per_part,
                      dir_nf=fc_params['nf'] // 2,
-                     activation=fc_params['activation'])
+                     act=fc_params['activation'])
             for _ in range(self.n_parts)
         ]
         for i in range(self.n_parts):
diff --git a/model/snerf_x.py b/model/snerf_x.py
new file mode 100644
index 0000000..64bfcf9
--- /dev/null
+++ b/model/snerf_x.py
@@ -0,0 +1,79 @@
+from utils.misc import print_and_log
+from .snerf import *
+
+
+class SNeRFX(SNeRF):
+
+    trainer = "TrainWithSpace"
+    SamplerClass = SphericalSampler
+    RendererClass = VolumnRenderer
+
+    def __init__(self, args0: dict, args1: dict = {}):
+        """
+        Initialize a multi-sphere-layer net
+
+        :param fc_params: parameters for full-connection network
+        :param sampler_params: parameters for sampler
+        :param normalize_coord: whether normalize the spherical coords to [0, 2pi] before encode
+        :param c: color mode
+        :param encode_to_dim: encode input to number of dimensions
+        """
+        super().__init__(args0, args1)
+
+    def _init_core(self):
+        if "net_samples" not in self.args:
+            n_nets = self.args.get("multi_nets", 1)
+            k = self.args["n_samples"] // self.space.steps[0].item()
+            self.args0["net_samples"] = [val * k for val in self.space.balance_cut(0, n_nets)]
+        self.cores = self._create_core(len(self.args0["net_samples"]))
+
+    def render(self, samples: Samples, *outputs: str, chunk_id: int, **kwargs) -> Dict[str, torch.Tensor]:
+        """
+        Infer colors, energies and other values (specified by `outputs`) of samples 
+        (invalid items are filtered out)
+
+        :param samples `Samples(N)`: samples
+        :param outputs `str...`: which types of inferred data should be returned
+        :param chunk_id `int`: current index of sample chunk in renderer
+        :return `Dict[str, Tensor(N, *)]`: outputs of cores
+        """
+        x = self.encode_x(samples)
+        d = self.encode_d(samples)
+        return self.cores[chunk_id](x, d, outputs)
+
+    @torch.no_grad()
+    def get_scores(self, sampled_points: torch.Tensor, sampled_voxel_indices: torch.Tensor) -> torch.Tensor:
+        raise NotImplementedError()
+
+    @torch.no_grad()
+    def pruning(self, threshold: float = 0.5, train_stats=False):
+        raise NotImplementedError()
+
+    @torch.no_grad()
+    def splitting(self):
+        ret = super().splitting()
+        k = self.args["n_samples"] // self.space.steps[0].item()
+        net_samples = [
+            val * k for val in self.space.balance_cut(0, len(self.cores))
+        ]
+        if len(net_samples) != len(self.cores):
+            print_and_log('Note: the result of balance cut has no enough bins. Keep origin cut.')
+            net_samples = [val * 2 for val in self.args0["net_samples"]]
+        self.args0['net_samples'] = net_samples
+        self.sampler = self.SamplerClass(**self.args)
+        return ret
+
+    @perf
+    def forward(self, rays_o: torch.Tensor, rays_d: torch.Tensor, *,
+                extra_outputs: List[str] = [], **kwargs) -> torch.Tensor:
+        """
+        Perform rendering for given rays.
+
+        :param rays_o `Tensor(N, 3)`: rays' origin
+        :param rays_d `Tensor(N, 3)`: rays' direction
+        :param extra_outputs `list[str]`: extra items should be contained in the rendering result,
+                                          defaults to []
+        :return `dict[str, Tensor]`: the rendering result, see corresponding Renderer implementation
+        """
+        return super().forward(rays_o, rays_d, extra_outputs=extra_outputs, **kwargs,
+                               raymarching_chunk_size_or_sections=self.args["net_samples"])
diff --git a/modules/__init__.py b/modules/__init__.py
index 2facaa3..c45b69e 100644
--- a/modules/__init__.py
+++ b/modules/__init__.py
@@ -1,43 +1,5 @@
-from typing import Tuple
-import torch
-import torch.nn as nn
-from torch.nn.modules.linear import Identity
-from utils.constants import *
-from .generic import *
 from .sampler import *
 from .input_encoder import *
 from .renderer import *
-
-
-class NerfCore(nn.Module):
-
-    def __init__(self, *, coord_chns, density_chns, color_chns, core_nf, core_layers,
-                 dir_chns=0, dir_nf=0, activation='relu', skips=[]):
-        super().__init__()
-        self.core = FcNet(in_chns=coord_chns, out_chns=0, nf=core_nf, n_layers=core_layers,
-                          skips=skips, activation=activation)
-        self.density_out = FcLayer(core_nf, density_chns) if density_chns > 0 else None
-        if color_chns == 0:
-            self.feature_out = None
-            self.color_out = None
-        elif dir_chns > 0:
-            self.feature_out = FcLayer(core_nf, core_nf)
-            self.color_out = nn.Sequential(
-                FcLayer(core_nf + dir_chns, dir_nf, activation),
-                FcLayer(dir_nf, color_chns)
-            )
-        else:
-            self.feature_out = Identity()
-            self.color_out = FcLayer(core_nf, color_chns)
-
-    def forward(self, coord: torch.Tensor, dir: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]:
-        core_output = self.core(coord)
-        density = self.density_out(core_output) if self.density_out is not None else None
-        if self.color_out is None:
-            color = None
-        else:
-            feature = self.feature_out(core_output)
-            if dir is not None:
-                feature = torch.cat([feature, dir], dim=-1)
-            color = torch.sigmoid(self.color_out(feature))
-        return color, density
\ No newline at end of file
+from .space import *
+from .core import *
\ No newline at end of file
diff --git a/modules/core.py b/modules/core.py
new file mode 100644
index 0000000..28219fd
--- /dev/null
+++ b/modules/core.py
@@ -0,0 +1,175 @@
+from .generic import *
+from typing import Dict
+
+
+class NerfCore(nn.Module):
+
+    def __init__(self, *, coord_chns, density_chns, color_chns, core_nf, core_layers,
+                 dir_chns=0, dir_nf=0, act='relu', skips=[]):
+        super().__init__()
+        self.core = FcNet(in_chns=coord_chns, out_chns=None, nf=core_nf, n_layers=core_layers,
+                          skips=skips, act=act)
+        self.density_out = FcLayer(core_nf, density_chns) if density_chns > 0 else None
+        if color_chns == 0:
+            self.feature_out = None
+            self.color_out = None
+        elif dir_chns > 0:
+            self.feature_out = FcLayer(core_nf, core_nf)
+            self.color_out = nn.Sequential(
+                FcLayer(core_nf + dir_chns, dir_nf, act),
+                FcLayer(dir_nf, color_chns)
+            )
+        else:
+            self.feature_out = torch.nn.Identity()
+            self.color_out = FcLayer(core_nf, color_chns)
+
+    def forward(self, x: torch.Tensor, d: torch.Tensor, outputs: List[str]) -> Dict[str, torch.Tensor]:
+        ret = {}
+        core_output = self.core(x)
+        if 'density' in outputs:
+            ret['density'] = torch.relu(self.density_out(core_output)) \
+                if self.density_out is not None else None
+        if 'color' in outputs:
+            if self.color_out is None:
+                ret['color'] = None
+            else:
+                feature = self.feature_out(core_output)
+                if dir is not None:
+                    feature = torch.cat([feature, d], dim=-1)
+                ret['color'] = self.color_out(feature).sigmoid()
+        for key in outputs:
+            if key == 'density' or key == 'color':
+                continue
+            ret[key] = None
+        return ret
+
+
+class NerfAdvCore(nn.Module):
+
+    def __init__(self, *, x_chns: int, d_chns: int, density_chns: int, color_chns: int,
+                 density_net_params: dict, color_net_params: dict,
+                 specular_net_params: dict = None,
+                 appearance="decomposite",
+                 density_color_connection=False):
+        """
+        Create a NeRF-Adv Core Net.
+        Required parameters for the sub-mlps include: "nf", "n_layers", "skips" and "act".
+        Other parameters will be properly set automatically.
+
+        :param x_chns `int`: the channels of input "position"
+        :param d_chns `int`: the channels of input "direction"
+        :param density_chns `int`: the channels of output "density"
+        :param color_chns `int`: the channels of output "color"
+        :param density_net_params `dict`: parameters for the density net
+        :param color_net_params `dict`: parameters for the color net
+        :param specular_net_params `dict`: (optional) parameters for the optional specular net, defaults to None
+        :param appearance `str`: (optional) options are [decomposite|combined], defaults to "decomposite"
+        :param density_color_connection `bool`: (optional) whether to add connections between 
+                                                density net and color net, defaults to False
+        """
+        super().__init__()
+        self.density_chns = density_chns
+        self.color_chns = color_chns
+        self.specular_feature_chns = color_net_params["nf"] if specular_net_params else 0
+        self.color_feature_chns = density_net_params["nf"] if density_color_connection else 0
+        self.appearance = appearance
+        self.density_color_connection = density_color_connection
+        self.density_net = FcNet(**density_net_params,
+                                 in_chns=x_chns,
+                                 out_chns=self.density_chns + self.color_feature_chns,
+                                 out_act='relu')
+        if self.appearance == "newtype":
+            self.specular_feature_chns = d_chns * 3
+            self.color_net = FcNet(**color_net_params,
+                                   in_chns=x_chns + self.color_feature_chns,
+                                   out_chns=self.color_chns + self.specular_feature_chns)
+            self.specular_net = "Placeholder"
+        else:
+            if self.appearance == "decomposite":
+                self.color_net = FcNet(**color_net_params,
+                                       in_chns=x_chns + self.color_feature_chns,
+                                       out_chns=self.color_chns + self.specular_feature_chns)
+            else:
+                if specular_net_params:
+                    self.color_net = FcNet(**color_net_params,
+                                           in_chns=x_chns + self.color_feature_chns,
+                                           out_chns=self.specular_feature_chns)
+                else:
+                    self.color_net = FcNet(**color_net_params,
+                                           in_chns=x_chns + d_chns + self.color_feature_chns,
+                                           out_chns=self.color_chns)
+            self.specular_net = FcNet(**specular_net_params,
+                                      in_chns=d_chns + self.specular_feature_chns,
+                                      out_chns=self.color_chns) if specular_net_params else None
+
+    def forward(self, x: torch.Tensor, d: torch.Tensor, outputs: List[str], *,
+                color_feats: torch.Tensor = None) -> Dict[str, torch.Tensor]:
+        input_shape = x.shape[:-1]
+        if len(input_shape) > 1:
+            x = x.flatten(0, -2)
+            d = d.flatten(0, -2)
+        n = x.shape[0]
+        c = self.color_chns
+
+        ret: Dict[str, torch.Tensor] = {}
+
+        if 'density' in outputs:
+            density_net_out: torch.Tensor = self.density_net(x)
+            ret['density'] = density_net_out[:, :self.density_chns]
+            color_feats = density_net_out[:, self.density_chns:]
+            if 'color_feat' in outputs:
+                ret['color_feat'] = color_feats
+
+        if 'color' in outputs or 'specluar' in outputs:
+            if 'density' in ret:
+                valid_mask = ret['density'][:, 0].detach() >= 1e-4
+                indices = valid_mask.nonzero()[:, 0]
+                x, d, color_feats = x[indices], d[indices], color_feats[indices]
+            else:
+                indices = None
+
+            speculars = None
+            color_net_in = [x]
+            if not self.specular_net:
+                color_net_in.append(d)
+            if self.density_color_connection:
+                color_net_in.append(color_feats)
+            color_net_in = torch.cat(color_net_in, -1)
+            color_net_out: torch.Tensor = self.color_net(color_net_in)
+            diffuses = color_net_out[:, :c]
+            specular_features = color_net_out[:, -self.specular_feature_chns:]
+
+            if self.appearance == "newtype":
+                speculars = torch.bmm(specular_features.reshape(n, 3, d.shape[-1]),
+                                      d[..., None])[..., 0]
+                # TODO relu or not?
+                diffuses = diffuses.relu()
+                speculars = speculars.relu()
+                colors = diffuses + speculars
+            else:
+                if not self.specular_net:
+                    colors = diffuses
+                    diffuses = None
+                else:
+                    specular_net_in = torch.cat([d, specular_features], -1)
+                    specular_net_out = self.specular_net(specular_net_in)
+                    if self.appearance == "decomposite":
+                        speculars = specular_net_out
+                        colors = diffuses + speculars
+                    else:
+                        diffuses = None
+                        colors = specular_net_out
+                colors = torch.sigmoid(colors) # TODO indent or not?
+            if 'color' in outputs:
+                ret['color'] = colors.new_zeros(n, c).index_copy(0, indices, colors) \
+                    if indices else colors
+            if 'diffuse' in outputs:
+                ret['diffuse'] = diffuses.new_zeros(n, c).index_copy(0, indices, diffuses) \
+                    if indices is not None and diffuses is not None else diffuses
+            if 'specular' in outputs:
+                ret['specular'] = speculars.new_zeros(n, c).index_copy(0, indices, speculars) \
+                    if indices is not None and speculars is not None else speculars
+
+        if len(input_shape) > 1:
+            ret = {key: val.reshape(*input_shape, -1) for key, val in ret.items()}
+        return ret
diff --git a/modules/generic.py b/modules/generic.py
index fe9e234..c8b0987 100644
--- a/modules/generic.py
+++ b/modules/generic.py
@@ -34,7 +34,7 @@ class Sine(nn.Module):
 
 class FcLayer(nn.Module):
 
-    def __init__(self, in_chns: int, out_chns: int, activation: str = 'linear', skip_chns: int = 0):
+    def __init__(self, in_chns: int, out_chns: int, act: str = 'linear', skip_chns: int = 0):
         super().__init__()
         nls_and_inits = {
             'sine': (Sine(), sine_init),
@@ -48,7 +48,7 @@ class FcLayer(nn.Module):
             'logsoftmax': (nn.LogSoftmax(dim=-1), softmax_init),
             'linear': (None, None)
         }
-        nl, nl_weight_init = nls_and_inits[activation]
+        nl, nl_weight_init = nls_and_inits[act]
 
         self.net = nn.Sequential(
             nn.Linear(in_chns + skip_chns, out_chns),
@@ -59,7 +59,7 @@ class FcLayer(nn.Module):
         if nl_weight_init is not None:
             nl_weight_init(self.net if isinstance(self.net, nn.Linear) else self.net[0])
         else:
-            self.init_params(activation)
+            self.init_params(act)
 
     def forward(self, x: torch.Tensor, x0: torch.Tensor = None) -> torch.Tensor:
         return self.net(torch.cat([x0, x], dim=-1) if self.skip else x)
@@ -68,9 +68,9 @@ class FcLayer(nn.Module):
         linear_net = self.net if isinstance(self.net, nn.Linear) else self.net[0]
         return linear_net.weight, linear_net.bias
     
-    def init_params(self, activation):
+    def init_params(self, act):
         weight, bias = self.get_params()
-        nn.init.xavier_normal_(weight, gain=nn.init.calculate_gain(activation))
+        nn.init.xavier_normal_(weight, gain=nn.init.calculate_gain(act))
         nn.init.zeros_(bias)
 
     def copy_to(self, layer):
@@ -83,7 +83,7 @@ class FcLayer(nn.Module):
 class FcNet(nn.Module):
 
     def __init__(self, *, in_chns: int, out_chns: int, nf: int, n_layers: int,
-                 skips: List[int] = [], activation: str = 'relu'):
+                 skips: List[int] = [], act: str = 'relu', out_act = 'linear'):
         """
         Initialize a full-connection net
 
@@ -95,12 +95,12 @@ class FcNet(nn.Module):
         """
         super().__init__()
 
-        self.layers = [FcLayer(in_chns, nf, activation)] + [
-            FcLayer(nf, nf, activation, skip_chns=in_chns if i in skips else 0)
+        self.layers = [FcLayer(in_chns, nf, act)] + [
+            FcLayer(nf, nf, act, skip_chns=in_chns if i in skips else 0)
             for i in range(n_layers - 1)
         ]
-        if out_chns > 0:
-            self.layers.append(FcLayer(nf, out_chns))
+        if out_chns:
+            self.layers.append(FcLayer(nf, out_chns, out_act))
         for i, layer in enumerate(self.layers):
             self.add_module(f"layer{i}", layer)
 
diff --git a/modules/renderer.py b/modules/renderer.py
index 4f8b467..84a86b4 100644
--- a/modules/renderer.py
+++ b/modules/renderer.py
@@ -1,8 +1,44 @@
+from itertools import cycle
+from math import ceil
+from typing import Dict, Tuple, Union
 import torch
 import torch.nn as nn
-import torch.nn.functional as nn_f
+
 from utils.constants import *
+from utils.perf import perf
 from .generic import *
+from .sampler import Samples
+
+
+def density2energy(densities: torch.Tensor, dists: torch.Tensor, raw_noise_std: float = 0):
+    """
+    Calculate energies from densities inferred by model.
+
+    :param densities `Tensor(N..., 1)`: model's output densities
+    :param dists `Tensor(N...)`: integration times
+    :param raw_noise_std `float`: the noise std used to egularize network during training (prevents 
+                                  floater artifacts), defaults to 0, means no noise is added
+    :return `Tensor(N..., 1)`: energies which block light rays
+    """
+    if raw_noise_std > 0:
+        # Add noise to model's predictions for density. Can be used to
+        # regularize network during training (prevents floater artifacts).
+        densities = densities + torch.normal(0.0, raw_noise_std, densities.size())
+    return densities * dists[..., None]
+
+
+def density2alpha(densities: torch.Tensor, dists: torch.Tensor, raw_noise_std: float = 0):
+    """
+    Calculate alphas from densities inferred by model.
+
+    :param densities `Tensor(N..., 1)`: model's output densities
+    :param dists `Tensor(N...)`: integration times
+    :param raw_noise_std `float`: the noise std used to egularize network during training (prevents 
+                                  floater artifacts), defaults to 0, means no noise is added
+    :return `Tensor(N..., 1)`: alphas
+    """
+    energies = density2energy(densities, dists, raw_noise_std)
+    return 1.0 - torch.exp(-energies)
 
 
 class AlphaComposition(nn.Module):
@@ -11,18 +47,26 @@ class AlphaComposition(nn.Module):
         super().__init__()
 
     def forward(self, colors, alphas, bg=None):
+        """
+        [summary]
+
+        :param colors `Tensor(N, P, C)`: [description]
+        :param alphas `Tensor(N, P, 1)`: [description]
+        :param bg `Tensor([N, ]C)`: [description], defaults to None
+        :return `Tensor(N, C)`: [description]
+        """
         # Compute weight for RGB of each sample along each ray.  A cumprod() is
         # used to express the idea of the ray not having reflected up to this
         # sample yet.
-        one_minus_alpha = torch.cumprod(1 - alphas[..., :-1] + TINY_FLOAT, dim=-1)
+        one_minus_alpha = torch.cumprod(1 - alphas[..., :-1, :] + TINY_FLOAT, dim=-2)
         one_minus_alpha = torch.cat([
-            torch.ones_like(one_minus_alpha[..., 0:1]),
+            torch.ones_like(one_minus_alpha[..., :1, :]),
             one_minus_alpha
-        ], dim=-1)
-        weights = alphas * one_minus_alpha  # (N_rays, N)
+        ], dim=-2)
+        weights = alphas * one_minus_alpha  # (N, P, 1)
 
-        # (N_rays, 1|3), computed weighted color of each sample along each ray.
-        final_color = torch.sum(weights[..., None] * colors, dim=-2)
+        # (N, C), computed weighted color of each sample along each ray.
+        final_color = torch.sum(weights * colors, dim=-2)
 
         # To composite onto a white background, use the accumulated alpha map.
         if bg is not None:
@@ -38,58 +82,290 @@ class AlphaComposition(nn.Module):
 
 class VolumnRenderer(nn.Module):
 
-    def __init__(self, *, raw_noise_std=0.0, sigma_as_density=True):
-        """
-        Initialize a Rendering module
-        """
+    class States:
+        kernel: nn.Module
+        samples: Samples
+        hit_mask: torch.Tensor
+        early_stop_tolerance: float
+        N: int
+        P: int
+
+        colors: torch.Tensor
+        diffuses: torch.Tensor
+        speculars: torch.Tensor
+        energies: torch.Tensor
+        weights: torch.Tensor
+        cum_energies: torch.Tensor
+        exp_energies: torch.Tensor
+        tot_evaluations: Dict[str, int]
+
+        chunk: Tuple[slice, slice]
+        cum_chunk: Tuple[slice, slice]
+        cum_last: Tuple[slice, slice]
+        chunk_id: int
+
+        @property
+        def start(self) -> int:
+            return self.chunk[1].start
+
+        @property
+        def end(self) -> int:
+            return self.chunk[1].stop
+
+        def __init__(self, kernel: nn.Module, samples: Samples, early_stop_tolerance: float) -> None:
+            self.kernel = kernel
+            self.samples = samples
+            self.early_stop_tolerance = early_stop_tolerance
+
+            N, P = samples.size
+            self.hit_mask = samples.voxel_indices != -1  # (N, P)
+            self.colors = torch.zeros(N, P, kernel.chns('color'), device=samples.device)
+            self.diffuses = torch.zeros(N, P, kernel.chns('color'), device=samples.device)
+            self.speculars = torch.zeros(N, P, kernel.chns('color'), device=samples.device)
+            self.energies = torch.zeros(N, P, 1, device=samples.device)
+            self.weights = torch.zeros(N, P, 1, device=samples.device)
+            self.cum_energies = torch.zeros(N, P + 1, 1, device=samples.device)
+            self.exp_energies = torch.ones(N, P + 1, 1, device=samples.device)
+            self.tot_evaluations = {}
+            self.N, self.P = N, P
+            self.chunk_id = -1
+
+        def n_hits(self, start: int = None, end: int = None) -> int:
+            if start is None:
+                return self.hit_mask.count_nonzero().item()
+            if end is None:
+                return self.hit_mask[:, start].count_nonzero().item()
+            return self.hit_mask[:, start:end].count_nonzero().item()
+
+        def accumulate_tot_evaluations(self, key: str, n: int):
+            if key not in self.tot_evaluations:
+                self.tot_evaluations[key] = 0
+            self.tot_evaluations[key] += n
+
+        def next_chunk(self, *, length=None, end=None):
+            start = 0 if not hasattr(self, "chunk") else self.end
+            length = length or self.P
+            end = min(end or start + length, self.P)
+            self.chunk = slice(None), slice(start, end)
+            self.cum_chunk = slice(None), slice(start + 1, end + 1)
+            self.cum_last = slice(None), slice(start, start + 1)
+            self.chunk_id += 1
+            return self
+
+    def __init__(self, **kwargs):
         super().__init__()
-        self.alpha_composition = AlphaComposition()
-        self.sigma_as_density = sigma_as_density
-        self.raw_noise_std = raw_noise_std
-
-    def forward(self, colors, sigmas, z_vals, bg_color=None, ret_depth=False, debug=False):
-        """Transforms model's predictions to semantically meaningful values.
-
-        Args:
-          color: [num_rays, num_samples along ray, 1|3]. Predicted color from model.
-          density: [num_rays, num_samples along ray]. Predicted density from model.
-          z_vals: [num_rays, num_samples along ray]. Integration time.
-
-        Returns:
-          rgb_map: [num_rays, 1|3]. Estimated RGB color of a ray.
-          disp_map: [num_rays]. Disparity map. Inverse of depth map.
-          acc_map: [num_rays]. Sum of weights along each ray.
-          weights: [num_rays, num_samples]. Weights assigned to each sampled color.
-          depth_map: [num_rays]. Estimated distance to object.
+
+    @perf
+    def forward(self, kernel: nn.Module, samples: Samples, extra_outputs: List[str] = [], *,
+                raymarching_early_stop_tolerance: float = 0,
+                raymarching_chunk_size_or_sections: Union[int, List[int]] = None,
+                **kwargs):
+        """
+        Perform volumn rendering.
+
+        :param kernel: render kernel
+        :param samples `Samples(N, P)`: samples
+        :param extra_outputs `list[str]`: extra items should be contained in the result dict.
+                Optional values include 'depth', 'layers', 'states' and attribute names in class `States` (e.g. 'weights'). Defaults to []
+        :param raymarching_early_stop_tolerance `float`: tolerance of raymarching early stop.
+                Should between 0 and 1 (0 means no early stop). Defaults to 0
+        :param raymarching_chunk_size_or_sections `int|list[int]`: indicates how to split raymarching process.
+                Use a list of integers to specify samples of every chunk, or a positive integer to specify number of chunks.
+                Use a negative interger to split by number of hits in chunks, and the absolute value means maximum number of hits in a chunk.
+                0 and `None` means not splitting the raymarching process. Defaults to `None`
+        :return `dict`: render result { 'color'[, 'depth', 'layers', 'states', ...] }
         """
-        alphas = self.density2alpha(sigmas, z_vals) if self.sigma_as_density \
-            else nn_f.sigmoid(sigmas)
-        ret = self.alpha_composition(colors, alphas, bg_color)
-        if ret_depth:
-            ret['depth'] = torch.sum(ret['weights'] * z_vals, dim=-1)
-        if debug:
-            ret['layers'] = torch.cat([colors, alphas[..., None]], dim=-1)
+        if samples.size[1] == 0:
+            print("VolumnRenderer.forward(): # of samples is zero")
+            return None
+
+        s = VolumnRenderer.States(kernel, samples, raymarching_early_stop_tolerance)
+
+        if not raymarching_chunk_size_or_sections:
+            raymarching_chunk_size_or_sections = [s.P]
+        elif isinstance(raymarching_chunk_size_or_sections, int) and \
+                raymarching_chunk_size_or_sections > 0:
+            raymarching_chunk_size_or_sections = [ceil(s.P / raymarching_chunk_size_or_sections)]
+
+        if isinstance(raymarching_chunk_size_or_sections, list):
+            chunk_sections = raymarching_chunk_size_or_sections
+            for chunk_samples in cycle(chunk_sections):
+                self._forward_chunk(s.next_chunk(length=chunk_samples))
+                if s.end >= s.P:
+                    break
+        else:
+            chunk_size = -raymarching_chunk_size_or_sections
+            chunk_hits = s.n_hits(0)
+            for i in range(1, s.P):
+                n_hits = s.n_hits(i)
+                if chunk_hits + n_hits > chunk_size:
+                    self._forward_chunk(s.next_chunk(end=i))
+                    n_hits = s.n_hits(i)
+                    chunk_hits = 0
+                chunk_hits += n_hits
+            self._forward_chunk(s.next_chunk())
+
+        ret = {
+            'color': torch.sum(s.colors * s.weights, 1),
+            'tot_evaluations': s.tot_evaluations
+        }
+        for key in extra_outputs:
+            if key == 'depth':
+                ret['depth'] = torch.sum(s.samples.depths[..., None] * s.weights, 1)
+            elif key == 'diffuse':
+                ret['diffuse'] = torch.sum(s.diffuses * s.weights, 1)
+            elif key == 'specular':
+                ret['specular'] = torch.sum(s.speculars * s.weights, 1)
+            elif key == 'layers':
+                ret['layers'] = torch.cat([s.colors, 1 - torch.exp(-s.energies)], dim=-1)
+            elif key == 'states':
+                ret['states'] = s
+            else:
+                ret[key] = getattr(s, key)
         return ret
 
-    def density2alpha(self, densities: torch.Tensor, z_vals: torch.Tensor):
+        # if raymarching_chunk_size == 0:
+        #     raymarching_chunk_samples = 1
+        # if raymarching_chunk_samples != 0:
+        #     if isinstance(raymarching_chunk_samples, int):
+        #         raymarching_chunk_samples = repeat(raymarching_chunk_samples,
+        #                                            ceil(s.P / raymarching_chunk_samples))
+        #     chunk_offset = 0
+        #     for chunk_samples in raymarching_chunk_samples:
+        #         start, end = chunk_offset, chunk_offset + chunk_samples
+        #         n_hits = self._forward_chunk(s, start, end)
+        #         if n_hits > 0 and tolerance > 0:  # Early stop
+        #             s.hit_mask[s.cum_energies[:, end, 0] > tolerance] = 0
+        #         chunk_offset += chunk_samples
+        # elif raymarching_chunk_size > 0:
+        #     chunk_offset, chunk_hits = 0, s.n_hits(0)
+        #     for i in range(1, s.P):
+        #         n_hits = s.n_hits(i)
+        #         if chunk_hits + n_hits > raymarching_chunk_size:
+        #             self._forward_chunk(s, chunk_offset, i, chunk_hits)
+        #             if chunk_hits > 0 and tolerance > 0:  # Early stop
+        #                 s.hit_mask[s.cum_energies[:, i, 0] > tolerance] = 0
+        #                 n_hits = s.n_hits(i)
+        #             chunk_hits, chunk_offset = 0, i
+        #         chunk_hits += n_hits
+        #     self._forward_chunk(s, chunk_offset, s.P, chunk_hits)
+        # else:
+        #     self._forward_chunk(s, 0, s.P)
+
+        # return self._composite(s, extra_outputs)
+        # original_depth = samples.get('original_point_depth', None)
+        # if original_depth is not None:
+        #    results['z'] = (original_depth * probs).sum(-1)
+        # if getattr(input_fn, "track_max_probs", False) and (not self.training):
+        #    input_fn.track_voxel_probs(samples['sampled_point_voxel_idx'].long(), results['probs'])
+
+    def _calc_weights(self, s: States):
+        """
+        Calculate weights of samples in composited outputs
+
+        :param s `States`: states
+        :param start `int`: chunk's start
+        :param end `int`: chunk's end
         """
-        Raw value inferred from model to color and alpha
+        s.cum_energies[s.cum_chunk] = torch.cumsum(s.energies[s.chunk], 1) \
+            + s.cum_energies[s.cum_last]
+        s.exp_energies[s.cum_chunk] = (-s.cum_energies[s.cum_chunk]).exp()
+        s.weights[s.chunk] = s.exp_energies[s.chunk] - s.exp_energies[s.cum_chunk]
 
-        :param densities `Tensor(N.rays, N.samples)`: model's output density
-        :param z_vals `Tensor(N.rays, N.samples)`: integration time
-        :return `Tensor(N.rays, N.samples)`: alpha
+    def _apply_early_stop(self, s: States):
         """
+        Stop rays whose accumulated opacity are larger than a threshold
+
+        :param s `States`: s
+        :param end `int`: chunk's end
+        """
+        if s.end < s.P and s.early_stop_tolerance > 0:
+            rays_to_stop = s.exp_energies[:, s.end, 0] < s.early_stop_tolerance
+            s.hit_mask[rays_to_stop, s.end:] = 0
+
+    def _forward_chunk(self, s: States) -> int:
+        fi_idxs: Tuple[torch.Tensor, ...] = s.hit_mask[s.chunk].nonzero(as_tuple=True)  # (N')
+        fi_idxs[1].add_(s.start)
+
+        if fi_idxs[0].size(0) == 0:
+            s.cum_energies[s.cum_chunk] = s.cum_energies[s.cum_last]
+            s.exp_energies[s.cum_chunk] = s.exp_energies[s.cum_last]
+            return 0
+
+        # fi_* means "filtered" by hit mask
+        fi_samples = s.samples[fi_idxs]  # N -> N'
+
+        # Infer densities and colors
+        fi_outputs = s.kernel.render(fi_samples, 'color', 'density', 'specular', 'diffuse',
+                                     chunk_id=s.chunk_id)
+        s.colors.index_put_(fi_idxs, fi_outputs['color'])
+        if fi_outputs['specular'] is not None:
+            s.speculars.index_put_(fi_idxs, fi_outputs['specular'])
+        if fi_outputs['diffuse'] is not None:
+            s.diffuses.index_put_(fi_idxs, fi_outputs['diffuse'])
+        s.energies.index_put_(fi_idxs, density2energy(fi_outputs['density'], fi_samples.dists))
+        s.accumulate_tot_evaluations("color", fi_idxs[0].size(0))
+
+        self._calc_weights(s)
+        self._apply_early_stop(s)
+
+
+class DensityFirstVolumnRenderer(VolumnRenderer):
+
+    def __init__(self, **kwargs):
+        super().__init__(**kwargs)
+
+    def _forward_chunk(self, s: VolumnRenderer.States) -> int:
+        fi_idxs: Tuple[torch.Tensor, ...] = s.hit_mask[s.chunk].nonzero(as_tuple=True)  # (N')
+        fi_idxs[1].add_(s.start)
+
+        if fi_idxs[0].size(0) == 0:
+            s.cum_energies[s.cum_chunk] = s.cum_energies[s.cum_last]
+            s.exp_energies[s.cum_chunk] = s.exp_energies[s.cum_last]
+            return 0
+
+        # fi_* means "filtered" by hit mask
+        fi_samples = s.samples[fi_idxs]  # N -> N'
+
+        # For all valid samples: encode X
+        fi_encoded_x = s.kernel.encode_x(fi_samples)  # (N', Ex)
+
+        # Infer densities (shape)
+        fi_outputs = s.kernel.infer(fi_encoded_x, None, 'density', 'color_feat',
+                                    chunk_id=s.chunk_id)
+        s.energies.index_put_(fi_idxs, density2energy(fi_outputs['density'], fi_samples.dists))
+        s.accumulate_tot_evaluations("density", fi_idxs[0].size(0))
+
+        self._calc_weights(s)
+        self._apply_early_stop(s)
+
+        # Remove samples whose weights are less than a threshold
+        s.hit_mask[s.chunk][s.weights[s.chunk][..., 0] < 0.01] = 0
+
+        # Update "filtered" tensors
+        fi_mask = s.hit_mask[fi_idxs]
+        fi_idxs = (fi_idxs[0][fi_mask], fi_idxs[1][fi_mask])  # N' -> N"
+        fi_encoded_x = fi_encoded_x[fi_mask]  # (N", Ex)
+        fi_color_feats = fi_outputs['color_feat'][fi_mask]
+
+        # For all valid samples: encode D
+        fi_encoded_d = s.kernel.encode_d(s.samples[fi_idxs])  # (N", Ed)
 
-        # Compute 'distance' (in time) between each integration time along a ray.
-        # The 'distance' from the last integration time is infinity.
-        # dists: (N_rays, N)
-        dists = z_vals[..., 1:] - z_vals[..., :-1]
-        last_dist = torch.zeros_like(z_vals[..., 0:1]) + TINY_FLOAT
-        dists = torch.cat([dists, last_dist], -1)
-
-        if self.raw_noise_std > 0.:
-            # Add noise to model's predictions for density. Can be used to
-            # regularize network during training (prevents floater artifacts).
-            noise = torch.normal(0.0, self.raw_noise_std, densities.size())
-            densities = densities + noise
-        return -torch.exp(-torch.relu(densities) * dists) + 1.0
+        # Infer colors (appearance)
+        fi_outputs = s.kernel.infer(fi_encoded_x, fi_encoded_d, 'color', 'specular', 'diffuse',
+                                    chunk_id=s.chunk_id,
+                                    extras={"color_feats": fi_color_feats})
+        # if s.chunk_id == 0:
+        #     fi_colors[:] *= fi_colors.new_tensor([1, 0, 0])
+        # elif s.chunk_id == 1:
+        #     fi_colors[:] *= fi_colors.new_tensor([0, 1, 0])
+        # elif s.chunk_id == 2:
+        #     fi_colors[:] *= fi_colors.new_tensor([0, 0, 1])
+        # else:
+        #     fi_colors[:] *= fi_colors.new_tensor([1, 1, 0])
+        s.colors.index_put_(fi_idxs, fi_outputs['color'])
+        if fi_outputs['specular'] is not None:
+            s.speculars.index_put_(fi_idxs, fi_outputs['specular'])
+        if fi_outputs['diffuse'] is not None:
+            s.diffuses.index_put_(fi_idxs, fi_outputs['diffuse'])
+        s.accumulate_tot_evaluations("color", fi_idxs[0].size(0))
diff --git a/modules/sampler.py b/modules/sampler.py
index eacd072..1eae990 100644
--- a/modules/sampler.py
+++ b/modules/sampler.py
@@ -1,14 +1,26 @@
-from typing import Tuple
+from .space import Space, Voxels
 import torch
 import torch.nn as nn
+from typing import Tuple
+
 from utils import device
 from utils import sphere
 from utils.constants import *
+from utils.perf import perf, checkpoint
 from .generic import *
+from clib import *
 
 
 class Bins(object):
 
+    @property
+    def up(self):
+        return self.bounds[1:]
+
+    @property
+    def lo(self):
+        return self.bounds[:-1]
+
     def __init__(self, vals: torch.Tensor):
         self.vals = vals
         self.bounds = torch.cat([
@@ -16,8 +28,6 @@ class Bins(object):
             0.5 * (self.vals[1:] + self.vals[:-1]),
             self.vals[-1:]
         ])
-        self.up = self.bounds[1:]
-        self.lo = self.bounds[:-1]
 
     @staticmethod
     def linspace(val_range: Tuple[float, float], N: int, device: torch.device = None):
@@ -26,14 +36,60 @@ class Bins(object):
     def to(self, device: torch.device):
         self.vals = self.vals.to(device)
         self.bounds = self.bounds.to(device)
-        self.up = self.bounds[1:]
-        self.lo = self.bounds[:-1]
+
+
+class Samples:
+    pts: torch.Tensor
+    """`Tensor(N[, P], 3)`"""
+
+    dirs: torch.Tensor
+    """`Tensor(N[, P], 3)`"""
+
+    depths: torch.Tensor
+    """`Tensor(N[, P])`"""
+
+    dists: torch.Tensor
+    """`Tensor(N[, P])`"""
+
+    voxel_indices: torch.Tensor
+    """`Tensor(N[, P])`"""
+
+    @property
+    def size(self):
+        return self.pts.size()[:-1]
+
+    @property
+    def device(self):
+        return self.pts.device
+
+    def __init__(self, pts: torch.Tensor, dirs: torch.Tensor, depths: torch.Tensor,
+                 dists: torch.Tensor, voxel_indices: torch.Tensor) -> None:
+        self.pts = pts
+        self.dirs = dirs
+        self.depths = depths
+        self.dists = dists
+        self.voxel_indices = voxel_indices
+
+    def __getitem__(self, index):
+        return Samples(
+            pts=self.pts[index],
+            dirs=self.dirs[index],
+            depths=self.depths[index],
+            dists=self.dists[index],
+            voxel_indices=self.voxel_indices[index])
+
+    def reshape(self, *shape: int):
+        return Samples(
+            pts=self.pts.reshape(*shape, 3),
+            dirs=self.dirs.reshape(*shape, 3),
+            depths=self.depths.reshape(*shape),
+            dists=self.dists.reshape(*shape),
+            voxel_indices=self.voxel_indices.reshape(*shape))
 
 
 class Sampler(nn.Module):
 
-    def __init__(self, *, sample_range: Tuple[float, float], n_samples: int,
-                 perturb_sample: bool, spherical: bool, lindisp: bool):
+    def __init__(self, *, sample_range: Tuple[float, float], n_samples: int, lindisp: bool, **kwargs):
         """
         Initialize a Sampler module
 
@@ -44,37 +100,81 @@ class Sampler(nn.Module):
         """
         super().__init__()
         self.lindisp = lindisp
-        self.spherical = spherical
-        self.perturb_sample = perturb_sample
         s_range = (1 / sample_range[0], 1 / sample_range[1]) if self.lindisp else sample_range
+        if s_range[1] > s_range[0]:
+            s_range[0] += 1e-4
+            s_range[1] -= 1e-4
+        else:
+            s_range[0] -= 1e-4
+            s_range[1] += 1e-4
         self.bins = Bins.linspace(s_range, n_samples, device=device.default())
 
-    def forward(self, rays_o, rays_d):
+    @perf
+    def forward(self, rays_o: torch.Tensor, rays_d: torch.Tensor, space_module: Space,
+                perturb_sample: bool, **kwargs) -> Tuple[Samples, torch.Tensor]:
         """
         Sample points along rays. return Spherical or Cartesian coordinates, 
         specified by `self.shperical`
 
-        :param rays_o `Tensor(B, 3)`: rays' origin
-        :param rays_d `Tensor(B, 3)`: rays' direction
-        :return `Tensor(B, N, 3)`: sampled points
-        :return `Tensor(B, N)`: corresponding depths along rays
+        :param rays_o `Tensor(N, 3)`: rays' origin
+        :param rays_d `Tensor(N, 3)`: rays' direction
+        :return `Samples(N, P)`: samples
         """
         s = self.bins.vals.expand(rays_o.size(0), -1)
-        if self.perturb_sample:
+        if perturb_sample:
             s = self.bins.lo + (self.bins.up - self.bins.lo) * torch.rand_like(s)
+        pts, depths = self._get_sample_points(rays_o, rays_d, s)
+        voxel_indices = space_module.get_voxel_indices(pts)
+        valid_rays_mask = voxel_indices.ne(-1).any(dim=-1)
+        return Samples(
+            pts=pts,
+            dirs=rays_d[:, None].expand(-1, depths.size(1), -1),
+            depths=depths,
+            dists=self._calc_dists(depths),
+            voxel_indices=voxel_indices
+        )[valid_rays_mask], valid_rays_mask
+
+    def _get_sample_points(self, rays_o, rays_d, s):
         z = torch.reciprocal(s) if self.lindisp else s
-        if self.spherical:
-            pts, depths = sphere.ray_sphere_intersect(rays_o, rays_d, z)
-            sphers = sphere.cartesian2spherical(pts, inverse_r=self.lindisp)
-            return sphers, depths, s, pts
-        else:
-            return rays_o[..., None, :] + rays_d[..., None, :] * z[..., None], z, s, None
+        pts = rays_o[:, None] + rays_d[:, None] * z[..., None]
+        depths = z
+        return pts, depths
+
+    def _calc_dists(self, vals):
+        # Compute 'distance' (in time) between each integration time along a ray.
+        # The 'distance' from the last integration time is infinity.
+        # dists: (N_rays, N)
+        dists = vals[..., 1:] - vals[..., :-1]
+        last_dist = torch.zeros_like(vals[..., :1]) + TINY_FLOAT
+        return torch.cat([dists, last_dist], -1)
+
+
+class SphericalSampler(Sampler):
+
+    def __init__(self, *, sample_range: Tuple[float, float], n_samples: int,
+                 perturb_sample: bool, **kwargs):
+        """
+        Initialize a Sampler module
+
+        :param depth_range: depth range for sampler
+        :param n_samples: count to sample along ray
+        :param perturb_sample: perturb the sample depths
+        :param lindisp: If True, sample linearly in inverse depth rather than in depth
+        """
+        super().__init__(sample_range=sample_range, n_samples=n_samples,
+                         perturb_sample=perturb_sample, lindisp=False)
+
+    def _get_sample_points(self, rays_o, rays_d, s):
+        r = torch.reciprocal(s)
+        pts, depths = sphere.ray_sphere_intersect(rays_o, rays_d, r)
+        pts = sphere.cartesian2spherical(pts, inverse_r=True)
+        return pts, depths
 
 
 class PdfSampler(nn.Module):
 
     def __init__(self, *, depth_range: Tuple[float, float], n_samples: int, perturb_sample: bool,
-                 spherical: bool, lindisp: bool):
+                 spherical: bool, lindisp: bool, **kwargs):
         """
         Initialize a Sampler module
 
@@ -90,7 +190,7 @@ class PdfSampler(nn.Module):
         self.n_samples = n_samples
         self.s_range = (1 / depth_range[0], 1 / depth_range[1]) if self.lindisp else depth_range
 
-    def forward(self, rays_o, rays_d, *, weights, s_vals=None, include_s_vals=False):
+    def forward(self, rays_o, rays_d, *, weights, s_vals=None, include_s_vals=False, **kwargs):
         """
         Sample points along rays. return Spherical or Cartesian coordinates, 
         specified by `self.shperical`
@@ -166,22 +266,116 @@ class PdfSampler(nn.Module):
 
 class VoxelSampler(nn.Module):
 
-    def __init__(self, *, depth_range: Tuple[float, float], n_samples: int, perturb_sample: bool,
-                 lindisp: bool, space):
+    def __init__(self, *, perturb_sample: bool, sample_step: float, **kwargs):
         """
-        Initialize a Sampler module
+        Initialize a VoxelSampler module
 
-        :param depth_range: depth range for sampler
-        :param n_samples: count to sample along ray
         :param perturb_sample: perturb the sample depths
-        :param lindisp: If True, sample linearly in inverse depth rather than in depth
+        :param step_size: step size
         """
         super().__init__()
-        self.lindisp = lindisp
         self.perturb_sample = perturb_sample
-        self.n_samples = n_samples
-        self.space = space
-        self.s_range = (1 / depth_range[0], 1 / depth_range[1]) if self.lindisp else depth_range
+        self.sample_step = sample_step
+
+    def _forward(self, rays_o: torch.Tensor, rays_d: torch.Tensor, space_module: Space,
+                 **kwargs) -> Tuple[Samples, torch.Tensor]:
+        """
+        [summary]
 
-    def forward(self, rays_o, rays_d, *, weights, s_vals=None, include_s_vals=False):
-        
\ No newline at end of file
+        :param rays_o `Tensor(N, 3)`: rays' origin positions
+        :param rays_d `Tensor(N, 3)`: rays' directions
+        :param step_size `float`: gap between samples along a ray
+        :return `Samples(N', P)`: samples along valid rays (which hit at least one voxel)
+        :return `Tensor(N)`: valid rays mask
+        """
+        intersections = space_module.ray_intersect(rays_o, rays_d, 100)
+        valid_rays_mask = intersections.hits > 0
+        rays_o = rays_o[valid_rays_mask]
+        rays_d = rays_d[valid_rays_mask]
+        intersections = intersections[valid_rays_mask]  # (N) -> (N')
+        n_rays = rays_o.size(0)
+        ray_index_list = torch.arange(n_rays, device=rays_o.device, dtype=torch.long)  # (N')
+
+        hits = intersections.hits
+        min_depths = intersections.min_depths
+        max_depths = intersections.max_depths
+        voxel_indices = intersections.voxel_indices
+
+        rays_near_depth = min_depths[:, :1]  # (N', 1)
+        rays_far_depth = max_depths[ray_index_list, hits - 1][:, None]  # (N', 1)
+        rays_length = rays_far_depth - rays_near_depth
+        rays_steps = (rays_length / self.sample_step).ceil().long()
+        rays_step_size = rays_length / rays_steps
+        max_steps = rays_steps.max().item()
+        rays_step = torch.arange(max_steps, device=rays_o.device,
+                                 dtype=torch.float)[None].repeat(n_rays, 1)  # (N', P)
+        invalid_samples_mask = rays_step >= rays_steps
+        samples_min_depth = rays_near_depth + rays_step * rays_step_size
+        samples_depth = samples_min_depth + rays_step_size \
+            * (torch.rand_like(samples_min_depth) if self.perturb_sample else 0.5)  # (N', P)
+        samples_dist = rays_step_size.repeat(1, max_steps)  # (N', 1) -> (N', P)
+        samples_voxel_index = voxel_indices[
+            ray_index_list[:, None],
+            torch.searchsorted(max_depths, samples_depth)
+        ]  # (N', P)
+        samples_depth[invalid_samples_mask] = HUGE_FLOAT
+        samples_dist[invalid_samples_mask] = 0
+        samples_voxel_index[invalid_samples_mask] = -1
+
+        rays_o, rays_d = rays_o[:, None], rays_d[:, None]
+        return Samples(
+            pts=rays_o + rays_d * samples_depth[..., None],
+            dirs=rays_d.expand(-1, max_steps, -1),
+            depths=samples_depth,
+            dists=samples_dist,
+            voxel_indices=samples_voxel_index
+        ), valid_rays_mask
+
+    @perf
+    def forward(self, rays_o: torch.Tensor, rays_d: torch.Tensor, space_module: Space,
+                **kwargs) -> Tuple[Samples, torch.Tensor]:
+        """
+        [summary]
+
+        :param rays_o `Tensor(N, 3)`: [description]
+        :param rays_d `Tensor(N, 3)`: [description]
+        :param step_size `float`: [description]
+        :return `Samples(N, P)`: [description]
+        """
+        intersections = space_module.ray_intersect(rays_o, rays_d, 100)
+        valid_rays_mask = intersections.hits > 0
+        rays_o = rays_o[valid_rays_mask]
+        rays_d = rays_d[valid_rays_mask]
+        intersections = intersections[valid_rays_mask]  # (N) -> (N')
+
+        checkpoint("Ray intersect")
+
+        if intersections.size == 0:
+            return None, valid_rays_mask
+        else:
+            min_depth = intersections.min_depths
+            max_depth = intersections.max_depths
+            pts_idx = intersections.voxel_indices
+            dists = max_depth - min_depth
+            tot_dists = dists.sum(dim=-1, keepdim=True)  # (N, 1)
+            probs = dists / tot_dists
+            steps = tot_dists[:, 0] / self.sample_step
+
+            # sample points and use middle point approximation
+            sampled_indices, sampled_depths, sampled_dists = inverse_cdf_sampling(
+                pts_idx, min_depth, max_depth, probs, steps, -1, not self.perturb_sample)
+            sampled_indices = sampled_indices.long()
+            invalid_idx_mask = sampled_indices.eq(-1)
+            sampled_dists.clamp_min_(0).masked_fill_(invalid_idx_mask, 0)
+            sampled_depths.masked_fill_(invalid_idx_mask, HUGE_FLOAT)
+
+            checkpoint("Inverse CDF sampling")
+
+            rays_o, rays_d = rays_o[:, None], rays_d[:, None]
+            return Samples(
+                pts=rays_o + rays_d * sampled_depths[..., None],
+                dirs=rays_d.expand(-1, sampled_depths.size(1), -1),
+                depths=sampled_depths,
+                dists=sampled_dists,
+                voxel_indices=sampled_indices
+            ), valid_rays_mask
diff --git a/modules/space.py b/modules/space.py
new file mode 100644
index 0000000..26dac98
--- /dev/null
+++ b/modules/space.py
@@ -0,0 +1,351 @@
+from math import ceil
+import torch
+import numpy as np
+from typing import List, NoReturn, Tuple, Union
+from torch import nn
+from plyfile import PlyData, PlyElement
+
+from utils.geometry import *
+from utils.constants import *
+from utils.voxels import *
+from utils.perf import perf
+from clib import *
+
+
+class Intersections:
+    min_depths: torch.Tensor
+    """`Tensor(N, P)` Min ray depths of intersected voxels"""
+
+    max_depths: torch.Tensor
+    """`Tensor(N, P)` Max ray depths of intersected voxels"""
+
+    voxel_indices: torch.Tensor
+    """`Tensor(N, P)` Indices of intersected voxels"""
+
+    hits: torch.Tensor
+    """`Tensor(N)` Number of hits"""
+
+    @property
+    def size(self):
+        return self.hits.size(0)
+
+    def __init__(self, min_depths: torch.Tensor, max_depths: torch.Tensor,
+                 voxel_indices: torch.Tensor, hits: torch.Tensor) -> None:
+        self.min_depths = min_depths
+        self.max_depths = max_depths
+        self.voxel_indices = voxel_indices
+        self.hits = hits
+
+    def __getitem__(self, index):
+        return Intersections(
+            min_depths=self.min_depths[index],
+            max_depths=self.max_depths[index],
+            voxel_indices=self.voxel_indices[index],
+            hits=self.hits[index])
+
+
+class Space(nn.Module):
+    bbox: Union[torch.Tensor, None]
+    """`Tensor(2, 3)` Bounding box"""
+
+    def __init__(self, *, bbox: List[float] = None, **kwargs):
+        super().__init__()
+        if bbox is None:
+            self.bbox = None
+        else:
+            self.register_buffer('bbox', torch.Tensor(bbox).reshape(2, 3), persistent=False)
+
+    def create_embedding(self, n_dims: int, name: str = 'default') -> torch.nn.Embedding:
+        raise NotImplementedError
+
+    def extract_embedding(self, pts: torch.Tensor, voxel_indices: torch.Tensor,
+                          name: str = 'default') -> torch.Tensor:
+        raise NotImplementedError
+
+    def ray_intersect(self, rays_o: torch.Tensor, rays_d: torch.Tensor, n_max_hits: int) -> Intersections:
+        raise NotImplementedError
+
+    def get_voxel_indices(self, pts: torch.Tensor) -> torch.Tensor:
+        voxel_indices = torch.zeros_like(pts[..., 0], dtype=torch.long)
+        if self.bbox is not None:
+            out_bbox = torch.logical_or(pts < self.bbox[0], pts >= self.bbox[1]).any(-1)  # (N...)
+            voxel_indices[out_bbox] = -1
+        return voxel_indices
+
+    @torch.no_grad()
+    def pruning(self, score_fn, threshold: float = 0.5, train_stats=False):
+        raise NotImplementedError()
+
+    @torch.no_grad()
+    def splitting(self):
+        raise NotImplementedError()
+
+
+class Voxels(Space):
+    steps: torch.Tensor
+    """`Tensor(3)` Steps along each dimension"""
+
+    corners: torch.Tensor
+    """`Tensor(C, 3)` Corner positions"""
+
+    voxels: torch.Tensor
+    """`Tensor(M, 3)` Voxel centers"""
+
+    corner_indices: torch.Tensor
+    """`Tensor(M, 8)` Voxel corner indices"""
+
+    voxel_indices_in_grid: torch.Tensor
+    """`Tensor(G)` Indices in voxel list or -1 for pruned space"""
+
+    @property
+    def dims(self) -> int:
+        """`int` Number of dimensions"""
+        return self.steps.size(0)
+
+    @property
+    def n_voxels(self) -> int:
+        """`int` Number of voxels"""
+        return self.voxels.size(0)
+
+    @property
+    def n_corner(self) -> int:
+        """`int` Number of corners"""
+        return self.corners.size(0)
+
+    @property
+    def voxel_size(self) -> torch.Tensor:
+        """`Tensor(3)` Voxel size"""
+        return (self.bbox[1] - self.bbox[0]) / self.steps
+
+    @property
+    def device(self) -> torch.device:
+        return self.voxels.device
+
+    def __init__(self, *, voxel_size: float = None,
+                 steps: Union[torch.Tensor, Tuple[int, int, int]] = None, **kwargs) -> None:
+        super().__init__(**kwargs)
+        if self.bbox is None:
+            raise ValueError("Missing argument 'bbox'")
+        if voxel_size is not None:
+            self.register_buffer('steps', get_grid_steps(self.bbox, voxel_size))
+        else:
+            self.register_buffer('steps', torch.tensor(steps, dtype=torch.long))
+        self.register_buffer('voxels', init_voxels(self.bbox, self.steps))
+        corners, corner_indices = get_corners(self.voxels, self.bbox, self.steps)
+        self.register_buffer("corners", corners)
+        self.register_buffer("corner_indices", corner_indices)
+        self.register_buffer('voxel_indices_in_grid', torch.arange(self.n_voxels))
+        self._register_load_state_dict_pre_hook(self._before_load_state_dict)
+
+    def create_embedding(self, n_dims: int, name: str = 'default') -> torch.nn.Embedding:
+        """
+        Create a embedding on voxel corners.
+
+        :param name `str`: embedding name
+        :param n_dims `int`: embedding dimension
+        :return `Embedding(n_corners, n_dims)`: new embedding on voxel corners
+        """
+        name = f'emb_{name}'
+        self.add_module(name, torch.nn.Embedding(self.n_corners.item(), n_dims))
+        return self.__getattr__(name)
+
+    def get_embedding(self, name: str = 'default') -> torch.nn.Embedding:
+        return getattr(self, f'emb_{name}')
+
+    def extract_embedding(self, pts: torch.Tensor, voxel_indices: torch.Tensor,
+                          name: str = 'default') -> torch.Tensor:
+        """
+        Extract embedding values at given points using trilinear interpolation.
+
+        :param pts `Tensor(N, 3)`: points to extract values
+        :param voxel_indices `Tensor(N)`: corresponding voxel indices
+        :param name `str`: embedding name, default to 'default'
+        :return `Tensor(N, X)`: extracted values
+        """
+        emb = self.get_embedding(name)
+        if emb is None:
+            raise KeyError(f"Embedding '{name}' doesn't exist")
+        voxels = self.voxels[voxel_indices]  # (N, 3)
+        corner_indices = self.corner_indices[voxel_indices]  # (N, 8)
+        p = (pts - voxels) / self.voxel_size + 0.5  # (N, 3) normed-coords in voxel
+        features = emb(corner_indices).reshape(pts.size(0), 8, -1)  # (N, 8, X)
+        return trilinear_interp(p, features)
+
+    @perf
+    def ray_intersect(self, rays_o: torch.Tensor, rays_d: torch.Tensor, n_max_hits: int) -> Intersections:
+        """
+        Calculate intersections of rays and voxels.
+
+        :param rays_o `Tensor(N, 3)`: rays' origin
+        :param rays_d `Tensor(N, 3)`: rays' direction
+        :param n_max_hits `int`: maximum number of hits (for allocating enough space)
+        :return `Intersection`: intersections of rays and voxels
+        """
+        # Prepend a dim to meet the requirement of external call
+        rays_o = rays_o[None].contiguous()
+        rays_d = rays_d[None].contiguous()
+
+        voxel_indices, min_depths, max_depths = self._ray_intersect(rays_o, rays_d, n_max_hits)
+        invalid_voxel_mask = voxel_indices.eq(-1)
+        hits = n_max_hits - invalid_voxel_mask.sum(-1)
+
+        # Sort intersections according to their depths
+        min_depths.masked_fill_(invalid_voxel_mask, HUGE_FLOAT)
+        max_depths.masked_fill_(invalid_voxel_mask, HUGE_FLOAT)
+        min_depths, sorted_idx = min_depths.sort(dim=-1)
+        max_depths = max_depths.gather(-1, sorted_idx)
+        voxel_indices = voxel_indices.gather(-1, sorted_idx)
+
+        return Intersections(
+            min_depths=min_depths[0],
+            max_depths=max_depths[0],
+            voxel_indices=voxel_indices[0],
+            hits=hits[0]
+        )
+
+    @perf
+    def get_voxel_indices(self, pts: torch.Tensor) -> torch.Tensor:
+        """
+        Get voxel indices of points.
+
+        If a point is not in any valid voxels, its corresponding voxel index is -1.
+
+        :param pts `Tensor(N..., 3)`: points
+        :return `Tensor(N...)`: corresponding voxel indices
+        """
+        grid_indices, out_mask = to_grid_indices(pts, self.bbox, steps=self.steps)
+        grid_indices[out_mask] = 0
+        voxel_indices = self.voxel_indices_in_grid[grid_indices]
+        voxel_indices[out_mask] = -1
+        return voxel_indices
+
+    @torch.no_grad()
+    def splitting(self) -> None:
+        """
+        Split voxels into smaller voxels with half size.
+        """
+        n_voxels_before = self.n_voxels
+        self.steps *= 2
+        self.voxels = split_voxels(self.voxels, self.voxel_size, 2, align_border=False)\
+            .reshape(-1, 3)
+        self._update_corners()
+        self._update_voxel_indices_in_grid()
+        return n_voxels_before, self.n_voxels
+
+    @torch.no_grad()
+    def prune(self, keeps: torch.Tensor) -> Tuple[int, int]:
+        self.voxels = self.voxels[keeps]
+        self.corner_indices = self.corner_indices[keeps]
+        self._update_voxel_indices_in_grid()
+        return keeps.size(0), keeps.sum().item()
+
+    @torch.no_grad()
+    def pruning(self, score_fn, threshold: float = 0.5) -> None:
+        scores = self._get_scores(score_fn, lambda x: torch.max(x, -1)[0])  # (M)
+        return self.prune(scores > threshold)
+
+    def n_voxels_along_dim(self, dim: int) -> torch.Tensor:
+        sum_dims = [val for val in range(self.dims) if val != dim]
+        return self.voxel_indices_in_grid.reshape(*self.steps).ne(-1).sum(sum_dims)
+
+    def balance_cut(self, dim: int, n_parts: int) -> List[int]:
+        n_voxels_list = self.n_voxels_along_dim(dim)
+        cdf = (n_voxels_list.cumsum(0) / self.n_voxels * n_parts).tolist()
+        bins = []
+        part = 1
+        offset = 0
+        for i in range(len(cdf)):
+            if cdf[i] >= part:
+                bins.append(i + 1 - offset)
+                offset = i + 1
+                part = int(cdf[i]) + 1
+        return bins
+
+    def sample(self, bits: int, perturb: bool = False) -> Tuple[torch.Tensor, torch.Tensor]:
+        sampled_xyz = split_voxels(self.voxels, self.voxel_size, bits)
+        sampled_idx = torch.arange(self.n_voxels, device=self.device)[:, None].expand(
+            *sampled_xyz.shape[:2])
+        sampled_xyz, sampled_idx = sampled_xyz.reshape(-1, 3), sampled_idx.flatten()
+
+    @torch.no_grad()
+    def _get_scores(self, score_fn, reduce_fn=None, bits=16) -> torch.Tensor:
+        def get_scores_once(pts, idxs):
+            scores = score_fn(pts, idxs).reshape(-1, bits ** 3)  # (B, P)
+            if reduce_fn is not None:
+                scores = reduce_fn(scores)  # (B[, ...])
+            return scores
+
+        sampled_xyz, sampled_idx = self.sample(bits)
+        chunk_size = 64
+        return torch.cat([
+            get_scores_once(sampled_xyz[i:i + chunk_size], sampled_idx[i:i + chunk_size])
+            for i in range(0, self.voxels.size(0), chunk_size)
+        ], 0)  # (M[, ...])
+
+    def _ray_intersect(self, rays_o: torch.Tensor, rays_d: torch.Tensor, n_max_hits: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+        return aabb_ray_intersect(self.voxel_size, n_max_hits, self.voxels, rays_o, rays_d)
+
+
+    def _update_corners(self):
+        """
+        Update voxel corners.
+        """
+        corners, corner_indices = get_corners(self.voxels, self.bbox, self.steps)
+        self.register_buffer("corners", corners)
+        self.register_buffer("corner_indices", corner_indices)
+
+    def _update_voxel_indices_in_grid(self):
+        """
+        Update voxel indices in grid.
+        """
+        grid_indices, _ = to_grid_indices(self.voxels, self.bbox, steps=self.steps)
+        self.voxel_indices_in_grid = grid_indices.new_full([self.steps.prod().item()], -1)
+        self.voxel_indices_in_grid[grid_indices] = torch.arange(self.n_voxels, device=self.device)
+
+    @torch.no_grad()
+    def _before_load_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys,
+                                unexpected_keys, error_msgs):
+        # Handle buffers
+        for name, buffer in self.named_buffers(recurse=False):
+            if name in self._non_persistent_buffers_set:
+                continue
+            buffer.resize_as_(state_dict[prefix + name])
+
+        # Handle embeddings
+        for name, module in self.named_modules():
+            if name.startswith('emb_'):
+                setattr(self, name, torch.nn.Embedding(self.n_corners.item(), module.embedding_dim))
+
+
+class Octree(Voxels):
+
+    def __init__(self, **kwargs) -> None:
+        super().__init__(**kwargs)
+        self.nodes_cached = None
+        self.tree_cached = None
+
+    def get(self) -> Tuple[torch.Tensor, torch.Tensor]:
+        if self.nodes_cached is None:
+            self.nodes_cached, self.tree_cached = build_easy_octree(
+                self.voxels, 0.5 * self.voxel_size)
+        return self.nodes_cached, self.tree_cached
+
+    def clear(self):
+        self.nodes_cached = None
+        self.tree_cached = None
+
+    def _ray_intersect(self, rays_o: torch.Tensor, rays_d: torch.Tensor, n_max_hits: int):
+        nodes, tree = self.get()
+        return octree_ray_intersect(self.voxel_size, n_max_hits, nodes, tree, rays_o, rays_d)
+
+    @torch.no_grad()
+    def splitting(self):
+        ret = super().splitting()
+        self.clear()
+        return ret
+
+    @torch.no_grad()
+    def prune(self, keeps: torch.Tensor) -> Tuple[int, int]:
+        ret = super().prune(keeps)
+        self.clear()
+        return ret
diff --git a/nerf++ b/nerf++
deleted file mode 160000
index a30f1a5..0000000
--- a/nerf++
+++ /dev/null
@@ -1 +0,0 @@
-Subproject commit a30f1a5ad116e43aad90c426a966b2a3fcedaf7e
diff --git a/nets/nerf.py b/nets/nerf.py
deleted file mode 100644
index 2d424d4..0000000
--- a/nets/nerf.py
+++ /dev/null
@@ -1,78 +0,0 @@
-import torch
-import torch.nn as nn
-from modules import *
-from utils import color
-
-
-class Nerf(nn.Module):
-
-    def __init__(self, fc_params, sampler_params, *,
-                 c: int = color.RGB,
-                 n_pos_encode: int = 0,
-                 n_dir_encode: int = None,
-                 coarse_net=None, **kwargs):
-        """
-        Initialize a NeRF unit
-
-        :param fc_params `dict`: parameters for full-connection network
-        :param sampler_params `dict`: parameters for sampler
-        :param c `int`: color mode
-        :param n_pos_encode `int`: encode position to number of dimensions
-        :param n_dir_encode `int`: encode direction to number of dimensions, `None` means direction is ignored
-        :param coarse_net `NerfUnit`: optional coarse net
-        """
-        super().__init__()
-        self.coarse_net = coarse_net
-        self.color = c
-        self.coord_chns = 3
-        self.color_chns = color.chns(self.color)
-
-        self.pos_encoder = InputEncoder.Get(n_pos_encode, self.coord_chns)
-
-        if n_dir_encode is not None:
-            self.dir_chns = 3
-            self.dir_encoder = InputEncoder.Get(n_dir_encode, self.dir_chns)
-        else:
-            self.dir_chns = 0
-            self.dir_encoder = None
-        self.core = NerfCore(coord_chns=self.pos_encoder.out_dim,
-                             density_chns=1,
-                             color_chns=self.color_chns,
-                             core_nf=fc_params['nf'],
-                             core_layers=fc_params['n_layers'],
-                             dir_chns=self.dir_encoder.out_dim if self.dir_encoder else 0,
-                             dir_nf=fc_params['nf'] // 2,
-                             activation=fc_params['activation'],
-                             skips=fc_params['skips'])
-        sampler_params['spherical'] = False
-        self.sampler = PdfSampler(**sampler_params) if self.coarse_net is not None \
-            else Sampler(**sampler_params)
-        self.rendering = VolumnRenderer()
-
-    def forward(self, rays_o: torch.Tensor, rays_d: torch.Tensor, *,
-                ret_depth=False, debug=False) -> torch.Tensor:
-        """
-        rays -> colors
-
-        :param rays_o `Tensor(B, 3)`: rays' origin
-        :param rays_d `Tensor(B, 3)`: rays' direction
-        :param prev_ret `Mapping`:
-        :param ret_depth `bool`:
-        :return: `Tensor(B, C)``, inferred images/pixels
-        """
-        if self.coarse_net is not None:
-            coarse_ret = self.coarse_net(rays_o, rays_d, ret_depth=ret_depth, debug=debug)
-            coords, depths, s_vals, _ = self.sampler(rays_o, rays_d, coarse_ret['sample'],
-                                                     coarse_ret['weight'])
-        else:
-            coords, depths, s_vals, _ = self.sampler(rays_o, rays_d)
-        coords_encoded = self.pos_encoder(coords)
-        dirs_encoded = self.dir_encoder(rays_d)[:, None].expand(-1, s_vals.size(-1), -1) \
-            if self.dir_encoder is not None else None
-        colors, densities = self.core(coords_encoded, dirs_encoded)
-        ret = self.rendering(colors, densities[..., 0], depths, ret_depth=ret_depth, debug=debug)
-        ret['sample'] = s_vals
-        if self.coarse_net is not None:
-            ret['coarse'] = coarse_ret
-        return ret
-
diff --git a/nets/nsvf.py b/nets/nsvf.py
deleted file mode 100644
index cd1fb7b..0000000
--- a/nets/nsvf.py
+++ /dev/null
@@ -1,71 +0,0 @@
-import torch
-import torch.nn as nn
-from modules import *
-from utils import color
-
-
-class NSVF(nn.Module):
-
-    def __init__(self, fc_params, sampler_params, *,
-                 c: int = color.RGB,
-                 n_featdim: int = 32,
-                 n_pos_encode: int = 0,
-                 n_dir_encode: int = None,
-                 **kwargs):
-        """
-        Initialize a NSVF model
-
-        :param fc_params `dict`: parameters for full-connection network
-        :param sampler_params `dict`: parameters for sampler
-        :param c `int`: color mode
-        :param n_pos_encode `int`: encode position to number of dimensions
-        :param n_dir_encode `int`: encode direction to number of dimensions, `None` means direction is ignored
-        :param coarse_net `NerfUnit`: optional coarse net
-        """
-        super().__init__()
-        self.color = c
-        self.coord_chns = n_featdim
-        self.color_chns = color.chns(self.color)
-
-        self.pos_encoder = InputEncoder.Get(n_pos_encode, self.coord_chns)
-        if n_dir_encode is not None:
-            self.dir_chns = 3
-            self.dir_encoder = InputEncoder.Get(n_dir_encode, self.dir_chns)
-        else:
-            self.dir_chns = 0
-            self.dir_encoder = None
-        self.core = NerfCore(coord_chns=self.pos_encoder.out_dim,
-                             density_chns=1,
-                             color_chns=self.color_chns,
-                             core_nf=fc_params['nf'],
-                             core_layers=fc_params['n_layers'],
-                             dir_chns=self.dir_encoder.out_dim if self.dir_encoder else 0,
-                             dir_nf=fc_params['nf'] // 2,
-                             activation=fc_params['activation'],
-                             skips=fc_params['skips'])
-        
-        self.space = OctTreeSpace()
-
-        sampler_params['space'] = self.space
-        self.sampler = VoxelSampler(**sampler_params)
-        self.rendering = VolumnRenderer()
-
-    def forward(self, rays_o: torch.Tensor, rays_d: torch.Tensor, *,
-                ret_depth=False, debug=False) -> torch.Tensor:
-        """
-        rays -> colors
-
-        :param rays_o `Tensor(B, 3)`: rays' origin
-        :param rays_d `Tensor(B, 3)`: rays' direction
-        :param prev_ret `Mapping`:
-        :param ret_depth `bool`:
-        :return: `Tensor(B, C)``, inferred images/pixels
-        """
-        feats, dirs, z_s, dz_s = self.sampler(rays_o, rays_d)
-        feats_encoded = self.pos_encoder(feats)
-        dirs_encoded = self.dir_encoder(rays_d)[:, None].expand(-1, z_s.size(-1), -1) \
-            if self.dir_encoder is not None else None
-        colors, densities = self.core(feats_encoded, dirs_encoded)
-        ret = self.rendering(colors, densities[..., 0], z_s, dz_s, ret_depth=ret_depth, debug=debug)
-        return ret
-
diff --git a/nets/snerf.py b/nets/snerf.py
deleted file mode 100644
index b2fdfb0..0000000
--- a/nets/snerf.py
+++ /dev/null
@@ -1,110 +0,0 @@
-import torch
-import torch.nn as nn
-from modules import *
-from utils import sphere
-from utils import color
-
-
-class Snerf(nn.Module):
-
-    def __init__(self, fc_params, sampler_params, *,
-                 n_parts: int = 1,
-                 c: int = color.RGB,
-                 pos_encode: int = 10,
-                 dir_encode: int = None,
-                 spherical_dir: bool = False, **kwargs):
-        """
-        Initialize a multi-sphere-layer net
-
-        :param fc_params: parameters for full-connection network
-        :param sampler_params: parameters for sampler
-        :param normalize_coord: whether normalize the spherical coords to [0, 2pi] before encode
-        :param c: color mode
-        :param encode_to_dim: encode input to number of dimensions
-        """
-        super().__init__()
-        self.color = c
-        self.spherical_dir = spherical_dir
-        self.n_samples = sampler_params['n_samples']
-        self.n_parts = n_parts
-        self.samples_per_part = self.n_samples // self.n_parts
-        self.coord_chns = 3
-        self.color_chns = color.chns(self.color)
-        self.pos_encoder = InputEncoder.Get(pos_encode, self.coord_chns)
-
-        if dir_encode is not None:
-            self.dir_encoder = InputEncoder.Get(dir_encode, 2 if self.spherical_dir else 3)
-            self.dir_chns_encoded = self.dir_encoder.out_dim
-        else:
-            self.dir_encoder = None
-            self.dir_chns_encoded = 0
-
-        self.nets = nn.ModuleList(
-            NerfCore(coord_chns=self.pos_encoder.out_dim,
-                     density_chns=1,
-                     color_chns=self.color_chns,
-                     core_nf=fc_params['nf'],
-                     core_layers=fc_params['n_layers'],
-                     dir_chns=self.dir_chns_encoded,
-                     dir_nf=fc_params['nf'] // 2,
-                     activation=fc_params['activation'])
-            for _ in range(self.n_parts))
-        sampler_params['spherical'] = True
-        self.sampler = Sampler(**sampler_params)
-        self.rendering = VolumnRenderer()
-
-    def forward(self, rays_o: torch.Tensor, rays_d: torch.Tensor,
-                ret_depth=False, debug=False) -> torch.Tensor:
-        """
-        rays -> colors
-
-        :param rays_o `Tensor(B, 3)`: rays' origin
-        :param rays_d `Tensor(B, 3)`: rays' direction
-        :return: `Tensor(B, C)``, inferred images/pixels
-        """
-        n_rays = rays_o.size(0)
-        coords, depths, _, pts = self.sampler(rays_o, rays_d)
-        coords_encoded = self.pos_encoder(coords)
-        if self.dir_encoder is not None:
-            if self.spherical_dir:
-                dirs_encoded = self.dir_encoder(sphere.calc_local_dir(rays_d, coords, pts))
-            else:
-                dirs_encoded = self.dir_encoder(rays_d)[:, None].expand(-1, self.n_samples, -1)
-        else:
-            dirs_encoded = None
-
-        densities = torch.empty(n_rays, self.n_samples, device=device.default())
-        colors = torch.empty(n_rays, self.n_samples, self.color_chns, device=device.default())
-        for i, net in enumerate(self.nets):
-            s = slice(i * self.samples_per_part, (i + 1) * self.samples_per_part)
-            c, d = net(coords_encoded[:, s],
-                       dirs_encoded[:, s] if dirs_encoded is not None else None)
-            colors[:, s] = c
-            densities[:, s] = d
-        ret = self.rendering(colors.view(-1, self.n_samples, self.color_chns),
-                             densities, depths, ret_depth=ret_depth, debug=debug)
-        if debug:
-            ret['sample_densities'] = densities
-            ret['sample_depths'] = depths
-        return ret
-
-
-class SnerfExport(nn.Module):
-
-    def __init__(self, net: Snerf):
-        super().__init__()
-        self.net = net
-
-    def forward(self, coords_encoded, z_vals):
-        colors = []
-        densities = []
-        for i in range(self.net.n_parts):
-            s = slice(i * self.net.samples_per_part, (i + 1) * self.net.samples_per_part)
-            mlp = self.net.nets[i] if self.net.nets is not None else self.net.net
-            c, d = mlp(coords_encoded[:, s].flatten(1, 2))
-            colors.append(c.view(-1, self.net.samples_per_part, self.net.color_chns))
-            densities.append(d)
-        colors = torch.cat(colors, 1)
-        densities = torch.cat(densities, 1)
-        alphas = self.net.rendering.density2alpha(densities, z_vals)
-        return torch.cat([colors, alphas[..., None]], -1)
diff --git a/notebook/gen_crop.ipynb b/notebook/gen_crop.ipynb
index 5c191cf..e7f6a9b 100644
--- a/notebook/gen_crop.ipynb
+++ b/notebook/gen_crop.ipynb
@@ -3,13 +3,11 @@
   {
    "cell_type": "code",
    "execution_count": null,
-   "metadata": {},
-   "outputs": [],
    "source": [
     "import sys\n",
     "import os\n",
     "import torch\n",
-    "import torch.nn as nn\n",
+    "import torch.nn.functional as nn_f\n",
     "import matplotlib.pyplot as plt\n",
     "\n",
     "rootdir = os.path.abspath(sys.path[0] + '/../')\n",
@@ -18,21 +16,16 @@
     "print(\"Set CUDA:%d as current device.\" % torch.cuda.current_device())\n",
     "torch.autograd.set_grad_enabled(False)\n",
     "\n",
-    "from data.spherical_view_syn import *\n",
-    "from configs.spherical_view_syn import SphericalViewSynConfig\n",
-    "from utils import netio\n",
     "from utils import img\n",
-    "from utils import device\n",
     "from utils.view import *\n",
-    "from components.fnr import FoveatedNeuralRenderer\n",
     "\n",
     "datadir = f\"{rootdir}/data/__new/__demo/for_crop\"\n",
     "figs = ['our', 'gt', 'nerf', 'fgt']\n",
     "crops = {\n",
-    "    'classroom_0': [[720, 800, 128], [1097, 982, 256]],\n",
-    "    'lobby_1': [[570, 1000, 100], [1049, 1049, 256]],\n",
-    "    'stones_2': [[720, 800, 100], [680, 1317, 256]],\n",
-    "    'barbershop_3': [[745, 810, 100], [1135, 627, 256]]\n",
+    "    'classroom_0': [[720, 790, 100], [370, 1160, 200]],\n",
+    "    'lobby_1': [[570, 1000, 100], [1300, 1000, 200]],\n",
+    "    'stones_2': [[720, 800, 100], [680, 1317, 200]],\n",
+    "    'barbershop_3': [[745, 810, 100], [950, 900, 200]]\n",
     "}\n",
     "colors = torch.tensor([[0, 1, 0, 1], [1, 1, 0, 1]], dtype=torch.float)\n",
     "border = 10\n",
@@ -78,16 +71,18 @@
     "    img.save(torch.cat([fovea_patches, periph_patches], dim=-1),\n",
     "             [f\"{datadir}/patch/{scene}_{fig}.png\" for fig in figs])\n",
     "    img.save(overlay, f\"{datadir}/overlay/{scene}.png\")\n"
-   ]
+   ],
+   "outputs": [],
+   "metadata": {}
   }
  ],
  "metadata": {
   "interpreter": {
-   "hash": "82066b63b621a9e3d15e3b7c11ca76da6238eff3834294910d715044bd0561e5"
+   "hash": "65406b00395a48e1d89cf658ae895e7869e05878f5469716b06a752a3915211c"
   },
   "kernelspec": {
-   "display_name": "Python 3.8.5 64-bit ('base': conda)",
-   "name": "python3"
+   "name": "python3",
+   "display_name": "Python 3.8.5 64-bit ('base': conda)"
   },
   "language_info": {
    "codemirror_mode": {
diff --git a/notebook/gen_demo_mono.ipynb b/notebook/gen_demo_mono.ipynb
index 0371863..5968a13 100644
--- a/notebook/gen_demo_mono.ipynb
+++ b/notebook/gen_demo_mono.ipynb
@@ -170,7 +170,7 @@
     "    images['overlaid'] = renderer.foveation.synthesis(images['layers_raw'], param[-2:], do_blend=False)\n",
     "    if True:\n",
     "        outputdir = '../__demo/mono/'\n",
-    "        misc.create_dir(outputdir)\n",
+    "        os.makedirs(outputdir, exist_ok=True)\n",
     "        img.save(images['layers_img'][0], f'{outputdir}{scene}_{i}_fovea.png')\n",
     "        img.save(images['layers_img'][1], f'{outputdir}{scene}_{i}_mid.png')\n",
     "        img.save(images['layers_img'][2], f'{outputdir}{scene}_{i}_periph.png')\n",
@@ -203,7 +203,7 @@
     "    center = (0, 0)\n",
     "    images = renderer(views.get(view_idx), center, using_mask=True)\n",
     "    outputdir = 'panorama'\n",
-    "    misc.create_dir(outputdir)\n",
+    "    os.makedirs(outputdir, exist_ok=True)\n",
     "    img.save(images['blended'], f'{outputdir}/{view_idx:04d}.png')"
    ],
    "outputs": [
diff --git a/notebook/gen_demo_stereo.ipynb b/notebook/gen_demo_stereo.ipynb
index f44e65e..138950d 100644
--- a/notebook/gen_demo_stereo.ipynb
+++ b/notebook/gen_demo_stereo.ipynb
@@ -216,7 +216,7 @@
     "                                                ret_raw=False)\n",
     "            if True:\n",
     "                outputdir = '../__demo/stereo_m%d' % mono_periph if mono_periph else '../__demo/stereo'\n",
-    "                misc.create_dir(outputdir)\n",
+    "                os.makedirs(outputdir, exist_ok=True)\n",
     "                img.save(torch.cat([\n",
     "                    left_images['blended'],\n",
     "                    right_images['blended']\n",
@@ -228,7 +228,7 @@
     "                    right_images['blended'][:, 1:3]\n",
     "                ], dim=1)\n",
     "                img.save(stereo_overlap, '%s/%s_%d_stereo.png' % (outputdir, scene, i))\n",
-    "                #misc.create_dir(outputdir + '/mid')\n",
+    "                #os.makedirs(outputdir + '/mid', exist_ok=True)\n",
     "                #img.save(left_images['layers_img'][1], '%s/mid/%s_%d_l.png' % (outputdir, scene, i))\n",
     "                #img.save(right_images['layers_img'][1], '%s/mid/%s_%d_r.png' % (outputdir, scene, i))\n",
     "                print(\"%s %d Saved\" % (scene, i))\n",
diff --git a/notebook/gen_for_eval.ipynb b/notebook/gen_for_eval.ipynb
index cada07b..a861f25 100644
--- a/notebook/gen_for_eval.ipynb
+++ b/notebook/gen_for_eval.ipynb
@@ -110,7 +110,7 @@
     "    #plot_figures(images, center)\n",
     "\n",
     "    outputdir = '../__1_eval/output_mono_periph/ref_as_right_eye/%s/' % scene\n",
-    "    misc.create_dir(outputdir)\n",
+    "    os.makedirs(outputdir, exist_ok=True)\n",
     "    #for key in images:\n",
     "    key = 'blended'\n",
     "    img.save(images[key], outputdir + 'view%04d_%s.png' % (view_idx, key))\n"
@@ -131,7 +131,7 @@
     "        images = gen.gen(center, test_view, True)\n",
     "        #plot_figures(images, center)\n",
     "\n",
-    "        misc.create_dir('output/eval_gaze')\n",
+    "        os.makedirs('output/eval_gaze', exist_ok=True)\n",
     "        out_path = 'output/eval_gaze/gaze%03d_%d,%d.png' % (gaze_idx, x, y)\n",
     "        img.save(images['blended'], out_path)\n",
     "        print('Output ' + out_path)\n",
diff --git a/notebook/gen_teaser.ipynb b/notebook/gen_teaser.ipynb
index ea770f6..2c20f68 100644
--- a/notebook/gen_teaser.ipynb
+++ b/notebook/gen_teaser.ipynb
@@ -130,7 +130,7 @@
     "    images = gen.gen(center, test_view, True)\n",
     "    #plot_figures(images, center)\n",
     "\n",
-    "    misc.create_dir('output/teasers')\n",
+    "    os.makedirs('output/teasers', exist_ok=True)\n",
     "    for key in images:\n",
     "        img.save(\n",
     "            images[key], 'output/teasers/view%04d_%s.png' % (view_idx, key))\n"
diff --git a/notebook/gen_test.ipynb b/notebook/gen_test.ipynb
index eabf583..cd6197d 100644
--- a/notebook/gen_test.ipynb
+++ b/notebook/gen_test.ipynb
@@ -150,7 +150,7 @@
     "print(\"Encoded:\", encoded)\n",
     "#plot_figures(images, center)\n",
     "\n",
-    "#misc.create_dir('output/teasers')\n",
+    "#os.makedirs('output/teasers', exist_ok=True)\n",
     "#for key in images:\n",
     "#    img.save(\n",
     "#        images[key], 'output/teasers/view%04d_%s.png' % (view_idx, key))\n"
diff --git a/notebook/gen_user_study_images.ipynb b/notebook/gen_user_study_images.ipynb
index 238185d..4518b11 100644
--- a/notebook/gen_user_study_images.ipynb
+++ b/notebook/gen_user_study_images.ipynb
@@ -188,7 +188,7 @@
     "\n",
     "#plot_figures(left_images, right_images, centers[set_id][0], centers[set_id][1])\n",
     "\n",
-    "misc.create_dir('output')\n",
+    "os.makedirs('output', exist_ok=True)\n",
     "for key in left_images:\n",
     "    img.save(\n",
     "        left_images[key], 'output/set%d_%s_l.png' % (set_id, key))\n",
diff --git a/notebook/gen_video.ipynb b/notebook/gen_video.ipynb
index a5feb72..df80a3c 100644
--- a/notebook/gen_video.ipynb
+++ b/notebook/gen_video.ipynb
@@ -117,7 +117,7 @@
     "    left_images = gen.gen(left_center, left_view, mono_trans=mono_trans)\n",
     "    right_images = gen.gen(right_center, right_view, mono_trans=mono_trans)\n",
     "    \n",
-    "    misc.create_dir('output/video_frames/hmd2')\n",
+    "    os.makedirs('output/video_frames/hmd2', exist_ok=True)\n",
     "    img.save(torch.cat([left_images['blended'], right_images['blended']], -1),\n",
     "                          'output/video_frames/hmd2/view%04d.png' % view_idx)\n",
     "    print('Frame %d saved' % view_idx)\n"
diff --git a/notebook/net_insight.ipynb b/notebook/net_insight.ipynb
index e71898e..00eb3e9 100644
--- a/notebook/net_insight.ipynb
+++ b/notebook/net_insight.ipynb
@@ -155,7 +155,7 @@
     "    images['overlaid'] = renderer.foveation.synthesis(images['layers_raw'], param[-2:], do_blend=False)\n",
     "    if True:\n",
     "        outputdir = '../__demo/mono/'\n",
-    "        misc.create_dir(outputdir)\n",
+    "        os.makedirs(outputdir, exist_ok=True)\n",
     "        img.save(images['layers_img'][0], f'{outputdir}{scene}_{i}_fovea.png')\n",
     "        img.save(images['layers_img'][1], f'{outputdir}{scene}_{i}_mid.png')\n",
     "        img.save(images['layers_img'][2], f'{outputdir}{scene}_{i}_periph.png')\n",
@@ -196,7 +196,7 @@
     "    center = (0, 0)\n",
     "    images = renderer(views.get(view_idx), center, using_mask=True)\n",
     "    outputdir = 'nerf_our'\n",
-    "    misc.create_dir(outputdir)\n",
+    "    os.makedirs(outputdir, exist_ok=True)\n",
     "    img.save(images['blended'], f'{outputdir}/{view_idx:04d}.png')"
    ]
   }
diff --git a/notebook/test_mono_gen.ipynb b/notebook/test_mono_gen.ipynb
index e278546..545ee19 100644
--- a/notebook/test_mono_gen.ipynb
+++ b/notebook/test_mono_gen.ipynb
@@ -101,7 +101,7 @@
     "gaze = [37.55656052, 20.7297554]\n",
     "images = renderer(view, gaze, using_mask=False, ret_raw=True)\n",
     "outputdir = '../__demo/mono_f60&m110/'\n",
-    "misc.create_dir(outputdir)\n",
+    "os.makedirs(outputdir, exist_ok=True)\n",
     "img.save(images['layers_img'][0], f'{outputdir}{scene}_fovea.png')\n",
     "img.save(images['blended'], f'{outputdir}{scene}.png')\n",
     "img.save(images['blended_raw'], f'{outputdir}{scene}_noCE.png')"
diff --git a/notebook/test_mono_view.ipynb b/notebook/test_mono_view.ipynb
index 72d05c5..323287e 100644
--- a/notebook/test_mono_view.ipynb
+++ b/notebook/test_mono_view.ipynb
@@ -249,7 +249,7 @@
     "\n",
     "plot_figures(left_images, right_images, left_center, right_center)\n",
     "\n",
-    "misc.create_dir('output/mono_test')\n",
+    "os.makedirs('output/mono_test', exist_ok=True)\n",
     "for key in left_images:\n",
     "    img.save(\n",
     "        left_images[key], 'output/mono_test/set%d_%s_l.png' % (set_id, key))\n",
diff --git a/run_lf_syn.py b/run_lf_syn.py
index 8c1b5e9..23b3bf1 100644
--- a/run_lf_syn.py
+++ b/run_lf_syn.py
@@ -58,7 +58,7 @@ def train():
     epoch = EPOCH_BEGIN
     iters = EPOCH_BEGIN * len(train_data_loader) * BATCH_SIZE
 
-    misc.create_dir(RUN_DIR)
+    os.makedirs(RUN_DIR, exist_ok=True)
 
     perf = Perf(enable=(MODE == "Perf"), start=True)
     writer = SummaryWriter(RUN_DIR)
@@ -129,7 +129,7 @@ def test(net_file: str):
 
     # 3. Test on train dataset
     print("Begin test on train dataset...")
-    misc.create_dir(OUTPUT_DIR)
+    os.makedirs(OUTPUT_DIR, exist_ok=True)
     for view_idxs, view_images, _, view_positions in train_data_loader:
         out_view_images = model(view_positions)
         img.save(view_images,
diff --git a/run_spherical_view_syn.py b/run_spherical_view_syn.py
index b87daa9..d5ef230 100644
--- a/run_spherical_view_syn.py
+++ b/run_spherical_view_syn.py
@@ -316,8 +316,8 @@ def train():
     if epochRange.start > 1:
         iters = netio.load(f'{run_dir}model-epoch_{epochRange.start - 1}.pth', model)
     else:
-        misc.create_dir(run_dir)
-        misc.create_dir(log_dir)
+        os.makedirs(run_dir, exist_ok=True)
+        os.makedirs(log_dir, exist_ok=True)
         iters = 0
 
     # 3. Train
@@ -400,7 +400,7 @@ def test():
 
         # 4. Save results
         print('Saving results...')
-        misc.create_dir(output_dir)
+        os.makedirs(output_dir, exist_ok=True)
 
         for key in out:
             shape = [n] + list(dataset.res) + list(out[key].size()[1:])
@@ -446,7 +446,7 @@ def test():
                 img.save_video(out['color'], output_file, 30)
             else:
                 output_subdir = f"{output_dir}/{output_dataset_id}_color"
-                misc.create_dir(output_subdir)
+                os.makedirs(output_subdir, exist_ok=True)
                 img.save(out['color'], [f'{output_subdir}/{i:0>4d}.png' for i in dataset.indices])
 
         if args.output_flags['depth']:
@@ -457,13 +457,13 @@ def test():
                 img.save_video(colorized_depths, output_file, 30)
             else:
                 output_subdir = f"{output_dir}/{output_dataset_id}_depth"
-                misc.create_dir(output_subdir)
+                os.makedirs(output_subdir, exist_ok=True)
                 img.save(colorized_depths, [
                     f'{output_subdir}/{i:0>4d}.png'
                     for i in dataset.indices
                 ])
                 output_subdir = f"{output_dir}/{output_dataset_id}_bins"
-                misc.create_dir(output_subdir)
+                os.makedirs(output_subdir, exist_ok=True)
                 img.save(out['bins'], [f'{output_subdir}/{i:0>4d}.png' for i in dataset.indices])
 
         if args.output_flags['layers']:
@@ -473,7 +473,7 @@ def test():
                     img.save_video(out['layers'][j], output_file, 30)
             else:
                 output_subdir = f"{output_dir}/{output_dataset_id}_layers"
-                misc.create_dir(output_subdir)
+                os.makedirs(output_subdir, exist_ok=True)
                 for j in range(config.sa['n_samples']):
                     img.save(out['layers'][j], [
                         f'{output_subdir}/{i:0>4d}[{j:0>3d}].png'
@@ -543,7 +543,7 @@ def test1():
 
         # 4. Save results
         print('Saving results...')
-        misc.create_dir(output_dir)
+        os.makedirs(output_dir, exist_ok=True)
 
         for key in out:
             shape = [n] + list(dataset.res) + list(out[key].size()[1:])
@@ -587,7 +587,7 @@ def test1():
                 img.save_video(out['color'], output_file, 30)
             else:
                 output_subdir = f"{output_dir}/{output_dataset_id}_color"
-                misc.create_dir(output_subdir)
+                os.makedirs(output_subdir, exist_ok=True)
                 img.save(out['color'], [f'{output_subdir}/{i:0>4d}.png' for i in dataset.indices])
 
         if args.output_flags['depth']:
@@ -598,7 +598,7 @@ def test1():
                 img.save_video(colorized_depths, output_file, 30)
             else:
                 output_subdir = f"{output_dir}/{output_dataset_id}_depth"
-                misc.create_dir(output_subdir)
+                os.makedirs(output_subdir, exist_ok=True)
                 img.save(colorized_depths, [
                     f'{output_subdir}/{i:0>4d}.png'
                     for i in dataset.indices
@@ -611,7 +611,7 @@ def test1():
                     img.save_video(out['layers'][j], output_file, 30)
             else:
                 output_subdir = f"{output_dir}/{output_dataset_id}_layers"
-                misc.create_dir(output_subdir)
+                os.makedirs(output_subdir, exist_ok=True)
                 for j in range(config.sa['n_samples']):
                     img.save(out['layers'][j], [
                         f'{output_subdir}/{i:0>4d}[{j:0>3d}].png'
@@ -679,7 +679,7 @@ def test2():
 
         # 4. Save results
         print('Saving results...')
-        misc.create_dir(output_dir)
+        os.makedirs(output_dir, exist_ok=True)
 
         for key in out:
             shape = [n] + list(dataset.res) + list(out[key].size()[1:])
@@ -723,7 +723,7 @@ def test2():
                 img.save_video(out['color'], output_file, 30)
             else:
                 output_subdir = f"{output_dir}/{output_dataset_id}_color"
-                misc.create_dir(output_subdir)
+                os.makedirs(output_subdir, exist_ok=True)
                 img.save(out['color'], [f'{output_subdir}/{i:0>4d}.png' for i in dataset.indices])
 
         if args.output_flags['depth']:
@@ -734,7 +734,7 @@ def test2():
                 img.save_video(colorized_depths, output_file, 30)
             else:
                 output_subdir = f"{output_dir}/{output_dataset_id}_depth"
-                misc.create_dir(output_subdir)
+                os.makedirs(output_subdir, exist_ok=True)
                 img.save(colorized_depths, [
                     f'{output_subdir}/{i:0>4d}.png'
                     for i in dataset.indices
@@ -747,7 +747,7 @@ def test2():
                     img.save_video(out['layers'][j], output_file, 30)
             else:
                 output_subdir = f"{output_dir}/{output_dataset_id}_layers"
-                misc.create_dir(output_subdir)
+                os.makedirs(output_subdir, exist_ok=True)
                 for j in range(config.sa['n_samples']):
                     img.save(out['layers'][j], [
                         f'{output_subdir}/{i:0>4d}[{j:0>3d}].png'
diff --git a/setup.py b/setup.py
new file mode 100644
index 0000000..26d489e
--- /dev/null
+++ b/setup.py
@@ -0,0 +1,27 @@
+from setuptools import setup
+from torch.utils.cpp_extension import BuildExtension, CUDAExtension
+import glob
+import os
+import sys
+
+# build clib
+src_root = "clib"
+sources = glob.glob(f"{src_root}/src/*.cpp") + glob.glob(f"{src_root}/src/*.cu")
+includes = f"{sys.path[0]}/{src_root}/include"
+
+setup(
+    name='dvs',
+    ext_modules=[
+        CUDAExtension(
+            name='clib._ext',
+            sources=sources,
+            extra_compile_args={
+                "cxx": ["-O2", f"-I{includes}"],
+                "nvcc": ["-O2", f"-I{includes}"],
+            },
+        )
+    ],
+    cmdclass={
+        'build_ext': BuildExtension
+    }
+)
\ No newline at end of file
diff --git a/term_test.py b/term_test.py
new file mode 100644
index 0000000..249bf9a
--- /dev/null
+++ b/term_test.py
@@ -0,0 +1,15 @@
+import os
+import shutil
+from sys import stdout
+from time import sleep
+from utils.progress_bar import *
+
+i = 0
+while True:
+    rows = shutil.get_terminal_size().lines
+    cols = shutil.get_terminal_size().columns
+    os.system('cls' if os.name == 'nt' else 'clear')
+    stdout.write("\n" * (rows - 1))
+    progress_bar(i, 10000, "Test", "XXX")
+    i += 1
+    sleep(0.02)
diff --git a/test.py b/test.py
new file mode 100644
index 0000000..eab0340
--- /dev/null
+++ b/test.py
@@ -0,0 +1,226 @@
+import os
+import argparse
+import torch
+import torch.nn.functional as nn_f
+from math import nan, ceil, prod
+from pathlib import Path
+
+parser = argparse.ArgumentParser()
+parser.add_argument('-m', '--model', type=str,
+                    help='The model file to load for testing')
+parser.add_argument('-r', '--output-res', type=str,
+                    help='Output resolution')
+parser.add_argument('-o', '--output', nargs='+', type=str, default=['perf', 'color'],
+                    help='Specify what to output (perf, color, depth, all)')
+parser.add_argument('--output-type', type=str, default='image',
+                    help='Specify the output type (image, video, debug)')
+parser.add_argument('--views', type=str,
+                    help='Specify the range of views to test')
+parser.add_argument('-p', '--prompt', action='store_true',
+                    help='Interactive prompt mode')
+parser.add_argument('--time', action='store_true',
+                    help='Enable time measurement')
+parser.add_argument('dataset', type=str,
+                    help='Dataset description file')
+args = parser.parse_args()
+
+
+import model as mdl
+from loss.ssim import ssim
+from utils import color
+from utils import interact
+from utils import device
+from utils import img
+from utils.perf import Perf, enable_perf, get_perf_result
+from utils.progress_bar import progress_bar
+from data.dataset_factory import *
+from data.loader import DataLoader
+from utils.constants import HUGE_FLOAT
+
+
+RAYS_PER_BATCH = 2 ** 14
+DATA_LOADER_CHUNK_SIZE = 1e8
+
+
+data_desc_path = DatasetFactory.get_dataset_desc_path(args.dataset)
+os.chdir(data_desc_path.parent)
+nets_dir = Path("_nets")
+data_desc_path = data_desc_path.name
+
+
+def set_outputs(args, outputs_str: str):
+    args.output = [s.strip() for s in outputs_str.split(',')]
+
+
+if args.prompt:  # Prompt test model, output resolution, output mode
+    model_files = [str(path.relative_to(nets_dir)) for path in nets_dir.rglob("*.tar")] \
+        + [str(path.relative_to(nets_dir)) for path in nets_dir.rglob("*.pth")]
+    args.model = interact.input_enum('Specify test model:', model_files,
+                                     err_msg='No such model file')
+    args.output_res = interact.input_ex('Specify output resolution:',
+                                        default='')
+    set_outputs(args, interact.input_ex('Specify the outputs | [perf,color,depth,layers,diffuse,specular]/all:',
+                                        default='perf,color'))
+    args.output_type = interact.input_enum('Specify the output type | image/video:',
+                                           ['image', 'video'],
+                                           err_msg='Wrong output type',
+                                           default='image')
+args.output_res = tuple(int(s) for s in reversed(args.output_res.split('x'))) if args.output_res \
+    else None
+args.output_flags = {
+    item: item in args.output or 'all' in args.output
+    for item in ['perf', 'color', 'depth', 'layers', 'diffuse', 'specular']
+}
+args.views = range(*[int(val) for val in args.views.split('-')]) if args.views else None
+
+if args.time:
+    enable_perf()
+
+dataset = DatasetFactory.load(data_desc_path, res=args.output_res,
+                              load_images=args.output_flags['perf'],
+                              views_to_load=args.views)
+print(f"Dataset loaded: {dataset.root}/{dataset.name}")
+
+
+model_path: Path = nets_dir / args.model
+model_name = model_path.parent.name
+model = mdl.load(model_path, {
+    "raymarching_early_stop_tolerance": 0.01,
+    # "raymarching_chunk_size_or_sections": [8],
+    "perturb_sample": False
+})[0].to(device.default()).eval()
+model_class = model.__class__.__name__
+model_args = model.args
+print(f"model: {model_name} ({model_class})")
+print("args:", json.dumps(model.args0))
+
+run_dir = model_path.parent
+output_dir = run_dir / f"output_{int(model_path.stem.split('_')[-1])}"
+output_dataset_id = '%s%s' % (
+    dataset.name,
+    f'_{args.output_res[1]}x{args.output_res[0]}' if args.output_res else ''
+)
+
+
+if __name__ == "__main__":
+    with torch.no_grad():
+        # 1. Initialize data loader
+        data_loader = DataLoader(dataset, RAYS_PER_BATCH, chunk_max_items=DATA_LOADER_CHUNK_SIZE,
+                                 shuffle=False, enable_preload=True,
+                                 color=color.from_str(model.args['color']))
+
+        # 3. Test on dataset
+        print("Begin test, batch size is %d" % RAYS_PER_BATCH)
+
+        i = 0
+        offset = 0
+        chns = model.chns('color')
+        n = dataset.n_views
+        total_pixels = prod([n, *dataset.res])
+
+        out = {}
+        if args.output_flags['perf'] or args.output_flags['color']:
+            out['color'] = torch.zeros(total_pixels, chns, device=device.default())
+        if args.output_flags['diffuse']:
+            out['diffuse'] = torch.zeros(total_pixels, chns, device=device.default())
+        if args.output_flags['specular']:
+            out['specular'] = torch.zeros(total_pixels, chns, device=device.default())
+        if args.output_flags['depth']:
+            out['depth'] = torch.full([total_pixels, 1], HUGE_FLOAT, device=device.default())
+        gt_images = torch.empty_like(out['color']) if dataset.image_path else None
+
+        tot_time = 0
+        tot_iters = len(data_loader)
+        progress_bar(i, tot_iters, 'Inferring...')
+        for _, rays_o, rays_d, extra in data_loader:
+            if args.output_flags['perf']:
+                test_perf = Perf.Node("Test")
+            n_rays = rays_o.size(0)
+            idx = slice(offset, offset + n_rays)
+            ret = model(rays_o, rays_d, extra_outputs=[key for key in out.keys() if key != 'color'])
+            if ret is not None:
+                for key in out:
+                    out[key][idx][ret['rays_mask']] = ret[key]
+            if args.output_flags['perf']:
+                test_perf.close()
+                torch.cuda.synchronize()
+                tot_time += test_perf.duration()
+            if gt_images is not None:
+                gt_images[idx] = extra['color']
+            i += 1
+            progress_bar(i, tot_iters, 'Inferring...')
+            offset += n_rays
+
+        # 4. Save results
+        print('Saving results...')
+        output_dir.mkdir(parents=True, exist_ok=True)
+
+        for key in out:
+            out[key] = out[key].reshape([n, *dataset.res, *out[key].shape[1:]])
+        if 'color' in out:
+            out['color'] = out['color'].permute(0, 3, 1, 2)
+        if 'diffuse' in out:
+            out['diffuse'] = out['diffuse'].permute(0, 3, 1, 2)
+        if 'specular' in out:
+            out['specular'] = out['specular'].permute(0, 3, 1, 2)
+
+        if args.output_flags['perf']:
+            perf_errors = torch.full([n], nan)
+            perf_ssims = torch.full([n], nan)
+            if gt_images is not None:
+                gt_images = gt_images.reshape(n, *dataset.res, chns).permute(0, 3, 1, 2)
+                for i in range(n):
+                    perf_errors[i] = nn_f.mse_loss(gt_images[i], out['color'][i]).item()
+                    perf_ssims[i] = ssim(gt_images[i:i + 1], out['color'][i:i + 1]).item() * 100
+            perf_mean_time = tot_time / n
+            perf_mean_error = torch.mean(perf_errors).item()
+            perf_name = f'perf_{output_dataset_id}_{perf_mean_time:.1f}ms_{perf_mean_error:.2e}.csv'
+
+            # Remove old performance reports
+            for file in output_dir.glob(f'perf_{output_dataset_id}*'):
+                file.unlink()
+
+            # Save new performance reports
+            with (output_dir / perf_name).open('w') as fp:
+                fp.write('View, PSNR, SSIM\n')
+                fp.writelines([
+                    f'{dataset.indices[i]}, '
+                    f'{img.mse2psnr(perf_errors[i].item()):.2f}, {perf_ssims[i].item():.2f}\n'
+                    for i in range(n)
+                ])
+
+        for output_type in ['color', 'diffuse', 'specular']:
+            if not args.output_flags[output_type]:
+                continue
+            if args.output_type == 'video':
+                output_file = output_dir / f"{output_dataset_id}_{output_type}.mp4"
+                img.save_video(out[output_type], output_file, 30)
+            else:
+                output_subdir = output_dir / f"{output_dataset_id}_{output_type}"
+                output_subdir.mkdir(exist_ok=True)
+                img.save(out[output_type],
+                         [f'{output_subdir}/{i:0>4d}.png' for i in dataset.indices])
+
+        if args.output_flags['depth']:
+            colored_depths = img.colorize_depthmap(out['depth'][..., 0], model_args['sample_range'])
+            if args.output_type == 'video':
+                output_file = output_dir / f"{output_dataset_id}_depth.mp4"
+                img.save_video(colored_depths, output_file, 30)
+            else:
+                output_subdir = output_dir / f"{output_dataset_id}_depth"
+                output_subdir.mkdir(exist_ok=True)
+                img.save(colored_depths, [f'{output_subdir}/{i:0>4d}.png' for i in dataset.indices])
+                #output_subdir = output_dir / f"{output_dataset_id}_bins"
+                # output_dir.mkdir(exist_ok=True)
+                #img.save(out['bins'], [f'{output_subdir}/{i:0>4d}.png' for i in dataset.indices])
+
+        if args.time:
+            s = "Performance Report ==>\n"
+            res = get_perf_result()
+            if res is None:
+                s += "No available data.\n"
+            else:
+                for key, val in res.items():
+                    path_segs = key.split("/")
+                    s += "  " * (len(path_segs) - 1) + f"{path_segs[-1]}: {val:.1f}ms\n"
+            print(s)
diff --git a/tools/clean_nets.py b/tools/clean_nets.py
index 8140163..e2ac2e4 100644
--- a/tools/clean_nets.py
+++ b/tools/clean_nets.py
@@ -1,21 +1,26 @@
 """
-Clean trained nets (*/model-epoch_#.pth) whose epoch is neither the largest nor a multiple of 50
+Clean trained nets (*/checkpoint_#.tar) whose epoch is neither the largest nor a multiple of 10
 """
 import sys
 import os
 
-sys.path.append(os.path.abspath(sys.path[0] + '/../'))
-
+base_dir = os.path.abspath(sys.path[0] + '/../')
+sys.path.append(base_dir)
 
 if __name__ == "__main__":
-    for dirpath, _, filenames in os.walk('../data'):
-        epoch_list = [int(filename[12:-4]) for filename in filenames
-                        if filename.startswith("model-epoch_")]
+    root = sys.argv[1] if len(sys.argv) > 1 else f'{base_dir}/data'
+    print(f"Clean model files in {root}...")
+
+    for dirpath, _, filenames in os.walk(root):
+        epoch_list = [int(filename[11:-4]) for filename in filenames
+                    if filename.startswith("checkpoint_")]
         if len(epoch_list) <= 1:
             continue
         epoch_list.sort()
         for epoch in epoch_list[:-1]:
-            if epoch % 50 != 0:
-                file_to_del = f"{dirpath}/model-epoch_{epoch}.pth"
+            if epoch % 10 != 0:
+                file_to_del = f"{dirpath}/checkpoint_{epoch}.tar"
                 print(f"Clean model file: {file_to_del}")
-                os.remove(file_to_del)
\ No newline at end of file
+                os.remove(file_to_del)
+
+    print("Finished.")
\ No newline at end of file
diff --git a/tools/depth_downsample.py b/tools/depth_downsample.py
index fd0a5d4..684fd6d 100644
--- a/tools/depth_downsample.py
+++ b/tools/depth_downsample.py
@@ -18,6 +18,6 @@ os.chdir(in_set)
 depthmaps = img.load(img_names)
 depthmaps = torch.floor((depthmaps * 16)) / 16
 
-misc.create_dir(out_set)
+os.makedirs(out_set, exist_ok=True)
 os.chdir(out_set)
 img.save(depthmaps, img_names)
\ No newline at end of file
diff --git a/tools/export_msl.py b/tools/export_msl.py
index 79b4c1a..e006e56 100644
--- a/tools/export_msl.py
+++ b/tools/export_msl.py
@@ -74,7 +74,7 @@ if __name__ == "__main__":
         # Load model`
         net, name = load_net(model_file)
 
-        misc.create_dir(os.path.join(opt.outdir, config.to_id()))
+        os.makedirs(os.path.join(opt.outdir, config.to_id()), exist_ok=True)
 
         # Export Sampler
         export_net(ExportNet(net), 'msl', {
diff --git a/tools/export_nmsl.py b/tools/export_nmsl.py
index af5e6c7..1c72a77 100644
--- a/tools/export_nmsl.py
+++ b/tools/export_nmsl.py
@@ -74,7 +74,7 @@ if __name__ == "__main__":
         # Load model`
         net, name = load_net(model_file)
 
-        misc.create_dir(os.path.join(opt.outdir, config.to_id()))
+        os.makedirs(os.path.join(opt.outdir, config.to_id()), exist_ok=True)
 
         # Export Sampler
         export_net(Sampler(net), 'sampler', {
diff --git a/tools/export_onnx.py b/tools/export_onnx.py
index f2745b3..f8a247c 100644
--- a/tools/export_onnx.py
+++ b/tools/export_onnx.py
@@ -54,7 +54,7 @@ if __name__ == "__main__":
         rays_o = torch.empty(batch_size, 3, device=device.default())
         rays_d = torch.empty(batch_size, 3, device=device.default())
 
-        misc.create_dir(opt.outdir)
+        os.makedirs(opt.outdir, exist_ok=True)
 
         # Export the model
         outpath = os.path.join(opt.outdir, config.to_id() + ".onnx")
diff --git a/tools/export_snerf_fast.py b/tools/export_snerf_fast.py
index 68a955b..a2d036c 100644
--- a/tools/export_snerf_fast.py
+++ b/tools/export_snerf_fast.py
@@ -44,7 +44,7 @@ if not opt.output:
     else:
         outdir = f"{dir_path}/export"
         output = os.path.join(outdir, f"{model_file.split('@')[0]}@{batch_size_str}.onnx")
-    misc.create_dir(outdir)
+    os.makedirs(outdir, exist_ok=True)
 else:
     output = opt.output
 outname = os.path.splitext(os.path.split(output)[-1])[0]
diff --git a/tools/gen_video.py b/tools/gen_video.py
index f78c14c..2c2e73a 100644
--- a/tools/gen_video.py
+++ b/tools/gen_video.py
@@ -172,7 +172,7 @@ print('Dataset loaded. Views:', n_views)
 
 
 videodir = os.path.dirname(os.path.abspath(opt.view_file))
-tempdir = '/dev/shm/dvs_tmp/realvideo'
+tempdir = '/dev/shm/dvs_tmp/video'
 videoname = f"{os.path.splitext(os.path.split(opt.view_file)[-1])[0]}_{'stereo' if opt.stereo else 'mono'}"
 gazeout = f"{videodir}/{videoname}_gaze.csv"
 if opt.noCE:
@@ -220,8 +220,8 @@ def add_hint(image, center, right_center=None):
         exit()
 
 
-misc.create_dir(os.path.dirname(inferout))
-misc.create_dir(os.path.dirname(hintout))
+os.makedirs(os.path.dirname(inferout), exist_ok=True)
+os.makedirs(os.path.dirname(hintout), exist_ok=True)
 
 hint_offset = infer_offset = 0
 if not opt.replace:
diff --git a/tools/image_scale.py b/tools/image_scale.py
index 64737ef..f3e8812 100644
--- a/tools/image_scale.py
+++ b/tools/image_scale.py
@@ -8,7 +8,7 @@ from utils import misc
 
 
 def batch_scale(src, target, size):
-    misc.create_dir(target)
+    os.makedirs(target, exist_ok=True)
     for file_name in os.listdir(src):
         postfix = os.path.splitext(file_name)[1]
         if postfix == '.jpg' or postfix == '.png':
diff --git a/tools/merge_dataset.py b/tools/merge_dataset.py
index 31a4f50..0b5c832 100644
--- a/tools/merge_dataset.py
+++ b/tools/merge_dataset.py
@@ -11,7 +11,7 @@ from utils import misc
 
 
 def copy_images(src_path, dst_path, n, offset=0):
-    misc.create_dir(os.path.dirname(dst_path))
+    os.makedirs(os.path.dirname(dst_path), exist_ok=True)
     for i in range(n):
         copy(src_path % i, dst_path % (i + offset))
 
diff --git a/tools/pano_process.py b/tools/pano_process.py
new file mode 100644
index 0000000..fca52c7
--- /dev/null
+++ b/tools/pano_process.py
@@ -0,0 +1,36 @@
+from pathlib import Path
+import sys
+import argparse
+import math
+import torch
+import torchvision.transforms.functional as trans_F
+
+sys.path.append(str(Path(sys.path[0]).parent.absolute()))
+
+from utils import img
+
+parser = argparse.ArgumentParser()
+parser.add_argument('-o', '--output', type=str)
+parser.add_argument('dir', type=str)
+args = parser.parse_args()
+
+data_dir = Path(args.dir)
+output_dir = Path(args.output)
+output_dir.mkdir(parents=True, exist_ok=True)
+
+files = [file for file in data_dir.glob('*') if file.suffix == '.png' or file.suffix == '.jpg']
+outfiles = [output_dir / file.name for file in data_dir.glob('*')
+            if file.suffix == '.png' or file.suffix == '.jpg']
+images = img.load(files)
+print(f"{images.size(0)} images loaded.")
+out_images = torch.zeros_like(images)
+H, W = images.shape[-2:]
+for row in range(H):
+    phi = math.pi / H * (row + 0.5)
+    length = math.ceil(math.sin(phi) * W * 0.5) * 2
+    cols = slice((W - length) // 2, (W + length) // 2)
+    out_images[..., row:row + 1, cols] = trans_F.resize(images[..., row:row + 1, :], [1, length])
+    sys.stdout.write(f'{row + 1} / {H} processed.   \r')
+print('')
+img.save(out_images, outfiles)
+print(f"{images.size(0)} images saved.")
\ No newline at end of file
diff --git a/tools/split_dataset.py b/tools/split_dataset.py
index 1483b1f..3ed4fa2 100644
--- a/tools/split_dataset.py
+++ b/tools/split_dataset.py
@@ -4,26 +4,33 @@ import os
 import argparse
 import numpy as np
 import torch
+from itertools import product, repeat
+from pathlib import Path
 
 sys.path.append(os.path.abspath(sys.path[0] + '/../'))
 
-from utils import misc
-
 parser = argparse.ArgumentParser()
 parser.add_argument('-o', '--output', type=str, default='train1')
+parser.add_argument("-t", "--trans", type=float)
+parser.add_argument("-v", "--views", type=int)
+parser.add_argument('-g', '--grids', nargs='+', type=int)
 parser.add_argument('dataset', type=str)
 args = parser.parse_args()
 
+if not args.dataset.endswith(".json"):
+    args.dataset = args.dataset.rstrip("/") + ".json"
+if not args.output.endswith(".json"):
+    args.output = args.output.rstrip("/") + ".json"
 
-data_desc_path = args.dataset
-data_desc_name = os.path.splitext(os.path.basename(data_desc_path))[0]
-data_dir = os.path.dirname(data_desc_path) + '/'
+in_desc_path = Path(args.dataset)
+in_name = in_desc_path.stem
+root_dir = in_desc_path.parent
+out_desc_path: Path = root_dir / args.output
+out_dir = out_desc_path.with_suffix("")
 
-with open(data_desc_path, 'r') as fp:
+with open(in_desc_path, 'r') as fp:
     dataset_desc = json.load(fp)
 
-indices = torch.arange(len(dataset_desc['view_centers'])).view(dataset_desc['samples'])
-
 idx = 0
 '''
 for i in range(3):
@@ -40,7 +47,7 @@ for i in range(3):
         out_desc['view_rots'] = np.array(dataset_desc['view_rots'])[views].tolist()
         with open(os.path.join(data_dir, f'{out_desc_name}.json'), 'w') as fp:
             json.dump(out_desc, fp, indent=4)
-        misc.create_dir(os.path.join(data_dir, out_desc_name))
+        os.makedirs(os.path.join(data_dir, out_desc_name), exist_ok=True)
         for k in range(len(views)):
             os.symlink(os.path.join('..', dataset_desc['view_file_pattern'] % views[k]),
                     os.path.join(data_dir, out_desc['view_file_pattern'] % views[k]))
@@ -61,26 +68,62 @@ for xi in range(0, 4, 2):
             out_desc['view_rots'] = np.array(dataset_desc['view_rots'])[views].tolist()
             with open(os.path.join(data_dir, f'{out_desc_name}.json'), 'w') as fp:
                 json.dump(out_desc, fp, indent=4)
-            misc.create_dir(os.path.join(data_dir, out_desc_name))
+            os.makedirs(os.path.join(data_dir, out_desc_name), exist_ok=True)
             for k in range(len(views)):
                 os.symlink(os.path.join('..', dataset_desc['view_file_pattern'] % views[k]),
                            os.path.join(data_dir, out_desc['view_file_pattern'] % views[k]))
             idx += 1
 '''
-from itertools import product
-out_desc_name = args.output
+
+
+def extract_by_grid(*grid_indices):
+    indices = torch.arange(len(dataset_desc['view_centers'])).view(dataset_desc['samples'])
+    views = []
+    for idx in product(*grid_indices):
+        views += indices[idx].flatten().tolist()
+    return views
+
+
+def extract_by_trans(max_trans, max_views):
+    if max_trans is not None:
+        centers = np.array(dataset_desc['view_centers'])
+        trans = np.linalg.norm(centers, axis=-1)
+        indices = np.nonzero(trans <= max_trans)[0]
+    else:
+        indices = np.arange(len(dataset_desc['view_centers']))
+    if max_views is not None:
+        indices = np.sort(indices[np.random.permutation(indices.shape[0])[:max_views]])
+    return indices.tolist()
+
+
+if args.grids:
+    views = extract_by_grid(*repeat(args.grids, 3))  # , [0, 2, 3, 5], [1, 2, 3, 4])
+else:
+    views = extract_by_trans(args.trans, args.views)
+
+image_path = dataset_desc['view_file_pattern']
+if "/" not in image_path:
+    image_path = in_name + "/" + image_path
+
+# Save new dataset
 out_desc = dataset_desc.copy()
-out_desc['view_file_pattern'] = f"{out_desc_name}/{dataset_desc['view_file_pattern'].split('/')[-1]}"
-views = []
-for idx in product([1,2,3,4], [1,2,3,4], [1,2,3,4]):#, [0, 2, 3, 5], [1, 2, 3, 4]):
-    views += indices[idx].flatten().tolist()
+out_desc['view_file_pattern'] = image_path.split('/')[-1]
 out_desc['samples'] = [len(views)]
 out_desc['views'] = views
 out_desc['view_centers'] = np.array(dataset_desc['view_centers'])[views].tolist()
-out_desc['view_rots'] = np.array(dataset_desc['view_rots'])[views].tolist()
-with open(os.path.join(data_dir, f'{out_desc_name}.json'), 'w') as fp:
+if 'view_rots' in dataset_desc:
+    out_desc['view_rots'] = np.array(dataset_desc['view_rots'])[views].tolist()
+
+# Write new data desc
+with open(out_desc_path, 'w') as fp:
     json.dump(out_desc, fp, indent=4)
-misc.create_dir(os.path.join(data_dir, out_desc_name))
+
+# Create symbol links of images
+out_dir.mkdir()
 for k in range(len(views)):
-    os.symlink(os.path.join('..', dataset_desc['view_file_pattern'] % views[k]),
-               os.path.join(data_dir, out_desc['view_file_pattern'] % views[k]))
+    if out_dir.parent.absolute() == root_dir.absolute():
+        os.symlink(Path("..") / (image_path % views[k]),
+                   out_dir / (out_desc['view_file_pattern'] % views[k]))
+    else:
+        os.symlink(root_dir.absolute() / (image_path % views[k]),
+                   out_dir / (out_desc['view_file_pattern'] % views[k]))
diff --git a/train.py b/train.py
new file mode 100644
index 0000000..6792a3e
--- /dev/null
+++ b/train.py
@@ -0,0 +1,103 @@
+import argparse
+import logging
+import os
+from pathlib import Path
+import sys
+
+import model as mdl
+import train
+from utils import color
+from utils import device
+from data.dataset_factory import *
+from data.loader import DataLoader
+from utils.misc import list_epochs, print_and_log
+
+
+RAYS_PER_BATCH = 2 ** 16
+DATA_LOADER_CHUNK_SIZE = 1e8
+
+
+parser = argparse.ArgumentParser()
+parser.add_argument('-c', '--config', type=str,
+                    help='Net config files')
+parser.add_argument('-e', '--epochs', type=int, default=50,
+                    help='Max epochs for train')
+parser.add_argument('--perf', type=int, default=0,
+                    help='Performance measurement frames (0 for disabling performance measurement)')
+parser.add_argument('--prune', type=int, default=5,
+                    help='Prune voxels on every # epochs')
+parser.add_argument('--split', type=int, default=10,
+                    help='Split voxels on every # epochs')
+parser.add_argument('--views', type=str,
+                    help='Specify the range of views to train')
+parser.add_argument('path', type=str,
+                    help='Dataset description file')
+args = parser.parse_args()
+
+argpath = Path(args.path)
+# argpath: May be model path or data path
+# 1) model path: continue training on the specified model
+# 2) data path: train a new model using specified dataset
+
+if argpath.suffix == ".tar":
+    args.mdl_path = argpath
+else:
+    existed_epochs = list_epochs(argpath, "checkpoint_*.tar")
+    args.mdl_path = argpath / f"checkpoint_{existed_epochs[-1]}.tar" if existed_epochs else None
+
+if args.mdl_path:
+    # Infer dataset path from model path
+    # The model path follows such rule: <dataset_dir>/_nets/<dataset_name>/<model_name>/checkpoint_*.tar
+    dataset_name = args.mdl_path.parent.parent.name
+    dataset_dir = args.mdl_path.parent.parent.parent.parent
+    args.data_path = dataset_dir / dataset_name
+    args.mdl_path = args.mdl_path.relative_to(dataset_dir)
+else:
+    args.data_path = argpath
+args.views = range(*[int(val) for val in args.views.split('-')]) if args.views else None
+
+dataset = DatasetFactory.load(args.data_path, views_to_load=args.views)
+print(f"Dataset loaded: {dataset.root}/{dataset.name}")
+os.chdir(dataset.root)
+
+if args.mdl_path:
+    # Load model to continue training
+    model, states = mdl.load(args.mdl_path)
+    model_name = args.mdl_path.parent.name
+    model_class = model.__class__.__name__
+    model_args = model.args
+else:
+    # Create model from specified configuration
+    with Path(f'{sys.path[0]}/configs/{args.config}.json').open() as fp:
+        config = json.load(fp)
+    model_name = args.config
+    model_class = config['model']
+    model_args = config['args']
+    model_args['bbox'] = dataset.bbox
+    model_args['depth_range'] = dataset.depth_range
+    model, states = mdl.create(model_class, model_args), None
+model.to(device.default()).train()
+
+run_dir = Path(f"_nets/{dataset.name}/{model_name}")
+run_dir.mkdir(parents=True, exist_ok=True)
+
+log_file = run_dir / "train.log"
+logging.basicConfig(format='%(asctime)s[%(levelname)s] %(message)s', level=logging.INFO,
+                    filename=log_file, filemode='a' if log_file.exists() else 'w')
+
+print_and_log(f"model: {model_name} ({model_class})")
+print_and_log(f"args: {json.dumps(model.args0)}")
+
+
+if __name__ == "__main__":
+    # 1. Initialize data loader
+    data_loader = DataLoader(dataset, RAYS_PER_BATCH, chunk_max_items=DATA_LOADER_CHUNK_SIZE,
+                             shuffle=True, enable_preload=True,
+                             color=color.from_str(model.args['color']))
+
+    # 2. Initialize model and trainer
+    trainer = train.get_trainer(model, run_dir=run_dir, states=states, perf_frames=args.perf,
+                                pruning_loop=args.prune, splitting_loop=args.split)
+
+    # 3. Train
+    trainer.train(data_loader, args.epochs)
\ No newline at end of file
diff --git a/train/__init__.py b/train/__init__.py
new file mode 100644
index 0000000..bb0a4bc
--- /dev/null
+++ b/train/__init__.py
@@ -0,0 +1,26 @@
+import importlib
+import os
+
+from model.base import BaseModel
+from . import base
+
+
+# Automatically import any python files this directory
+package_dir = os.path.dirname(__file__)
+package = os.path.basename(package_dir)
+for file in os.listdir(package_dir):
+    path = os.path.join(package_dir, file)
+    if file.startswith('_') or file.startswith('.'):
+        continue
+    if file.endswith('.py') or os.path.isdir(path):
+        model_name = file[:-3] if file.endswith('.py') else file
+        importlib.import_module(f'{package}.{model_name}')
+
+
+def get_class(class_name: str) -> type:
+    return base.train_classes[class_name]
+
+
+def get_trainer(model: BaseModel, **kwargs) -> base.Train:
+    train_class = get_class(model.trainer)
+    return train_class(model, **kwargs)
diff --git a/train/base.py b/train/base.py
new file mode 100644
index 0000000..78dd048
--- /dev/null
+++ b/train/base.py
@@ -0,0 +1,225 @@
+import csv
+import logging
+import sys
+import time
+import torch
+import torch.nn.functional as nn_f
+from typing import Dict
+from pathlib import Path
+
+import loss
+from utils.constants import HUGE_FLOAT
+from utils.misc import format_time
+from utils.progress_bar import progress_bar
+from utils.perf import Perf, checkpoint, enable_perf, perf, get_perf_result
+from data.loader import DataLoader
+from model.base import BaseModel
+from model import save
+
+
+train_classes = {}
+
+
+class BaseTrainMeta(type):
+
+    def __new__(cls, name, bases, attrs):
+        new_cls = type.__new__(cls, name, bases, attrs)
+        train_classes[name] = new_cls
+        return new_cls
+
+
+class Train(object, metaclass=BaseTrainMeta):
+
+    @property
+    def perf_mode(self):
+        return self.perf_frames > 0
+
+    def __init__(self, model: BaseModel, *,
+                 run_dir: Path, states: dict = None, perf_frames: int = 0) -> None:
+        super().__init__()
+        self.model = model
+        self.epoch = 0
+        self.iters = 0
+        self.run_dir = run_dir
+
+        self.model.train()
+        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=5e-4)
+
+        if states:
+            if 'epoch' in states:
+                self.epoch = states['epoch']
+            if 'iters' in states:
+                self.iters = states['iters']
+            if 'opti' in states:
+                self.optimizer.load_state_dict(states['opti'])
+
+        # For performance measurement
+        self.perf_frames = perf_frames
+        if self.perf_mode:
+            enable_perf()
+
+    def train(self, data_loader: DataLoader, max_epochs: int):
+        self.data_loader = data_loader
+        self.iters_per_epoch = self.perf_frames or len(data_loader)
+
+        print("Begin training...")
+        while self.epoch < max_epochs:
+            self.epoch += 1
+            self._train_epoch()
+            self._save_checkpoint()
+        print("Train finished")
+
+    def _save_checkpoint(self):
+        save(self.run_dir / f'checkpoint_{self.epoch}.tar', self.model, epoch=self.epoch,
+             iters=self.iters, opti=self.optimizer.state_dict())
+        for i in range(1, self.epoch):
+            if i % 10 != 0:
+                (self.run_dir / f'checkpoint_{i}.tar').unlink(missing_ok=True)
+
+    def _show_progress(self, iters_in_epoch: int, loss: Dict[str, float] = {}):
+        loss_val = loss.get('val', 0)
+        loss_min = loss.get('min', 0)
+        loss_max = loss.get('max', 0)
+        loss_avg = loss.get('avg', 0)
+        iters_per_epoch = self.perf_frames or len(self.data_loader)
+        progress_bar(iters_in_epoch, iters_per_epoch,
+                     f"Loss: {loss_val:.2e} ({loss_min:.2e}/{loss_avg:.2e}/{loss_max:.2e})",
+                     f"Epoch {self.epoch:<3d}",
+                     f" {self.run_dir}")
+
+    def _show_perf(self):
+        s = "Performance Report ==>\n"
+        res = get_perf_result()
+        if res is None:
+            s += "No available data.\n"
+        else:
+            for key, val in res.items():
+                path_segs = key.split("/")
+                s += "  " * (len(path_segs) - 1) + f"{path_segs[-1]}: {val:.1f}ms\n"
+        print(s)
+
+    @perf
+    def _train_iter(self, rays_o: torch.Tensor, rays_d: torch.Tensor,
+                    extra: Dict[str, torch.Tensor]) -> float:
+        out = self.model(rays_o, rays_d, extra_outputs=['energies', 'speculars'])
+        if 'rays_mask' in out:
+            extra = {key: value[out['rays_mask']] for key, value in extra.items()}
+        checkpoint("Forward")
+
+        self.optimizer.zero_grad()
+        loss_val = loss.mse_loss(out['color'], extra['color'])
+        if self.model.args.get('density_regularization_weight'):
+            loss_val += loss.cauchy_loss(out['energies'],
+                                         s=self.model.args['density_regularization_scale']) \
+                * self.model.args['density_regularization_weight']
+        if self.model.args.get('specular_regularization_weight'):
+            loss_val += loss.cauchy_loss(out['speculars'],
+                                         s=self.model.args['specular_regularization_scale']) \
+                * self.model.args['specular_regularization_weight']
+        checkpoint("Compute loss")
+
+        loss_val.backward()
+        checkpoint("Backward")
+
+        self.optimizer.step()
+        checkpoint("Update")
+
+        return loss_val.item()
+
+    def _train_epoch(self):
+        iters_in_epoch = 0
+        loss_min = HUGE_FLOAT
+        loss_max = 0
+        loss_avg = 0
+        train_epoch_node = Perf.Node("Train Epoch")
+
+        self._show_progress(iters_in_epoch, loss={'val': 0, 'min': 0, 'max': 0, 'avg': 0})
+        for idx, rays_o, rays_d, extra in self.data_loader:
+            loss_val = self._train_iter(rays_o, rays_d, extra)
+
+            loss_min = min(loss_min, loss_val)
+            loss_max = max(loss_max, loss_val)
+            loss_avg = (loss_avg * iters_in_epoch + loss_val) / (iters_in_epoch + 1)
+
+            self.iters += 1
+            iters_in_epoch += 1
+            self._show_progress(iters_in_epoch, loss={
+                'val': loss_val,
+                'min': loss_min,
+                'max': loss_max,
+                'avg': loss_avg
+            })
+
+            if self.perf_mode and iters_in_epoch >= self.perf_frames:
+                self._show_perf()
+                exit()
+        train_epoch_node.close()
+        torch.cuda.synchronize()
+        epoch_dur = train_epoch_node.duration() / 1000
+        logging.info(f"Epoch {self.epoch} spent {format_time(epoch_dur)} "
+                     f"(Avg. {format_time(epoch_dur / self.iters_per_epoch)}/iter). "
+                     f"Loss is {loss_min:.2e}/{loss_avg:.2e}/{loss_max:.2e}")
+
+    def _train_epoch_debug(self):  # TBR
+        iters_in_epoch = 0
+        loss_min = HUGE_FLOAT
+        loss_max = 0
+        loss_avg = 0
+
+        self._show_progress(iters_in_epoch, loss={'val': 0, 'min': 0, 'max': 0, 'avg': 0})
+        indices = []
+        debug_data = []
+        for idx, rays_o, rays_d, extra in self.data_loader:
+            out = self.model(rays_o, rays_d, extra_outputs=['layers', 'weights'])
+            loss_val = nn_f.mse_loss(out['color'], extra['color']).item()
+
+            loss_min = min(loss_min, loss_val)
+            loss_max = max(loss_max, loss_val)
+            loss_avg = (loss_avg * iters_in_epoch + loss_val) / (iters_in_epoch + 1)
+
+            self.iters += 1
+            iters_in_epoch += 1
+            self._show_progress(iters_in_epoch, loss={
+                'val': loss_val,
+                'min': loss_min,
+                'max': loss_max,
+                'avg': loss_avg
+            })
+
+            indices.append(idx)
+            debug_data.append(torch.cat([
+                extra['view_idx'][..., None],
+                extra['pix_idx'][..., None],
+                rays_d,
+                #out['samples'].pts[:, 215:225].reshape(idx.size(0), -1),
+                #out['samples'].dirs[:, :3].reshape(idx.size(0), -1),
+                #out['samples'].voxel_indices[:, 215:225],
+                out['states'].densities[:, 210:230].detach().reshape(idx.size(0), -1),
+                out['states'].energies[:, 210:230].detach().reshape(idx.size(0), -1)
+                # out['color'].detach()
+            ], dim=-1))
+            # states: VolumnRenderer.States = out['states'] # TBR
+
+        indices = torch.cat(indices, dim=0)
+        debug_data = torch.cat(debug_data, dim=0)
+        indices, sort = indices.sort()
+        debug_data = debug_data[sort]
+        name = "rand.csv" if self.data_loader.shuffle else "seq.csv"
+        with (self.run_dir / name).open("w") as fp:
+            csv_writer = csv.writer(fp)
+            csv_writer.writerows(torch.cat([indices[:20, None], debug_data[:20]], dim=-1).tolist())
+        return
+        with (self.run_dir / 'states.csv').open("w") as fp:
+            csv_writer = csv.writer(fp)
+            for chunk_info in states.chunk_infos:
+                csv_writer.writerow(
+                    [*chunk_info['range'], chunk_info['hits'], chunk_info['core_i']])
+                if chunk_info['hits'] > 0:
+                    csv_writer.writerows(torch.cat([
+                        chunk_info['samples'].pts,
+                        chunk_info['samples'].dirs,
+                        chunk_info['samples'].voxel_indices[:, None],
+                        chunk_info['colors'],
+                        chunk_info['energies']
+                    ], dim=-1).tolist())
+                csv_writer.writerow([])
diff --git a/train/train_with_space.py b/train/train_with_space.py
new file mode 100644
index 0000000..4d236da
--- /dev/null
+++ b/train/train_with_space.py
@@ -0,0 +1,127 @@
+from modules.sampler import Samples
+from modules.space import Octree, Voxels
+from utils.mem_profiler import MemProfiler
+from utils.misc import print_and_log
+from .base import *
+
+
+class TrainWithSpace(Train):
+
+    def __init__(self, model: BaseModel, pruning_loop: int = 10000, splitting_loop: int = 10000,
+                 **kwargs) -> None:
+        super().__init__(model, **kwargs)
+        self.pruning_loop = pruning_loop
+        self.splitting_loop = splitting_loop
+        #MemProfiler.enable = True
+
+    def _train_epoch(self):
+        if not self.perf_mode:
+            if self.epoch != 1:
+                if self.splitting_loop == 1 or self.epoch % self.splitting_loop == 1:
+                    try:
+                        with torch.no_grad():
+                            before, after = self.model.splitting()
+                        print_and_log(
+                            f"Splitting done. # of voxels before: {before}, after: {after}")
+                    except NotImplementedError:
+                        print_and_log(
+                            "Note: The space does not support splitting operation. Just skip it.")
+                if self.pruning_loop == 1 or self.epoch % self.pruning_loop == 1:
+                    try:
+                        with torch.no_grad():
+                            #before, after = self.model.pruning()
+                            # print(f"Pruning by voxel densities done. # of voxels before: {before}, after: {after}")
+                            # self._prune_inner_voxels()
+                            self._prune_voxels_by_weights()
+                    except NotImplementedError:
+                        print_and_log(
+                            "Note: The space does not support pruning operation. Just skip it.")
+
+        super()._train_epoch()
+
+    def _prune_inner_voxels(self):
+        space: Voxels = self.model.space
+        voxel_access_counts = torch.zeros(space.n_voxels, dtype=torch.long,
+                                          device=space.voxels.device)
+        iters_in_epoch = 0
+        batch_size = self.data_loader.batch_size
+        self.data_loader.batch_size = 2 ** 14
+        for _, rays_o, rays_d, _ in self.data_loader:
+            self.model(rays_o, rays_d,
+                       raymarching_early_stop_tolerance=0.01,
+                       raymarching_chunk_size_or_sections=[1],
+                       perturb_sample=False,
+                       voxel_access_counts=voxel_access_counts,
+                       voxel_access_tolerance=0)
+            iters_in_epoch += 1
+            percent = iters_in_epoch / len(self.data_loader) * 100
+            sys.stdout.write(f'Pruning inner voxels...{percent:.1f}%   \r')
+        self.data_loader.batch_size = batch_size
+        before, after = space.prune(voxel_access_counts > 0)
+        print(f"Prune inner voxels: {before} -> {after}")
+
+    def _prune_voxels_by_weights(self):
+        space: Voxels = self.model.space
+        voxel_access_counts = torch.zeros(space.n_voxels, dtype=torch.long,
+                                          device=space.voxels.device)
+        iters_in_epoch = 0
+        batch_size = self.data_loader.batch_size
+        self.data_loader.batch_size = 2 ** 14
+        for _, rays_o, rays_d, _ in self.data_loader:
+            ret = self.model(rays_o, rays_d,
+                             raymarching_early_stop_tolerance=0,
+                             raymarching_chunk_size_or_sections=None,
+                             perturb_sample=False,
+                             extra_outputs=['weights'])
+            valid_mask = ret['weights'][..., 0] > 0.01
+            accessed_voxels = ret['samples'].voxel_indices[valid_mask]
+            voxel_access_counts.index_add_(0, accessed_voxels, torch.ones_like(accessed_voxels))
+            iters_in_epoch += 1
+            percent = iters_in_epoch / len(self.data_loader) * 100
+            sys.stdout.write(f'Pruning by weights...{percent:.1f}%   \r')
+        self.data_loader.batch_size = batch_size
+        before, after = space.prune(voxel_access_counts > 0)
+        print_and_log(f"Prune by weights: {before} -> {after}")
+
+    def _prune_voxels_by_voxel_weights(self):
+        space: Voxels = self.model.space
+        voxel_access_counts = torch.zeros(space.n_voxels, dtype=torch.long,
+                                          device=space.voxels.device)
+        with torch.no_grad():
+            batch_size = self.data_loader.batch_size
+            self.data_loader.batch_size = 2 ** 14
+            iters_in_epoch = 0
+            for _, rays_o, rays_d, _ in self.data_loader:
+                ret = self.model(rays_o, rays_d,
+                                 raymarching_early_stop_tolerance=0,
+                                 raymarching_chunk_size_or_sections=None,
+                                 perturb_sample=False,
+                                 extra_outputs=['weights'])
+                self._accumulate_access_count_by_weight(ret['samples'], ret['weights'][..., 0],
+                                                        voxel_access_counts)
+                iters_in_epoch += 1
+                percent = iters_in_epoch / len(self.data_loader) * 100
+                sys.stdout.write(f'Pruning by voxel weights...{percent:.1f}%   \r')
+            self.data_loader.batch_size = batch_size
+        before, after = space.prune(voxel_access_counts > 0)
+        print_and_log(f"Prune by voxel weights: {before} -> {after}")
+
+    def _accumulate_access_count_by_weight(self, samples: Samples, weights: torch.Tensor,
+                                           voxel_access_counts: torch.Tensor):
+        uni_vidxs = -torch.ones_like(samples.voxel_indices)
+        vidx_accu = torch.zeros_like(samples.voxel_indices, dtype=torch.float)
+        uni_vidxs_row = torch.arange(samples.size[0], dtype=torch.long, device=samples.device)
+        uni_vidxs_head = torch.zeros_like(samples.voxel_indices[:, 0])
+        uni_vidxs[:, 0] = samples.voxel_indices[:, 0]
+        vidx_accu[:, 0].add_(weights[:, 0])
+        for i in range(samples.size[1]):
+            # For those rows that voxels are changed, move the head one step forward
+            next_voxel = uni_vidxs[uni_vidxs_row, uni_vidxs_head].ne(samples.voxel_indices[:, i])
+            uni_vidxs_head[next_voxel].add_(1)
+            # Set voxel indices and accumulate weights
+            uni_vidxs[uni_vidxs_row, uni_vidxs_head] = samples.voxel_indices[:, i]
+            vidx_accu[uni_vidxs_row, uni_vidxs_head].add_(weights[:, i])
+        max_accu = vidx_accu.max(dim=1, keepdim=True)[0]
+        uni_vidxs[vidx_accu < max_accu * 0.1] = -1
+        access_voxels, access_count = uni_vidxs.unique(return_counts=True)
+        voxel_access_counts[access_voxels[1:]].add_(access_count[1:])
diff --git a/train_oracle.py b/train_oracle.py
index 82e5afd..219d81a 100644
--- a/train_oracle.py
+++ b/train_oracle.py
@@ -260,8 +260,8 @@ def train():
     if epochRange.start > 1:
         iters = netio.load(f'{run_dir}model-epoch_{epochRange.start - 1}.pth', model)
     else:
-        misc.create_dir(run_dir)
-        misc.create_dir(log_dir)
+        os.makedirs(run_dir, exist_ok=True)
+        os.makedirs(log_dir, exist_ok=True)
         iters = 0
 
     # 3. Train
@@ -333,7 +333,7 @@ def test():
 
         # 4. Save results
         print('Saving results...')
-        misc.create_dir(output_dir)
+        os.makedirs(output_dir, exist_ok=True)
 
         for key in out:
             shape = [n] + list(dataset.view_res) + list(out[key].size()[1:])
@@ -367,7 +367,7 @@ def test():
                     for i in range(n)
                 ])
         output_subdir = f"{output_dir}/{output_dataset_id}_bins"
-        misc.create_dir(output_subdir)
+        os.makedirs(output_subdir, exist_ok=True)
         img.save(out['bins'], [f'{output_subdir}/{i:0>4d}.png' for i in dataset.view_idxs])
 
 
diff --git a/upsampling/run_upsampling.py b/upsampling/run_upsampling.py
index 79ee6ea..8b90e2f 100644
--- a/upsampling/run_upsampling.py
+++ b/upsampling/run_upsampling.py
@@ -60,7 +60,7 @@ args.color = color.from_str(args.color)
 
 
 def train():
-    misc.create_dir(run_dir)
+    os.makedirs(run_dir, exist_ok=True)
     train_set = UpsamplingDataset('.', 'input/out_view_%04d.png',
                                   'gt/view_%04d.png', color=args.color)
     training_data_loader = FastDataLoader(dataset=train_set,
@@ -80,7 +80,7 @@ def train():
 
 
 def test():
-    misc.create_dir(os.path.dirname(args.testOutPatt))
+    os.makedirs(os.path.dirname(args.testOutPatt), exist_ok=True)
     train_set = UpsamplingDataset(
         '.', 'input/out_view_%04d.png', None, color=args.color)
     training_data_loader = FastDataLoader(dataset=train_set,
diff --git a/utils/constants.py b/utils/constants.py
index 8d2ad1b..42601d4 100644
--- a/utils/constants.py
+++ b/utils/constants.py
@@ -2,4 +2,6 @@ import math
 
 HUGE_FLOAT = 1e10
 TINY_FLOAT = 1e-6
-PI = math.pi
\ No newline at end of file
+PI = math.pi
+NAN = math.nan
+E = math.e
\ No newline at end of file
diff --git a/utils/geometry.py b/utils/geometry.py
new file mode 100644
index 0000000..527ac4a
--- /dev/null
+++ b/utils/geometry.py
@@ -0,0 +1,284 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+from typing import Union
+import numpy as np
+import torch
+import torch.nn.functional as F
+
+INF = 1000.0
+
+
+def ones_like(x):
+    T = torch if isinstance(x, torch.Tensor) else np
+    return T.ones_like(x)
+
+
+def stack(x):
+    T = torch if isinstance(x[0], torch.Tensor) else np
+    return T.stack(x)
+
+
+def matmul(x, y):
+    T = torch if isinstance(x, torch.Tensor) else np
+    return T.matmul(x, y)
+
+
+def cross(x, y, axis=0):
+    T = torch if isinstance(x, torch.Tensor) else np
+    return T.cross(x, y, axis)
+
+
+def cat(x, axis=1):
+    if isinstance(x[0], torch.Tensor):
+        return torch.cat(x, dim=axis)
+    return np.concatenate(x, axis=axis)
+
+
+def normalize(x, axis=-1, order=2):
+    if isinstance(x, torch.Tensor):
+        l2 = x.norm(p=order, dim=axis, keepdim=True)
+        return x / (l2 + 1e-8), l2
+
+    else:
+        l2 = np.linalg.norm(x, order, axis)
+        l2 = np.expand_dims(l2, axis)
+        l2[l2 == 0] = 1
+        return x / l2, l2
+
+
+def parse_extrinsics(extrinsics, world2camera=True):
+    """ this function is only for numpy for now"""
+    if extrinsics.shape[0] == 3 and extrinsics.shape[1] == 4:
+        extrinsics = np.vstack([extrinsics, np.array([[0, 0, 0, 1.0]])])
+    if extrinsics.shape[0] == 1 and extrinsics.shape[1] == 16:
+        extrinsics = extrinsics.reshape(4, 4)
+    if world2camera:
+        extrinsics = np.linalg.inv(extrinsics).astype(np.float32)
+    return extrinsics
+
+
+def parse_intrinsics(intrinsics):
+    fx = intrinsics[0, 0]
+    fy = intrinsics[1, 1]
+    cx = intrinsics[0, 2]
+    cy = intrinsics[1, 2]
+    return fx, fy, cx, cy
+
+
+def uv2cam(uv, z, intrinsics, homogeneous=False):
+    fx, fy, cx, cy = parse_intrinsics(intrinsics)
+    x_lift = (uv[0] - cx) / fx * z
+    y_lift = (uv[1] - cy) / fy * z
+    z_lift = ones_like(x_lift) * z
+
+    if homogeneous:
+        return stack([x_lift, y_lift, z_lift, ones_like(z_lift)])
+    else:
+        return stack([x_lift, y_lift, z_lift])
+
+
+def cam2world(xyz_cam, inv_RT):
+    return matmul(inv_RT, xyz_cam)[:3]
+
+
+def r6d2mat(d6: torch.Tensor) -> torch.Tensor:
+    """
+    Converts 6D rotation representation by Zhou et al. [1] to rotation matrix
+    using Gram--Schmidt orthogonalisation per Section B of [1].
+    Args:
+        d6: 6D rotation representation, of size (*, 6)
+
+    Returns:
+        batch of rotation matrices of size (*, 3, 3)
+
+    [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H.
+    On the Continuity of Rotation Representations in Neural Networks.
+    IEEE Conference on Computer Vision and Pattern Recognition, 2019.
+    Retrieved from http://arxiv.org/abs/1812.07035
+    """
+
+    a1, a2 = d6[..., :3], d6[..., 3:]
+    b1 = F.normalize(a1, dim=-1)
+    b2 = a2 - (b1 * a2).sum(-1, keepdim=True) * b1
+    b2 = F.normalize(b2, dim=-1)
+    b3 = torch.cross(b1, b2, dim=-1)
+    return torch.stack((b1, b2, b3), dim=-2)
+
+
+def get_ray_direction(ray_start, uv, intrinsics, inv_RT, depths=None):
+    if depths is None:
+        depths = 1
+    rt_cam = uv2cam(uv, depths, intrinsics, True)
+    rt = cam2world(rt_cam, inv_RT)
+    ray_dir, _ = normalize(rt - ray_start[:, None], axis=0)
+    return ray_dir
+
+
+def look_at_rotation(camera_position, at=None, up=None, inverse=False, cv=False):
+    """
+    This function takes a vector 'camera_position' which specifies the location
+    of the camera in world coordinates and two vectors `at` and `up` which
+    indicate the position of the object and the up directions of the world
+    coordinate system respectively. The object is assumed to be centered at
+    the origin.
+
+    The output is a rotation matrix representing the transformation
+    from world coordinates -> view coordinates.
+
+    Input:
+        camera_position: 3
+        at: 1 x 3 or N x 3  (0, 0, 0) in default
+        up: 1 x 3 or N x 3  (0, 1, 0) in default
+    """
+
+    if at is None:
+        at = torch.zeros_like(camera_position)
+    else:
+        at = torch.tensor(at).type_as(camera_position)
+    if up is None:
+        up = torch.zeros_like(camera_position)
+        up[2] = -1
+    else:
+        up = torch.tensor(up).type_as(camera_position)
+
+    z_axis = normalize(at - camera_position)[0]
+    x_axis = normalize(cross(up, z_axis))[0]
+    y_axis = normalize(cross(z_axis, x_axis))[0]
+
+    R = cat([x_axis[:, None], y_axis[:, None], z_axis[:, None]], axis=1)
+    return R
+
+
+def ray(ray_start, ray_dir, depths):
+    return ray_start + ray_dir * depths
+
+
+def compute_normal_map(ray_start, ray_dir, depths, RT, width=512, proj=False):
+    raise NotImplementedError("This function needs fairnr.data.data_utils to work. "
+                              "Will remove this dependency later.")
+    # TODO:
+    # this function is pytorch-only (for not)
+    wld_coords = ray(ray_start, ray_dir, depths.unsqueeze(-1)).transpose(0, 1)
+    cam_coords = matmul(RT[:3, :3], wld_coords) + RT[:3, 3].unsqueeze(-1)
+    cam_coords = D.unflatten_img(cam_coords, width)
+
+    # estimate local normal
+    shift_l = cam_coords[:, 2:, :]
+    shift_r = cam_coords[:, :-2, :]
+    shift_u = cam_coords[:, :, 2:]
+    shift_d = cam_coords[:, :, :-2]
+    diff_hor = normalize(shift_r - shift_l, axis=0)[0][:, :, 1:-1]
+    diff_ver = normalize(shift_u - shift_d, axis=0)[0][:, 1:-1, :]
+    normal = cross(diff_hor, diff_ver)
+    _normal = normal.new_zeros(*cam_coords.size())
+    _normal[:, 1:-1, 1:-1] = normal
+    _normal = _normal.reshape(3, -1).transpose(0, 1)
+
+    # compute the projected color
+    if proj:
+        _normal = normalize(_normal, axis=1)[0]
+        wld_coords0 = ray(ray_start, ray_dir, 0).transpose(0, 1)
+        cam_coords0 = matmul(RT[:3, :3], wld_coords0) + RT[:3, 3].unsqueeze(-1)
+        cam_coords0 = D.unflatten_img(cam_coords0, width)
+        cam_raydir = normalize(cam_coords - cam_coords0, 0)[0].reshape(3, -1).transpose(0, 1)
+        proj_factor = (_normal * cam_raydir).sum(-1).abs() * 0.8 + 0.2
+        return proj_factor
+    return _normal
+
+
+# helper functions for encoder
+
+def padding_points(xs, pad):
+    if len(xs) == 1:
+        return xs[0].unsqueeze(0)
+
+    maxlen = max([x.size(0) for x in xs])
+    xt = xs[0].new_ones(len(xs), maxlen, xs[0].size(1)).fill_(pad)
+    for i in range(len(xs)):
+        xt[i, :xs[i].size(0)] = xs[i]
+    return xt
+
+
+def pruning_points(feats, points, scores, depth=0, th=0.5):
+    if depth > 0:
+        g = int(8 ** depth)
+        scores = scores.reshape(scores.size(0), -1, g).sum(-1, keepdim=True)
+        scores = scores.expand(*scores.size()[:2], g).reshape(scores.size(0), -1)
+    alpha = (1 - torch.exp(-scores)) > th
+    feats = [feats[i][alpha[i]] for i in range(alpha.size(0))]
+    points = [points[i][alpha[i]] for i in range(alpha.size(0))]
+    points = padding_points(points, INF)
+    feats = padding_points(feats, 0)
+    return feats, points
+
+
+def offset_points(point_xyz: torch.Tensor, half_voxel: Union[torch.Tensor, int, float] = 1,
+                  offset_only: bool = False, bits: int = 2) -> torch.Tensor:
+    """
+    [summary]
+
+    :param point_xyz `Tensor(N, 3)`: [description]
+    :param half_voxel `Tensor(1) | int | float`: [description], defaults to 1
+    :param offset_only `bool`: [description], defaults to False
+    :param bits `int`: [description], defaults to 2
+    :return `Tensor(N, X, 3)|Tensor(X, 3)`: [description]
+    """
+    c = torch.arange(1 - bits, bits, 2, dtype=point_xyz.dtype, device=point_xyz.device)
+    offset = (torch.stack(torch.meshgrid(c, c, c), dim=-1).reshape(-1, 3)) / (bits - 1) * half_voxel
+    return offset if offset_only else point_xyz[:, None] + offset
+
+
+def discretize_points(voxel_points, voxel_size):
+    # this function turns voxel centers/corners into integer indeices
+    # we assume all points are alreay put as voxels (real numbers)
+    minimal_voxel_point = voxel_points.min(dim=0, keepdim=True)[0]
+    voxel_indices = ((voxel_points - minimal_voxel_point) / voxel_size).round_().long()  # float
+    residual = (voxel_points - voxel_indices.type_as(voxel_points)
+                * voxel_size).mean(0, keepdim=True)
+    return voxel_indices, residual
+
+
+def expand_points(voxel_points, voxel_size):
+    _voxel_size = min([
+        torch.sqrt(((voxel_points[j:j + 1] - voxel_points[j + 1:]) ** 2).sum(-1).min())
+        for j in range(100)])
+    depth = int(np.round(torch.log2(_voxel_size / voxel_size)))
+    if depth > 0:
+        half_voxel = _voxel_size / 2.0
+        for _ in range(depth):
+            voxel_points = offset_points(voxel_points, half_voxel / 2.0).reshape(-1, 3)
+            half_voxel = half_voxel / 2.0
+
+    return voxel_points, depth
+
+
+def get_edge(depth_pts, voxel_pts, voxel_size, th=0.05):
+    voxel_pts = offset_points(voxel_pts, voxel_size / 2.0)
+    diff_pts = (voxel_pts - depth_pts[:, None, :]).norm(dim=2)
+    ab = diff_pts.sort(dim=1)[0][:, :2]
+    a, b = ab[:, 0], ab[:, 1]
+    c = voxel_size
+    p = (ab.sum(-1) + c) / 2.0
+    h = (p * (p - a) * (p - b) * (p - c)) ** 0.5 / c
+    return h < (th * voxel_size)
+
+
+# fill-in image
+def fill_in(shape, hits, input, initial=1.0):
+    input_sizes = [k for k in input.size()]
+    if (len(input_sizes) == len(shape)) and \
+            all([shape[i] == input_sizes[i] for i in range(len(shape))]):
+        return input   # shape is the same no need to fill
+
+    if isinstance(initial, torch.Tensor):
+        output = initial.expand(*shape)
+    else:
+        output = input.new_ones(*shape) * initial
+    if input is not None:
+        if len(shape) == 1:
+            return output.masked_scatter(hits, input)
+        return output.masked_scatter(hits.unsqueeze(-1).expand(*shape), input)
+    return output
diff --git a/utils/img.py b/utils/img.py
index a39d308..8920922 100644
--- a/utils/img.py
+++ b/utils/img.py
@@ -1,10 +1,11 @@
 import os
+from pathlib import Path
 import shutil
 import torch
 import matplotlib.pyplot as plt
 import numpy as np
 import torch.nn.functional as nn_f
-from typing import Tuple
+from typing import List, Tuple, Union
 from . import misc
 from .constants import *
 
@@ -65,7 +66,7 @@ def load(*paths: str, permute=True, with_alpha=False) -> torch.Tensor:
     chns = 4 if with_alpha else 3
     new_paths = []
     for path in paths:
-        new_paths += [path] if isinstance(path, str) else list(path)
+        new_paths += [path] if isinstance(path, (str, Path)) else list(path)
     imgs = np.stack([plt.imread(path)[..., :chns] for path in new_paths])
     if imgs.dtype == 'uint8':
         imgs = imgs.astype(np.float32) / 255
@@ -76,7 +77,7 @@ def load_seq(path: str, n: int, permute=True, with_alpha=False) -> torch.Tensor:
     return load([path % i for i in range(n)], permute=permute, with_alpha=with_alpha)
 
 
-def save(input: torch.Tensor, *paths: str):
+def save(input: torch.Tensor, *paths: Union[str, Path, List[Union[str, Path]]]):
     """
     Save one or multiple torch-image(s) to `paths`
 
@@ -86,7 +87,7 @@ def save(input: torch.Tensor, *paths: str):
     """
     new_paths = []
     for path in paths:
-        new_paths += [path] if isinstance(path, str) else list(path)
+        new_paths += [path] if isinstance(path, (str, Path)) else list(path)
     if len(input.size()) < 4:
         input = input[None]
     if input.size(0) != len(new_paths):
@@ -100,9 +101,9 @@ def save(input: torch.Tensor, *paths: str):
         plt.imsave(path, np_img[i])
 
 
-def save_seq(input: torch.Tensor, path: str):
+def save_seq(input: torch.Tensor, path: Union[str, Path]):
     n = 1 if len(input.size()) <= 3 else input.size(0)
-    return save(input, [path % i for i in range(n)])
+    return save(input, [str(path) % i for i in range(n)])
 
 
 def plot(input: torch.Tensor, *, ax: plt.Axes = None):
@@ -118,7 +119,7 @@ def plot(input: torch.Tensor, *, ax: plt.Axes = None):
     return plt.imshow(im) if ax is None else ax.imshow(im)
 
 
-def save_video(frames: torch.Tensor, path: str, fps: int,
+def save_video(frames: torch.Tensor, path: Union[str, Path], fps: int,
                repeat: int = 1, pingpong: bool = False):
     """
     Encode and save a sequence of frames as video file
@@ -134,19 +135,16 @@ def save_video(frames: torch.Tensor, path: str, fps: int,
         frames = torch.cat([frames, frames.flip(0)], 0)
     if repeat > 1:
         frames = frames.expand(repeat, -1, -1, -1, -1).flatten(0, 1)
-    dir, file_name = os.path.split(path)
-    if not dir:
-        dir = './'
-    misc.create_dir(dir)
-    cwd = os.getcwd()
-    os.chdir(dir)
-    temp_out_dir = os.path.splitext(file_name)[0] + '_tempout'
-    misc.create_dir(temp_out_dir)
-    os.chdir(temp_out_dir)
-    save_seq(frames, 'out_%04d.png')
-    os.system(f'ffmpeg -y -r {fps:d} -i out_%04d.png -c:v libx264 ../{file_name}')
-    os.chdir(cwd)
-    shutil.rmtree(os.path.join(dir, temp_out_dir))
+
+    path = Path(path)
+    tempdir = Path('/dev/shm/dvs_tmp/video')
+    inferout = tempdir / path.stem / f"%04d.bmp"
+    os.makedirs(inferout.parent, exist_ok=True)
+    os.makedirs(path.parent, exist_ok=True)
+
+    save_seq(frames, inferout)
+    os.system(f'ffmpeg -y -r {fps:d} -i {inferout} -c:v libx264 {path}')
+    shutil.rmtree(inferout.parent)
 
 
 def horizontal_shift(input: torch.Tensor, offset: int, dim=-1) -> torch.Tensor:
diff --git a/utils/mem_profiler.py b/utils/mem_profiler.py
index d034bc0..848d4e7 100644
--- a/utils/mem_profiler.py
+++ b/utils/mem_profiler.py
@@ -2,13 +2,14 @@ from cgitb import enable
 import torch
 from .device import *
 
+
 class MemProfiler:
 
     enable = False
 
     @staticmethod
-    def print_memory_stats(prefix, last_allocated=None, device=None):
-        if not MemProfiler.enable:
+    def print_memory_stats(prefix, last_allocated=None, device=None, enable_once=False):
+        if not enable_once and not MemProfiler.enable:
             return
         if device is None:
             device = default()
diff --git a/utils/misc.py b/utils/misc.py
index 2bf7250..e6f49fe 100644
--- a/utils/misc.py
+++ b/utils/misc.py
@@ -1,9 +1,13 @@
-import os
+from itertools import repeat
+import logging
+from pathlib import Path
+import re
+import shutil
 import torch
 import glm
 import csv
 import numpy as np
-from typing import List, Tuple, Union
+from typing import List, Union
 from torch.types import Number
 from .constants import *
 from .device import *
@@ -59,31 +63,11 @@ def meshgrid(*size: int, normalize: bool = False, swap_dim: bool = False) -> tor
     return torch.stack([x / (size[1] - 1.), y / (size[0] - 1.)], 2) if normalize else torch.stack([x, y], 2)
 
 
-def create_dir(path):
-    if not os.path.exists(path):
-        os.makedirs(path)
-
-
 def get_angle(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
-    angle = -torch.atan(x / y) + (y < 0) * PI + 0.5 * PI
+    angle = -torch.atan(x / y) - (y < 0) * PI + 0.5 * PI
     return angle
 
 
-def depth_sample(depth_range: Tuple[float, float], n: int, lindisp: bool) -> torch.Tensor:
-    """
-    Get [n_layers] foreground layers whose diopters are distributed uniformly
-    in  [depth_range] plus a background layer
-
-    :param depth_range: depth range of foreground layers
-    :param n_layers: number of foreground layers
-    :return: list of [n_layers+1] depths
-    """
-    if lindisp:
-        depth_range = (1 / depth_range[0], 1 / depth_range[1])
-    samples = torch.linspace(depth_range[0], depth_range[1], n)
-    return samples
-
-
 def broadcast_cat(input: torch.Tensor,
                   s: Union[Number, List[Number], torch.Tensor],
                   dim=-1,
@@ -130,4 +114,73 @@ def view_like(input: torch.Tensor, ref: torch.Tensor) -> torch.Tensor:
     return input.view(out_shape)
 
 
-def values(map, *keys): return list(map[key] for key in keys)
+def format_time(seconds):
+    days = int(seconds / 3600 / 24)
+    seconds = seconds - days * 3600 * 24
+    hours = int(seconds / 3600)
+    seconds = seconds - hours * 3600
+    minutes = int(seconds / 60)
+    seconds = seconds - minutes * 60
+    seconds_final = int(seconds)
+    seconds = seconds - seconds_final
+    millis = int(seconds * 1000)
+
+    if days > 0:
+        output = f"{days}D{hours:0>2d}h{minutes:0>2d}m"
+    elif hours > 0:
+        output = f"{hours:0>2d}h{minutes:0>2d}m{seconds_final:0>2d}s"
+    elif minutes > 0:
+        output = f"{minutes:0>2d}m{seconds_final:0>2d}s"
+    elif seconds_final > 0:
+        output = f"{seconds_final:0>2d}s{millis:0>3d}ms"
+    elif millis > 0:
+        output = f"{millis:0>3d}ms"
+    else:
+        output = '0ms'
+    return output
+
+
+def print_and_log(s):
+    print(s)
+    logging.info(s)
+
+
+def masked_scatter(mask: torch.Tensor, value: torch.Tensor, initial: Union[torch.Tensor, Number] = 0):
+    """
+    Extend PyTorch's built-in `masked_scatter` function
+
+    :param mask `Tensor(M...)`: the boolean mask
+    :param value `Tensor(N, D...)`: the value to fill in with, should have at least as many elements 
+                                    as the number of ones in `mask`
+    :param destination `Tensor(M..., D...)`: (optional) the destination tensor to fill,
+                                             if not specified, a new tensor filled with 
+                                             `empty_value` will be created and used as destination
+    :param empty_value `Number`: the initial elements in the newly created destination tensor, 
+                                 defaults to 0
+    :return `Tensor(M..., D...)`: the destination tensor after filled
+    """
+    M_ = mask.size()
+    D_ = value.size()[1:]
+    if not isinstance(initial, torch.Tensor):
+        initial = value.new_full([*M_, *D_], initial)
+    return initial.masked_scatter(mask.reshape(*M_, *repeat(1, len(D_))), value)
+
+
+def list_epochs(dir: Path, pattern: str) -> List[int]:
+    prefix = pattern.split("*")[0]
+    epoch_list = [int(str(path.stem)[len(prefix):]) for path in dir.glob(pattern)]
+    epoch_list.sort()
+    return epoch_list
+
+
+def rename_seqs_with_offset(dir: Path, file_pattern: str, offset: int):
+    start, end = re.search(r'%0\dd', file_pattern).span()
+    prefix, suffix = start, len(file_pattern) - end
+
+    seqs = [
+        int(path.name[prefix:-suffix])
+        for path in dir.glob(re.sub(r'%0\dd', "*", file_pattern))
+    ]
+    seqs.sort(reverse=offset > 0)
+    for i in seqs:
+        (dir / (file_pattern % i)).rename(dir / (file_pattern % (i + offset)))
diff --git a/utils/perf.py b/utils/perf.py
index 2b9f278..5a3dd58 100644
--- a/utils/perf.py
+++ b/utils/perf.py
@@ -1,32 +1,137 @@
+from numpy import average
+import torch
 import torch.cuda
+from typing import Dict, List, OrderedDict
 
 
 class Perf(object):
+    frames: List[Dict[str, float]]
 
-    def __init__(self, enable, start=False) -> None:
+    class Node:
+        def __init__(self, name, parent=None) -> None:
+            self.name = name
+            self.parent = parent
+            self.events = []
+            self.event_names = []
+            self.child_nodes = []
+            self.child_nodes_event_idx = []
+            self.add_checkpoint("Start")
+
+        def add_checkpoint(self, name):
+            event = torch.cuda.Event(enable_timing=True)
+            event.record()
+            self.events.append(event)
+            self.event_names.append(name)
+
+        def add_child(self, name):
+            child = Perf.Node(name, self)
+            self.child_nodes.append(child)
+            self.child_nodes_event_idx.append(len(self.events))
+            return child
+
+        def close(self):
+            self.add_checkpoint("End")
+            return self.parent
+
+        def duration(self, i0=0, i1=-1) -> float:
+            return self.events[i0].elapsed_time(self.events[i1])
+
+        def result(self, prefix: str = '') -> OrderedDict[str, float]:
+            path = f"{prefix}{self.name}"
+            res = {path: self.duration()}
+            j = 0
+            for i in range(1, len(self.events) - 1):
+                event_path = f"{path}/{self.event_names[i]}"
+                res[event_path] = self.duration(i - 1, i)
+                while j < len(self.child_nodes):
+                    if self.child_nodes_event_idx[j] > i:
+                        break
+                    res.update(self.child_nodes[j].result(f"{event_path}/"))
+                    j += 1
+            while j < len(self.child_nodes):
+                res.update(self.child_nodes[j].result(f"{path}/"))
+                j += 1
+            return res
+
+    def __init__(self) -> None:
         super().__init__()
-        self.enable = enable
-        self.start_event = None
-        if start:
-            self.start()
-
-    def start(self):
-        if not self.enable:
-            return
-        if self.start_event == None:
-            self.start_event = torch.cuda.Event(enable_timing=True)
-            self.end_event = torch.cuda.Event(enable_timing=True)
-        torch.cuda.synchronize()
-        self.start_event.record()
-
-    def checkpoint(self, name: str = None, end: bool = False):
-        if not self.enable:
-            return 0
-        self.end_event.record()
-        torch.cuda.synchronize()
-        duration = self.start_event.elapsed_time(self.end_event)
-        if name:
-            print('%s: %.1fms' % (name, duration))
-        if not end:
-            self.start_event.record()
-        return duration
+        self.root_node = None
+        self.current_node = None
+        self.frames = []
+
+    def start_node(self, name):
+        if self.current_node is None:
+            self.root_node = self.current_node = Perf.Node(name)
+        else:
+            self.current_node = self.current_node.add_child(name)
+
+    def checkpoint(self, name):
+        self.current_node.add_checkpoint(name)
+
+    def end_node(self):
+        self.current_node = self.current_node.close()
+        if self.current_node is None:
+            torch.cuda.synchronize()
+            self.frames.append(self.root_node.result())
+
+    def get_result(self, i=None):
+        if i is not None:
+            return self.frames[i]
+        if len(self.frames) == 0:
+            return {}
+        res = {key: [val] for key, val in self.frames[0].items()}
+        for i in range(1, len(self.frames)):
+            for key, val in self.frames[i].items():
+                res[key].append(val)
+        return {key: average(val) for key, val in res.items()}
+
+
+default_perf_object = None
+
+
+def enable_perf():
+    global default_perf_object
+    default_perf_object = Perf()
+
+
+def perf(fn_or_name):
+    if isinstance(fn_or_name, str):
+        name = fn_or_name
+
+        def perf_with_name(fn):
+            def wrap_perf(*args, **kwargs):
+                start_node(name)
+                ret = fn(*args, **kwargs)
+                end_node()
+                return ret
+            return wrap_perf
+        return perf_with_name
+    fn = fn_or_name
+
+    def wrap_perf(*args, **kwargs):
+        start_node(fn.__qualname__)
+        ret = fn(*args, **kwargs)
+        end_node()
+        return ret
+    return wrap_perf
+
+
+def start_node(name):
+    if default_perf_object is not None:
+        default_perf_object.start_node(name)
+
+
+def end_node():
+    if default_perf_object is not None:
+        default_perf_object.end_node()
+
+
+def checkpoint(name):
+    if default_perf_object is not None:
+        default_perf_object.checkpoint(name)
+
+
+def get_perf_result(i=None):
+    if default_perf_object is not None:
+        return default_perf_object.get_result(i)
+    return None
diff --git a/utils/progress_bar.py b/utils/progress_bar.py
index 144586d..fd84774 100644
--- a/utils/progress_bar.py
+++ b/utils/progress_bar.py
@@ -1,78 +1,50 @@
+import shutil
 import sys
 import time
-import os
+from .misc import format_time
+from .constants import NAN
 
-bar_length = 50
-LAST_T = time.time()
-BEGIN_T = LAST_T
 
+last_time = time.time()
+begin_time = last_time
 
-def get_terminal_columns():
-    return os.get_terminal_size().columns
 
-
-def progress_bar(current, total, msg=None, premsg=None):
-    global LAST_T, BEGIN_T
+def progress_bar(current, total, msg=None, premsg=None, barmsg=None):
+    global last_time, begin_time
     if current == 0:
-        BEGIN_T = time.time()  # Reset for new bar.
+        begin_time = time.time()  # Reset for new bar.
     current_time = time.time()
-    step_time = current_time - LAST_T
-    LAST_T = current_time
-    total_time = current_time - BEGIN_T
+    step_time = current_time - last_time
+    total_time = current_time - begin_time
+    last_time = current_time
+    estimated_time = 0 if current == 0 else total_time / current * (total - current)
+
+    show_opt = int(current_time) % 6 >= 3 and current < total
+    show_barmsg = barmsg is not None and show_opt
 
     str0 = f"{premsg} [" if premsg else '['
-    str1 = f"] {current + 1:d}/{total:d} | Step: {format_time(step_time)} | Tot: {format_time(total_time)}"
+    str1 = f"] {current:d}/{total:d} | Step: {format_time(step_time)} | " + (
+        f"Eta: {format_time(estimated_time)}" if show_opt else f"Tot: {format_time(total_time)}"
+    )
     if msg:
         str1 += f" | {msg}"
 
-    tot_cols = get_terminal_columns()
+    tot_cols = shutil.get_terminal_size().columns - 10
     bar_length = tot_cols - len(str0) - len(str1)
-    current_len = int(bar_length * (current + 1) / total)
-    rest_len = int(bar_length - current_len)
-
-    if current_len == 0:
-        str_bar = '.' * rest_len
+    if show_barmsg and bar_length < len(barmsg):
+        sys.stdout.write(str0[:-1] + barmsg)
+    elif bar_length <= 0:
+        sys.stdout.write(str0[:-1] + str1[2:])
     else:
-        str_bar = '=' * (current_len - 1) + '>' + '.' * rest_len
-
-    sys.stdout.write(str0 + str_bar + str1)
-
-    if current < total - 1:
-        sys.stdout.write('\r')
-    else:
-        sys.stdout.write('\n')
+        current_len = int(bar_length * current / total)
+        rest_len = int(bar_length - current_len)
+        str_bar = ''
+        if current_len > 0:
+            str_bar += '=' * (current_len - 1) + '>'
+        str_bar += '.' * rest_len
+        if show_barmsg:
+            str_bar = barmsg + str_bar[len(barmsg):]
+        sys.stdout.write(str0 + str_bar + str1)
+
+    sys.stdout.write('\r' if current < total else '\n')
     sys.stdout.flush()
-
-
-# return the formatted time
-def format_time(seconds):
-    days = int(seconds / 3600 / 24)
-    seconds = seconds - days * 3600 * 24
-    hours = int(seconds / 3600)
-    seconds = seconds - hours * 3600
-    minutes = int(seconds / 60)
-    seconds = seconds - minutes * 60
-    seconds_final = int(seconds)
-    seconds = seconds - seconds_final
-    millis = int(seconds * 1000)
-
-    output = ''
-    time_index = 1
-    if days > 0:
-        output += str(days) + 'D'
-        time_index += 1
-    if hours > 0 and time_index <= 2:
-        output += str(hours) + 'h'
-        time_index += 1
-    if minutes > 0 and time_index <= 2:
-        output += str(minutes) + 'm'
-        time_index += 1
-    if seconds_final > 0 and time_index <= 2:
-        output += '%02ds' % seconds_final
-        time_index += 1
-    if millis > 0 and time_index <= 2:
-        output += '%03dms' % millis
-        time_index += 1
-    if output == '':
-        output = '0ms'
-    return output
diff --git a/utils/sphere.py b/utils/sphere.py
index 8feb6d5..24d6309 100644
--- a/utils/sphere.py
+++ b/utils/sphere.py
@@ -1,4 +1,4 @@
-from typing import List, Union
+from typing import Union
 import torch
 import math
 from . import misc
@@ -13,12 +13,12 @@ def cartesian2spherical(cart: torch.Tensor, inverse_r: bool = False) -> torch.Te
     :return `Tensor(..., 3)`: coordinates in Spherical (r, theta, phi)
     """
     rho = torch.sqrt(torch.sum(cart * cart, dim=-1))
-    theta = misc.get_angle(cart[..., 0], cart[..., 2])
+    theta = misc.get_angle(cart[..., 2], cart[..., 0])
     if inverse_r:
         rho = rho.reciprocal()
-        phi = torch.acos(cart[..., 1] * rho)
+        phi = torch.asin(cart[..., 1] * rho)
     else:
-        phi = torch.acos(cart[..., 1] / rho)
+        phi = torch.asin(cart[..., 1] / rho)
     return torch.stack([rho, theta, phi], dim=-1)
 
 
@@ -34,9 +34,9 @@ def spherical2cartesian(spher: torch.Tensor, inverse_r: bool = False) -> torch.T
         rho = rho.reciprocal()
     sin_theta_phi = torch.sin(spher[..., 1:3])
     cos_theta_phi = torch.cos(spher[..., 1:3])
-    x = rho * cos_theta_phi[..., 0] * sin_theta_phi[..., 1]
-    y = rho * cos_theta_phi[..., 1]
-    z = rho * sin_theta_phi[..., 0] * sin_theta_phi[..., 1]
+    x = rho * sin_theta_phi[..., 0] * cos_theta_phi[..., 1]
+    y = rho * sin_theta_phi[..., 1]
+    z = rho * cos_theta_phi[..., 0] * cos_theta_phi[..., 1]
     return torch.stack([x, y, z], dim=-1)
 
 
diff --git a/utils/voxels.py b/utils/voxels.py
new file mode 100644
index 0000000..824f2de
--- /dev/null
+++ b/utils/voxels.py
@@ -0,0 +1,174 @@
+import torch
+from typing import Tuple, Union
+
+
+def get_grid_steps(bbox: torch.Tensor, step_size: Union[torch.Tensor, float]) -> torch.Tensor:
+    """
+    Get grid steps alone every dim.
+
+    :param bbox `Tensor(2, D)`: bounding box
+    :param step_size `Tensor(1|D) | float`: step size
+    :return `Tensor(D)`: grid steps alone every dim
+    """
+    return ((bbox[1] - bbox[0]) / step_size).ceil().long()
+
+
+def to_grid_coords(pts: torch.Tensor, bbox: torch.Tensor, *,
+                   step_size: Union[torch.Tensor, float] = None,
+                   steps: torch.Tensor = None) -> torch.Tensor:
+    """
+    Get discretized (integer) grid coordinates of points.
+
+    At least one of the parameters `step_size` and `steps` should be specified. If `step_size` is
+    specified, then the grid coordinates will be calculated according to the step size, ignoring
+    the value of `steps`.
+
+    :param pts `Tensor(N..., D)`: points
+    :param bbox `Tensor(2, D)`: bounding box
+    :param step_size `Tensor(1|D) | float`: (optional) step size
+    :param steps `Tensor(1|D)`: (optional) steps alone every dim
+    :return `Tensor(N..., D)`: discretized grid coordinates
+    """
+    if step_size is not None:
+        return ((pts - bbox[0]) / step_size).floor().long()
+    return ((pts - bbox[0]) / (bbox[1] - bbox[0]) * steps).floor().long()
+
+
+def to_grid_indices(pts: torch.Tensor, bbox: torch.Tensor, *,
+                    step_size: Union[torch.Tensor, float] = None,
+                    steps: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]:
+    """
+    Get flattened grid indices of points.
+
+    At least one of the parameters `step_size` and `steps` should be specified. If `step_size` is
+    specified, then the grid indices will be calculated according to the step size, ignoring
+    the value of `steps`.
+
+    :param pts `Tensor(N..., D)`: points
+    :param bbox `Tensor(2, D)`: bounding box
+    :param step_size `Tensor(1|D) | float`: (optional) step size
+    :param steps `Tensor(1|D)`: (optional) steps alone every dim
+    :return `Tensor(N...)`: grid indices
+    :return `Tensor(N...)`: a mask tensor indicating the returned indices are outside or not
+    """
+    if step_size is not None:
+        steps = get_grid_steps(bbox, step_size)  # (D)
+    grid_coords = to_grid_coords(pts, bbox, step_size=step_size, steps=steps)  # (N..., D)
+    outside_mask = torch.logical_or(grid_coords < 0, grid_coords >= steps).any(-1)  # (N...)
+    if pts.size(-1) == 1:
+        grid_indices = grid_coords[..., 0]
+    elif pts.size(-1) == 2:
+        grid_indices = grid_coords[..., 0] * steps[1] + grid_coords[..., 1]
+    elif pts.size(-1) == 3:
+        grid_indices = grid_coords[..., 0] * steps[1] * steps[2] \
+            + grid_coords[..., 1] * steps[2] + grid_coords[..., 2]
+    elif pts.size(-1) == 4:
+        grid_indices = grid_coords[..., 0] * steps[1] * steps[2] * steps[3] \
+            + grid_coords[..., 1] * steps[2] * steps[3] \
+            + grid_coords[..., 2] * steps[3] \
+            + grid_coords[..., 3]
+    else:
+        raise NotImplementedError("The function does not support D>4")
+    return grid_indices, outside_mask
+
+
+def init_voxels(bbox: torch.Tensor, steps: torch.Tensor):
+    """
+    Initialize voxels.
+    """
+    x, y, z = torch.meshgrid(*[torch.arange(steps[i]) for i in range(3)])
+    return to_voxel_centers(torch.stack([x, y, z], -1).reshape(-1, 3), bbox, steps=steps)
+
+
+def to_voxel_centers(grid_coords: torch.Tensor, bbox: torch.Tensor, *,
+                     step_size: Union[torch.Tensor, float] = None,
+                     steps: torch.Tensor = None) -> torch.Tensor:
+    """
+    Get discretized (integer) grid coordinates of points.
+
+    At least one of the parameters `step_size` and `steps` should be specified. If `step_size` is
+    specified, then the grid coordinates will be calculated according to the step size, ignoring
+    the value of `steps`.
+
+    :param pts `Tensor(N..., D)`: points
+    :param bbox `Tensor(2, D)`: bounding box
+    :param step_size `Tensor(1|D) | float`: (optional) step size
+    :param steps `Tensor(1|D)`: (optional) steps alone every dim
+    :return `Tensor(N..., D)`: discretized grid coordinates
+    """
+    grid_coords = grid_coords.float() + 0.5
+    if step_size is not None:
+        return grid_coords * step_size + bbox[0]
+    return grid_coords / steps * (bbox[1] - bbox[0]) + bbox[0]
+
+
+def split_voxels_local(voxel_size: Union[torch.Tensor, float], n: int, align_border: bool = True,
+                       dims=3, *, dtype: torch.dtype = None, device: torch.device = None,
+                       like: torch.Tensor = None):
+    """
+    [summary]
+
+    :param voxel_size `Tensor(D)|float`: [description]
+    :param n `int`: [description]
+    :param align_border `bool`: [description], defaults to False
+    :param dims `int`: [description], defaults to 3
+    :param dtype `dtype`: [description], defaults to None
+    :param device `device`: [description], defaults to None
+    :param like `Tensor(*)`:
+    :return `Tensor(X, D)`: [description]
+    """
+    if like is not None:
+        dtype = like.dtype
+        device = like.device
+    c = torch.arange(1 - n, n, 2, dtype=dtype, device=device)
+    offset = torch.stack(torch.meshgrid([c] * dims), -1).flatten(0, -2) * voxel_size / 2 /\
+        (n - 1 if align_border else n)
+    return offset
+
+
+def split_voxels(voxel_centers: torch.Tensor, voxel_size: Union[torch.Tensor, float], n: int,
+                 align_border: bool = True):
+    """
+    [summary]
+
+    :param voxel_centers `Tensor(N, D)`: [description]
+    :param voxel_size `Tensor(D)|float`: [description]
+    :param n `int`: [description]
+    :param align_border `bool`: [description], defaults to False
+    :param return_local `bool`: [description], defaults to False
+    :return `Tensor(N, X, D)`: [description]
+    """
+    return voxel_centers[:, None] + split_voxels_local(
+        voxel_size, n, align_border, voxel_centers.shape[-1], like=voxel_centers)
+
+
+def get_corners(voxel_centers: torch.Tensor, bbox: torch.Tensor, steps: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+    half_voxel_size = (bbox[1] - bbox[0]) / steps * 0.5
+    expand_bbox = bbox
+    expand_bbox[0] -= 0.5 * half_voxel_size
+    expand_bbox[1] += 0.5 * half_voxel_size
+    double_grid_coords = to_grid_coords(voxel_centers, expand_bbox, step_size=half_voxel_size)
+    # (M, 3) -> [1, 3, 5, ...]
+
+    corner_coords = split_voxels(double_grid_coords, 2, 2).reshape(-1, 3)
+    # (8M, 3) -> [0, 2, 4, ...]
+
+    corner_coords, corner_indices = corner_coords.unique(dim=0, sorted=True, return_inverse=True)
+    corners = to_voxel_centers(corner_coords, expand_bbox, step_size=half_voxel_size)
+
+    return corners, corner_indices.reshape(-1, 8)
+
+
+def trilinear_interp(pts: torch.Tensor, corner_values: torch.Tensor) -> torch.Tensor:
+    """
+    Perform trilinear interpolation in unit voxel ([0,0,0] ~ [1,1,1]).
+
+    :param pts `Tensor(N, 3)`: uniform coordinates in voxels
+    :param corner_values `Tensor(N, 8X)|Tensor(N, 8, X)`: values at corners of voxels
+    :return `Tensor(N, X)`: interpolated values
+    """
+    pts = pts[:, None]  # (N, 1, 3)
+    corners = split_voxels_local(1, 2, like=pts) + 0.5  # (8, 3)
+    weights = (pts * corners * 2 - pts - corners + 1).prod(-1, keepdim=True)  # (N, 8, 1)
+    corner_values = corner_values.reshape(corner_values.size(0), 8, -1)  # (N, 8, X)
+    return (weights * corner_values).sum(1)
-- 
GitLab