class VoxelSampler(Module): def __init__(self, *, sample_step: float, **kwargs): """ Initialize a VoxelSampler module :param perturb_sample: perturb the sample depths :param step_size: step size """ super().__init__() self.sample_step = sample_step def _forward(self, rays_o: torch.Tensor, rays_d: torch.Tensor, space_module: Space, *, perturb_sample: bool, **kwargs) -> tuple[Samples, torch.Tensor]: """ [summary] :param rays_o `Tensor(N, 3)`: rays' origin positions :param rays_d `Tensor(N, 3)`: rays' directions :param step_size `float`: gap between samples along a ray :return `Samples(N', P)`: samples along valid rays (which hit at least one voxel) :return `Tensor(N)`: valid rays mask """ intersections = space_module.ray_intersect(rays_o, rays_d, 100) valid_rays_mask = intersections.hits > 0 rays_o = rays_o[valid_rays_mask] rays_d = rays_d[valid_rays_mask] intersections = intersections[valid_rays_mask] # (N) -> (N') n_rays = rays_o.size(0) ray_index_list = torch.arange(n_rays, device=rays_o.device, dtype=torch.long) # (N') hits = intersections.hits min_depths = intersections.min_depths max_depths = intersections.max_depths voxel_indices = intersections.voxel_indices rays_near_depth = min_depths[:, :1] # (N', 1) rays_far_depth = max_depths[ray_index_list, hits - 1][:, None] # (N', 1) rays_length = rays_far_depth - rays_near_depth rays_steps = (rays_length / self.sample_step).ceil().long() rays_step_size = rays_length / rays_steps max_steps = rays_steps.max().item() rays_step = torch.arange(max_steps, device=rays_o.device, dtype=torch.float)[None].repeat(n_rays, 1) # (N', P) invalid_samples_mask = rays_step >= rays_steps samples_min_depth = rays_near_depth + rays_step * rays_step_size samples_depth = samples_min_depth + rays_step_size \ * (torch.rand_like(samples_min_depth) if perturb_sample else 0.5) # (N', P) samples_dist = rays_step_size.repeat(1, max_steps) # (N', 1) -> (N', P) samples_voxel_index = voxel_indices[ ray_index_list[:, None], torch.searchsorted(max_depths, samples_depth) ] # (N', P) samples_depth[invalid_samples_mask] = math.huge samples_dist[invalid_samples_mask] = 0 samples_voxel_index[invalid_samples_mask] = -1 rays_o, rays_d = rays_o[:, None], rays_d[:, None] return Samples( pts=rays_o + rays_d * samples_depth[..., None], dirs=rays_d.expand(-1, max_steps, -1), depths=samples_depth, dists=samples_dist, voxel_indices=samples_voxel_index ), valid_rays_mask @profile def forward(self, rays_o: torch.Tensor, rays_d: torch.Tensor, space: Space, *, perturb_sample: bool, **kwargs) -> tuple[Samples, torch.Tensor]: """ [summary] :param rays_o `Tensor(N, 3)`: [description] :param rays_d `Tensor(N, 3)`: [description] :param step_size `float`: [description] :return `Samples(N, P)`: [description] """ with profile("Ray intersect"): intersections = space.ray_intersect(rays_o, rays_d, 100) valid_rays_mask = intersections.hits > 0 rays_o = rays_o[valid_rays_mask] rays_d = rays_d[valid_rays_mask] intersections = intersections[valid_rays_mask] # (N) -> (N') if intersections.size == 0: return None, valid_rays_mask else: with profile("Inverse CDF sampling"): min_depth = intersections.min_depths max_depth = intersections.max_depths pts_idx = intersections.voxel_indices dists = max_depth - min_depth tot_dists = dists.sum(dim=-1, keepdim=True) # (N, 1) probs = dists / tot_dists steps = tot_dists[:, 0] / self.sample_step # sample points and use middle point approximation sampled_indices, sampled_depths, sampled_dists = inverse_cdf_sampling( pts_idx, min_depth, max_depth, probs, steps, -1, not perturb_sample) sampled_indices = sampled_indices.long() invalid_idx_mask = sampled_indices.eq(-1) sampled_dists.clamp_min_(0).masked_fill_(invalid_idx_mask, 0) sampled_depths.masked_fill_(invalid_idx_mask, math.huge) rays_o, rays_d = rays_o[:, None], rays_d[:, None] return Samples( pts=rays_o + rays_d * sampled_depths[..., None], dirs=rays_d.expand(-1, sampled_depths.size(1), -1), depths=sampled_depths, dists=sampled_dists, voxel_indices=sampled_indices ), valid_rays_mask