msl_net.py 8.72 KB
Newer Older
BobYeah's avatar
BobYeah committed
1
from typing import Tuple
BobYeah's avatar
sync    
BobYeah committed
2
3
import torch
import torch.nn as nn
BobYeah's avatar
BobYeah committed
4
from .my import net_modules
BobYeah's avatar
sync    
BobYeah committed
5
6
7
from .my import util
from .my import device

BobYeah's avatar
BobYeah committed
8
9
10
rand_gen = torch.Generator(device=device.GetDevice())
rand_gen.manual_seed(torch.seed())

BobYeah's avatar
sync    
BobYeah committed
11

12
13
14
15
def RaySphereIntersect(p: torch.Tensor, v: torch.Tensor, r: torch.Tensor) -> torch.Tensor:
    """
    Calculate intersections of each rays and each spheres

BobYeah's avatar
BobYeah committed
16
17
18
19
20
    :param p ```Tensor(B, 3)```: positions of rays
    :param v ```Tensor(B, 3)```: directions of rays
    :param r ```Tensor(N)```: , radius of spheres
    :return ```Tensor(B, N, 3)```: points of intersection
    :return ```Tensor(B, N)```: depths of intersection along ray
21
    """
BobYeah's avatar
BobYeah committed
22
    # p, v: Expand to (B, 1, 3)
23
24
    p = p.unsqueeze(1)
    v = v.unsqueeze(1)
BobYeah's avatar
BobYeah committed
25
    # pp, vv, pv: (B, 1)
26
27
28
    pp = (p * p).sum(dim=2)
    vv = (v * v).sum(dim=2)
    pv = (p * v).sum(dim=2)
BobYeah's avatar
BobYeah committed
29
30
    depths = (((pv * pv - vv * (pp - r * r)).sqrt() - pv) / vv)
    return p + depths[..., None] * v, depths
31
32


BobYeah's avatar
sync    
BobYeah committed
33
34
class Rendering(nn.Module):

BobYeah's avatar
BobYeah committed
35
    def __init__(self, *, raw_noise_std: float = 0.0, white_bg: bool = False):
36
37
38
        """
        Initialize a Rendering module
        """
BobYeah's avatar
sync    
BobYeah committed
39
        super().__init__()
BobYeah's avatar
BobYeah committed
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
        self.raw_noise_std = raw_noise_std
        self.white_bg = white_bg

    def forward(self, raw, z_vals, ret_extra: bool = False):
        """Transforms model's predictions to semantically meaningful values.

        Args:
          raw: [num_rays, num_samples along ray, 4]. Prediction from model.
          z_vals: [num_rays, num_samples along ray]. Integration time.

        Returns:
          rgb_map: [num_rays, 3]. Estimated RGB color of a ray.
          disp_map: [num_rays]. Disparity map. Inverse of depth map.
          acc_map: [num_rays]. Sum of weights along each ray.
          weights: [num_rays, num_samples]. Weights assigned to each sampled color.
          depth_map: [num_rays]. Estimated distance to object.
        """
BobYeah's avatar
sync    
BobYeah committed
57
        color, alpha = self.raw2color(raw, z_vals)
BobYeah's avatar
BobYeah committed
58
59
60
61
62
63
64
65
66
67
68
69
70
71

        # Compute weight for RGB of each sample along each ray.  A cumprod() is
        # used to express the idea of the ray not having reflected up to this
        # sample yet.
        one_minus_alpha = util.broadcast_cat(
            torch.cumprod(1 - alpha[..., :-1] + 1e-10, dim=-1),
            1.0, append=False)
        weights = alpha * one_minus_alpha  # (N_rays, N_samples)

        # (N_rays, 1|3), computed weighted color of each sample along each ray.
        color_map = torch.sum(weights[..., None] * color, dim=-2)

        # To composite onto a white background, use the accumulated alpha map.
        if self.white_bg or ret_extra:
BobYeah's avatar
sync    
BobYeah committed
72
            # Sum of weights along each ray. This value is in [0, 1] up to numerical error.
BobYeah's avatar
BobYeah committed
73
74
75
76
77
78
79
80
            acc_map = torch.sum(weights, -1)
            if self.white_bg:
                color_map = color_map + (1. - acc_map[..., None])
        else:
            acc_map = None

        if not ret_extra:
            return color_map
BobYeah's avatar
sync    
BobYeah committed
81

BobYeah's avatar
BobYeah committed
82
83
84
85
        # Estimated depth map is expected distance.
        depth_map = torch.sum(weights * z_vals, dim=-1)

        # Disparity map is inverse depth.
BobYeah's avatar
sync    
BobYeah committed
86
87
        disp_map = torch.clamp_min(
            depth_map / torch.sum(weights, dim=-1), 1e-10).reciprocal()
BobYeah's avatar
BobYeah committed
88
89
90

        return color_map, disp_map, acc_map, weights, depth_map

BobYeah's avatar
sync    
BobYeah committed
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
    def raw2color(self, raw: torch.Tensor, z_vals: torch.Tensor):
        """
        Raw value inferred from model to color and alpha

        :param raw ```Tensor(N.rays, N.samples, 2|4)```: model's output
        :param z_vals ```Tensor(N.rays, N.samples)```: integration time
        :return ```Tensor(N.rays, N.samples, 1|3)```: color
        :return ```Tensor(N.rays, N.samples)```: alpha
        """

        def raw2alpha(raw, dists, act_fn=torch.relu):
            """
            Function for computing density from model prediction.
            This value is strictly between [0, 1].
            """
            return -torch.exp(-act_fn(raw) * dists) + 1.0

        # Compute 'distance' (in time) between each integration time along a ray.
        # The 'distance' from the last integration time is infinity.
        # dists: (N_rays, N_samples)
        dists = z_vals[..., 1:] - z_vals[..., :-1]
        last_dist = z_vals[..., 0:1] * 0 + 1e10
        
        dists = torch.cat([
            dists, last_dist
        ], -1)

        # Extract RGB of each sample position along each ray.
        color = torch.sigmoid(raw[..., :-1])  # (N_rays, N_samples, 1|3)

        if self.raw_noise_std > 0.:
            # Add noise to model's predictions for density. Can be used to
            # regularize network during training (prevents floater artifacts).
            noise = torch.normal(0.0, self.raw_noise_std,
                                 raw[..., 3].size(), rand_gen)
            alpha = raw2alpha(raw[..., -1] + noise, dists)
        else:
            alpha = raw2alpha(raw[..., -1], dists)

        return color, alpha

BobYeah's avatar
BobYeah committed
132
133
134
135

class Sampler(nn.Module):

    def __init__(self, *, depth_range: Tuple[float, float], n_samples: int,
BobYeah's avatar
sync    
BobYeah committed
136
                 perturb_sample: bool, spherical: bool, lindisp: bool, inverse_r: bool):
BobYeah's avatar
BobYeah committed
137
138
        """
        Initialize a Sampler module
BobYeah's avatar
sync    
BobYeah committed
139

BobYeah's avatar
BobYeah committed
140
141
142
        :param depth_range: depth range for sampler
        :param n_samples: count to sample along ray
        :param perturb_sample: perturb the sample depths
BobYeah's avatar
sync    
BobYeah committed
143
        :param lindisp: If True, sample linearly in inverse depth rather than in depth
BobYeah's avatar
BobYeah committed
144
145
        """
        super().__init__()
BobYeah's avatar
sync    
BobYeah committed
146
147
148
149
150
        if lindisp:
            self.r = 1 / torch.linspace(1 / depth_range[0], 1 / depth_range[1],
                                        n_samples, device=device.GetDevice())
        else:
            self.r = torch.linspace(depth_range[0], depth_range[1],
BobYeah's avatar
BobYeah committed
151
152
153
                                    n_samples, device=device.GetDevice())
        self.perturb_sample = perturb_sample
        self.spherical = spherical
BobYeah's avatar
sync    
BobYeah committed
154
        self.inverse_r = inverse_r
BobYeah's avatar
BobYeah committed
155
156
157
158
159
160
        if perturb_sample:
            mids = .5 * (self.r[1:] + self.r[:-1])
            self.upper = torch.cat([mids, self.r[-1:]], -1)
            self.lower = torch.cat([self.r[:1], mids], -1)

    def forward(self, rays_o, rays_d):
BobYeah's avatar
sync    
BobYeah committed
161
        """
BobYeah's avatar
BobYeah committed
162
163
        Sample points along rays. return Spherical or Cartesian coordinates, 
        specified by ```self.shperical```
BobYeah's avatar
sync    
BobYeah committed
164

BobYeah's avatar
BobYeah committed
165
166
167
168
        :param rays_o ```Tensor(B, 3)```: rays' origin
        :param rays_d ```Tensor(B, 3)```: rays' direction
        :return ```Tensor(B, N, 3)```: sampled points
        :return ```Tensor(B, N)```: corresponding depths along rays
BobYeah's avatar
sync    
BobYeah committed
169
        """
BobYeah's avatar
BobYeah committed
170
171
172
173
174
175
176
177
178
179
180
        if self.perturb_sample:
            # stratified samples in those intervals
            t_rand = torch.rand(self.r.size(),
                                generator=rand_gen,
                                device=device.GetDevice())
            r = self.lower + (self.upper - self.lower) * t_rand
        else:
            r = self.r

        if self.spherical:
            pts, depths = RaySphereIntersect(rays_o, rays_d, r)
BobYeah's avatar
sync    
BobYeah committed
181
            sphers = util.CartesianToSpherical(pts, inverse_r=self.inverse_r)
BobYeah's avatar
BobYeah committed
182
183
184
            return sphers, depths
        else:
            return rays_o[..., None, :] + rays_d[..., None, :] * r[..., None], r
185

BobYeah's avatar
sync    
BobYeah committed
186
187
188

class MslNet(nn.Module):

BobYeah's avatar
BobYeah committed
189
190
    def __init__(self, fc_params, sampler_params,
                 gray=False,
BobYeah's avatar
sync    
BobYeah committed
191
192
                 encode_to_dim: int = 0,
                 export_mode: bool = False):
193
194
195
        """
        Initialize a multi-sphere-layer net

BobYeah's avatar
BobYeah committed
196
197
        :param fc_params: parameters for full-connection network
        :param sampler_params: parameters for sampler
BobYeah's avatar
BobYeah committed
198
199
        :param gray: is grayscale mode
        :param encode_to_dim: encode input to number of dimensions
200
        """
BobYeah's avatar
sync    
BobYeah committed
201
        super().__init__()
BobYeah's avatar
BobYeah committed
202
203
204
205
206
        self.in_chns = 3
        self.input_encoder = net_modules.InputEncoder.Get(
            encode_to_dim, self.in_chns)
        fc_params['in_chns'] = self.input_encoder.out_dim
        fc_params['out_chns'] = 2 if gray else 4
BobYeah's avatar
BobYeah committed
207
        self.sampler = Sampler(**sampler_params)
BobYeah's avatar
BobYeah committed
208
209
        self.net = net_modules.FcNet(**fc_params)
        self.rendering = Rendering()
BobYeah's avatar
sync    
BobYeah committed
210
        self.export_mode = export_mode
BobYeah's avatar
BobYeah committed
211

BobYeah's avatar
sync    
BobYeah committed
212
213
    def forward(self, rays_o: torch.Tensor, rays_d: torch.Tensor,
                ret_depth: bool = False) -> torch.Tensor:
214
        """
BobYeah's avatar
BobYeah committed
215
        rays -> colors
216

BobYeah's avatar
sync    
BobYeah committed
217
218
219
        :param rays_o ```Tensor(B, 3)```: rays' origin
        :param rays_d ```Tensor(B, 3)```: rays' direction
        :return: ```Tensor(B, C)``, inferred images/pixels
220
        """
BobYeah's avatar
sync    
BobYeah committed
221
        coords, depths = self.sampler(rays_o, rays_d)
BobYeah's avatar
BobYeah committed
222
        encoded = self.input_encoder(coords)
BobYeah's avatar
sync    
BobYeah committed
223
224
225
226
227
228
229
230
231

        if self.export_mode:
            colors, alphas = self.rendering.raw2color(self.net(encoded), depths)
            return torch.cat([colors, alphas[..., None]], -1)

        if ret_depth:
            color_map, _, _, _, depth_map = self.rendering(
                self.net(encoded), depths, ret_extra=True)
            return color_map, depth_map
BobYeah's avatar
BobYeah committed
232
        
BobYeah's avatar
sync    
BobYeah committed
233
        return self.rendering(self.net(encoded), depths)