from .__common__ import *
from .nerf import NeRF


class VNeRF(NeRF):

    def _init_chns(self):
        super()._init_chns(x=3)

    def _init_space(self):
        self.space = Space.create(self.args)
        self.space.create_voxel_embedding(self.args['n_featdim'])

    def _create_core_unit(self):
        return super()._create_core_unit(x_chns=self.x_encoder.out_dim + self.args['n_featdim'])

    def _input(self, samples: Samples, what: str) -> torch.Tensor | None:
        if what == "x":
            return torch.cat([
                self.space.extract_voxel_embedding(samples.voxel_indices),
                self._encode("x", samples.pts)
            ], dim=-1)
        else:
            return super()._input(samples, what)