Commit 6294701e authored by Nianchen Deng's avatar Nianchen Deng
Browse files

sync

parent 2824f796
from typing import Union
from .train_with_space import TrainWithSpace
import torch
from pathlib import Path
from model.cnerf import CNeRF
from data.loader import DataLoader, MultiScaleDataLoader
from modules import Voxels
from utils.misc import print_and_log
class TrainMultiScale(TrainWithSpace):
model: CNeRF
data_loader: Union[DataLoader, MultiScaleDataLoader]
def __init__(self, model: CNeRF, run_dir: Path, states: dict) -> None:
super().__init__(model, run_dir, states)
self.freeze_epochs = self._arg("freeze_epochs", [])
self.level_by_level = True#self._arg("level_by_level", False)
def _train_epoch(self):
l = self._check_epoch_matches(self.freeze_epochs, self.epoch)
if l >= 0:
self.model.trigger_stage(l + 1)
cur_level = self.model.stage
if isinstance(self.data_loader, MultiScaleDataLoader):
if self.level_by_level:
self.data_loader.set_active_sub_loaders(cur_level)
else:
self.data_loader.set_active_sub_loaders(slice(cur_level, None))
self._split()
space: Voxels = self.model.model(self.model.stage).space
if self._check_epoch_matches(self.prune_epochs, self.epoch + 1) >= 0:
self.voxel_access = torch.zeros(space.n_voxels, dtype=torch.long, device=space.device)
super(TrainWithSpace, self)._train_epoch()
if self.voxel_access is not None:
before, after = space.prune(self.voxel_access > 0)
print_and_log(f"Prune by weights: {before} -> {after}")
self.voxel_access = None
from modules.sampler import Samples from .train import Train
from modules.space import Octree, Voxels import sys
import torch
from pathlib import Path
from typing import List
from modules import Voxels
from model.base import BaseModel
from data.loader import DataLoader
from utils.samples import Samples
from utils.mem_profiler import MemProfiler from utils.mem_profiler import MemProfiler
from utils.misc import print_and_log from utils.misc import print_and_log
from .base import * from utils.type import InputData, ReturnData
class TrainWithSpace(Train): class TrainWithSpace(Train):
def __init__(self, model: BaseModel, pruning_loop: int = 10000, splitting_loop: int = 10000, def __init__(self, model: BaseModel, run_dir: Path, states: dict) -> None:
**kwargs) -> None: super().__init__(model, run_dir, states)
super().__init__(model, **kwargs) self.prune_epochs = [] if self.perf_mode else self._arg("prune_epochs", [])
self.pruning_loop = pruning_loop self.split_epochs = [] if self.perf_mode else self._arg("split_epochs", [])
self.splitting_loop = splitting_loop self.voxel_access = None
#MemProfiler.enable = True #MemProfiler.enable = True
def _train_epoch(self): def _train_epoch(self):
if not self.perf_mode: self._split()
if self.epoch != 1: space: Voxels = self.model.space
if self.splitting_loop == 1 or self.epoch % self.splitting_loop == 1: if self._check_epoch_matches(self.prune_epochs, self.epoch + 1) >= 0:
try: self.voxel_access = torch.zeros(space.n_voxels, dtype=torch.long, device=space.device)
with torch.no_grad():
before, after = self.model.split()
print_and_log(f"Splitting done: {before} -> {after}")
except NotImplementedError:
print_and_log(
"Note: The space does not support splitting operation. Just skip it.")
if self.pruning_loop == 1 or self.epoch % self.pruning_loop == 1:
try:
with torch.no_grad():
# self._prune_voxels_by_densities()
self._prune_voxels_by_weights()
except NotImplementedError:
print_and_log(
"Note: The space does not support pruning operation. Just skip it.")
super()._train_epoch() super()._train_epoch()
if self.voxel_access is not None:
before, after = space.prune(self.voxel_access > 0)
print_and_log(f"Prune by weights: {before} -> {after}")
self.voxel_access = None
# self._prune()
def _forward(self, data: InputData) -> ReturnData:
if self.voxel_access is None:
return super()._forward(data)
out = self.model(data, 'color', 'energies', 'speculars', 'weights', "samples")
with torch.no_grad():
access_voxels = out['samples'].voxel_indices[out['weights'][..., 0] > 0.01]
self.voxel_access.index_add_(0, access_voxels, torch.ones_like(access_voxels))
return out
@torch.no_grad()
def _split(self):
if self._check_epoch_matches(self.split_epochs) < 0:
return
try:
before, after = self.model.split()
print_and_log(f"Splitting done: {before} -> {after}")
except NotImplementedError:
print_and_log("The space does not support splitting operation. Just skip it.")
@torch.no_grad()
def _prune(self):
if not self._check_epoch_matches(self.prune_epochs):
return
try:
# self._prune_voxels_by_densities()
self._prune_voxels_by_weights()
except NotImplementedError:
print_and_log("The space does not support pruning operation. Just skip it.")
def _check_epoch_matches(self, key_epochs: List[int], epoch: int = None):
epoch = epoch if epoch is not None else self.epoch
if epoch == 0 or len(key_epochs) == 0:
return -1
if len(key_epochs) == 1:
return 0 if epoch % key_epochs[0] == 0 else -1
try:
return key_epochs.index(epoch)
except ValueError:
return -1
def _prune_voxels_by_densities(self): def _prune_voxels_by_densities(self):
space: Voxels = self.model.space space: Voxels = self.model.space
...@@ -43,7 +81,7 @@ class TrainWithSpace(Train): ...@@ -43,7 +81,7 @@ class TrainWithSpace(Train):
@torch.no_grad() @torch.no_grad()
def get_scores(sampled_points: torch.Tensor, sampled_voxel_indices: torch.Tensor) -> torch.Tensor: def get_scores(sampled_points: torch.Tensor, sampled_voxel_indices: torch.Tensor) -> torch.Tensor:
densities = self.model.render( densities = self.model.render_pass(
Samples(sampled_points, None, None, None, sampled_voxel_indices), Samples(sampled_points, None, None, None, sampled_voxel_indices),
'density') 'density')
return 1 - (-densities).exp() return 1 - (-densities).exp()
...@@ -57,28 +95,38 @@ class TrainWithSpace(Train): ...@@ -57,28 +95,38 @@ class TrainWithSpace(Train):
], 0) # (M[, ...]) ], 0) # (M[, ...])
return space.prune(scores > threshold) return space.prune(scores > threshold)
def _prune_voxels_by_weights(self): def _prune_voxels_by_weights(self, data_loader: DataLoader = None):
space: Voxels = self.model.space space: Voxels = self.model.space
voxel_access_counts = torch.zeros(space.n_voxels, dtype=torch.long, data_loader = data_loader or self.data_loader
device=space.voxels.device) batch_size = data_loader.batch_size
data_loader.batch_size = 2 ** 12
# Note:
# The first element of `voxel_access_counts` and `vidx_map` is perserved for 'invalid voxel'(-1)
# So the index should be offset by 1 when querying these variables
voxel_access_counts = torch.zeros(space.n_voxels + 1, dtype=torch.long, device=space.device)
vidx_map = torch.arange(-1, space.n_voxels, device=space.device)
iters_in_epoch = 0 iters_in_epoch = 0
batch_size = self.data_loader.batch_size for data in data_loader:
self.data_loader.batch_size = 2 ** 14 samples = self.model.sample(data, perturb_sample=False)
for _, rays_o, rays_d, _ in self.data_loader: rays_mask = vidx_map[samples.voxel_indices + 1].ne(-1).any(dim=-1) # (N)
ret = self.model(rays_o, rays_d, samples = samples[rays_mask] # (N, P) -> (N', P)
raymarching_early_stop_tolerance=0, ret = self.model.render(samples, 'weights',
raymarching_chunk_size_or_sections=None, raymarching_early_stop_tolerance=0,
perturb_sample=False, raymarching_chunk_size_or_sections=None)
extra_outputs=['weights'])
valid_mask = ret['weights'][..., 0] > 0.01 valid_mask = ret['weights'][..., 0] > 0.01
accessed_voxels = ret['samples'].voxel_indices[valid_mask] accessed_voxels = samples.voxel_indices[valid_mask]
voxel_access_counts.index_add_(0, accessed_voxels, torch.ones_like(accessed_voxels)) voxel_access_counts.index_add_(0, accessed_voxels + 1, torch.ones_like(accessed_voxels))
# Filter out accessed voxels to speed-up pruning
vidx_map[voxel_access_counts.ne(0)] = -1
iters_in_epoch += 1 iters_in_epoch += 1
percent = iters_in_epoch / len(self.data_loader) * 100 percent = iters_in_epoch / len(data_loader) * 100
sys.stdout.write(f'Pruning by weights...{percent:.1f}% \r') n_voxel_accessed = voxel_access_counts.count_nonzero()
self.data_loader.batch_size = batch_size sys.stdout.write(f'Pruning by weights...{percent: .1f}% '
before, after = space.prune(voxel_access_counts > 0) f'(Accessed {n_voxel_accessed}, Tot {space.n_voxels}) \r')
print_and_log(f"Prune by weights: {before} -> {after}") data_loader.batch_size = batch_size
before, after = space.prune(voxel_access_counts[1:] > 0)
print_and_log(f"Prune by weights: {before} -> {after} ")
def _prune_voxels_by_voxel_weights(self): def _prune_voxels_by_voxel_weights(self):
space: Voxels = self.model.space space: Voxels = self.model.space
...@@ -88,12 +136,14 @@ class TrainWithSpace(Train): ...@@ -88,12 +136,14 @@ class TrainWithSpace(Train):
batch_size = self.data_loader.batch_size batch_size = self.data_loader.batch_size
self.data_loader.batch_size = 2 ** 14 self.data_loader.batch_size = 2 ** 14
iters_in_epoch = 0 iters_in_epoch = 0
for _, rays_o, rays_d, _ in self.data_loader: for data in self.data_loader:
ret = self.model(rays_o, rays_d, ret = self.model(**data,
raymarching_early_stop_tolerance=0, extra_outputs=['weights'],
raymarching_chunk_size_or_sections=None, extra_args={
perturb_sample=False, 'raymarching_early_stop_tolerance': 0,
extra_outputs=['weights']) 'raymarching_chunk_size_or_sections': None,
'perturb_sample': False
})
self._accumulate_access_count_by_weight(ret['samples'], ret['weights'][..., 0], self._accumulate_access_count_by_weight(ret['samples'], ret['weights'][..., 0],
voxel_access_counts) voxel_access_counts)
iters_in_epoch += 1 iters_in_epoch += 1
......
import os import os
import sys import sys
import argparse import argparse
import shutil
from typing import Mapping
from utils.constants import TINY_FLOAT
import torch import torch
import torch.optim import torch.optim
import math
import time import time
from tensorboardX import SummaryWriter from tensorboardX import SummaryWriter
from torch import nn from torch import nn
from numpy.core.numeric import NaN
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
# Arguments for train >>> # Arguments for train >>>
...@@ -49,10 +44,11 @@ print("Set CUDA:%d as current device." % torch.cuda.current_device()) ...@@ -49,10 +44,11 @@ print("Set CUDA:%d as current device." % torch.cuda.current_device())
from utils import netio from utils import netio
from utils import misc from utils import math
from utils import device from utils import device
from utils import img from utils import img
from utils import interact from utils import interact
from utils import color
from utils.progress_bar import progress_bar from utils.progress_bar import progress_bar
from utils.perf import Perf from utils.perf import Perf
from data.spherical_view_syn import * from data.spherical_view_syn import *
...@@ -341,8 +337,8 @@ def test(): ...@@ -341,8 +337,8 @@ def test():
out['bins'] = out['bins'].permute(0, 3, 1, 2) out['bins'] = out['bins'].permute(0, 3, 1, 2)
if args.output_flags['perf']: if args.output_flags['perf']:
perf_errors = torch.ones(n) * NaN perf_errors = torch.ones(n) * math.nan
perf_ssims = torch.ones(n) * NaN perf_ssims = torch.ones(n) * math.nan
if dataset.view_images != None: if dataset.view_images != None:
for i in range(n): for i in range(n):
perf_errors[i] = loss_func(dataset.view_images[i], out['color'][i]).item() perf_errors[i] = loss_func(dataset.view_images[i], out['color'][i]).item()
......
from utils import netio
from pathlib import Path
dir = "/home/dengnc/dvs/data/classroom/_nets/ms_train_t0.8/_cnerfadv_ioc/"
for epochs in range(1, 151):
path = f"{dir}checkpoint_{epochs}.tar"
if not Path(path).exists():
continue
print(f"Update epoch {epochs}")
s = netio.load_checkpoint(path)[0]
args0 = s["args"]
args0_for_submodel = {
key: value for key, value in args0.items()
if key != "sub_models" and key != "interp_on_coarse"
}
for i in range(len(args0["sub_models"])):
args0["sub_models"][i] = {**args0_for_submodel, **args0["sub_models"][i]}
if epochs >= 30:
args0["sub_models"][0]["n_samples"] = 64
elif epochs >= 10:
args0["sub_models"][0]["n_samples"] = 32
if epochs >= 70:
args0["sub_models"][1]["n_samples"] = 128
if epochs >= 120:
args0["sub_models"][2]["n_samples"] = 256
netio.save_checkpoint(s, dir, epochs)
\ No newline at end of file
import math
HUGE_FLOAT = 1e10
TINY_FLOAT = 1e-6
PI = math.pi
NAN = math.nan
E = math.e
\ No newline at end of file
env = None
def get_env():
return env
def set_env(new_env: dict):
global env
env = new_env
...@@ -7,8 +7,7 @@ import numpy as np ...@@ -7,8 +7,7 @@ import numpy as np
import torch.nn.functional as nn_f import torch.nn.functional as nn_f
from typing import List, Tuple, Union from typing import List, Tuple, Union
from . import misc from . import misc
from .constants import * from . import math
def is_image_file(filename): def is_image_file(filename):
""" """
...@@ -186,7 +185,7 @@ def translate(input: torch.Tensor, offset: Tuple[float, float]) -> torch.Tensor: ...@@ -186,7 +185,7 @@ def translate(input: torch.Tensor, offset: Tuple[float, float]) -> torch.Tensor:
def mse2psnr(x): def mse2psnr(x):
logfunc = torch.log if isinstance(x, torch.Tensor) else np.log logfunc = torch.log if isinstance(x, torch.Tensor) else np.log
return -10. * logfunc(x + TINY_FLOAT) / np.log(10.) return -10. * logfunc(x + math.tiny) / np.log(10.)
def colorize_depthmap(depthmap: torch.Tensor, depth_range, inverse=True, colormap='binary'): def colorize_depthmap(depthmap: torch.Tensor, depth_range, inverse=True, colormap='binary'):
......
import torch
from math import *
huge = 1e10
tiny = 1e-6
def expected_sin(x: torch.Tensor, x_var: torch.Tensor):
"""Estimates mean and variance of sin(z), z ~ N(x, var)."""
# When the variance is wide, shrink sin towards zero.
y = (-.5 * x_var).exp() * x.sin()
y_var = torch.clamp_min(.5 * (1 - (-2 * x_var).exp() * (2 * x).cos()) - y**2, 0)
return y, y_var
def lift_gaussian(d: torch.Tensor, t_mean: torch.Tensor, t_var: torch.Tensor,
r_var: torch.Tensor, diag: bool):
"""Lift a Gaussian defined along a ray to 3D coordinates."""
mean = d[..., None, :] * t_mean[..., None]
d_sq = d**2
d_mag_sq = torch.sum(d_sq, -1, keepdim=True).clamp_min(1e-10)
if diag:
d_outer_diag = d_sq
null_outer_diag = 1 - d_outer_diag / d_mag_sq
t_cov_diag = t_var[..., None] * d_outer_diag[..., None, :]
xy_cov_diag = r_var[..., None] * null_outer_diag[..., None, :]
cov_diag = t_cov_diag + xy_cov_diag
return mean, cov_diag
else:
d_outer = d[..., :, None] * d[..., None, :]
eye = torch.eye(d.shape[-1], device=d.device)
null_outer = eye - d[..., :, None] * (d / d_mag_sq)[..., None, :]
t_cov = t_var[..., None, None] * d_outer[..., None, :, :]
xy_cov = r_var[..., None, None] * null_outer[..., None, :, :]
cov = t_cov + xy_cov
return mean, cov
def conical_frustum_to_gaussian(d: torch.Tensor, t0: float, t1: float, base_radius: float,
diag: bool, stable: bool = True):
"""Approximate a conical frustum as a Gaussian distribution (mean+cov).
Assumes the ray is originating from the origin, and base_radius is the
radius at dist=1. Doesn't assume `d` is normalized.
Args:
d: torch.float32 3-vector, the axis of the cone
t0: float, the starting distance of the frustum.
t1: float, the ending distance of the frustum.
base_radius: float, the scale of the radius as a function of distance.
diag: boolean, whether or the Gaussian will be diagonal or full-covariance.
stable: boolean, whether or not to use the stable computation described in
the paper (setting this to False will cause catastrophic failure).
Returns:
a Gaussian (mean and covariance).
"""
if stable:
mu = (t0 + t1) / 2
hw = (t1 - t0) / 2
t_mean = mu + (2 * mu * hw**2) / (3 * mu**2 + hw**2)
t_var = (hw**2) / 3 - (4 / 15) * ((hw**4 * (12 * mu**2 - hw**2)) /
(3 * mu**2 + hw**2)**2)
r_var = base_radius**2 * ((mu**2) / 4 + (5 / 12) * hw**2 - 4 / 15 *
(hw**4) / (3 * mu**2 + hw**2))
else:
t_mean = (3 * (t1**4 - t0**4)) / (4 * (t1**3 - t0**3))
r_var = base_radius**2 * (3 / 20 * (t1**5 - t0**5) / (t1**3 - t0**3))
t_mosq = 3 / 5 * (t1**5 - t0**5) / (t1**3 - t0**3)
t_var = t_mosq - t_mean**2
return lift_gaussian(d, t_mean, t_var, r_var, diag)
def cylinder_to_gaussian(d: torch.Tensor, t0: float, t1: float, radius: float, diag: bool):
"""Approximate a cylinder as a Gaussian distribution (mean+cov).
Assumes the ray is originating from the origin, and radius is the
radius. Does not renormalize `d`.
Args:
d: torch.float32 3-vector, the axis of the cylinder
t0: float, the starting distance of the cylinder.
t1: float, the ending distance of the cylinder.
radius: float, the radius of the cylinder
diag: boolean, whether or the Gaussian will be diagonal or full-covariance.
Returns:
a Gaussian (mean and covariance).
"""
t_mean = (t0 + t1) / 2
r_var = radius**2 / 4
t_var = (t1 - t0)**2 / 12
return lift_gaussian(d, t_mean, t_var, r_var, diag)
...@@ -7,9 +7,9 @@ import torch ...@@ -7,9 +7,9 @@ import torch
import glm import glm
import csv import csv
import numpy as np import numpy as np
from typing import List, Union from typing import List, Tuple, Union
from torch.types import Number from torch.types import Number
from .constants import * from . import math
from .device import * from .device import *
...@@ -46,7 +46,7 @@ def glm2torch(val) -> torch.Tensor: ...@@ -46,7 +46,7 @@ def glm2torch(val) -> torch.Tensor:
return torch.from_numpy(np.array(val)) return torch.from_numpy(np.array(val))
def meshgrid(*size: int, normalize: bool = False, swap_dim: bool = False) -> torch.Tensor: def meshgrid(*size: int, normalize: bool = False, swap_dim: bool = False, device: torch.device = None) -> torch.Tensor:
""" """
Generate a mesh grid Generate a mesh grid
...@@ -57,7 +57,8 @@ def meshgrid(*size: int, normalize: bool = False, swap_dim: bool = False) -> tor ...@@ -57,7 +57,8 @@ def meshgrid(*size: int, normalize: bool = False, swap_dim: bool = False) -> tor
""" """
if len(size) == 1: if len(size) == 1:
size = (size[0], size[0]) size = (size[0], size[0])
y, x = torch.meshgrid(torch.arange(size[0]), torch.arange(size[1]), indexing='ij') y, x = torch.meshgrid(torch.arange(size[0], device=device),
torch.arange(size[1], device=device))
if normalize: if normalize:
x.div_(size[1] - 1.) x.div_(size[1] - 1.)
y.div_(size[0] - 1.) y.div_(size[0] - 1.)
...@@ -65,7 +66,7 @@ def meshgrid(*size: int, normalize: bool = False, swap_dim: bool = False) -> tor ...@@ -65,7 +66,7 @@ def meshgrid(*size: int, normalize: bool = False, swap_dim: bool = False) -> tor
def get_angle(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: def get_angle(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
angle = -torch.atan(x / y) - (y < 0) * PI + 0.5 * PI angle = -torch.atan(x / y) - (y < 0) * math.pi + 0.5 * math.pi
return angle return angle
...@@ -167,13 +168,6 @@ def masked_scatter(mask: torch.Tensor, value: torch.Tensor, initial: Union[torch ...@@ -167,13 +168,6 @@ def masked_scatter(mask: torch.Tensor, value: torch.Tensor, initial: Union[torch
return initial.masked_scatter(mask.reshape(*M_, *repeat(1, len(D_))), value) return initial.masked_scatter(mask.reshape(*M_, *repeat(1, len(D_))), value)
def list_epochs(dir: Path, pattern: str) -> List[int]:
prefix = pattern.split("*")[0]
epoch_list = [int(str(path.stem)[len(prefix):]) for path in dir.glob(pattern)]
epoch_list.sort()
return epoch_list
def rename_seqs_with_offset(dir: Path, file_pattern: str, offset: int): def rename_seqs_with_offset(dir: Path, file_pattern: str, offset: int):
start, end = re.search(r'%0\dd', file_pattern).span() start, end = re.search(r'%0\dd', file_pattern).span()
prefix, suffix = start, len(file_pattern) - end prefix, suffix = start, len(file_pattern) - end
...@@ -185,3 +179,43 @@ def rename_seqs_with_offset(dir: Path, file_pattern: str, offset: int): ...@@ -185,3 +179,43 @@ def rename_seqs_with_offset(dir: Path, file_pattern: str, offset: int):
seqs.sort(reverse=offset > 0) seqs.sort(reverse=offset > 0)
for i in seqs: for i in seqs:
(dir / (file_pattern % i)).rename(dir / (file_pattern % (i + offset))) (dir / (file_pattern % i)).rename(dir / (file_pattern % (i + offset)))
def merge(*args: dict, **kwargs) -> dict:
ret_args = {}
for arg in args:
ret_args.update(arg)
ret_args.update(kwargs)
return ret_args
def union(*tensors: torch.Tensor) -> torch.Tensor:
return torch.cat(tensors, dim=-1)
def split(tensor: torch.Tensor, *sizes: int) -> Tuple[torch.Tensor, ...]:
sizes = list(sizes)
tot_size = sum(sizes)
for i in range(len(sizes)):
if sizes[i] == -1:
sizes[i] = tensor.shape[-1] - tot_size - 1
tot_size = tensor.shape[-1]
break
if tot_size > tensor.shape[-1]:
raise ValueError("The total number of sizes is larger than the last dim of input tensor")
if any([size < 0 for size in sizes]):
raise ValueError("Only one element in sizes could be -1")
if tot_size < tensor.shape[-1]:
sizes = [*sizes, tensor.shape[-1] - tot_size]
return torch.split(tensor, sizes, -1)[:-1]
else:
return torch.split(tensor, sizes, -1)
def dump_tensors_to_csv(path, *tensors: torch.Tensor, open_mode = "w"):
for i in range(len(tensors)):
if len(tensors[i].shape) == 1:
tensors[i] = tensors[i][:, None]
elif len(tensors[i].shape) > 2:
tensors[i] = tensors[i].flatten(1, -1)
with open(path, open_mode) as fp:
csv_writer = csv.writer(fp)
csv_writer.writerows(torch.cat(tensors, -1).tolist())
\ No newline at end of file
import torch
from collections import OrderedDict
from typing import Any, Optional, Union
class Module(torch.nn.Module):
@property
def cls(self) -> str:
return self._get_name()
def __init__(self):
super().__init__()
self.device = torch.device("cpu")
self._temp = OrderedDict()
self._register_load_state_dict_pre_hook(self._before_load_state_dict)
@staticmethod
def create_multiple(fn, n: int) -> torch.nn.ModuleList:
return torch.nn.ModuleList([fn() for _ in range(n)])
def register_temp(self, name: str, value: Union[torch.Tensor, torch.nn.Module]):
temp = self.__dict__.get('_temp')
if temp is None:
raise AttributeError("cannot assign temp before Module.__init__() call")
if '.' in name:
raise KeyError("temp name can't contain \".\"")
if name == '':
raise KeyError("temp name can't be empty string \"\"")
if hasattr(self, name) and name not in temp:
raise KeyError(f"attribute '{name}' already exists")
if not isinstance(value, (type(None), torch.Tensor, torch.nn.Module)):
raise TypeError(f"cannot assign '{torch.typename(value)}' object to temp '{name}' "
"(torch Tensor, Module or None required)")
if value is not None:
value = value.to(self.device)
temp[name] = value
def freeze(self):
for parameter in self.parameters():
parameter.requires_grad = False
return self
def add_module(self, name: str, module: Optional[torch.nn.Module]) -> None:
if isinstance(module, torch.nn.Module):
module = module.to(self.device)
return super().add_module(name, module)
def register_parameter(self, name: str, param: Optional[torch.nn.Parameter]) -> None:
if isinstance(param, torch.nn.Parameter):
param = param.to(self.device)
return super().register_parameter(name, param)
def register_buffer(self, name: str, tensor: Optional[torch.Tensor], persistent: bool = True) -> None:
if isinstance(tensor, torch.Tensor):
tensor = tensor.to(self.device)
return super().register_buffer(name, tensor, persistent=persistent)
def to(self, *args, **kwargs):
super().to(*args, **kwargs)
target_device = None
try:
target_device = torch.device(kwargs['device'] if 'device' in kwargs else args[0])
except Exception:
pass
if target_device is not None:
def move_to_device(m):
if isinstance(m, Module):
m.device = target_device
self.apply(move_to_device)
return self
def load_state_dict(self, state_dict: 'OrderedDict[str, torch.Tensor]', strict: bool = True):
ret = super().load_state_dict(state_dict, strict=strict)
def fn(module):
if isinstance(module, Module):
module._after_load_state_dict()
self.apply(fn)
return ret
def __getattr__(self, name: str) -> Union[torch.Tensor, torch.nn.Module, Any]:
temp = self.__dict__.get('_temp')
if temp is not None and name in temp:
return temp[name]
return super().__getattr__(name)
def __setattr__(self, name: str, value: Union[torch.Tensor, torch.nn.Module, Any]) -> None:
if isinstance(value, (torch.Tensor, torch.nn.Module)):
value = value.to(self.device)
temp = self.__dict__.get('_temp')
if temp is not None and name in temp:
if not isinstance(value, (type(None), torch.Tensor, torch.nn.Module)):
raise TypeError("cannot assign '{}' object to temp '{}' "
"(torch Tensor, Module or None required)"
.format(torch.typename(value), name))
temp[name] = value
else:
super().__setattr__(name, value)
def __delattr__(self, name) -> None:
temp = self.__dict__.get('_temp')
if temp is not None and name in temp:
del self._temp[name]
else:
super().__delattr__(name)
def _apply(self, fn):
for key, value in self._temp.items():
if value is not None:
self._temp[key] = fn(value)
return super()._apply(fn)
def _before_load_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys,
unexpected_keys, error_msgs):
pass
def _after_load_state_dict(self) -> None:
pass
from typing import List, Tuple, Union
import torch import torch
import numpy as np import numpy as np
from utils import device from pathlib import Path
checkpoint_file_prefix = "checkpoint_"
checkpoint_file_suffix = ".tar"
def get_checkpoint_filename(epoch):
return f"{checkpoint_file_prefix}{epoch}{checkpoint_file_suffix}"
def list_epochs(directory: Union[str, Path]) -> List[int]:
directory = Path(directory)
epoch_list = [
int(file_path.stem[len(checkpoint_file_prefix):])
for file_path in directory.glob(get_checkpoint_filename("*"))
]
epoch_list.sort()
return epoch_list
def load_checkpoint(path: Union[str, Path]) -> Tuple[dict, Path]:
path = Path(path)
if path.suffix != checkpoint_file_suffix:
existed_epochs = list_epochs(path)
if len(existed_epochs) == 0:
raise FileNotFoundError(f"{path} does not contain checkpoint files")
path = path / get_checkpoint_filename(existed_epochs[-1])
return torch.load(path), path
def save_checkpoint(states_dict: dict, directory: Union[str, Path], epoch: int):
torch.save(states_dict, Path(directory) / get_checkpoint_filename(epoch))
def log(model): def log(model):
model_parameters = filter(lambda p: p.requires_grad, model.parameters()) model_parameters = filter(lambda p: p.requires_grad, model.parameters())
params = sum([np.prod(p.size()) for p in model_parameters]) params = sum([np.prod(p.size()) for p in model_parameters])
print("%d" % params) print("%d" % params)
def load(path, model, **extra_models):
print('Load net from %s ...' % path)
whole_dict = torch.load(path, map_location=device.default())
model.load_state_dict(whole_dict['model'])
for model, key in extra_models:
if key in whole_dict:
model.load_state_dict(whole_dict[key])
return whole_dict['iters'] if 'iters' in whole_dict else 0
def save(path, model, iters, print_log=True, **extra_models):
if print_log:
print('Saving net to %s ...' % path)
whole_dict = {
'iters': iters,
'model': model.state_dict(),
}
for model, key in extra_models:
whole_dict[key] = model.state_dict()
torch.save(whole_dict, path)
from numpy import average
import torch import torch
import torch.cuda import torch.cuda
from typing import Dict, List, OrderedDict from numpy import average
from typing import Any, Callable, Dict, List, OrderedDict, Union
class Perf(object): class Perf(object):
...@@ -94,24 +94,60 @@ def enable_perf(): ...@@ -94,24 +94,60 @@ def enable_perf():
default_perf_object = Perf() default_perf_object = Perf()
def perf(fn_or_name): class _PerfWrap(object):
def __init__(self, fn: Callable = None, name: str = None) -> None:
super().__init__()
self.fn = fn
self.name = name
def __call__(self, *args: Any, **kwargs: Any) -> Any:
if self.fn == None and len(args) == 1 and isinstance(args[0], Callable):
self.fn = args[0]
return lambda *args, **kwargs: self(*args, **kwargs)
self.__enter__()
ret = self.fn(*args, **kwargs)
self.__exit__()
return ret
def __enter__(self):
#print(f"Start node \"{self.name or self.fn.__qualname__}\"")
start_node(self.name or self.fn.__qualname__)
return self
def __exit__(self, *args: Any, **kwargs: Any):
#print(f"End node \"{self.name or self.fn.__qualname__}\"")
end_node()
def perf(arg: Union[str, Callable]):
if isinstance(arg, str):
return _PerfWrap(name=arg)
else:
return lambda *args, **kwargs: _PerfWrap(fn=arg)(*args, **kwargs)
def debug_perf(fn_or_name):
if isinstance(fn_or_name, str): if isinstance(fn_or_name, str):
name = fn_or_name name = fn_or_name
def perf_with_name(fn): def perf_with_name(fn):
def wrap_perf(*args, **kwargs): def wrap_perf(*args, **kwargs):
start_node(name) node = Perf.Node(name)
ret = fn(*args, **kwargs) ret = fn(*args, **kwargs)
end_node() node.close()
torch.cuda.synchronize()
print(f"Debug Node {name}: {node.duration():.1f}ms")
return ret return ret
return wrap_perf return wrap_perf
return perf_with_name return perf_with_name
fn = fn_or_name fn = fn_or_name
def wrap_perf(*args, **kwargs): def wrap_perf(*args, **kwargs):
start_node(fn.__qualname__) node = Perf.Node(fn.__qualname__)
ret = fn(*args, **kwargs) ret = fn(*args, **kwargs)
end_node() node.close()
torch.cuda.synchronize()
print(f"Debug Node {fn.__qualname__}: {node.duration():.1f}ms")
return ret return ret
return wrap_perf return wrap_perf
......
...@@ -2,22 +2,27 @@ import shutil ...@@ -2,22 +2,27 @@ import shutil
import sys import sys
import time import time
from .misc import format_time from .misc import format_time
from .constants import NAN
last_time = time.time() begin_time = 0
begin_time = last_time recent_times = []
def progress_bar(current, total, msg=None, premsg=None, barmsg=None): def progress_bar(current, total, msg=None, premsg=None, barmsg=None):
global last_time, begin_time global begin_time, recent_times
current_time = time.time()
if current == 0: if current == 0:
begin_time = time.time() # Reset for new bar. begin_time = time.time() # Reset for new bar.
current_time = time.time() recent_times = [begin_time]
step_time = current_time - last_time total_time = 0
total_time = current_time - begin_time estimated_time = 0
last_time = current_time step_time = 0
estimated_time = 0 if current == 0 else total_time / current * (total - current) else:
recent_elapse = current_time - recent_times[0]
step_time = recent_elapse / len(recent_times)
total_time = current_time - begin_time
estimated_time = (total - current) * step_time
recent_times = (recent_times + [current_time])[-100:]
show_opt = int(current_time) % 6 >= 3 and current < total show_opt = int(current_time) % 6 >= 3 and current < total
show_barmsg = barmsg is not None and show_opt show_barmsg = barmsg is not None and show_opt
......
import torch
from typing import Optional, Union, Any, List, Tuple
from . import math
class Samples(object):
indices: torch.Tensor
""" Tensor(N[, P], 2)` The unique indices of samples, e.g. (i-th ray, j-th sample)"""
pts: torch.Tensor
"""`Tensor(N[, P], 3)`"""
dirs: torch.Tensor
"""`Tensor(N[, P], 3)`"""
depths: torch.Tensor
"""`Tensor(N[, P])`"""
dists: torch.Tensor
"""`Tensor(N[, P])`"""
voxel_indices: torch.Tensor
"""`Tensor(N[, P])`"""
size: List[int]
"""Size of the samples"""
device: torch.device
"""Device where tensors of this object locate"""
def __init__(self, **_data: Union[torch.Tensor, float, int]) -> None:
super().__init__()
super().__setattr__("_data", _data)
super().__setattr__("size", self.pts.size()[:-1])
super().__setattr__("device", self.pts.device)
def __getitem__(self, index: Union[int, slice, list, tuple, torch.Tensor, None]):
if isinstance(index, torch.Tensor) and index.dtype == torch.bool:
index = index.nonzero(as_tuple=True)
return Samples(**{
key: value[index] if isinstance(value, torch.Tensor) else value
for key, value in self._data.items()
})
def __getattr__(self, __name: str) -> Union[torch.Tensor, Any]:
try:
return self._data[__name]
except KeyError:
return None
def __setattr__(self, __name: str, __value: Any) -> None:
self._data[__name] = __value
def reshape(self, *shape: int):
return Samples(**{
key: value.reshape(*shape, *value.shape[len(self.size):])
if isinstance(value, torch.Tensor) else value
for key, value in self._data.items()
})
def filter_rays(self) -> Optional[torch.Tensor]:
if isinstance(self.voxel_indices, torch.Tensor):
valid_rays_mask = self.voxel_indices.ne(-1).any(dim=-1) # (N)
rays_filter = valid_rays_mask.nonzero(as_tuple=True)[0] # (N) -> (N')
super().__setattr__("_data", {
key: value[rays_filter] if isinstance(value, torch.Tensor) else value
for key, value in self._data.items()
})
super().__setattr__("size", self.pts.size()[:-1])
return rays_filter
return None
def interpolate(self, fine_samples, *values: torch.Tensor) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]:
P1 = self.size[-1]
P2 = fine_samples.size[-1]
K = P2 // P1
if K > 1:
# Do interpolation
t1 = self.t # ([N, ]P1)
t2 = fine_samples.t # ([N, ]P2)
lo = torch.arange(P1, device=fine_samples.device).repeat_interleave(K)[:P2]
up = (lo + 1).clamp(max=P1 - 1)
t1_lo, t1_up = t1[..., lo], t1[..., up]
k = ((t2 - t1_lo) / (t1_up - t1_lo + math.tiny))[..., None] # ([N, ]P2, 1)
values = [
value[..., lo, :] * (1 - k) + value[..., up, :] * k # ([N, ]P2, X)
for value in values
]
return values[0] if len(values) == 1 else tuple(values)
import numpy as np import numpy as np
from .constants import * from . import math
def helix(t_range, loops, n): def helix(t_range, loops, n):
n_per_loop = n // loops n_per_loop = n // loops
angles = np.linspace(0, 2 * PI, n_per_loop, endpoint=False)[None, :]. \ angles = np.linspace(0, 2 * math.math.pi, n_per_loop, endpoint=False)[None, :]. \
repeat(loops, axis=0).flatten() repeat(loops, axis=0).flatten()
centers = np.empty([n, 3]) centers = np.empty([n, 3])
centers[:, 0] = 0.5 * t_range[0] * np.cos(angles) centers[:, 0] = 0.5 * t_range[0] * np.cos(angles)
...@@ -18,9 +17,9 @@ def helix(t_range, loops, n): ...@@ -18,9 +17,9 @@ def helix(t_range, loops, n):
def scan_around(t_range, circles, n): def scan_around(t_range, circles, n):
angles = np.linspace(-PI, PI, n // circles, endpoint=False) angles = np.linspace(-math.pi, math.pi, n // circles, endpoint=False)
x_rots = angles[None, :].repeat(circles, axis=0).flatten() x_rots = angles[None, :].repeat(circles, axis=0).flatten()
c_angles = angles + 0.8 * PI c_angles = angles + 0.8 * math.pi
centers = np.empty([n, 3]) centers = np.empty([n, 3])
for i in range(circles): for i in range(circles):
r = (0.5 * t_range[0] / circles * (i + 1), r = (0.5 * t_range[0] / circles * (i + 1),
...@@ -28,16 +27,16 @@ def scan_around(t_range, circles, n): ...@@ -28,16 +27,16 @@ def scan_around(t_range, circles, n):
0.5 * t_range[2]) 0.5 * t_range[2])
s = slice(i * len(angles), (i + 1) * len(angles)) s = slice(i * len(angles), (i + 1) * len(angles))
centers[s, 0] = r[0] * np.sin(c_angles) centers[s, 0] = r[0] * np.sin(c_angles)
centers[s, 1] = r[1] * np.sin(angles * 10 + i * 2 * PI / circles) centers[s, 1] = r[1] * np.sin(angles * 10 + i * 2 * math.pi / circles)
centers[s, 2] = r[2] * np.cos(c_angles) centers[s, 2] = r[2] * np.cos(c_angles)
rots = np.stack([x_rots, np.zeros_like(x_rots)], axis=1) rots = np.stack([x_rots, np.zeros_like(x_rots)], axis=1)
return centers, rots return centers, rots
def look_around(t_range, n): def look_around(t_range, n):
angles = np.linspace(-PI, PI, n, endpoint=False) angles = np.linspace(-math.pi, math.pi, n, endpoint=False)
x_rots = angles x_rots = angles
c_angles = angles + 0.8 * PI c_angles = angles + 0.8 * math.pi
centers = np.empty([n, 3]) centers = np.empty([n, 3])
r = (0.5 * t_range[0], 0.5 * t_range[1], 0.5 * t_range[2]) r = (0.5 * t_range[0], 0.5 * t_range[1], 0.5 * t_range[2])
centers[:, 0] = r[0] * np.sin(c_angles) centers[:, 0] = r[0] * np.sin(c_angles)
......
from typing import Union from typing import Union
import torch import torch
import math from . import math
from . import misc from . import misc
......
from typing import Any, Dict, Union
import torch
InputData = Dict[str, Union[torch.Tensor, Any]]
ReturnData = Dict[str, Union[torch.Tensor, Any]]
NetOutput = Dict[str, torch.Tensor]
class NetInput:
def __init__(self, x: torch.Tensor = None, d: torch.Tensor = None, f: torch.Tensor = None) -> None:
self.x = x
self.d = d
self.f = f
if x is not None:
self.shape = x.shape[:-1]
elif d is not None:
self.shape = d.shape[:-1]
else:
self.shape = [0]
def __getitem__(self, index: Union[int, slice, list, tuple, torch.Tensor, None]) -> 'NetInput':
if isinstance(index, torch.Tensor) and index.dtype == torch.bool:
index = index.nonzero(as_tuple=True)
return NetInput(
self.x[index] if self.x is not None else None,
self.d[index] if self.d is not None else None,
self.f[index] if self.f is not None else None
)
from typing import List, Mapping, Tuple, Union from typing import List, Mapping, Tuple, Union
from numpy import ones_like
import torch import torch
import math
import glm import glm
from . import misc from . import misc
from . import math
def fov2length(angle): def fov2length(angle):
......
import torch import torch
from typing import Tuple, Union from typing import Tuple, Union
from . import math
def get_grid_steps(bbox: torch.Tensor, step_size: Union[torch.Tensor, float]) -> torch.Tensor: def get_grid_steps(bbox: torch.Tensor, step_size: Union[torch.Tensor, float]) -> torch.Tensor:
""" """
...@@ -13,6 +15,18 @@ def get_grid_steps(bbox: torch.Tensor, step_size: Union[torch.Tensor, float]) -> ...@@ -13,6 +15,18 @@ def get_grid_steps(bbox: torch.Tensor, step_size: Union[torch.Tensor, float]) ->
return ((bbox[1] - bbox[0]) / step_size).ceil().long() return ((bbox[1] - bbox[0]) / step_size).ceil().long()
def get_out_of_bound_mask(pts: torch.Tensor, bbox: torch.Tensor) -> torch.Tensor:
"""
Get a mask tensor indicating which elements in `pts` are out of the bound `bbox`
:param pts `Tensor(N..., D)`: points
:param bbox `Tensor(2, D)`: bounding box
:return `Tensor(N...)`: a mask tensor
"""
k = (pts - bbox[0]) / (bbox[1] - bbox[0])
return torch.logical_or(k < -math.tiny, k > 1 + math.tiny).any(-1)
def to_flat_indices(grid_coords: torch.Tensor, steps: torch.Tensor) -> torch.Tensor: def to_flat_indices(grid_coords: torch.Tensor, steps: torch.Tensor) -> torch.Tensor:
indices = grid_coords[..., 0] indices = grid_coords[..., 0]
for i in range(1, grid_coords.shape[-1]): for i in range(1, grid_coords.shape[-1]):
...@@ -20,9 +34,7 @@ def to_flat_indices(grid_coords: torch.Tensor, steps: torch.Tensor) -> torch.Ten ...@@ -20,9 +34,7 @@ def to_flat_indices(grid_coords: torch.Tensor, steps: torch.Tensor) -> torch.Ten
return indices return indices
def to_grid_coords(pts: torch.Tensor, bbox: torch.Tensor, *, def to_grid_coords(pts: torch.Tensor, bbox: torch.Tensor, steps: torch.Tensor) -> torch.Tensor:
step_size: Union[torch.Tensor, float] = None,
steps: torch.Tensor = None) -> torch.Tensor:
""" """
Get discretized (integer) grid coordinates of points. Get discretized (integer) grid coordinates of points.
...@@ -32,18 +44,13 @@ def to_grid_coords(pts: torch.Tensor, bbox: torch.Tensor, *, ...@@ -32,18 +44,13 @@ def to_grid_coords(pts: torch.Tensor, bbox: torch.Tensor, *,
:param pts `Tensor(N..., D)`: points :param pts `Tensor(N..., D)`: points
:param bbox `Tensor(2, D)`: bounding box :param bbox `Tensor(2, D)`: bounding box
:param step_size `Tensor(1|D) | float`: (optional) step size :param steps `Tensor(1|D)`: steps alone every dim
:param steps `Tensor(1|D)`: (optional) steps alone every dim
:return `Tensor(N..., D)`: discretized grid coordinates :return `Tensor(N..., D)`: discretized grid coordinates
""" """
if step_size is not None:
return ((pts - bbox[0]) / step_size).floor().long()
return ((pts - bbox[0]) / (bbox[1] - bbox[0]) * steps).floor().long() return ((pts - bbox[0]) / (bbox[1] - bbox[0]) * steps).floor().long()
def to_grid_indices(pts: torch.Tensor, bbox: torch.Tensor, *, def to_grid_indices(pts: torch.Tensor, bbox: torch.Tensor, steps: torch.Tensor) -> torch.Tensor:
step_size: Union[torch.Tensor, float] = None,
steps: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]:
""" """
Get flattened grid indices of points. Get flattened grid indices of points.
...@@ -53,46 +60,34 @@ def to_grid_indices(pts: torch.Tensor, bbox: torch.Tensor, *, ...@@ -53,46 +60,34 @@ def to_grid_indices(pts: torch.Tensor, bbox: torch.Tensor, *,
:param pts `Tensor(N..., D)`: points :param pts `Tensor(N..., D)`: points
:param bbox `Tensor(2, D)`: bounding box :param bbox `Tensor(2, D)`: bounding box
:param step_size `Tensor(1|D) | float`: (optional) step size :param steps `Tensor(1|D)`: steps alone every dim
:param steps `Tensor(1|D)`: (optional) steps alone every dim :return `Tensor(N...)`: flattened grid indices
:return `Tensor(N...)`: grid indices """
:return `Tensor(N...)`: a mask tensor indicating the returned indices are outside or not grid_coords = to_grid_coords(pts, bbox, steps).minimum(steps - 1) # (N..., D)
""" outside_mask = get_out_of_bound_mask(pts, bbox) # (N...)
if step_size is not None:
steps = get_grid_steps(bbox, step_size) # (D)
grid_coords = to_grid_coords(pts, bbox, step_size=step_size, steps=steps) # (N..., D)
outside_mask = torch.logical_or(grid_coords < 0, grid_coords >= steps).any(-1) # (N...)
grid_indices = to_flat_indices(grid_coords, steps) grid_indices = to_flat_indices(grid_coords, steps)
return grid_indices, outside_mask grid_indices[outside_mask] = -1
return grid_indices
def init_voxels(bbox: torch.Tensor, steps: torch.Tensor): def init_voxels(bbox: torch.Tensor, steps: torch.Tensor):
""" """
Initialize voxels. Initialize voxels.
""" """
x, y, z = torch.meshgrid(*[torch.arange(steps[i]) for i in range(3)], indexing="ij") x, y, z = torch.meshgrid(*[torch.arange(steps[i]) for i in range(3)])
return to_voxel_centers(torch.stack([x, y, z], -1).reshape(-1, 3), bbox, steps=steps) return to_voxel_centers(torch.stack([x, y, z], -1).reshape(-1, 3), bbox, steps)
def to_voxel_centers(grid_coords: torch.Tensor, bbox: torch.Tensor, *, def to_voxel_centers(grid_coords: torch.Tensor, bbox: torch.Tensor, steps: torch.Tensor) -> torch.Tensor:
step_size: Union[torch.Tensor, float] = None,
steps: torch.Tensor = None) -> torch.Tensor:
""" """
Get discretized (integer) grid coordinates of points. Get center positions of grids.
At least one of the parameters `step_size` and `steps` should be specified. If `step_size` is :param grid_coords `Tensor(N..., D)`: grid coordinates
specified, then the grid coordinates will be calculated according to the step size, ignoring
the value of `steps`.
:param pts `Tensor(N..., D)`: points
:param bbox `Tensor(2, D)`: bounding box :param bbox `Tensor(2, D)`: bounding box
:param step_size `Tensor(1|D) | float`: (optional) step size :param steps `Tensor(1|D)`: steps alone every dim
:param steps `Tensor(1|D)`: (optional) steps alone every dim :return `Tensor(N..., D)`: center positions of grids
:return `Tensor(N..., D)`: discretized grid coordinates
""" """
grid_coords = grid_coords.float() + .5 grid_coords = grid_coords.float() + .5
if step_size is not None:
return grid_coords * step_size + bbox[0]
return grid_coords / steps * (bbox[1] - bbox[0]) + bbox[0] return grid_coords / steps * (bbox[1] - bbox[0]) + bbox[0]
...@@ -115,7 +110,7 @@ def split_voxels_local(voxel_size: Union[torch.Tensor, float], n: int, align_bor ...@@ -115,7 +110,7 @@ def split_voxels_local(voxel_size: Union[torch.Tensor, float], n: int, align_bor
dtype = like.dtype dtype = like.dtype
device = like.device device = like.device
c = torch.arange(1 - n, n, 2, dtype=dtype, device=device) c = torch.arange(1 - n, n, 2, dtype=dtype, device=device)
offset = torch.stack(torch.meshgrid([c] * dims, indexing='ij'), -1).flatten(0, -2)\ offset = torch.stack(torch.meshgrid([c] * dims), -1).flatten(0, -2)\
* voxel_size * .5 / (n - 1 if align_border else n) * voxel_size * .5 / (n - 1 if align_border else n)
return offset return offset
...@@ -141,14 +136,14 @@ def get_corners(voxel_centers: torch.Tensor, bbox: torch.Tensor, steps: torch.Te ...@@ -141,14 +136,14 @@ def get_corners(voxel_centers: torch.Tensor, bbox: torch.Tensor, steps: torch.Te
expand_bbox = bbox.clone() expand_bbox = bbox.clone()
expand_bbox[0] -= 0.5 * half_voxel_size expand_bbox[0] -= 0.5 * half_voxel_size
expand_bbox[1] += 0.5 * half_voxel_size expand_bbox[1] += 0.5 * half_voxel_size
double_grid_coords = to_grid_coords(voxel_centers, expand_bbox, step_size=half_voxel_size) double_grid_coords = to_grid_coords(voxel_centers, expand_bbox, steps * 2 + 1)
# (M, 3) -> [1, 3, 5, ...] # (M, 3) -> [1, 3, 5, ...]
corner_coords = split_voxels(double_grid_coords, 2, 2).reshape(-1, 3) corner_coords = split_voxels(double_grid_coords, 2, 2).reshape(-1, 3)
# (8M, 3) -> [0, 2, 4, ...] # (8M, 3) -> [0, 2, 4, ...]
corner_coords, corner_indices = corner_coords.unique(dim=0, sorted=True, return_inverse=True) corner_coords, corner_indices = corner_coords.unique(dim=0, sorted=True, return_inverse=True)
corners = to_voxel_centers(corner_coords, expand_bbox, step_size=half_voxel_size) corners = to_voxel_centers(corner_coords, expand_bbox, steps * 2 + 1)
return corners, corner_indices.reshape(-1, 8) return corners, corner_indices.reshape(-1, 8)
...@@ -163,6 +158,7 @@ def trilinear_interp(pts: torch.Tensor, corner_values: torch.Tensor) -> torch.Te ...@@ -163,6 +158,7 @@ def trilinear_interp(pts: torch.Tensor, corner_values: torch.Tensor) -> torch.Te
""" """
pts = pts[:, None] # (N, 1, 3) pts = pts[:, None] # (N, 1, 3)
corners = split_voxels_local(1, 2, like=pts) + 0.5 # (8, 3) corners = split_voxels_local(1, 2, like=pts) + 0.5 # (8, 3)
weights = (pts * corners * 2 - pts - corners + 1).prod(-1, keepdim=True) # (N, 8, 1)
corner_values = corner_values.reshape(corner_values.size(0), 8, -1) # (N, 8, X) corner_values = corner_values.reshape(corner_values.size(0), 8, -1) # (N, 8, X)
weights = (pts * corners * 2 - pts - corners + 1).prod(-1, keepdim=True) # (N, 8, 1)
return (weights * corner_values).sum(1) return (weights * corner_values).sum(1)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment