Commit 2824f796 authored by Nianchen Deng's avatar Nianchen Deng
Browse files

sync

parent 5699ccbf
...@@ -24,14 +24,19 @@ ...@@ -24,14 +24,19 @@
"program": "train.py", "program": "train.py",
"args": [ "args": [
//"-c", //"-c",
//"snerf_voxels", //"snerf_voxels+ls+f32",
"/home/dengnc/dvs/data/__new/barbershop_fovea_r360x80_t0.6/_nets/train_t0.3/snerfadvx_voxels_x4/checkpoint_10.tar", "/data1/dnc/dvs/data/__nerf/room/_nets/train/snerf_voxels+ls+f32/checkpoint_1.tar",
"--prune", "--prune",
"100", "1",
"--split", "--split",
"100" "1",
//"data/__new/barbershop_fovea_r360x80_t0.6/train_t0.3.json" "-e",
"100",
"--views",
"5",
//"data/__nerf/room/train.json"
], ],
"justMyCode": false,
"console": "integratedTerminal" "console": "integratedTerminal"
}, },
{ {
......
{
"model": "SNeRF",
"args": {
"color": "rgb",
"n_pot_encode": 10,
"n_dir_encode": 4,
"fc_params": {
"nf": 256,
"n_layers": 8,
"activation": "relu",
"skips": [ 4 ]
},
"n_featdim": 32,
"space": "voxels",
"steps": [4, 16, 8],
"n_samples": 16,
"perturb_sample": true,
"density_regularization_weight": 1e-4,
"density_regularization_scale": 1e4
}
}
\ No newline at end of file
...@@ -16,7 +16,7 @@ class BaseModelMeta(type): ...@@ -16,7 +16,7 @@ class BaseModelMeta(type):
class BaseModel(nn.Module, metaclass=BaseModelMeta): class BaseModel(nn.Module, metaclass=BaseModelMeta):
trainer = "Train" TrainerClass = "Train"
@property @property
def args(self): def args(self):
......
...@@ -10,7 +10,7 @@ from utils.misc import masked_scatter ...@@ -10,7 +10,7 @@ from utils.misc import masked_scatter
class NeRF(BaseModel): class NeRF(BaseModel):
trainer = "TrainWithSpace" TrainerClass = "TrainWithSpace"
SamplerClass = Sampler SamplerClass = Sampler
RendererClass = VolumnRenderer RendererClass = VolumnRenderer
...@@ -124,21 +124,11 @@ class NeRF(BaseModel): ...@@ -124,21 +124,11 @@ class NeRF(BaseModel):
return self.pot_encoder(x) return self.pot_encoder(x)
def encode_d(self, samples: Samples) -> torch.Tensor: def encode_d(self, samples: Samples) -> torch.Tensor:
return self.dir_encoder(samples.dirs) if self.dir_encoder is not None else None return self.dir_encoder(samples.dirs) if self.dir_encoder else None
@torch.no_grad() @torch.no_grad()
def get_scores(self, sampled_points: torch.Tensor, sampled_voxel_indices: torch.Tensor) -> torch.Tensor: def split(self):
densities = self.render(Samples(sampled_points, None, None, None, sampled_voxel_indices), ret = self.space.split()
'density')
return 1 - (-densities).exp()
@torch.no_grad()
def pruning(self, threshold: float = 0.5, train_stats=False):
return self.space.pruning(self.get_scores, threshold, train_stats)
@torch.no_grad()
def splitting(self):
ret = self.space.splitting()
if 'n_samples' in self.args0: if 'n_samples' in self.args0:
self.args0['n_samples'] *= 2 self.args0['n_samples'] *= 2
if 'voxel_size' in self.args0: if 'voxel_size' in self.args0:
...@@ -149,12 +139,10 @@ class NeRF(BaseModel): ...@@ -149,12 +139,10 @@ class NeRF(BaseModel):
if 'sample_step' in self.args0: if 'sample_step' in self.args0:
self.args0['sample_step'] /= 2 self.args0['sample_step'] /= 2
self.sampler = self.SamplerClass(**self.args) self.sampler = self.SamplerClass(**self.args)
if self.args.get('n_featdim') and hasattr(self, "trainer"):
self.trainer.reset_optimizer()
return ret return ret
@torch.no_grad()
def double_samples(self):
pass
@perf @perf
def forward(self, rays_o: torch.Tensor, rays_d: torch.Tensor, *, def forward(self, rays_o: torch.Tensor, rays_d: torch.Tensor, *,
extra_outputs: List[str] = [], **kwargs) -> torch.Tensor: extra_outputs: List[str] = [], **kwargs) -> torch.Tensor:
......
...@@ -40,16 +40,8 @@ class SNeRFAdvanceX(SNeRFAdvance): ...@@ -40,16 +40,8 @@ class SNeRFAdvanceX(SNeRFAdvance):
return self.cores[chunk_id](x, d, outputs, **extras) return self.cores[chunk_id](x, d, outputs, **extras)
@torch.no_grad() @torch.no_grad()
def get_scores(self, sampled_points: torch.Tensor, sampled_voxel_indices: torch.Tensor) -> torch.Tensor: def split(self):
raise NotImplementedError() ret = super().split()
@torch.no_grad()
def pruning(self, threshold: float = 0.5, train_stats=False):
raise NotImplementedError()
@torch.no_grad()
def splitting(self):
ret = super().splitting()
k = self.args["n_samples"] // self.space.steps[0].item() k = self.args["n_samples"] // self.space.steps[0].item()
net_samples = [val * k for val in self.space.balance_cut(0, len(self.cores))] net_samples = [val * k for val in self.space.balance_cut(0, len(self.cores))]
if len(net_samples) != len(self.cores): if len(net_samples) != len(self.cores):
......
...@@ -4,10 +4,6 @@ from .snerf import * ...@@ -4,10 +4,6 @@ from .snerf import *
class SNeRFX(SNeRF): class SNeRFX(SNeRF):
trainer = "TrainWithSpace"
SamplerClass = SphericalSampler
RendererClass = VolumnRenderer
def __init__(self, args0: dict, args1: dict = {}): def __init__(self, args0: dict, args1: dict = {}):
""" """
Initialize a multi-sphere-layer net Initialize a multi-sphere-layer net
...@@ -42,16 +38,8 @@ class SNeRFX(SNeRF): ...@@ -42,16 +38,8 @@ class SNeRFX(SNeRF):
return self.cores[chunk_id](x, d, outputs) return self.cores[chunk_id](x, d, outputs)
@torch.no_grad() @torch.no_grad()
def get_scores(self, sampled_points: torch.Tensor, sampled_voxel_indices: torch.Tensor) -> torch.Tensor: def split(self):
raise NotImplementedError() ret = super().split()
@torch.no_grad()
def pruning(self, threshold: float = 0.5, train_stats=False):
raise NotImplementedError()
@torch.no_grad()
def splitting(self):
ret = super().splitting()
k = self.args["n_samples"] // self.space.steps[0].item() k = self.args["n_samples"] // self.space.steps[0].item()
net_samples = [ net_samples = [
val * k for val in self.space.balance_cut(0, len(self.cores)) val * k for val in self.space.balance_cut(0, len(self.cores))
......
from math import ceil
import torch import torch
import numpy as np from typing import List, Tuple, Union
from typing import List, NoReturn, Tuple, Union
from torch import nn from torch import nn
from plyfile import PlyData, PlyElement
from utils.geometry import * from utils.geometry import *
from utils.constants import * from utils.constants import *
...@@ -73,11 +70,11 @@ class Space(nn.Module): ...@@ -73,11 +70,11 @@ class Space(nn.Module):
return voxel_indices return voxel_indices
@torch.no_grad() @torch.no_grad()
def pruning(self, score_fn, threshold: float = 0.5, train_stats=False): def prune(self, keeps: torch.Tensor) -> Tuple[int, int]:
raise NotImplementedError() raise NotImplementedError()
@torch.no_grad() @torch.no_grad()
def splitting(self): def split(self):
raise NotImplementedError() raise NotImplementedError()
...@@ -108,7 +105,7 @@ class Voxels(Space): ...@@ -108,7 +105,7 @@ class Voxels(Space):
return self.voxels.size(0) return self.voxels.size(0)
@property @property
def n_corner(self) -> int: def n_corners(self) -> int:
"""`int` Number of corners""" """`int` Number of corners"""
return self.corners.size(0) return self.corners.size(0)
...@@ -145,12 +142,18 @@ class Voxels(Space): ...@@ -145,12 +142,18 @@ class Voxels(Space):
:param n_dims `int`: embedding dimension :param n_dims `int`: embedding dimension
:return `Embedding(n_corners, n_dims)`: new embedding on voxel corners :return `Embedding(n_corners, n_dims)`: new embedding on voxel corners
""" """
name = f'emb_{name}' if self.get_embedding(name) is not None:
self.add_module(name, torch.nn.Embedding(self.n_corners.item(), n_dims)) raise KeyError(f"Embedding '{name}' already existed")
return self.__getattr__(name) emb = torch.nn.Embedding(self.n_corners, n_dims, device=self.device)
setattr(self, f'emb_{name}', emb)
return emb
def get_embedding(self, name: str = 'default') -> torch.nn.Embedding: def get_embedding(self, name: str = 'default') -> torch.nn.Embedding:
return getattr(self, f'emb_{name}') return getattr(self, f'emb_{name}', None)
def set_embedding(self, weight: torch.Tensor, name: str = 'default'):
emb = torch.nn.Embedding(*weight.shape, _weight=weight, device=self.device)
setattr(self, f'emb_{name}', emb)
def extract_embedding(self, pts: torch.Tensor, voxel_indices: torch.Tensor, def extract_embedding(self, pts: torch.Tensor, voxel_indices: torch.Tensor,
name: str = 'default') -> torch.Tensor: name: str = 'default') -> torch.Tensor:
...@@ -167,9 +170,8 @@ class Voxels(Space): ...@@ -167,9 +170,8 @@ class Voxels(Space):
raise KeyError(f"Embedding '{name}' doesn't exist") raise KeyError(f"Embedding '{name}' doesn't exist")
voxels = self.voxels[voxel_indices] # (N, 3) voxels = self.voxels[voxel_indices] # (N, 3)
corner_indices = self.corner_indices[voxel_indices] # (N, 8) corner_indices = self.corner_indices[voxel_indices] # (N, 8)
p = (pts - voxels) / self.voxel_size + 0.5 # (N, 3) normed-coords in voxel p = (pts - voxels) / self.voxel_size + .5 # (N, 3) normed-coords in voxel
features = emb(corner_indices).reshape(pts.size(0), 8, -1) # (N, 8, X) return trilinear_interp(p, emb(corner_indices))
return trilinear_interp(p, features)
@perf @perf
def ray_intersect(self, rays_o: torch.Tensor, rays_d: torch.Tensor, n_max_hits: int) -> Intersections: def ray_intersect(self, rays_o: torch.Tensor, rays_d: torch.Tensor, n_max_hits: int) -> Intersections:
...@@ -220,17 +222,34 @@ class Voxels(Space): ...@@ -220,17 +222,34 @@ class Voxels(Space):
return voxel_indices return voxel_indices
@torch.no_grad() @torch.no_grad()
def splitting(self) -> None: def split(self) -> None:
""" """
Split voxels into smaller voxels with half size. Split voxels into smaller voxels with half size.
""" """
n_voxels_before = self.n_voxels new_steps = self.steps * 2
self.steps *= 2 new_voxels = split_voxels(self.voxels, self.voxel_size, 2, align_border=False)\
self.voxels = split_voxels(self.voxels, self.voxel_size, 2, align_border=False)\
.reshape(-1, 3) .reshape(-1, 3)
self._update_corners() new_corners, new_corner_indices = get_corners(new_voxels, self.bbox, new_steps)
# Calculate new embeddings through trilinear interpolation
grid_indices_of_new_corners = to_flat_indices(
to_grid_coords(new_corners, self.bbox, steps=self.steps).min(self.steps - 1),
self.steps)
voxel_indices_of_new_corners = self.voxel_indices_in_grid[grid_indices_of_new_corners]
for name, _ in self.named_modules():
if not name.startswith("emb_"):
continue
new_emb_weight = self.extract_embedding(new_corners, voxel_indices_of_new_corners,
name=name[4:])
self.set_embedding(new_emb_weight, name=name[4:])
# Apply new tensors
self.steps = new_steps
self.voxels = new_voxels
self.corners = new_corners
self.corner_indices = new_corner_indices
self._update_voxel_indices_in_grid() self._update_voxel_indices_in_grid()
return n_voxels_before, self.n_voxels return self.n_voxels // 8, self.n_voxels
@torch.no_grad() @torch.no_grad()
def prune(self, keeps: torch.Tensor) -> Tuple[int, int]: def prune(self, keeps: torch.Tensor) -> Tuple[int, int]:
...@@ -239,11 +258,6 @@ class Voxels(Space): ...@@ -239,11 +258,6 @@ class Voxels(Space):
self._update_voxel_indices_in_grid() self._update_voxel_indices_in_grid()
return keeps.size(0), keeps.sum().item() return keeps.size(0), keeps.sum().item()
@torch.no_grad()
def pruning(self, score_fn, threshold: float = 0.5) -> None:
scores = self._get_scores(score_fn, lambda x: torch.max(x, -1)[0]) # (M)
return self.prune(scores > threshold)
def n_voxels_along_dim(self, dim: int) -> torch.Tensor: def n_voxels_along_dim(self, dim: int) -> torch.Tensor:
sum_dims = [val for val in range(self.dims) if val != dim] sum_dims = [val for val in range(self.dims) if val != dim]
return self.voxel_indices_in_grid.reshape(*self.steps).ne(-1).sum(sum_dims) return self.voxel_indices_in_grid.reshape(*self.steps).ne(-1).sum(sum_dims)
...@@ -261,39 +275,30 @@ class Voxels(Space): ...@@ -261,39 +275,30 @@ class Voxels(Space):
part = int(cdf[i]) + 1 part = int(cdf[i]) + 1
return bins return bins
def sample(self, bits: int, perturb: bool = False) -> Tuple[torch.Tensor, torch.Tensor]: def sample(self, S: int, perturb: bool = False, include_border: bool = True) -> Tuple[torch.Tensor, torch.Tensor]:
sampled_xyz = split_voxels(self.voxels, self.voxel_size, bits) """
sampled_idx = torch.arange(self.n_voxels, device=self.device)[:, None].expand( For each voxel, sample `S^3` points uniformly, with small perturb if `perturb` is `True`.
*sampled_xyz.shape[:2])
sampled_xyz, sampled_idx = sampled_xyz.reshape(-1, 3), sampled_idx.flatten() When `perturb` is `False`, `include_border` can specify whether to sample points from border to border or at centers of sub-voxels.
When `perturb` is `True`, points are sampled at centers of sub-voxels, then applying a random offset in sub-voxels.
@torch.no_grad() :param S `int`: number of samples along each dim
def _get_scores(self, score_fn, reduce_fn=None, bits=16) -> torch.Tensor: :param perturb `bool?`: whether perturb samples, defaults to `False`
def get_scores_once(pts, idxs): :param include_border `bool?`: whether include border, defaults to `True`
scores = score_fn(pts, idxs).reshape(-1, bits ** 3) # (B, P) :return `Tensor(N*S^3, 3)`: sampled points
if reduce_fn is not None: :return `Tensor(N*S^3)`: voxel indices of sampled points
scores = reduce_fn(scores) # (B[, ...]) """
return scores pts = split_voxels(self.voxels, self.voxel_size, S,
align_border=not perturb and include_border) # (N, X, D)
sampled_xyz, sampled_idx = self.sample(bits) voxel_indices = torch.arange(self.n_voxels, device=self.device)[:, None]\
chunk_size = 64 .expand(*pts.shape[:-1]) # (N) -> (N, X)
return torch.cat([ if perturb:
get_scores_once(sampled_xyz[i:i + chunk_size], sampled_idx[i:i + chunk_size]) pts += (torch.rand_like(pts) - .5) * self.voxel_size / S
for i in range(0, self.voxels.size(0), chunk_size) return pts.reshape(-1, 3), voxel_indices.flatten()
], 0) # (M[, ...])
def _ray_intersect(self, rays_o: torch.Tensor, rays_d: torch.Tensor, n_max_hits: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: def _ray_intersect(self, rays_o: torch.Tensor, rays_d: torch.Tensor, n_max_hits: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
return aabb_ray_intersect(self.voxel_size, n_max_hits, self.voxels, rays_o, rays_d) return aabb_ray_intersect(self.voxel_size, n_max_hits, self.voxels, rays_o, rays_d)
def _update_corners(self):
"""
Update voxel corners.
"""
corners, corner_indices = get_corners(self.voxels, self.bbox, self.steps)
self.register_buffer("corners", corners)
self.register_buffer("corner_indices", corner_indices)
def _update_voxel_indices_in_grid(self): def _update_voxel_indices_in_grid(self):
""" """
Update voxel indices in grid. Update voxel indices in grid.
...@@ -314,7 +319,7 @@ class Voxels(Space): ...@@ -314,7 +319,7 @@ class Voxels(Space):
# Handle embeddings # Handle embeddings
for name, module in self.named_modules(): for name, module in self.named_modules():
if name.startswith('emb_'): if name.startswith('emb_'):
setattr(self, name, torch.nn.Embedding(self.n_corners.item(), module.embedding_dim)) setattr(self, name, torch.nn.Embedding(self.n_corners, module.embedding_dim))
class Octree(Voxels): class Octree(Voxels):
...@@ -339,8 +344,8 @@ class Octree(Voxels): ...@@ -339,8 +344,8 @@ class Octree(Voxels):
return octree_ray_intersect(self.voxel_size, n_max_hits, nodes, tree, rays_o, rays_d) return octree_ray_intersect(self.voxel_size, n_max_hits, nodes, tree, rays_o, rays_d)
@torch.no_grad() @torch.no_grad()
def splitting(self): def split(self):
ret = super().splitting() ret = super().split()
self.clear() self.clear()
return ret return ret
......
{
"cells": [
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from utils.voxels import *\n",
"\n",
"bbox, steps = torch.tensor([[-2, -3.14159, 1], [2, 3.14159, 0]]), torch.tensor([2, 3, 3])\n",
"voxel_size = (bbox[1] - bbox[0]) / steps\n",
"voxels = init_voxels(bbox, steps)\n",
"corners, corner_indices = get_corners(voxels, bbox, steps)\n",
"voxel_indices_in_grid = torch.arange(voxels.shape[0])\n",
"emb = torch.nn.Embedding(corners.shape[0], 3, _weight=corners)"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([11, 3]) tensor([ 0, -1, -1, 1, -1, -1, 2, 3, 4, -1, 5, 6, -1, 7, 8, -1, 9, 10])\n"
]
}
],
"source": [
"keeps = torch.tensor([True]*18)\n",
"keeps[torch.tensor([1,2,4,5,9,12,15])] = False\n",
"voxels = voxels[keeps]\n",
"corner_indices = corner_indices[keeps]\n",
"grid_indices, _ = to_grid_indices(voxels, bbox, steps=steps)\n",
"voxel_indices_in_grid = grid_indices.new_full([steps.prod().item()], -1)\n",
"voxel_indices_in_grid[grid_indices] = torch.arange(voxels.shape[0])\n",
"print(voxels.shape, voxel_indices_in_grid)"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([88, 3]) torch.Size([185, 3]) torch.Size([88, 8])\n"
]
}
],
"source": [
"new_voxels = split_voxels(voxels, (bbox[1] - bbox[0]) / steps, 2, align_border=False).reshape(-1, 3)\n",
"new_corners, new_corner_indices = get_corners(new_voxels, bbox, steps * 2)\n",
"print(new_voxels.shape, new_corners.shape, new_corner_indices.shape)"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor([ 0, 0, -1, 0, 0, -1, 1, 1, -1, 1, 1, -1, 2, 2, 3, 3, 4, 4,\n",
" 4, 2, 2, 3, 3, 4, 4, 4, 2, 2, 3, 3, 4, 4, 4, 0, 0, -1,\n",
" 0, 0, -1, 1, 1, -1, 1, 1, -1, 2, 2, 3, 3, 4, 4, 4, 2, 2,\n",
" 3, 3, 4, 4, 4, 2, 2, 3, 3, 4, 4, 4, -1, -1, 5, 5, 6, 6,\n",
" 6, -1, -1, 5, 5, 6, 6, 6, -1, -1, 7, 7, 8, 8, 8, -1, -1, 7,\n",
" 7, 8, 8, 8, -1, -1, 9, 9, 10, 10, 10, -1, -1, 9, 9, 10, 10, 10,\n",
" -1, -1, 9, 9, 10, 10, 10, 5, 5, 6, 6, 6, 5, 5, 6, 6, 6, 7,\n",
" 7, 8, 8, 8, 7, 7, 8, 8, 8, 9, 9, 10, 10, 10, 9, 9, 10, 10,\n",
" 10, 9, 9, 10, 10, 10, 5, 5, 6, 6, 6, 5, 5, 6, 6, 6, 7, 7,\n",
" 8, 8, 8, 7, 7, 8, 8, 8, 9, 9, 10, 10, 10, 9, 9, 10, 10, 10,\n",
" 9, 9, 10, 10, 10])\n",
"tensor(0)\n"
]
}
],
"source": [
"voxel_indices_of_new_corner = voxel_indices_in_grid[to_flat_indices(to_grid_coords(new_corners, bbox, steps=steps).min(steps - 1), steps)]\n",
"print(voxel_indices_of_new_corner)\n",
"p_of_new_corners = (new_corners - voxels[voxel_indices_of_new_corner]) / voxel_size + .5\n",
"print(((new_corners - trilinear_interp(p_of_new_corners, emb(corner_indices[voxel_indices_of_new_corner]))) > 1e-6).sum())"
]
}
],
"metadata": {
"interpreter": {
"hash": "08b118544df3cb8970a671e5837a88fd458f4d4c799ef1fb2709465a22a45b92"
},
"kernelspec": {
"display_name": "Python 3.9.5 64-bit ('base': conda)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.5"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}
...@@ -38,7 +38,7 @@ from data.loader import DataLoader ...@@ -38,7 +38,7 @@ from data.loader import DataLoader
from utils.constants import HUGE_FLOAT from utils.constants import HUGE_FLOAT
RAYS_PER_BATCH = 2 ** 14 RAYS_PER_BATCH = 2 ** 12
DATA_LOADER_CHUNK_SIZE = 1e8 DATA_LOADER_CHUNK_SIZE = 1e8
......
This diff is collapsed.
This diff is collapsed.
...@@ -13,8 +13,9 @@ from data.loader import DataLoader ...@@ -13,8 +13,9 @@ from data.loader import DataLoader
from utils.misc import list_epochs, print_and_log from utils.misc import list_epochs, print_and_log
RAYS_PER_BATCH = 2 ** 16 RAYS_PER_BATCH = 2 ** 12
DATA_LOADER_CHUNK_SIZE = 1e8 DATA_LOADER_CHUNK_SIZE = 1e8
root_dir = Path.cwd()
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
...@@ -68,7 +69,7 @@ if args.mdl_path: ...@@ -68,7 +69,7 @@ if args.mdl_path:
model_args = model.args model_args = model.args
else: else:
# Create model from specified configuration # Create model from specified configuration
with Path(f'{sys.path[0]}/configs/{args.config}.json').open() as fp: with Path(f'{root_dir}/configs/{args.config}.json').open() as fp:
config = json.load(fp) config = json.load(fp)
model_name = args.config model_name = args.config
model_class = config['model'] model_class = config['model']
...@@ -76,7 +77,7 @@ else: ...@@ -76,7 +77,7 @@ else:
model_args['bbox'] = dataset.bbox model_args['bbox'] = dataset.bbox
model_args['depth_range'] = dataset.depth_range model_args['depth_range'] = dataset.depth_range
model, states = mdl.create(model_class, model_args), None model, states = mdl.create(model_class, model_args), None
model.to(device.default()).train() model.to(device.default())
run_dir = Path(f"_nets/{dataset.name}/{model_name}") run_dir = Path(f"_nets/{dataset.name}/{model_name}")
run_dir.mkdir(parents=True, exist_ok=True) run_dir.mkdir(parents=True, exist_ok=True)
......
...@@ -22,5 +22,5 @@ def get_class(class_name: str) -> type: ...@@ -22,5 +22,5 @@ def get_class(class_name: str) -> type:
def get_trainer(model: BaseModel, **kwargs) -> base.Train: def get_trainer(model: BaseModel, **kwargs) -> base.Train:
train_class = get_class(model.trainer) train_class = get_class(model.TrainerClass)
return train_class(model, **kwargs) return train_class(model, **kwargs)
...@@ -42,8 +42,9 @@ class Train(object, metaclass=BaseTrainMeta): ...@@ -42,8 +42,9 @@ class Train(object, metaclass=BaseTrainMeta):
self.iters = 0 self.iters = 0
self.run_dir = run_dir self.run_dir = run_dir
self.model.trainer = self
self.model.train() self.model.train()
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=5e-4) self.reset_optimizer()
if states: if states:
if 'epoch' in states: if 'epoch' in states:
...@@ -58,6 +59,9 @@ class Train(object, metaclass=BaseTrainMeta): ...@@ -58,6 +59,9 @@ class Train(object, metaclass=BaseTrainMeta):
if self.perf_mode: if self.perf_mode:
enable_perf() enable_perf()
def reset_optimizer(self):
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=5e-4)
def train(self, data_loader: DataLoader, max_epochs: int): def train(self, data_loader: DataLoader, max_epochs: int):
self.data_loader = data_loader self.data_loader = data_loader
self.iters_per_epoch = self.perf_frames or len(data_loader) self.iters_per_epoch = self.perf_frames or len(data_loader)
......
...@@ -20,18 +20,15 @@ class TrainWithSpace(Train): ...@@ -20,18 +20,15 @@ class TrainWithSpace(Train):
if self.splitting_loop == 1 or self.epoch % self.splitting_loop == 1: if self.splitting_loop == 1 or self.epoch % self.splitting_loop == 1:
try: try:
with torch.no_grad(): with torch.no_grad():
before, after = self.model.splitting() before, after = self.model.split()
print_and_log( print_and_log(f"Splitting done: {before} -> {after}")
f"Splitting done. # of voxels before: {before}, after: {after}")
except NotImplementedError: except NotImplementedError:
print_and_log( print_and_log(
"Note: The space does not support splitting operation. Just skip it.") "Note: The space does not support splitting operation. Just skip it.")
if self.pruning_loop == 1 or self.epoch % self.pruning_loop == 1: if self.pruning_loop == 1 or self.epoch % self.pruning_loop == 1:
try: try:
with torch.no_grad(): with torch.no_grad():
#before, after = self.model.pruning() # self._prune_voxels_by_densities()
# print(f"Pruning by voxel densities done. # of voxels before: {before}, after: {after}")
# self._prune_inner_voxels()
self._prune_voxels_by_weights() self._prune_voxels_by_weights()
except NotImplementedError: except NotImplementedError:
print_and_log( print_and_log(
...@@ -39,26 +36,26 @@ class TrainWithSpace(Train): ...@@ -39,26 +36,26 @@ class TrainWithSpace(Train):
super()._train_epoch() super()._train_epoch()
def _prune_inner_voxels(self): def _prune_voxels_by_densities(self):
space: Voxels = self.model.space space: Voxels = self.model.space
voxel_access_counts = torch.zeros(space.n_voxels, dtype=torch.long, threshold = .5
device=space.voxels.device) bits = 16
iters_in_epoch = 0
batch_size = self.data_loader.batch_size @torch.no_grad()
self.data_loader.batch_size = 2 ** 14 def get_scores(sampled_points: torch.Tensor, sampled_voxel_indices: torch.Tensor) -> torch.Tensor:
for _, rays_o, rays_d, _ in self.data_loader: densities = self.model.render(
self.model(rays_o, rays_d, Samples(sampled_points, None, None, None, sampled_voxel_indices),
raymarching_early_stop_tolerance=0.01, 'density')
raymarching_chunk_size_or_sections=[1], return 1 - (-densities).exp()
perturb_sample=False,
voxel_access_counts=voxel_access_counts, sampled_xyz, sampled_idx = space.sample(bits)
voxel_access_tolerance=0) chunk_size = 64
iters_in_epoch += 1 scores = torch.cat([
percent = iters_in_epoch / len(self.data_loader) * 100 torch.max(get_scores(sampled_xyz[i:i + chunk_size], sampled_idx[i:i + chunk_size])
sys.stdout.write(f'Pruning inner voxels...{percent:.1f}% \r') .reshape(-1, bits ** 3), -1)[0]
self.data_loader.batch_size = batch_size for i in range(0, self.voxels.size(0), chunk_size)
before, after = space.prune(voxel_access_counts > 0) ], 0) # (M[, ...])
print(f"Prune inner voxels: {before} -> {after}") return space.prune(scores > threshold)
def _prune_voxels_by_weights(self): def _prune_voxels_by_weights(self):
space: Voxels = self.model.space space: Voxels = self.model.space
......
...@@ -57,10 +57,11 @@ def meshgrid(*size: int, normalize: bool = False, swap_dim: bool = False) -> tor ...@@ -57,10 +57,11 @@ def meshgrid(*size: int, normalize: bool = False, swap_dim: bool = False) -> tor
""" """
if len(size) == 1: if len(size) == 1:
size = (size[0], size[0]) size = (size[0], size[0])
y, x = torch.meshgrid(torch.arange(0, size[0]), torch.arange(0, size[1])) y, x = torch.meshgrid(torch.arange(size[0]), torch.arange(size[1]), indexing='ij')
if swap_dim: if normalize:
return torch.stack([y / (size[0] - 1.), x / (size[1] - 1.)], 2) if normalize else torch.stack([y, x], 2) x.div_(size[1] - 1.)
return torch.stack([x / (size[1] - 1.), y / (size[0] - 1.)], 2) if normalize else torch.stack([x, y], 2) y.div_(size[0] - 1.)
return torch.stack([y, x], 2) if swap_dim else torch.stack([x, y], 2)
def get_angle(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: def get_angle(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
......
...@@ -13,6 +13,13 @@ def get_grid_steps(bbox: torch.Tensor, step_size: Union[torch.Tensor, float]) -> ...@@ -13,6 +13,13 @@ def get_grid_steps(bbox: torch.Tensor, step_size: Union[torch.Tensor, float]) ->
return ((bbox[1] - bbox[0]) / step_size).ceil().long() return ((bbox[1] - bbox[0]) / step_size).ceil().long()
def to_flat_indices(grid_coords: torch.Tensor, steps: torch.Tensor) -> torch.Tensor:
indices = grid_coords[..., 0]
for i in range(1, grid_coords.shape[-1]):
indices = indices * steps[i] + grid_coords[..., i]
return indices
def to_grid_coords(pts: torch.Tensor, bbox: torch.Tensor, *, def to_grid_coords(pts: torch.Tensor, bbox: torch.Tensor, *,
step_size: Union[torch.Tensor, float] = None, step_size: Union[torch.Tensor, float] = None,
steps: torch.Tensor = None) -> torch.Tensor: steps: torch.Tensor = None) -> torch.Tensor:
...@@ -55,20 +62,7 @@ def to_grid_indices(pts: torch.Tensor, bbox: torch.Tensor, *, ...@@ -55,20 +62,7 @@ def to_grid_indices(pts: torch.Tensor, bbox: torch.Tensor, *,
steps = get_grid_steps(bbox, step_size) # (D) steps = get_grid_steps(bbox, step_size) # (D)
grid_coords = to_grid_coords(pts, bbox, step_size=step_size, steps=steps) # (N..., D) grid_coords = to_grid_coords(pts, bbox, step_size=step_size, steps=steps) # (N..., D)
outside_mask = torch.logical_or(grid_coords < 0, grid_coords >= steps).any(-1) # (N...) outside_mask = torch.logical_or(grid_coords < 0, grid_coords >= steps).any(-1) # (N...)
if pts.size(-1) == 1: grid_indices = to_flat_indices(grid_coords, steps)
grid_indices = grid_coords[..., 0]
elif pts.size(-1) == 2:
grid_indices = grid_coords[..., 0] * steps[1] + grid_coords[..., 1]
elif pts.size(-1) == 3:
grid_indices = grid_coords[..., 0] * steps[1] * steps[2] \
+ grid_coords[..., 1] * steps[2] + grid_coords[..., 2]
elif pts.size(-1) == 4:
grid_indices = grid_coords[..., 0] * steps[1] * steps[2] * steps[3] \
+ grid_coords[..., 1] * steps[2] * steps[3] \
+ grid_coords[..., 2] * steps[3] \
+ grid_coords[..., 3]
else:
raise NotImplementedError("The function does not support D>4")
return grid_indices, outside_mask return grid_indices, outside_mask
...@@ -76,7 +70,7 @@ def init_voxels(bbox: torch.Tensor, steps: torch.Tensor): ...@@ -76,7 +70,7 @@ def init_voxels(bbox: torch.Tensor, steps: torch.Tensor):
""" """
Initialize voxels. Initialize voxels.
""" """
x, y, z = torch.meshgrid(*[torch.arange(steps[i]) for i in range(3)]) x, y, z = torch.meshgrid(*[torch.arange(steps[i]) for i in range(3)], indexing="ij")
return to_voxel_centers(torch.stack([x, y, z], -1).reshape(-1, 3), bbox, steps=steps) return to_voxel_centers(torch.stack([x, y, z], -1).reshape(-1, 3), bbox, steps=steps)
...@@ -96,7 +90,7 @@ def to_voxel_centers(grid_coords: torch.Tensor, bbox: torch.Tensor, *, ...@@ -96,7 +90,7 @@ def to_voxel_centers(grid_coords: torch.Tensor, bbox: torch.Tensor, *,
:param steps `Tensor(1|D)`: (optional) steps alone every dim :param steps `Tensor(1|D)`: (optional) steps alone every dim
:return `Tensor(N..., D)`: discretized grid coordinates :return `Tensor(N..., D)`: discretized grid coordinates
""" """
grid_coords = grid_coords.float() + 0.5 grid_coords = grid_coords.float() + .5
if step_size is not None: if step_size is not None:
return grid_coords * step_size + bbox[0] return grid_coords * step_size + bbox[0]
return grid_coords / steps * (bbox[1] - bbox[0]) + bbox[0] return grid_coords / steps * (bbox[1] - bbox[0]) + bbox[0]
...@@ -121,8 +115,8 @@ def split_voxels_local(voxel_size: Union[torch.Tensor, float], n: int, align_bor ...@@ -121,8 +115,8 @@ def split_voxels_local(voxel_size: Union[torch.Tensor, float], n: int, align_bor
dtype = like.dtype dtype = like.dtype
device = like.device device = like.device
c = torch.arange(1 - n, n, 2, dtype=dtype, device=device) c = torch.arange(1 - n, n, 2, dtype=dtype, device=device)
offset = torch.stack(torch.meshgrid([c] * dims), -1).flatten(0, -2) * voxel_size / 2 /\ offset = torch.stack(torch.meshgrid([c] * dims, indexing='ij'), -1).flatten(0, -2)\
(n - 1 if align_border else n) * voxel_size * .5 / (n - 1 if align_border else n)
return offset return offset
...@@ -144,7 +138,7 @@ def split_voxels(voxel_centers: torch.Tensor, voxel_size: Union[torch.Tensor, fl ...@@ -144,7 +138,7 @@ def split_voxels(voxel_centers: torch.Tensor, voxel_size: Union[torch.Tensor, fl
def get_corners(voxel_centers: torch.Tensor, bbox: torch.Tensor, steps: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: def get_corners(voxel_centers: torch.Tensor, bbox: torch.Tensor, steps: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
half_voxel_size = (bbox[1] - bbox[0]) / steps * 0.5 half_voxel_size = (bbox[1] - bbox[0]) / steps * 0.5
expand_bbox = bbox expand_bbox = bbox.clone()
expand_bbox[0] -= 0.5 * half_voxel_size expand_bbox[0] -= 0.5 * half_voxel_size
expand_bbox[1] += 0.5 * half_voxel_size expand_bbox[1] += 0.5 * half_voxel_size
double_grid_coords = to_grid_coords(voxel_centers, expand_bbox, step_size=half_voxel_size) double_grid_coords = to_grid_coords(voxel_centers, expand_bbox, step_size=half_voxel_size)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment