msl_net.py 6.09 KB
Newer Older
Nianchen Deng's avatar
sync    
Nianchen Deng committed
1
import math
BobYeah's avatar
sync    
BobYeah committed
2
3
import torch
import torch.nn as nn
Nianchen Deng's avatar
Nianchen Deng committed
4
from .modules import *
Nianchen Deng's avatar
sync    
Nianchen Deng committed
5
6
from ..my import util
from ..my import color_mode
Nianchen Deng's avatar
sync    
Nianchen Deng committed
7

BobYeah's avatar
sync    
BobYeah committed
8
9
10

class MslNet(nn.Module):

BobYeah's avatar
BobYeah committed
11
    def __init__(self, fc_params, sampler_params,
Nianchen Deng's avatar
sync    
Nianchen Deng committed
12
13
                 normalize_coord: bool,
                 dir_as_input: bool,
Nianchen Deng's avatar
sync    
Nianchen Deng committed
14
                 color: int = color_mode.RGB,
BobYeah's avatar
sync    
BobYeah committed
15
16
                 encode_to_dim: int = 0,
                 export_mode: bool = False):
17
18
19
        """
        Initialize a multi-sphere-layer net

BobYeah's avatar
BobYeah committed
20
21
        :param fc_params: parameters for full-connection network
        :param sampler_params: parameters for sampler
Nianchen Deng's avatar
sync    
Nianchen Deng committed
22
23
        :param normalize_coord: whether normalize the spherical coords to [0, 2pi] before encode
        :param color: color mode
BobYeah's avatar
BobYeah committed
24
        :param encode_to_dim: encode input to number of dimensions
25
        """
BobYeah's avatar
sync    
BobYeah committed
26
        super().__init__()
BobYeah's avatar
BobYeah committed
27
        self.in_chns = 3
Nianchen Deng's avatar
Nianchen Deng committed
28
        self.input_encoder = InputEncoder.Get(
BobYeah's avatar
BobYeah committed
29
30
            encode_to_dim, self.in_chns)
        fc_params['in_chns'] = self.input_encoder.out_dim
Nianchen Deng's avatar
sync    
Nianchen Deng committed
31
        fc_params['out_chns'] = 2 if color == color_mode.GRAY else 4
BobYeah's avatar
BobYeah committed
32
        self.sampler = Sampler(**sampler_params)
BobYeah's avatar
BobYeah committed
33
        self.rendering = Rendering()
BobYeah's avatar
sync    
BobYeah committed
34
        self.export_mode = export_mode
Nianchen Deng's avatar
sync    
Nianchen Deng committed
35
36
37
38
        self.normalize_coord = normalize_coord
        self.dir_as_input = dir_as_input
        self.color = color
        if self.color == color_mode.YCbCr:
Nianchen Deng's avatar
Nianchen Deng committed
39
            self.net1 = FcNet(
Nianchen Deng's avatar
sync    
Nianchen Deng committed
40
41
42
43
                in_chns=fc_params['in_chns'],
                out_chns=fc_params['nf'] + 2,
                nf=fc_params['nf'],
                n_layers=fc_params['n_layers'] - 2)
Nianchen Deng's avatar
Nianchen Deng committed
44
            self.net2 = FcNet(
Nianchen Deng's avatar
sync    
Nianchen Deng committed
45
46
47
48
49
                in_chns=fc_params['nf'],
                out_chns=2,
                nf=fc_params['nf'],
                n_layers=1)
            self.net = None
Nianchen Deng's avatar
sync    
Nianchen Deng committed
50
        elif self.dir_as_input:
Nianchen Deng's avatar
Nianchen Deng committed
51
52
            self.input_encoder2 = InputEncoder.Get(4, 2)
            self.net1 = FcNet(
Nianchen Deng's avatar
sync    
Nianchen Deng committed
53
54
55
56
                in_chns=fc_params['in_chns'],
                out_chns=fc_params['nf'],
                nf=fc_params['nf'],
                n_layers=fc_params['n_layers'])
Nianchen Deng's avatar
Nianchen Deng committed
57
            self.net2 = FcNet(
Nianchen Deng's avatar
sync    
Nianchen Deng committed
58
59
60
61
62
                in_chns=fc_params['nf'] + self.input_encoder2.out_dim,
                out_chns=fc_params['out_chns'],
                nf=fc_params['nf'],
                n_layers=1)
            self.net = None
Nianchen Deng's avatar
sync    
Nianchen Deng committed
63
        else:
Nianchen Deng's avatar
Nianchen Deng committed
64
            self.net = FcNet(**fc_params)
Nianchen Deng's avatar
sync    
Nianchen Deng committed
65
66
67
68
69
70
        if self.normalize_coord:
            self.register_buffer('angle_range', torch.tensor(
                [[1e5, 1e5], [-1e5, -1e5]]))
            self.register_buffer('depth_range', torch.tensor([
                self.sampler.lower[0], self.sampler.upper[-1]
            ]))
Nianchen Deng's avatar
Nianchen Deng committed
71
        self.n_samples = sampler_params['n_samples']
Nianchen Deng's avatar
sync    
Nianchen Deng committed
72
73
74
75
76
77
78
79
80
81
82
83

    def update_normalize_range(self, rays_o: torch.Tensor, rays_d: torch.Tensor):
        coords, _, _ = self.sampler(rays_o, rays_d)
        coords = coords[..., 1:].view(-1, 2)
        self.angle_range = torch.stack([
            torch.cat([coords, self.angle_range[0:1]]).amin(0),
            torch.cat([coords, self.angle_range[1:2]]).amax(0)
        ])

    def calc_local_dir(self, rays_d, coords, pts: torch.Tensor):
        """
        [summary]
BobYeah's avatar
BobYeah committed
84

Nianchen Deng's avatar
sync    
Nianchen Deng committed
85
86
87
88
89
90
        :param rays_d ```Tensor(B, 3)```: 
        :param coords ```Tensor(B, N, 3)```: 
        :param pts ```Tensor(B, N, 3)```: 
        :return ```Tensor(B, N, 2)```
        """
        local_z = pts / pts.norm(dim=-1, keepdim=True)
Nianchen Deng's avatar
Nianchen Deng committed
91
92
        local_x = util.SphericalToCartesian(
            coords + torch.tensor([0, 0.1 / 180 * math.pi, 0], device=coords.device)) - pts
Nianchen Deng's avatar
sync    
Nianchen Deng committed
93
94
        local_x = local_x / local_x.norm(dim=-1, keepdim=True)
        local_y = torch.cross(local_x, local_z, -1)
Nianchen Deng's avatar
Nianchen Deng committed
95
96
97
98
        local_rot = torch.stack(
            [local_x, local_y, local_z], dim=-2)  # (B, N, 3, 3)
        return util.CartesianToSpherical(torch.matmul(
            rays_d[:, None, None, :], local_rot)).squeeze(-2)[..., 1:3]
99

Nianchen Deng's avatar
Nianchen Deng committed
100
101
102
103
104
    def sample_and_infer(self, rays_o: torch.Tensor, rays_d: torch.Tensor,
                         sampler: Sampler = None) -> torch.Tensor:
        if not sampler:
            sampler = self.sampler
        coords, pts, depths = sampler(rays_o, rays_d)
Nianchen Deng's avatar
sync    
Nianchen Deng committed
105
106
107
108

        if self.dir_as_input:
            dirs = self.calc_local_dir(rays_d, coords, pts)

Nianchen Deng's avatar
Nianchen Deng committed
109
110
111
        if self.normalize_coord:  # Normalize coords to [0, 2pi]
            range = torch.cat(
                [self.depth_range.view(2, 1), self.angle_range], 1)
Nianchen Deng's avatar
sync    
Nianchen Deng committed
112
            coords = (coords - range[0]) / (range[1] - range[0]) * 2 * math.pi
BobYeah's avatar
BobYeah committed
113
        encoded = self.input_encoder(coords)
BobYeah's avatar
sync    
BobYeah committed
114

Nianchen Deng's avatar
sync    
Nianchen Deng committed
115
        if self.color == color_mode.YCbCr:
Nianchen Deng's avatar
sync    
Nianchen Deng committed
116
117
118
119
120
121
            mid_output = self.net1(encoded)
            net2_output = self.net2(mid_output[..., :-2])
            raw = torch.cat([
                mid_output[..., -2:],
                net2_output
            ], -1)
Nianchen Deng's avatar
sync    
Nianchen Deng committed
122
123
124
125
        elif self.dir_as_input:
            encoded_dirs = self.input_encoder2(dirs)
            #print(encoded.size(), self.net1(encoded).size(), encoded_dirs.size())
            raw = self.net2(torch.cat([self.net1(encoded), encoded_dirs], -1))
Nianchen Deng's avatar
sync    
Nianchen Deng committed
126
127
        else:
            raw = self.net(encoded)
Nianchen Deng's avatar
Nianchen Deng committed
128
        return raw, depths
Nianchen Deng's avatar
sync    
Nianchen Deng committed
129

Nianchen Deng's avatar
Nianchen Deng committed
130
131
132
133
134
135
136
137
138
139
    def forward(self, rays_o: torch.Tensor, rays_d: torch.Tensor,
                ret_depth: bool = False, sampler: Sampler = None) -> torch.Tensor:
        """
        rays -> colors

        :param rays_o ```Tensor(B, 3)```: rays' origin
        :param rays_d ```Tensor(B, 3)```: rays' direction
        :return: ```Tensor(B, C)``, inferred images/pixels
        """
        raw, depths = self.sample_and_infer(rays_o, rays_d, sampler)
BobYeah's avatar
sync    
BobYeah committed
140
        if self.export_mode:
Nianchen Deng's avatar
sync    
Nianchen Deng committed
141
            colors, alphas = self.rendering.raw2color(raw, depths)
BobYeah's avatar
sync    
BobYeah committed
142
143
144
145
            return torch.cat([colors, alphas[..., None]], -1)

        if ret_depth:
            color_map, _, _, _, depth_map = self.rendering(
Nianchen Deng's avatar
sync    
Nianchen Deng committed
146
                raw, depths, ret_extra=True)
BobYeah's avatar
sync    
BobYeah committed
147
            return color_map, depth_map
Nianchen Deng's avatar
sync    
Nianchen Deng committed
148

Nianchen Deng's avatar
sync    
Nianchen Deng committed
149
        return self.rendering(raw, depths)
Nianchen Deng's avatar
Nianchen Deng committed
150
151
152
153
154
155
156
157
158
159
160
161


class ExportNet(nn.Module):

    def __init__(self, net: MslNet):
        super().__init__()
        self.net = net

    def forward(self, encoded: torch.Tensor, depths: torch.Tensor) -> torch.Tensor:
        raw = self.net.net(encoded)
        colors, alphas = self.net.rendering.raw2color(raw, depths)
        return torch.cat([colors, alphas[..., None]], -1)