from .__common__ import * from clib import * #from model.utils import load from utils.nn import Parameter from utils.geometry import * from utils.voxels import * __all__ = ["Space", "Voxels", "Octree"] class Intersections: min_depths: torch.Tensor """`Tensor(N, P)` Min ray depths of intersected voxels""" max_depths: torch.Tensor """`Tensor(N, P)` Max ray depths of intersected voxels""" voxel_indices: torch.Tensor """`Tensor(N, P)` Indices of intersected voxels""" hits: torch.Tensor """`Tensor(N)` Number of hits""" @property def shape(self): return self.hits.shape def __init__(self, min_depths: torch.Tensor, max_depths: torch.Tensor, voxel_indices: torch.Tensor, hits: torch.Tensor) -> None: self.min_depths = min_depths self.max_depths = max_depths self.voxel_indices = voxel_indices self.hits = hits def __getitem__(self, index): return Intersections( min_depths=self.min_depths[index], max_depths=self.max_depths[index], voxel_indices=self.voxel_indices[index], hits=self.hits[index]) class Space(nn.Module): bbox: torch.Tensor | None """`Tensor(2, D)` Bounding box""" @property def dims(self) -> int: """`int` Number of dimensions""" return self.bbox.shape[1] if self.bbox is not None else 3 @staticmethod def create(type: str, args: dict[str, Any]) -> 'Space': match type: case "Space": return Space(**args) case "Octree": return Octree(**args) case "Voxels": return Voxels(**args) case _: return load(type).space def __init__(self, clone_src: "Space" = None, *, bbox: list[float] = None, **kwargs): super().__init__() if clone_src: self.device = clone_src.device self.register_temp('bbox', clone_src.bbox) else: self.register_temp('bbox', None if not bbox else torch.tensor(bbox).reshape(2, -1)) def ray_intersect_with_bbox(self, rays_o: torch.Tensor, rays_d: torch.Tensor) -> Intersections: """ [summary] :param rays_o `Tensor(N..., D)`: rays' origin :param rays_d `Tensor(N..., D)`: rays' direction :param max_hits `int?`: max number of hits of each ray, have no effect for this method :return `Intersect(N...)`: rays' intersection with the bounding box """ if self.bbox is None: raise RuntimeError("The space has no bounding box") inv_d = rays_d.reciprocal().unsqueeze(-2) t = (self.bbox - rays_o.unsqueeze(-2)) * inv_d # (N..., 2, D) t0 = t.min(dim=-2)[0].max(dim=-1, keepdim=True)[0].clamp(min=1e-4) # (N..., 1) t1 = t.max(dim=-2)[0].min(dim=-1, keepdim=True)[0] miss = t1 <= t0 t0[miss], t1[miss] = -1., -1. hit = torch.logical_not(miss).long() return Intersections(t0, t1, hit - 1, hit.squeeze(-1)) def ray_intersect(self, rays_o: torch.Tensor, rays_d: torch.Tensor, max_hits: int) -> Intersections: return self.ray_intersect_with_bbox(rays_o, rays_d) def get_voxel_indices(self, pts: torch.Tensor) -> int | torch.Tensor: if self.bbox is None: return 0 voxel_indices = torch.zeros_like(pts[..., 0], dtype=torch.long) out_bbox = get_out_of_bound_mask(pts, self.bbox) # (N...) voxel_indices[out_bbox] = -1 return voxel_indices @torch.no_grad() def prune(self, keeps: torch.Tensor) -> tuple[int, int]: raise NotImplementedError() @torch.no_grad() def split(self) -> tuple[int, int]: raise NotImplementedError() @torch.no_grad() def clone(self): return self.__class__(self) class Voxels(Space): bbox: torch.Tensor """`Tensor(2, D)` Bounding box""" steps: torch.Tensor """`Tensor(3)` Steps along each dimension""" corners: torch.Tensor """`Tensor(C, 3)` Corner positions""" voxels: torch.Tensor """`Tensor(M, 3)` Voxel centers""" corner_indices: torch.Tensor """`Tensor(M, 8)` Voxel corner indices""" voxel_indices_in_grid: torch.Tensor """`Tensor(G)` Indices in voxel list or -1 for pruned space Note that the first element is perserved for 'invalid voxel'(-1), so the grid index should be offset by 1 before querying for corresponding voxel index. """ @property def n_voxels(self) -> int: """`int` Number of voxels""" return self.voxels.size(0) @property def n_corners(self) -> int: """`int` Number of corners""" return self.corners.size(0) @property def n_grids(self) -> int: """`int` Number of grids, i.e. steps[0] * steps[1] * ... * steps[D]""" return self.steps.prod().item() @property def voxel_size(self) -> torch.Tensor: """`Tensor(3)` Voxel size""" if self.bbox is None: raise RuntimeError("Cannot get property 'voxel_size' of a space which " "doesn't have bounding box") return (self.bbox[1] - self.bbox[0]) / self.steps @property def corner_embeddings(self) -> dict[str, torch.nn.Embedding]: return {name[4:]: emb for name, emb in self.named_modules() if name.startswith("emb_")} @property def voxel_embeddings(self) -> dict[str, torch.nn.Embedding]: return {name[5:]: emb for name, emb in self.named_modules() if name.startswith("vemb_")} def __init__(self, clone_src: "Voxels" = None, *, bbox: list[float] = None, voxel_size: float = None, steps: torch.Tensor | tuple[int, ...] = None, **kwargs) -> None: if clone_src: super().__init__(clone_src) self.register_buffer('steps', clone_src.steps) self.register_buffer('voxels', clone_src.voxels) self.register_buffer("corners", clone_src.corners) self.register_buffer("corner_indices", clone_src.corner_indices) self.register_buffer('voxel_indices_in_grid', clone_src.voxel_indices_in_grid) else: if bbox is None: raise ValueError("Missing argument 'bbox'") super().__init__(bbox=bbox) if steps is not None: self.register_buffer('steps', torch.tensor(steps, dtype=torch.long)) else: self.register_buffer('steps', get_grid_steps(self.bbox, voxel_size)) self.register_buffer('voxels', init_voxels(self.bbox, self.steps)) corners, corner_indices = get_corners(self.voxels, self.bbox, self.steps) self.register_buffer("corners", corners) self.register_buffer("corner_indices", corner_indices) self.register_buffer('voxel_indices_in_grid', torch.arange(-1, self.n_voxels)) def to_vi(self, gi: torch.Tensor) -> torch.Tensor: return self.voxel_indices_in_grid[gi + 1] def create_embedding(self, n_dims: int, name: str = 'default') -> torch.nn.Embedding: """ Create a embedding on voxel corners. :param name `str`: embedding name :param n_dims `int`: embedding dimension :return `Embedding(n_corners, n_dims)`: new embedding on voxel corners """ if self.get_embedding(name) is not None: raise KeyError(f"Embedding '{name}' already existed") emb = torch.nn.Embedding(self.n_corners, n_dims).to(self.device) setattr(self, f'emb_{name}', emb) return emb def get_embedding(self, name: str = 'default') -> torch.nn.Embedding: return getattr(self, f'emb_{name}', None) def set_embedding(self, weight: torch.Tensor, name: str = 'default'): emb = torch.nn.Embedding(*weight.shape, _weight=weight).to(self.device) setattr(self, f'emb_{name}', emb) return emb def extract_embedding(self, pts: torch.Tensor, voxel_indices: torch.Tensor, name: str = 'default') -> torch.Tensor: """ Extract embedding values at given points using trilinear interpolation. :param pts `Tensor(N, 3)`: points to extract values :param voxel_indices `Tensor(N)`: corresponding voxel indices :param name `str`: embedding name, default to 'default' :return `Tensor(N, X)`: extracted values """ emb = self.get_embedding(name) if emb is None: raise KeyError(f"Embedding '{name}' doesn't exist") voxels = self.voxels[voxel_indices] # (N, 3) corner_indices = self.corner_indices[voxel_indices] # (N, 8) p = (pts - voxels) / self.voxel_size + .5 # (N, 3) normed-coords in voxel return linear_interp(p, emb(corner_indices)) def create_voxel_embedding(self, n_dims: int, name: str = 'default') -> torch.nn.Embedding: """ Create a embedding on voxels. :param name `str`: embedding name :param n_dims `int`: embedding dimension :return `Embedding(n_corners, n_dims)`: new embedding on voxels """ if self.get_voxel_embedding(name) is not None: raise KeyError(f"Embedding '{name}' already existed") emb = torch.nn.Embedding(self.n_voxels, n_dims).to(self.device) setattr(self, f'vemb_{name}', emb) return emb def get_voxel_embedding(self, name: str = 'default') -> torch.nn.Embedding: return getattr(self, f'vemb_{name}', None) def set_voxel_embedding(self, weight: torch.Tensor, name: str = 'default'): emb = torch.nn.Embedding(*weight.shape, _weight=weight).to(self.device) setattr(self, f'vemb_{name}', emb) return emb def extract_voxel_embedding(self, voxel_indices: torch.Tensor, name: str = 'default') -> torch.Tensor: """ Extract embedding values at given voxels. :param voxel_indices `Tensor(N)`: voxel indices :param name `str`: embedding name, default to 'default' :return `Tensor(N, X)`: extracted values """ emb = self.get_voxel_embedding(name) if emb is None: raise KeyError(f"Embedding '{name}' doesn't exist") return emb(voxel_indices) @profile def ray_intersect(self, rays_o: torch.Tensor, rays_d: torch.Tensor, n_max_hits: int) -> Intersections: """ Calculate intersections of rays and voxels. :param rays_o `Tensor(N, 3)`: rays' origin :param rays_d `Tensor(N, 3)`: rays' direction :param n_max_hits `int`: maximum number of hits (for allocating enough space) :return `Intersection`: intersections of rays and voxels """ # Prepend a dim to meet the requirement of external call rays_o = rays_o[None].contiguous() rays_d = rays_d[None].contiguous() voxel_indices, min_depths, max_depths = self._ray_intersect(rays_o, rays_d, n_max_hits) invalid_voxel_mask = voxel_indices.eq(-1) hits = n_max_hits - invalid_voxel_mask.sum(-1) # Sort intersections according to their depths min_depths.masked_fill_(invalid_voxel_mask, math.huge) max_depths.masked_fill_(invalid_voxel_mask, math.huge) min_depths, sorted_idx = min_depths.sort(dim=-1) max_depths = max_depths.gather(-1, sorted_idx) voxel_indices = voxel_indices.gather(-1, sorted_idx) return Intersections( min_depths=min_depths[0], max_depths=max_depths[0], voxel_indices=voxel_indices[0], hits=hits[0] ) @profile def get_voxel_indices(self, pts: torch.Tensor) -> torch.Tensor: """ Get voxel indices of points. If a point is not in any valid voxels, its corresponding voxel index is -1. :param pts `Tensor(N..., 3)`: points :return `Tensor(N...)`: corresponding voxel indices """ gi = to_grid_indices(pts, self.bbox, self.steps) return self.to_vi(gi) @profile def get_corners(self, vidxs: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: vidxs = vidxs.unique() if vidxs[0] == -1: vidxs = vidxs[1:] cidxs = self.corner_indices[vidxs].unique() fi_cidxs = torch.full([self.n_corners], -1, dtype=torch.long, device=self.device) fi_cidxs[cidxs] = torch.arange(cidxs.shape[0], device=self.device) fi_corner_indices = fi_cidxs[self.corner_indices] fi_corners = self.corners[cidxs] return fi_corner_indices, fi_corners @torch.no_grad() def split(self) -> tuple[int, int]: """ Split voxels into smaller voxels with half size. """ # Calculate new voxels and corners new_steps = self.steps * 2 new_voxels = split_voxels(self.voxels, self.voxel_size, 2, align_border=False)\ .reshape(-1, 3) new_corners, new_corner_indices = get_corners(new_voxels, self.bbox, new_steps) # Split corner embeddings through interpolation corner_embs = self.corner_embeddings if len(corner_embs) > 0: gi_of_new_corners = to_grid_indices(new_corners, self.bbox, self.steps) vi_of_new_corners = self.to_vi(gi_of_new_corners) for name, emb in corner_embs.items(): new_emb_weight = self.extract_embedding(new_corners, vi_of_new_corners, name=name) self.set_embedding(new_emb_weight, name=name) # Remove old embedding weight and related state from optimizer self._update_optimizer(emb.weight) # Split voxel embeddings self._update_voxel_embeddings(lambda val: torch.repeat_interleave(val, 8, dim=0)) # Apply new tensors self.steps = new_steps self.voxels = new_voxels self.corners = new_corners self.corner_indices = new_corner_indices self._update_gi2vi() return self.n_voxels // 8, self.n_voxels @torch.no_grad() def prune(self, keeps: torch.Tensor) -> tuple[int, int]: self.voxels = self.voxels[keeps] self.corner_indices = self.corner_indices[keeps] self._update_gi2vi() # Prune voxel embeddings self._update_voxel_embeddings(lambda val: val[keeps]) return keeps.size(0), keeps.sum().item() def _update_voxel_embeddings(self, update_fn): for name, emb in self.voxel_embeddings.items(): new_emb = self.set_voxel_embedding(update_fn(emb.weight), name) self._update_optimizer(emb.weight, new_emb.weight, update_fn) def _update_optimizer(self, old_param: Parameter, new_param: Parameter, update_fn): optimizer = get_env()["trainer"].optimizer if isinstance(optimizer, (torch.optim.Adam, torch.optim.AdamW)): # Update related states in optimizer if old_param in optimizer.state: if new_param is not None: # Transfer state from old parameter to new parameter state = optimizer.state[old_param] state.update({ key: update_fn(state[key]) for key in ['exp_avg', 'exp_avg_sq', 'max_exp_avg_sq'] if key in state }) optimizer.state[new_param] = state # Remove state of old parameter optimizer.state.pop(old_param) # Update parameter list in optimizer for group in optimizer.param_groups: try: if new_param is not None: # Replace old parameter with new one idx = group['params'].index(old_param) group['params'][idx] = new_param else: # Or just remove old parameter if new parameter is not specified group['params'].remove(old_param) except Exception: pass def n_voxels_along_dim(self, dim: int) -> torch.Tensor: sum_dims = [val for val in range(self.dims) if val != dim] return self.voxel_indices_in_grid[1:].reshape(*self.steps).ne(-1).sum(sum_dims) def balance_cut(self, dim: int, n_parts: int) -> list[int]: n_voxels_list = self.n_voxels_along_dim(dim) cdf = (n_voxels_list.cumsum(0) / self.n_voxels * n_parts).tolist() bins = [] part = 1 offset = 0 for i in range(len(cdf)): if cdf[i] > part: bins.append(i - offset) offset = i part = int(cdf[i]) + 1 bins.append(len(cdf) - offset) return bins def sample(self, S: int, perturb: bool = False, include_border: bool = True) -> tuple[torch.Tensor, torch.Tensor]: """ For each voxel, sample `S^3` points uniformly, with small perturb if `perturb` is `True`. When `perturb` is `False`, `include_border` can specify whether to sample points from border to border or at centers of sub-voxels. When `perturb` is `True`, points are sampled at centers of sub-voxels, then applying a random offset in sub-voxels. :param S `int`: number of samples along each dim :param perturb `bool?`: whether perturb samples, defaults to `False` :param include_border `bool?`: whether include border, defaults to `True` :return `Tensor(N*S^3, 3)`: sampled points :return `Tensor(N*S^3)`: voxel indices of sampled points """ pts = split_voxels(self.voxels, self.voxel_size, S, align_border=not perturb and include_border) # (N, X, D) voxel_indices = torch.arange(self.n_voxels, device=self.device)[:, None]\ .expand(*pts.shape[:-1]) # (N) -> (N, X) if perturb: pts += (torch.rand_like(pts) - .5) * self.voxel_size / S return pts.reshape(-1, 3), voxel_indices.flatten() def _ray_intersect(self, rays_o: torch.Tensor, rays_d: torch.Tensor, n_max_hits: int) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: return aabb_ray_intersect(self.voxel_size, n_max_hits, self.voxels, rays_o, rays_d) def _update_gi2vi(self): """ Update voxel indices in grid. """ gi = to_grid_indices(self.voxels, self.bbox, self.steps) # Perserve the first element in voxel_indices_in_grid for 'invalid voxel'(-1) self.voxel_indices_in_grid = gi.new_full([self.n_grids + 1], -1) self.voxel_indices_in_grid[gi + 1] = torch.arange(self.n_voxels, device=self.device) def _before_load_state_dict(self, state_dict, prefix, *args): # Handle buffers for name, buffer in self.named_buffers(recurse=False): if name in self._non_persistent_buffers_set: continue buffer.resize_as_(state_dict[prefix + name]) # Handle embeddings for name, module in self.named_modules(): if name.startswith('emb_'): setattr(self, name, torch.nn.Embedding(self.n_corners, module.embedding_dim)) if name.startswith('vemb_'): setattr(self, name, torch.nn.Embedding(self.n_voxels, module.embedding_dim)) def _after_load_state_dict(self): self._update_gi2vi() class Octree(Voxels): def __init__(self, clone_src: "Octree" = None, **kwargs) -> None: super().__init__(clone_src, **kwargs) self.nodes_cached = None self.tree_cached = None def get(self) -> tuple[torch.Tensor, torch.Tensor]: if self.nodes_cached is None: self.nodes_cached, self.tree_cached = build_easy_octree( self.voxels, 0.5 * self.voxel_size) return self.nodes_cached, self.tree_cached def clear(self): self.nodes_cached = None self.tree_cached = None def _ray_intersect(self, rays_o: torch.Tensor, rays_d: torch.Tensor, n_max_hits: int): nodes, tree = self.get() return octree_ray_intersect(self.voxel_size, n_max_hits, nodes, tree, rays_o, rays_d) @torch.no_grad() def split(self): ret = super().split() self.clear() return ret @torch.no_grad() def prune(self, keeps: torch.Tensor) -> tuple[int, int]: ret = super().prune(keeps) self.clear() return ret