sampler.py 9.05 KB
Newer Older
Nianchen Deng's avatar
sync    
Nianchen Deng committed
1
from .__common__ import *
Nianchen Deng's avatar
sync    
Nianchen Deng committed
2
3
from .space import Space
from clib import *
Nianchen Deng's avatar
Nianchen Deng committed
4
from utils import sphere
Nianchen Deng's avatar
sync    
Nianchen Deng committed
5
from utils.misc import grid2d
Nianchen Deng's avatar
sync    
Nianchen Deng committed
6

Nianchen Deng's avatar
sync    
Nianchen Deng committed
7
__all__ = ["Sampler", "UniformSampler", "PdfSampler"]
Nianchen Deng's avatar
sync    
Nianchen Deng committed
8

Nianchen Deng's avatar
Nianchen Deng committed
9

Nianchen Deng's avatar
sync    
Nianchen Deng committed
10
11
class Sampler(nn.Module):
    _samples_indices_cached: torch.Tensor | None
Nianchen Deng's avatar
Nianchen Deng committed
12

Nianchen Deng's avatar
sync    
Nianchen Deng committed
13
    def __init__(self, x_chns: int, d_chns: int):
Nianchen Deng's avatar
Nianchen Deng committed
14
15
16
        """
        Initialize a Sampler module
        """
Nianchen Deng's avatar
sync    
Nianchen Deng committed
17
        super().__init__({}, {"x": x_chns, "d": d_chns})
Nianchen Deng's avatar
sync    
Nianchen Deng committed
18
        self._samples_indices_cached = None
Nianchen Deng's avatar
Nianchen Deng committed
19

Nianchen Deng's avatar
sync    
Nianchen Deng committed
20
21
22
    # stub method for type hint
    def __call__(self, rays: Rays, space: Space, **kwargs) -> Samples:
        ...
Nianchen Deng's avatar
Nianchen Deng committed
23

Nianchen Deng's avatar
sync    
Nianchen Deng committed
24
    def _get_samples_indices(self, pts: torch.Tensor) -> torch.Tensor:
Nianchen Deng's avatar
Nianchen Deng committed
25
        """
Nianchen Deng's avatar
sync    
Nianchen Deng committed
26
27
        Get 2D indices of samples. The first value is the index of ray, while the second value is 
        the index of sample in a ray.
Nianchen Deng's avatar
sync    
Nianchen Deng committed
28

Nianchen Deng's avatar
sync    
Nianchen Deng committed
29
30
31
        :param pts `Tensor(B, P, 3)`: the sample points
        :return `Tensor(B, P)`: the 2D indices of samples
        """
Nianchen Deng's avatar
sync    
Nianchen Deng committed
32
        if self._samples_indices_cached is None\
Nianchen Deng's avatar
sync    
Nianchen Deng committed
33
                or self._samples_indices_cached.device != pts.device\
Nianchen Deng's avatar
sync    
Nianchen Deng committed
34
35
                or self._samples_indices_cached.shape[0] < pts.shape[0]\
                or self._samples_indices_cached.shape[1] < pts.shape[1]:
Nianchen Deng's avatar
sync    
Nianchen Deng committed
36
            self._samples_indices_cached = grid2d(*pts.shape[:2], indexing="ij", device=pts.device)
Nianchen Deng's avatar
sync    
Nianchen Deng committed
37
        return self._samples_indices_cached[:pts.shape[0], :pts.shape[1]]
Nianchen Deng's avatar
sync    
Nianchen Deng committed
38

Nianchen Deng's avatar
sync    
Nianchen Deng committed
39
    def _get_samples(self, rays: Rays, space: Space, t_vals: torch.Tensor, mode: str) -> Samples:
Nianchen Deng's avatar
sync    
Nianchen Deng committed
40
        """
Nianchen Deng's avatar
sync    
Nianchen Deng committed
41
        Get samples along rays at sample steps specified by `t_vals`.
Nianchen Deng's avatar
sync    
Nianchen Deng committed
42

Nianchen Deng's avatar
sync    
Nianchen Deng committed
43
44
        :param rays_o `Tensor(B, 3)`: rays' origin
        :param rays_d `Tensor(B, 3)`: rays' direction
Nianchen Deng's avatar
sync    
Nianchen Deng committed
45
46
        :param t_vals `Tensor(B, P)`: sample steps
        :param mode `str`: sample mode, one of "xyz", "xyz_disp", "spherical", "spherical_radius"
Nianchen Deng's avatar
sync    
Nianchen Deng committed
47
        :return `Samples(B, P)`: samples
Nianchen Deng's avatar
sync    
Nianchen Deng committed
48
        """
Nianchen Deng's avatar
sync    
Nianchen Deng committed
49
50
51
52
53
54
55
56
57
58
59
60
        if mode == "xyz":
            z_vals = t_vals
            pts = rays.get_points(z_vals)
        elif mode == "xyz_disp":
            z_vals = t_vals.reciprocal()
            pts = rays.get_points(z_vals)
        elif mode == "spherical":
            z_vals = t_vals.reciprocal()
            pts = sphere.cartesian2spherical(rays.get_points(z_vals), inverse_r=True)
        elif mode == "spherical_radius":
            z_vals = sphere.ray_sphere_intersect(rays, t_vals.reciprocal())
            pts = sphere.cartesian2spherical(rays.get_points(z_vals), inverse_r=True)
Nianchen Deng's avatar
sync    
Nianchen Deng committed
61
        else:
Nianchen Deng's avatar
sync    
Nianchen Deng committed
62
            raise ValueError(f"Unknown mode: {mode}")
Nianchen Deng's avatar
Nianchen Deng committed
63

Nianchen Deng's avatar
sync    
Nianchen Deng committed
64
65
66
67
68
69
70
71
72
73
74
75
        rays_d = rays.rays_d.unsqueeze(1)  # (B, 1, 3)
        dists = union(z_vals[..., 1:] - z_vals[..., :-1], math.huge)  # (B, P)
        dists *= torch.norm(rays_d, dim=-1)
        return Samples(
            pts=pts,
            dirs=rays_d.expand(*pts.shape[:2], -1),
            depths=z_vals,
            t_vals=t_vals,
            dists=dists,
            voxel_indices=space.get_voxel_indices(pts) if space else 0,
            indices=self._get_samples_indices(pts)
        )
Nianchen Deng's avatar
Nianchen Deng committed
76
77


Nianchen Deng's avatar
sync    
Nianchen Deng committed
78
79
80
81
82
class UniformSampler(Sampler):
    """
    This module expands NeRF's code of uniform sampling to support our spherical sampling and enable
    the trace of samples' indices.
    """
Nianchen Deng's avatar
Nianchen Deng committed
83

Nianchen Deng's avatar
sync    
Nianchen Deng committed
84
85
    def __init__(self):
        super().__init__(3, 3)
Nianchen Deng's avatar
Nianchen Deng committed
86

Nianchen Deng's avatar
sync    
Nianchen Deng committed
87
    def _sample(self, range: tuple[float, float], n_rays: int, n_samples: int, perturb: bool) -> torch.Tensor:
Nianchen Deng's avatar
Nianchen Deng committed
88
        """
Nianchen Deng's avatar
sync    
Nianchen Deng committed
89
        Generate sample steps along rays in the specified range.
Nianchen Deng's avatar
Nianchen Deng committed
90

Nianchen Deng's avatar
sync    
Nianchen Deng committed
91
92
93
94
95
        :param range `float, float`: sampling range
        :param n_rays `int`: number of rays (B)
        :param n_samples `int`: number of samples per ray (P)
        :param perturb `bool`: whether perturb sampling
        :return `Tensor(B, P)`: sampled "t"s along rays
Nianchen Deng's avatar
Nianchen Deng committed
96
        """
Nianchen Deng's avatar
sync    
Nianchen Deng committed
97
98
99
100
101
102
103
104
        t_vals = torch.linspace(*range, n_samples, device=self.device)  # (P)
        if perturb:
            mids = .5 * (t_vals[..., 1:] + t_vals[..., :-1])
            upper = union(mids, t_vals[..., -1:])
            lower = union(t_vals[..., :1], mids)
            # stratified samples in those intervals
            t_vals = t_vals.expand(n_rays, -1)
            t_vals = lower + (upper - lower) * torch.rand_like(t_vals)
Nianchen Deng's avatar
Nianchen Deng committed
105
        else:
Nianchen Deng's avatar
sync    
Nianchen Deng committed
106
107
            t_vals = t_vals.expand(n_rays, -1)
        return t_vals
Nianchen Deng's avatar
Nianchen Deng committed
108

Nianchen Deng's avatar
sync    
Nianchen Deng committed
109
110
111
112
113
114
    # stub method for type hint
    def __call__(self, rays: Rays, space: Space, *,
                 range: tuple[float, float],
                 mode: str,
                 n_samples: int,
                 perturb: bool) -> Samples:
Nianchen Deng's avatar
Nianchen Deng committed
115
        """
Nianchen Deng's avatar
sync    
Nianchen Deng committed
116
        Sample points along rays.
Nianchen Deng's avatar
sync    
Nianchen Deng committed
117

Nianchen Deng's avatar
sync    
Nianchen Deng committed
118
119
120
121
122
123
124
        :param rays `Rays(B)`: rays
        :param space `Space`: sample space
        :param range `float, float`: sampling range
        :param mode `str`: sample mode, one of "xyz", "xyz_disp", "spherical", "spherical_radius"
        :param n_samples `int`: number of samples per ray
        :param perturb `bool`: whether perturb sampling, defaults to `False`
        :return `Samples(B, P)`: samples
Nianchen Deng's avatar
sync    
Nianchen Deng committed
125
        """
Nianchen Deng's avatar
sync    
Nianchen Deng committed
126
        ...
Nianchen Deng's avatar
Nianchen Deng committed
127

Nianchen Deng's avatar
sync    
Nianchen Deng committed
128
129
130
131
132
133
134
135
136
    @profile
    def forward(self, rays: Rays, space: Space, *,
                range: tuple[float, float],
                mode: str,
                n_samples: int,
                perturb: bool) -> Samples:
        t_range = range if mode == "xyz" else (1. / range[0], 1. / range[1])
        t_vals = self._sample(t_range, rays.shape[0], n_samples, perturb)  # (B, P)
        return self._get_samples(rays, space, t_vals, mode)
Nianchen Deng's avatar
sync    
Nianchen Deng committed
137
138


Nianchen Deng's avatar
sync    
Nianchen Deng committed
139
140
141
142
class PdfSampler(Sampler):
    """
    Hierarchical sampling (section 5.2 of NeRF)
    """
Nianchen Deng's avatar
sync    
Nianchen Deng committed
143

Nianchen Deng's avatar
sync    
Nianchen Deng committed
144
145
    def __init__(self):
        super().__init__(3, 3)
Nianchen Deng's avatar
sync    
Nianchen Deng committed
146

Nianchen Deng's avatar
sync    
Nianchen Deng committed
147
148
    def _sample(self, t_vals: torch.Tensor, weights: torch.Tensor, n_importance: int,
                perturb: bool, include_existed: bool, sort_descending: bool) -> torch.Tensor:
Nianchen Deng's avatar
sync    
Nianchen Deng committed
149
        """
Nianchen Deng's avatar
sync    
Nianchen Deng committed
150
        Generate sample steps by PDF according to existed sample steps and their weights.
Nianchen Deng's avatar
sync    
Nianchen Deng committed
151

Nianchen Deng's avatar
sync    
Nianchen Deng committed
152
153
154
155
156
157
        :param t_vals `Tensor(B, P)`: existed sample steps
        :param weights `Tensor(B, P)`: weights of existed sample steps
        :param n_importance `int`: number of samples to generate for each ray
        :param perturb `bool`: whether perturb sampling
        :param include_existed `bool`: whether to include existed samples in the output
        :return `Tensor(B, P'[+P])`: the output sample steps
Nianchen Deng's avatar
sync    
Nianchen Deng committed
158
        """
Nianchen Deng's avatar
sync    
Nianchen Deng committed
159
160
        bins = .5 * (t_vals[..., 1:] + t_vals[..., :-1])  # (B, P - 1)
        weights = weights[..., 1:-1] + math.tiny  # (B, P - 2)
Nianchen Deng's avatar
sync    
Nianchen Deng committed
161

Nianchen Deng's avatar
sync    
Nianchen Deng committed
162
163
164
        # Get PDF
        pdf = weights / torch.sum(weights, -1, keepdim=True)
        cdf = union(0., torch.cumsum(pdf, -1))  # (B, P - 1)
Nianchen Deng's avatar
sync    
Nianchen Deng committed
165

Nianchen Deng's avatar
sync    
Nianchen Deng committed
166
167
168
        # Take uniform samples
        if perturb:
            u = torch.rand(*cdf.shape[:-1], n_importance, device=self.device)
Nianchen Deng's avatar
sync    
Nianchen Deng committed
169
        else:
Nianchen Deng's avatar
sync    
Nianchen Deng committed
170
171
            u = torch.linspace(0., 1., steps=n_importance, device=self.device).\
                expand(*cdf.shape[:-1], -1)
Nianchen Deng's avatar
sync    
Nianchen Deng committed
172

Nianchen Deng's avatar
sync    
Nianchen Deng committed
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
        # Invert CDF
        u = u.contiguous()  # (B, P')
        inds = torch.searchsorted(cdf, u, right=True)  # (B, P')
        inds_g = torch.stack([
            (inds - 1).clamp_min(0),  # below
            inds.clamp_max(cdf.shape[-1] - 1)  # above
        ], -1)  # (B, P', 2)

        matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]]  # [B, P', P - 1]
        cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g)  # (B, P', 2)
        bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g)  # (B, P', 2)

        denom = cdf_g[..., 1] - cdf_g[..., 0]
        denom = torch.where(denom < math.tiny, torch.ones_like(denom), denom)
        t = (u - cdf_g[..., 0]) / denom
        t_samples = (bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0])).detach()
        if include_existed:
            return torch.sort(union(t_vals, t_samples), -1, descending=sort_descending)[0]
        else:
            return t_samples

    # stub method for type hint
    def __call__(self, rays: Rays, space: Space, t_vals: torch.Tensor, weights: torch.Tensor, *,
                 mode: str,
                 n_importance: int,
                 perturb: bool,
                 include_existed_samples: bool) -> Samples:
        """
        Sample points along rays using PDF sampling based on existed samples.

        :param rays `Rays(B)`: rays
        :param space `Space`: sample space
        :param t_vals `Tensor(B, P)`: existed sample steps
        :param weights `Tensor(B, P)`: weights of existed sample steps
        :param mode `str`: sample mode, one of "xyz", "xyz_disp", "spherical", "spherical_radius"
        :param n_importance `int`: number of samples to generate using PDF sampling for each ray
        :param perturb `bool`: whether perturb sampling, defaults to `False`
        :param include_existed_samples `bool`: whether to include existed samples in the output,
            defaults to `True`
        :return `Samples(B, P'[+P])`: samples
        """
        ...

    @profile
    def forward(self, rays: Rays, space: Space, t_vals: torch.Tensor, weights: torch.Tensor, *,
                mode: str,
                n_importance: int,
                perturb: bool,
                include_existed_samples: bool) -> Samples:
        t_vals = self._sample(t_vals, weights, n_importance, perturb, include_existed_samples,
                              mode != "xyz")
        return self._get_samples(rays, space, t_vals, mode)