import torch from itertools import cycle from typing import Dict, Set, Tuple, Union from utils.type import NetInput, ReturnData from .generic import * from model.base import BaseModel from utils import math from utils.module import Module from utils.perf import checkpoint, perf from utils.samples import Samples def density2energy(densities: torch.Tensor, dists: torch.Tensor, raw_noise_std: float = 0): """ Calculate energies from densities inferred by model. :param densities `Tensor(N..., 1)`: model's output densities :param dists `Tensor(N...)`: integration times :param raw_noise_std `float`: the noise std used to egularize network during training (prevents floater artifacts), defaults to 0, means no noise is added :return `Tensor(N..., 1)`: energies which block light rays """ if raw_noise_std > 0: # Add noise to model's predictions for density. Can be used to # regularize network during training (prevents floater artifacts). densities = densities + torch.normal(0.0, raw_noise_std, densities.size()) return densities * dists[..., None] def density2alpha(densities: torch.Tensor, dists: torch.Tensor, raw_noise_std: float = 0): """ Calculate alphas from densities inferred by model. :param densities `Tensor(N..., 1)`: model's output densities :param dists `Tensor(N...)`: integration times :param raw_noise_std `float`: the noise std used to egularize network during training (prevents floater artifacts), defaults to 0, means no noise is added :return `Tensor(N..., 1)`: alphas """ energies = density2energy(densities, dists, raw_noise_std) return 1.0 - torch.exp(-energies) class AlphaComposition(Module): def __init__(self): super().__init__() def forward(self, colors, alphas, bg=None): """ [summary] :param colors `Tensor(N, P, C)`: [description] :param alphas `Tensor(N, P, 1)`: [description] :param bg `Tensor([N, ]C)`: [description], defaults to None :return `Tensor(N, C)`: [description] """ # Compute weight for RGB of each sample along each ray. A cumprod() is # used to express the idea of the ray not having reflected up to this # sample yet. one_minus_alpha = torch.cumprod(1 - alphas[..., :-1, :] + math.tiny, dim=-2) one_minus_alpha = torch.cat([ torch.ones_like(one_minus_alpha[..., :1, :]), one_minus_alpha ], dim=-2) weights = alphas * one_minus_alpha # (N, P, 1) # (N, C), computed weighted color of each sample along each ray. final_color = torch.sum(weights * colors, dim=-2) # To composite onto a white background, use the accumulated alpha map. if bg is not None: # Sum of weights along each ray. This value is in [0, 1] up to numerical error. acc_map = torch.sum(weights, -1) final_color = final_color + bg * (1. - acc_map[..., None]) return { 'color': final_color, 'weights': weights, } class VolumnRenderer(Module): class States: kernel: BaseModel samples: Samples early_stop_tolerance: float outputs: Set[str] hit_mask: torch.Tensor N: int P: int device: torch.device colors: torch.Tensor densities: torch.Tensor energies: torch.Tensor weights: torch.Tensor cum_energies: torch.Tensor exp_energies: torch.Tensor tot_evaluations: Dict[str, int] chunk: Tuple[slice, slice] cum_chunk: Tuple[slice, slice] cum_last: Tuple[slice, slice] chunk_id: int @property def start(self) -> int: return self.chunk[1].start @property def end(self) -> int: return self.chunk[1].stop def __init__(self, kernel: BaseModel, samples: Samples, early_stop_tolerance: float, outputs: Set[str]) -> None: self.kernel = kernel self.samples = samples self.early_stop_tolerance = early_stop_tolerance self.outputs = outputs N, P = samples.size self.device = self.samples.device self.hit_mask = samples.voxel_indices != -1 # (N, P) | bool self.colors = torch.zeros(N, P, kernel.chns('color'), device=samples.device) self.densities = torch.zeros(N, P, 1, device=samples.device) self.energies = torch.zeros(N, P, 1, device=samples.device) self.weights = torch.zeros(N, P, 1, device=samples.device) self.cum_energies = torch.zeros(N, P + 1, 1, device=samples.device) self.exp_energies = torch.ones(N, P + 1, 1, device=samples.device) self.tot_evaluations = {} self.N, self.P = N, P self.chunk_id = -1 def n_hits(self, index: Union[int, slice] = None) -> int: if not isinstance(self.hit_mask, torch.Tensor): if index is not None: return self.N * self.colors[:, index].shape[1] return self.N * self.P if index is None: return self.hit_mask.count_nonzero().item() return self.hit_mask[:, index].count_nonzero().item() def accumulate_tot_evaluations(self, key: str, n: int): if key not in self.tot_evaluations: self.tot_evaluations[key] = 0 self.tot_evaluations[key] += n def next_chunk(self, *, length=None, end=None): start = 0 if not hasattr(self, "chunk") else self.end length = length or self.P end = min(end or start + length, self.P) self.chunk = slice(None), slice(start, end) self.cum_chunk = slice(None), slice(start + 1, end + 1) self.cum_last = slice(None), slice(start, start + 1) self.chunk_id += 1 return self def put(self, key: str, values: torch.Tensor, indices: Union[Tuple[torch.Tensor, torch.Tensor], Tuple[slice, slice]]): if not hasattr(self, key): new_tensor = torch.zeros(self.N, self.P, values.shape[-1], device=self.device) setattr(self, key, new_tensor) tensor: torch.Tensor = getattr(self, key) # if isinstance(indices[0], torch.Tensor): # tensor.index_put_(indices, values) # else: tensor[indices] = values def __init__(self, **kwargs): super().__init__() @perf def forward(self, kernel: BaseModel, samples: Samples, *outputs: str, raymarching_early_stop_tolerance: float = 0, raymarching_chunk_size_or_sections: Union[int, List[int]] = None, **kwargs) -> ReturnData: """ Perform volumn rendering. :param kernel `BaseModel`: render kernel :param samples `Samples(N, P)`: samples :param outputs `str...`: items should be contained in the result dict. Optional values include 'color', 'depth', 'layers', 'states' and attribute names in class `States` (e.g. 'weights'). Defaults to [] :param raymarching_early_stop_tolerance `float`: tolerance of raymarching early stop. Should between 0 and 1 (0 means no early stop). Defaults to 0 :param raymarching_chunk_size_or_sections `int|list[int]`: indicates how to split raymarching process. Use a list of integers to specify samples of every chunk, or a positive integer to specify number of chunks. Use a negative interger to split by number of hits in chunks, and the absolute value means maximum number of hits in a chunk. 0 and `None` means not splitting the raymarching process. Defaults to `None` :return `dict`: render result { 'color'[, 'depth', 'layers', 'states', ...] } """ if samples.size[1] == 0: print("VolumnRenderer.forward(): # of samples is zero") return None infer_outputs = set() for key in outputs: if key == "color": infer_outputs.add("colors") infer_outputs.add("densities") elif key == "specular": infer_outputs.add("speculars") infer_outputs.add("densities") elif key == "diffuse": infer_outputs.add("diffuses") infer_outputs.add("densities") elif key == "depth": infer_outputs.add("densities") else: infer_outputs.add(key) s = VolumnRenderer.States(kernel, samples, raymarching_early_stop_tolerance, infer_outputs) checkpoint("Prepare states object") if not raymarching_chunk_size_or_sections: raymarching_chunk_size_or_sections = [s.P] elif isinstance(raymarching_chunk_size_or_sections, int) and \ raymarching_chunk_size_or_sections > 0: raymarching_chunk_size_or_sections = [ math.ceil(s.P / raymarching_chunk_size_or_sections) ] if isinstance(raymarching_chunk_size_or_sections, list): chunk_sections = raymarching_chunk_size_or_sections for chunk_samples in cycle(chunk_sections): self._forward_chunk(s.next_chunk(length=chunk_samples)) if s.end >= s.P: break else: chunk_size = -raymarching_chunk_size_or_sections chunk_hits = s.n_hits(0) for i in range(1, s.P): n_hits = s.n_hits(i) if chunk_hits + n_hits > chunk_size: self._forward_chunk(s.next_chunk(end=i)) n_hits = s.n_hits(i) chunk_hits = 0 chunk_hits += n_hits self._forward_chunk(s.next_chunk()) checkpoint("Run forward chunks") ret = {} for key in outputs: if key == 'color': ret['color'] = torch.sum(s.colors * s.weights, 1) elif key == 'depth': ret['depth'] = torch.sum(s.samples.depths[..., None] * s.weights, 1) elif key == 'diffuse' and hasattr(s, "diffuses"): ret['diffuse'] = torch.sum(s.diffuses * s.weights, 1) elif key == 'specular' and hasattr(s, "speculars"): ret['specular'] = torch.sum(s.speculars * s.weights, 1) elif key == 'layers': ret['layers'] = torch.cat([s.colors, 1 - torch.exp(-s.energies)], dim=-1) elif key == 'states': ret['states'] = s else: if hasattr(s, key): ret[key] = getattr(s, key) checkpoint("Set return data") return ret @perf def _calc_weights(self, s: States): """ Calculate weights of samples in composited outputs :param s `States`: states :param start `int`: chunk's start :param end `int`: chunk's end """ s.energies[s.chunk] = density2energy(s.densities[s.chunk], s.samples.dists[s.chunk]) s.cum_energies[s.cum_chunk] = torch.cumsum(s.energies[s.chunk], 1) \ + s.cum_energies[s.cum_last] s.exp_energies[s.cum_chunk] = (-s.cum_energies[s.cum_chunk]).exp() s.weights[s.chunk] = s.exp_energies[s.chunk] - s.exp_energies[s.cum_chunk] @perf def _apply_early_stop(self, s: States): """ Stop rays whose accumulated opacity are larger than a threshold :param s `States`: s :param end `int`: chunk's end """ if s.end < s.P and s.early_stop_tolerance > 0 and isinstance(s.hit_mask, torch.Tensor): rays_to_stop = s.exp_energies[:, s.end, 0] < s.early_stop_tolerance s.hit_mask[rays_to_stop, s.end:] = 0 @perf def _forward_chunk(self, s: States) -> int: if isinstance(s.hit_mask, torch.Tensor): fi_idxs: Tuple[torch.Tensor, ...] = s.hit_mask[s.chunk].nonzero(as_tuple=True) if fi_idxs[0].size(0) == 0: s.cum_energies[s.cum_chunk] = s.cum_energies[s.cum_last] s.exp_energies[s.cum_chunk] = s.exp_energies[s.cum_last] return fi_idxs[1].add_(s.start) s.accumulate_tot_evaluations("colors", fi_idxs[0].size(0)) else: fi_idxs = s.chunk fi_outputs = s.kernel.infer(*s.outputs, samples=s.samples[fi_idxs], chunk_id=s.chunk_id) for key, value in fi_outputs.items(): s.put(key, value, fi_idxs) self._calc_weights(s) self._apply_early_stop(s) class DensityFirstVolumnRenderer(VolumnRenderer): def __init__(self, **kwargs): super().__init__(**kwargs) def _forward_chunk(self, s: VolumnRenderer.States) -> int: fi_idxs: Tuple[torch.Tensor, ...] = s.hit_mask[s.chunk].nonzero(as_tuple=True) # (N') fi_idxs[1].add_(s.start) if fi_idxs[0].size(0) == 0: s.cum_energies[s.cum_chunk] = s.cum_energies[s.cum_last] s.exp_energies[s.cum_chunk] = s.exp_energies[s.cum_last] return # fi_* means "filtered" by hit mask fi_samples = s.samples[fi_idxs] # N -> N' # For all valid samples: encode X density_inputs = s.kernel.input(fi_samples, "x", "f") # (N', Ex) # Infer densities (shape) density_outputs = s.kernel.infer('densities', 'features', samples=fi_samples, inputs=density_inputs, chunk_id=s.chunk_id) s.put('densities', density_outputs['densities'], fi_idxs) s.accumulate_tot_evaluations("densities", fi_idxs[0].size(0)) self._calc_weights(s) self._apply_early_stop(s) # Remove samples whose weights are less than a threshold s.hit_mask[s.chunk][s.weights[s.chunk][..., 0] < 0.01] = 0 # Update "filtered" tensors fi_mask = s.hit_mask[fi_idxs] fi_idxs = (fi_idxs[0][fi_mask], fi_idxs[1][fi_mask]) # N' -> N" fi_samples = s.samples[fi_idxs] # N -> N" fi_features = density_outputs['features'][fi_mask] color_inputs = s.kernel.input(fi_samples, "d") # (N") color_inputs.x = density_inputs.x[fi_mask] # Infer colors (appearance) outputs = s.outputs.copy() if 'densities' in outputs: outputs.remove('densities') color_outputs = s.kernel.infer(*outputs, samples=fi_samples, inputs=color_inputs, chunk_id=s.chunk_id, features=fi_features) # if s.chunk_id == 0: # fi_colors[:] *= fi_colors.new_tensor([1, 0, 0]) # elif s.chunk_id == 1: # fi_colors[:] *= fi_colors.new_tensor([0, 1, 0]) # elif s.chunk_id == 2: # fi_colors[:] *= fi_colors.new_tensor([0, 0, 1]) # else: # fi_colors[:] *= fi_colors.new_tensor([1, 1, 0]) for key, value in color_outputs.items(): s.put(key, value, fi_idxs) s.accumulate_tot_evaluations("colors", fi_idxs[0].size(0))