class NeRF(Module): def __init__(self, *, x_chns, density_chns, color_chns, nf, n_layers, d_chns=0, d_nf=0, act='relu', skips=[], with_layer_norm=False, density_out_act='relu', color_out_act='sigmoid', feature_layer=False): super().__init__() self.x_chns = x_chns self.d_chns = d_chns self.field = FcBlock(in_chns=x_chns, out_chns=None, nf=nf, n_layers=n_layers, skips=skips, act=act, with_ln=with_layer_norm) self.density_out = FcLayer(nf, density_chns, density_out_act, with_ln=False) \ if density_chns > 0 else None if color_chns == 0: self.color_out = None elif d_chns > 0: self.feature_layer = feature_layer and FcLayer(nf, nf, with_ln=False) self.color_out = FcBlock(in_chns=nf + d_chns, out_chns=color_chns, nf=d_nf or nf // 2, n_layers=1, act=act, out_act=color_out_act, with_ln=with_layer_norm) self.with_dir = True else: self.color_out = FcLayer(nf, color_chns, color_out_act, with_ln=False) self.with_dir = False def forward(self, inputs: NetInput, *outputs: str, field_out: torch.Tensor = None, **kwargs) -> NetOutput: ret = NetOutput() if field_out is None: field_out = self.field(inputs.x) if 'field_out' in outputs: ret.field_out = field_out if 'densities' in outputs and self.density_out: ret.densities = self.density_out(field_out) if 'colors' in outputs and self.color_out: if self.with_dir: if self.feature_layer: h = self.feature_layer(field_out) h = union(h, inputs.d) else: h = field_out ret.colors = self.color_out(h) return ret def get_exporter(self): return ModelExporter(self.infer, "densities", "colors", x=[self.x_chns], d=[self.d_chns]) def infer(self, x: torch.Tensor, d: torch.Tensor = None, f: torch.Tensor = None): return tuple(self._forward(NetInput(x, d, f), "colors", "densities").values()) class NerfAdvCore(Module): def __init__(self, *, x_chns: int, d_chns: int, density_chns: int, color_chns: int, density_net: dict, color_net: dict, specular_net: dict = None, appearance="decomposite", with_layer_norm=False, f_chns=0): """ Create a NeRF-Adv Core Net. Required parameters for the sub-mlps include: "nf", "n_layers", "skips" and "act". Other parameters will be properly set automatically. :param x_chns `int`: the channels of input "position" :param d_chns `int`: the channels of input "direction" :param density_chns `int`: the channels of output "density" :param color_chns `int`: the channels of output "color" :param density_net_params `dict`: parameters for the density net :param color_net_params `dict`: parameters for the color net :param specular_net_params `dict`: (optional) parameters for the optional specular net, defaults to None :param appearance `str`: (optional) options are [decomposite|combined], defaults to "decomposite" """ super().__init__() self.input_f = f_chns > 0 self.density_chns = density_chns self.color_chns = color_chns self.specular_feature_chns = color_net["nf"] if specular_net else 0 self.color_feature_chns = density_net["nf"] self.appearance = appearance self.density_net = FcBlock(**density_net, in_chns=x_chns + f_chns, out_chns=self.density_chns + self.color_feature_chns, out_act='relu', with_ln=with_layer_norm) if self.appearance == "newtype": self.specular_feature_chns = d_chns * 3 self.color_net = FcBlock(**color_net, in_chns=x_chns + self.color_feature_chns, out_chns=self.color_chns + self.specular_feature_chns, with_ln=with_layer_norm) self.specular_net = "Placeholder" else: match = re.match("mlp_basis\((\d+)\)", self.appearance) if match is not None: basis_dim = int(match.group(1)) self.color_net = FcBlock(**color_net, in_chns=x_chns + self.color_feature_chns, out_chns=self.color_chns * basis_dim) elif self.appearance == "decomposite": self.color_net = FcBlock(**color_net, in_chns=x_chns + self.color_feature_chns, out_chns=self.color_chns + self.specular_feature_chns, with_ln=with_layer_norm) else: if specular_net: self.color_net = FcBlock(**color_net, in_chns=x_chns + self.color_feature_chns, out_chns=self.specular_feature_chns, with_ln=with_layer_norm) else: self.color_net = FcBlock(**color_net, in_chns=x_chns + d_chns + self.color_feature_chns, out_chns=self.color_chns, with_ln=with_layer_norm) self.specular_net = FcBlock(**specular_net, in_chns=d_chns + self.specular_feature_chns, out_chns=self.color_chns, with_ln=with_layer_norm) if specular_net else None def forward(self, inputs: NetInput, *outputs: str, features: torch.Tensor = None, **kwargs) -> NetOutput: output_shape = inputs.shape ret: NetOutput = {} if 'densities' in outputs or 'features' in outputs: density_net_in = union(inputs.x, inputs.f) if self.input_f else inputs.x density_net_out: torch.Tensor = self.density_net(density_net_in) densities, features = split(density_net_out, self.density_chns, -1) if 'features' in outputs: ret['features'] = features if 'densities' in outputs: ret['densities'] = densities if 'colors' in outputs or 'specluars' in outputs or 'diffuses' in outputs: if 'densities' in ret: valid_mask = ret['densities'][..., 0].detach() >= 1e-4 indices: tuple[torch.Tensor, ...] = valid_mask.nonzero(as_tuple=True) inputs, features = inputs[indices], features[indices] else: indices = None color_net_in = [inputs.x, features] if not self.specular_net: color_net_in.append(inputs.d) color_net_out: torch.Tensor = self.color_net(union(*color_net_in)) diffuses = color_net_out[..., :self.color_chns] specular_features = color_net_out[..., -self.specular_feature_chns:] if self.appearance == "newtype": speculars = torch.matmul( specular_features.reshape(*inputs.shape, -1, inputs.d.shape[-1]), inputs.d[..., None])[..., 0] # TODO relu or not? diffuses = diffuses.relu() speculars = speculars.relu() colors = diffuses + speculars else: if not self.specular_net: colors = diffuses diffuses = None speculars = None else: specular_net_out = self.specular_net(union(inputs.d, specular_features)) if self.appearance == "decomposite": speculars = specular_net_out colors = diffuses + speculars else: diffuses = None speculars = None colors = specular_net_out colors = torch.sigmoid(colors) # TODO indent or not? def postprocess(data: torch.Tensor): return data.new_zeros(*output_shape, data.shape[-1]).index_put(indices, data)\ if indices is not None else data if 'colors' in outputs: ret['colors'] = postprocess(colors) if 'diffuses' in outputs and diffuses is not None: ret['diffuses'] = postprocess(diffuses) if 'speculars' in outputs and speculars is not None: ret['speculars'] = postprocess(speculars) return ret class MultiNerf(Module): @property def n_levels(self): return len(self.nets) def __init__(self, nets: Iterable[Module]): super().__init__() self.nets = nets for i in len(nets): self.add_module(f"Level {i}", nets[i]) def set_frozen(self, level: int, on: bool): for net in self.nets: net[level].train(not on) def forward(self, inputs: NetInput, *outputs: str, samples: Samples, **kwargs) -> NetOutput: L = samples.level if L == 0: return self.nets[L](inputs, *outputs) if samples.features is not None: return self.nets[L](NetInput(inputs.x, inputs.d, samples.features), *outputs) features = self.nets[0](inputs, 'features')['features'] for i in range(1, L): features = self.nets[i](NetInput(inputs.x, inputs.d, features), 'features')['features'] samples.features = features return self(inputs, *outputs, samples=samples)