sampler.py 14.1 KB
Newer Older
Nianchen Deng's avatar
Nianchen Deng committed
1
import torch
Nianchen Deng's avatar
sync    
Nianchen Deng committed
2
3
from typing import Tuple

Nianchen Deng's avatar
sync    
Nianchen Deng committed
4
5
6
from .generic import *
from .space import Space
from clib import *
Nianchen Deng's avatar
Nianchen Deng committed
7
8
from utils import device
from utils import sphere
Nianchen Deng's avatar
sync    
Nianchen Deng committed
9
10
11
12
from utils import misc
from utils import math
from utils.module import Module
from utils.samples import Samples
Nianchen Deng's avatar
sync    
Nianchen Deng committed
13
from utils.perf import perf, checkpoint
Nianchen Deng's avatar
Nianchen Deng committed
14
15
16
17


class Bins(object):

Nianchen Deng's avatar
sync    
Nianchen Deng committed
18
19
20
21
22
23
24
25
    @property
    def up(self):
        return self.bounds[1:]

    @property
    def lo(self):
        return self.bounds[:-1]

Nianchen Deng's avatar
Nianchen Deng committed
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
    def __init__(self, vals: torch.Tensor):
        self.vals = vals
        self.bounds = torch.cat([
            self.vals[:1],
            0.5 * (self.vals[1:] + self.vals[:-1]),
            self.vals[-1:]
        ])

    @staticmethod
    def linspace(val_range: Tuple[float, float], N: int, device: torch.device = None):
        return Bins(torch.linspace(*val_range, N, device=device))

    def to(self, device: torch.device):
        self.vals = self.vals.to(device)
        self.bounds = self.bounds.to(device)
Nianchen Deng's avatar
sync    
Nianchen Deng committed
41
42


Nianchen Deng's avatar
sync    
Nianchen Deng committed
43
class Sampler(Module):
Nianchen Deng's avatar
sync    
Nianchen Deng committed
44

Nianchen Deng's avatar
sync    
Nianchen Deng committed
45
    def __init__(self, **kwargs):
Nianchen Deng's avatar
Nianchen Deng committed
46
47
48
49
        """
        Initialize a Sampler module
        """
        super().__init__()
Nianchen Deng's avatar
sync    
Nianchen Deng committed
50
        self._samples_indices_cached = None
Nianchen Deng's avatar
Nianchen Deng committed
51

Nianchen Deng's avatar
sync    
Nianchen Deng committed
52
53
    def _sample(self, range: Tuple[float, float], n_rays: int, n_samples: int, perturb: bool,
                device: torch.device) -> torch.Tensor:
Nianchen Deng's avatar
Nianchen Deng committed
54
        """
Nianchen Deng's avatar
sync    
Nianchen Deng committed
55
        [summary]
Nianchen Deng's avatar
Nianchen Deng committed
56

Nianchen Deng's avatar
sync    
Nianchen Deng committed
57
58
59
60
61
62
        :param t_range `float, float`: sampling range
        :param n_rays `int`: number of rays (B)
        :param n_samples `int`: number of samples per ray (P)
        :param perturb `bool`: whether perturb sampling
        :param device `torch.device`: the device used to create tensors
        :return `Tensor(B, P+1)`: sampling bounds of t
Nianchen Deng's avatar
Nianchen Deng committed
63
        """
Nianchen Deng's avatar
sync    
Nianchen Deng committed
64
65
66
67
68
69
70
71
72
73
74
75
        bounds = torch.linspace(*range, n_samples + 1, device=device)  # (P+1)
        if perturb:
            rand_bounds = torch.cat([
                bounds[:1],
                0.5 * (bounds[1:] + bounds[:-1]),
                bounds[-1:]
            ])
            rand_vals = torch.rand(n_rays, n_samples + 1, device=device)
            bounds = rand_bounds[:-1] * (1 - rand_vals) + rand_bounds[1:] * rand_vals
        else:
            bounds = bounds[None].expand(n_rays, -1)
        return bounds
Nianchen Deng's avatar
sync    
Nianchen Deng committed
76

Nianchen Deng's avatar
sync    
Nianchen Deng committed
77
78
79
80
81
82
83
    def _get_samples_indices(self, pts: torch.Tensor):
        if self._samples_indices_cached is None\
                or self._samples_indices_cached.shape[0] < pts.shape[0]\
                or self._samples_indices_cached.shape[1] < pts.shape[1]:
            self._samples_indices_cached = misc.meshgrid(
                *pts.shape[:2], swap_dim=True, device=pts.device)
        return self._samples_indices_cached[:pts.shape[0], :pts.shape[1]]
Nianchen Deng's avatar
sync    
Nianchen Deng committed
84

Nianchen Deng's avatar
sync    
Nianchen Deng committed
85
86
87
88
89
    @perf
    def forward(self, rays_o: torch.Tensor, rays_d: torch.Tensor, space_: Space, *,
                sample_range: Tuple[float, float], n_samples: int, lindisp: bool = False,
                perturb_sample: bool = True, spherical: bool = False,
                **kwargs) -> Tuple[Samples, torch.Tensor]:
Nianchen Deng's avatar
sync    
Nianchen Deng committed
90
        """
Nianchen Deng's avatar
sync    
Nianchen Deng committed
91
        Sample points along rays.
Nianchen Deng's avatar
sync    
Nianchen Deng committed
92

Nianchen Deng's avatar
sync    
Nianchen Deng committed
93
94
95
96
97
98
99
        :param rays_o `Tensor(B, 3)`: rays' origin
        :param rays_d `Tensor(B, 3)`: rays' direction
        :param sample_range `float, float`: sampling range
        :param n_samples `int`: number of samples per ray
        :param lindisp `bool`: whether sample linearly in disparity space (1/depth)
        :param perturb_sample `bool`: whether perturb sampling
        :return `Samples(B, P)`: samples
Nianchen Deng's avatar
sync    
Nianchen Deng committed
100
        """
Nianchen Deng's avatar
sync    
Nianchen Deng committed
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
        if spherical:
            t_bounds = self._sample(sample_range, rays_o.shape[0], n_samples, perturb_sample,
                                    rays_o.device)
            t0, t1 = t_bounds[:, :-1], t_bounds[:, 1:]  # (B, P)
            t = (t0 + t1) * .5

            p, z = sphere.ray_sphere_intersect(rays_o, rays_d, t.reciprocal())
            p = sphere.cartesian2spherical(p, inverse_r=True)
            vidxs = space_.get_voxel_indices(p)
            return Samples(
                pts=p,
                dirs=rays_d[:, None].expand(-1, n_samples, -1),
                depths=z,
                dists=(t1 + math.tiny).reciprocal() - t0.reciprocal(),
                voxel_indices=vidxs,
                indices=self._get_samples_indices(p),
                t=t
            )
        else:
            sample_range = (1 / sample_range[0], 1 / sample_range[1]) if lindisp else sample_range
            z_bounds = self._sample(sample_range, rays_o.shape[0], n_samples, perturb_sample,
                                    rays_o.device)
            if lindisp:
                z_bounds = z_bounds.reciprocal()
            z0, z1 = z_bounds[:, :-1], z_bounds[:, 1:]  # (B, P)
            z = (z0 + z1) * .5
            p = rays_o[:, None] + rays_d[:, None] * z[..., None]
            vidxs = space_.get_voxel_indices(p)
            return Samples(
                pts=p,
                dirs=rays_d[:, None].expand(-1, n_samples, -1),
                depths=z,
                dists=z1 - z0,
                voxel_indices=vidxs,
                indices=self._get_samples_indices(p),
                t=z
            )
Nianchen Deng's avatar
Nianchen Deng committed
138
139


Nianchen Deng's avatar
sync    
Nianchen Deng committed
140
class PdfSampler(Module):
Nianchen Deng's avatar
Nianchen Deng committed
141
142

    def __init__(self, *, depth_range: Tuple[float, float], n_samples: int, perturb_sample: bool,
Nianchen Deng's avatar
sync    
Nianchen Deng committed
143
                 spherical: bool, lindisp: bool, **kwargs):
Nianchen Deng's avatar
Nianchen Deng committed
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
        """
        Initialize a Sampler module

        :param depth_range: depth range for sampler
        :param n_samples: count to sample along ray
        :param perturb_sample: perturb the sample depths
        :param lindisp: If True, sample linearly in inverse depth rather than in depth
        """
        super().__init__()
        self.lindisp = lindisp
        self.perturb_sample = perturb_sample
        self.spherical = spherical
        self.n_samples = n_samples
        self.s_range = (1 / depth_range[0], 1 / depth_range[1]) if self.lindisp else depth_range

Nianchen Deng's avatar
sync    
Nianchen Deng committed
159
    def forward(self, rays_o, rays_d, *, weights, s_vals=None, include_s_vals=False, **kwargs):
Nianchen Deng's avatar
Nianchen Deng committed
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
        """
        Sample points along rays. return Spherical or Cartesian coordinates, 
        specified by `self.shperical`

        :param rays_o `Tensor(B, 3)`: rays' origin
        :param rays_d `Tensor(B, 3)`: rays' direction
        :param weights `Tensor(B, M)`: weights of sample bins
        :param s_vals `Tensor(B, M)`: (optional) center of sample bins
        :param include_s_vals `bool`: (default to `False`) include `s_vals` in the sample array
        :return `Tensor(B, N, 3)`: sampled points
        :return `Tensor(B, N)`: corresponding depths along rays
        """
        if s_vals is None:
            s_vals = torch.linspace(*self.s_range, self.n_samples, device=device.default())
        s = self.sample_pdf(Bins(s_vals).bounds, weights, self.n_samples, det=self.perturb_sample)
        if include_s_vals:
            s = torch.cat([s, s_vals], dim=-1)
        s = torch.sort(s, descending=self.lindisp)[0]
        z = torch.reciprocal(s) if self.lindisp else s
        if self.spherical:
            pts, depths = sphere.ray_sphere_intersect(rays_o, rays_d, z)
            sphers = sphere.cartesian2spherical(pts, inverse_r=self.lindisp)
            return sphers, depths, s, pts
        else:
            return rays_o[..., None, :] + rays_d[..., None, :] * z[..., None], z, s

    def sample_pdf(self, bins: torch.Tensor, weights: torch.Tensor, N: int, det=True):
        '''
        :param bins `Tensor(..., M+1)`: bounds of bins
        :param weights `Tensor(..., M)`: weights of bins
        :param N `int`: # of samples along each ray
        :param det `bool`: (default to `True`) perform deterministic sampling or not
        :return `Tensor(..., N)`: samples
        '''
        # Get pdf
Nianchen Deng's avatar
sync    
Nianchen Deng committed
195
        weights = weights + math.tiny                                          # prevent nans
Nianchen Deng's avatar
Nianchen Deng committed
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
        pdf = weights / torch.sum(weights, dim=-1, keepdim=True)                # [..., M]
        cdf = torch.cat([
            torch.zeros_like(pdf[..., :1]),
            torch.cumsum(pdf, dim=-1)
        ], dim=-1)                                                              # [..., M+1]

        # Take uniform samples
        dots_sh = list(weights.shape[:-1])
        M = weights.shape[-1]

        u = torch.linspace(0, 1, N, device=bins.device).expand(dots_sh + [N]) \
            if det else torch.rand(dots_sh + [N], device=bins.device)           # [..., N]

        # Invert CDF
        # [..., N, 1] >= [..., 1, M] ----> [..., N, M] ----> [..., N,]
        above_inds = torch.sum(u[..., None] >= cdf[..., None, :-1], dim=-1).long()

        # random sample inside each bin
        below_inds = torch.clamp(above_inds - 1, min=0)
        inds_g = torch.stack((below_inds, above_inds), dim=-1)                  # [..., N, 2]

        cdf = cdf[..., None, :].expand(dots_sh + [N, M + 1])                    # [..., N, M+1]
        cdf_g = torch.gather(cdf, dim=-1, index=inds_g)                         # [..., N, 2]

        bins = bins[..., None, :].expand(dots_sh + [N, M + 1])                  # [..., N, M+1]
        bins_g = torch.gather(bins, dim=-1, index=inds_g)                       # [..., N, 2]

        # fix numeric issue
        denom = cdf_g[..., 1] - cdf_g[..., 0]                                   # [..., N]
Nianchen Deng's avatar
sync    
Nianchen Deng committed
225
        denom = torch.where(denom < math.tiny, torch.ones_like(denom), denom)
Nianchen Deng's avatar
Nianchen Deng committed
226
227
        t = (u - cdf_g[..., 0]) / denom

Nianchen Deng's avatar
sync    
Nianchen Deng committed
228
        samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0] + math.tiny)
Nianchen Deng's avatar
Nianchen Deng committed
229
230
231
232

        return samples


Nianchen Deng's avatar
sync    
Nianchen Deng committed
233
class VoxelSampler(Module):
Nianchen Deng's avatar
Nianchen Deng committed
234

Nianchen Deng's avatar
sync    
Nianchen Deng committed
235
    def __init__(self, *, sample_step: float, **kwargs):
Nianchen Deng's avatar
Nianchen Deng committed
236
        """
Nianchen Deng's avatar
sync    
Nianchen Deng committed
237
        Initialize a VoxelSampler module
Nianchen Deng's avatar
Nianchen Deng committed
238
239

        :param perturb_sample: perturb the sample depths
Nianchen Deng's avatar
sync    
Nianchen Deng committed
240
        :param step_size: step size
Nianchen Deng's avatar
Nianchen Deng committed
241
242
        """
        super().__init__()
Nianchen Deng's avatar
sync    
Nianchen Deng committed
243
244
        self.sample_step = sample_step

Nianchen Deng's avatar
sync    
Nianchen Deng committed
245
246
    def _forward(self, rays_o: torch.Tensor, rays_d: torch.Tensor, space_module: Space, *,
                 perturb_sample: bool, **kwargs) -> Tuple[Samples, torch.Tensor]:
Nianchen Deng's avatar
sync    
Nianchen Deng committed
247
248
        """
        [summary]
Nianchen Deng's avatar
Nianchen Deng committed
249

Nianchen Deng's avatar
sync    
Nianchen Deng committed
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
        :param rays_o `Tensor(N, 3)`: rays' origin positions
        :param rays_d `Tensor(N, 3)`: rays' directions
        :param step_size `float`: gap between samples along a ray
        :return `Samples(N', P)`: samples along valid rays (which hit at least one voxel)
        :return `Tensor(N)`: valid rays mask
        """
        intersections = space_module.ray_intersect(rays_o, rays_d, 100)
        valid_rays_mask = intersections.hits > 0
        rays_o = rays_o[valid_rays_mask]
        rays_d = rays_d[valid_rays_mask]
        intersections = intersections[valid_rays_mask]  # (N) -> (N')
        n_rays = rays_o.size(0)
        ray_index_list = torch.arange(n_rays, device=rays_o.device, dtype=torch.long)  # (N')

        hits = intersections.hits
        min_depths = intersections.min_depths
        max_depths = intersections.max_depths
        voxel_indices = intersections.voxel_indices

        rays_near_depth = min_depths[:, :1]  # (N', 1)
        rays_far_depth = max_depths[ray_index_list, hits - 1][:, None]  # (N', 1)
        rays_length = rays_far_depth - rays_near_depth
        rays_steps = (rays_length / self.sample_step).ceil().long()
        rays_step_size = rays_length / rays_steps
        max_steps = rays_steps.max().item()
        rays_step = torch.arange(max_steps, device=rays_o.device,
                                 dtype=torch.float)[None].repeat(n_rays, 1)  # (N', P)
        invalid_samples_mask = rays_step >= rays_steps
        samples_min_depth = rays_near_depth + rays_step * rays_step_size
        samples_depth = samples_min_depth + rays_step_size \
Nianchen Deng's avatar
sync    
Nianchen Deng committed
280
            * (torch.rand_like(samples_min_depth) if perturb_sample else 0.5)  # (N', P)
Nianchen Deng's avatar
sync    
Nianchen Deng committed
281
282
283
284
285
        samples_dist = rays_step_size.repeat(1, max_steps)  # (N', 1) -> (N', P)
        samples_voxel_index = voxel_indices[
            ray_index_list[:, None],
            torch.searchsorted(max_depths, samples_depth)
        ]  # (N', P)
Nianchen Deng's avatar
sync    
Nianchen Deng committed
286
        samples_depth[invalid_samples_mask] = math.huge
Nianchen Deng's avatar
sync    
Nianchen Deng committed
287
288
289
290
291
292
293
294
295
296
297
298
299
        samples_dist[invalid_samples_mask] = 0
        samples_voxel_index[invalid_samples_mask] = -1

        rays_o, rays_d = rays_o[:, None], rays_d[:, None]
        return Samples(
            pts=rays_o + rays_d * samples_depth[..., None],
            dirs=rays_d.expand(-1, max_steps, -1),
            depths=samples_depth,
            dists=samples_dist,
            voxel_indices=samples_voxel_index
        ), valid_rays_mask

    @perf
Nianchen Deng's avatar
sync    
Nianchen Deng committed
300
301
    def forward(self, rays_o: torch.Tensor, rays_d: torch.Tensor,
                space: Space, *, perturb_sample: bool, **kwargs) -> Tuple[Samples, torch.Tensor]:
Nianchen Deng's avatar
sync    
Nianchen Deng committed
302
303
304
305
306
307
308
309
        """
        [summary]

        :param rays_o `Tensor(N, 3)`: [description]
        :param rays_d `Tensor(N, 3)`: [description]
        :param step_size `float`: [description]
        :return `Samples(N, P)`: [description]
        """
Nianchen Deng's avatar
sync    
Nianchen Deng committed
310
        intersections = space.ray_intersect(rays_o, rays_d, 100)
Nianchen Deng's avatar
sync    
Nianchen Deng committed
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
        valid_rays_mask = intersections.hits > 0
        rays_o = rays_o[valid_rays_mask]
        rays_d = rays_d[valid_rays_mask]
        intersections = intersections[valid_rays_mask]  # (N) -> (N')

        checkpoint("Ray intersect")

        if intersections.size == 0:
            return None, valid_rays_mask
        else:
            min_depth = intersections.min_depths
            max_depth = intersections.max_depths
            pts_idx = intersections.voxel_indices
            dists = max_depth - min_depth
            tot_dists = dists.sum(dim=-1, keepdim=True)  # (N, 1)
            probs = dists / tot_dists
            steps = tot_dists[:, 0] / self.sample_step

            # sample points and use middle point approximation
            sampled_indices, sampled_depths, sampled_dists = inverse_cdf_sampling(
Nianchen Deng's avatar
sync    
Nianchen Deng committed
331
                pts_idx, min_depth, max_depth, probs, steps, -1, not perturb_sample)
Nianchen Deng's avatar
sync    
Nianchen Deng committed
332
333
334
            sampled_indices = sampled_indices.long()
            invalid_idx_mask = sampled_indices.eq(-1)
            sampled_dists.clamp_min_(0).masked_fill_(invalid_idx_mask, 0)
Nianchen Deng's avatar
sync    
Nianchen Deng committed
335
            sampled_depths.masked_fill_(invalid_idx_mask, math.huge)
Nianchen Deng's avatar
sync    
Nianchen Deng committed
336
337
338
339
340
341
342
343
344
345
346

            checkpoint("Inverse CDF sampling")

            rays_o, rays_d = rays_o[:, None], rays_d[:, None]
            return Samples(
                pts=rays_o + rays_d * sampled_depths[..., None],
                dirs=rays_d.expand(-1, sampled_depths.size(1), -1),
                depths=sampled_depths,
                dists=sampled_dists,
                voxel_indices=sampled_indices
            ), valid_rays_mask