from typing import Tuple import torch import torch.nn as nn from utils import device from utils import sphere from utils.constants import * from .generic import * class Bins(object): def __init__(self, vals: torch.Tensor): self.vals = vals self.bounds = torch.cat([ self.vals[:1], 0.5 * (self.vals[1:] + self.vals[:-1]), self.vals[-1:] ]) self.up = self.bounds[1:] self.lo = self.bounds[:-1] @staticmethod def linspace(val_range: Tuple[float, float], N: int, device: torch.device = None): return Bins(torch.linspace(*val_range, N, device=device)) def to(self, device: torch.device): self.vals = self.vals.to(device) self.bounds = self.bounds.to(device) self.up = self.bounds[1:] self.lo = self.bounds[:-1] class Sampler(nn.Module): def __init__(self, *, sample_range: Tuple[float, float], n_samples: int, perturb_sample: bool, spherical: bool, lindisp: bool): """ Initialize a Sampler module :param depth_range: depth range for sampler :param n_samples: count to sample along ray :param perturb_sample: perturb the sample depths :param lindisp: If True, sample linearly in inverse depth rather than in depth """ super().__init__() self.lindisp = lindisp self.spherical = spherical self.perturb_sample = perturb_sample s_range = (1 / sample_range[0], 1 / sample_range[1]) if self.lindisp else sample_range self.bins = Bins.linspace(s_range, n_samples, device=device.default()) def forward(self, rays_o, rays_d): """ Sample points along rays. return Spherical or Cartesian coordinates, specified by `self.shperical` :param rays_o `Tensor(B, 3)`: rays' origin :param rays_d `Tensor(B, 3)`: rays' direction :return `Tensor(B, N, 3)`: sampled points :return `Tensor(B, N)`: corresponding depths along rays """ s = self.bins.vals.expand(rays_o.size(0), -1) if self.perturb_sample: s = self.bins.lo + (self.bins.up - self.bins.lo) * torch.rand_like(s) z = torch.reciprocal(s) if self.lindisp else s if self.spherical: pts, depths = sphere.ray_sphere_intersect(rays_o, rays_d, z) sphers = sphere.cartesian2spherical(pts, inverse_r=self.lindisp) return sphers, depths, s, pts else: return rays_o[..., None, :] + rays_d[..., None, :] * z[..., None], z, s, None class PdfSampler(nn.Module): def __init__(self, *, depth_range: Tuple[float, float], n_samples: int, perturb_sample: bool, spherical: bool, lindisp: bool): """ Initialize a Sampler module :param depth_range: depth range for sampler :param n_samples: count to sample along ray :param perturb_sample: perturb the sample depths :param lindisp: If True, sample linearly in inverse depth rather than in depth """ super().__init__() self.lindisp = lindisp self.perturb_sample = perturb_sample self.spherical = spherical self.n_samples = n_samples self.s_range = (1 / depth_range[0], 1 / depth_range[1]) if self.lindisp else depth_range def forward(self, rays_o, rays_d, *, weights, s_vals=None, include_s_vals=False): """ Sample points along rays. return Spherical or Cartesian coordinates, specified by `self.shperical` :param rays_o `Tensor(B, 3)`: rays' origin :param rays_d `Tensor(B, 3)`: rays' direction :param weights `Tensor(B, M)`: weights of sample bins :param s_vals `Tensor(B, M)`: (optional) center of sample bins :param include_s_vals `bool`: (default to `False`) include `s_vals` in the sample array :return `Tensor(B, N, 3)`: sampled points :return `Tensor(B, N)`: corresponding depths along rays """ if s_vals is None: s_vals = torch.linspace(*self.s_range, self.n_samples, device=device.default()) s = self.sample_pdf(Bins(s_vals).bounds, weights, self.n_samples, det=self.perturb_sample) if include_s_vals: s = torch.cat([s, s_vals], dim=-1) s = torch.sort(s, descending=self.lindisp)[0] z = torch.reciprocal(s) if self.lindisp else s if self.spherical: pts, depths = sphere.ray_sphere_intersect(rays_o, rays_d, z) sphers = sphere.cartesian2spherical(pts, inverse_r=self.lindisp) return sphers, depths, s, pts else: return rays_o[..., None, :] + rays_d[..., None, :] * z[..., None], z, s def sample_pdf(self, bins: torch.Tensor, weights: torch.Tensor, N: int, det=True): ''' :param bins `Tensor(..., M+1)`: bounds of bins :param weights `Tensor(..., M)`: weights of bins :param N `int`: # of samples along each ray :param det `bool`: (default to `True`) perform deterministic sampling or not :return `Tensor(..., N)`: samples ''' # Get pdf weights = weights + TINY_FLOAT # prevent nans pdf = weights / torch.sum(weights, dim=-1, keepdim=True) # [..., M] cdf = torch.cat([ torch.zeros_like(pdf[..., :1]), torch.cumsum(pdf, dim=-1) ], dim=-1) # [..., M+1] # Take uniform samples dots_sh = list(weights.shape[:-1]) M = weights.shape[-1] u = torch.linspace(0, 1, N, device=bins.device).expand(dots_sh + [N]) \ if det else torch.rand(dots_sh + [N], device=bins.device) # [..., N] # Invert CDF # [..., N, 1] >= [..., 1, M] ----> [..., N, M] ----> [..., N,] above_inds = torch.sum(u[..., None] >= cdf[..., None, :-1], dim=-1).long() # random sample inside each bin below_inds = torch.clamp(above_inds - 1, min=0) inds_g = torch.stack((below_inds, above_inds), dim=-1) # [..., N, 2] cdf = cdf[..., None, :].expand(dots_sh + [N, M + 1]) # [..., N, M+1] cdf_g = torch.gather(cdf, dim=-1, index=inds_g) # [..., N, 2] bins = bins[..., None, :].expand(dots_sh + [N, M + 1]) # [..., N, M+1] bins_g = torch.gather(bins, dim=-1, index=inds_g) # [..., N, 2] # fix numeric issue denom = cdf_g[..., 1] - cdf_g[..., 0] # [..., N] denom = torch.where(denom < TINY_FLOAT, torch.ones_like(denom), denom) t = (u - cdf_g[..., 0]) / denom samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0] + TINY_FLOAT) return samples