msl_net.py 6.18 KB
Newer Older
1
2
from typing import List, Tuple
from math import pi
BobYeah's avatar
sync    
BobYeah committed
3
4
5
6
7
8
9
import torch
import torch.nn as nn
from .pytorch_prototyping.pytorch_prototyping import *
from .my import util
from .my import device


10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
def CartesianToSpherical(cart: torch.Tensor) -> torch.Tensor:
    """
    Convert coordinates from Cartesian to Spherical

    :param cart: ... x 3, coordinates in Cartesian
    :return: ... x 3, coordinates in Spherical (r, theta, phi)
    """
    rho = torch.norm(cart, p=2, dim=-1)
    theta = torch.atan2(cart[..., 2], cart[..., 0])
    theta = theta + (theta < 0).type_as(theta) * (2 * pi)
    phi = torch.acos(cart[..., 1] / rho)
    return torch.stack([rho, theta, phi], dim=-1)


def SphericalToCartesian(spher: torch.Tensor) -> torch.Tensor:
    """
    Convert coordinates from Spherical to Cartesian

    :param spher: ... x 3, coordinates in Spherical
    :return: ... x 3, coordinates in Cartesian (r, theta, phi)
    """
    rho = spher[..., 0]
    sin_theta_phi = torch.sin(spher[..., 1:3])
    cos_theta_phi = torch.cos(spher[..., 1:3])
    x = rho * cos_theta_phi[..., 0] * sin_theta_phi[..., 1]
    y = rho * cos_theta_phi[..., 1]
    z = rho * sin_theta_phi[..., 0] * sin_theta_phi[..., 1]
    return torch.stack([x, y, z], dim=-1)


def RaySphereIntersect(p: torch.Tensor, v: torch.Tensor, r: torch.Tensor) -> torch.Tensor:
    """
    Calculate intersections of each rays and each spheres

    :param p: B x 3, positions of rays
    :param v: B x 3, directions of rays
    :param r: B'(1D), radius of spheres
    :return: B x B' x 3, points of intersection
    """
    # p, v: Expand to B x 1 x 3
    p = p.unsqueeze(1)
    v = v.unsqueeze(1)
    # pp, vv, pv: B x 1
    pp = (p * p).sum(dim=2)
    vv = (v * v).sum(dim=2)
    pv = (p * v).sum(dim=2)
    # k: Expand to B x B' x 1
    k = (((pv * pv - vv * (pp - r * r)).sqrt() - pv) / vv).unsqueeze(2)
    return p + k * v


def RayToSpherical(p: torch.Tensor, v: torch.Tensor, r: torch.Tensor) -> torch.Tensor:
    """
    Calculate intersections of each rays and each spheres

    :param p: B x 3, positions of rays
    :param v: B x 3, directions of rays
    :param r: B' x 1, radius of spheres
    :return: B x B' x 3, spherical coordinates
    """
    p_on_spheres = RaySphereIntersect(p, v, r)
    return CartesianToSpherical(p_on_spheres)


BobYeah's avatar
sync    
BobYeah committed
74
75
class FcNet(nn.Module):

76
    def __init__(self, in_chns: int, out_chns: int, nf: int, n_layers: int):
BobYeah's avatar
sync    
BobYeah committed
77
78
        super().__init__()
        self.layers = list()
79
80
81
82
83
        self.layers += [
            nn.Linear(in_chns, nf),
            #nn.LayerNorm([nf]),
            nn.ReLU()
        ]
BobYeah's avatar
sync    
BobYeah committed
84
        for _ in range(1, n_layers):
85
86
87
88
89
            self.layers += [
                nn.Linear(nf, nf),
                #nn.LayerNorm([nf]),
                nn.ReLU()
            ]
BobYeah's avatar
sync    
BobYeah committed
90
91
        self.layers.append(nn.Linear(nf, out_chns))
        self.net = nn.Sequential(*self.layers)
92
        self.net.apply(self.init_weights)
BobYeah's avatar
sync    
BobYeah committed
93

94
    def forward(self, x: torch.Tensor) -> torch.Tensor:
BobYeah's avatar
sync    
BobYeah committed
95
96
        return self.net(x)

97
98
99
100
101
102
    def init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.xavier_normal_(m.weight)
            nn.init.constant_(m.bias, 0.0)


BobYeah's avatar
sync    
BobYeah committed
103
104
class Rendering(nn.Module):

105
106
107
108
109
110
    def __init__(self, sphere_layers: List[float]):
        """
        Initialize a Rendering module

        :param sphere_layers: L x 1, radius of sphere layers
        """
BobYeah's avatar
sync    
BobYeah committed
111
        super().__init__()
112
113
        self.sphere_layers = torch.tensor(
            sphere_layers, device=device.GetDevice())
BobYeah's avatar
sync    
BobYeah committed
114

115
    def forward(self, net: FcNet, p: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
BobYeah's avatar
sync    
BobYeah committed
116
117
118
        """
        [summary]

119
120
121
122
        :param net: the full-connected net
        :param p: B x 3, positions of rays
        :param v: B x 3, directions of rays
        :return B x 1/3, view images by blended layers
BobYeah's avatar
sync    
BobYeah committed
123
        """
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
        L = self.sphere_layers.size()[0]
        sp = RayToSpherical(p, v, self.sphere_layers)  # B x L x 3
        sp[..., 0] = 1 / sp[..., 0]                    # Radius to diopter
        color_alpha: torch.Tensor = net(
            sp.flatten(0, 1)).view(p.size()[0], L, -1)
        if (color_alpha.size(-1) == 2):  # Grayscale
            c = color_alpha[..., 0:1]
            a = color_alpha[..., 1:2]
        else:                           # RGB
            c = color_alpha[..., 0:3]
            a = color_alpha[..., 3:4]
        blended = c[:, 0, :] * a[:, 0, :]
        for l in range(1, L):
            blended = blended * (1 - a[:, l, :]) + c[:, l, :] * a[:, l, :]
        return blended

BobYeah's avatar
sync    
BobYeah committed
140
141
142

class MslNet(nn.Module):

143
144
145
146
147
148
149
150
    def __init__(self, cam_params, sphere_layers: List[float], out_res: Tuple[int, int], gray=False):
        """
        Initialize a multi-sphere-layer net

        :param cam_params: intrinsic parameters of camera
        :param sphere_layers: L x 1, radius of sphere layers
        :param out_res: resolution of output view image
        """
BobYeah's avatar
sync    
BobYeah committed
151
        super().__init__()
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
        self.cam_params = cam_params
        self.out_res = out_res
        self.v_local = util.GetLocalViewRays(self.cam_params, out_res, flatten=True) \
            .to(device.GetDevice()) # N x 3
        #self.net = FCBlock(hidden_ch=64,
        #                   num_hidden_layers=4,
        #                   in_features=3,
        #                   out_features=2 if gray else 4,
        #                   outermost_linear=True)
        self.net = FcNet(in_chns=3, out_chns=2 if gray else 4, nf=256, n_layers=8)
        self.rendering = Rendering(sphere_layers)

    def forward(self, view_centers: torch.Tensor, view_rots: torch.Tensor) -> torch.Tensor:
        """
        T_view -> image

        :param view_centers: B x 3, centers of views
        :param view_rots: B x 3 x 3, rotation matrices of views
        :return: B x 1/3 x H_out x W_out, inferred images of views
        """
        # Transpose matrix so we can perform vec x mat
        view_rots_t = view_rots.permute(0, 2, 1)
BobYeah's avatar
sync    
BobYeah committed
174

175
176
177
178
179
180
181
182
        # p and v are B x N x 3 tensor
        p = view_centers.unsqueeze(1).expand(-1, self.v_local.size(0), -1)
        v = torch.matmul(self.v_local, view_rots_t)
        c: torch.Tensor = self.rendering(
            self.net, p.flatten(0, 1), v.flatten(0, 1))  # (BN) x 3
        # unflatten
        return c.view(view_centers.size(0), self.out_res[0],
                      self.out_res[1], -1).permute(0, 3, 1, 2)