msl_net_new.py 10.7 KB
Newer Older
BobYeah's avatar
BobYeah committed
1
from typing import Tuple
Nianchen Deng's avatar
sync    
Nianchen Deng committed
2
import math
BobYeah's avatar
sync    
BobYeah committed
3
4
import torch
import torch.nn as nn
Nianchen Deng's avatar
sync    
Nianchen Deng committed
5
6
7
8
from ..my import net_modules
from ..my import util
from ..my import device
from ..my import color_mode
Nianchen Deng's avatar
sync    
Nianchen Deng committed
9

BobYeah's avatar
sync    
BobYeah committed
10

BobYeah's avatar
BobYeah committed
11
12
13
rand_gen = torch.Generator(device=device.GetDevice())
rand_gen.manual_seed(torch.seed())

BobYeah's avatar
sync    
BobYeah committed
14

15
16
17
18
def RaySphereIntersect(p: torch.Tensor, v: torch.Tensor, r: torch.Tensor) -> torch.Tensor:
    """
    Calculate intersections of each rays and each spheres

BobYeah's avatar
BobYeah committed
19
20
21
22
23
    :param p ```Tensor(B, 3)```: positions of rays
    :param v ```Tensor(B, 3)```: directions of rays
    :param r ```Tensor(N)```: , radius of spheres
    :return ```Tensor(B, N, 3)```: points of intersection
    :return ```Tensor(B, N)```: depths of intersection along ray
24
    """
BobYeah's avatar
BobYeah committed
25
    # p, v: Expand to (B, 1, 3)
26
27
    p = p.unsqueeze(1)
    v = v.unsqueeze(1)
BobYeah's avatar
BobYeah committed
28
    # pp, vv, pv: (B, 1)
29
30
31
    pp = (p * p).sum(dim=2)
    vv = (v * v).sum(dim=2)
    pv = (p * v).sum(dim=2)
BobYeah's avatar
BobYeah committed
32
33
    depths = (((pv * pv - vv * (pp - r * r)).sqrt() - pv) / vv)
    return p + depths[..., None] * v, depths
34
35


BobYeah's avatar
sync    
BobYeah committed
36
37
class Rendering(nn.Module):

BobYeah's avatar
BobYeah committed
38
    def __init__(self, *, raw_noise_std: float = 0.0, white_bg: bool = False):
39
40
41
        """
        Initialize a Rendering module
        """
BobYeah's avatar
sync    
BobYeah committed
42
        super().__init__()
BobYeah's avatar
BobYeah committed
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
        self.raw_noise_std = raw_noise_std
        self.white_bg = white_bg

    def forward(self, raw, z_vals, ret_extra: bool = False):
        """Transforms model's predictions to semantically meaningful values.

        Args:
          raw: [num_rays, num_samples along ray, 4]. Prediction from model.
          z_vals: [num_rays, num_samples along ray]. Integration time.

        Returns:
          rgb_map: [num_rays, 3]. Estimated RGB color of a ray.
          disp_map: [num_rays]. Disparity map. Inverse of depth map.
          acc_map: [num_rays]. Sum of weights along each ray.
          weights: [num_rays, num_samples]. Weights assigned to each sampled color.
          depth_map: [num_rays]. Estimated distance to object.
        """
BobYeah's avatar
sync    
BobYeah committed
60
        color, alpha = self.raw2color(raw, z_vals)
BobYeah's avatar
BobYeah committed
61
62
63
64
65
66
67
68
69
70
71
72
73
74

        # Compute weight for RGB of each sample along each ray.  A cumprod() is
        # used to express the idea of the ray not having reflected up to this
        # sample yet.
        one_minus_alpha = util.broadcast_cat(
            torch.cumprod(1 - alpha[..., :-1] + 1e-10, dim=-1),
            1.0, append=False)
        weights = alpha * one_minus_alpha  # (N_rays, N_samples)

        # (N_rays, 1|3), computed weighted color of each sample along each ray.
        color_map = torch.sum(weights[..., None] * color, dim=-2)

        # To composite onto a white background, use the accumulated alpha map.
        if self.white_bg or ret_extra:
BobYeah's avatar
sync    
BobYeah committed
75
            # Sum of weights along each ray. This value is in [0, 1] up to numerical error.
BobYeah's avatar
BobYeah committed
76
77
78
79
80
81
82
83
            acc_map = torch.sum(weights, -1)
            if self.white_bg:
                color_map = color_map + (1. - acc_map[..., None])
        else:
            acc_map = None

        if not ret_extra:
            return color_map
BobYeah's avatar
sync    
BobYeah committed
84

BobYeah's avatar
BobYeah committed
85
86
87
88
        # Estimated depth map is expected distance.
        depth_map = torch.sum(weights * z_vals, dim=-1)

        # Disparity map is inverse depth.
BobYeah's avatar
sync    
BobYeah committed
89
90
        disp_map = torch.clamp_min(
            depth_map / torch.sum(weights, dim=-1), 1e-10).reciprocal()
BobYeah's avatar
BobYeah committed
91
92
93

        return color_map, disp_map, acc_map, weights, depth_map

BobYeah's avatar
sync    
BobYeah committed
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
    def raw2color(self, raw: torch.Tensor, z_vals: torch.Tensor):
        """
        Raw value inferred from model to color and alpha

        :param raw ```Tensor(N.rays, N.samples, 2|4)```: model's output
        :param z_vals ```Tensor(N.rays, N.samples)```: integration time
        :return ```Tensor(N.rays, N.samples, 1|3)```: color
        :return ```Tensor(N.rays, N.samples)```: alpha
        """

        def raw2alpha(raw, dists, act_fn=torch.relu):
            """
            Function for computing density from model prediction.
            This value is strictly between [0, 1].
            """
            return -torch.exp(-act_fn(raw) * dists) + 1.0

        # Compute 'distance' (in time) between each integration time along a ray.
        # The 'distance' from the last integration time is infinity.
        # dists: (N_rays, N_samples)
        dists = z_vals[..., 1:] - z_vals[..., :-1]
        last_dist = z_vals[..., 0:1] * 0 + 1e10
Nianchen Deng's avatar
sync    
Nianchen Deng committed
116

BobYeah's avatar
sync    
BobYeah committed
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
        dists = torch.cat([
            dists, last_dist
        ], -1)

        # Extract RGB of each sample position along each ray.
        color = torch.sigmoid(raw[..., :-1])  # (N_rays, N_samples, 1|3)

        if self.raw_noise_std > 0.:
            # Add noise to model's predictions for density. Can be used to
            # regularize network during training (prevents floater artifacts).
            noise = torch.normal(0.0, self.raw_noise_std,
                                 raw[..., 3].size(), rand_gen)
            alpha = raw2alpha(raw[..., -1] + noise, dists)
        else:
            alpha = raw2alpha(raw[..., -1], dists)

        return color, alpha

BobYeah's avatar
BobYeah committed
135
136
137
138

class Sampler(nn.Module):

    def __init__(self, *, depth_range: Tuple[float, float], n_samples: int,
BobYeah's avatar
sync    
BobYeah committed
139
                 perturb_sample: bool, spherical: bool, lindisp: bool, inverse_r: bool):
BobYeah's avatar
BobYeah committed
140
141
        """
        Initialize a Sampler module
BobYeah's avatar
sync    
BobYeah committed
142

BobYeah's avatar
BobYeah committed
143
144
145
        :param depth_range: depth range for sampler
        :param n_samples: count to sample along ray
        :param perturb_sample: perturb the sample depths
BobYeah's avatar
sync    
BobYeah committed
146
        :param lindisp: If True, sample linearly in inverse depth rather than in depth
BobYeah's avatar
BobYeah committed
147
148
        """
        super().__init__()
Nianchen Deng's avatar
sync    
Nianchen Deng committed
149
150
151
152
153
154
        self.lindisp = lindisp
        if self.lindisp:
            depth_range = (1 / depth_range[0], 1 / depth_range[1])
        self.r = torch.linspace(depth_range[0], depth_range[1],
                                n_samples, device=device.GetDevice())
        step = (depth_range[1] - depth_range[0]) / (n_samples - 1)
BobYeah's avatar
BobYeah committed
155
156
        self.perturb_sample = perturb_sample
        self.spherical = spherical
BobYeah's avatar
sync    
BobYeah committed
157
        self.inverse_r = inverse_r
Nianchen Deng's avatar
sync    
Nianchen Deng committed
158
159
        self.upper = torch.clamp_min(self.r + step / 2, 0)
        self.lower = torch.clamp_min(self.r - step / 2, 0)
BobYeah's avatar
BobYeah committed
160
161

    def forward(self, rays_o, rays_d):
BobYeah's avatar
sync    
BobYeah committed
162
        """
BobYeah's avatar
BobYeah committed
163
164
        Sample points along rays. return Spherical or Cartesian coordinates, 
        specified by ```self.shperical```
BobYeah's avatar
sync    
BobYeah committed
165

BobYeah's avatar
BobYeah committed
166
167
168
169
        :param rays_o ```Tensor(B, 3)```: rays' origin
        :param rays_d ```Tensor(B, 3)```: rays' direction
        :return ```Tensor(B, N, 3)```: sampled points
        :return ```Tensor(B, N)```: corresponding depths along rays
BobYeah's avatar
sync    
BobYeah committed
170
        """
BobYeah's avatar
BobYeah committed
171
172
173
174
175
176
177
178
        if self.perturb_sample:
            # stratified samples in those intervals
            t_rand = torch.rand(self.r.size(),
                                generator=rand_gen,
                                device=device.GetDevice())
            r = self.lower + (self.upper - self.lower) * t_rand
        else:
            r = self.r
Nianchen Deng's avatar
sync    
Nianchen Deng committed
179
180
        if self.lindisp:
            r = torch.reciprocal(r)
BobYeah's avatar
BobYeah committed
181
182
183

        if self.spherical:
            pts, depths = RaySphereIntersect(rays_o, rays_d, r)
BobYeah's avatar
sync    
BobYeah committed
184
            sphers = util.CartesianToSpherical(pts, inverse_r=self.inverse_r)
Nianchen Deng's avatar
sync    
Nianchen Deng committed
185
            return sphers, pts, depths
BobYeah's avatar
BobYeah committed
186
187
        else:
            return rays_o[..., None, :] + rays_d[..., None, :] * r[..., None], r
188

BobYeah's avatar
sync    
BobYeah committed
189

Nianchen Deng's avatar
sync    
Nianchen Deng committed
190
class NewMslNet(nn.Module):
BobYeah's avatar
sync    
BobYeah committed
191

BobYeah's avatar
BobYeah committed
192
    def __init__(self, fc_params, sampler_params,
Nianchen Deng's avatar
sync    
Nianchen Deng committed
193
194
195
                 normalize_coord: bool,
                 dir_as_input: bool,
                 not_same_net: bool = False,
Nianchen Deng's avatar
sync    
Nianchen Deng committed
196
                 color: int = color_mode.RGB,
BobYeah's avatar
sync    
BobYeah committed
197
198
                 encode_to_dim: int = 0,
                 export_mode: bool = False):
199
200
201
        """
        Initialize a multi-sphere-layer net

BobYeah's avatar
BobYeah committed
202
203
        :param fc_params: parameters for full-connection network
        :param sampler_params: parameters for sampler
Nianchen Deng's avatar
sync    
Nianchen Deng committed
204
205
        :param normalize_coord: whether normalize the spherical coords to [0, 2pi] before encode
        :param color: color mode
BobYeah's avatar
BobYeah committed
206
        :param encode_to_dim: encode input to number of dimensions
207
        """
BobYeah's avatar
sync    
BobYeah committed
208
        super().__init__()
BobYeah's avatar
BobYeah committed
209
210
211
212
        self.in_chns = 3
        self.input_encoder = net_modules.InputEncoder.Get(
            encode_to_dim, self.in_chns)
        fc_params['in_chns'] = self.input_encoder.out_dim
Nianchen Deng's avatar
sync    
Nianchen Deng committed
213
        fc_params['out_chns'] = 2 if color == color_mode.GRAY else 4
BobYeah's avatar
BobYeah committed
214
        self.sampler = Sampler(**sampler_params)
BobYeah's avatar
BobYeah committed
215
        self.rendering = Rendering()
BobYeah's avatar
sync    
BobYeah committed
216
        self.export_mode = export_mode
Nianchen Deng's avatar
sync    
Nianchen Deng committed
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
        self.normalize_coord = normalize_coord
        self.nets = nn.ModuleList([
            net_modules.FcNet(**fc_params),
            net_modules.FcNet(in_chns=fc_params['in_chns'],
                              out_chns=fc_params['out_chns'],
                              nf=128, n_layers=4) if not_same_net
            else net_modules.FcNet(**fc_params)
        ])
        self.n_samples = sampler_params['n_samples']
        if self.normalize_coord:
            self.register_buffer('angle_range', torch.tensor(
                [[1e5, 1e5], [-1e5, -1e5]]))
            self.register_buffer('depth_range', torch.tensor([
                [self.sampler.lower[0], self.sampler.lower[self.n_samples // 2]],
                [self.sampler.upper[self.n_samples // 2 - 1], self.sampler.upper[-1]]
            ]))

    def update_normalize_range(self, rays_o: torch.Tensor, rays_d: torch.Tensor):
        coords, _, _ = self.sampler(rays_o, rays_d)
        coords = coords[..., 1:].view(-1, 2)
        self.angle_range = torch.stack([
            torch.cat([coords, self.angle_range[0:1]]).amin(0),
            torch.cat([coords, self.angle_range[1:2]]).amax(0)
        ])
BobYeah's avatar
BobYeah committed
241

BobYeah's avatar
sync    
BobYeah committed
242
243
    def forward(self, rays_o: torch.Tensor, rays_d: torch.Tensor,
                ret_depth: bool = False) -> torch.Tensor:
244
        """
BobYeah's avatar
BobYeah committed
245
        rays -> colors
246

BobYeah's avatar
sync    
BobYeah committed
247
248
249
        :param rays_o ```Tensor(B, 3)```: rays' origin
        :param rays_d ```Tensor(B, 3)```: rays' direction
        :return: ```Tensor(B, C)``, inferred images/pixels
250
        """
Nianchen Deng's avatar
sync    
Nianchen Deng committed
251
252
253
254
255
256
257
258
259
        coords, pts, depths = self.sampler(rays_o, rays_d)
        if self.normalize_coord:  # Normalize coords to [0, 2pi]
            range = torch.cat([self.depth_range[:, 0:1], self.angle_range], 1)
            coords[:, :self.n_samples // 2] = (coords[:, :self.n_samples // 2] - range[0]) / (
                range[1] - range[0]) * 2 * math.pi
            range = torch.cat([self.depth_range[:, 1:2], self.angle_range], 1)
            coords[:, self.n_samples // 2:] = (coords[:, self.n_samples // 2:] - range[0]) / (
                range[1] - range[0]) * 2 * math.pi

BobYeah's avatar
BobYeah committed
260
        encoded = self.input_encoder(coords)
BobYeah's avatar
sync    
BobYeah committed
261

Nianchen Deng's avatar
sync    
Nianchen Deng committed
262
263
264
265
        raw = torch.cat([
            self.nets[0](encoded[:, :self.n_samples // 2]),
            self.nets[1](encoded[:, self.n_samples // 2:]),
        ], 1)
BobYeah's avatar
sync    
BobYeah committed
266
        if self.export_mode:
Nianchen Deng's avatar
sync    
Nianchen Deng committed
267
            colors, alphas = self.rendering.raw2color(raw, depths)
BobYeah's avatar
sync    
BobYeah committed
268
269
270
271
            return torch.cat([colors, alphas[..., None]], -1)

        if ret_depth:
            color_map, _, _, _, depth_map = self.rendering(
Nianchen Deng's avatar
sync    
Nianchen Deng committed
272
                raw, depths, ret_extra=True)
BobYeah's avatar
sync    
BobYeah committed
273
            return color_map, depth_map
Nianchen Deng's avatar
sync    
Nianchen Deng committed
274

Nianchen Deng's avatar
sync    
Nianchen Deng committed
275
        return self.rendering(raw, depths)