from itertools import cycle from math import ceil from typing import Dict, Tuple, Union import torch import torch.nn as nn from utils.constants import * from utils.perf import perf from .generic import * from .sampler import Samples def density2energy(densities: torch.Tensor, dists: torch.Tensor, raw_noise_std: float = 0): """ Calculate energies from densities inferred by model. :param densities `Tensor(N..., 1)`: model's output densities :param dists `Tensor(N...)`: integration times :param raw_noise_std `float`: the noise std used to egularize network during training (prevents floater artifacts), defaults to 0, means no noise is added :return `Tensor(N..., 1)`: energies which block light rays """ if raw_noise_std > 0: # Add noise to model's predictions for density. Can be used to # regularize network during training (prevents floater artifacts). densities = densities + torch.normal(0.0, raw_noise_std, densities.size()) return densities * dists[..., None] def density2alpha(densities: torch.Tensor, dists: torch.Tensor, raw_noise_std: float = 0): """ Calculate alphas from densities inferred by model. :param densities `Tensor(N..., 1)`: model's output densities :param dists `Tensor(N...)`: integration times :param raw_noise_std `float`: the noise std used to egularize network during training (prevents floater artifacts), defaults to 0, means no noise is added :return `Tensor(N..., 1)`: alphas """ energies = density2energy(densities, dists, raw_noise_std) return 1.0 - torch.exp(-energies) class AlphaComposition(nn.Module): def __init__(self): super().__init__() def forward(self, colors, alphas, bg=None): """ [summary] :param colors `Tensor(N, P, C)`: [description] :param alphas `Tensor(N, P, 1)`: [description] :param bg `Tensor([N, ]C)`: [description], defaults to None :return `Tensor(N, C)`: [description] """ # Compute weight for RGB of each sample along each ray. A cumprod() is # used to express the idea of the ray not having reflected up to this # sample yet. one_minus_alpha = torch.cumprod(1 - alphas[..., :-1, :] + TINY_FLOAT, dim=-2) one_minus_alpha = torch.cat([ torch.ones_like(one_minus_alpha[..., :1, :]), one_minus_alpha ], dim=-2) weights = alphas * one_minus_alpha # (N, P, 1) # (N, C), computed weighted color of each sample along each ray. final_color = torch.sum(weights * colors, dim=-2) # To composite onto a white background, use the accumulated alpha map. if bg is not None: # Sum of weights along each ray. This value is in [0, 1] up to numerical error. acc_map = torch.sum(weights, -1) final_color = final_color + bg * (1. - acc_map[..., None]) return { 'color': final_color, 'weights': weights, } class VolumnRenderer(nn.Module): class States: kernel: nn.Module samples: Samples hit_mask: torch.Tensor early_stop_tolerance: float N: int P: int colors: torch.Tensor diffuses: torch.Tensor speculars: torch.Tensor energies: torch.Tensor weights: torch.Tensor cum_energies: torch.Tensor exp_energies: torch.Tensor tot_evaluations: Dict[str, int] chunk: Tuple[slice, slice] cum_chunk: Tuple[slice, slice] cum_last: Tuple[slice, slice] chunk_id: int @property def start(self) -> int: return self.chunk[1].start @property def end(self) -> int: return self.chunk[1].stop def __init__(self, kernel: nn.Module, samples: Samples, early_stop_tolerance: float) -> None: self.kernel = kernel self.samples = samples self.early_stop_tolerance = early_stop_tolerance N, P = samples.size self.hit_mask = samples.voxel_indices != -1 # (N, P) self.colors = torch.zeros(N, P, kernel.chns('color'), device=samples.device) self.diffuses = torch.zeros(N, P, kernel.chns('color'), device=samples.device) self.speculars = torch.zeros(N, P, kernel.chns('color'), device=samples.device) self.energies = torch.zeros(N, P, 1, device=samples.device) self.weights = torch.zeros(N, P, 1, device=samples.device) self.cum_energies = torch.zeros(N, P + 1, 1, device=samples.device) self.exp_energies = torch.ones(N, P + 1, 1, device=samples.device) self.tot_evaluations = {} self.N, self.P = N, P self.chunk_id = -1 def n_hits(self, start: int = None, end: int = None) -> int: if start is None: return self.hit_mask.count_nonzero().item() if end is None: return self.hit_mask[:, start].count_nonzero().item() return self.hit_mask[:, start:end].count_nonzero().item() def accumulate_tot_evaluations(self, key: str, n: int): if key not in self.tot_evaluations: self.tot_evaluations[key] = 0 self.tot_evaluations[key] += n def next_chunk(self, *, length=None, end=None): start = 0 if not hasattr(self, "chunk") else self.end length = length or self.P end = min(end or start + length, self.P) self.chunk = slice(None), slice(start, end) self.cum_chunk = slice(None), slice(start + 1, end + 1) self.cum_last = slice(None), slice(start, start + 1) self.chunk_id += 1 return self def __init__(self, **kwargs): super().__init__() @perf def forward(self, kernel: nn.Module, samples: Samples, extra_outputs: List[str] = [], *, raymarching_early_stop_tolerance: float = 0, raymarching_chunk_size_or_sections: Union[int, List[int]] = None, **kwargs): """ Perform volumn rendering. :param kernel: render kernel :param samples `Samples(N, P)`: samples :param extra_outputs `list[str]`: extra items should be contained in the result dict. Optional values include 'depth', 'layers', 'states' and attribute names in class `States` (e.g. 'weights'). Defaults to [] :param raymarching_early_stop_tolerance `float`: tolerance of raymarching early stop. Should between 0 and 1 (0 means no early stop). Defaults to 0 :param raymarching_chunk_size_or_sections `int|list[int]`: indicates how to split raymarching process. Use a list of integers to specify samples of every chunk, or a positive integer to specify number of chunks. Use a negative interger to split by number of hits in chunks, and the absolute value means maximum number of hits in a chunk. 0 and `None` means not splitting the raymarching process. Defaults to `None` :return `dict`: render result { 'color'[, 'depth', 'layers', 'states', ...] } """ if samples.size[1] == 0: print("VolumnRenderer.forward(): # of samples is zero") return None s = VolumnRenderer.States(kernel, samples, raymarching_early_stop_tolerance) if not raymarching_chunk_size_or_sections: raymarching_chunk_size_or_sections = [s.P] elif isinstance(raymarching_chunk_size_or_sections, int) and \ raymarching_chunk_size_or_sections > 0: raymarching_chunk_size_or_sections = [ceil(s.P / raymarching_chunk_size_or_sections)] if isinstance(raymarching_chunk_size_or_sections, list): chunk_sections = raymarching_chunk_size_or_sections for chunk_samples in cycle(chunk_sections): self._forward_chunk(s.next_chunk(length=chunk_samples)) if s.end >= s.P: break else: chunk_size = -raymarching_chunk_size_or_sections chunk_hits = s.n_hits(0) for i in range(1, s.P): n_hits = s.n_hits(i) if chunk_hits + n_hits > chunk_size: self._forward_chunk(s.next_chunk(end=i)) n_hits = s.n_hits(i) chunk_hits = 0 chunk_hits += n_hits self._forward_chunk(s.next_chunk()) ret = { 'color': torch.sum(s.colors * s.weights, 1), 'tot_evaluations': s.tot_evaluations } for key in extra_outputs: if key == 'depth': ret['depth'] = torch.sum(s.samples.depths[..., None] * s.weights, 1) elif key == 'diffuse': ret['diffuse'] = torch.sum(s.diffuses * s.weights, 1) elif key == 'specular': ret['specular'] = torch.sum(s.speculars * s.weights, 1) elif key == 'layers': ret['layers'] = torch.cat([s.colors, 1 - torch.exp(-s.energies)], dim=-1) elif key == 'states': ret['states'] = s else: ret[key] = getattr(s, key) return ret # if raymarching_chunk_size == 0: # raymarching_chunk_samples = 1 # if raymarching_chunk_samples != 0: # if isinstance(raymarching_chunk_samples, int): # raymarching_chunk_samples = repeat(raymarching_chunk_samples, # ceil(s.P / raymarching_chunk_samples)) # chunk_offset = 0 # for chunk_samples in raymarching_chunk_samples: # start, end = chunk_offset, chunk_offset + chunk_samples # n_hits = self._forward_chunk(s, start, end) # if n_hits > 0 and tolerance > 0: # Early stop # s.hit_mask[s.cum_energies[:, end, 0] > tolerance] = 0 # chunk_offset += chunk_samples # elif raymarching_chunk_size > 0: # chunk_offset, chunk_hits = 0, s.n_hits(0) # for i in range(1, s.P): # n_hits = s.n_hits(i) # if chunk_hits + n_hits > raymarching_chunk_size: # self._forward_chunk(s, chunk_offset, i, chunk_hits) # if chunk_hits > 0 and tolerance > 0: # Early stop # s.hit_mask[s.cum_energies[:, i, 0] > tolerance] = 0 # n_hits = s.n_hits(i) # chunk_hits, chunk_offset = 0, i # chunk_hits += n_hits # self._forward_chunk(s, chunk_offset, s.P, chunk_hits) # else: # self._forward_chunk(s, 0, s.P) # return self._composite(s, extra_outputs) # original_depth = samples.get('original_point_depth', None) # if original_depth is not None: # results['z'] = (original_depth * probs).sum(-1) # if getattr(input_fn, "track_max_probs", False) and (not self.training): # input_fn.track_voxel_probs(samples['sampled_point_voxel_idx'].long(), results['probs']) def _calc_weights(self, s: States): """ Calculate weights of samples in composited outputs :param s `States`: states :param start `int`: chunk's start :param end `int`: chunk's end """ s.cum_energies[s.cum_chunk] = torch.cumsum(s.energies[s.chunk], 1) \ + s.cum_energies[s.cum_last] s.exp_energies[s.cum_chunk] = (-s.cum_energies[s.cum_chunk]).exp() s.weights[s.chunk] = s.exp_energies[s.chunk] - s.exp_energies[s.cum_chunk] def _apply_early_stop(self, s: States): """ Stop rays whose accumulated opacity are larger than a threshold :param s `States`: s :param end `int`: chunk's end """ if s.end < s.P and s.early_stop_tolerance > 0: rays_to_stop = s.exp_energies[:, s.end, 0] < s.early_stop_tolerance s.hit_mask[rays_to_stop, s.end:] = 0 def _forward_chunk(self, s: States) -> int: fi_idxs: Tuple[torch.Tensor, ...] = s.hit_mask[s.chunk].nonzero(as_tuple=True) # (N') fi_idxs[1].add_(s.start) if fi_idxs[0].size(0) == 0: s.cum_energies[s.cum_chunk] = s.cum_energies[s.cum_last] s.exp_energies[s.cum_chunk] = s.exp_energies[s.cum_last] return 0 # fi_* means "filtered" by hit mask fi_samples = s.samples[fi_idxs] # N -> N' # Infer densities and colors fi_outputs = s.kernel.render(fi_samples, 'color', 'density', 'specular', 'diffuse', chunk_id=s.chunk_id) s.colors.index_put_(fi_idxs, fi_outputs['color']) if fi_outputs['specular'] is not None: s.speculars.index_put_(fi_idxs, fi_outputs['specular']) if fi_outputs['diffuse'] is not None: s.diffuses.index_put_(fi_idxs, fi_outputs['diffuse']) s.energies.index_put_(fi_idxs, density2energy(fi_outputs['density'], fi_samples.dists)) s.accumulate_tot_evaluations("color", fi_idxs[0].size(0)) self._calc_weights(s) self._apply_early_stop(s) class DensityFirstVolumnRenderer(VolumnRenderer): def __init__(self, **kwargs): super().__init__(**kwargs) def _forward_chunk(self, s: VolumnRenderer.States) -> int: fi_idxs: Tuple[torch.Tensor, ...] = s.hit_mask[s.chunk].nonzero(as_tuple=True) # (N') fi_idxs[1].add_(s.start) if fi_idxs[0].size(0) == 0: s.cum_energies[s.cum_chunk] = s.cum_energies[s.cum_last] s.exp_energies[s.cum_chunk] = s.exp_energies[s.cum_last] return 0 # fi_* means "filtered" by hit mask fi_samples = s.samples[fi_idxs] # N -> N' # For all valid samples: encode X fi_encoded_x = s.kernel.encode_x(fi_samples) # (N', Ex) # Infer densities (shape) fi_outputs = s.kernel.infer(fi_encoded_x, None, 'density', 'color_feat', chunk_id=s.chunk_id) s.energies.index_put_(fi_idxs, density2energy(fi_outputs['density'], fi_samples.dists)) s.accumulate_tot_evaluations("density", fi_idxs[0].size(0)) self._calc_weights(s) self._apply_early_stop(s) # Remove samples whose weights are less than a threshold s.hit_mask[s.chunk][s.weights[s.chunk][..., 0] < 0.01] = 0 # Update "filtered" tensors fi_mask = s.hit_mask[fi_idxs] fi_idxs = (fi_idxs[0][fi_mask], fi_idxs[1][fi_mask]) # N' -> N" fi_encoded_x = fi_encoded_x[fi_mask] # (N", Ex) fi_color_feats = fi_outputs['color_feat'][fi_mask] # For all valid samples: encode D fi_encoded_d = s.kernel.encode_d(s.samples[fi_idxs]) # (N", Ed) # Infer colors (appearance) fi_outputs = s.kernel.infer(fi_encoded_x, fi_encoded_d, 'color', 'specular', 'diffuse', chunk_id=s.chunk_id, extras={"color_feats": fi_color_feats}) # if s.chunk_id == 0: # fi_colors[:] *= fi_colors.new_tensor([1, 0, 0]) # elif s.chunk_id == 1: # fi_colors[:] *= fi_colors.new_tensor([0, 1, 0]) # elif s.chunk_id == 2: # fi_colors[:] *= fi_colors.new_tensor([0, 0, 1]) # else: # fi_colors[:] *= fi_colors.new_tensor([1, 1, 0]) s.colors.index_put_(fi_idxs, fi_outputs['color']) if fi_outputs['specular'] is not None: s.speculars.index_put_(fi_idxs, fi_outputs['specular']) if fi_outputs['diffuse'] is not None: s.diffuses.index_put_(fi_idxs, fi_outputs['diffuse']) s.accumulate_tot_evaluations("color", fi_idxs[0].size(0))