mnerf_advance.py 1.38 KB
Newer Older
Nianchen Deng's avatar
sync    
Nianchen Deng committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
from .__common__ import *
from .mnerf import MNeRF


class MNeRFAdvance(MNeRF):
    """
    Advanced Multi-scale NeRF
    """

    TrainerClass = "TrainMultiScale"

    def n_samples(self, level: int = -1) -> int:
        return self.args["n_samples_list"][level]

    def split(self):
        self.args0["n_samples_list"] = [val * 2 for val in self.args["n_samples_list"]]
        return super().split()

    def _sample(self, data: InputData, **extra_args) -> Samples:
Nianchen Deng's avatar
sync    
Nianchen Deng committed
20
        return super()._sample(data, **(extra_args | {"n_samples": self.n_samples(data["level"]) + 1}))
Nianchen Deng's avatar
sync    
Nianchen Deng committed
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36

    def _render(self, samples: Samples, *outputs: str, **extra_args) -> ReturnData:
        L = samples.level
        steps = self.args["n_samples_list"][L] // self.args["n_samples_list"][0]
        curr_samples = samples[:, ::steps]
        curr_samples.level = 0
        curr_samples.features = None
        for i in range(L):
            render_out = super()._render(curr_samples, 'features', **extra_args)
            next_steps = self.args["n_samples_list"][L] // self.args["n_samples_list"][i + 1]
            next_samples = samples[:, ::next_steps]
            features = curr_samples.interpolate(next_samples, render_out['features'])
            curr_samples = next_samples
            curr_samples.level = i + 1
            curr_samples.features = features
        return super()._render(curr_samples, *outputs, **extra_args)