msl_net.py 12.4 KB
Newer Older
BobYeah's avatar
BobYeah committed
1
from typing import Tuple
Nianchen Deng's avatar
sync    
Nianchen Deng committed
2
import math
BobYeah's avatar
sync    
BobYeah committed
3
4
import torch
import torch.nn as nn
Nianchen Deng's avatar
sync    
Nianchen Deng committed
5
6
7
8
from ..my import net_modules
from ..my import util
from ..my import device
from ..my import color_mode
Nianchen Deng's avatar
sync    
Nianchen Deng committed
9

BobYeah's avatar
sync    
BobYeah committed
10

BobYeah's avatar
BobYeah committed
11
12
13
rand_gen = torch.Generator(device=device.GetDevice())
rand_gen.manual_seed(torch.seed())

BobYeah's avatar
sync    
BobYeah committed
14

15
16
17
18
def RaySphereIntersect(p: torch.Tensor, v: torch.Tensor, r: torch.Tensor) -> torch.Tensor:
    """
    Calculate intersections of each rays and each spheres

BobYeah's avatar
BobYeah committed
19
20
21
22
23
    :param p ```Tensor(B, 3)```: positions of rays
    :param v ```Tensor(B, 3)```: directions of rays
    :param r ```Tensor(N)```: , radius of spheres
    :return ```Tensor(B, N, 3)```: points of intersection
    :return ```Tensor(B, N)```: depths of intersection along ray
24
    """
BobYeah's avatar
BobYeah committed
25
    # p, v: Expand to (B, 1, 3)
26
27
    p = p.unsqueeze(1)
    v = v.unsqueeze(1)
BobYeah's avatar
BobYeah committed
28
    # pp, vv, pv: (B, 1)
29
30
31
    pp = (p * p).sum(dim=2)
    vv = (v * v).sum(dim=2)
    pv = (p * v).sum(dim=2)
BobYeah's avatar
BobYeah committed
32
33
    depths = (((pv * pv - vv * (pp - r * r)).sqrt() - pv) / vv)
    return p + depths[..., None] * v, depths
34
35


BobYeah's avatar
sync    
BobYeah committed
36
37
class Rendering(nn.Module):

BobYeah's avatar
BobYeah committed
38
    def __init__(self, *, raw_noise_std: float = 0.0, white_bg: bool = False):
39
40
41
        """
        Initialize a Rendering module
        """
BobYeah's avatar
sync    
BobYeah committed
42
        super().__init__()
BobYeah's avatar
BobYeah committed
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
        self.raw_noise_std = raw_noise_std
        self.white_bg = white_bg

    def forward(self, raw, z_vals, ret_extra: bool = False):
        """Transforms model's predictions to semantically meaningful values.

        Args:
          raw: [num_rays, num_samples along ray, 4]. Prediction from model.
          z_vals: [num_rays, num_samples along ray]. Integration time.

        Returns:
          rgb_map: [num_rays, 3]. Estimated RGB color of a ray.
          disp_map: [num_rays]. Disparity map. Inverse of depth map.
          acc_map: [num_rays]. Sum of weights along each ray.
          weights: [num_rays, num_samples]. Weights assigned to each sampled color.
          depth_map: [num_rays]. Estimated distance to object.
        """
BobYeah's avatar
sync    
BobYeah committed
60
        color, alpha = self.raw2color(raw, z_vals)
BobYeah's avatar
BobYeah committed
61
62
63
64
65
66
67
68
69
70
71
72
73
74

        # Compute weight for RGB of each sample along each ray.  A cumprod() is
        # used to express the idea of the ray not having reflected up to this
        # sample yet.
        one_minus_alpha = util.broadcast_cat(
            torch.cumprod(1 - alpha[..., :-1] + 1e-10, dim=-1),
            1.0, append=False)
        weights = alpha * one_minus_alpha  # (N_rays, N_samples)

        # (N_rays, 1|3), computed weighted color of each sample along each ray.
        color_map = torch.sum(weights[..., None] * color, dim=-2)

        # To composite onto a white background, use the accumulated alpha map.
        if self.white_bg or ret_extra:
BobYeah's avatar
sync    
BobYeah committed
75
            # Sum of weights along each ray. This value is in [0, 1] up to numerical error.
BobYeah's avatar
BobYeah committed
76
77
78
79
80
81
82
83
            acc_map = torch.sum(weights, -1)
            if self.white_bg:
                color_map = color_map + (1. - acc_map[..., None])
        else:
            acc_map = None

        if not ret_extra:
            return color_map
BobYeah's avatar
sync    
BobYeah committed
84

BobYeah's avatar
BobYeah committed
85
86
87
88
        # Estimated depth map is expected distance.
        depth_map = torch.sum(weights * z_vals, dim=-1)

        # Disparity map is inverse depth.
BobYeah's avatar
sync    
BobYeah committed
89
90
        disp_map = torch.clamp_min(
            depth_map / torch.sum(weights, dim=-1), 1e-10).reciprocal()
BobYeah's avatar
BobYeah committed
91
92
93

        return color_map, disp_map, acc_map, weights, depth_map

BobYeah's avatar
sync    
BobYeah committed
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
    def raw2color(self, raw: torch.Tensor, z_vals: torch.Tensor):
        """
        Raw value inferred from model to color and alpha

        :param raw ```Tensor(N.rays, N.samples, 2|4)```: model's output
        :param z_vals ```Tensor(N.rays, N.samples)```: integration time
        :return ```Tensor(N.rays, N.samples, 1|3)```: color
        :return ```Tensor(N.rays, N.samples)```: alpha
        """

        def raw2alpha(raw, dists, act_fn=torch.relu):
            """
            Function for computing density from model prediction.
            This value is strictly between [0, 1].
            """
            return -torch.exp(-act_fn(raw) * dists) + 1.0

        # Compute 'distance' (in time) between each integration time along a ray.
        # The 'distance' from the last integration time is infinity.
        # dists: (N_rays, N_samples)
        dists = z_vals[..., 1:] - z_vals[..., :-1]
        last_dist = z_vals[..., 0:1] * 0 + 1e10
Nianchen Deng's avatar
sync    
Nianchen Deng committed
116

BobYeah's avatar
sync    
BobYeah committed
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
        dists = torch.cat([
            dists, last_dist
        ], -1)

        # Extract RGB of each sample position along each ray.
        color = torch.sigmoid(raw[..., :-1])  # (N_rays, N_samples, 1|3)

        if self.raw_noise_std > 0.:
            # Add noise to model's predictions for density. Can be used to
            # regularize network during training (prevents floater artifacts).
            noise = torch.normal(0.0, self.raw_noise_std,
                                 raw[..., 3].size(), rand_gen)
            alpha = raw2alpha(raw[..., -1] + noise, dists)
        else:
            alpha = raw2alpha(raw[..., -1], dists)

        return color, alpha

BobYeah's avatar
BobYeah committed
135
136
137
138

class Sampler(nn.Module):

    def __init__(self, *, depth_range: Tuple[float, float], n_samples: int,
BobYeah's avatar
sync    
BobYeah committed
139
                 perturb_sample: bool, spherical: bool, lindisp: bool, inverse_r: bool):
BobYeah's avatar
BobYeah committed
140
141
        """
        Initialize a Sampler module
BobYeah's avatar
sync    
BobYeah committed
142

BobYeah's avatar
BobYeah committed
143
144
145
        :param depth_range: depth range for sampler
        :param n_samples: count to sample along ray
        :param perturb_sample: perturb the sample depths
BobYeah's avatar
sync    
BobYeah committed
146
        :param lindisp: If True, sample linearly in inverse depth rather than in depth
BobYeah's avatar
BobYeah committed
147
148
        """
        super().__init__()
Nianchen Deng's avatar
sync    
Nianchen Deng committed
149
150
151
152
153
154
        self.lindisp = lindisp
        if self.lindisp:
            depth_range = (1 / depth_range[0], 1 / depth_range[1])
        self.r = torch.linspace(depth_range[0], depth_range[1],
                                n_samples, device=device.GetDevice())
        step = (depth_range[1] - depth_range[0]) / (n_samples - 1)
BobYeah's avatar
BobYeah committed
155
156
        self.perturb_sample = perturb_sample
        self.spherical = spherical
BobYeah's avatar
sync    
BobYeah committed
157
        self.inverse_r = inverse_r
Nianchen Deng's avatar
sync    
Nianchen Deng committed
158
159
        self.upper = torch.clamp_min(self.r + step / 2, 0)
        self.lower = torch.clamp_min(self.r - step / 2, 0)
BobYeah's avatar
BobYeah committed
160
161

    def forward(self, rays_o, rays_d):
BobYeah's avatar
sync    
BobYeah committed
162
        """
BobYeah's avatar
BobYeah committed
163
164
        Sample points along rays. return Spherical or Cartesian coordinates, 
        specified by ```self.shperical```
BobYeah's avatar
sync    
BobYeah committed
165

BobYeah's avatar
BobYeah committed
166
167
168
169
        :param rays_o ```Tensor(B, 3)```: rays' origin
        :param rays_d ```Tensor(B, 3)```: rays' direction
        :return ```Tensor(B, N, 3)```: sampled points
        :return ```Tensor(B, N)```: corresponding depths along rays
BobYeah's avatar
sync    
BobYeah committed
170
        """
BobYeah's avatar
BobYeah committed
171
172
173
174
175
176
177
178
        if self.perturb_sample:
            # stratified samples in those intervals
            t_rand = torch.rand(self.r.size(),
                                generator=rand_gen,
                                device=device.GetDevice())
            r = self.lower + (self.upper - self.lower) * t_rand
        else:
            r = self.r
Nianchen Deng's avatar
sync    
Nianchen Deng committed
179
180
        if self.lindisp:
            r = torch.reciprocal(r)
BobYeah's avatar
BobYeah committed
181
182
183

        if self.spherical:
            pts, depths = RaySphereIntersect(rays_o, rays_d, r)
BobYeah's avatar
sync    
BobYeah committed
184
            sphers = util.CartesianToSpherical(pts, inverse_r=self.inverse_r)
Nianchen Deng's avatar
sync    
Nianchen Deng committed
185
            return sphers, pts, depths
BobYeah's avatar
BobYeah committed
186
187
        else:
            return rays_o[..., None, :] + rays_d[..., None, :] * r[..., None], r
188

Nianchen Deng's avatar
sync    
Nianchen Deng committed
189
190
191
192
# x>0, y>0 -> (y, -x)
# x<0, y>0 -> (-y, x)
# x<0, y<0 -> (y, -x)
# x>0, y<0 -> (-y, x)
BobYeah's avatar
sync    
BobYeah committed
193
194
195

class MslNet(nn.Module):

BobYeah's avatar
BobYeah committed
196
    def __init__(self, fc_params, sampler_params,
Nianchen Deng's avatar
sync    
Nianchen Deng committed
197
198
                 normalize_coord: bool,
                 dir_as_input: bool,
Nianchen Deng's avatar
sync    
Nianchen Deng committed
199
                 color: int = color_mode.RGB,
BobYeah's avatar
sync    
BobYeah committed
200
201
                 encode_to_dim: int = 0,
                 export_mode: bool = False):
202
203
204
        """
        Initialize a multi-sphere-layer net

BobYeah's avatar
BobYeah committed
205
206
        :param fc_params: parameters for full-connection network
        :param sampler_params: parameters for sampler
Nianchen Deng's avatar
sync    
Nianchen Deng committed
207
208
        :param normalize_coord: whether normalize the spherical coords to [0, 2pi] before encode
        :param color: color mode
BobYeah's avatar
BobYeah committed
209
        :param encode_to_dim: encode input to number of dimensions
210
        """
BobYeah's avatar
sync    
BobYeah committed
211
        super().__init__()
BobYeah's avatar
BobYeah committed
212
213
214
215
        self.in_chns = 3
        self.input_encoder = net_modules.InputEncoder.Get(
            encode_to_dim, self.in_chns)
        fc_params['in_chns'] = self.input_encoder.out_dim
Nianchen Deng's avatar
sync    
Nianchen Deng committed
216
        fc_params['out_chns'] = 2 if color == color_mode.GRAY else 4
BobYeah's avatar
BobYeah committed
217
        self.sampler = Sampler(**sampler_params)
BobYeah's avatar
BobYeah committed
218
        self.rendering = Rendering()
BobYeah's avatar
sync    
BobYeah committed
219
        self.export_mode = export_mode
Nianchen Deng's avatar
sync    
Nianchen Deng committed
220
221
222
223
        self.normalize_coord = normalize_coord
        self.dir_as_input = dir_as_input
        self.color = color
        if self.color == color_mode.YCbCr:
Nianchen Deng's avatar
sync    
Nianchen Deng committed
224
225
226
227
228
229
230
231
232
233
234
            self.net1 = net_modules.FcNet(
                in_chns=fc_params['in_chns'],
                out_chns=fc_params['nf'] + 2,
                nf=fc_params['nf'],
                n_layers=fc_params['n_layers'] - 2)
            self.net2 = net_modules.FcNet(
                in_chns=fc_params['nf'],
                out_chns=2,
                nf=fc_params['nf'],
                n_layers=1)
            self.net = None
Nianchen Deng's avatar
sync    
Nianchen Deng committed
235
236
237
238
239
240
241
242
243
244
245
246
247
        elif self.dir_as_input:
            self.input_encoder2 = net_modules.InputEncoder.Get(4, 2)
            self.net1 = net_modules.FcNet(
                in_chns=fc_params['in_chns'],
                out_chns=fc_params['nf'],
                nf=fc_params['nf'],
                n_layers=fc_params['n_layers'])
            self.net2 = net_modules.FcNet(
                in_chns=fc_params['nf'] + self.input_encoder2.out_dim,
                out_chns=fc_params['out_chns'],
                nf=fc_params['nf'],
                n_layers=1)
            self.net = None
Nianchen Deng's avatar
sync    
Nianchen Deng committed
248
249
        else:
            self.net = net_modules.FcNet(**fc_params)
Nianchen Deng's avatar
sync    
Nianchen Deng committed
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
        if self.normalize_coord:
            self.register_buffer('angle_range', torch.tensor(
                [[1e5, 1e5], [-1e5, -1e5]]))
            self.register_buffer('depth_range', torch.tensor([
                self.sampler.lower[0], self.sampler.upper[-1]
            ]))

    def update_normalize_range(self, rays_o: torch.Tensor, rays_d: torch.Tensor):
        coords, _, _ = self.sampler(rays_o, rays_d)
        coords = coords[..., 1:].view(-1, 2)
        self.angle_range = torch.stack([
            torch.cat([coords, self.angle_range[0:1]]).amin(0),
            torch.cat([coords, self.angle_range[1:2]]).amax(0)
        ])

    def calc_local_dir(self, rays_d, coords, pts: torch.Tensor):
        """
        [summary]
BobYeah's avatar
BobYeah committed
268

Nianchen Deng's avatar
sync    
Nianchen Deng committed
269
270
271
272
273
274
275
276
277
278
279
280
        :param rays_d ```Tensor(B, 3)```: 
        :param coords ```Tensor(B, N, 3)```: 
        :param pts ```Tensor(B, N, 3)```: 
        :return ```Tensor(B, N, 2)```
        """
        local_z = pts / pts.norm(dim=-1, keepdim=True)
        local_x = util.SphericalToCartesian(coords + torch.tensor([0, 0.1 / 180 * math.pi, 0], device=coords.device)) - pts
        local_x = local_x / local_x.norm(dim=-1, keepdim=True)
        local_y = torch.cross(local_x, local_z, -1)
        local_rot = torch.stack([local_x, local_y, local_z], dim=-2) # (B, N, 3, 3)
        return util.CartesianToSpherical(torch.matmul(rays_d[:, None, None, :], local_rot)).squeeze(-2)[..., 1:3]
        
BobYeah's avatar
sync    
BobYeah committed
281
282
    def forward(self, rays_o: torch.Tensor, rays_d: torch.Tensor,
                ret_depth: bool = False) -> torch.Tensor:
283
        """
BobYeah's avatar
BobYeah committed
284
        rays -> colors
285

BobYeah's avatar
sync    
BobYeah committed
286
287
288
        :param rays_o ```Tensor(B, 3)```: rays' origin
        :param rays_d ```Tensor(B, 3)```: rays' direction
        :return: ```Tensor(B, C)``, inferred images/pixels
289
        """
Nianchen Deng's avatar
sync    
Nianchen Deng committed
290
291
292
293
294
295
296
297
        coords, pts, depths = self.sampler(rays_o, rays_d)

        if self.dir_as_input:
            dirs = self.calc_local_dir(rays_d, coords, pts)

        if self.normalize_coord: # Normalize coords to [0, 2pi]
            range = torch.cat([self.depth_range.view(2, 1), self.angle_range], 1)
            coords = (coords - range[0]) / (range[1] - range[0]) * 2 * math.pi
BobYeah's avatar
BobYeah committed
298
        encoded = self.input_encoder(coords)
BobYeah's avatar
sync    
BobYeah committed
299

Nianchen Deng's avatar
sync    
Nianchen Deng committed
300
        if self.color == color_mode.YCbCr:
Nianchen Deng's avatar
sync    
Nianchen Deng committed
301
302
303
304
305
306
            mid_output = self.net1(encoded)
            net2_output = self.net2(mid_output[..., :-2])
            raw = torch.cat([
                mid_output[..., -2:],
                net2_output
            ], -1)
Nianchen Deng's avatar
sync    
Nianchen Deng committed
307
308
309
310
        elif self.dir_as_input:
            encoded_dirs = self.input_encoder2(dirs)
            #print(encoded.size(), self.net1(encoded).size(), encoded_dirs.size())
            raw = self.net2(torch.cat([self.net1(encoded), encoded_dirs], -1))
Nianchen Deng's avatar
sync    
Nianchen Deng committed
311
312
        else:
            raw = self.net(encoded)
Nianchen Deng's avatar
sync    
Nianchen Deng committed
313

BobYeah's avatar
sync    
BobYeah committed
314
        if self.export_mode:
Nianchen Deng's avatar
sync    
Nianchen Deng committed
315
            colors, alphas = self.rendering.raw2color(raw, depths)
BobYeah's avatar
sync    
BobYeah committed
316
317
318
319
            return torch.cat([colors, alphas[..., None]], -1)

        if ret_depth:
            color_map, _, _, _, depth_map = self.rendering(
Nianchen Deng's avatar
sync    
Nianchen Deng committed
320
                raw, depths, ret_extra=True)
BobYeah's avatar
sync    
BobYeah committed
321
            return color_map, depth_map
Nianchen Deng's avatar
sync    
Nianchen Deng committed
322

Nianchen Deng's avatar
sync    
Nianchen Deng committed
323
        return self.rendering(raw, depths)