core.py 9.85 KB
Newer Older
Nianchen Deng's avatar
sync    
Nianchen Deng committed
1
class NeRF(Module):
Nianchen Deng's avatar
sync    
Nianchen Deng committed
2

Nianchen Deng's avatar
sync    
Nianchen Deng committed
3
4
    def __init__(self, *, x_chns, density_chns, color_chns, nf, n_layers,
                 d_chns=0, d_nf=0, act='relu', skips=[], with_layer_norm=False,
Nianchen Deng's avatar
sync    
Nianchen Deng committed
5
                 density_out_act='relu', color_out_act='sigmoid', feature_layer=False):
Nianchen Deng's avatar
sync    
Nianchen Deng committed
6
        super().__init__()
Nianchen Deng's avatar
sync    
Nianchen Deng committed
7
8
9
10
        self.x_chns = x_chns
        self.d_chns = d_chns
        self.field = FcBlock(in_chns=x_chns, out_chns=None, nf=nf, n_layers=n_layers,
                             skips=skips, act=act, with_ln=with_layer_norm)
Nianchen Deng's avatar
sync    
Nianchen Deng committed
11
12
        self.density_out = FcLayer(nf, density_chns, density_out_act, with_ln=False) \
            if density_chns > 0 else None
Nianchen Deng's avatar
sync    
Nianchen Deng committed
13
14
        if color_chns == 0:
            self.color_out = None
Nianchen Deng's avatar
sync    
Nianchen Deng committed
15
        elif d_chns > 0:
Nianchen Deng's avatar
sync    
Nianchen Deng committed
16
            self.feature_layer = feature_layer and FcLayer(nf, nf, with_ln=False)
Nianchen Deng's avatar
sync    
Nianchen Deng committed
17
18
19
20
            self.color_out = FcBlock(in_chns=nf + d_chns, out_chns=color_chns,
                                     nf=d_nf or nf // 2, n_layers=1,
                                     act=act, out_act=color_out_act, with_ln=with_layer_norm)
            self.with_dir = True
Nianchen Deng's avatar
sync    
Nianchen Deng committed
21
        else:
Nianchen Deng's avatar
sync    
Nianchen Deng committed
22
23
            self.color_out = FcLayer(nf, color_chns, color_out_act, with_ln=False)
            self.with_dir = False
Nianchen Deng's avatar
sync    
Nianchen Deng committed
24

Nianchen Deng's avatar
sync    
Nianchen Deng committed
25
26
27
28
29
30
    def forward(self, inputs: NetInput, *outputs: str, field_out: torch.Tensor = None, **kwargs) -> NetOutput:
        ret = NetOutput()
        if field_out is None:
            field_out = self.field(inputs.x)
        if 'field_out' in outputs:
            ret.field_out = field_out
Nianchen Deng's avatar
sync    
Nianchen Deng committed
31
        if 'densities' in outputs and self.density_out:
Nianchen Deng's avatar
sync    
Nianchen Deng committed
32
            ret.densities = self.density_out(field_out)
Nianchen Deng's avatar
sync    
Nianchen Deng committed
33
34
        if 'colors' in outputs and self.color_out:
            if self.with_dir:
Nianchen Deng's avatar
sync    
Nianchen Deng committed
35
36
37
38
39
40
                if self.feature_layer:
                    h = self.feature_layer(field_out)
                h = union(h, inputs.d)
            else:
                h = field_out
            ret.colors = self.color_out(h)
Nianchen Deng's avatar
sync    
Nianchen Deng committed
41
42
        return ret

Nianchen Deng's avatar
sync    
Nianchen Deng committed
43
44
45
46
47
48
    def get_exporter(self):
        return ModelExporter(self.infer, "densities", "colors", x=[self.x_chns], d=[self.d_chns])

    def infer(self, x: torch.Tensor, d: torch.Tensor = None, f: torch.Tensor = None):
        return tuple(self._forward(NetInput(x, d, f), "colors", "densities").values())

Nianchen Deng's avatar
sync    
Nianchen Deng committed
49

Nianchen Deng's avatar
sync    
Nianchen Deng committed
50
class NerfAdvCore(Module):
Nianchen Deng's avatar
sync    
Nianchen Deng committed
51
52

    def __init__(self, *, x_chns: int, d_chns: int, density_chns: int, color_chns: int,
Nianchen Deng's avatar
sync    
Nianchen Deng committed
53
54
                 density_net: dict, color_net: dict, specular_net: dict = None,
                 appearance="decomposite", with_layer_norm=False, f_chns=0):
Nianchen Deng's avatar
sync    
Nianchen Deng committed
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
        """
        Create a NeRF-Adv Core Net.
        Required parameters for the sub-mlps include: "nf", "n_layers", "skips" and "act".
        Other parameters will be properly set automatically.

        :param x_chns `int`: the channels of input "position"
        :param d_chns `int`: the channels of input "direction"
        :param density_chns `int`: the channels of output "density"
        :param color_chns `int`: the channels of output "color"
        :param density_net_params `dict`: parameters for the density net
        :param color_net_params `dict`: parameters for the color net
        :param specular_net_params `dict`: (optional) parameters for the optional specular net, defaults to None
        :param appearance `str`: (optional) options are [decomposite|combined], defaults to "decomposite"
        """
        super().__init__()
Nianchen Deng's avatar
sync    
Nianchen Deng committed
70
        self.input_f = f_chns > 0
Nianchen Deng's avatar
sync    
Nianchen Deng committed
71
72
        self.density_chns = density_chns
        self.color_chns = color_chns
Nianchen Deng's avatar
sync    
Nianchen Deng committed
73
74
        self.specular_feature_chns = color_net["nf"] if specular_net else 0
        self.color_feature_chns = density_net["nf"]
Nianchen Deng's avatar
sync    
Nianchen Deng committed
75
        self.appearance = appearance
Nianchen Deng's avatar
sync    
Nianchen Deng committed
76
77
78
79
80
        self.density_net = FcBlock(**density_net,
                                   in_chns=x_chns + f_chns,
                                   out_chns=self.density_chns + self.color_feature_chns,
                                   out_act='relu',
                                   with_ln=with_layer_norm)
Nianchen Deng's avatar
sync    
Nianchen Deng committed
81
82
        if self.appearance == "newtype":
            self.specular_feature_chns = d_chns * 3
Nianchen Deng's avatar
sync    
Nianchen Deng committed
83
84
85
86
            self.color_net = FcBlock(**color_net,
                                     in_chns=x_chns + self.color_feature_chns,
                                     out_chns=self.color_chns + self.specular_feature_chns,
                                     with_ln=with_layer_norm)
Nianchen Deng's avatar
sync    
Nianchen Deng committed
87
88
            self.specular_net = "Placeholder"
        else:
Nianchen Deng's avatar
sync    
Nianchen Deng committed
89
90
91
92
93
94
95
96
97
98
99
            match = re.match("mlp_basis\((\d+)\)", self.appearance)
            if match is not None:
                basis_dim = int(match.group(1))
                self.color_net = FcBlock(**color_net,
                                         in_chns=x_chns + self.color_feature_chns,
                                         out_chns=self.color_chns * basis_dim)
            elif self.appearance == "decomposite":
                self.color_net = FcBlock(**color_net,
                                         in_chns=x_chns + self.color_feature_chns,
                                         out_chns=self.color_chns + self.specular_feature_chns,
                                         with_ln=with_layer_norm)
Nianchen Deng's avatar
sync    
Nianchen Deng committed
100
            else:
Nianchen Deng's avatar
sync    
Nianchen Deng committed
101
102
103
104
105
                if specular_net:
                    self.color_net = FcBlock(**color_net,
                                             in_chns=x_chns + self.color_feature_chns,
                                             out_chns=self.specular_feature_chns,
                                             with_ln=with_layer_norm)
Nianchen Deng's avatar
sync    
Nianchen Deng committed
106
                else:
Nianchen Deng's avatar
sync    
Nianchen Deng committed
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
                    self.color_net = FcBlock(**color_net,
                                             in_chns=x_chns + d_chns + self.color_feature_chns,
                                             out_chns=self.color_chns,
                                             with_ln=with_layer_norm)
            self.specular_net = FcBlock(**specular_net,
                                        in_chns=d_chns + self.specular_feature_chns,
                                        out_chns=self.color_chns,
                                        with_ln=with_layer_norm) if specular_net else None

    def forward(self, inputs: NetInput, *outputs: str, features: torch.Tensor = None, **kwargs) -> NetOutput:
        output_shape = inputs.shape

        ret: NetOutput = {}

        if 'densities' in outputs or 'features' in outputs:
            density_net_in = union(inputs.x, inputs.f) if self.input_f else inputs.x
            density_net_out: torch.Tensor = self.density_net(density_net_in)
            densities, features = split(density_net_out, self.density_chns, -1)
            if 'features' in outputs:
                ret['features'] = features
            if 'densities' in outputs:
                ret['densities'] = densities

        if 'colors' in outputs or 'specluars' in outputs or 'diffuses' in outputs:
            if 'densities' in ret:
                valid_mask = ret['densities'][..., 0].detach() >= 1e-4
Nianchen Deng's avatar
sync    
Nianchen Deng committed
133
                indices: tuple[torch.Tensor, ...] = valid_mask.nonzero(as_tuple=True)
Nianchen Deng's avatar
sync    
Nianchen Deng committed
134
                inputs, features = inputs[indices], features[indices]
Nianchen Deng's avatar
sync    
Nianchen Deng committed
135
136
137
            else:
                indices = None

Nianchen Deng's avatar
sync    
Nianchen Deng committed
138
            color_net_in = [inputs.x, features]
Nianchen Deng's avatar
sync    
Nianchen Deng committed
139
            if not self.specular_net:
Nianchen Deng's avatar
sync    
Nianchen Deng committed
140
141
142
143
                color_net_in.append(inputs.d)
            color_net_out: torch.Tensor = self.color_net(union(*color_net_in))
            diffuses = color_net_out[..., :self.color_chns]
            specular_features = color_net_out[..., -self.specular_feature_chns:]
Nianchen Deng's avatar
sync    
Nianchen Deng committed
144
145

            if self.appearance == "newtype":
Nianchen Deng's avatar
sync    
Nianchen Deng committed
146
147
148
                speculars = torch.matmul(
                    specular_features.reshape(*inputs.shape, -1, inputs.d.shape[-1]),
                    inputs.d[..., None])[..., 0]
Nianchen Deng's avatar
sync    
Nianchen Deng committed
149
150
151
152
153
154
155
156
                # TODO relu or not?
                diffuses = diffuses.relu()
                speculars = speculars.relu()
                colors = diffuses + speculars
            else:
                if not self.specular_net:
                    colors = diffuses
                    diffuses = None
Nianchen Deng's avatar
sync    
Nianchen Deng committed
157
                    speculars = None
Nianchen Deng's avatar
sync    
Nianchen Deng committed
158
                else:
Nianchen Deng's avatar
sync    
Nianchen Deng committed
159
                    specular_net_out = self.specular_net(union(inputs.d, specular_features))
Nianchen Deng's avatar
sync    
Nianchen Deng committed
160
161
162
163
164
                    if self.appearance == "decomposite":
                        speculars = specular_net_out
                        colors = diffuses + speculars
                    else:
                        diffuses = None
Nianchen Deng's avatar
sync    
Nianchen Deng committed
165
                        speculars = None
Nianchen Deng's avatar
sync    
Nianchen Deng committed
166
                        colors = specular_net_out
Nianchen Deng's avatar
sync    
Nianchen Deng committed
167
168
169
170
171
172
173
174
175
176
177
178
                colors = torch.sigmoid(colors)  # TODO indent or not?

            def postprocess(data: torch.Tensor):
                return data.new_zeros(*output_shape, data.shape[-1]).index_put(indices, data)\
                    if indices is not None else data

            if 'colors' in outputs:
                ret['colors'] = postprocess(colors)
            if 'diffuses' in outputs and diffuses is not None:
                ret['diffuses'] = postprocess(diffuses)
            if 'speculars' in outputs and speculars is not None:
                ret['speculars'] = postprocess(speculars)
Nianchen Deng's avatar
sync    
Nianchen Deng committed
179
        return ret
Nianchen Deng's avatar
sync    
Nianchen Deng committed
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208


class MultiNerf(Module):

    @property
    def n_levels(self):
        return len(self.nets)

    def __init__(self, nets: Iterable[Module]):
        super().__init__()
        self.nets = nets
        for i in len(nets):
            self.add_module(f"Level {i}", nets[i])

    def set_frozen(self, level: int, on: bool):
        for net in self.nets:
            net[level].train(not on)

    def forward(self, inputs: NetInput, *outputs: str, samples: Samples, **kwargs) -> NetOutput:
        L = samples.level
        if L == 0:
            return self.nets[L](inputs, *outputs)
        if samples.features is not None:
            return self.nets[L](NetInput(inputs.x, inputs.d, samples.features), *outputs)
        features = self.nets[0](inputs, 'features')['features']
        for i in range(1, L):
            features = self.nets[i](NetInput(inputs.x, inputs.d, features), 'features')['features']
        samples.features = features
        return self(inputs, *outputs, samples=samples)