sampler.py 7.58 KB
Newer Older
Nianchen Deng's avatar
Nianchen Deng committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
from typing import Tuple
import torch
import torch.nn as nn
from utils import device
from utils import sphere
from utils.constants import *
from .generic import *


class Bins(object):

    def __init__(self, vals: torch.Tensor):
        self.vals = vals
        self.bounds = torch.cat([
            self.vals[:1],
            0.5 * (self.vals[1:] + self.vals[:-1]),
            self.vals[-1:]
        ])
        self.up = self.bounds[1:]
        self.lo = self.bounds[:-1]

    @staticmethod
    def linspace(val_range: Tuple[float, float], N: int, device: torch.device = None):
        return Bins(torch.linspace(*val_range, N, device=device))

    def to(self, device: torch.device):
        self.vals = self.vals.to(device)
        self.bounds = self.bounds.to(device)
        self.up = self.bounds[1:]
        self.lo = self.bounds[:-1]


class Sampler(nn.Module):

    def __init__(self, *, sample_range: Tuple[float, float], n_samples: int,
                 perturb_sample: bool, spherical: bool, lindisp: bool):
        """
        Initialize a Sampler module

        :param depth_range: depth range for sampler
        :param n_samples: count to sample along ray
        :param perturb_sample: perturb the sample depths
        :param lindisp: If True, sample linearly in inverse depth rather than in depth
        """
        super().__init__()
        self.lindisp = lindisp
        self.spherical = spherical
        self.perturb_sample = perturb_sample
        s_range = (1 / sample_range[0], 1 / sample_range[1]) if self.lindisp else sample_range
        self.bins = Bins.linspace(s_range, n_samples, device=device.default())

    def forward(self, rays_o, rays_d):
        """
        Sample points along rays. return Spherical or Cartesian coordinates, 
        specified by `self.shperical`

        :param rays_o `Tensor(B, 3)`: rays' origin
        :param rays_d `Tensor(B, 3)`: rays' direction
        :return `Tensor(B, N, 3)`: sampled points
        :return `Tensor(B, N)`: corresponding depths along rays
        """
        s = self.bins.vals.expand(rays_o.size(0), -1)
        if self.perturb_sample:
            s = self.bins.lo + (self.bins.up - self.bins.lo) * torch.rand_like(s)
        z = torch.reciprocal(s) if self.lindisp else s
        if self.spherical:
            pts, depths = sphere.ray_sphere_intersect(rays_o, rays_d, z)
            sphers = sphere.cartesian2spherical(pts, inverse_r=self.lindisp)
            return sphers, depths, s, pts
        else:
            return rays_o[..., None, :] + rays_d[..., None, :] * z[..., None], z, s, None


class PdfSampler(nn.Module):

    def __init__(self, *, depth_range: Tuple[float, float], n_samples: int, perturb_sample: bool,
                 spherical: bool, lindisp: bool):
        """
        Initialize a Sampler module

        :param depth_range: depth range for sampler
        :param n_samples: count to sample along ray
        :param perturb_sample: perturb the sample depths
        :param lindisp: If True, sample linearly in inverse depth rather than in depth
        """
        super().__init__()
        self.lindisp = lindisp
        self.perturb_sample = perturb_sample
        self.spherical = spherical
        self.n_samples = n_samples
        self.s_range = (1 / depth_range[0], 1 / depth_range[1]) if self.lindisp else depth_range

    def forward(self, rays_o, rays_d, *, weights, s_vals=None, include_s_vals=False):
        """
        Sample points along rays. return Spherical or Cartesian coordinates, 
        specified by `self.shperical`

        :param rays_o `Tensor(B, 3)`: rays' origin
        :param rays_d `Tensor(B, 3)`: rays' direction
        :param weights `Tensor(B, M)`: weights of sample bins
        :param s_vals `Tensor(B, M)`: (optional) center of sample bins
        :param include_s_vals `bool`: (default to `False`) include `s_vals` in the sample array
        :return `Tensor(B, N, 3)`: sampled points
        :return `Tensor(B, N)`: corresponding depths along rays
        """
        if s_vals is None:
            s_vals = torch.linspace(*self.s_range, self.n_samples, device=device.default())
        s = self.sample_pdf(Bins(s_vals).bounds, weights, self.n_samples, det=self.perturb_sample)
        if include_s_vals:
            s = torch.cat([s, s_vals], dim=-1)
        s = torch.sort(s, descending=self.lindisp)[0]
        z = torch.reciprocal(s) if self.lindisp else s
        if self.spherical:
            pts, depths = sphere.ray_sphere_intersect(rays_o, rays_d, z)
            sphers = sphere.cartesian2spherical(pts, inverse_r=self.lindisp)
            return sphers, depths, s, pts
        else:
            return rays_o[..., None, :] + rays_d[..., None, :] * z[..., None], z, s

    def sample_pdf(self, bins: torch.Tensor, weights: torch.Tensor, N: int, det=True):
        '''
        :param bins `Tensor(..., M+1)`: bounds of bins
        :param weights `Tensor(..., M)`: weights of bins
        :param N `int`: # of samples along each ray
        :param det `bool`: (default to `True`) perform deterministic sampling or not
        :return `Tensor(..., N)`: samples
        '''
        # Get pdf
        weights = weights + TINY_FLOAT                                          # prevent nans
        pdf = weights / torch.sum(weights, dim=-1, keepdim=True)                # [..., M]
        cdf = torch.cat([
            torch.zeros_like(pdf[..., :1]),
            torch.cumsum(pdf, dim=-1)
        ], dim=-1)                                                              # [..., M+1]

        # Take uniform samples
        dots_sh = list(weights.shape[:-1])
        M = weights.shape[-1]

        u = torch.linspace(0, 1, N, device=bins.device).expand(dots_sh + [N]) \
            if det else torch.rand(dots_sh + [N], device=bins.device)           # [..., N]

        # Invert CDF
        # [..., N, 1] >= [..., 1, M] ----> [..., N, M] ----> [..., N,]
        above_inds = torch.sum(u[..., None] >= cdf[..., None, :-1], dim=-1).long()

        # random sample inside each bin
        below_inds = torch.clamp(above_inds - 1, min=0)
        inds_g = torch.stack((below_inds, above_inds), dim=-1)                  # [..., N, 2]

        cdf = cdf[..., None, :].expand(dots_sh + [N, M + 1])                    # [..., N, M+1]
        cdf_g = torch.gather(cdf, dim=-1, index=inds_g)                         # [..., N, 2]

        bins = bins[..., None, :].expand(dots_sh + [N, M + 1])                  # [..., N, M+1]
        bins_g = torch.gather(bins, dim=-1, index=inds_g)                       # [..., N, 2]

        # fix numeric issue
        denom = cdf_g[..., 1] - cdf_g[..., 0]                                   # [..., N]
        denom = torch.where(denom < TINY_FLOAT, torch.ones_like(denom), denom)
        t = (u - cdf_g[..., 0]) / denom

        samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0] + TINY_FLOAT)

        return samples


class VoxelSampler(nn.Module):

    def __init__(self, *, depth_range: Tuple[float, float], n_samples: int, perturb_sample: bool,
                 lindisp: bool, space):
        """
        Initialize a Sampler module

        :param depth_range: depth range for sampler
        :param n_samples: count to sample along ray
        :param perturb_sample: perturb the sample depths
        :param lindisp: If True, sample linearly in inverse depth rather than in depth
        """
        super().__init__()
        self.lindisp = lindisp
        self.perturb_sample = perturb_sample
        self.n_samples = n_samples
        self.space = space
        self.s_range = (1 / depth_range[0], 1 / depth_range[1]) if self.lindisp else depth_range

    def forward(self, rays_o, rays_d, *, weights, s_vals=None, include_s_vals=False):