import torch
from modules import *
from .nerf import *


class NeRFAdvance(NeRF):

    RendererClass = DensityFirstVolumnRenderer

    def __init__(self, args0: dict, args1: dict = {}):
        super().__init__(args0, args1)

    def _new_core_unit(self):
        return NerfAdvCore(
            x_chns=self.pot_encoder.out_dim,
            d_chns=self.dir_encoder.out_dim,
            density_chns=self.chns('density'),
            color_chns=self.chns('color'),
            density_net_params=self.args["density_net"],
            color_net_params=self.args["color_net"],
            specular_net_params=self.args.get("specular_net"),
            appearance=self.args.get("appearance", "decomposite"),
            density_color_connection=self.args.get("density_color_connection", False)
        )

    def infer(self, x: torch.Tensor, d: torch.Tensor, *outputs, extras={}, **kwargs) -> Dict[str, torch.Tensor]:
        """
        Infer colors, energies and other values (specified by `outputs`) of samples 
        (invalid items are filtered out) given their encoded positions and directions

        :param x `Tensor(N, Ex)`: encoded positions
        :param d `Tensor(N, Ed)`: encoded directions 
        :param outputs `str...`: which types of inferred data should be returned
        :param extras `dict`: extra data needed by cores
        :return `Dict[str, Tensor(N, *)]`: outputs of cores
        """
        return self.core(x, d, outputs, **extras)