# Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. # # To install the _ext library, run the following command: # > python setup.py build_ext --inplace ''' Modified based on: https://github.com/erikwijmans/Pointnet2_PyTorch ''' from __future__ import ( division, absolute_import, with_statement, print_function, unicode_literals, ) import torch import torch.nn.functional as F import numpy as np from torch.autograd import Function from torch.autograd.function import FunctionCtx, once_differentiable import clib._ext as _ext from utils.geometry import discretize_points from utils import math class BallRayIntersect(Function): @staticmethod def forward(ctx, radius, n_max, points, ray_start, ray_dir): inds, min_depth, max_depth = _ext.ball_intersect( ray_start.float(), ray_dir.float(), points.float(), radius, n_max) min_depth = min_depth.type_as(ray_start) max_depth = max_depth.type_as(ray_start) ctx.mark_non_differentiable(inds) ctx.mark_non_differentiable(min_depth) ctx.mark_non_differentiable(max_depth) return inds, min_depth, max_depth @staticmethod def backward(ctx, a, b, c): return None, None, None, None, None ball_ray_intersect = BallRayIntersect.apply def aabb_ray_intersect(voxelsize: float, n_max: int, points: torch.Tensor, ray_start: torch.Tensor, ray_dir: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ AABB-Ray intersect test :param voxelsize `float`: size of a voxel :param n_max `int`: maximum number of hits :param points `Tensor(M, 3)`: voxels' centers :param ray_start `Tensor(S, N, 3)`: rays' start positions :param ray_dir `Tensor(S, N, 3)`: rays' directions :return `Tensor(S, N, n_max)`: indices of intersected voxels or -1 :return `Tensor(S, N, n_max)`: min depths of every intersected voxels :return `Tensor(S, N, n_max)`: max depths of every intersected voxels """ # HACK: speed-up ray-voxel intersection by batching... G = min(2048, int(2e9 / points.numel())) # HACK: avoid out-of-memory S, N = ray_start.shape[:2] K = math.ceil(N / G) G, K = 1, N # HACK H = K * G if H > N: ray_start = torch.cat([ray_start, ray_start[:, :H - N]], 1) ray_dir = torch.cat([ray_dir, ray_dir[:, :H - N]], 1) ray_start = ray_start.reshape(S * G, K, 3) ray_dir = ray_dir.reshape(S * G, K, 3) points = points[None].expand(S * G, *points.size()).contiguous() inds, min_depth, max_depth = _ext.aabb_intersect( ray_start.float(), ray_dir.float(), points.float(), voxelsize, n_max) min_depth = min_depth.type_as(ray_start) max_depth = max_depth.type_as(ray_start) inds = inds.reshape(S, H, -1) min_depth = min_depth.reshape(S, H, -1) max_depth = max_depth.reshape(S, H, -1) if H > N: inds = inds[:, :N] min_depth = min_depth[:, :N] max_depth = max_depth[:, :N] return inds, min_depth, max_depth class SparseVoxelOctreeRayIntersect(Function): @staticmethod def forward(ctx, voxelsize, n_max, points, children, ray_start, ray_dir): # HACK: avoid out-of-memory G = min(2048, int(2 * 10 ** 9 / (points.numel() + children.numel()))) S, N = ray_start.shape[:2] K = int(np.ceil(N / G)) G, K = 1, N # HACK H = K * G if H > N: ray_start = torch.cat([ray_start, ray_start[:, :H - N]], 1) ray_dir = torch.cat([ray_dir, ray_dir[:, :H - N]], 1) ray_start = ray_start.reshape(S * G, K, 3) ray_dir = ray_dir.reshape(S * G, K, 3) points = points[None].expand(S * G, *points.size()).contiguous() children = children[None].expand(S * G, *children.size()).contiguous() inds, min_depth, max_depth = _ext.svo_intersect( ray_start.float(), ray_dir.float(), points.float(), children.int(), voxelsize, n_max) min_depth = min_depth.type_as(ray_start) max_depth = max_depth.type_as(ray_start) inds = inds.reshape(S, H, -1) min_depth = min_depth.reshape(S, H, -1) max_depth = max_depth.reshape(S, H, -1) if H > N: inds = inds[:, :N] min_depth = min_depth[:, :N] max_depth = max_depth[:, :N] ctx.mark_non_differentiable(inds) ctx.mark_non_differentiable(min_depth) ctx.mark_non_differentiable(max_depth) return inds, min_depth, max_depth @staticmethod def backward(ctx, a, b, c): return None, None, None, None, None def octree_ray_intersect(voxelsize: float, n_max: int, points: torch.Tensor, children: torch.Tensor, ray_start: torch.Tensor, ray_dir: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Octree-Ray intersect test :param voxelsize `float`: size of a voxel :param n_max `int`: maximum number of hits :param points `Tensor(M, 3)`: voxels' centers :param children `Tensor(M, 9)`: flattened octree structure :param ray_start `Tensor(S, N, 3)`: rays' start positions :param ray_dir `Tensor(S, N, 3)`: rays' directions :return `Tensor(S, N, n_max)`: indices of intersected voxels or -1 :return `Tensor(S, N, n_max)`: min depths of every intersected voxels :return `Tensor(S, N, n_max)`: max depths of every intersected voxels """ return SparseVoxelOctreeRayIntersect.apply(voxelsize, n_max, points, children, ray_start, ray_dir) class TriangleRayIntersect(Function): @staticmethod def forward(ctx, cagesize, blur_ratio, n_max, points, faces, ray_start, ray_dir): # HACK: speed-up ray-voxel intersection by batching... G = min(2048, int(2 * 10 ** 9 / (3 * faces.numel()))) # HACK: avoid out-of-memory S, N = ray_start.shape[:2] K = int(np.ceil(N / G)) H = K * G if H > N: ray_start = torch.cat([ray_start, ray_start[:, :H - N]], 1) ray_dir = torch.cat([ray_dir, ray_dir[:, :H - N]], 1) ray_start = ray_start.reshape(S * G, K, 3) ray_dir = ray_dir.reshape(S * G, K, 3) face_points = F.embedding(faces.reshape(-1, 3), points.reshape(-1, 3)) face_points = face_points.unsqueeze(0).expand(S * G, *face_points.size()).contiguous() inds, depth, uv = _ext.triangle_intersect( ray_start.float(), ray_dir.float(), face_points.float(), cagesize, blur_ratio, n_max) depth = depth.type_as(ray_start) uv = uv.type_as(ray_start) inds = inds.reshape(S, H, -1) depth = depth.reshape(S, H, -1, 3) uv = uv.reshape(S, H, -1) if H > N: inds = inds[:, :N] depth = depth[:, :N] uv = uv[:, :N] ctx.mark_non_differentiable(inds) ctx.mark_non_differentiable(depth) ctx.mark_non_differentiable(uv) return inds, depth, uv @staticmethod def backward(ctx, a, b, c): return None, None, None, None, None, None triangle_ray_intersect = TriangleRayIntersect.apply class UniformRaySampling(Function): @staticmethod def forward(ctx, pts_idx, min_depth, max_depth, step_size, max_ray_length, deterministic=False): G, N, P = 256, pts_idx.size(0), pts_idx.size(1) H = int(np.ceil(N / G)) * G if H > N: pts_idx = torch.cat([pts_idx, pts_idx[:H - N]], 0) min_depth = torch.cat([min_depth, min_depth[:H - N]], 0) max_depth = torch.cat([max_depth, max_depth[:H - N]], 0) pts_idx = pts_idx.reshape(G, -1, P) min_depth = min_depth.reshape(G, -1, P) max_depth = max_depth.reshape(G, -1, P) # pre-generate noise max_steps = int(max_ray_length / step_size) max_steps = max_steps + min_depth.size(-1) * 2 noise = min_depth.new_zeros(*min_depth.size()[:-1], max_steps) if deterministic: noise += 0.5 else: noise = noise.uniform_() # call cuda function sampled_idx, sampled_depth, sampled_dists = _ext.uniform_ray_sampling( pts_idx, min_depth.float(), max_depth.float(), noise.float(), step_size, max_steps) sampled_depth = sampled_depth.type_as(min_depth) sampled_dists = sampled_dists.type_as(min_depth) sampled_idx = sampled_idx.reshape(H, -1) sampled_depth = sampled_depth.reshape(H, -1) sampled_dists = sampled_dists.reshape(H, -1) if H > N: sampled_idx = sampled_idx[: N] sampled_depth = sampled_depth[: N] sampled_dists = sampled_dists[: N] max_len = sampled_idx.ne(-1).sum(-1).max() sampled_idx = sampled_idx[:, :max_len] sampled_depth = sampled_depth[:, :max_len] sampled_dists = sampled_dists[:, :max_len] ctx.mark_non_differentiable(sampled_idx) ctx.mark_non_differentiable(sampled_depth) ctx.mark_non_differentiable(sampled_dists) return sampled_idx, sampled_depth, sampled_dists @staticmethod def backward(ctx, a, b, c): return None, None, None, None, None, None def uniform_ray_sampling(pts_idx: torch.Tensor, min_depth: torch.Tensor, max_depth: torch.Tensor, step_size: float, max_ray_length: float, deterministic: bool = False) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Sample along rays uniformly :param pts_idx `Tensor(N, P)`: indices of voxels intersected with rays :param min_depth `Tensor(N, P)`: min depth of intersections of rays and voxels :param max_depth `Tensor(N, P)`: max depth of intersections of rays and voxels :param step_size `float`: size of sampling step :param max_ray_length `float`: maximum sampling depth along rays :param deterministic `bool`: (optional) sample deterministically (or randomly), defaults to False :return `Tensor(N, P')`: voxel indices of sampled points :return `Tensor(N, P')`: depth of sampled points :return `Tensor(N, P')`: length of sampled points """ return UniformRaySampling.apply(pts_idx, min_depth, max_depth, step_size, max_ray_length, deterministic) class InverseCDFRaySampling(Function): @staticmethod def forward(ctx, pts_idx, min_depth, max_depth, probs, steps, fixed_step_size=-1, deterministic=False): G, N, P = 200, pts_idx.size(0), pts_idx.size(1) H = int(np.ceil(N / G)) * G if H > N: pts_idx = torch.cat([pts_idx, pts_idx[:1].expand(H - N, P)], 0) min_depth = torch.cat([min_depth, min_depth[:1].expand(H - N, P)], 0) max_depth = torch.cat([max_depth, max_depth[:1].expand(H - N, P)], 0) probs = torch.cat([probs, probs[:1].expand(H - N, P)], 0) steps = torch.cat([steps, steps[:1].expand(H - N)], 0) # print(G, P, np.ceil(N / G), N, H, pts_idx.shape, min_depth.device) pts_idx = pts_idx.reshape(G, -1, P) min_depth = min_depth.reshape(G, -1, P) max_depth = max_depth.reshape(G, -1, P) probs = probs.reshape(G, -1, P) steps = steps.reshape(G, -1) # pre-generate noise max_steps = steps.ceil().long().max() + P noise = min_depth.new_zeros(*min_depth.size()[:-1], max_steps) if deterministic: noise += 0.5 else: noise = noise.uniform_().clamp(min=0.001, max=0.999) # in case # call cuda function chunk_size = 4 * G # to avoid oom? results = [ _ext.inverse_cdf_sampling( pts_idx[:, i:i + chunk_size].contiguous(), min_depth.float()[:, i:i + chunk_size].contiguous(), max_depth.float()[:, i:i + chunk_size].contiguous(), noise.float()[:, i:i + chunk_size].contiguous(), probs.float()[:, i:i + chunk_size].contiguous(), steps.float()[:, i:i + chunk_size].contiguous(), fixed_step_size) for i in range(0, min_depth.size(1), chunk_size) ] sampled_idx, sampled_depth, sampled_dists = [ torch.cat([r[i] for r in results], 1) for i in range(3) ] sampled_depth = sampled_depth.type_as(min_depth) sampled_dists = sampled_dists.type_as(min_depth) sampled_idx = sampled_idx.reshape(H, -1) sampled_depth = sampled_depth.reshape(H, -1) sampled_dists = sampled_dists.reshape(H, -1) if H > N: sampled_idx = sampled_idx[: N] sampled_depth = sampled_depth[: N] sampled_dists = sampled_dists[: N] max_len = sampled_idx.ne(-1).sum(-1).max() sampled_idx = sampled_idx[:, :max_len] sampled_depth = sampled_depth[:, :max_len] sampled_dists = sampled_dists[:, :max_len] ctx.mark_non_differentiable(sampled_idx) ctx.mark_non_differentiable(sampled_depth) ctx.mark_non_differentiable(sampled_dists) return sampled_idx, sampled_depth, sampled_dists @staticmethod def backward(ctx, a, b, c): return None, None, None, None, None, None, None def inverse_cdf_sampling(pts_idx: torch.Tensor, min_depth: torch.Tensor, max_depth: torch.Tensor, probs: torch.Tensor, steps: torch.Tensor, fixed_step_size: float = -1, deterministic: bool = False) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Sample along rays by inverse CDF :param pts_idx `Tensor(N, P)`: indices of voxels intersected with rays :param min_depth `Tensor(N, P)`: min depth of intersections of rays and voxels :param max_depth `Tensor(N, P)`: max depth of intersections of rays and voxels :param probs `Tensor(N, P)`: :param steps `Tensor(N)`: :param fixed_step_size `float`: :param deterministic `bool`: (optional) sample deterministically (or randomly), defaults to False :return `Tensor(N, P')`: voxel indices of sampled points :return `Tensor(N, P')`: depth of sampled points :return `Tensor(N, P')`: length of sampled points """ return InverseCDFRaySampling.apply(pts_idx, min_depth, max_depth, probs, steps, fixed_step_size, deterministic) # back-up for ray point sampling @torch.no_grad() def _parallel_ray_sampling(MARCH_SIZE, pts_idx, min_depth, max_depth, deterministic=False): # uniform sampling _min_depth = min_depth.min(1)[0] _max_depth = max_depth.masked_fill(max_depth.eq(math.huge), 0).max(1)[0] max_ray_length = (_max_depth - _min_depth).max() delta = torch.arange(int(max_ray_length / MARCH_SIZE), device=min_depth.device, dtype=min_depth.dtype) delta = delta[None, :].expand(min_depth.size(0), delta.size(-1)) if deterministic: delta = delta + 0.5 else: delta = delta + delta.clone().uniform_().clamp(min=0.01, max=0.99) delta = delta * MARCH_SIZE sampled_depth = min_depth[:, :1] + delta sampled_idx = (sampled_depth[:, :, None] >= min_depth[:, None, :]).sum(-1) - 1 sampled_idx = pts_idx.gather(1, sampled_idx) # include all boundary points sampled_depth = torch.cat([min_depth, max_depth, sampled_depth], -1) sampled_idx = torch.cat([pts_idx, pts_idx, sampled_idx], -1) # reorder sampled_depth, ordered_index = sampled_depth.sort(-1) sampled_idx = sampled_idx.gather(1, ordered_index) sampled_dists = sampled_depth[:, 1:] - sampled_depth[:, :-1] # distances sampled_depth = .5 * (sampled_depth[:, 1:] + sampled_depth[:, :-1]) # mid-points # remove all invalid depths min_ids = (sampled_depth[:, :, None] >= min_depth[:, None, :]).sum(-1) - 1 max_ids = (sampled_depth[:, :, None] >= max_depth[:, None, :]).sum(-1) sampled_depth.masked_fill_( (max_ids.ne(min_ids)) | (sampled_depth > _max_depth[:, None]) | (sampled_dists == 0.0), math.huge) sampled_depth, ordered_index = sampled_depth.sort(-1) # sort again sampled_masks = sampled_depth.eq(math.huge) num_max_steps = (~sampled_masks).sum(-1).max() sampled_depth = sampled_depth[:, :num_max_steps] sampled_dists = sampled_dists.gather(1, ordered_index).masked_fill_( sampled_masks, 0.0)[:, :num_max_steps] sampled_idx = sampled_idx.gather(1, ordered_index).masked_fill_( sampled_masks, -1)[:, :num_max_steps] return sampled_idx, sampled_depth, sampled_dists @torch.no_grad() def parallel_ray_sampling(MARCH_SIZE, pts_idx, min_depth, max_depth, deterministic=False): chunk_size = 4096 full_size = min_depth.shape[0] if full_size <= chunk_size: return _parallel_ray_sampling(MARCH_SIZE, pts_idx, min_depth, max_depth, deterministic=deterministic) outputs = zip(*[ _parallel_ray_sampling( MARCH_SIZE, pts_idx[i:i + chunk_size], min_depth[i:i + chunk_size], max_depth[i:i + chunk_size], deterministic=deterministic) for i in range(0, full_size, chunk_size)]) sampled_idx, sampled_depth, sampled_dists = outputs def padding_points(xs, pad): if len(xs) == 1: return xs[0] maxlen = max([x.size(1) for x in xs]) full_size = sum([x.size(0) for x in xs]) xt = xs[0].new_ones(full_size, maxlen).fill_(pad) st = 0 for i in range(len(xs)): xt[st: st + xs[i].size(0), :xs[i].size(1)] = xs[i] st += xs[i].size(0) return xt sampled_idx = padding_points(sampled_idx, -1) sampled_depth = padding_points(sampled_depth, math.huge) sampled_dists = padding_points(sampled_dists, 0.0) return sampled_idx, sampled_depth, sampled_dists def build_easy_octree(points: torch.Tensor, half_voxel: float) -> tuple[torch.Tensor, torch.Tensor]: """ Build an octree. :param points `Tensor(M, 3)`: centers of leaf voxels :param half_voxel `float`: half size of voxel :return `Tensor(M', 3)`: centers of all nodes in octree :return `Tensor(M', 9)`: flattened octree structure """ coords, residual = discretize_points(points, half_voxel) ranges = coords.max(0)[0] - coords.min(0)[0] depths = torch.log2(ranges.max().float()).ceil_().long() - 1 center = (coords.max(0)[0] + coords.min(0)[0]) / 2 centers, children = _ext.build_octree(center, coords, int(depths)) centers = centers.float() * half_voxel + residual # transform back to float return centers, children class MultiresHashEncode(Function): @staticmethod def forward(ctx: FunctionCtx, levels: int, coarse_levels: int, res_list: torch.Tensor, hash_table_offsets: torch.Tensor, x: torch.Tensor, hash_table: torch.Tensor, grad_enabled: bool) -> torch.Tensor: """ [summary] :param ctx `FunctionCtx`: [description] :param levels `int`: [description] :param coarse_levels `int`: [description] :param res_list `Tensor(L, D)`: [description] :param hash_table_offsets `Tensor(L+1)`: [description] :param x `Tensor(N, D)`: [description] :param hash_table `Tensor(T, F)`: [description] :return `Tensor(L, N, F)`: [description] """ x = x.contiguous() res_list = res_list.int().contiguous() hash_table_offsets = hash_table_offsets.int().contiguous() if grad_enabled and hash_table.requires_grad: encoded, weights, indices = _ext.multires_hash_encode_with_grad( levels, coarse_levels, x, res_list, hash_table, hash_table_offsets) ctx.save_for_backward(weights, indices.long()) ctx.hash_table_shape = hash_table.shape return encoded print(hash_table) return _ext.multires_hash_encode(levels, coarse_levels, x, res_list, hash_table, hash_table_offsets) @staticmethod @once_differentiable def backward(ctx: FunctionCtx, grad_output: torch.Tensor): """ [summary] :param ctx `FunctionCtx`: [description] :param grad_output `Tensor(L, N, F)`: [description] :return: [description] """ weights, indices = ctx.saved_tensors # (L, N, C) t = grad_output[..., None, :] * weights[..., None] # (L, N, C, F) grad_hash_table = grad_output.new_zeros(*ctx.hash_table_shape) grad_hash_table.index_put_([indices], t, accumulate=True) return None, None, None, None, None, grad_hash_table, None def multires_hash_encode(levels: int, coarse_levels: int, res_list: torch.Tensor, hash_table_offsets: torch.Tensor, x: torch.Tensor, hash_table: torch.Tensor) -> torch.Tensor: """ :param levels `int`: [description] :param coarse_levels `int`: [description] :param res_list `Tensor(L, D)`: [description] :param hash_table_offsets `Tensor(L+1)`: [description] :param x `Tensor(N, D)`: [description] :param hash_table `Tensor(T, F)`: [description] :return `Tensor(L, N, F)`: [description] """ return MultiresHashEncode.apply(levels, coarse_levels, res_list, hash_table_offsets, x, hash_table, torch.is_grad_enabled())