sampler.py 4.97 KB
Newer Older
Nianchen Deng's avatar
sync    
Nianchen Deng committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
class VoxelSampler(Module):

    def __init__(self, *, sample_step: float, **kwargs):
        """
        Initialize a VoxelSampler module

        :param perturb_sample: perturb the sample depths
        :param step_size: step size
        """
        super().__init__()
        self.sample_step = sample_step

    def _forward(self, rays_o: torch.Tensor, rays_d: torch.Tensor, space_module: Space, *,
                 perturb_sample: bool, **kwargs) -> tuple[Samples, torch.Tensor]:
        """
        [summary]

        :param rays_o `Tensor(N, 3)`: rays' origin positions
        :param rays_d `Tensor(N, 3)`: rays' directions
        :param step_size `float`: gap between samples along a ray
        :return `Samples(N', P)`: samples along valid rays (which hit at least one voxel)
        :return `Tensor(N)`: valid rays mask
        """
        intersections = space_module.ray_intersect(rays_o, rays_d, 100)
        valid_rays_mask = intersections.hits > 0
        rays_o = rays_o[valid_rays_mask]
        rays_d = rays_d[valid_rays_mask]
        intersections = intersections[valid_rays_mask]  # (N) -> (N')
        n_rays = rays_o.size(0)
        ray_index_list = torch.arange(n_rays, device=rays_o.device, dtype=torch.long)  # (N')

        hits = intersections.hits
        min_depths = intersections.min_depths
        max_depths = intersections.max_depths
        voxel_indices = intersections.voxel_indices

        rays_near_depth = min_depths[:, :1]  # (N', 1)
        rays_far_depth = max_depths[ray_index_list, hits - 1][:, None]  # (N', 1)
        rays_length = rays_far_depth - rays_near_depth
        rays_steps = (rays_length / self.sample_step).ceil().long()
        rays_step_size = rays_length / rays_steps
        max_steps = rays_steps.max().item()
        rays_step = torch.arange(max_steps, device=rays_o.device,
                                 dtype=torch.float)[None].repeat(n_rays, 1)  # (N', P)
        invalid_samples_mask = rays_step >= rays_steps
        samples_min_depth = rays_near_depth + rays_step * rays_step_size
        samples_depth = samples_min_depth + rays_step_size \
            * (torch.rand_like(samples_min_depth) if perturb_sample else 0.5)  # (N', P)
        samples_dist = rays_step_size.repeat(1, max_steps)  # (N', 1) -> (N', P)
        samples_voxel_index = voxel_indices[
            ray_index_list[:, None],
            torch.searchsorted(max_depths, samples_depth)
        ]  # (N', P)
        samples_depth[invalid_samples_mask] = math.huge
        samples_dist[invalid_samples_mask] = 0
        samples_voxel_index[invalid_samples_mask] = -1

        rays_o, rays_d = rays_o[:, None], rays_d[:, None]
        return Samples(
            pts=rays_o + rays_d * samples_depth[..., None],
            dirs=rays_d.expand(-1, max_steps, -1),
            depths=samples_depth,
            dists=samples_dist,
            voxel_indices=samples_voxel_index
        ), valid_rays_mask

    @profile
    def forward(self, rays_o: torch.Tensor, rays_d: torch.Tensor,
                space: Space, *, perturb_sample: bool, **kwargs) -> tuple[Samples, torch.Tensor]:
        """
        [summary]

        :param rays_o `Tensor(N, 3)`: [description]
        :param rays_d `Tensor(N, 3)`: [description]
        :param step_size `float`: [description]
        :return `Samples(N, P)`: [description]
        """
        with profile("Ray intersect"):
            intersections = space.ray_intersect(rays_o, rays_d, 100)
            valid_rays_mask = intersections.hits > 0
            rays_o = rays_o[valid_rays_mask]
            rays_d = rays_d[valid_rays_mask]
            intersections = intersections[valid_rays_mask]  # (N) -> (N')

        if intersections.size == 0:
            return None, valid_rays_mask
        else:
            with profile("Inverse CDF sampling"):
                min_depth = intersections.min_depths
                max_depth = intersections.max_depths
                pts_idx = intersections.voxel_indices
                dists = max_depth - min_depth
                tot_dists = dists.sum(dim=-1, keepdim=True)  # (N, 1)
                probs = dists / tot_dists
                steps = tot_dists[:, 0] / self.sample_step

                # sample points and use middle point approximation
                sampled_indices, sampled_depths, sampled_dists = inverse_cdf_sampling(
                    pts_idx, min_depth, max_depth, probs, steps, -1, not perturb_sample)
                sampled_indices = sampled_indices.long()
                invalid_idx_mask = sampled_indices.eq(-1)
                sampled_dists.clamp_min_(0).masked_fill_(invalid_idx_mask, 0)
                sampled_depths.masked_fill_(invalid_idx_mask, math.huge)

            rays_o, rays_d = rays_o[:, None], rays_d[:, None]
            return Samples(
                pts=rays_o + rays_d * sampled_depths[..., None],
                dirs=rays_d.expand(-1, sampled_depths.size(1), -1),
                depths=sampled_depths,
                dists=sampled_dists,
                voxel_indices=sampled_indices
            ), valid_rays_mask