fs_nerf.py 2.5 KB
Newer Older
Nianchen Deng's avatar
sync    
Nianchen Deng committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
from .__common__ import *
from .model import Model


class FsNeRF(Model):

    class Args(Model.Args):
        n_samples: int = 64
        perturb_sampling: bool = False
        with_radius: bool = False
        n_fields: int = 1
        depth: int = 8
        width: int = 256
        skips: list[int] = [4]
        act: str = "relu"
        ln: bool = False
        xfreqs: int = 6
        raw_noise_std: float = 0.
        near: float = 1.
        far: float = 10.
        white_bg: bool = False

    args: Args

    def __init__(self, args: Args):
        """
        Initialize a FS-NeRF model

        :param args `Args`: arguments
        """
        super().__init__(args)

        # Initialize components
        self._init_sampler()
        self._init_encoders()
        self._init_core()
        self._init_renderer()

    @profile
    def forward(self, rays: Rays, *outputs: str, **args) -> ReturnData:
        samples = self.sample(rays, **args)
        x = self.encode(samples)
        rgbd = self.infer(x)
        return self.render(samples, rgbd, *outputs, **args)

    def sample(self, rays: Rays, **kwargs) -> Samples:
        args = self.args.merge_with(kwargs)
        return self.sampler(rays, None, range=(args.near, args.far), mode="spherical_radius",
                            n_samples=args.n_samples,
                            perturb=args.perturb_sampling if self.training else False)

    def encode(self, samples: Samples) -> torch.Tensor:
        return self.x_encoder(samples.pts[..., -self.x_encoder.in_chns:])

    def infer(self, x: torch.Tensor) -> torch.Tensor:
        return self.core(x)

    def render(self, samples: Samples, rgbd: torch.Tensor, *outputs: str, **kwargs) -> ReturnData:
        args = self.args.merge_with(kwargs)
        return self.renderer(samples, rgbd, *outputs, white_bg=args.white_bg,
                             raw_noise_std=args.raw_noise_std if self.training else 0.)

    def _init_encoders(self):
        self.x_encoder = FreqEncoder(self.sampler.out_chns["x"] - (not self.args.with_radius),
                                     self.args.xfreqs, False)

    def _init_core(self):
        self.core = core.FsNeRF(self.x_encoder.out_chns, self.color.chns,
                                self.args.depth, self.args.width, self.args.skips,
                                self.args.act, self.args.ln, self.args.n_samples, self.args.n_fields)

    def _init_sampler(self):
        self.sampler = UniformSampler()

    def _init_renderer(self):
        self.renderer = VolumnRenderer()