class IntegratedPosEncoder(InputEncoder): def __init__(self, chns, L, shape: str, cat_input=False): super().__init__(chns) self.shape = shape def _lift_gaussian(self, d: torch.Tensor, t_mean: torch.Tensor, t_var: torch.Tensor, r_var: torch.Tensor, diag: bool): """Lift a Gaussian defined along a ray to 3D coordinates.""" mean = d[..., None, :] * t_mean[..., None] d_sq = d**2 d_mag_sq = torch.sum(d_sq, -1, keepdim=True).clamp_min(1e-10) if diag: d_outer_diag = d_sq null_outer_diag = 1 - d_outer_diag / d_mag_sq t_cov_diag = t_var[..., None] * d_outer_diag[..., None, :] xy_cov_diag = r_var[..., None] * null_outer_diag[..., None, :] cov_diag = t_cov_diag + xy_cov_diag return mean, cov_diag else: d_outer = d[..., :, None] * d[..., None, :] eye = torch.eye(d.shape[-1], device=d.device) null_outer = eye - d[..., :, None] * (d / d_mag_sq)[..., None, :] t_cov = t_var[..., None, None] * d_outer[..., None, :, :] xy_cov = r_var[..., None, None] * null_outer[..., None, :, :] cov = t_cov + xy_cov return mean, cov def _conical_frustum_to_gaussian(self, d: torch.Tensor, t0: float, t1: float, base_radius: float, diag: bool, stable: bool = True): """Approximate a conical frustum as a Gaussian distribution (mean+cov). Assumes the ray is originating from the origin, and base_radius is the radius at dist=1. Doesn't assume `d` is normalized. Args: d: torch.float32 3-vector, the axis of the cone t0: float, the starting distance of the frustum. t1: float, the ending distance of the frustum. base_radius: float, the scale of the radius as a function of distance. diag: boolean, whether or the Gaussian will be diagonal or full-covariance. stable: boolean, whether or not to use the stable computation described in the paper (setting this to False will cause catastrophic failure). Returns: a Gaussian (mean and covariance). """ if stable: mu = (t0 + t1) / 2 hw = (t1 - t0) / 2 t_mean = mu + (2 * mu * hw**2) / (3 * mu**2 + hw**2) t_var = (hw**2) / 3 - (4 / 15) * ((hw**4 * (12 * mu**2 - hw**2)) / (3 * mu**2 + hw**2)**2) r_var = base_radius**2 * ((mu**2) / 4 + (5 / 12) * hw**2 - 4 / 15 * (hw**4) / (3 * mu**2 + hw**2)) else: t_mean = (3 * (t1**4 - t0**4)) / (4 * (t1**3 - t0**3)) r_var = base_radius**2 * (3 / 20 * (t1**5 - t0**5) / (t1**3 - t0**3)) t_mosq = 3 / 5 * (t1**5 - t0**5) / (t1**3 - t0**3) t_var = t_mosq - t_mean**2 return self._lift_gaussian(d, t_mean, t_var, r_var, diag) def _cylinder_to_gaussian(self, d: torch.Tensor, t0: float, t1: float, radius: float, diag: bool): """Approximate a cylinder as a Gaussian distribution (mean+cov). Assumes the ray is originating from the origin, and radius is the radius. Does not renormalize `d`. Args: d: torch.float32 3-vector, the axis of the cylinder t0: float, the starting distance of the cylinder. t1: float, the ending distance of the cylinder. radius: float, the radius of the cylinder diag: boolean, whether or the Gaussian will be diagonal or full-covariance. Returns: a Gaussian (mean and covariance). """ t_mean = (t0 + t1) / 2 r_var = radius**2 / 4 t_var = (t1 - t0)**2 / 12 return self._lift_gaussian(d, t_mean, t_var, r_var, diag) def cast_rays(self, t_vals: torch.Tensor, rays_o: torch.Tensor, rays_d: torch.Tensor, rays_r: torch.Tensor, diag: bool = True): """Cast rays (cone- or cylinder-shaped) and featurize sections of it. Args: t_vals: float array, the "fencepost" distances along the ray. rays_o: float array, the ray origin coordinates. rays_d: float array, the ray direction vectors. radii: float array, the radii (base radii for cones) of the rays. ray_shape: string, the shape of the ray, must be 'cone' or 'cylinder'. diag: boolean, whether or not the covariance matrices should be diagonal. Returns: a tuple of arrays of means and covariances. """ t0 = t_vals[..., :-1] t1 = t_vals[..., 1:] if self.shape == 'cone': gaussian_fn = self._conical_frustum_to_gaussian elif self.shape == 'cylinder': gaussian_fn = self._cylinder_to_gaussian else: assert False means, covs = gaussian_fn(rays_d, t0, t1, rays_r, diag) means = means + rays_o[..., None, :] return means, covs def integrated_pos_enc(x_coord: tuple[torch.Tensor, torch.Tensor], min_deg: int, max_deg: int, diag: bool = True): """Encode `x` with sinusoids scaled by 2^[min_deg:max_deg-1]. Args: x_coord: a tuple containing: x, torch.ndarray, variables to be encoded. Should be in [-pi, pi]. x_cov, torch.ndarray, covariance matrices for `x`. min_deg: int, the min degree of the encoding. max_deg: int, the max degree of the encoding. diag: bool, if true, expects input covariances to be diagonal (full otherwise). Returns: encoded: torch.ndarray, encoded variables. """ if diag: x, x_cov_diag = x_coord scales = torch.tensor([2**i for i in range(min_deg, max_deg)], device=x.device)[:, None] shape = list(x.shape[:-1]) + [-1] y = torch.reshape(x[..., None, :] * scales, shape) y_var = torch.reshape(x_cov_diag[..., None, :] * scales**2, shape) else: x, x_cov = x_coord num_dims = x.shape[-1] basis = torch.cat([ 2**i * torch.eye(num_dims, device=x.device) for i in range(min_deg, max_deg) ], 1) y = torch.matmul(x, basis) # Get the diagonal of a covariance matrix (ie, variance). This is equivalent # to jax.vmap(torch.diag)((basis.T @ covs) @ basis). y_var = (torch.matmul(x_cov, basis) * basis).sum(-2) return math.expected_sin( torch.cat([y, y + 0.5 * math.pi], -1), torch.cat([y_var] * 2, -1))[0] # @torch.jit.script def intepolate_calc_weight(x, corners): return x * corners * 2 - x - corners + 1 class MultiresHashEncoder(InputEncoder): fast_op = True t_ind: torch.dtype layers: int coarse_levels: int layers_hashsize: list[int] layers_res: torch.Tensor """Tensor(L, D)""" local_corners: torch.Tensor """Tensor(C, D)""" layers_hashoffset: torch.Tensor """Tensor(L+1)""" hashtable: torch.nn.parameter.Parameter """Parameter(T, F)""" def __init__(self, chns: int, layers: int, log2_hashsize: int, features: int, res0: int | list[int], scale_up: float = 2.0): super().__init__(chns, layers * features, (0., 1.)) res0 = torch.tensor([res0] * chns if isinstance(res0, int) else res0) self.layers = layers self.features = features self.scale_up = scale_up self.max_hashsize = 2 ** log2_hashsize self.t_ind = torch.int if self.fast_op else torch.long layers_res: list[torch.Tensor] = [] self.layers_hashsize: list[int] = [] self.coarse_levels = 0 layers_hashoffset: list[int] = [0] for i in range(layers): layers_res.append((res0 * scale_up ** i).to(self.t_ind)) if layers_res[-1].max() > self.max_hashsize ** (1 / 3)\ or layers_res[-1].prod() > self.max_hashsize: self.layers_hashsize.append(self.max_hashsize) else: self.layers_hashsize.append(layers_res[-1].prod().item()) self.coarse_levels = i + 1 layers_hashoffset.append(layers_hashoffset[-1] + self.layers_hashsize[-1]) self.register_temp("layers_res", torch.stack(layers_res, 0)) self.register_temp("layers_hashoffset", torch.tensor(layers_hashoffset, dtype=self.t_ind)) self.register_temp("local_corners", split_voxels_local(1, 2, dims=chns) + .5) # Initialize the hash table entries using the uniform distribution U(−10^−4, 10^−4) to provide # a small amount of randomness while encouraging initial predictions close to zero [muller2022instant] self.hashtable = torch.nn.parameter.Parameter( (torch.rand(layers_hashoffset[-1], features, device=self.device) - .5)) @profile def forward(self, x: torch.Tensor) -> torch.Tensor: """ Encode inputs using multi-resolution hash encoder [muller2022instant] :param x `Tensor(N..., D)`: D-dim inputs :return `Tensor(N..., LF)`: encoded outputs """ if self.fast_op: N_, D = x.shape[:-1], x.shape[-1] return multires_hash_encode(self.layers, self.coarse_levels, self.layers_res, self.layers_hashoffset, x.reshape(-1, D), self.hashtable)\ .transpose(0, 1).reshape(*N_, -1) @profile("Calculate corners") def calc_corners(x) -> tuple[torch.Tensor, torch.Tensor]: grid_pos = x.unsqueeze(-2) * (self.layers_res - 1) # (N..., L, D) grid_pos.unsqueeze_(-2) # (N..., L, 1, D) grid_lo = torch.floor(grid_pos) grid_pos.sub_(grid_lo) corners = (grid_lo + self.local_corners).long().min(self.layers_res.unsqueeze(-2) - 1) # (N..., L, C, D) return grid_pos, corners grid_pos, corners = calc_corners(x) # (N..., L, 1, D), (N..., L, C, D) @profile("Calculate encoded") def calc_encoded(level: int) -> torch.Tensor: if level < self.coarse_levels: idx = to_flat_indices(corners[..., level, :, :], self.layers_res[level, None]) else: idx = self._fast_hash(corners[..., level, :, :]) % self.max_hashsize idx.add_(self.layers_hashoffset[level, None]) return self._linear_interp(grid_pos[..., level, :, :], self.hashtable[idx]) result = torch.stack([calc_encoded(level) for level in range(self.layers)], dim=-2) # (N..., L, X) return result.flatten(-2) def _linear_interp(self, x: torch.Tensor, corner_values: torch.Tensor) -> torch.Tensor: """ [summary] :param x `Tensor(N..., L, 1, D)`: [description] :param corner_values `Tensor(N..., L, C, X)`: [description] :return `Tensor(N..., L, X): [description] :rtype: [type] """ weights = (x * self.local_corners * 2 - x - self.local_corners + 1).prod(-1, keepdim=True) # (N..., L, C, 1) return (weights * corner_values).sum(-2) # (N..., L, X) def extra_repr(self) -> str: return f"{self.in_chns} -> {self.out_chns}({self.layers}x{self.features})"\ f", resolution={self.layers_res[0].tolist()}*{self.scale_up}^L"\ f", max_hashsize={self.max_hashsize}" @profile def _fast_hash(self, grid_pos: torch.Tensor) -> torch.Tensor: """ Perform fast hash according to instant-ngp :param grid_pos `Tensor(N..., D)`: integer grid positions :return `Tensor(N...)`: hash values """ if grid_pos.shape[-1] > 7: raise ValueError("fast_hash can only hash up to 7 dimensions.") # While 1 is technically not a good prime for hashing (or a prime at all), it helps memory coherence # and is sufficient for our use case of obtaining a uniformly colliding index from high-dimensional # coordinates. [muller2022instant] primes = [1, 2654435761, 805459861, 3674653429, 2097192037, 1434869437, 2165219737] result = grid_pos[..., 0] * primes[0] for i in range(1, grid_pos.shape[-1]): result.bitwise_xor_(grid_pos[..., i] * primes[i]) return result class LayeredMultiresHashEncoder(InputEncoder): use_cpp = False layers: int coarse_levels: int layers_res: torch.Tensor """Tensor(L, D)""" local_corners: torch.Tensor """Tensor(C, D)""" layers_hashsize: list[int] layers_hashoffset: list[int] t_ind: torch.dtype def __init__(self, chns: int, layers: int, log2_hashsize: int, features: int, res0: int | list[int], scale_up: float = 2.0, parts: int = 64): super().__init__(chns, layers * features, (0., 1.)) res0 = torch.tensor([res0] * chns if isinstance(res0, int) else res0) self.layers = layers self.features = features self.scale_up = scale_up self.max_hashsize = 2 ** log2_hashsize // parts self.t_ind = torch.int if self.use_cpp else torch.long layers_res: list[torch.Tensor] = [] self.layers_hashsize: list[int] = [] self.layers_usehash: list[bool] = [] self.coarse_levels = 0 layers_hashoffset: list[int] = [0] for i in range(layers): layers_res.append(res0 if i == 0 else (layers_res[-1] * scale_up).to(self.t_ind)) if layers_res[-1].max() > self.max_hashsize ** (1 / 3)\ or layers_res[-1].prod() > self.max_hashsize: self.layers_hashsize.append(self.max_hashsize) else: self.layers_hashsize.append(layers_res[-1].prod().item()) self.coarse_levels = i + 1 layers_hashoffset.append(layers_hashoffset[-1] + self.layers_hashsize[-1]) self.register_temp("layers_res", torch.stack(layers_res, 0)) self.register_temp("layers_hashoffset", torch.tensor(layers_hashoffset, dtype=self.t_ind)) # Initialize the hash table entries using the uniform distribution U(−10^−4, 10^−4) to provide # a small amount of randomnesddaddadwss while encouraging initial predictions close to zero [muller2022instant] self.hashtable = torch.nn.parameter.Parameter( (torch.rand(parts, layers_hashoffset[-1], features, device=self.device) - .5)) self.register_temp("local_corners", split_voxels_local(1, 2, dims=chns) + .5) @profile def forward(self, x: torch.Tensor) -> torch.Tensor: """ Encode inputs using multi-resolution hash encoder [muller2022instant] :param x `Tensor(N..., P, D)`: D-dim inputs :return `Tensor(N..., P, LF)`: encoded outputs """ if self.use_cpp: N_, P, D = x.shape[:-2], x.shape[-2], x.shape[-1] return torch.stack([ multires_hash_encode(self.layers, self.coarse_levels, self.layers_res, self.layers_hashoffset, x[..., i, :].reshape(-1, D), self.hashtable[i]).reshape(*N_, -1) for i in range(P) ], dim=-2) @profile("Calculate corners") def calc_corners(x) -> tuple[torch.Tensor, torch.Tensor]: grid_pos = x.unsqueeze(-2) * (self.layers_res - 1) # (N..., P, L, D) grid_pos.unsqueeze_(-2) # (N..., P, L, 1, D) grid_lo = torch.floor(grid_pos) grid_pos.sub_(grid_lo) corners = (grid_lo + self.local_corners).long().min(self.layers_res.unsqueeze(-2) - 1) # (N..., L, C, D) return grid_pos, corners grid_pos, corners = calc_corners(x) # (N..., P, L, 1, D), (N..., P, L, C, D) @profile("Calculate encoded") def calc_encoded(level: int) -> torch.Tensor: if level < self.coarse_levels: idx = to_flat_indices(corners[..., level, :, :], self.layers_res[level, None]) else: idx = self._fast_hash(corners[..., level, :, :]) % self.max_hashsize idx.add_(self.layers_hashoffset[level, None]) part_idx = torch.arange(x.shape[-2], device=x.device)[:, None].broadcast_to(idx.shape) return self._linear_interp(grid_pos[..., level, :, :], self.hashtable[part_idx, idx]) result = torch.stack([calc_encoded(level) for level in range(self.layers)], dim=-2) # (N..., L, X) return result.flatten(-2) def _linear_interp(self, x: torch.Tensor, corner_values: torch.Tensor) -> torch.Tensor: """ [summary] :param x `Tensor(N..., L, 1, D)`: [description] :param corner_values `Tensor(N..., L, C, X)`: [description] :return `Tensor(N..., L, X): [description] :rtype: [type] """ weights = (x * self.local_corners * 2 - x - self.local_corners + 1).prod(-1, keepdim=True) # (N..., L, C, 1) return (weights * corner_values).sum(-2) # (N..., L, X) def extra_repr(self) -> str: return f"{self.in_chns} -> {self.out_chns}({self.layers}x{self.features})"\ f", resolution={self.layers_res[0].tolist()}*{self.scale_up}^L"\ f", max_hashsize={self.max_hashsize}" @profile def _fast_hash(self, grid_pos: torch.Tensor) -> torch.Tensor: """ Perform fast hash according to instant-ngp :param grid_pos `Tensor(N..., D)`: integer grid positions :return `Tensor(N...)`: hash values """ if grid_pos.shape[-1] > 7: raise ValueError("fast_hash can only hash up to 7 dimensions.") # While 1 is technically not a good prime for hashing (or a prime at all), it helps memory coherence # and is sufficient for our use case of obtaining a uniformly colliding index from high-dimensional # coordinates. [muller2022instant] primes = [1, 2654435761, 805459861, 3674653429, 2097192037, 1434869437, 2165219737] result = grid_pos[..., 0] * primes[0] for i in range(1, grid_pos.shape[-1]): result.bitwise_xor_(grid_pos[..., i] * primes[i]) return result