from .space import Space, Voxels import torch import torch.nn as nn from typing import Tuple from utils import device from utils import sphere from utils.constants import * from utils.perf import perf, checkpoint from .generic import * from clib import * class Bins(object): @property def up(self): return self.bounds[1:] @property def lo(self): return self.bounds[:-1] def __init__(self, vals: torch.Tensor): self.vals = vals self.bounds = torch.cat([ self.vals[:1], 0.5 * (self.vals[1:] + self.vals[:-1]), self.vals[-1:] ]) @staticmethod def linspace(val_range: Tuple[float, float], N: int, device: torch.device = None): return Bins(torch.linspace(*val_range, N, device=device)) def to(self, device: torch.device): self.vals = self.vals.to(device) self.bounds = self.bounds.to(device) class Samples: pts: torch.Tensor """`Tensor(N[, P], 3)`""" dirs: torch.Tensor """`Tensor(N[, P], 3)`""" depths: torch.Tensor """`Tensor(N[, P])`""" dists: torch.Tensor """`Tensor(N[, P])`""" voxel_indices: torch.Tensor """`Tensor(N[, P])`""" @property def size(self): return self.pts.size()[:-1] @property def device(self): return self.pts.device def __init__(self, pts: torch.Tensor, dirs: torch.Tensor, depths: torch.Tensor, dists: torch.Tensor, voxel_indices: torch.Tensor) -> None: self.pts = pts self.dirs = dirs self.depths = depths self.dists = dists self.voxel_indices = voxel_indices def __getitem__(self, index): return Samples( pts=self.pts[index], dirs=self.dirs[index], depths=self.depths[index], dists=self.dists[index], voxel_indices=self.voxel_indices[index]) def reshape(self, *shape: int): return Samples( pts=self.pts.reshape(*shape, 3), dirs=self.dirs.reshape(*shape, 3), depths=self.depths.reshape(*shape), dists=self.dists.reshape(*shape), voxel_indices=self.voxel_indices.reshape(*shape)) class Sampler(nn.Module): def __init__(self, *, sample_range: Tuple[float, float], n_samples: int, lindisp: bool, **kwargs): """ Initialize a Sampler module :param depth_range: depth range for sampler :param n_samples: count to sample along ray :param perturb_sample: perturb the sample depths :param lindisp: If True, sample linearly in inverse depth rather than in depth """ super().__init__() self.lindisp = lindisp s_range = (1 / sample_range[0], 1 / sample_range[1]) if self.lindisp else sample_range if s_range[1] > s_range[0]: s_range[0] += 1e-4 s_range[1] -= 1e-4 else: s_range[0] -= 1e-4 s_range[1] += 1e-4 self.bins = Bins.linspace(s_range, n_samples, device=device.default()) @perf def forward(self, rays_o: torch.Tensor, rays_d: torch.Tensor, space_module: Space, perturb_sample: bool, **kwargs) -> Tuple[Samples, torch.Tensor]: """ Sample points along rays. return Spherical or Cartesian coordinates, specified by `self.shperical` :param rays_o `Tensor(N, 3)`: rays' origin :param rays_d `Tensor(N, 3)`: rays' direction :return `Samples(N, P)`: samples """ s = self.bins.vals.expand(rays_o.size(0), -1) if perturb_sample: s = self.bins.lo + (self.bins.up - self.bins.lo) * torch.rand_like(s) pts, depths = self._get_sample_points(rays_o, rays_d, s) voxel_indices = space_module.get_voxel_indices(pts) valid_rays_mask = voxel_indices.ne(-1).any(dim=-1) return Samples( pts=pts, dirs=rays_d[:, None].expand(-1, depths.size(1), -1), depths=depths, dists=self._calc_dists(depths), voxel_indices=voxel_indices )[valid_rays_mask], valid_rays_mask def _get_sample_points(self, rays_o, rays_d, s): z = torch.reciprocal(s) if self.lindisp else s pts = rays_o[:, None] + rays_d[:, None] * z[..., None] depths = z return pts, depths def _calc_dists(self, vals): # Compute 'distance' (in time) between each integration time along a ray. # The 'distance' from the last integration time is infinity. # dists: (N_rays, N) dists = vals[..., 1:] - vals[..., :-1] last_dist = torch.zeros_like(vals[..., :1]) + TINY_FLOAT return torch.cat([dists, last_dist], -1) class SphericalSampler(Sampler): def __init__(self, *, sample_range: Tuple[float, float], n_samples: int, perturb_sample: bool, **kwargs): """ Initialize a Sampler module :param depth_range: depth range for sampler :param n_samples: count to sample along ray :param perturb_sample: perturb the sample depths :param lindisp: If True, sample linearly in inverse depth rather than in depth """ super().__init__(sample_range=sample_range, n_samples=n_samples, perturb_sample=perturb_sample, lindisp=False) def _get_sample_points(self, rays_o, rays_d, s): r = torch.reciprocal(s) pts, depths = sphere.ray_sphere_intersect(rays_o, rays_d, r) pts = sphere.cartesian2spherical(pts, inverse_r=True) return pts, depths class PdfSampler(nn.Module): def __init__(self, *, depth_range: Tuple[float, float], n_samples: int, perturb_sample: bool, spherical: bool, lindisp: bool, **kwargs): """ Initialize a Sampler module :param depth_range: depth range for sampler :param n_samples: count to sample along ray :param perturb_sample: perturb the sample depths :param lindisp: If True, sample linearly in inverse depth rather than in depth """ super().__init__() self.lindisp = lindisp self.perturb_sample = perturb_sample self.spherical = spherical self.n_samples = n_samples self.s_range = (1 / depth_range[0], 1 / depth_range[1]) if self.lindisp else depth_range def forward(self, rays_o, rays_d, *, weights, s_vals=None, include_s_vals=False, **kwargs): """ Sample points along rays. return Spherical or Cartesian coordinates, specified by `self.shperical` :param rays_o `Tensor(B, 3)`: rays' origin :param rays_d `Tensor(B, 3)`: rays' direction :param weights `Tensor(B, M)`: weights of sample bins :param s_vals `Tensor(B, M)`: (optional) center of sample bins :param include_s_vals `bool`: (default to `False`) include `s_vals` in the sample array :return `Tensor(B, N, 3)`: sampled points :return `Tensor(B, N)`: corresponding depths along rays """ if s_vals is None: s_vals = torch.linspace(*self.s_range, self.n_samples, device=device.default()) s = self.sample_pdf(Bins(s_vals).bounds, weights, self.n_samples, det=self.perturb_sample) if include_s_vals: s = torch.cat([s, s_vals], dim=-1) s = torch.sort(s, descending=self.lindisp)[0] z = torch.reciprocal(s) if self.lindisp else s if self.spherical: pts, depths = sphere.ray_sphere_intersect(rays_o, rays_d, z) sphers = sphere.cartesian2spherical(pts, inverse_r=self.lindisp) return sphers, depths, s, pts else: return rays_o[..., None, :] + rays_d[..., None, :] * z[..., None], z, s def sample_pdf(self, bins: torch.Tensor, weights: torch.Tensor, N: int, det=True): ''' :param bins `Tensor(..., M+1)`: bounds of bins :param weights `Tensor(..., M)`: weights of bins :param N `int`: # of samples along each ray :param det `bool`: (default to `True`) perform deterministic sampling or not :return `Tensor(..., N)`: samples ''' # Get pdf weights = weights + TINY_FLOAT # prevent nans pdf = weights / torch.sum(weights, dim=-1, keepdim=True) # [..., M] cdf = torch.cat([ torch.zeros_like(pdf[..., :1]), torch.cumsum(pdf, dim=-1) ], dim=-1) # [..., M+1] # Take uniform samples dots_sh = list(weights.shape[:-1]) M = weights.shape[-1] u = torch.linspace(0, 1, N, device=bins.device).expand(dots_sh + [N]) \ if det else torch.rand(dots_sh + [N], device=bins.device) # [..., N] # Invert CDF # [..., N, 1] >= [..., 1, M] ----> [..., N, M] ----> [..., N,] above_inds = torch.sum(u[..., None] >= cdf[..., None, :-1], dim=-1).long() # random sample inside each bin below_inds = torch.clamp(above_inds - 1, min=0) inds_g = torch.stack((below_inds, above_inds), dim=-1) # [..., N, 2] cdf = cdf[..., None, :].expand(dots_sh + [N, M + 1]) # [..., N, M+1] cdf_g = torch.gather(cdf, dim=-1, index=inds_g) # [..., N, 2] bins = bins[..., None, :].expand(dots_sh + [N, M + 1]) # [..., N, M+1] bins_g = torch.gather(bins, dim=-1, index=inds_g) # [..., N, 2] # fix numeric issue denom = cdf_g[..., 1] - cdf_g[..., 0] # [..., N] denom = torch.where(denom < TINY_FLOAT, torch.ones_like(denom), denom) t = (u - cdf_g[..., 0]) / denom samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0] + TINY_FLOAT) return samples class VoxelSampler(nn.Module): def __init__(self, *, perturb_sample: bool, sample_step: float, **kwargs): """ Initialize a VoxelSampler module :param perturb_sample: perturb the sample depths :param step_size: step size """ super().__init__() self.perturb_sample = perturb_sample self.sample_step = sample_step def _forward(self, rays_o: torch.Tensor, rays_d: torch.Tensor, space_module: Space, **kwargs) -> Tuple[Samples, torch.Tensor]: """ [summary] :param rays_o `Tensor(N, 3)`: rays' origin positions :param rays_d `Tensor(N, 3)`: rays' directions :param step_size `float`: gap between samples along a ray :return `Samples(N', P)`: samples along valid rays (which hit at least one voxel) :return `Tensor(N)`: valid rays mask """ intersections = space_module.ray_intersect(rays_o, rays_d, 100) valid_rays_mask = intersections.hits > 0 rays_o = rays_o[valid_rays_mask] rays_d = rays_d[valid_rays_mask] intersections = intersections[valid_rays_mask] # (N) -> (N') n_rays = rays_o.size(0) ray_index_list = torch.arange(n_rays, device=rays_o.device, dtype=torch.long) # (N') hits = intersections.hits min_depths = intersections.min_depths max_depths = intersections.max_depths voxel_indices = intersections.voxel_indices rays_near_depth = min_depths[:, :1] # (N', 1) rays_far_depth = max_depths[ray_index_list, hits - 1][:, None] # (N', 1) rays_length = rays_far_depth - rays_near_depth rays_steps = (rays_length / self.sample_step).ceil().long() rays_step_size = rays_length / rays_steps max_steps = rays_steps.max().item() rays_step = torch.arange(max_steps, device=rays_o.device, dtype=torch.float)[None].repeat(n_rays, 1) # (N', P) invalid_samples_mask = rays_step >= rays_steps samples_min_depth = rays_near_depth + rays_step * rays_step_size samples_depth = samples_min_depth + rays_step_size \ * (torch.rand_like(samples_min_depth) if self.perturb_sample else 0.5) # (N', P) samples_dist = rays_step_size.repeat(1, max_steps) # (N', 1) -> (N', P) samples_voxel_index = voxel_indices[ ray_index_list[:, None], torch.searchsorted(max_depths, samples_depth) ] # (N', P) samples_depth[invalid_samples_mask] = HUGE_FLOAT samples_dist[invalid_samples_mask] = 0 samples_voxel_index[invalid_samples_mask] = -1 rays_o, rays_d = rays_o[:, None], rays_d[:, None] return Samples( pts=rays_o + rays_d * samples_depth[..., None], dirs=rays_d.expand(-1, max_steps, -1), depths=samples_depth, dists=samples_dist, voxel_indices=samples_voxel_index ), valid_rays_mask @perf def forward(self, rays_o: torch.Tensor, rays_d: torch.Tensor, space_module: Space, **kwargs) -> Tuple[Samples, torch.Tensor]: """ [summary] :param rays_o `Tensor(N, 3)`: [description] :param rays_d `Tensor(N, 3)`: [description] :param step_size `float`: [description] :return `Samples(N, P)`: [description] """ intersections = space_module.ray_intersect(rays_o, rays_d, 100) valid_rays_mask = intersections.hits > 0 rays_o = rays_o[valid_rays_mask] rays_d = rays_d[valid_rays_mask] intersections = intersections[valid_rays_mask] # (N) -> (N') checkpoint("Ray intersect") if intersections.size == 0: return None, valid_rays_mask else: min_depth = intersections.min_depths max_depth = intersections.max_depths pts_idx = intersections.voxel_indices dists = max_depth - min_depth tot_dists = dists.sum(dim=-1, keepdim=True) # (N, 1) probs = dists / tot_dists steps = tot_dists[:, 0] / self.sample_step # sample points and use middle point approximation sampled_indices, sampled_depths, sampled_dists = inverse_cdf_sampling( pts_idx, min_depth, max_depth, probs, steps, -1, not self.perturb_sample) sampled_indices = sampled_indices.long() invalid_idx_mask = sampled_indices.eq(-1) sampled_dists.clamp_min_(0).masked_fill_(invalid_idx_mask, 0) sampled_depths.masked_fill_(invalid_idx_mask, HUGE_FLOAT) checkpoint("Inverse CDF sampling") rays_o, rays_d = rays_o[:, None], rays_d[:, None] return Samples( pts=rays_o + rays_d * sampled_depths[..., None], dirs=rays_d.expand(-1, sampled_depths.size(1), -1), depths=sampled_depths, dists=sampled_dists, voxel_indices=sampled_indices ), valid_rays_mask