from .generic import * from typing import Dict class NerfCore(nn.Module): def __init__(self, *, coord_chns, density_chns, color_chns, core_nf, core_layers, dir_chns=0, dir_nf=0, act='relu', skips=[]): super().__init__() self.core = FcNet(in_chns=coord_chns, out_chns=None, nf=core_nf, n_layers=core_layers, skips=skips, act=act) self.density_out = FcLayer(core_nf, density_chns) if density_chns > 0 else None if color_chns == 0: self.feature_out = None self.color_out = None elif dir_chns > 0: self.feature_out = FcLayer(core_nf, core_nf) self.color_out = nn.Sequential( FcLayer(core_nf + dir_chns, dir_nf, act), FcLayer(dir_nf, color_chns) ) else: self.feature_out = torch.nn.Identity() self.color_out = FcLayer(core_nf, color_chns) def forward(self, x: torch.Tensor, d: torch.Tensor, outputs: List[str]) -> Dict[str, torch.Tensor]: ret = {} core_output = self.core(x) if 'density' in outputs: ret['density'] = torch.relu(self.density_out(core_output)) \ if self.density_out is not None else None if 'color' in outputs: if self.color_out is None: ret['color'] = None else: feature = self.feature_out(core_output) if dir is not None: feature = torch.cat([feature, d], dim=-1) ret['color'] = self.color_out(feature).sigmoid() for key in outputs: if key == 'density' or key == 'color': continue ret[key] = None return ret class NerfAdvCore(nn.Module): def __init__(self, *, x_chns: int, d_chns: int, density_chns: int, color_chns: int, density_net_params: dict, color_net_params: dict, specular_net_params: dict = None, appearance="decomposite", density_color_connection=False): """ Create a NeRF-Adv Core Net. Required parameters for the sub-mlps include: "nf", "n_layers", "skips" and "act". Other parameters will be properly set automatically. :param x_chns `int`: the channels of input "position" :param d_chns `int`: the channels of input "direction" :param density_chns `int`: the channels of output "density" :param color_chns `int`: the channels of output "color" :param density_net_params `dict`: parameters for the density net :param color_net_params `dict`: parameters for the color net :param specular_net_params `dict`: (optional) parameters for the optional specular net, defaults to None :param appearance `str`: (optional) options are [decomposite|combined], defaults to "decomposite" :param density_color_connection `bool`: (optional) whether to add connections between density net and color net, defaults to False """ super().__init__() self.density_chns = density_chns self.color_chns = color_chns self.specular_feature_chns = color_net_params["nf"] if specular_net_params else 0 self.color_feature_chns = density_net_params["nf"] if density_color_connection else 0 self.appearance = appearance self.density_color_connection = density_color_connection self.density_net = FcNet(**density_net_params, in_chns=x_chns, out_chns=self.density_chns + self.color_feature_chns, out_act='relu') if self.appearance == "newtype": self.specular_feature_chns = d_chns * 3 self.color_net = FcNet(**color_net_params, in_chns=x_chns + self.color_feature_chns, out_chns=self.color_chns + self.specular_feature_chns) self.specular_net = "Placeholder" else: if self.appearance == "decomposite": self.color_net = FcNet(**color_net_params, in_chns=x_chns + self.color_feature_chns, out_chns=self.color_chns + self.specular_feature_chns) else: if specular_net_params: self.color_net = FcNet(**color_net_params, in_chns=x_chns + self.color_feature_chns, out_chns=self.specular_feature_chns) else: self.color_net = FcNet(**color_net_params, in_chns=x_chns + d_chns + self.color_feature_chns, out_chns=self.color_chns) self.specular_net = FcNet(**specular_net_params, in_chns=d_chns + self.specular_feature_chns, out_chns=self.color_chns) if specular_net_params else None def forward(self, x: torch.Tensor, d: torch.Tensor, outputs: List[str], *, color_feats: torch.Tensor = None) -> Dict[str, torch.Tensor]: input_shape = x.shape[:-1] if len(input_shape) > 1: x = x.flatten(0, -2) d = d.flatten(0, -2) n = x.shape[0] c = self.color_chns ret: Dict[str, torch.Tensor] = {} if 'density' in outputs: density_net_out: torch.Tensor = self.density_net(x) ret['density'] = density_net_out[:, :self.density_chns] color_feats = density_net_out[:, self.density_chns:] if 'color_feat' in outputs: ret['color_feat'] = color_feats if 'color' in outputs or 'specluar' in outputs: if 'density' in ret: valid_mask = ret['density'][:, 0].detach() >= 1e-4 indices = valid_mask.nonzero()[:, 0] x, d, color_feats = x[indices], d[indices], color_feats[indices] else: indices = None speculars = None color_net_in = [x] if not self.specular_net: color_net_in.append(d) if self.density_color_connection: color_net_in.append(color_feats) color_net_in = torch.cat(color_net_in, -1) color_net_out: torch.Tensor = self.color_net(color_net_in) diffuses = color_net_out[:, :c] specular_features = color_net_out[:, -self.specular_feature_chns:] if self.appearance == "newtype": speculars = torch.bmm(specular_features.reshape(n, 3, d.shape[-1]), d[..., None])[..., 0] # TODO relu or not? diffuses = diffuses.relu() speculars = speculars.relu() colors = diffuses + speculars else: if not self.specular_net: colors = diffuses diffuses = None else: specular_net_in = torch.cat([d, specular_features], -1) specular_net_out = self.specular_net(specular_net_in) if self.appearance == "decomposite": speculars = specular_net_out colors = diffuses + speculars else: diffuses = None colors = specular_net_out colors = torch.sigmoid(colors) # TODO indent or not? if 'color' in outputs: ret['color'] = colors.new_zeros(n, c).index_copy(0, indices, colors) \ if indices else colors if 'diffuse' in outputs: ret['diffuse'] = diffuses.new_zeros(n, c).index_copy(0, indices, diffuses) \ if indices is not None and diffuses is not None else diffuses if 'specular' in outputs: ret['specular'] = speculars.new_zeros(n, c).index_copy(0, indices, speculars) \ if indices is not None and speculars is not None else speculars if len(input_shape) > 1: ret = {key: val.reshape(*input_shape, -1) for key, val in ret.items()} return ret