voxels.py 6.21 KB
Newer Older
Nianchen Deng's avatar
sync    
Nianchen Deng committed
1
2
3
import torch
from typing import Tuple, Union

Nianchen Deng's avatar
sync    
Nianchen Deng committed
4
5
from . import math

Nianchen Deng's avatar
sync    
Nianchen Deng committed
6
7
8
9
10
11
12
13
14
15
16
17

def get_grid_steps(bbox: torch.Tensor, step_size: Union[torch.Tensor, float]) -> torch.Tensor:
    """
    Get grid steps alone every dim.

    :param bbox `Tensor(2, D)`: bounding box
    :param step_size `Tensor(1|D) | float`: step size
    :return `Tensor(D)`: grid steps alone every dim
    """
    return ((bbox[1] - bbox[0]) / step_size).ceil().long()


Nianchen Deng's avatar
sync    
Nianchen Deng committed
18
19
20
21
22
23
24
25
26
27
28
29
def get_out_of_bound_mask(pts: torch.Tensor, bbox: torch.Tensor) -> torch.Tensor:
    """
    Get a mask tensor indicating which elements in `pts` are out of the bound `bbox`

    :param pts `Tensor(N..., D)`: points
    :param bbox `Tensor(2, D)`: bounding box
    :return `Tensor(N...)`: a mask tensor
    """
    k = (pts - bbox[0]) / (bbox[1] - bbox[0])
    return torch.logical_or(k < -math.tiny, k > 1 + math.tiny).any(-1)


Nianchen Deng's avatar
sync    
Nianchen Deng committed
30
31
32
33
34
35
36
def to_flat_indices(grid_coords: torch.Tensor, steps: torch.Tensor) -> torch.Tensor:
    indices = grid_coords[..., 0]
    for i in range(1, grid_coords.shape[-1]):
        indices = indices * steps[i] + grid_coords[..., i]
    return indices


Nianchen Deng's avatar
sync    
Nianchen Deng committed
37
def to_grid_coords(pts: torch.Tensor, bbox: torch.Tensor, steps: torch.Tensor) -> torch.Tensor:
Nianchen Deng's avatar
sync    
Nianchen Deng committed
38
39
40
41
42
43
44
45
46
    """
    Get discretized (integer) grid coordinates of points.

    At least one of the parameters `step_size` and `steps` should be specified. If `step_size` is
    specified, then the grid coordinates will be calculated according to the step size, ignoring
    the value of `steps`.

    :param pts `Tensor(N..., D)`: points
    :param bbox `Tensor(2, D)`: bounding box
Nianchen Deng's avatar
sync    
Nianchen Deng committed
47
    :param steps `Tensor(1|D)`: steps alone every dim
Nianchen Deng's avatar
sync    
Nianchen Deng committed
48
49
50
51
52
    :return `Tensor(N..., D)`: discretized grid coordinates
    """
    return ((pts - bbox[0]) / (bbox[1] - bbox[0]) * steps).floor().long()


Nianchen Deng's avatar
sync    
Nianchen Deng committed
53
def to_grid_indices(pts: torch.Tensor, bbox: torch.Tensor, steps: torch.Tensor) -> torch.Tensor:
Nianchen Deng's avatar
sync    
Nianchen Deng committed
54
55
56
57
58
59
60
61
62
    """
    Get flattened grid indices of points.

    At least one of the parameters `step_size` and `steps` should be specified. If `step_size` is
    specified, then the grid indices will be calculated according to the step size, ignoring
    the value of `steps`.

    :param pts `Tensor(N..., D)`: points
    :param bbox `Tensor(2, D)`: bounding box
Nianchen Deng's avatar
sync    
Nianchen Deng committed
63
64
65
66
67
    :param steps `Tensor(1|D)`: steps alone every dim
    :return `Tensor(N...)`: flattened grid indices
    """
    grid_coords = to_grid_coords(pts, bbox, steps).minimum(steps - 1)  # (N..., D)
    outside_mask = get_out_of_bound_mask(pts, bbox)  # (N...)
Nianchen Deng's avatar
sync    
Nianchen Deng committed
68
    grid_indices = to_flat_indices(grid_coords, steps)
Nianchen Deng's avatar
sync    
Nianchen Deng committed
69
70
    grid_indices[outside_mask] = -1
    return grid_indices
Nianchen Deng's avatar
sync    
Nianchen Deng committed
71
72
73
74
75
76


def init_voxels(bbox: torch.Tensor, steps: torch.Tensor):
    """
    Initialize voxels.
    """
Nianchen Deng's avatar
sync    
Nianchen Deng committed
77
78
    x, y, z = torch.meshgrid(*[torch.arange(steps[i]) for i in range(3)])
    return to_voxel_centers(torch.stack([x, y, z], -1).reshape(-1, 3), bbox, steps)
Nianchen Deng's avatar
sync    
Nianchen Deng committed
79
80


Nianchen Deng's avatar
sync    
Nianchen Deng committed
81
def to_voxel_centers(grid_coords: torch.Tensor, bbox: torch.Tensor, steps: torch.Tensor) -> torch.Tensor:
Nianchen Deng's avatar
sync    
Nianchen Deng committed
82
    """
Nianchen Deng's avatar
sync    
Nianchen Deng committed
83
    Get center positions of grids.
Nianchen Deng's avatar
sync    
Nianchen Deng committed
84

Nianchen Deng's avatar
sync    
Nianchen Deng committed
85
    :param grid_coords `Tensor(N..., D)`: grid coordinates
Nianchen Deng's avatar
sync    
Nianchen Deng committed
86
    :param bbox `Tensor(2, D)`: bounding box
Nianchen Deng's avatar
sync    
Nianchen Deng committed
87
88
    :param steps `Tensor(1|D)`: steps alone every dim
    :return `Tensor(N..., D)`: center positions of grids
Nianchen Deng's avatar
sync    
Nianchen Deng committed
89
    """
Nianchen Deng's avatar
sync    
Nianchen Deng committed
90
    grid_coords = grid_coords.float() + .5
Nianchen Deng's avatar
sync    
Nianchen Deng committed
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
    return grid_coords / steps * (bbox[1] - bbox[0]) + bbox[0]


def split_voxels_local(voxel_size: Union[torch.Tensor, float], n: int, align_border: bool = True,
                       dims=3, *, dtype: torch.dtype = None, device: torch.device = None,
                       like: torch.Tensor = None):
    """
    [summary]

    :param voxel_size `Tensor(D)|float`: [description]
    :param n `int`: [description]
    :param align_border `bool`: [description], defaults to False
    :param dims `int`: [description], defaults to 3
    :param dtype `dtype`: [description], defaults to None
    :param device `device`: [description], defaults to None
    :param like `Tensor(*)`:
    :return `Tensor(X, D)`: [description]
    """
    if like is not None:
        dtype = like.dtype
        device = like.device
    c = torch.arange(1 - n, n, 2, dtype=dtype, device=device)
Nianchen Deng's avatar
sync    
Nianchen Deng committed
113
    offset = torch.stack(torch.meshgrid([c] * dims), -1).flatten(0, -2)\
Nianchen Deng's avatar
sync    
Nianchen Deng committed
114
        * voxel_size * .5 / (n - 1 if align_border else n)
Nianchen Deng's avatar
sync    
Nianchen Deng committed
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
    return offset


def split_voxels(voxel_centers: torch.Tensor, voxel_size: Union[torch.Tensor, float], n: int,
                 align_border: bool = True):
    """
    [summary]

    :param voxel_centers `Tensor(N, D)`: [description]
    :param voxel_size `Tensor(D)|float`: [description]
    :param n `int`: [description]
    :param align_border `bool`: [description], defaults to False
    :param return_local `bool`: [description], defaults to False
    :return `Tensor(N, X, D)`: [description]
    """
    return voxel_centers[:, None] + split_voxels_local(
        voxel_size, n, align_border, voxel_centers.shape[-1], like=voxel_centers)


def get_corners(voxel_centers: torch.Tensor, bbox: torch.Tensor, steps: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
    half_voxel_size = (bbox[1] - bbox[0]) / steps * 0.5
Nianchen Deng's avatar
sync    
Nianchen Deng committed
136
    expand_bbox = bbox.clone()
Nianchen Deng's avatar
sync    
Nianchen Deng committed
137
138
    expand_bbox[0] -= 0.5 * half_voxel_size
    expand_bbox[1] += 0.5 * half_voxel_size
Nianchen Deng's avatar
sync    
Nianchen Deng committed
139
    double_grid_coords = to_grid_coords(voxel_centers, expand_bbox, steps * 2 + 1)
Nianchen Deng's avatar
sync    
Nianchen Deng committed
140
141
142
143
144
145
    # (M, 3) -> [1, 3, 5, ...]

    corner_coords = split_voxels(double_grid_coords, 2, 2).reshape(-1, 3)
    # (8M, 3) -> [0, 2, 4, ...]

    corner_coords, corner_indices = corner_coords.unique(dim=0, sorted=True, return_inverse=True)
Nianchen Deng's avatar
sync    
Nianchen Deng committed
146
    corners = to_voxel_centers(corner_coords, expand_bbox, steps * 2 + 1)
Nianchen Deng's avatar
sync    
Nianchen Deng committed
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161

    return corners, corner_indices.reshape(-1, 8)


def trilinear_interp(pts: torch.Tensor, corner_values: torch.Tensor) -> torch.Tensor:
    """
    Perform trilinear interpolation in unit voxel ([0,0,0] ~ [1,1,1]).

    :param pts `Tensor(N, 3)`: uniform coordinates in voxels
    :param corner_values `Tensor(N, 8X)|Tensor(N, 8, X)`: values at corners of voxels
    :return `Tensor(N, X)`: interpolated values
    """
    pts = pts[:, None]  # (N, 1, 3)
    corners = split_voxels_local(1, 2, like=pts) + 0.5  # (8, 3)
    corner_values = corner_values.reshape(corner_values.size(0), 8, -1)  # (N, 8, X)
Nianchen Deng's avatar
sync    
Nianchen Deng committed
162
163

    weights = (pts * corners * 2 - pts - corners + 1).prod(-1, keepdim=True)  # (N, 8, 1)
Nianchen Deng's avatar
sync    
Nianchen Deng committed
164
    return (weights * corner_values).sum(1)