import torch import torch.nn as nn import torch.nn.functional as nn_f from utils.constants import * from .generic import * class AlphaComposition(nn.Module): def __init__(self): super().__init__() def forward(self, colors, alphas, bg=None): # Compute weight for RGB of each sample along each ray. A cumprod() is # used to express the idea of the ray not having reflected up to this # sample yet. one_minus_alpha = torch.cumprod(1 - alphas[..., :-1] + TINY_FLOAT, dim=-1) one_minus_alpha = torch.cat([ torch.ones_like(one_minus_alpha[..., 0:1]), one_minus_alpha ], dim=-1) weights = alphas * one_minus_alpha # (N_rays, N) # (N_rays, 1|3), computed weighted color of each sample along each ray. final_color = torch.sum(weights[..., None] * colors, dim=-2) # To composite onto a white background, use the accumulated alpha map. if bg is not None: # Sum of weights along each ray. This value is in [0, 1] up to numerical error. acc_map = torch.sum(weights, -1) final_color = final_color + bg * (1. - acc_map[..., None]) return { 'color': final_color, 'weights': weights, } class VolumnRenderer(nn.Module): def __init__(self, *, raw_noise_std=0.0, sigma_as_density=True): """ Initialize a Rendering module """ super().__init__() self.alpha_composition = AlphaComposition() self.sigma_as_density = sigma_as_density self.raw_noise_std = raw_noise_std def forward(self, colors, sigmas, z_vals, bg_color=None, ret_depth=False, debug=False): """Transforms model's predictions to semantically meaningful values. Args: color: [num_rays, num_samples along ray, 1|3]. Predicted color from model. density: [num_rays, num_samples along ray]. Predicted density from model. z_vals: [num_rays, num_samples along ray]. Integration time. Returns: rgb_map: [num_rays, 1|3]. Estimated RGB color of a ray. disp_map: [num_rays]. Disparity map. Inverse of depth map. acc_map: [num_rays]. Sum of weights along each ray. weights: [num_rays, num_samples]. Weights assigned to each sampled color. depth_map: [num_rays]. Estimated distance to object. """ alphas = self.density2alpha(sigmas, z_vals) if self.sigma_as_density \ else nn_f.sigmoid(sigmas) ret = self.alpha_composition(colors, alphas, bg_color) if ret_depth: ret['depth'] = torch.sum(ret['weights'] * z_vals, dim=-1) if debug: ret['layers'] = torch.cat([colors, alphas[..., None]], dim=-1) return ret def density2alpha(self, densities: torch.Tensor, z_vals: torch.Tensor): """ Raw value inferred from model to color and alpha :param densities `Tensor(N.rays, N.samples)`: model's output density :param z_vals `Tensor(N.rays, N.samples)`: integration time :return `Tensor(N.rays, N.samples)`: alpha """ # Compute 'distance' (in time) between each integration time along a ray. # The 'distance' from the last integration time is infinity. # dists: (N_rays, N) dists = z_vals[..., 1:] - z_vals[..., :-1] last_dist = torch.zeros_like(z_vals[..., 0:1]) + TINY_FLOAT dists = torch.cat([dists, last_dist], -1) if self.raw_noise_std > 0.: # Add noise to model's predictions for density. Can be used to # regularize network during training (prevents floater artifacts). noise = torch.normal(0.0, self.raw_noise_std, densities.size()) densities = densities + noise return -torch.exp(-torch.relu(densities) * dists) + 1.0