from typing import Tuple import torch import torch.nn as nn from torch.nn.modules.linear import Identity from utils.constants import * from .generic import * from .sampler import * from .input_encoder import * from .renderer import * class NerfCore(nn.Module): def __init__(self, *, coord_chns, density_chns, color_chns, core_nf, core_layers, dir_chns=0, dir_nf=0, activation='relu', skips=[]): super().__init__() self.core = FcNet(in_chns=coord_chns, out_chns=0, nf=core_nf, n_layers=core_layers, skips=skips, activation=activation) self.density_out = FcLayer(core_nf, density_chns) if density_chns > 0 else None if color_chns == 0: self.feature_out = None self.color_out = None elif dir_chns > 0: self.feature_out = FcLayer(core_nf, core_nf) self.color_out = nn.Sequential( FcLayer(core_nf + dir_chns, dir_nf, activation), FcLayer(dir_nf, color_chns) ) else: self.feature_out = Identity() self.color_out = FcLayer(core_nf, color_chns) def forward(self, coord: torch.Tensor, dir: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]: core_output = self.core(coord) density = self.density_out(core_output) if self.density_out is not None else None if self.color_out is None: color = None else: feature = self.feature_out(core_output) if dir is not None: feature = torch.cat([feature, dir], dim=-1) color = torch.sigmoid(self.color_out(feature)) return color, density