mnerf.py 782 Bytes
Newer Older
Nianchen Deng's avatar
sync    
Nianchen Deng committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
import torch
from .__common__ import *
from .nerf import NeRF


class MNeRF(NeRF):
    """
    Multi-scale NeRF
    """

    TrainerClass = "TrainMultiScale"

    def freeze(self, level: int):
        for core in self.cores:
            core.set_frozen(level, True)

    def _create_core_unit(self):
        nets = []
        in_chns = self.x_encoder.out_dim
        for core_params in self.args['core_params']:
            nets.append(super()._create_core_unit(core_params, x_chns=in_chns))
            in_chns = self.x_encoder.out_dim + core_params['nf']
        return MultiNerf(nets)

Nianchen Deng's avatar
sync    
Nianchen Deng committed
25
    @profile
Nianchen Deng's avatar
sync    
Nianchen Deng committed
26
27
28
29
    def _sample(self, data: InputData, **extra_args) -> Samples:
        samples = super()._sample(data, **extra_args)
        samples.level = data["level"]
        return samples