snerf_x.py 1.24 KB
Newer Older
Nianchen Deng's avatar
sync    
Nianchen Deng committed
1
2
from .__common__ import *
from .nerf import NeRF
Nianchen Deng's avatar
sync    
Nianchen Deng committed
3

Nianchen Deng's avatar
sync    
Nianchen Deng committed
4
from .utils import load
Nianchen Deng's avatar
sync    
Nianchen Deng committed
5
6


Nianchen Deng's avatar
sync    
Nianchen Deng committed
7
class SNeRFX(NeRF):
Nianchen Deng's avatar
sync    
Nianchen Deng committed
8

Nianchen Deng's avatar
sync    
Nianchen Deng committed
9
10
11
    def _preprocess_args(self):
        self.args0["spherical"] = True
        super()._preprocess_args()
Nianchen Deng's avatar
sync    
Nianchen Deng committed
12
13
        if "net_samples" not in self.args:
            n_nets = self.args.get("multi_nets", 1)
Nianchen Deng's avatar
sync    
Nianchen Deng committed
14
15
16
17
            cut_by_space = load(self.args['cut_by']).space if "cut_by" in self.args else self.space
            k = self.args["n_samples"] // cut_by_space.steps[0].item()
            self.args0["net_samples"] = [val * k for val in cut_by_space.balance_cut(0, n_nets)]
        self.args1["multi_nets"] = len(self.args["net_samples"])
Nianchen Deng's avatar
sync    
Nianchen Deng committed
18
19

    @torch.no_grad()
Nianchen Deng's avatar
sync    
Nianchen Deng committed
20
21
    def split(self):
        ret = super().split()
Nianchen Deng's avatar
sync    
Nianchen Deng committed
22
        self.args0['net_samples'] = [val * 2 for val in self.args0['net_samples']]
Nianchen Deng's avatar
sync    
Nianchen Deng committed
23
24
        return ret

Nianchen Deng's avatar
sync    
Nianchen Deng committed
25
    @profile
Nianchen Deng's avatar
sync    
Nianchen Deng committed
26
27
    def _render(self, samples: Samples, *outputs: str, **extra_args) -> ReturnData:
        return super()._render(samples, *outputs,
Nianchen Deng's avatar
sync    
Nianchen Deng committed
28
29
                               **extra_args |
                               {"raymarching_chunk_size_or_sections": self.args["net_samples"]})
Nianchen Deng's avatar
sync    
Nianchen Deng committed
30

Nianchen Deng's avatar
sync    
Nianchen Deng committed
31
    @profile
Nianchen Deng's avatar
sync    
Nianchen Deng committed
32
33
    def _multi_infer(self, inputs: NetInput, *outputs: str, chunk_id: int, **kwargs) -> NetOutput:
        return self.cores[chunk_id](inputs, *outputs, **kwargs)