vnerf.py 737 Bytes
Newer Older
Nianchen Deng's avatar
sync    
Nianchen Deng committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
from .__common__ import *
from .nerf import NeRF


class VNeRF(NeRF):

    def _init_chns(self):
        super()._init_chns(x=3)

    def _init_space(self):
        self.space = Space.create(self.args)
        self.space.create_voxel_embedding(self.args['n_featdim'])

    def _create_core_unit(self):
        return super()._create_core_unit(x_chns=self.x_encoder.out_dim + self.args['n_featdim'])

Nianchen Deng's avatar
sync    
Nianchen Deng committed
17
    def _input(self, samples: Samples, what: str) -> torch.Tensor | None:
Nianchen Deng's avatar
sync    
Nianchen Deng committed
18
19
20
21
22
23
24
        if what == "x":
            return torch.cat([
                self.space.extract_voxel_embedding(samples.voxel_indices),
                self._encode("x", samples.pts)
            ], dim=-1)
        else:
            return super()._input(samples, what)