nerf.py 4.63 KB
Newer Older
Nianchen Deng's avatar
sync    
Nianchen Deng committed
1
from .__common__ import *
Nianchen Deng's avatar
sync    
Nianchen Deng committed
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
from .model import Model


class NeRF(Model):
    class Args(Model.Args):
        n_samples: int = 64
        sample_mode: str = "xyz"
        perturb_sampling: bool = False
        depth: int = 8
        width: int = 256
        skips: list[int] = [4]
        act: str = "relu"
        ln: bool = False
        color_decoder: str = "NeRF"
        n_importance: int = 0
        fine_depth: int = 8
        fine_width: int = 256
        fine_skips: list[int] = [4]
        xfreqs: int = 10
        dfreqs: int = 4
        raw_noise_std: float = 0.
        near: float = 1.
        far: float = 10.
        white_bg: bool = False

    args: Args

    def __init__(self, args: Args):
Nianchen Deng's avatar
sync    
Nianchen Deng committed
30
31
32
        """
        Initialize a NeRF model

Nianchen Deng's avatar
sync    
Nianchen Deng committed
33
        :param args `dict`: arguments
Nianchen Deng's avatar
sync    
Nianchen Deng committed
34
        """
Nianchen Deng's avatar
sync    
Nianchen Deng committed
35
36
37
        super().__init__(args)
        if args.sample_mode == "xyz" or args.sample_mode == "xyz_disp":
            args.near = 0.1
Nianchen Deng's avatar
sync    
Nianchen Deng committed
38
39

        # Initialize components
Nianchen Deng's avatar
sync    
Nianchen Deng committed
40
        self._init_sampler()
Nianchen Deng's avatar
sync    
Nianchen Deng committed
41
42
        self._init_encoders()
        self._init_core()
Nianchen Deng's avatar
sync    
Nianchen Deng committed
43
        self._init_renderer()
Nianchen Deng's avatar
sync    
Nianchen Deng committed
44

Nianchen Deng's avatar
sync    
Nianchen Deng committed
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
        if self.args.n_importance > 0:
            self._init_cascade()

    @profile
    def forward(self, rays: Rays, *outputs: str, **args) -> ReturnData:
        samples = self.sample(rays, **args)
        x, d = self.encode(samples)
        rgbd = self.infer(x, d)
        return self.render(rays, samples, rgbd, *outputs, cascade=True, **args)

    def sample(self, rays: Rays, **kwargs) -> Samples:
        args = self.args.merge_with(kwargs)
        return self.sampler(rays, None, range=(args.near, args.far),
                            mode=args.sample_mode, n_samples=args.n_samples,
                            perturb=args.perturb_sampling if self.training else False)

    def encode(self, samples: Samples) -> tuple[torch.Tensor, torch.Tensor]:
        return self.x_encoder(samples.pts), self.d_encoder(math.normalize(samples.dirs))

    def infer(self, x: torch.Tensor, d: torch.Tensor, *, fine: bool = False) -> torch.Tensor:
        if self.args.n_importance > 0 and fine:
            return self.fine_core(x, d)
        return self.core(x, d)

    def render(self, rays: Rays, samples: Samples, rgbd: torch.Tensor, *outputs: str,
               cascade: bool = False, **kwargs) -> ReturnData:
        args = self.args.merge_with(kwargs)
        if args.n_importance > 0 and cascade:
            coarse_outputs = [item[7:] for item in outputs if item.startswith("coarse_")]
            coarse_ret = self.renderer(samples, rgbd, "weights", *coarse_outputs,
                                       white_bg=args.white_bg,
                                       raw_noise_std=args.raw_noise_std if self.training else 0.)
            samples = self.pdf_sampler(rays, None, samples.t_vals, coarse_ret["weights"][..., 0],
                                       mode=args.sample_mode,
                                       n_importance=args.n_importance,
                                       perturb=args.perturb_sampling if self.training else False,
                                       include_existed_samples=True)
            x, d = self.encode(samples)
            fine_rgbd = self.infer(x, d, fine=True)
            return self.renderer(samples, fine_rgbd, *outputs, white_bg=args.white_bg,
                                 raw_noise_std=args.raw_noise_std if self.training else 0.) | {
                f"coarse_{key}": coarse_ret[key]
                for key in coarse_outputs
                if key in coarse_ret
            }
        return self.renderer(samples, rgbd, *outputs, white_bg=args.white_bg,
                             raw_noise_std=args.raw_noise_std if self.training else 0.)
Nianchen Deng's avatar
sync    
Nianchen Deng committed
92

Nianchen Deng's avatar
sync    
Nianchen Deng committed
93
    def _init_encoders(self):
Nianchen Deng's avatar
sync    
Nianchen Deng committed
94
95
        self.x_encoder = FreqEncoder(self.sampler.out_chns["x"], self.args.xfreqs, True)
        self.d_encoder = FreqEncoder(self.sampler.out_chns["d"], self.args.dfreqs, True)
Nianchen Deng's avatar
sync    
Nianchen Deng committed
96
97

    def _init_core(self):
Nianchen Deng's avatar
sync    
Nianchen Deng committed
98
99
100
        self.core = core.NeRF(self.x_encoder.out_chns, self.d_encoder.out_chns, self.color.chns,
                              self.args.depth, self.args.width, self.args.skips,
                              self.args.act, self.args.ln, self.args.color_decoder)
Nianchen Deng's avatar
sync    
Nianchen Deng committed
101
102

    def _init_sampler(self):
Nianchen Deng's avatar
sync    
Nianchen Deng committed
103
104
105
106
107
108
109
110
        self.sampler = UniformSampler()

    def _init_cascade(self):
        self.pdf_sampler = PdfSampler()
        self.fine_core = core.NeRF(self.x_encoder.out_chns, self.d_encoder.out_chns, self.color.chns,
                                   self.args.fine_depth, self.args.fine_width,
                                   self.args.fine_skips, self.args.act, self.args.ln,
                                   self.args.color_decoder)
Nianchen Deng's avatar
sync    
Nianchen Deng committed
111
112

    def _init_renderer(self):
Nianchen Deng's avatar
sync    
Nianchen Deng committed
113
        self.renderer = VolumnRenderer()