sampler.py 14.7 KB
Newer Older
Nianchen Deng's avatar
sync    
Nianchen Deng committed
1
from .space import Space, Voxels
Nianchen Deng's avatar
Nianchen Deng committed
2
3
import torch
import torch.nn as nn
Nianchen Deng's avatar
sync    
Nianchen Deng committed
4
5
from typing import Tuple

Nianchen Deng's avatar
Nianchen Deng committed
6
7
8
from utils import device
from utils import sphere
from utils.constants import *
Nianchen Deng's avatar
sync    
Nianchen Deng committed
9
from utils.perf import perf, checkpoint
Nianchen Deng's avatar
Nianchen Deng committed
10
from .generic import *
Nianchen Deng's avatar
sync    
Nianchen Deng committed
11
from clib import *
Nianchen Deng's avatar
Nianchen Deng committed
12
13
14
15


class Bins(object):

Nianchen Deng's avatar
sync    
Nianchen Deng committed
16
17
18
19
20
21
22
23
    @property
    def up(self):
        return self.bounds[1:]

    @property
    def lo(self):
        return self.bounds[:-1]

Nianchen Deng's avatar
Nianchen Deng committed
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
    def __init__(self, vals: torch.Tensor):
        self.vals = vals
        self.bounds = torch.cat([
            self.vals[:1],
            0.5 * (self.vals[1:] + self.vals[:-1]),
            self.vals[-1:]
        ])

    @staticmethod
    def linspace(val_range: Tuple[float, float], N: int, device: torch.device = None):
        return Bins(torch.linspace(*val_range, N, device=device))

    def to(self, device: torch.device):
        self.vals = self.vals.to(device)
        self.bounds = self.bounds.to(device)
Nianchen Deng's avatar
sync    
Nianchen Deng committed
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87


class Samples:
    pts: torch.Tensor
    """`Tensor(N[, P], 3)`"""

    dirs: torch.Tensor
    """`Tensor(N[, P], 3)`"""

    depths: torch.Tensor
    """`Tensor(N[, P])`"""

    dists: torch.Tensor
    """`Tensor(N[, P])`"""

    voxel_indices: torch.Tensor
    """`Tensor(N[, P])`"""

    @property
    def size(self):
        return self.pts.size()[:-1]

    @property
    def device(self):
        return self.pts.device

    def __init__(self, pts: torch.Tensor, dirs: torch.Tensor, depths: torch.Tensor,
                 dists: torch.Tensor, voxel_indices: torch.Tensor) -> None:
        self.pts = pts
        self.dirs = dirs
        self.depths = depths
        self.dists = dists
        self.voxel_indices = voxel_indices

    def __getitem__(self, index):
        return Samples(
            pts=self.pts[index],
            dirs=self.dirs[index],
            depths=self.depths[index],
            dists=self.dists[index],
            voxel_indices=self.voxel_indices[index])

    def reshape(self, *shape: int):
        return Samples(
            pts=self.pts.reshape(*shape, 3),
            dirs=self.dirs.reshape(*shape, 3),
            depths=self.depths.reshape(*shape),
            dists=self.dists.reshape(*shape),
            voxel_indices=self.voxel_indices.reshape(*shape))
Nianchen Deng's avatar
Nianchen Deng committed
88
89
90
91


class Sampler(nn.Module):

Nianchen Deng's avatar
sync    
Nianchen Deng committed
92
    def __init__(self, *, sample_range: Tuple[float, float], n_samples: int, lindisp: bool, **kwargs):
Nianchen Deng's avatar
Nianchen Deng committed
93
94
95
96
97
98
99
100
101
102
103
        """
        Initialize a Sampler module

        :param depth_range: depth range for sampler
        :param n_samples: count to sample along ray
        :param perturb_sample: perturb the sample depths
        :param lindisp: If True, sample linearly in inverse depth rather than in depth
        """
        super().__init__()
        self.lindisp = lindisp
        s_range = (1 / sample_range[0], 1 / sample_range[1]) if self.lindisp else sample_range
Nianchen Deng's avatar
sync    
Nianchen Deng committed
104
105
106
107
108
109
        if s_range[1] > s_range[0]:
            s_range[0] += 1e-4
            s_range[1] -= 1e-4
        else:
            s_range[0] -= 1e-4
            s_range[1] += 1e-4
Nianchen Deng's avatar
Nianchen Deng committed
110
111
        self.bins = Bins.linspace(s_range, n_samples, device=device.default())

Nianchen Deng's avatar
sync    
Nianchen Deng committed
112
113
114
    @perf
    def forward(self, rays_o: torch.Tensor, rays_d: torch.Tensor, space_module: Space,
                perturb_sample: bool, **kwargs) -> Tuple[Samples, torch.Tensor]:
Nianchen Deng's avatar
Nianchen Deng committed
115
116
117
118
        """
        Sample points along rays. return Spherical or Cartesian coordinates, 
        specified by `self.shperical`

Nianchen Deng's avatar
sync    
Nianchen Deng committed
119
120
121
        :param rays_o `Tensor(N, 3)`: rays' origin
        :param rays_d `Tensor(N, 3)`: rays' direction
        :return `Samples(N, P)`: samples
Nianchen Deng's avatar
Nianchen Deng committed
122
123
        """
        s = self.bins.vals.expand(rays_o.size(0), -1)
Nianchen Deng's avatar
sync    
Nianchen Deng committed
124
        if perturb_sample:
Nianchen Deng's avatar
Nianchen Deng committed
125
            s = self.bins.lo + (self.bins.up - self.bins.lo) * torch.rand_like(s)
Nianchen Deng's avatar
sync    
Nianchen Deng committed
126
127
128
129
130
131
132
133
134
135
136
137
        pts, depths = self._get_sample_points(rays_o, rays_d, s)
        voxel_indices = space_module.get_voxel_indices(pts)
        valid_rays_mask = voxel_indices.ne(-1).any(dim=-1)
        return Samples(
            pts=pts,
            dirs=rays_d[:, None].expand(-1, depths.size(1), -1),
            depths=depths,
            dists=self._calc_dists(depths),
            voxel_indices=voxel_indices
        )[valid_rays_mask], valid_rays_mask

    def _get_sample_points(self, rays_o, rays_d, s):
Nianchen Deng's avatar
Nianchen Deng committed
138
        z = torch.reciprocal(s) if self.lindisp else s
Nianchen Deng's avatar
sync    
Nianchen Deng committed
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
        pts = rays_o[:, None] + rays_d[:, None] * z[..., None]
        depths = z
        return pts, depths

    def _calc_dists(self, vals):
        # Compute 'distance' (in time) between each integration time along a ray.
        # The 'distance' from the last integration time is infinity.
        # dists: (N_rays, N)
        dists = vals[..., 1:] - vals[..., :-1]
        last_dist = torch.zeros_like(vals[..., :1]) + TINY_FLOAT
        return torch.cat([dists, last_dist], -1)


class SphericalSampler(Sampler):

    def __init__(self, *, sample_range: Tuple[float, float], n_samples: int,
                 perturb_sample: bool, **kwargs):
        """
        Initialize a Sampler module

        :param depth_range: depth range for sampler
        :param n_samples: count to sample along ray
        :param perturb_sample: perturb the sample depths
        :param lindisp: If True, sample linearly in inverse depth rather than in depth
        """
        super().__init__(sample_range=sample_range, n_samples=n_samples,
                         perturb_sample=perturb_sample, lindisp=False)

    def _get_sample_points(self, rays_o, rays_d, s):
        r = torch.reciprocal(s)
        pts, depths = sphere.ray_sphere_intersect(rays_o, rays_d, r)
        pts = sphere.cartesian2spherical(pts, inverse_r=True)
        return pts, depths
Nianchen Deng's avatar
Nianchen Deng committed
172
173
174
175
176


class PdfSampler(nn.Module):

    def __init__(self, *, depth_range: Tuple[float, float], n_samples: int, perturb_sample: bool,
Nianchen Deng's avatar
sync    
Nianchen Deng committed
177
                 spherical: bool, lindisp: bool, **kwargs):
Nianchen Deng's avatar
Nianchen Deng committed
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
        """
        Initialize a Sampler module

        :param depth_range: depth range for sampler
        :param n_samples: count to sample along ray
        :param perturb_sample: perturb the sample depths
        :param lindisp: If True, sample linearly in inverse depth rather than in depth
        """
        super().__init__()
        self.lindisp = lindisp
        self.perturb_sample = perturb_sample
        self.spherical = spherical
        self.n_samples = n_samples
        self.s_range = (1 / depth_range[0], 1 / depth_range[1]) if self.lindisp else depth_range

Nianchen Deng's avatar
sync    
Nianchen Deng committed
193
    def forward(self, rays_o, rays_d, *, weights, s_vals=None, include_s_vals=False, **kwargs):
Nianchen Deng's avatar
Nianchen Deng committed
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
        """
        Sample points along rays. return Spherical or Cartesian coordinates, 
        specified by `self.shperical`

        :param rays_o `Tensor(B, 3)`: rays' origin
        :param rays_d `Tensor(B, 3)`: rays' direction
        :param weights `Tensor(B, M)`: weights of sample bins
        :param s_vals `Tensor(B, M)`: (optional) center of sample bins
        :param include_s_vals `bool`: (default to `False`) include `s_vals` in the sample array
        :return `Tensor(B, N, 3)`: sampled points
        :return `Tensor(B, N)`: corresponding depths along rays
        """
        if s_vals is None:
            s_vals = torch.linspace(*self.s_range, self.n_samples, device=device.default())
        s = self.sample_pdf(Bins(s_vals).bounds, weights, self.n_samples, det=self.perturb_sample)
        if include_s_vals:
            s = torch.cat([s, s_vals], dim=-1)
        s = torch.sort(s, descending=self.lindisp)[0]
        z = torch.reciprocal(s) if self.lindisp else s
        if self.spherical:
            pts, depths = sphere.ray_sphere_intersect(rays_o, rays_d, z)
            sphers = sphere.cartesian2spherical(pts, inverse_r=self.lindisp)
            return sphers, depths, s, pts
        else:
            return rays_o[..., None, :] + rays_d[..., None, :] * z[..., None], z, s

    def sample_pdf(self, bins: torch.Tensor, weights: torch.Tensor, N: int, det=True):
        '''
        :param bins `Tensor(..., M+1)`: bounds of bins
        :param weights `Tensor(..., M)`: weights of bins
        :param N `int`: # of samples along each ray
        :param det `bool`: (default to `True`) perform deterministic sampling or not
        :return `Tensor(..., N)`: samples
        '''
        # Get pdf
        weights = weights + TINY_FLOAT                                          # prevent nans
        pdf = weights / torch.sum(weights, dim=-1, keepdim=True)                # [..., M]
        cdf = torch.cat([
            torch.zeros_like(pdf[..., :1]),
            torch.cumsum(pdf, dim=-1)
        ], dim=-1)                                                              # [..., M+1]

        # Take uniform samples
        dots_sh = list(weights.shape[:-1])
        M = weights.shape[-1]

        u = torch.linspace(0, 1, N, device=bins.device).expand(dots_sh + [N]) \
            if det else torch.rand(dots_sh + [N], device=bins.device)           # [..., N]

        # Invert CDF
        # [..., N, 1] >= [..., 1, M] ----> [..., N, M] ----> [..., N,]
        above_inds = torch.sum(u[..., None] >= cdf[..., None, :-1], dim=-1).long()

        # random sample inside each bin
        below_inds = torch.clamp(above_inds - 1, min=0)
        inds_g = torch.stack((below_inds, above_inds), dim=-1)                  # [..., N, 2]

        cdf = cdf[..., None, :].expand(dots_sh + [N, M + 1])                    # [..., N, M+1]
        cdf_g = torch.gather(cdf, dim=-1, index=inds_g)                         # [..., N, 2]

        bins = bins[..., None, :].expand(dots_sh + [N, M + 1])                  # [..., N, M+1]
        bins_g = torch.gather(bins, dim=-1, index=inds_g)                       # [..., N, 2]

        # fix numeric issue
        denom = cdf_g[..., 1] - cdf_g[..., 0]                                   # [..., N]
        denom = torch.where(denom < TINY_FLOAT, torch.ones_like(denom), denom)
        t = (u - cdf_g[..., 0]) / denom

        samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0] + TINY_FLOAT)

        return samples


class VoxelSampler(nn.Module):

Nianchen Deng's avatar
sync    
Nianchen Deng committed
269
    def __init__(self, *, perturb_sample: bool, sample_step: float, **kwargs):
Nianchen Deng's avatar
Nianchen Deng committed
270
        """
Nianchen Deng's avatar
sync    
Nianchen Deng committed
271
        Initialize a VoxelSampler module
Nianchen Deng's avatar
Nianchen Deng committed
272
273

        :param perturb_sample: perturb the sample depths
Nianchen Deng's avatar
sync    
Nianchen Deng committed
274
        :param step_size: step size
Nianchen Deng's avatar
Nianchen Deng committed
275
276
277
        """
        super().__init__()
        self.perturb_sample = perturb_sample
Nianchen Deng's avatar
sync    
Nianchen Deng committed
278
279
280
281
282
283
        self.sample_step = sample_step

    def _forward(self, rays_o: torch.Tensor, rays_d: torch.Tensor, space_module: Space,
                 **kwargs) -> Tuple[Samples, torch.Tensor]:
        """
        [summary]
Nianchen Deng's avatar
Nianchen Deng committed
284

Nianchen Deng's avatar
sync    
Nianchen Deng committed
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
        :param rays_o `Tensor(N, 3)`: rays' origin positions
        :param rays_d `Tensor(N, 3)`: rays' directions
        :param step_size `float`: gap between samples along a ray
        :return `Samples(N', P)`: samples along valid rays (which hit at least one voxel)
        :return `Tensor(N)`: valid rays mask
        """
        intersections = space_module.ray_intersect(rays_o, rays_d, 100)
        valid_rays_mask = intersections.hits > 0
        rays_o = rays_o[valid_rays_mask]
        rays_d = rays_d[valid_rays_mask]
        intersections = intersections[valid_rays_mask]  # (N) -> (N')
        n_rays = rays_o.size(0)
        ray_index_list = torch.arange(n_rays, device=rays_o.device, dtype=torch.long)  # (N')

        hits = intersections.hits
        min_depths = intersections.min_depths
        max_depths = intersections.max_depths
        voxel_indices = intersections.voxel_indices

        rays_near_depth = min_depths[:, :1]  # (N', 1)
        rays_far_depth = max_depths[ray_index_list, hits - 1][:, None]  # (N', 1)
        rays_length = rays_far_depth - rays_near_depth
        rays_steps = (rays_length / self.sample_step).ceil().long()
        rays_step_size = rays_length / rays_steps
        max_steps = rays_steps.max().item()
        rays_step = torch.arange(max_steps, device=rays_o.device,
                                 dtype=torch.float)[None].repeat(n_rays, 1)  # (N', P)
        invalid_samples_mask = rays_step >= rays_steps
        samples_min_depth = rays_near_depth + rays_step * rays_step_size
        samples_depth = samples_min_depth + rays_step_size \
            * (torch.rand_like(samples_min_depth) if self.perturb_sample else 0.5)  # (N', P)
        samples_dist = rays_step_size.repeat(1, max_steps)  # (N', 1) -> (N', P)
        samples_voxel_index = voxel_indices[
            ray_index_list[:, None],
            torch.searchsorted(max_depths, samples_depth)
        ]  # (N', P)
        samples_depth[invalid_samples_mask] = HUGE_FLOAT
        samples_dist[invalid_samples_mask] = 0
        samples_voxel_index[invalid_samples_mask] = -1

        rays_o, rays_d = rays_o[:, None], rays_d[:, None]
        return Samples(
            pts=rays_o + rays_d * samples_depth[..., None],
            dirs=rays_d.expand(-1, max_steps, -1),
            depths=samples_depth,
            dists=samples_dist,
            voxel_indices=samples_voxel_index
        ), valid_rays_mask

    @perf
    def forward(self, rays_o: torch.Tensor, rays_d: torch.Tensor, space_module: Space,
                **kwargs) -> Tuple[Samples, torch.Tensor]:
        """
        [summary]

        :param rays_o `Tensor(N, 3)`: [description]
        :param rays_d `Tensor(N, 3)`: [description]
        :param step_size `float`: [description]
        :return `Samples(N, P)`: [description]
        """
        intersections = space_module.ray_intersect(rays_o, rays_d, 100)
        valid_rays_mask = intersections.hits > 0
        rays_o = rays_o[valid_rays_mask]
        rays_d = rays_d[valid_rays_mask]
        intersections = intersections[valid_rays_mask]  # (N) -> (N')

        checkpoint("Ray intersect")

        if intersections.size == 0:
            return None, valid_rays_mask
        else:
            min_depth = intersections.min_depths
            max_depth = intersections.max_depths
            pts_idx = intersections.voxel_indices
            dists = max_depth - min_depth
            tot_dists = dists.sum(dim=-1, keepdim=True)  # (N, 1)
            probs = dists / tot_dists
            steps = tot_dists[:, 0] / self.sample_step

            # sample points and use middle point approximation
            sampled_indices, sampled_depths, sampled_dists = inverse_cdf_sampling(
                pts_idx, min_depth, max_depth, probs, steps, -1, not self.perturb_sample)
            sampled_indices = sampled_indices.long()
            invalid_idx_mask = sampled_indices.eq(-1)
            sampled_dists.clamp_min_(0).masked_fill_(invalid_idx_mask, 0)
            sampled_depths.masked_fill_(invalid_idx_mask, HUGE_FLOAT)

            checkpoint("Inverse CDF sampling")

            rays_o, rays_d = rays_o[:, None], rays_d[:, None]
            return Samples(
                pts=rays_o + rays_d * sampled_depths[..., None],
                dirs=rays_d.expand(-1, sampled_depths.size(1), -1),
                depths=sampled_depths,
                dists=sampled_dists,
                voxel_indices=sampled_indices
            ), valid_rays_mask