voxels.py 7.24 KB
Newer Older
Nianchen Deng's avatar
sync    
Nianchen Deng committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
import torch
from typing import Tuple, Union


def get_grid_steps(bbox: torch.Tensor, step_size: Union[torch.Tensor, float]) -> torch.Tensor:
    """
    Get grid steps alone every dim.

    :param bbox `Tensor(2, D)`: bounding box
    :param step_size `Tensor(1|D) | float`: step size
    :return `Tensor(D)`: grid steps alone every dim
    """
    return ((bbox[1] - bbox[0]) / step_size).ceil().long()


def to_grid_coords(pts: torch.Tensor, bbox: torch.Tensor, *,
                   step_size: Union[torch.Tensor, float] = None,
                   steps: torch.Tensor = None) -> torch.Tensor:
    """
    Get discretized (integer) grid coordinates of points.

    At least one of the parameters `step_size` and `steps` should be specified. If `step_size` is
    specified, then the grid coordinates will be calculated according to the step size, ignoring
    the value of `steps`.

    :param pts `Tensor(N..., D)`: points
    :param bbox `Tensor(2, D)`: bounding box
    :param step_size `Tensor(1|D) | float`: (optional) step size
    :param steps `Tensor(1|D)`: (optional) steps alone every dim
    :return `Tensor(N..., D)`: discretized grid coordinates
    """
    if step_size is not None:
        return ((pts - bbox[0]) / step_size).floor().long()
    return ((pts - bbox[0]) / (bbox[1] - bbox[0]) * steps).floor().long()


def to_grid_indices(pts: torch.Tensor, bbox: torch.Tensor, *,
                    step_size: Union[torch.Tensor, float] = None,
                    steps: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Get flattened grid indices of points.

    At least one of the parameters `step_size` and `steps` should be specified. If `step_size` is
    specified, then the grid indices will be calculated according to the step size, ignoring
    the value of `steps`.

    :param pts `Tensor(N..., D)`: points
    :param bbox `Tensor(2, D)`: bounding box
    :param step_size `Tensor(1|D) | float`: (optional) step size
    :param steps `Tensor(1|D)`: (optional) steps alone every dim
    :return `Tensor(N...)`: grid indices
    :return `Tensor(N...)`: a mask tensor indicating the returned indices are outside or not
    """
    if step_size is not None:
        steps = get_grid_steps(bbox, step_size)  # (D)
    grid_coords = to_grid_coords(pts, bbox, step_size=step_size, steps=steps)  # (N..., D)
    outside_mask = torch.logical_or(grid_coords < 0, grid_coords >= steps).any(-1)  # (N...)
    if pts.size(-1) == 1:
        grid_indices = grid_coords[..., 0]
    elif pts.size(-1) == 2:
        grid_indices = grid_coords[..., 0] * steps[1] + grid_coords[..., 1]
    elif pts.size(-1) == 3:
        grid_indices = grid_coords[..., 0] * steps[1] * steps[2] \
            + grid_coords[..., 1] * steps[2] + grid_coords[..., 2]
    elif pts.size(-1) == 4:
        grid_indices = grid_coords[..., 0] * steps[1] * steps[2] * steps[3] \
            + grid_coords[..., 1] * steps[2] * steps[3] \
            + grid_coords[..., 2] * steps[3] \
            + grid_coords[..., 3]
    else:
        raise NotImplementedError("The function does not support D>4")
    return grid_indices, outside_mask


def init_voxels(bbox: torch.Tensor, steps: torch.Tensor):
    """
    Initialize voxels.
    """
    x, y, z = torch.meshgrid(*[torch.arange(steps[i]) for i in range(3)])
    return to_voxel_centers(torch.stack([x, y, z], -1).reshape(-1, 3), bbox, steps=steps)


def to_voxel_centers(grid_coords: torch.Tensor, bbox: torch.Tensor, *,
                     step_size: Union[torch.Tensor, float] = None,
                     steps: torch.Tensor = None) -> torch.Tensor:
    """
    Get discretized (integer) grid coordinates of points.

    At least one of the parameters `step_size` and `steps` should be specified. If `step_size` is
    specified, then the grid coordinates will be calculated according to the step size, ignoring
    the value of `steps`.

    :param pts `Tensor(N..., D)`: points
    :param bbox `Tensor(2, D)`: bounding box
    :param step_size `Tensor(1|D) | float`: (optional) step size
    :param steps `Tensor(1|D)`: (optional) steps alone every dim
    :return `Tensor(N..., D)`: discretized grid coordinates
    """
    grid_coords = grid_coords.float() + 0.5
    if step_size is not None:
        return grid_coords * step_size + bbox[0]
    return grid_coords / steps * (bbox[1] - bbox[0]) + bbox[0]


def split_voxels_local(voxel_size: Union[torch.Tensor, float], n: int, align_border: bool = True,
                       dims=3, *, dtype: torch.dtype = None, device: torch.device = None,
                       like: torch.Tensor = None):
    """
    [summary]

    :param voxel_size `Tensor(D)|float`: [description]
    :param n `int`: [description]
    :param align_border `bool`: [description], defaults to False
    :param dims `int`: [description], defaults to 3
    :param dtype `dtype`: [description], defaults to None
    :param device `device`: [description], defaults to None
    :param like `Tensor(*)`:
    :return `Tensor(X, D)`: [description]
    """
    if like is not None:
        dtype = like.dtype
        device = like.device
    c = torch.arange(1 - n, n, 2, dtype=dtype, device=device)
    offset = torch.stack(torch.meshgrid([c] * dims), -1).flatten(0, -2) * voxel_size / 2 /\
        (n - 1 if align_border else n)
    return offset


def split_voxels(voxel_centers: torch.Tensor, voxel_size: Union[torch.Tensor, float], n: int,
                 align_border: bool = True):
    """
    [summary]

    :param voxel_centers `Tensor(N, D)`: [description]
    :param voxel_size `Tensor(D)|float`: [description]
    :param n `int`: [description]
    :param align_border `bool`: [description], defaults to False
    :param return_local `bool`: [description], defaults to False
    :return `Tensor(N, X, D)`: [description]
    """
    return voxel_centers[:, None] + split_voxels_local(
        voxel_size, n, align_border, voxel_centers.shape[-1], like=voxel_centers)


def get_corners(voxel_centers: torch.Tensor, bbox: torch.Tensor, steps: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
    half_voxel_size = (bbox[1] - bbox[0]) / steps * 0.5
    expand_bbox = bbox
    expand_bbox[0] -= 0.5 * half_voxel_size
    expand_bbox[1] += 0.5 * half_voxel_size
    double_grid_coords = to_grid_coords(voxel_centers, expand_bbox, step_size=half_voxel_size)
    # (M, 3) -> [1, 3, 5, ...]

    corner_coords = split_voxels(double_grid_coords, 2, 2).reshape(-1, 3)
    # (8M, 3) -> [0, 2, 4, ...]

    corner_coords, corner_indices = corner_coords.unique(dim=0, sorted=True, return_inverse=True)
    corners = to_voxel_centers(corner_coords, expand_bbox, step_size=half_voxel_size)

    return corners, corner_indices.reshape(-1, 8)


def trilinear_interp(pts: torch.Tensor, corner_values: torch.Tensor) -> torch.Tensor:
    """
    Perform trilinear interpolation in unit voxel ([0,0,0] ~ [1,1,1]).

    :param pts `Tensor(N, 3)`: uniform coordinates in voxels
    :param corner_values `Tensor(N, 8X)|Tensor(N, 8, X)`: values at corners of voxels
    :return `Tensor(N, X)`: interpolated values
    """
    pts = pts[:, None]  # (N, 1, 3)
    corners = split_voxels_local(1, 2, like=pts) + 0.5  # (8, 3)
    weights = (pts * corners * 2 - pts - corners + 1).prod(-1, keepdim=True)  # (N, 8, 1)
    corner_values = corner_values.reshape(corner_values.size(0), 8, -1)  # (N, 8, X)
    return (weights * corner_values).sum(1)