snerf_fast.py 4.68 KB
Newer Older
Nianchen Deng's avatar
sync    
Nianchen Deng committed
1
2
import torch
import torch.nn as nn
Nianchen Deng's avatar
Nianchen Deng committed
3
from modules import *
Nianchen Deng's avatar
sync    
Nianchen Deng committed
4
5
6
7
8
9
from utils import sphere
from utils import color


class SnerfFast(nn.Module):

Nianchen Deng's avatar
Nianchen Deng committed
10
11
12
13
14
15
    def __init__(self, fc_params, sampler_params, *,
                 n_parts: int = 1,
                 c: int = color.RGB,
                 pos_encode: int = 0,
                 dir_encode: int = None,
                 spherical_dir: bool = False, **kwargs):
Nianchen Deng's avatar
sync    
Nianchen Deng committed
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
        """
        Initialize a multi-sphere-layer net

        :param fc_params: parameters for full-connection network
        :param sampler_params: parameters for sampler
        :param normalize_coord: whether normalize the spherical coords to [0, 2pi] before encode
        :param c: color mode
        :param encode_to_dim: encode input to number of dimensions
        """
        super().__init__()
        self.color = c
        self.spherical_dir = spherical_dir
        self.n_samples = sampler_params['n_samples']
        self.n_parts = n_parts
        self.samples_per_part = self.n_samples // self.n_parts
Nianchen Deng's avatar
Nianchen Deng committed
31
        self.coord_chns = 2
Nianchen Deng's avatar
sync    
Nianchen Deng committed
32
        self.color_chns = color.chns(self.color)
Nianchen Deng's avatar
Nianchen Deng committed
33
        self.pos_encoder = InputEncoder.Get(pos_encode, self.coord_chns)
Nianchen Deng's avatar
sync    
Nianchen Deng committed
34
35
36
37
38
39
40
41
42

        if dir_encode is not None:
            self.dir_encoder = InputEncoder.Get(dir_encode, 2 if self.spherical_dir else 3)
            self.dir_chns_per_part = self.dir_encoder.out_dim * \
                (self.samples_per_part if self.spherical_dir else 1)
        else:
            self.dir_encoder = None
            self.dir_chns_per_part = 0

Nianchen Deng's avatar
Nianchen Deng committed
43
44
45
46
47
48
49
50
51
52
53
54
55
56
        self.nets = [
            NerfCore(coord_chns=self.pos_encoder.out_dim * self.samples_per_part,
                     density_chns=self.samples_per_part,
                     color_chns=self.color_chns * self.samples_per_part,
                     core_nf=fc_params['nf'],
                     core_layers=fc_params['n_layers'],
                     dir_chns=self.dir_chns_per_part,
                     dir_nf=fc_params['nf'] // 2,
                     activation=fc_params['activation'])
            for _ in range(self.n_parts)
        ]
        for i in range(self.n_parts):
            self.add_module(f"mlp_{i:d}", self.nets[i])
        sampler_params['spherical'] = True
Nianchen Deng's avatar
sync    
Nianchen Deng committed
57
        self.sampler = Sampler(**sampler_params)
Nianchen Deng's avatar
Nianchen Deng committed
58
        self.rendering = VolumnRenderer()
Nianchen Deng's avatar
sync    
Nianchen Deng committed
59
60
61
62
63
64
65
66
67
68

    def forward(self, rays_o: torch.Tensor, rays_d: torch.Tensor,
                ret_depth=False, debug=False) -> torch.Tensor:
        """
        rays -> colors

        :param rays_o `Tensor(B, 3)`: rays' origin
        :param rays_d `Tensor(B, 3)`: rays' direction
        :return: `Tensor(B, C)``, inferred images/pixels
        """
Nianchen Deng's avatar
Nianchen Deng committed
69
70
71
        coords, depths, _, pts = self.sampler(rays_o, rays_d)
        #print('NaN count: ', coords.isnan().sum().item(), depths.isnan().sum().item(), pts.isnan().sum().item())
        coords_encoded = self.pos_encoder(coords[..., -self.coord_chns:])
Nianchen Deng's avatar
sync    
Nianchen Deng committed
72
73
74
        dirs_encoded = self.dir_encoder(
            sphere.calc_local_dir(rays_d, coords, pts) if self.spherical_dir else rays_d) \
            if self.dir_encoder is not None else None
Nianchen Deng's avatar
Nianchen Deng committed
75

Nianchen Deng's avatar
sync    
Nianchen Deng committed
76
77
78
        densities = torch.empty(rays_o.size(0), self.n_samples, device=device.default())
        colors = torch.empty(rays_o.size(0), self.n_samples, self.color_chns,
                             device=device.default())
Nianchen Deng's avatar
Nianchen Deng committed
79
80
81
82
83
84
85
86
87
88
89
90
        for i, net in enumerate(self.nets):
            s = slice(i * self.samples_per_part, (i + 1) * self.samples_per_part)
            c, d = net(coords_encoded[:, s].flatten(1, 2),
                       dirs_encoded[:, s].flatten(1, 2) if self.spherical_dir else dirs_encoded)
            colors[:, s] = c.view(-1, self.samples_per_part, self.color_chns)
            densities[:, s] = d
        ret = self.rendering(colors.view(-1, self.n_samples, self.color_chns),
                             densities, depths, ret_depth=ret_depth, debug=debug)
        if debug:
            ret['sample_densities'] = densities
            ret['sample_depths'] = depths
        return ret
Nianchen Deng's avatar
sync    
Nianchen Deng committed
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111


class SnerfFastExport(nn.Module):

    def __init__(self, net: SnerfFast):
        super().__init__()
        self.net = net

    def forward(self, coords_encoded, z_vals):
        colors = []
        densities = []
        for i in range(self.net.n_parts):
            s = slice(i * self.net.samples_per_part, (i + 1) * self.net.samples_per_part)
            mlp = self.net.nets[i] if self.net.nets is not None else self.net.net
            c, d = mlp(coords_encoded[:, s].flatten(1, 2))
            colors.append(c.view(-1, self.net.samples_per_part, self.net.color_chns))
            densities.append(d)
        colors = torch.cat(colors, 1)
        densities = torch.cat(densities, 1)
        alphas = self.net.rendering.density2alpha(densities, z_vals)
        return torch.cat([colors, alphas[..., None]], -1)