from math import ceil import torch import numpy as np from typing import List, NoReturn, Tuple, Union from torch import nn from plyfile import PlyData, PlyElement from utils.geometry import * from utils.constants import * from utils.voxels import * from utils.perf import perf from clib import * class Intersections: min_depths: torch.Tensor """`Tensor(N, P)` Min ray depths of intersected voxels""" max_depths: torch.Tensor """`Tensor(N, P)` Max ray depths of intersected voxels""" voxel_indices: torch.Tensor """`Tensor(N, P)` Indices of intersected voxels""" hits: torch.Tensor """`Tensor(N)` Number of hits""" @property def size(self): return self.hits.size(0) def __init__(self, min_depths: torch.Tensor, max_depths: torch.Tensor, voxel_indices: torch.Tensor, hits: torch.Tensor) -> None: self.min_depths = min_depths self.max_depths = max_depths self.voxel_indices = voxel_indices self.hits = hits def __getitem__(self, index): return Intersections( min_depths=self.min_depths[index], max_depths=self.max_depths[index], voxel_indices=self.voxel_indices[index], hits=self.hits[index]) class Space(nn.Module): bbox: Union[torch.Tensor, None] """`Tensor(2, 3)` Bounding box""" def __init__(self, *, bbox: List[float] = None, **kwargs): super().__init__() if bbox is None: self.bbox = None else: self.register_buffer('bbox', torch.Tensor(bbox).reshape(2, 3), persistent=False) def create_embedding(self, n_dims: int, name: str = 'default') -> torch.nn.Embedding: raise NotImplementedError def extract_embedding(self, pts: torch.Tensor, voxel_indices: torch.Tensor, name: str = 'default') -> torch.Tensor: raise NotImplementedError def ray_intersect(self, rays_o: torch.Tensor, rays_d: torch.Tensor, n_max_hits: int) -> Intersections: raise NotImplementedError def get_voxel_indices(self, pts: torch.Tensor) -> torch.Tensor: voxel_indices = torch.zeros_like(pts[..., 0], dtype=torch.long) if self.bbox is not None: out_bbox = torch.logical_or(pts < self.bbox[0], pts >= self.bbox[1]).any(-1) # (N...) voxel_indices[out_bbox] = -1 return voxel_indices @torch.no_grad() def pruning(self, score_fn, threshold: float = 0.5, train_stats=False): raise NotImplementedError() @torch.no_grad() def splitting(self): raise NotImplementedError() class Voxels(Space): steps: torch.Tensor """`Tensor(3)` Steps along each dimension""" corners: torch.Tensor """`Tensor(C, 3)` Corner positions""" voxels: torch.Tensor """`Tensor(M, 3)` Voxel centers""" corner_indices: torch.Tensor """`Tensor(M, 8)` Voxel corner indices""" voxel_indices_in_grid: torch.Tensor """`Tensor(G)` Indices in voxel list or -1 for pruned space""" @property def dims(self) -> int: """`int` Number of dimensions""" return self.steps.size(0) @property def n_voxels(self) -> int: """`int` Number of voxels""" return self.voxels.size(0) @property def n_corner(self) -> int: """`int` Number of corners""" return self.corners.size(0) @property def voxel_size(self) -> torch.Tensor: """`Tensor(3)` Voxel size""" return (self.bbox[1] - self.bbox[0]) / self.steps @property def device(self) -> torch.device: return self.voxels.device def __init__(self, *, voxel_size: float = None, steps: Union[torch.Tensor, Tuple[int, int, int]] = None, **kwargs) -> None: super().__init__(**kwargs) if self.bbox is None: raise ValueError("Missing argument 'bbox'") if voxel_size is not None: self.register_buffer('steps', get_grid_steps(self.bbox, voxel_size)) else: self.register_buffer('steps', torch.tensor(steps, dtype=torch.long)) self.register_buffer('voxels', init_voxels(self.bbox, self.steps)) corners, corner_indices = get_corners(self.voxels, self.bbox, self.steps) self.register_buffer("corners", corners) self.register_buffer("corner_indices", corner_indices) self.register_buffer('voxel_indices_in_grid', torch.arange(self.n_voxels)) self._register_load_state_dict_pre_hook(self._before_load_state_dict) def create_embedding(self, n_dims: int, name: str = 'default') -> torch.nn.Embedding: """ Create a embedding on voxel corners. :param name `str`: embedding name :param n_dims `int`: embedding dimension :return `Embedding(n_corners, n_dims)`: new embedding on voxel corners """ name = f'emb_{name}' self.add_module(name, torch.nn.Embedding(self.n_corners.item(), n_dims)) return self.__getattr__(name) def get_embedding(self, name: str = 'default') -> torch.nn.Embedding: return getattr(self, f'emb_{name}') def extract_embedding(self, pts: torch.Tensor, voxel_indices: torch.Tensor, name: str = 'default') -> torch.Tensor: """ Extract embedding values at given points using trilinear interpolation. :param pts `Tensor(N, 3)`: points to extract values :param voxel_indices `Tensor(N)`: corresponding voxel indices :param name `str`: embedding name, default to 'default' :return `Tensor(N, X)`: extracted values """ emb = self.get_embedding(name) if emb is None: raise KeyError(f"Embedding '{name}' doesn't exist") voxels = self.voxels[voxel_indices] # (N, 3) corner_indices = self.corner_indices[voxel_indices] # (N, 8) p = (pts - voxels) / self.voxel_size + 0.5 # (N, 3) normed-coords in voxel features = emb(corner_indices).reshape(pts.size(0), 8, -1) # (N, 8, X) return trilinear_interp(p, features) @perf def ray_intersect(self, rays_o: torch.Tensor, rays_d: torch.Tensor, n_max_hits: int) -> Intersections: """ Calculate intersections of rays and voxels. :param rays_o `Tensor(N, 3)`: rays' origin :param rays_d `Tensor(N, 3)`: rays' direction :param n_max_hits `int`: maximum number of hits (for allocating enough space) :return `Intersection`: intersections of rays and voxels """ # Prepend a dim to meet the requirement of external call rays_o = rays_o[None].contiguous() rays_d = rays_d[None].contiguous() voxel_indices, min_depths, max_depths = self._ray_intersect(rays_o, rays_d, n_max_hits) invalid_voxel_mask = voxel_indices.eq(-1) hits = n_max_hits - invalid_voxel_mask.sum(-1) # Sort intersections according to their depths min_depths.masked_fill_(invalid_voxel_mask, HUGE_FLOAT) max_depths.masked_fill_(invalid_voxel_mask, HUGE_FLOAT) min_depths, sorted_idx = min_depths.sort(dim=-1) max_depths = max_depths.gather(-1, sorted_idx) voxel_indices = voxel_indices.gather(-1, sorted_idx) return Intersections( min_depths=min_depths[0], max_depths=max_depths[0], voxel_indices=voxel_indices[0], hits=hits[0] ) @perf def get_voxel_indices(self, pts: torch.Tensor) -> torch.Tensor: """ Get voxel indices of points. If a point is not in any valid voxels, its corresponding voxel index is -1. :param pts `Tensor(N..., 3)`: points :return `Tensor(N...)`: corresponding voxel indices """ grid_indices, out_mask = to_grid_indices(pts, self.bbox, steps=self.steps) grid_indices[out_mask] = 0 voxel_indices = self.voxel_indices_in_grid[grid_indices] voxel_indices[out_mask] = -1 return voxel_indices @torch.no_grad() def splitting(self) -> None: """ Split voxels into smaller voxels with half size. """ n_voxels_before = self.n_voxels self.steps *= 2 self.voxels = split_voxels(self.voxels, self.voxel_size, 2, align_border=False)\ .reshape(-1, 3) self._update_corners() self._update_voxel_indices_in_grid() return n_voxels_before, self.n_voxels @torch.no_grad() def prune(self, keeps: torch.Tensor) -> Tuple[int, int]: self.voxels = self.voxels[keeps] self.corner_indices = self.corner_indices[keeps] self._update_voxel_indices_in_grid() return keeps.size(0), keeps.sum().item() @torch.no_grad() def pruning(self, score_fn, threshold: float = 0.5) -> None: scores = self._get_scores(score_fn, lambda x: torch.max(x, -1)[0]) # (M) return self.prune(scores > threshold) def n_voxels_along_dim(self, dim: int) -> torch.Tensor: sum_dims = [val for val in range(self.dims) if val != dim] return self.voxel_indices_in_grid.reshape(*self.steps).ne(-1).sum(sum_dims) def balance_cut(self, dim: int, n_parts: int) -> List[int]: n_voxels_list = self.n_voxels_along_dim(dim) cdf = (n_voxels_list.cumsum(0) / self.n_voxels * n_parts).tolist() bins = [] part = 1 offset = 0 for i in range(len(cdf)): if cdf[i] >= part: bins.append(i + 1 - offset) offset = i + 1 part = int(cdf[i]) + 1 return bins def sample(self, bits: int, perturb: bool = False) -> Tuple[torch.Tensor, torch.Tensor]: sampled_xyz = split_voxels(self.voxels, self.voxel_size, bits) sampled_idx = torch.arange(self.n_voxels, device=self.device)[:, None].expand( *sampled_xyz.shape[:2]) sampled_xyz, sampled_idx = sampled_xyz.reshape(-1, 3), sampled_idx.flatten() @torch.no_grad() def _get_scores(self, score_fn, reduce_fn=None, bits=16) -> torch.Tensor: def get_scores_once(pts, idxs): scores = score_fn(pts, idxs).reshape(-1, bits ** 3) # (B, P) if reduce_fn is not None: scores = reduce_fn(scores) # (B[, ...]) return scores sampled_xyz, sampled_idx = self.sample(bits) chunk_size = 64 return torch.cat([ get_scores_once(sampled_xyz[i:i + chunk_size], sampled_idx[i:i + chunk_size]) for i in range(0, self.voxels.size(0), chunk_size) ], 0) # (M[, ...]) def _ray_intersect(self, rays_o: torch.Tensor, rays_d: torch.Tensor, n_max_hits: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: return aabb_ray_intersect(self.voxel_size, n_max_hits, self.voxels, rays_o, rays_d) def _update_corners(self): """ Update voxel corners. """ corners, corner_indices = get_corners(self.voxels, self.bbox, self.steps) self.register_buffer("corners", corners) self.register_buffer("corner_indices", corner_indices) def _update_voxel_indices_in_grid(self): """ Update voxel indices in grid. """ grid_indices, _ = to_grid_indices(self.voxels, self.bbox, steps=self.steps) self.voxel_indices_in_grid = grid_indices.new_full([self.steps.prod().item()], -1) self.voxel_indices_in_grid[grid_indices] = torch.arange(self.n_voxels, device=self.device) @torch.no_grad() def _before_load_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): # Handle buffers for name, buffer in self.named_buffers(recurse=False): if name in self._non_persistent_buffers_set: continue buffer.resize_as_(state_dict[prefix + name]) # Handle embeddings for name, module in self.named_modules(): if name.startswith('emb_'): setattr(self, name, torch.nn.Embedding(self.n_corners.item(), module.embedding_dim)) class Octree(Voxels): def __init__(self, **kwargs) -> None: super().__init__(**kwargs) self.nodes_cached = None self.tree_cached = None def get(self) -> Tuple[torch.Tensor, torch.Tensor]: if self.nodes_cached is None: self.nodes_cached, self.tree_cached = build_easy_octree( self.voxels, 0.5 * self.voxel_size) return self.nodes_cached, self.tree_cached def clear(self): self.nodes_cached = None self.tree_cached = None def _ray_intersect(self, rays_o: torch.Tensor, rays_d: torch.Tensor, n_max_hits: int): nodes, tree = self.get() return octree_ray_intersect(self.voxel_size, n_max_hits, nodes, tree, rays_o, rays_d) @torch.no_grad() def splitting(self): ret = super().splitting() self.clear() return ret @torch.no_grad() def prune(self, keeps: torch.Tensor) -> Tuple[int, int]: ret = super().prune(keeps) self.clear() return ret