import torch from modules import * from .nerf import * class NeRFAdvance(NeRF): RendererClass = DensityFirstVolumnRenderer def __init__(self, args0: dict, args1: dict = {}): super().__init__(args0, args1) def _new_core_unit(self): return NerfAdvCore( x_chns=self.pot_encoder.out_dim, d_chns=self.dir_encoder.out_dim, density_chns=self.chns('density'), color_chns=self.chns('color'), density_net_params=self.args["density_net"], color_net_params=self.args["color_net"], specular_net_params=self.args.get("specular_net"), appearance=self.args.get("appearance", "decomposite"), density_color_connection=self.args.get("density_color_connection", False) ) def infer(self, x: torch.Tensor, d: torch.Tensor, *outputs, extras={}, **kwargs) -> Dict[str, torch.Tensor]: """ Infer colors, energies and other values (specified by `outputs`) of samples (invalid items are filtered out) given their encoded positions and directions :param x `Tensor(N, Ex)`: encoded positions :param d `Tensor(N, Ed)`: encoded directions :param outputs `str...`: which types of inferred data should be returned :param extras `dict`: extra data needed by cores :return `Dict[str, Tensor(N, *)]`: outputs of cores """ return self.core(x, d, outputs, **extras)