space.py 20 KB
Newer Older
Nianchen Deng's avatar
sync    
Nianchen Deng committed
1
from .__common__ import *
Nianchen Deng's avatar
sync    
Nianchen Deng committed
2
from clib import *
Nianchen Deng's avatar
sync    
Nianchen Deng committed
3
4
#from model.utils import load
from utils.nn import Parameter
Nianchen Deng's avatar
sync    
Nianchen Deng committed
5
6
from utils.geometry import *
from utils.voxels import *
Nianchen Deng's avatar
sync    
Nianchen Deng committed
7
8

__all__ = ["Space", "Voxels", "Octree"]
Nianchen Deng's avatar
sync    
Nianchen Deng committed
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24


class Intersections:
    min_depths: torch.Tensor
    """`Tensor(N, P)` Min ray depths of intersected voxels"""

    max_depths: torch.Tensor
    """`Tensor(N, P)` Max ray depths of intersected voxels"""

    voxel_indices: torch.Tensor
    """`Tensor(N, P)` Indices of intersected voxels"""

    hits: torch.Tensor
    """`Tensor(N)` Number of hits"""

    @property
Nianchen Deng's avatar
sync    
Nianchen Deng committed
25
26
    def shape(self):
        return self.hits.shape
Nianchen Deng's avatar
sync    
Nianchen Deng committed
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42

    def __init__(self, min_depths: torch.Tensor, max_depths: torch.Tensor,
                 voxel_indices: torch.Tensor, hits: torch.Tensor) -> None:
        self.min_depths = min_depths
        self.max_depths = max_depths
        self.voxel_indices = voxel_indices
        self.hits = hits

    def __getitem__(self, index):
        return Intersections(
            min_depths=self.min_depths[index],
            max_depths=self.max_depths[index],
            voxel_indices=self.voxel_indices[index],
            hits=self.hits[index])


Nianchen Deng's avatar
sync    
Nianchen Deng committed
43
44
45
class Space(nn.Module):
    bbox: torch.Tensor | None
    """`Tensor(2, D)` Bounding box"""
Nianchen Deng's avatar
sync    
Nianchen Deng committed
46

Nianchen Deng's avatar
sync    
Nianchen Deng committed
47
48
49
50
51
52
    @property
    def dims(self) -> int:
        """`int` Number of dimensions"""
        return self.bbox.shape[1] if self.bbox is not None else 3

    @staticmethod
Nianchen Deng's avatar
sync    
Nianchen Deng committed
53
54
55
56
57
58
59
60
61
62
63
64
    def create(type: str, args: dict[str, Any]) -> 'Space':
        match type:
            case "Space":
                return Space(**args)
            case "Octree":
                return Octree(**args)
            case "Voxels":
                return Voxels(**args)
            case _:
                return load(type).space

    def __init__(self, clone_src: "Space" = None, *, bbox: list[float] = None, **kwargs):
Nianchen Deng's avatar
sync    
Nianchen Deng committed
65
        super().__init__()
Nianchen Deng's avatar
sync    
Nianchen Deng committed
66
67
68
        if clone_src:
            self.device = clone_src.device
            self.register_temp('bbox', clone_src.bbox)
Nianchen Deng's avatar
sync    
Nianchen Deng committed
69
        else:
Nianchen Deng's avatar
sync    
Nianchen Deng committed
70
            self.register_temp('bbox', None if not bbox else torch.tensor(bbox).reshape(2, -1))
Nianchen Deng's avatar
sync    
Nianchen Deng committed
71

Nianchen Deng's avatar
sync    
Nianchen Deng committed
72
73
74
    def ray_intersect_with_bbox(self, rays_o: torch.Tensor, rays_d: torch.Tensor) -> Intersections:
        """
        [summary]
Nianchen Deng's avatar
sync    
Nianchen Deng committed
75

Nianchen Deng's avatar
sync    
Nianchen Deng committed
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
        :param rays_o `Tensor(N..., D)`: rays' origin
        :param rays_d `Tensor(N..., D)`: rays' direction
        :param max_hits `int?`: max number of hits of each ray, have no effect for this method
        :return `Intersect(N...)`: rays' intersection with the bounding box
        """
        if self.bbox is None:
            raise RuntimeError("The space has no bounding box")
        inv_d = rays_d.reciprocal().unsqueeze(-2)
        t = (self.bbox - rays_o.unsqueeze(-2)) * inv_d  # (N..., 2, D)
        t0 = t.min(dim=-2)[0].max(dim=-1, keepdim=True)[0].clamp(min=1e-4)  # (N..., 1)
        t1 = t.max(dim=-2)[0].min(dim=-1, keepdim=True)[0]
        miss = t1 <= t0
        t0[miss], t1[miss] = -1., -1.
        hit = torch.logical_not(miss).long()
        return Intersections(t0, t1, hit - 1, hit.squeeze(-1))

    def ray_intersect(self, rays_o: torch.Tensor, rays_d: torch.Tensor, max_hits: int) -> Intersections:
        return self.ray_intersect_with_bbox(rays_o, rays_d)

    def get_voxel_indices(self, pts: torch.Tensor) -> int | torch.Tensor:
Nianchen Deng's avatar
sync    
Nianchen Deng committed
96
97
        if self.bbox is None:
            return 0
Nianchen Deng's avatar
sync    
Nianchen Deng committed
98
        voxel_indices = torch.zeros_like(pts[..., 0], dtype=torch.long)
Nianchen Deng's avatar
sync    
Nianchen Deng committed
99
100
        out_bbox = get_out_of_bound_mask(pts, self.bbox)  # (N...)
        voxel_indices[out_bbox] = -1
Nianchen Deng's avatar
sync    
Nianchen Deng committed
101
102
103
        return voxel_indices

    @torch.no_grad()
Nianchen Deng's avatar
sync    
Nianchen Deng committed
104
    def prune(self, keeps: torch.Tensor) -> tuple[int, int]:
Nianchen Deng's avatar
sync    
Nianchen Deng committed
105
106
107
        raise NotImplementedError()

    @torch.no_grad()
Nianchen Deng's avatar
sync    
Nianchen Deng committed
108
    def split(self) -> tuple[int, int]:
Nianchen Deng's avatar
sync    
Nianchen Deng committed
109
110
        raise NotImplementedError()

Nianchen Deng's avatar
sync    
Nianchen Deng committed
111
112
    @torch.no_grad()
    def clone(self):
Nianchen Deng's avatar
sync    
Nianchen Deng committed
113
        return self.__class__(self)
Nianchen Deng's avatar
sync    
Nianchen Deng committed
114

Nianchen Deng's avatar
sync    
Nianchen Deng committed
115
116

class Voxels(Space):
Nianchen Deng's avatar
sync    
Nianchen Deng committed
117
118
119
    bbox: torch.Tensor
    """`Tensor(2, D)` Bounding box"""

Nianchen Deng's avatar
sync    
Nianchen Deng committed
120
121
122
123
124
125
126
127
128
129
130
131
132
    steps: torch.Tensor
    """`Tensor(3)` Steps along each dimension"""

    corners: torch.Tensor
    """`Tensor(C, 3)` Corner positions"""

    voxels: torch.Tensor
    """`Tensor(M, 3)` Voxel centers"""

    corner_indices: torch.Tensor
    """`Tensor(M, 8)` Voxel corner indices"""

    voxel_indices_in_grid: torch.Tensor
Nianchen Deng's avatar
sync    
Nianchen Deng committed
133
    """`Tensor(G)` Indices in voxel list or -1 for pruned space
Nianchen Deng's avatar
sync    
Nianchen Deng committed
134

Nianchen Deng's avatar
sync    
Nianchen Deng committed
135
136
137
       Note that the first element is perserved for 'invalid voxel'(-1), so the grid 
       index should be offset by 1 before querying for corresponding voxel index.
    """
Nianchen Deng's avatar
sync    
Nianchen Deng committed
138
139
140
141
142
143
144

    @property
    def n_voxels(self) -> int:
        """`int` Number of voxels"""
        return self.voxels.size(0)

    @property
Nianchen Deng's avatar
sync    
Nianchen Deng committed
145
    def n_corners(self) -> int:
Nianchen Deng's avatar
sync    
Nianchen Deng committed
146
147
148
        """`int` Number of corners"""
        return self.corners.size(0)

Nianchen Deng's avatar
sync    
Nianchen Deng committed
149
150
151
152
153
    @property
    def n_grids(self) -> int:
        """`int` Number of grids, i.e. steps[0] * steps[1] * ... * steps[D]"""
        return self.steps.prod().item()

Nianchen Deng's avatar
sync    
Nianchen Deng committed
154
155
156
    @property
    def voxel_size(self) -> torch.Tensor:
        """`Tensor(3)` Voxel size"""
Nianchen Deng's avatar
sync    
Nianchen Deng committed
157
158
159
        if self.bbox is None:
            raise RuntimeError("Cannot get property 'voxel_size' of a space which "
                               "doesn't have bounding box")
Nianchen Deng's avatar
sync    
Nianchen Deng committed
160
161
162
        return (self.bbox[1] - self.bbox[0]) / self.steps

    @property
Nianchen Deng's avatar
sync    
Nianchen Deng committed
163
    def corner_embeddings(self) -> dict[str, torch.nn.Embedding]:
Nianchen Deng's avatar
sync    
Nianchen Deng committed
164
        return {name[4:]: emb for name, emb in self.named_modules() if name.startswith("emb_")}
Nianchen Deng's avatar
sync    
Nianchen Deng committed
165

Nianchen Deng's avatar
sync    
Nianchen Deng committed
166
    @property
Nianchen Deng's avatar
sync    
Nianchen Deng committed
167
    def voxel_embeddings(self) -> dict[str, torch.nn.Embedding]:
Nianchen Deng's avatar
sync    
Nianchen Deng committed
168
169
        return {name[5:]: emb for name, emb in self.named_modules() if name.startswith("vemb_")}

Nianchen Deng's avatar
sync    
Nianchen Deng committed
170
171
    def __init__(self, clone_src: "Voxels" = None, *, bbox: list[float] = None,
                 voxel_size: float = None, steps: torch.Tensor | tuple[int, ...] = None,
Nianchen Deng's avatar
sync    
Nianchen Deng committed
172
173
                 **kwargs) -> None:
        if clone_src:
Nianchen Deng's avatar
sync    
Nianchen Deng committed
174
            super().__init__(clone_src)
Nianchen Deng's avatar
sync    
Nianchen Deng committed
175
176
177
178
179
            self.register_buffer('steps', clone_src.steps)
            self.register_buffer('voxels', clone_src.voxels)
            self.register_buffer("corners", clone_src.corners)
            self.register_buffer("corner_indices", clone_src.corner_indices)
            self.register_buffer('voxel_indices_in_grid', clone_src.voxel_indices_in_grid)
Nianchen Deng's avatar
sync    
Nianchen Deng committed
180
        else:
Nianchen Deng's avatar
sync    
Nianchen Deng committed
181
            if bbox is None:
Nianchen Deng's avatar
sync    
Nianchen Deng committed
182
                raise ValueError("Missing argument 'bbox'")
Nianchen Deng's avatar
sync    
Nianchen Deng committed
183
184
            super().__init__(bbox=bbox)
            if steps is not None:
Nianchen Deng's avatar
sync    
Nianchen Deng committed
185
                self.register_buffer('steps', torch.tensor(steps, dtype=torch.long))
Nianchen Deng's avatar
sync    
Nianchen Deng committed
186
187
            else:
                self.register_buffer('steps', get_grid_steps(self.bbox, voxel_size))
Nianchen Deng's avatar
sync    
Nianchen Deng committed
188
189
190
191
192
193
194
195
            self.register_buffer('voxels', init_voxels(self.bbox, self.steps))
            corners, corner_indices = get_corners(self.voxels, self.bbox, self.steps)
            self.register_buffer("corners", corners)
            self.register_buffer("corner_indices", corner_indices)
            self.register_buffer('voxel_indices_in_grid', torch.arange(-1, self.n_voxels))

    def to_vi(self, gi: torch.Tensor) -> torch.Tensor:
        return self.voxel_indices_in_grid[gi + 1]
Nianchen Deng's avatar
sync    
Nianchen Deng committed
196
197
198
199
200
201
202
203
204

    def create_embedding(self, n_dims: int, name: str = 'default') -> torch.nn.Embedding:
        """
        Create a embedding on voxel corners.

        :param name `str`: embedding name
        :param n_dims `int`: embedding dimension
        :return `Embedding(n_corners, n_dims)`: new embedding on voxel corners
        """
Nianchen Deng's avatar
sync    
Nianchen Deng committed
205
206
        if self.get_embedding(name) is not None:
            raise KeyError(f"Embedding '{name}' already existed")
Nianchen Deng's avatar
sync    
Nianchen Deng committed
207
        emb = torch.nn.Embedding(self.n_corners, n_dims).to(self.device)
Nianchen Deng's avatar
sync    
Nianchen Deng committed
208
209
        setattr(self, f'emb_{name}', emb)
        return emb
Nianchen Deng's avatar
sync    
Nianchen Deng committed
210
211

    def get_embedding(self, name: str = 'default') -> torch.nn.Embedding:
Nianchen Deng's avatar
sync    
Nianchen Deng committed
212
213
214
        return getattr(self, f'emb_{name}', None)

    def set_embedding(self, weight: torch.Tensor, name: str = 'default'):
Nianchen Deng's avatar
sync    
Nianchen Deng committed
215
        emb = torch.nn.Embedding(*weight.shape, _weight=weight).to(self.device)
Nianchen Deng's avatar
sync    
Nianchen Deng committed
216
        setattr(self, f'emb_{name}', emb)
Nianchen Deng's avatar
sync    
Nianchen Deng committed
217
        return emb
Nianchen Deng's avatar
sync    
Nianchen Deng committed
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233

    def extract_embedding(self, pts: torch.Tensor, voxel_indices: torch.Tensor,
                          name: str = 'default') -> torch.Tensor:
        """
        Extract embedding values at given points using trilinear interpolation.

        :param pts `Tensor(N, 3)`: points to extract values
        :param voxel_indices `Tensor(N)`: corresponding voxel indices
        :param name `str`: embedding name, default to 'default'
        :return `Tensor(N, X)`: extracted values
        """
        emb = self.get_embedding(name)
        if emb is None:
            raise KeyError(f"Embedding '{name}' doesn't exist")
        voxels = self.voxels[voxel_indices]  # (N, 3)
        corner_indices = self.corner_indices[voxel_indices]  # (N, 8)
Nianchen Deng's avatar
sync    
Nianchen Deng committed
234
        p = (pts - voxels) / self.voxel_size + .5  # (N, 3) normed-coords in voxel
Nianchen Deng's avatar
sync    
Nianchen Deng committed
235
        return linear_interp(p, emb(corner_indices))
Nianchen Deng's avatar
sync    
Nianchen Deng committed
236

Nianchen Deng's avatar
sync    
Nianchen Deng committed
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
    def create_voxel_embedding(self, n_dims: int, name: str = 'default') -> torch.nn.Embedding:
        """
        Create a embedding on voxels.

        :param name `str`: embedding name
        :param n_dims `int`: embedding dimension
        :return `Embedding(n_corners, n_dims)`: new embedding on voxels
        """
        if self.get_voxel_embedding(name) is not None:
            raise KeyError(f"Embedding '{name}' already existed")
        emb = torch.nn.Embedding(self.n_voxels, n_dims).to(self.device)
        setattr(self, f'vemb_{name}', emb)
        return emb

    def get_voxel_embedding(self, name: str = 'default') -> torch.nn.Embedding:
        return getattr(self, f'vemb_{name}', None)

    def set_voxel_embedding(self, weight: torch.Tensor, name: str = 'default'):
        emb = torch.nn.Embedding(*weight.shape, _weight=weight).to(self.device)
        setattr(self, f'vemb_{name}', emb)
        return emb

    def extract_voxel_embedding(self, voxel_indices: torch.Tensor, name: str = 'default') -> torch.Tensor:
        """
        Extract embedding values at given voxels.

        :param voxel_indices `Tensor(N)`: voxel indices
        :param name `str`: embedding name, default to 'default'
        :return `Tensor(N, X)`: extracted values
        """
        emb = self.get_voxel_embedding(name)
        if emb is None:
            raise KeyError(f"Embedding '{name}' doesn't exist")
        return emb(voxel_indices)

Nianchen Deng's avatar
sync    
Nianchen Deng committed
272
    @profile
Nianchen Deng's avatar
sync    
Nianchen Deng committed
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
    def ray_intersect(self, rays_o: torch.Tensor, rays_d: torch.Tensor, n_max_hits: int) -> Intersections:
        """
        Calculate intersections of rays and voxels.

        :param rays_o `Tensor(N, 3)`: rays' origin
        :param rays_d `Tensor(N, 3)`: rays' direction
        :param n_max_hits `int`: maximum number of hits (for allocating enough space)
        :return `Intersection`: intersections of rays and voxels
        """
        # Prepend a dim to meet the requirement of external call
        rays_o = rays_o[None].contiguous()
        rays_d = rays_d[None].contiguous()

        voxel_indices, min_depths, max_depths = self._ray_intersect(rays_o, rays_d, n_max_hits)
        invalid_voxel_mask = voxel_indices.eq(-1)
        hits = n_max_hits - invalid_voxel_mask.sum(-1)

        # Sort intersections according to their depths
Nianchen Deng's avatar
sync    
Nianchen Deng committed
291
292
        min_depths.masked_fill_(invalid_voxel_mask, math.huge)
        max_depths.masked_fill_(invalid_voxel_mask, math.huge)
Nianchen Deng's avatar
sync    
Nianchen Deng committed
293
294
295
296
297
298
299
300
301
302
303
        min_depths, sorted_idx = min_depths.sort(dim=-1)
        max_depths = max_depths.gather(-1, sorted_idx)
        voxel_indices = voxel_indices.gather(-1, sorted_idx)

        return Intersections(
            min_depths=min_depths[0],
            max_depths=max_depths[0],
            voxel_indices=voxel_indices[0],
            hits=hits[0]
        )

Nianchen Deng's avatar
sync    
Nianchen Deng committed
304
    @profile
Nianchen Deng's avatar
sync    
Nianchen Deng committed
305
306
307
308
309
310
311
312
313
    def get_voxel_indices(self, pts: torch.Tensor) -> torch.Tensor:
        """
        Get voxel indices of points.

        If a point is not in any valid voxels, its corresponding voxel index is -1.

        :param pts `Tensor(N..., 3)`: points
        :return `Tensor(N...)`: corresponding voxel indices
        """
Nianchen Deng's avatar
sync    
Nianchen Deng committed
314
315
316
        gi = to_grid_indices(pts, self.bbox, self.steps)
        return self.to_vi(gi)

Nianchen Deng's avatar
sync    
Nianchen Deng committed
317
318
    @profile
    def get_corners(self, vidxs: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
Nianchen Deng's avatar
sync    
Nianchen Deng committed
319
320
321
322
323
324
325
326
327
        vidxs = vidxs.unique()
        if vidxs[0] == -1:
            vidxs = vidxs[1:]
        cidxs = self.corner_indices[vidxs].unique()
        fi_cidxs = torch.full([self.n_corners], -1, dtype=torch.long, device=self.device)
        fi_cidxs[cidxs] = torch.arange(cidxs.shape[0], device=self.device)
        fi_corner_indices = fi_cidxs[self.corner_indices]
        fi_corners = self.corners[cidxs]
        return fi_corner_indices, fi_corners
Nianchen Deng's avatar
sync    
Nianchen Deng committed
328
329

    @torch.no_grad()
Nianchen Deng's avatar
sync    
Nianchen Deng committed
330
    def split(self) -> tuple[int, int]:
Nianchen Deng's avatar
sync    
Nianchen Deng committed
331
332
333
        """
        Split voxels into smaller voxels with half size.
        """
Nianchen Deng's avatar
sync    
Nianchen Deng committed
334
        # Calculate new voxels and corners
Nianchen Deng's avatar
sync    
Nianchen Deng committed
335
336
        new_steps = self.steps * 2
        new_voxels = split_voxels(self.voxels, self.voxel_size, 2, align_border=False)\
Nianchen Deng's avatar
sync    
Nianchen Deng committed
337
            .reshape(-1, 3)
Nianchen Deng's avatar
sync    
Nianchen Deng committed
338
339
        new_corners, new_corner_indices = get_corners(new_voxels, self.bbox, new_steps)

Nianchen Deng's avatar
sync    
Nianchen Deng committed
340
341
342
343
344
345
346
347
348
349
350
351
352
        # Split corner embeddings through interpolation
        corner_embs = self.corner_embeddings
        if len(corner_embs) > 0:
            gi_of_new_corners = to_grid_indices(new_corners, self.bbox, self.steps)
            vi_of_new_corners = self.to_vi(gi_of_new_corners)
            for name, emb in corner_embs.items():
                new_emb_weight = self.extract_embedding(new_corners, vi_of_new_corners, name=name)
                self.set_embedding(new_emb_weight, name=name)
                # Remove old embedding weight and related state from optimizer
                self._update_optimizer(emb.weight)

        # Split voxel embeddings
        self._update_voxel_embeddings(lambda val: torch.repeat_interleave(val, 8, dim=0))
Nianchen Deng's avatar
sync    
Nianchen Deng committed
353
354
355
356
357
358

        # Apply new tensors
        self.steps = new_steps
        self.voxels = new_voxels
        self.corners = new_corners
        self.corner_indices = new_corner_indices
Nianchen Deng's avatar
sync    
Nianchen Deng committed
359
        self._update_gi2vi()
Nianchen Deng's avatar
sync    
Nianchen Deng committed
360
        return self.n_voxels // 8, self.n_voxels
Nianchen Deng's avatar
sync    
Nianchen Deng committed
361
362

    @torch.no_grad()
Nianchen Deng's avatar
sync    
Nianchen Deng committed
363
    def prune(self, keeps: torch.Tensor) -> tuple[int, int]:
Nianchen Deng's avatar
sync    
Nianchen Deng committed
364
365
        self.voxels = self.voxels[keeps]
        self.corner_indices = self.corner_indices[keeps]
Nianchen Deng's avatar
sync    
Nianchen Deng committed
366
367
368
369
370
        self._update_gi2vi()

        # Prune voxel embeddings
        self._update_voxel_embeddings(lambda val: val[keeps])

Nianchen Deng's avatar
sync    
Nianchen Deng committed
371
372
        return keeps.size(0), keeps.sum().item()

Nianchen Deng's avatar
sync    
Nianchen Deng committed
373
374
375
376
377
    def _update_voxel_embeddings(self, update_fn):
        for name, emb in self.voxel_embeddings.items():
            new_emb = self.set_voxel_embedding(update_fn(emb.weight), name)
            self._update_optimizer(emb.weight, new_emb.weight, update_fn)

Nianchen Deng's avatar
sync    
Nianchen Deng committed
378
    def _update_optimizer(self, old_param: Parameter, new_param: Parameter, update_fn):
Nianchen Deng's avatar
sync    
Nianchen Deng committed
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
        optimizer = get_env()["trainer"].optimizer
        if isinstance(optimizer, (torch.optim.Adam, torch.optim.AdamW)):
            # Update related states in optimizer
            if old_param in optimizer.state:
                if new_param is not None:
                    # Transfer state from old parameter to new parameter
                    state = optimizer.state[old_param]
                    state.update({
                        key: update_fn(state[key])
                        for key in ['exp_avg', 'exp_avg_sq', 'max_exp_avg_sq'] if key in state
                    })
                    optimizer.state[new_param] = state
                # Remove state of old parameter
                optimizer.state.pop(old_param)

            # Update parameter list in optimizer
            for group in optimizer.param_groups:
                try:
                    if new_param is not None:
                        # Replace old parameter with new one
                        idx = group['params'].index(old_param)
                        group['params'][idx] = new_param
                    else:
                        # Or just remove old parameter if new parameter is not specified
                        group['params'].remove(old_param)
                except Exception:
                    pass

Nianchen Deng's avatar
sync    
Nianchen Deng committed
407
408
    def n_voxels_along_dim(self, dim: int) -> torch.Tensor:
        sum_dims = [val for val in range(self.dims) if val != dim]
Nianchen Deng's avatar
sync    
Nianchen Deng committed
409
        return self.voxel_indices_in_grid[1:].reshape(*self.steps).ne(-1).sum(sum_dims)
Nianchen Deng's avatar
sync    
Nianchen Deng committed
410

Nianchen Deng's avatar
sync    
Nianchen Deng committed
411
    def balance_cut(self, dim: int, n_parts: int) -> list[int]:
Nianchen Deng's avatar
sync    
Nianchen Deng committed
412
413
414
415
416
417
        n_voxels_list = self.n_voxels_along_dim(dim)
        cdf = (n_voxels_list.cumsum(0) / self.n_voxels * n_parts).tolist()
        bins = []
        part = 1
        offset = 0
        for i in range(len(cdf)):
Nianchen Deng's avatar
sync    
Nianchen Deng committed
418
419
420
            if cdf[i] > part:
                bins.append(i - offset)
                offset = i
Nianchen Deng's avatar
sync    
Nianchen Deng committed
421
                part = int(cdf[i]) + 1
Nianchen Deng's avatar
sync    
Nianchen Deng committed
422
        bins.append(len(cdf) - offset)
Nianchen Deng's avatar
sync    
Nianchen Deng committed
423
424
        return bins

Nianchen Deng's avatar
sync    
Nianchen Deng committed
425
    def sample(self, S: int, perturb: bool = False, include_border: bool = True) -> tuple[torch.Tensor, torch.Tensor]:
Nianchen Deng's avatar
sync    
Nianchen Deng committed
426
427
428
429
430
        """
        For each voxel, sample `S^3` points uniformly, with small perturb if `perturb` is `True`.

        When `perturb` is `False`, `include_border` can specify whether to sample points from border to border or at centers of sub-voxels.
        When `perturb` is `True`, points are sampled at centers of sub-voxels, then applying a random offset in sub-voxels.
Nianchen Deng's avatar
sync    
Nianchen Deng committed
431

Nianchen Deng's avatar
sync    
Nianchen Deng committed
432
433
434
435
436
437
438
439
440
441
442
443
444
        :param S `int`: number of samples along each dim
        :param perturb `bool?`: whether perturb samples, defaults to `False`
        :param include_border `bool?`: whether include border, defaults to `True`
        :return `Tensor(N*S^3, 3)`: sampled points
        :return `Tensor(N*S^3)`: voxel indices of sampled points
        """
        pts = split_voxels(self.voxels, self.voxel_size, S,
                           align_border=not perturb and include_border)  # (N, X, D)
        voxel_indices = torch.arange(self.n_voxels, device=self.device)[:, None]\
            .expand(*pts.shape[:-1])  # (N) -> (N, X)
        if perturb:
            pts += (torch.rand_like(pts) - .5) * self.voxel_size / S
        return pts.reshape(-1, 3), voxel_indices.flatten()
Nianchen Deng's avatar
sync    
Nianchen Deng committed
445

Nianchen Deng's avatar
sync    
Nianchen Deng committed
446
    def _ray_intersect(self, rays_o: torch.Tensor, rays_d: torch.Tensor, n_max_hits: int) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
Nianchen Deng's avatar
sync    
Nianchen Deng committed
447
448
        return aabb_ray_intersect(self.voxel_size, n_max_hits, self.voxels, rays_o, rays_d)

Nianchen Deng's avatar
sync    
Nianchen Deng committed
449
    def _update_gi2vi(self):
Nianchen Deng's avatar
sync    
Nianchen Deng committed
450
451
452
        """
        Update voxel indices in grid.
        """
Nianchen Deng's avatar
sync    
Nianchen Deng committed
453
454
455
456
        gi = to_grid_indices(self.voxels, self.bbox, self.steps)
        # Perserve the first element in voxel_indices_in_grid for 'invalid voxel'(-1)
        self.voxel_indices_in_grid = gi.new_full([self.n_grids + 1], -1)
        self.voxel_indices_in_grid[gi + 1] = torch.arange(self.n_voxels, device=self.device)
Nianchen Deng's avatar
sync    
Nianchen Deng committed
457

Nianchen Deng's avatar
sync    
Nianchen Deng committed
458
    def _before_load_state_dict(self, state_dict, prefix, *args):
Nianchen Deng's avatar
sync    
Nianchen Deng committed
459
460
461
462
463
464
465
466
467
        # Handle buffers
        for name, buffer in self.named_buffers(recurse=False):
            if name in self._non_persistent_buffers_set:
                continue
            buffer.resize_as_(state_dict[prefix + name])

        # Handle embeddings
        for name, module in self.named_modules():
            if name.startswith('emb_'):
Nianchen Deng's avatar
sync    
Nianchen Deng committed
468
                setattr(self, name, torch.nn.Embedding(self.n_corners, module.embedding_dim))
Nianchen Deng's avatar
sync    
Nianchen Deng committed
469
470
471
472
473
            if name.startswith('vemb_'):
                setattr(self, name, torch.nn.Embedding(self.n_voxels, module.embedding_dim))

    def _after_load_state_dict(self):
        self._update_gi2vi()
Nianchen Deng's avatar
sync    
Nianchen Deng committed
474
475
476
477


class Octree(Voxels):

Nianchen Deng's avatar
sync    
Nianchen Deng committed
478
479
    def __init__(self, clone_src: "Octree" = None, **kwargs) -> None:
        super().__init__(clone_src, **kwargs)
Nianchen Deng's avatar
sync    
Nianchen Deng committed
480
481
482
        self.nodes_cached = None
        self.tree_cached = None

Nianchen Deng's avatar
sync    
Nianchen Deng committed
483
    def get(self) -> tuple[torch.Tensor, torch.Tensor]:
Nianchen Deng's avatar
sync    
Nianchen Deng committed
484
485
486
487
488
489
490
491
492
493
494
495
496
497
        if self.nodes_cached is None:
            self.nodes_cached, self.tree_cached = build_easy_octree(
                self.voxels, 0.5 * self.voxel_size)
        return self.nodes_cached, self.tree_cached

    def clear(self):
        self.nodes_cached = None
        self.tree_cached = None

    def _ray_intersect(self, rays_o: torch.Tensor, rays_d: torch.Tensor, n_max_hits: int):
        nodes, tree = self.get()
        return octree_ray_intersect(self.voxel_size, n_max_hits, nodes, tree, rays_o, rays_d)

    @torch.no_grad()
Nianchen Deng's avatar
sync    
Nianchen Deng committed
498
499
    def split(self):
        ret = super().split()
Nianchen Deng's avatar
sync    
Nianchen Deng committed
500
501
502
503
        self.clear()
        return ret

    @torch.no_grad()
Nianchen Deng's avatar
sync    
Nianchen Deng committed
504
    def prune(self, keeps: torch.Tensor) -> tuple[int, int]:
Nianchen Deng's avatar
sync    
Nianchen Deng committed
505
506
507
        ret = super().prune(keeps)
        self.clear()
        return ret