Commit 2824f796 authored by Nianchen Deng's avatar Nianchen Deng
Browse files

sync

parent 5699ccbf
......@@ -24,14 +24,19 @@
"program": "train.py",
"args": [
//"-c",
//"snerf_voxels",
"/home/dengnc/dvs/data/__new/barbershop_fovea_r360x80_t0.6/_nets/train_t0.3/snerfadvx_voxels_x4/checkpoint_10.tar",
//"snerf_voxels+ls+f32",
"/data1/dnc/dvs/data/__nerf/room/_nets/train/snerf_voxels+ls+f32/checkpoint_1.tar",
"--prune",
"100",
"1",
"--split",
"100"
//"data/__new/barbershop_fovea_r360x80_t0.6/train_t0.3.json"
"1",
"-e",
"100",
"--views",
"5",
//"data/__nerf/room/train.json"
],
"justMyCode": false,
"console": "integratedTerminal"
},
{
......
{
"model": "SNeRF",
"args": {
"color": "rgb",
"n_pot_encode": 10,
"n_dir_encode": 4,
"fc_params": {
"nf": 256,
"n_layers": 8,
"activation": "relu",
"skips": [ 4 ]
},
"n_featdim": 32,
"space": "voxels",
"steps": [4, 16, 8],
"n_samples": 16,
"perturb_sample": true,
"density_regularization_weight": 1e-4,
"density_regularization_scale": 1e4
}
}
\ No newline at end of file
......@@ -16,7 +16,7 @@ class BaseModelMeta(type):
class BaseModel(nn.Module, metaclass=BaseModelMeta):
trainer = "Train"
TrainerClass = "Train"
@property
def args(self):
......
......@@ -10,7 +10,7 @@ from utils.misc import masked_scatter
class NeRF(BaseModel):
trainer = "TrainWithSpace"
TrainerClass = "TrainWithSpace"
SamplerClass = Sampler
RendererClass = VolumnRenderer
......@@ -124,21 +124,11 @@ class NeRF(BaseModel):
return self.pot_encoder(x)
def encode_d(self, samples: Samples) -> torch.Tensor:
return self.dir_encoder(samples.dirs) if self.dir_encoder is not None else None
return self.dir_encoder(samples.dirs) if self.dir_encoder else None
@torch.no_grad()
def get_scores(self, sampled_points: torch.Tensor, sampled_voxel_indices: torch.Tensor) -> torch.Tensor:
densities = self.render(Samples(sampled_points, None, None, None, sampled_voxel_indices),
'density')
return 1 - (-densities).exp()
@torch.no_grad()
def pruning(self, threshold: float = 0.5, train_stats=False):
return self.space.pruning(self.get_scores, threshold, train_stats)
@torch.no_grad()
def splitting(self):
ret = self.space.splitting()
def split(self):
ret = self.space.split()
if 'n_samples' in self.args0:
self.args0['n_samples'] *= 2
if 'voxel_size' in self.args0:
......@@ -149,12 +139,10 @@ class NeRF(BaseModel):
if 'sample_step' in self.args0:
self.args0['sample_step'] /= 2
self.sampler = self.SamplerClass(**self.args)
if self.args.get('n_featdim') and hasattr(self, "trainer"):
self.trainer.reset_optimizer()
return ret
@torch.no_grad()
def double_samples(self):
pass
@perf
def forward(self, rays_o: torch.Tensor, rays_d: torch.Tensor, *,
extra_outputs: List[str] = [], **kwargs) -> torch.Tensor:
......
......@@ -40,16 +40,8 @@ class SNeRFAdvanceX(SNeRFAdvance):
return self.cores[chunk_id](x, d, outputs, **extras)
@torch.no_grad()
def get_scores(self, sampled_points: torch.Tensor, sampled_voxel_indices: torch.Tensor) -> torch.Tensor:
raise NotImplementedError()
@torch.no_grad()
def pruning(self, threshold: float = 0.5, train_stats=False):
raise NotImplementedError()
@torch.no_grad()
def splitting(self):
ret = super().splitting()
def split(self):
ret = super().split()
k = self.args["n_samples"] // self.space.steps[0].item()
net_samples = [val * k for val in self.space.balance_cut(0, len(self.cores))]
if len(net_samples) != len(self.cores):
......
......@@ -4,10 +4,6 @@ from .snerf import *
class SNeRFX(SNeRF):
trainer = "TrainWithSpace"
SamplerClass = SphericalSampler
RendererClass = VolumnRenderer
def __init__(self, args0: dict, args1: dict = {}):
"""
Initialize a multi-sphere-layer net
......@@ -42,16 +38,8 @@ class SNeRFX(SNeRF):
return self.cores[chunk_id](x, d, outputs)
@torch.no_grad()
def get_scores(self, sampled_points: torch.Tensor, sampled_voxel_indices: torch.Tensor) -> torch.Tensor:
raise NotImplementedError()
@torch.no_grad()
def pruning(self, threshold: float = 0.5, train_stats=False):
raise NotImplementedError()
@torch.no_grad()
def splitting(self):
ret = super().splitting()
def split(self):
ret = super().split()
k = self.args["n_samples"] // self.space.steps[0].item()
net_samples = [
val * k for val in self.space.balance_cut(0, len(self.cores))
......
from math import ceil
import torch
import numpy as np
from typing import List, NoReturn, Tuple, Union
from typing import List, Tuple, Union
from torch import nn
from plyfile import PlyData, PlyElement
from utils.geometry import *
from utils.constants import *
......@@ -73,11 +70,11 @@ class Space(nn.Module):
return voxel_indices
@torch.no_grad()
def pruning(self, score_fn, threshold: float = 0.5, train_stats=False):
def prune(self, keeps: torch.Tensor) -> Tuple[int, int]:
raise NotImplementedError()
@torch.no_grad()
def splitting(self):
def split(self):
raise NotImplementedError()
......@@ -108,7 +105,7 @@ class Voxels(Space):
return self.voxels.size(0)
@property
def n_corner(self) -> int:
def n_corners(self) -> int:
"""`int` Number of corners"""
return self.corners.size(0)
......@@ -145,12 +142,18 @@ class Voxels(Space):
:param n_dims `int`: embedding dimension
:return `Embedding(n_corners, n_dims)`: new embedding on voxel corners
"""
name = f'emb_{name}'
self.add_module(name, torch.nn.Embedding(self.n_corners.item(), n_dims))
return self.__getattr__(name)
if self.get_embedding(name) is not None:
raise KeyError(f"Embedding '{name}' already existed")
emb = torch.nn.Embedding(self.n_corners, n_dims, device=self.device)
setattr(self, f'emb_{name}', emb)
return emb
def get_embedding(self, name: str = 'default') -> torch.nn.Embedding:
return getattr(self, f'emb_{name}')
return getattr(self, f'emb_{name}', None)
def set_embedding(self, weight: torch.Tensor, name: str = 'default'):
emb = torch.nn.Embedding(*weight.shape, _weight=weight, device=self.device)
setattr(self, f'emb_{name}', emb)
def extract_embedding(self, pts: torch.Tensor, voxel_indices: torch.Tensor,
name: str = 'default') -> torch.Tensor:
......@@ -167,9 +170,8 @@ class Voxels(Space):
raise KeyError(f"Embedding '{name}' doesn't exist")
voxels = self.voxels[voxel_indices] # (N, 3)
corner_indices = self.corner_indices[voxel_indices] # (N, 8)
p = (pts - voxels) / self.voxel_size + 0.5 # (N, 3) normed-coords in voxel
features = emb(corner_indices).reshape(pts.size(0), 8, -1) # (N, 8, X)
return trilinear_interp(p, features)
p = (pts - voxels) / self.voxel_size + .5 # (N, 3) normed-coords in voxel
return trilinear_interp(p, emb(corner_indices))
@perf
def ray_intersect(self, rays_o: torch.Tensor, rays_d: torch.Tensor, n_max_hits: int) -> Intersections:
......@@ -220,17 +222,34 @@ class Voxels(Space):
return voxel_indices
@torch.no_grad()
def splitting(self) -> None:
def split(self) -> None:
"""
Split voxels into smaller voxels with half size.
"""
n_voxels_before = self.n_voxels
self.steps *= 2
self.voxels = split_voxels(self.voxels, self.voxel_size, 2, align_border=False)\
new_steps = self.steps * 2
new_voxels = split_voxels(self.voxels, self.voxel_size, 2, align_border=False)\
.reshape(-1, 3)
self._update_corners()
new_corners, new_corner_indices = get_corners(new_voxels, self.bbox, new_steps)
# Calculate new embeddings through trilinear interpolation
grid_indices_of_new_corners = to_flat_indices(
to_grid_coords(new_corners, self.bbox, steps=self.steps).min(self.steps - 1),
self.steps)
voxel_indices_of_new_corners = self.voxel_indices_in_grid[grid_indices_of_new_corners]
for name, _ in self.named_modules():
if not name.startswith("emb_"):
continue
new_emb_weight = self.extract_embedding(new_corners, voxel_indices_of_new_corners,
name=name[4:])
self.set_embedding(new_emb_weight, name=name[4:])
# Apply new tensors
self.steps = new_steps
self.voxels = new_voxels
self.corners = new_corners
self.corner_indices = new_corner_indices
self._update_voxel_indices_in_grid()
return n_voxels_before, self.n_voxels
return self.n_voxels // 8, self.n_voxels
@torch.no_grad()
def prune(self, keeps: torch.Tensor) -> Tuple[int, int]:
......@@ -239,11 +258,6 @@ class Voxels(Space):
self._update_voxel_indices_in_grid()
return keeps.size(0), keeps.sum().item()
@torch.no_grad()
def pruning(self, score_fn, threshold: float = 0.5) -> None:
scores = self._get_scores(score_fn, lambda x: torch.max(x, -1)[0]) # (M)
return self.prune(scores > threshold)
def n_voxels_along_dim(self, dim: int) -> torch.Tensor:
sum_dims = [val for val in range(self.dims) if val != dim]
return self.voxel_indices_in_grid.reshape(*self.steps).ne(-1).sum(sum_dims)
......@@ -261,39 +275,30 @@ class Voxels(Space):
part = int(cdf[i]) + 1
return bins
def sample(self, bits: int, perturb: bool = False) -> Tuple[torch.Tensor, torch.Tensor]:
sampled_xyz = split_voxels(self.voxels, self.voxel_size, bits)
sampled_idx = torch.arange(self.n_voxels, device=self.device)[:, None].expand(
*sampled_xyz.shape[:2])
sampled_xyz, sampled_idx = sampled_xyz.reshape(-1, 3), sampled_idx.flatten()
def sample(self, S: int, perturb: bool = False, include_border: bool = True) -> Tuple[torch.Tensor, torch.Tensor]:
"""
For each voxel, sample `S^3` points uniformly, with small perturb if `perturb` is `True`.
When `perturb` is `False`, `include_border` can specify whether to sample points from border to border or at centers of sub-voxels.
When `perturb` is `True`, points are sampled at centers of sub-voxels, then applying a random offset in sub-voxels.
@torch.no_grad()
def _get_scores(self, score_fn, reduce_fn=None, bits=16) -> torch.Tensor:
def get_scores_once(pts, idxs):
scores = score_fn(pts, idxs).reshape(-1, bits ** 3) # (B, P)
if reduce_fn is not None:
scores = reduce_fn(scores) # (B[, ...])
return scores
sampled_xyz, sampled_idx = self.sample(bits)
chunk_size = 64
return torch.cat([
get_scores_once(sampled_xyz[i:i + chunk_size], sampled_idx[i:i + chunk_size])
for i in range(0, self.voxels.size(0), chunk_size)
], 0) # (M[, ...])
:param S `int`: number of samples along each dim
:param perturb `bool?`: whether perturb samples, defaults to `False`
:param include_border `bool?`: whether include border, defaults to `True`
:return `Tensor(N*S^3, 3)`: sampled points
:return `Tensor(N*S^3)`: voxel indices of sampled points
"""
pts = split_voxels(self.voxels, self.voxel_size, S,
align_border=not perturb and include_border) # (N, X, D)
voxel_indices = torch.arange(self.n_voxels, device=self.device)[:, None]\
.expand(*pts.shape[:-1]) # (N) -> (N, X)
if perturb:
pts += (torch.rand_like(pts) - .5) * self.voxel_size / S
return pts.reshape(-1, 3), voxel_indices.flatten()
def _ray_intersect(self, rays_o: torch.Tensor, rays_d: torch.Tensor, n_max_hits: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
return aabb_ray_intersect(self.voxel_size, n_max_hits, self.voxels, rays_o, rays_d)
def _update_corners(self):
"""
Update voxel corners.
"""
corners, corner_indices = get_corners(self.voxels, self.bbox, self.steps)
self.register_buffer("corners", corners)
self.register_buffer("corner_indices", corner_indices)
def _update_voxel_indices_in_grid(self):
"""
Update voxel indices in grid.
......@@ -314,7 +319,7 @@ class Voxels(Space):
# Handle embeddings
for name, module in self.named_modules():
if name.startswith('emb_'):
setattr(self, name, torch.nn.Embedding(self.n_corners.item(), module.embedding_dim))
setattr(self, name, torch.nn.Embedding(self.n_corners, module.embedding_dim))
class Octree(Voxels):
......@@ -339,8 +344,8 @@ class Octree(Voxels):
return octree_ray_intersect(self.voxel_size, n_max_hits, nodes, tree, rays_o, rays_d)
@torch.no_grad()
def splitting(self):
ret = super().splitting()
def split(self):
ret = super().split()
self.clear()
return ret
......
{
"cells": [
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from utils.voxels import *\n",
"\n",
"bbox, steps = torch.tensor([[-2, -3.14159, 1], [2, 3.14159, 0]]), torch.tensor([2, 3, 3])\n",
"voxel_size = (bbox[1] - bbox[0]) / steps\n",
"voxels = init_voxels(bbox, steps)\n",
"corners, corner_indices = get_corners(voxels, bbox, steps)\n",
"voxel_indices_in_grid = torch.arange(voxels.shape[0])\n",
"emb = torch.nn.Embedding(corners.shape[0], 3, _weight=corners)"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([11, 3]) tensor([ 0, -1, -1, 1, -1, -1, 2, 3, 4, -1, 5, 6, -1, 7, 8, -1, 9, 10])\n"
]
}
],
"source": [
"keeps = torch.tensor([True]*18)\n",
"keeps[torch.tensor([1,2,4,5,9,12,15])] = False\n",
"voxels = voxels[keeps]\n",
"corner_indices = corner_indices[keeps]\n",
"grid_indices, _ = to_grid_indices(voxels, bbox, steps=steps)\n",
"voxel_indices_in_grid = grid_indices.new_full([steps.prod().item()], -1)\n",
"voxel_indices_in_grid[grid_indices] = torch.arange(voxels.shape[0])\n",
"print(voxels.shape, voxel_indices_in_grid)"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([88, 3]) torch.Size([185, 3]) torch.Size([88, 8])\n"
]
}
],
"source": [
"new_voxels = split_voxels(voxels, (bbox[1] - bbox[0]) / steps, 2, align_border=False).reshape(-1, 3)\n",
"new_corners, new_corner_indices = get_corners(new_voxels, bbox, steps * 2)\n",
"print(new_voxels.shape, new_corners.shape, new_corner_indices.shape)"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor([ 0, 0, -1, 0, 0, -1, 1, 1, -1, 1, 1, -1, 2, 2, 3, 3, 4, 4,\n",
" 4, 2, 2, 3, 3, 4, 4, 4, 2, 2, 3, 3, 4, 4, 4, 0, 0, -1,\n",
" 0, 0, -1, 1, 1, -1, 1, 1, -1, 2, 2, 3, 3, 4, 4, 4, 2, 2,\n",
" 3, 3, 4, 4, 4, 2, 2, 3, 3, 4, 4, 4, -1, -1, 5, 5, 6, 6,\n",
" 6, -1, -1, 5, 5, 6, 6, 6, -1, -1, 7, 7, 8, 8, 8, -1, -1, 7,\n",
" 7, 8, 8, 8, -1, -1, 9, 9, 10, 10, 10, -1, -1, 9, 9, 10, 10, 10,\n",
" -1, -1, 9, 9, 10, 10, 10, 5, 5, 6, 6, 6, 5, 5, 6, 6, 6, 7,\n",
" 7, 8, 8, 8, 7, 7, 8, 8, 8, 9, 9, 10, 10, 10, 9, 9, 10, 10,\n",
" 10, 9, 9, 10, 10, 10, 5, 5, 6, 6, 6, 5, 5, 6, 6, 6, 7, 7,\n",
" 8, 8, 8, 7, 7, 8, 8, 8, 9, 9, 10, 10, 10, 9, 9, 10, 10, 10,\n",
" 9, 9, 10, 10, 10])\n",
"tensor(0)\n"
]
}
],
"source": [
"voxel_indices_of_new_corner = voxel_indices_in_grid[to_flat_indices(to_grid_coords(new_corners, bbox, steps=steps).min(steps - 1), steps)]\n",
"print(voxel_indices_of_new_corner)\n",
"p_of_new_corners = (new_corners - voxels[voxel_indices_of_new_corner]) / voxel_size + .5\n",
"print(((new_corners - trilinear_interp(p_of_new_corners, emb(corner_indices[voxel_indices_of_new_corner]))) > 1e-6).sum())"
]
}
],
"metadata": {
"interpreter": {
"hash": "08b118544df3cb8970a671e5837a88fd458f4d4c799ef1fb2709465a22a45b92"
},
"kernelspec": {
"display_name": "Python 3.9.5 64-bit ('base': conda)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.5"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}
......@@ -38,7 +38,7 @@ from data.loader import DataLoader
from utils.constants import HUGE_FLOAT
RAYS_PER_BATCH = 2 ** 14
RAYS_PER_BATCH = 2 ** 12
DATA_LOADER_CHUNK_SIZE = 1e8
......
This diff is collapsed.
This diff is collapsed.
......@@ -13,8 +13,9 @@ from data.loader import DataLoader
from utils.misc import list_epochs, print_and_log
RAYS_PER_BATCH = 2 ** 16
RAYS_PER_BATCH = 2 ** 12
DATA_LOADER_CHUNK_SIZE = 1e8
root_dir = Path.cwd()
parser = argparse.ArgumentParser()
......@@ -68,7 +69,7 @@ if args.mdl_path:
model_args = model.args
else:
# Create model from specified configuration
with Path(f'{sys.path[0]}/configs/{args.config}.json').open() as fp:
with Path(f'{root_dir}/configs/{args.config}.json').open() as fp:
config = json.load(fp)
model_name = args.config
model_class = config['model']
......@@ -76,7 +77,7 @@ else:
model_args['bbox'] = dataset.bbox
model_args['depth_range'] = dataset.depth_range
model, states = mdl.create(model_class, model_args), None
model.to(device.default()).train()
model.to(device.default())
run_dir = Path(f"_nets/{dataset.name}/{model_name}")
run_dir.mkdir(parents=True, exist_ok=True)
......
......@@ -22,5 +22,5 @@ def get_class(class_name: str) -> type:
def get_trainer(model: BaseModel, **kwargs) -> base.Train:
train_class = get_class(model.trainer)
train_class = get_class(model.TrainerClass)
return train_class(model, **kwargs)
......@@ -42,8 +42,9 @@ class Train(object, metaclass=BaseTrainMeta):
self.iters = 0
self.run_dir = run_dir
self.model.trainer = self
self.model.train()
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=5e-4)
self.reset_optimizer()
if states:
if 'epoch' in states:
......@@ -58,6 +59,9 @@ class Train(object, metaclass=BaseTrainMeta):
if self.perf_mode:
enable_perf()
def reset_optimizer(self):
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=5e-4)
def train(self, data_loader: DataLoader, max_epochs: int):
self.data_loader = data_loader
self.iters_per_epoch = self.perf_frames or len(data_loader)
......
......@@ -20,18 +20,15 @@ class TrainWithSpace(Train):
if self.splitting_loop == 1 or self.epoch % self.splitting_loop == 1:
try:
with torch.no_grad():
before, after = self.model.splitting()
print_and_log(
f"Splitting done. # of voxels before: {before}, after: {after}")
before, after = self.model.split()
print_and_log(f"Splitting done: {before} -> {after}")
except NotImplementedError:
print_and_log(
"Note: The space does not support splitting operation. Just skip it.")
if self.pruning_loop == 1 or self.epoch % self.pruning_loop == 1:
try:
with torch.no_grad():
#before, after = self.model.pruning()
# print(f"Pruning by voxel densities done. # of voxels before: {before}, after: {after}")
# self._prune_inner_voxels()
# self._prune_voxels_by_densities()
self._prune_voxels_by_weights()
except NotImplementedError:
print_and_log(
......@@ -39,26 +36,26 @@ class TrainWithSpace(Train):
super()._train_epoch()
def _prune_inner_voxels(self):
def _prune_voxels_by_densities(self):
space: Voxels = self.model.space
voxel_access_counts = torch.zeros(space.n_voxels, dtype=torch.long,
device=space.voxels.device)
iters_in_epoch = 0
batch_size = self.data_loader.batch_size
self.data_loader.batch_size = 2 ** 14
for _, rays_o, rays_d, _ in self.data_loader:
self.model(rays_o, rays_d,
raymarching_early_stop_tolerance=0.01,
raymarching_chunk_size_or_sections=[1],
perturb_sample=False,
voxel_access_counts=voxel_access_counts,
voxel_access_tolerance=0)
iters_in_epoch += 1
percent = iters_in_epoch / len(self.data_loader) * 100
sys.stdout.write(f'Pruning inner voxels...{percent:.1f}% \r')
self.data_loader.batch_size = batch_size
before, after = space.prune(voxel_access_counts > 0)
print(f"Prune inner voxels: {before} -> {after}")
threshold = .5
bits = 16
@torch.no_grad()
def get_scores(sampled_points: torch.Tensor, sampled_voxel_indices: torch.Tensor) -> torch.Tensor:
densities = self.model.render(
Samples(sampled_points, None, None, None, sampled_voxel_indices),
'density')
return 1 - (-densities).exp()
sampled_xyz, sampled_idx = space.sample(bits)
chunk_size = 64
scores = torch.cat([
torch.max(get_scores(sampled_xyz[i:i + chunk_size], sampled_idx[i:i + chunk_size])
.reshape(-1, bits ** 3), -1)[0]
for i in range(0, self.voxels.size(0), chunk_size)
], 0) # (M[, ...])
return space.prune(scores > threshold)
def _prune_voxels_by_weights(self):
space: Voxels = self.model.space
......
......@@ -57,10 +57,11 @@ def meshgrid(*size: int, normalize: bool = False, swap_dim: bool = False) -> tor
"""
if len(size) == 1:
size = (size[0], size[0])
y, x = torch.meshgrid(torch.arange(0, size[0]), torch.arange(0, size[1]))
if swap_dim:
return torch.stack([y / (size[0] - 1.), x / (size[1] - 1.)], 2) if normalize else torch.stack([y, x], 2)
return torch.stack([x / (size[1] - 1.), y / (size[0] - 1.)], 2) if normalize else torch.stack([x, y], 2)
y, x = torch.meshgrid(torch.arange(size[0]), torch.arange(size[1]), indexing='ij')
if normalize:
x.div_(size[1] - 1.)
y.div_(size[0] - 1.)
return torch.stack([y, x], 2) if swap_dim else torch.stack([x, y], 2)
def get_angle(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
......
......@@ -13,6 +13,13 @@ def get_grid_steps(bbox: torch.Tensor, step_size: Union[torch.Tensor, float]) ->
return ((bbox[1] - bbox[0]) / step_size).ceil().long()
def to_flat_indices(grid_coords: torch.Tensor, steps: torch.Tensor) -> torch.Tensor:
indices = grid_coords[..., 0]
for i in range(1, grid_coords.shape[-1]):
indices = indices * steps[i] + grid_coords[..., i]
return indices
def to_grid_coords(pts: torch.Tensor, bbox: torch.Tensor, *,
step_size: Union[torch.Tensor, float] = None,
steps: torch.Tensor = None) -> torch.Tensor:
......@@ -55,20 +62,7 @@ def to_grid_indices(pts: torch.Tensor, bbox: torch.Tensor, *,
steps = get_grid_steps(bbox, step_size) # (D)
grid_coords = to_grid_coords(pts, bbox, step_size=step_size, steps=steps) # (N..., D)
outside_mask = torch.logical_or(grid_coords < 0, grid_coords >= steps).any(-1) # (N...)
if pts.size(-1) == 1:
grid_indices = grid_coords[..., 0]
elif pts.size(-1) == 2:
grid_indices = grid_coords[..., 0] * steps[1] + grid_coords[..., 1]
elif pts.size(-1) == 3:
grid_indices = grid_coords[..., 0] * steps[1] * steps[2] \
+ grid_coords[..., 1] * steps[2] + grid_coords[..., 2]
elif pts.size(-1) == 4:
grid_indices = grid_coords[..., 0] * steps[1] * steps[2] * steps[3] \
+ grid_coords[..., 1] * steps[2] * steps[3] \
+ grid_coords[..., 2] * steps[3] \
+ grid_coords[..., 3]
else:
raise NotImplementedError("The function does not support D>4")
grid_indices = to_flat_indices(grid_coords, steps)
return grid_indices, outside_mask
......@@ -76,7 +70,7 @@ def init_voxels(bbox: torch.Tensor, steps: torch.Tensor):
"""
Initialize voxels.
"""
x, y, z = torch.meshgrid(*[torch.arange(steps[i]) for i in range(3)])
x, y, z = torch.meshgrid(*[torch.arange(steps[i]) for i in range(3)], indexing="ij")
return to_voxel_centers(torch.stack([x, y, z], -1).reshape(-1, 3), bbox, steps=steps)
......@@ -96,7 +90,7 @@ def to_voxel_centers(grid_coords: torch.Tensor, bbox: torch.Tensor, *,
:param steps `Tensor(1|D)`: (optional) steps alone every dim
:return `Tensor(N..., D)`: discretized grid coordinates
"""
grid_coords = grid_coords.float() + 0.5
grid_coords = grid_coords.float() + .5
if step_size is not None:
return grid_coords * step_size + bbox[0]
return grid_coords / steps * (bbox[1] - bbox[0]) + bbox[0]
......@@ -121,8 +115,8 @@ def split_voxels_local(voxel_size: Union[torch.Tensor, float], n: int, align_bor
dtype = like.dtype
device = like.device
c = torch.arange(1 - n, n, 2, dtype=dtype, device=device)
offset = torch.stack(torch.meshgrid([c] * dims), -1).flatten(0, -2) * voxel_size / 2 /\
(n - 1 if align_border else n)
offset = torch.stack(torch.meshgrid([c] * dims, indexing='ij'), -1).flatten(0, -2)\
* voxel_size * .5 / (n - 1 if align_border else n)
return offset
......@@ -144,7 +138,7 @@ def split_voxels(voxel_centers: torch.Tensor, voxel_size: Union[torch.Tensor, fl
def get_corners(voxel_centers: torch.Tensor, bbox: torch.Tensor, steps: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
half_voxel_size = (bbox[1] - bbox[0]) / steps * 0.5
expand_bbox = bbox
expand_bbox = bbox.clone()
expand_bbox[0] -= 0.5 * half_voxel_size
expand_bbox[1] += 0.5 * half_voxel_size
double_grid_coords = to_grid_coords(voxel_centers, expand_bbox, step_size=half_voxel_size)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment