import torch from typing import Tuple from .generic import * from .space import Space from clib import * from utils import device from utils import sphere from utils import misc from utils import math from utils.module import Module from utils.samples import Samples from utils.perf import perf, checkpoint class Bins(object): @property def up(self): return self.bounds[1:] @property def lo(self): return self.bounds[:-1] def __init__(self, vals: torch.Tensor): self.vals = vals self.bounds = torch.cat([ self.vals[:1], 0.5 * (self.vals[1:] + self.vals[:-1]), self.vals[-1:] ]) @staticmethod def linspace(val_range: Tuple[float, float], N: int, device: torch.device = None): return Bins(torch.linspace(*val_range, N, device=device)) def to(self, device: torch.device): self.vals = self.vals.to(device) self.bounds = self.bounds.to(device) class Sampler(Module): def __init__(self, **kwargs): """ Initialize a Sampler module """ super().__init__() self._samples_indices_cached = None def _sample(self, range: Tuple[float, float], n_rays: int, n_samples: int, perturb: bool, device: torch.device) -> torch.Tensor: """ [summary] :param t_range `float, float`: sampling range :param n_rays `int`: number of rays (B) :param n_samples `int`: number of samples per ray (P) :param perturb `bool`: whether perturb sampling :param device `torch.device`: the device used to create tensors :return `Tensor(B, P+1)`: sampling bounds of t """ bounds = torch.linspace(*range, n_samples + 1, device=device) # (P+1) if perturb: rand_bounds = torch.cat([ bounds[:1], 0.5 * (bounds[1:] + bounds[:-1]), bounds[-1:] ]) rand_vals = torch.rand(n_rays, n_samples + 1, device=device) bounds = rand_bounds[:-1] * (1 - rand_vals) + rand_bounds[1:] * rand_vals else: bounds = bounds[None].expand(n_rays, -1) return bounds def _get_samples_indices(self, pts: torch.Tensor): if self._samples_indices_cached is None\ or self._samples_indices_cached.shape[0] < pts.shape[0]\ or self._samples_indices_cached.shape[1] < pts.shape[1]: self._samples_indices_cached = misc.meshgrid( *pts.shape[:2], swap_dim=True, device=pts.device) return self._samples_indices_cached[:pts.shape[0], :pts.shape[1]] @perf def forward(self, rays_o: torch.Tensor, rays_d: torch.Tensor, space_: Space, *, sample_range: Tuple[float, float], n_samples: int, lindisp: bool = False, perturb_sample: bool = True, spherical: bool = False, **kwargs) -> Tuple[Samples, torch.Tensor]: """ Sample points along rays. :param rays_o `Tensor(B, 3)`: rays' origin :param rays_d `Tensor(B, 3)`: rays' direction :param sample_range `float, float`: sampling range :param n_samples `int`: number of samples per ray :param lindisp `bool`: whether sample linearly in disparity space (1/depth) :param perturb_sample `bool`: whether perturb sampling :return `Samples(B, P)`: samples """ if spherical: t_bounds = self._sample(sample_range, rays_o.shape[0], n_samples, perturb_sample, rays_o.device) t0, t1 = t_bounds[:, :-1], t_bounds[:, 1:] # (B, P) t = (t0 + t1) * .5 p, z = sphere.ray_sphere_intersect(rays_o, rays_d, t.reciprocal()) p = sphere.cartesian2spherical(p, inverse_r=True) vidxs = space_.get_voxel_indices(p) return Samples( pts=p, dirs=rays_d[:, None].expand(-1, n_samples, -1), depths=z, dists=(t1 + math.tiny).reciprocal() - t0.reciprocal(), voxel_indices=vidxs, indices=self._get_samples_indices(p), t=t ) else: sample_range = (1 / sample_range[0], 1 / sample_range[1]) if lindisp else sample_range z_bounds = self._sample(sample_range, rays_o.shape[0], n_samples, perturb_sample, rays_o.device) if lindisp: z_bounds = z_bounds.reciprocal() z0, z1 = z_bounds[:, :-1], z_bounds[:, 1:] # (B, P) z = (z0 + z1) * .5 p = rays_o[:, None] + rays_d[:, None] * z[..., None] vidxs = space_.get_voxel_indices(p) return Samples( pts=p, dirs=rays_d[:, None].expand(-1, n_samples, -1), depths=z, dists=z1 - z0, voxel_indices=vidxs, indices=self._get_samples_indices(p), t=z ) class PdfSampler(Module): def __init__(self, *, depth_range: Tuple[float, float], n_samples: int, perturb_sample: bool, spherical: bool, lindisp: bool, **kwargs): """ Initialize a Sampler module :param depth_range: depth range for sampler :param n_samples: count to sample along ray :param perturb_sample: perturb the sample depths :param lindisp: If True, sample linearly in inverse depth rather than in depth """ super().__init__() self.lindisp = lindisp self.perturb_sample = perturb_sample self.spherical = spherical self.n_samples = n_samples self.s_range = (1 / depth_range[0], 1 / depth_range[1]) if self.lindisp else depth_range def forward(self, rays_o, rays_d, *, weights, s_vals=None, include_s_vals=False, **kwargs): """ Sample points along rays. return Spherical or Cartesian coordinates, specified by `self.shperical` :param rays_o `Tensor(B, 3)`: rays' origin :param rays_d `Tensor(B, 3)`: rays' direction :param weights `Tensor(B, M)`: weights of sample bins :param s_vals `Tensor(B, M)`: (optional) center of sample bins :param include_s_vals `bool`: (default to `False`) include `s_vals` in the sample array :return `Tensor(B, N, 3)`: sampled points :return `Tensor(B, N)`: corresponding depths along rays """ if s_vals is None: s_vals = torch.linspace(*self.s_range, self.n_samples, device=device.default()) s = self.sample_pdf(Bins(s_vals).bounds, weights, self.n_samples, det=self.perturb_sample) if include_s_vals: s = torch.cat([s, s_vals], dim=-1) s = torch.sort(s, descending=self.lindisp)[0] z = torch.reciprocal(s) if self.lindisp else s if self.spherical: pts, depths = sphere.ray_sphere_intersect(rays_o, rays_d, z) sphers = sphere.cartesian2spherical(pts, inverse_r=self.lindisp) return sphers, depths, s, pts else: return rays_o[..., None, :] + rays_d[..., None, :] * z[..., None], z, s def sample_pdf(self, bins: torch.Tensor, weights: torch.Tensor, N: int, det=True): ''' :param bins `Tensor(..., M+1)`: bounds of bins :param weights `Tensor(..., M)`: weights of bins :param N `int`: # of samples along each ray :param det `bool`: (default to `True`) perform deterministic sampling or not :return `Tensor(..., N)`: samples ''' # Get pdf weights = weights + math.tiny # prevent nans pdf = weights / torch.sum(weights, dim=-1, keepdim=True) # [..., M] cdf = torch.cat([ torch.zeros_like(pdf[..., :1]), torch.cumsum(pdf, dim=-1) ], dim=-1) # [..., M+1] # Take uniform samples dots_sh = list(weights.shape[:-1]) M = weights.shape[-1] u = torch.linspace(0, 1, N, device=bins.device).expand(dots_sh + [N]) \ if det else torch.rand(dots_sh + [N], device=bins.device) # [..., N] # Invert CDF # [..., N, 1] >= [..., 1, M] ----> [..., N, M] ----> [..., N,] above_inds = torch.sum(u[..., None] >= cdf[..., None, :-1], dim=-1).long() # random sample inside each bin below_inds = torch.clamp(above_inds - 1, min=0) inds_g = torch.stack((below_inds, above_inds), dim=-1) # [..., N, 2] cdf = cdf[..., None, :].expand(dots_sh + [N, M + 1]) # [..., N, M+1] cdf_g = torch.gather(cdf, dim=-1, index=inds_g) # [..., N, 2] bins = bins[..., None, :].expand(dots_sh + [N, M + 1]) # [..., N, M+1] bins_g = torch.gather(bins, dim=-1, index=inds_g) # [..., N, 2] # fix numeric issue denom = cdf_g[..., 1] - cdf_g[..., 0] # [..., N] denom = torch.where(denom < math.tiny, torch.ones_like(denom), denom) t = (u - cdf_g[..., 0]) / denom samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0] + math.tiny) return samples class VoxelSampler(Module): def __init__(self, *, sample_step: float, **kwargs): """ Initialize a VoxelSampler module :param perturb_sample: perturb the sample depths :param step_size: step size """ super().__init__() self.sample_step = sample_step def _forward(self, rays_o: torch.Tensor, rays_d: torch.Tensor, space_module: Space, *, perturb_sample: bool, **kwargs) -> Tuple[Samples, torch.Tensor]: """ [summary] :param rays_o `Tensor(N, 3)`: rays' origin positions :param rays_d `Tensor(N, 3)`: rays' directions :param step_size `float`: gap between samples along a ray :return `Samples(N', P)`: samples along valid rays (which hit at least one voxel) :return `Tensor(N)`: valid rays mask """ intersections = space_module.ray_intersect(rays_o, rays_d, 100) valid_rays_mask = intersections.hits > 0 rays_o = rays_o[valid_rays_mask] rays_d = rays_d[valid_rays_mask] intersections = intersections[valid_rays_mask] # (N) -> (N') n_rays = rays_o.size(0) ray_index_list = torch.arange(n_rays, device=rays_o.device, dtype=torch.long) # (N') hits = intersections.hits min_depths = intersections.min_depths max_depths = intersections.max_depths voxel_indices = intersections.voxel_indices rays_near_depth = min_depths[:, :1] # (N', 1) rays_far_depth = max_depths[ray_index_list, hits - 1][:, None] # (N', 1) rays_length = rays_far_depth - rays_near_depth rays_steps = (rays_length / self.sample_step).ceil().long() rays_step_size = rays_length / rays_steps max_steps = rays_steps.max().item() rays_step = torch.arange(max_steps, device=rays_o.device, dtype=torch.float)[None].repeat(n_rays, 1) # (N', P) invalid_samples_mask = rays_step >= rays_steps samples_min_depth = rays_near_depth + rays_step * rays_step_size samples_depth = samples_min_depth + rays_step_size \ * (torch.rand_like(samples_min_depth) if perturb_sample else 0.5) # (N', P) samples_dist = rays_step_size.repeat(1, max_steps) # (N', 1) -> (N', P) samples_voxel_index = voxel_indices[ ray_index_list[:, None], torch.searchsorted(max_depths, samples_depth) ] # (N', P) samples_depth[invalid_samples_mask] = math.huge samples_dist[invalid_samples_mask] = 0 samples_voxel_index[invalid_samples_mask] = -1 rays_o, rays_d = rays_o[:, None], rays_d[:, None] return Samples( pts=rays_o + rays_d * samples_depth[..., None], dirs=rays_d.expand(-1, max_steps, -1), depths=samples_depth, dists=samples_dist, voxel_indices=samples_voxel_index ), valid_rays_mask @perf def forward(self, rays_o: torch.Tensor, rays_d: torch.Tensor, space: Space, *, perturb_sample: bool, **kwargs) -> Tuple[Samples, torch.Tensor]: """ [summary] :param rays_o `Tensor(N, 3)`: [description] :param rays_d `Tensor(N, 3)`: [description] :param step_size `float`: [description] :return `Samples(N, P)`: [description] """ intersections = space.ray_intersect(rays_o, rays_d, 100) valid_rays_mask = intersections.hits > 0 rays_o = rays_o[valid_rays_mask] rays_d = rays_d[valid_rays_mask] intersections = intersections[valid_rays_mask] # (N) -> (N') checkpoint("Ray intersect") if intersections.size == 0: return None, valid_rays_mask else: min_depth = intersections.min_depths max_depth = intersections.max_depths pts_idx = intersections.voxel_indices dists = max_depth - min_depth tot_dists = dists.sum(dim=-1, keepdim=True) # (N, 1) probs = dists / tot_dists steps = tot_dists[:, 0] / self.sample_step # sample points and use middle point approximation sampled_indices, sampled_depths, sampled_dists = inverse_cdf_sampling( pts_idx, min_depth, max_depth, probs, steps, -1, not perturb_sample) sampled_indices = sampled_indices.long() invalid_idx_mask = sampled_indices.eq(-1) sampled_dists.clamp_min_(0).masked_fill_(invalid_idx_mask, 0) sampled_depths.masked_fill_(invalid_idx_mask, math.huge) checkpoint("Inverse CDF sampling") rays_o, rays_d = rays_o[:, None], rays_d[:, None] return Samples( pts=rays_o + rays_d * sampled_depths[..., None], dirs=rays_d.expand(-1, sampled_depths.size(1), -1), depths=sampled_depths, dists=sampled_dists, voxel_indices=sampled_indices ), valid_rays_mask