import torch from .__common__ import * from .nerf import NeRF class MNeRF(NeRF): """ Multi-scale NeRF """ TrainerClass = "TrainMultiScale" def freeze(self, level: int): for core in self.cores: core.set_frozen(level, True) def _create_core_unit(self): nets = [] in_chns = self.x_encoder.out_dim for core_params in self.args['core_params']: nets.append(super()._create_core_unit(core_params, x_chns=in_chns)) in_chns = self.x_encoder.out_dim + core_params['nf'] return MultiNerf(nets) @profile def _sample(self, data: InputData, **extra_args) -> Samples: samples = super()._sample(data, **extra_args) samples.level = data["level"] return samples