from .__common__ import * from .mnerf import MNeRF class MNeRFAdvance(MNeRF): """ Advanced Multi-scale NeRF """ TrainerClass = "TrainMultiScale" def n_samples(self, level: int = -1) -> int: return self.args["n_samples_list"][level] def split(self): self.args0["n_samples_list"] = [val * 2 for val in self.args["n_samples_list"]] return super().split() def _sample(self, data: InputData, **extra_args) -> Samples: return super()._sample(data, **(extra_args | {"n_samples": self.n_samples(data["level"]) + 1})) def _render(self, samples: Samples, *outputs: str, **extra_args) -> ReturnData: L = samples.level steps = self.args["n_samples_list"][L] // self.args["n_samples_list"][0] curr_samples = samples[:, ::steps] curr_samples.level = 0 curr_samples.features = None for i in range(L): render_out = super()._render(curr_samples, 'features', **extra_args) next_steps = self.args["n_samples_list"][L] // self.args["n_samples_list"][i + 1] next_samples = samples[:, ::next_steps] features = curr_samples.interpolate(next_samples, render_out['features']) curr_samples = next_samples curr_samples.level = i + 1 curr_samples.features = features return super()._render(curr_samples, *outputs, **extra_args)