renderer.py 3.79 KB
Newer Older
Nianchen Deng's avatar
Nianchen Deng committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import torch
import torch.nn as nn
import torch.nn.functional as nn_f
from utils.constants import *
from .generic import *


class AlphaComposition(nn.Module):

    def __init__(self):
        super().__init__()

    def forward(self, colors, alphas, bg=None):
        # Compute weight for RGB of each sample along each ray.  A cumprod() is
        # used to express the idea of the ray not having reflected up to this
        # sample yet.
        one_minus_alpha = torch.cumprod(1 - alphas[..., :-1] + TINY_FLOAT, dim=-1)
        one_minus_alpha = torch.cat([
            torch.ones_like(one_minus_alpha[..., 0:1]),
            one_minus_alpha
        ], dim=-1)
        weights = alphas * one_minus_alpha  # (N_rays, N)

        # (N_rays, 1|3), computed weighted color of each sample along each ray.
        final_color = torch.sum(weights[..., None] * colors, dim=-2)

        # To composite onto a white background, use the accumulated alpha map.
        if bg is not None:
            # Sum of weights along each ray. This value is in [0, 1] up to numerical error.
            acc_map = torch.sum(weights, -1)
            final_color = final_color + bg * (1. - acc_map[..., None])

        return {
            'color': final_color,
            'weights': weights,
        }


class VolumnRenderer(nn.Module):

    def __init__(self, *, raw_noise_std=0.0, sigma_as_density=True):
        """
        Initialize a Rendering module
        """
        super().__init__()
        self.alpha_composition = AlphaComposition()
        self.sigma_as_density = sigma_as_density
        self.raw_noise_std = raw_noise_std

    def forward(self, colors, sigmas, z_vals, bg_color=None, ret_depth=False, debug=False):
        """Transforms model's predictions to semantically meaningful values.

        Args:
          color: [num_rays, num_samples along ray, 1|3]. Predicted color from model.
          density: [num_rays, num_samples along ray]. Predicted density from model.
          z_vals: [num_rays, num_samples along ray]. Integration time.

        Returns:
          rgb_map: [num_rays, 1|3]. Estimated RGB color of a ray.
          disp_map: [num_rays]. Disparity map. Inverse of depth map.
          acc_map: [num_rays]. Sum of weights along each ray.
          weights: [num_rays, num_samples]. Weights assigned to each sampled color.
          depth_map: [num_rays]. Estimated distance to object.
        """
        alphas = self.density2alpha(sigmas, z_vals) if self.sigma_as_density \
            else nn_f.sigmoid(sigmas)
        ret = self.alpha_composition(colors, alphas, bg_color)
        if ret_depth:
            ret['depth'] = torch.sum(ret['weights'] * z_vals, dim=-1)
        if debug:
            ret['layers'] = torch.cat([colors, alphas[..., None]], dim=-1)
        return ret

    def density2alpha(self, densities: torch.Tensor, z_vals: torch.Tensor):
        """
        Raw value inferred from model to color and alpha

        :param densities `Tensor(N.rays, N.samples)`: model's output density
        :param z_vals `Tensor(N.rays, N.samples)`: integration time
        :return `Tensor(N.rays, N.samples)`: alpha
        """

        # Compute 'distance' (in time) between each integration time along a ray.
        # The 'distance' from the last integration time is infinity.
        # dists: (N_rays, N)
        dists = z_vals[..., 1:] - z_vals[..., :-1]
        last_dist = torch.zeros_like(z_vals[..., 0:1]) + TINY_FLOAT
        dists = torch.cat([dists, last_dist], -1)

        if self.raw_noise_std > 0.:
            # Add noise to model's predictions for density. Can be used to
            # regularize network during training (prevents floater artifacts).
            noise = torch.normal(0.0, self.raw_noise_std, densities.size())
            densities = densities + noise
        return -torch.exp(-torch.relu(densities) * dists) + 1.0