msl_net.py 9.51 KB
Newer Older
BobYeah's avatar
BobYeah committed
1
from typing import Tuple
BobYeah's avatar
sync    
BobYeah committed
2
3
import torch
import torch.nn as nn
BobYeah's avatar
BobYeah committed
4
from .my import net_modules
BobYeah's avatar
sync    
BobYeah committed
5
6
from .my import util
from .my import device
Nianchen Deng's avatar
sync    
Nianchen Deng committed
7
8
from .my import color_mode

BobYeah's avatar
sync    
BobYeah committed
9

BobYeah's avatar
BobYeah committed
10
11
12
rand_gen = torch.Generator(device=device.GetDevice())
rand_gen.manual_seed(torch.seed())

BobYeah's avatar
sync    
BobYeah committed
13

14
15
16
17
def RaySphereIntersect(p: torch.Tensor, v: torch.Tensor, r: torch.Tensor) -> torch.Tensor:
    """
    Calculate intersections of each rays and each spheres

BobYeah's avatar
BobYeah committed
18
19
20
21
22
    :param p ```Tensor(B, 3)```: positions of rays
    :param v ```Tensor(B, 3)```: directions of rays
    :param r ```Tensor(N)```: , radius of spheres
    :return ```Tensor(B, N, 3)```: points of intersection
    :return ```Tensor(B, N)```: depths of intersection along ray
23
    """
BobYeah's avatar
BobYeah committed
24
    # p, v: Expand to (B, 1, 3)
25
26
    p = p.unsqueeze(1)
    v = v.unsqueeze(1)
BobYeah's avatar
BobYeah committed
27
    # pp, vv, pv: (B, 1)
28
29
30
    pp = (p * p).sum(dim=2)
    vv = (v * v).sum(dim=2)
    pv = (p * v).sum(dim=2)
BobYeah's avatar
BobYeah committed
31
32
    depths = (((pv * pv - vv * (pp - r * r)).sqrt() - pv) / vv)
    return p + depths[..., None] * v, depths
33
34


BobYeah's avatar
sync    
BobYeah committed
35
36
class Rendering(nn.Module):

BobYeah's avatar
BobYeah committed
37
    def __init__(self, *, raw_noise_std: float = 0.0, white_bg: bool = False):
38
39
40
        """
        Initialize a Rendering module
        """
BobYeah's avatar
sync    
BobYeah committed
41
        super().__init__()
BobYeah's avatar
BobYeah committed
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
        self.raw_noise_std = raw_noise_std
        self.white_bg = white_bg

    def forward(self, raw, z_vals, ret_extra: bool = False):
        """Transforms model's predictions to semantically meaningful values.

        Args:
          raw: [num_rays, num_samples along ray, 4]. Prediction from model.
          z_vals: [num_rays, num_samples along ray]. Integration time.

        Returns:
          rgb_map: [num_rays, 3]. Estimated RGB color of a ray.
          disp_map: [num_rays]. Disparity map. Inverse of depth map.
          acc_map: [num_rays]. Sum of weights along each ray.
          weights: [num_rays, num_samples]. Weights assigned to each sampled color.
          depth_map: [num_rays]. Estimated distance to object.
        """
BobYeah's avatar
sync    
BobYeah committed
59
        color, alpha = self.raw2color(raw, z_vals)
BobYeah's avatar
BobYeah committed
60
61
62
63
64
65
66
67
68
69
70
71
72
73

        # Compute weight for RGB of each sample along each ray.  A cumprod() is
        # used to express the idea of the ray not having reflected up to this
        # sample yet.
        one_minus_alpha = util.broadcast_cat(
            torch.cumprod(1 - alpha[..., :-1] + 1e-10, dim=-1),
            1.0, append=False)
        weights = alpha * one_minus_alpha  # (N_rays, N_samples)

        # (N_rays, 1|3), computed weighted color of each sample along each ray.
        color_map = torch.sum(weights[..., None] * color, dim=-2)

        # To composite onto a white background, use the accumulated alpha map.
        if self.white_bg or ret_extra:
BobYeah's avatar
sync    
BobYeah committed
74
            # Sum of weights along each ray. This value is in [0, 1] up to numerical error.
BobYeah's avatar
BobYeah committed
75
76
77
78
79
80
81
82
            acc_map = torch.sum(weights, -1)
            if self.white_bg:
                color_map = color_map + (1. - acc_map[..., None])
        else:
            acc_map = None

        if not ret_extra:
            return color_map
BobYeah's avatar
sync    
BobYeah committed
83

BobYeah's avatar
BobYeah committed
84
85
86
87
        # Estimated depth map is expected distance.
        depth_map = torch.sum(weights * z_vals, dim=-1)

        # Disparity map is inverse depth.
BobYeah's avatar
sync    
BobYeah committed
88
89
        disp_map = torch.clamp_min(
            depth_map / torch.sum(weights, dim=-1), 1e-10).reciprocal()
BobYeah's avatar
BobYeah committed
90
91
92

        return color_map, disp_map, acc_map, weights, depth_map

BobYeah's avatar
sync    
BobYeah committed
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
    def raw2color(self, raw: torch.Tensor, z_vals: torch.Tensor):
        """
        Raw value inferred from model to color and alpha

        :param raw ```Tensor(N.rays, N.samples, 2|4)```: model's output
        :param z_vals ```Tensor(N.rays, N.samples)```: integration time
        :return ```Tensor(N.rays, N.samples, 1|3)```: color
        :return ```Tensor(N.rays, N.samples)```: alpha
        """

        def raw2alpha(raw, dists, act_fn=torch.relu):
            """
            Function for computing density from model prediction.
            This value is strictly between [0, 1].
            """
            return -torch.exp(-act_fn(raw) * dists) + 1.0

        # Compute 'distance' (in time) between each integration time along a ray.
        # The 'distance' from the last integration time is infinity.
        # dists: (N_rays, N_samples)
        dists = z_vals[..., 1:] - z_vals[..., :-1]
        last_dist = z_vals[..., 0:1] * 0 + 1e10
        
        dists = torch.cat([
            dists, last_dist
        ], -1)

        # Extract RGB of each sample position along each ray.
        color = torch.sigmoid(raw[..., :-1])  # (N_rays, N_samples, 1|3)

        if self.raw_noise_std > 0.:
            # Add noise to model's predictions for density. Can be used to
            # regularize network during training (prevents floater artifacts).
            noise = torch.normal(0.0, self.raw_noise_std,
                                 raw[..., 3].size(), rand_gen)
            alpha = raw2alpha(raw[..., -1] + noise, dists)
        else:
            alpha = raw2alpha(raw[..., -1], dists)

        return color, alpha

BobYeah's avatar
BobYeah committed
134
135
136
137

class Sampler(nn.Module):

    def __init__(self, *, depth_range: Tuple[float, float], n_samples: int,
BobYeah's avatar
sync    
BobYeah committed
138
                 perturb_sample: bool, spherical: bool, lindisp: bool, inverse_r: bool):
BobYeah's avatar
BobYeah committed
139
140
        """
        Initialize a Sampler module
BobYeah's avatar
sync    
BobYeah committed
141

BobYeah's avatar
BobYeah committed
142
143
144
        :param depth_range: depth range for sampler
        :param n_samples: count to sample along ray
        :param perturb_sample: perturb the sample depths
BobYeah's avatar
sync    
BobYeah committed
145
        :param lindisp: If True, sample linearly in inverse depth rather than in depth
BobYeah's avatar
BobYeah committed
146
147
        """
        super().__init__()
BobYeah's avatar
sync    
BobYeah committed
148
149
150
151
152
        if lindisp:
            self.r = 1 / torch.linspace(1 / depth_range[0], 1 / depth_range[1],
                                        n_samples, device=device.GetDevice())
        else:
            self.r = torch.linspace(depth_range[0], depth_range[1],
BobYeah's avatar
BobYeah committed
153
154
155
                                    n_samples, device=device.GetDevice())
        self.perturb_sample = perturb_sample
        self.spherical = spherical
BobYeah's avatar
sync    
BobYeah committed
156
        self.inverse_r = inverse_r
BobYeah's avatar
BobYeah committed
157
158
159
160
161
162
        if perturb_sample:
            mids = .5 * (self.r[1:] + self.r[:-1])
            self.upper = torch.cat([mids, self.r[-1:]], -1)
            self.lower = torch.cat([self.r[:1], mids], -1)

    def forward(self, rays_o, rays_d):
BobYeah's avatar
sync    
BobYeah committed
163
        """
BobYeah's avatar
BobYeah committed
164
165
        Sample points along rays. return Spherical or Cartesian coordinates, 
        specified by ```self.shperical```
BobYeah's avatar
sync    
BobYeah committed
166

BobYeah's avatar
BobYeah committed
167
168
169
170
        :param rays_o ```Tensor(B, 3)```: rays' origin
        :param rays_d ```Tensor(B, 3)```: rays' direction
        :return ```Tensor(B, N, 3)```: sampled points
        :return ```Tensor(B, N)```: corresponding depths along rays
BobYeah's avatar
sync    
BobYeah committed
171
        """
BobYeah's avatar
BobYeah committed
172
173
174
175
176
177
178
179
180
181
182
        if self.perturb_sample:
            # stratified samples in those intervals
            t_rand = torch.rand(self.r.size(),
                                generator=rand_gen,
                                device=device.GetDevice())
            r = self.lower + (self.upper - self.lower) * t_rand
        else:
            r = self.r

        if self.spherical:
            pts, depths = RaySphereIntersect(rays_o, rays_d, r)
BobYeah's avatar
sync    
BobYeah committed
183
            sphers = util.CartesianToSpherical(pts, inverse_r=self.inverse_r)
BobYeah's avatar
BobYeah committed
184
185
186
            return sphers, depths
        else:
            return rays_o[..., None, :] + rays_d[..., None, :] * r[..., None], r
187

BobYeah's avatar
sync    
BobYeah committed
188
189
190

class MslNet(nn.Module):

BobYeah's avatar
BobYeah committed
191
    def __init__(self, fc_params, sampler_params,
Nianchen Deng's avatar
sync    
Nianchen Deng committed
192
                 color: int = color_mode.RGB,
BobYeah's avatar
sync    
BobYeah committed
193
194
                 encode_to_dim: int = 0,
                 export_mode: bool = False):
195
196
197
        """
        Initialize a multi-sphere-layer net

BobYeah's avatar
BobYeah committed
198
199
        :param fc_params: parameters for full-connection network
        :param sampler_params: parameters for sampler
BobYeah's avatar
BobYeah committed
200
201
        :param gray: is grayscale mode
        :param encode_to_dim: encode input to number of dimensions
202
        """
BobYeah's avatar
sync    
BobYeah committed
203
        super().__init__()
BobYeah's avatar
BobYeah committed
204
205
206
207
        self.in_chns = 3
        self.input_encoder = net_modules.InputEncoder.Get(
            encode_to_dim, self.in_chns)
        fc_params['in_chns'] = self.input_encoder.out_dim
Nianchen Deng's avatar
sync    
Nianchen Deng committed
208
        fc_params['out_chns'] = 2 if color == color_mode.GRAY else 4
BobYeah's avatar
BobYeah committed
209
        self.sampler = Sampler(**sampler_params)
BobYeah's avatar
BobYeah committed
210
        self.rendering = Rendering()
BobYeah's avatar
sync    
BobYeah committed
211
        self.export_mode = export_mode
Nianchen Deng's avatar
sync    
Nianchen Deng committed
212
213
214
215
216
217
218
219
220
221
222
223
224
225
        if color == color_mode.YCbCr:
            self.net1 = net_modules.FcNet(
                in_chns=fc_params['in_chns'],
                out_chns=fc_params['nf'] + 2,
                nf=fc_params['nf'],
                n_layers=fc_params['n_layers'] - 2)
            self.net2 = net_modules.FcNet(
                in_chns=fc_params['nf'],
                out_chns=2,
                nf=fc_params['nf'],
                n_layers=1)
            self.net = None
        else:
            self.net = net_modules.FcNet(**fc_params)
BobYeah's avatar
BobYeah committed
226

BobYeah's avatar
sync    
BobYeah committed
227
228
    def forward(self, rays_o: torch.Tensor, rays_d: torch.Tensor,
                ret_depth: bool = False) -> torch.Tensor:
229
        """
BobYeah's avatar
BobYeah committed
230
        rays -> colors
231

BobYeah's avatar
sync    
BobYeah committed
232
233
234
        :param rays_o ```Tensor(B, 3)```: rays' origin
        :param rays_d ```Tensor(B, 3)```: rays' direction
        :return: ```Tensor(B, C)``, inferred images/pixels
235
        """
BobYeah's avatar
sync    
BobYeah committed
236
        coords, depths = self.sampler(rays_o, rays_d)
BobYeah's avatar
BobYeah committed
237
        encoded = self.input_encoder(coords)
BobYeah's avatar
sync    
BobYeah committed
238

Nianchen Deng's avatar
sync    
Nianchen Deng committed
239
240
241
242
243
244
245
246
247
248
        if not self.net:
            mid_output = self.net1(encoded)
            net2_output = self.net2(mid_output[..., :-2])
            raw = torch.cat([
                mid_output[..., -2:],
                net2_output
            ], -1)
        else:
            raw = self.net(encoded)
        
BobYeah's avatar
sync    
BobYeah committed
249
        if self.export_mode:
Nianchen Deng's avatar
sync    
Nianchen Deng committed
250
            colors, alphas = self.rendering.raw2color(raw, depths)
BobYeah's avatar
sync    
BobYeah committed
251
252
253
254
            return torch.cat([colors, alphas[..., None]], -1)

        if ret_depth:
            color_map, _, _, _, depth_map = self.rendering(
Nianchen Deng's avatar
sync    
Nianchen Deng committed
255
                raw, depths, ret_extra=True)
BobYeah's avatar
sync    
BobYeah committed
256
            return color_map, depth_map
BobYeah's avatar
BobYeah committed
257
        
Nianchen Deng's avatar
sync    
Nianchen Deng committed
258
        return self.rendering(raw, depths)