nsvf.py 2.78 KB
Newer Older
Nianchen Deng's avatar
Nianchen Deng committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import torch
import torch.nn as nn
from modules import *
from utils import color


class NSVF(nn.Module):

    def __init__(self, fc_params, sampler_params, *,
                 c: int = color.RGB,
                 n_featdim: int = 32,
                 n_pos_encode: int = 0,
                 n_dir_encode: int = None,
                 **kwargs):
        """
        Initialize a NSVF model

        :param fc_params `dict`: parameters for full-connection network
        :param sampler_params `dict`: parameters for sampler
        :param c `int`: color mode
        :param n_pos_encode `int`: encode position to number of dimensions
        :param n_dir_encode `int`: encode direction to number of dimensions, `None` means direction is ignored
        :param coarse_net `NerfUnit`: optional coarse net
        """
        super().__init__()
        self.color = c
        self.coord_chns = n_featdim
        self.color_chns = color.chns(self.color)

        self.pos_encoder = InputEncoder.Get(n_pos_encode, self.coord_chns)
        if n_dir_encode is not None:
            self.dir_chns = 3
            self.dir_encoder = InputEncoder.Get(n_dir_encode, self.dir_chns)
        else:
            self.dir_chns = 0
            self.dir_encoder = None
        self.core = NerfCore(coord_chns=self.pos_encoder.out_dim,
                             density_chns=1,
                             color_chns=self.color_chns,
                             core_nf=fc_params['nf'],
                             core_layers=fc_params['n_layers'],
                             dir_chns=self.dir_encoder.out_dim if self.dir_encoder else 0,
                             dir_nf=fc_params['nf'] // 2,
                             activation=fc_params['activation'],
                             skips=fc_params['skips'])
        
        self.space = OctTreeSpace()

        sampler_params['space'] = self.space
        self.sampler = VoxelSampler(**sampler_params)
        self.rendering = VolumnRenderer()

    def forward(self, rays_o: torch.Tensor, rays_d: torch.Tensor, *,
                ret_depth=False, debug=False) -> torch.Tensor:
        """
        rays -> colors

        :param rays_o `Tensor(B, 3)`: rays' origin
        :param rays_d `Tensor(B, 3)`: rays' direction
        :param prev_ret `Mapping`:
        :param ret_depth `bool`:
        :return: `Tensor(B, C)``, inferred images/pixels
        """
        feats, dirs, z_s, dz_s = self.sampler(rays_o, rays_d)
        feats_encoded = self.pos_encoder(feats)
        dirs_encoded = self.dir_encoder(rays_d)[:, None].expand(-1, z_s.size(-1), -1) \
            if self.dir_encoder is not None else None
        colors, densities = self.core(feats_encoded, dirs_encoded)
        ret = self.rendering(colors, densities[..., 0], z_s, dz_s, ret_depth=ret_depth, debug=debug)
        return ret