from .__common__ import * from .nerf import NeRF class VNeRF(NeRF): def _init_chns(self): super()._init_chns(x=3) def _init_space(self): self.space = Space.create(self.args) self.space.create_voxel_embedding(self.args['n_featdim']) def _create_core_unit(self): return super()._create_core_unit(x_chns=self.x_encoder.out_dim + self.args['n_featdim']) def _input(self, samples: Samples, what: str) -> torch.Tensor | None: if what == "x": return torch.cat([ self.space.extract_voxel_embedding(samples.voxel_indices), self._encode("x", samples.pts) ], dim=-1) else: return super()._input(samples, what)