fs_nerf.py 2.65 KB
Newer Older
Nianchen Deng's avatar
sync    
Nianchen Deng committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
from ..__common__ import *
from .field import *
from .color_decoder import *
from .density_decoder import *


class FsNeRF(nn.Module):

    def __init__(self, x_chns: int, color_chns: int, depth: int, width: int,
                 skips: list[int], act: str, ln: bool, n_samples: int, n_fields: int):
        """
        Initialize a FS-NeRF core module.

        :param x_chns `int`: channels of input positions (D_x)
        :param d_chns `int`: channels of input directions (D_d)
        :param color_chns `int`: channels of output colors (D_c)
        :param depth `int`: number of layers in field network
        :param width `int`: width of each layer in field network
        :param skips `[int]`: skip connections from input to specific layers in field network
        :param act `str`: activation function in field network and color decoder
        :param ln `bool`: whether enable layer normalization in field network and color decoder
        :param color_decoder_type `str`: type of color decoder
        """
        super().__init__({"x": x_chns}, {"rgbd": 1 + color_chns})
        self.n_fields = n_fields
        self.samples_per_field = n_samples // n_fields
        self.subnets = torch.nn.ModuleList()
        for _ in range(n_fields):
            field = Field(x_chns * self.samples_per_field, [depth, width], skips, act, ln)
            density_decoder = DensityDecoder(field.out_chns, self.samples_per_field)
            color_decoder = BasicColorDecoder(field.out_chns, color_chns * self.samples_per_field)
            self.subnets.append(torch.nn.ModuleDict({
                "field": field,
                "density_decoder": density_decoder,
                "color_decoder": color_decoder
            }))

    # stub method for type hint
    def __call__(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        """
        Inference colors and densities from input samples

        :param x `Tensor(B..., P, D_x)`: input positions
        :return `Tensor(B..., P, D_c + D_σ)`: output colors and densities
        """
        ...

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        densities = []
        colors = []
        for i in range(self.n_fields):
            f = self.subnets[i]["field"](
                x[..., i * self.samples_per_field:(i + 1) * self.samples_per_field, :].flatten(-2))
            densities.append(self.subnets[i]["density_decoder"](f)
                             .unflatten(-1, (self.samples_per_field, -1)))
            colors.append(self.subnets[i]["color_decoder"](f, None)
                          .unflatten(-1, (self.samples_per_field, -1)))
        return torch.cat([torch.cat(colors, -2), torch.cat(densities, -2)], -1)