import torch from typing import Tuple, Union def get_grid_steps(bbox: torch.Tensor, step_size: Union[torch.Tensor, float]) -> torch.Tensor: """ Get grid steps alone every dim. :param bbox `Tensor(2, D)`: bounding box :param step_size `Tensor(1|D) | float`: step size :return `Tensor(D)`: grid steps alone every dim """ return ((bbox[1] - bbox[0]) / step_size).ceil().long() def to_grid_coords(pts: torch.Tensor, bbox: torch.Tensor, *, step_size: Union[torch.Tensor, float] = None, steps: torch.Tensor = None) -> torch.Tensor: """ Get discretized (integer) grid coordinates of points. At least one of the parameters `step_size` and `steps` should be specified. If `step_size` is specified, then the grid coordinates will be calculated according to the step size, ignoring the value of `steps`. :param pts `Tensor(N..., D)`: points :param bbox `Tensor(2, D)`: bounding box :param step_size `Tensor(1|D) | float`: (optional) step size :param steps `Tensor(1|D)`: (optional) steps alone every dim :return `Tensor(N..., D)`: discretized grid coordinates """ if step_size is not None: return ((pts - bbox[0]) / step_size).floor().long() return ((pts - bbox[0]) / (bbox[1] - bbox[0]) * steps).floor().long() def to_grid_indices(pts: torch.Tensor, bbox: torch.Tensor, *, step_size: Union[torch.Tensor, float] = None, steps: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]: """ Get flattened grid indices of points. At least one of the parameters `step_size` and `steps` should be specified. If `step_size` is specified, then the grid indices will be calculated according to the step size, ignoring the value of `steps`. :param pts `Tensor(N..., D)`: points :param bbox `Tensor(2, D)`: bounding box :param step_size `Tensor(1|D) | float`: (optional) step size :param steps `Tensor(1|D)`: (optional) steps alone every dim :return `Tensor(N...)`: grid indices :return `Tensor(N...)`: a mask tensor indicating the returned indices are outside or not """ if step_size is not None: steps = get_grid_steps(bbox, step_size) # (D) grid_coords = to_grid_coords(pts, bbox, step_size=step_size, steps=steps) # (N..., D) outside_mask = torch.logical_or(grid_coords < 0, grid_coords >= steps).any(-1) # (N...) if pts.size(-1) == 1: grid_indices = grid_coords[..., 0] elif pts.size(-1) == 2: grid_indices = grid_coords[..., 0] * steps[1] + grid_coords[..., 1] elif pts.size(-1) == 3: grid_indices = grid_coords[..., 0] * steps[1] * steps[2] \ + grid_coords[..., 1] * steps[2] + grid_coords[..., 2] elif pts.size(-1) == 4: grid_indices = grid_coords[..., 0] * steps[1] * steps[2] * steps[3] \ + grid_coords[..., 1] * steps[2] * steps[3] \ + grid_coords[..., 2] * steps[3] \ + grid_coords[..., 3] else: raise NotImplementedError("The function does not support D>4") return grid_indices, outside_mask def init_voxels(bbox: torch.Tensor, steps: torch.Tensor): """ Initialize voxels. """ x, y, z = torch.meshgrid(*[torch.arange(steps[i]) for i in range(3)]) return to_voxel_centers(torch.stack([x, y, z], -1).reshape(-1, 3), bbox, steps=steps) def to_voxel_centers(grid_coords: torch.Tensor, bbox: torch.Tensor, *, step_size: Union[torch.Tensor, float] = None, steps: torch.Tensor = None) -> torch.Tensor: """ Get discretized (integer) grid coordinates of points. At least one of the parameters `step_size` and `steps` should be specified. If `step_size` is specified, then the grid coordinates will be calculated according to the step size, ignoring the value of `steps`. :param pts `Tensor(N..., D)`: points :param bbox `Tensor(2, D)`: bounding box :param step_size `Tensor(1|D) | float`: (optional) step size :param steps `Tensor(1|D)`: (optional) steps alone every dim :return `Tensor(N..., D)`: discretized grid coordinates """ grid_coords = grid_coords.float() + 0.5 if step_size is not None: return grid_coords * step_size + bbox[0] return grid_coords / steps * (bbox[1] - bbox[0]) + bbox[0] def split_voxels_local(voxel_size: Union[torch.Tensor, float], n: int, align_border: bool = True, dims=3, *, dtype: torch.dtype = None, device: torch.device = None, like: torch.Tensor = None): """ [summary] :param voxel_size `Tensor(D)|float`: [description] :param n `int`: [description] :param align_border `bool`: [description], defaults to False :param dims `int`: [description], defaults to 3 :param dtype `dtype`: [description], defaults to None :param device `device`: [description], defaults to None :param like `Tensor(*)`: :return `Tensor(X, D)`: [description] """ if like is not None: dtype = like.dtype device = like.device c = torch.arange(1 - n, n, 2, dtype=dtype, device=device) offset = torch.stack(torch.meshgrid([c] * dims), -1).flatten(0, -2) * voxel_size / 2 /\ (n - 1 if align_border else n) return offset def split_voxels(voxel_centers: torch.Tensor, voxel_size: Union[torch.Tensor, float], n: int, align_border: bool = True): """ [summary] :param voxel_centers `Tensor(N, D)`: [description] :param voxel_size `Tensor(D)|float`: [description] :param n `int`: [description] :param align_border `bool`: [description], defaults to False :param return_local `bool`: [description], defaults to False :return `Tensor(N, X, D)`: [description] """ return voxel_centers[:, None] + split_voxels_local( voxel_size, n, align_border, voxel_centers.shape[-1], like=voxel_centers) def get_corners(voxel_centers: torch.Tensor, bbox: torch.Tensor, steps: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: half_voxel_size = (bbox[1] - bbox[0]) / steps * 0.5 expand_bbox = bbox expand_bbox[0] -= 0.5 * half_voxel_size expand_bbox[1] += 0.5 * half_voxel_size double_grid_coords = to_grid_coords(voxel_centers, expand_bbox, step_size=half_voxel_size) # (M, 3) -> [1, 3, 5, ...] corner_coords = split_voxels(double_grid_coords, 2, 2).reshape(-1, 3) # (8M, 3) -> [0, 2, 4, ...] corner_coords, corner_indices = corner_coords.unique(dim=0, sorted=True, return_inverse=True) corners = to_voxel_centers(corner_coords, expand_bbox, step_size=half_voxel_size) return corners, corner_indices.reshape(-1, 8) def trilinear_interp(pts: torch.Tensor, corner_values: torch.Tensor) -> torch.Tensor: """ Perform trilinear interpolation in unit voxel ([0,0,0] ~ [1,1,1]). :param pts `Tensor(N, 3)`: uniform coordinates in voxels :param corner_values `Tensor(N, 8X)|Tensor(N, 8, X)`: values at corners of voxels :return `Tensor(N, X)`: interpolated values """ pts = pts[:, None] # (N, 1, 3) corners = split_voxels_local(1, 2, like=pts) + 0.5 # (8, 3) weights = (pts * corners * 2 - pts - corners + 1).prod(-1, keepdim=True) # (N, 8, 1) corner_values = corner_values.reshape(corner_values.size(0), 8, -1) # (N, 8, X) return (weights * corner_values).sum(1)