nerf_depth.py 1.97 KB
Newer Older
Nianchen Deng's avatar
sync    
Nianchen Deng committed
1
2
import torch
import torch.nn as nn
Nianchen Deng's avatar
Nianchen Deng committed
3
from modules import *
Nianchen Deng's avatar
sync    
Nianchen Deng committed
4
5
6
7
8
9
10
11
12
13
14
15
from utils import color

'''
The first step towards depth-guide acceleration
Sample according to raw depth input
'''


class NerfDepth(nn.Module):

    def __init__(self, fc_params, sampler_params,
                 c: int = color.RGB,
Nianchen Deng's avatar
Nianchen Deng committed
16
                 pos_encode: int = 0,
Nianchen Deng's avatar
sync    
Nianchen Deng committed
17
18
19
20
21
22
23
                 n_bins: int = 128,
                 include_neighbor_bins=True):
        super().__init__()
        self.color = c
        self.n_samples = sampler_params['n_samples']
        self.coord_chns = 3
        self.color_chns = color.chns(self.color)
Nianchen Deng's avatar
Nianchen Deng committed
24
25
        self.pos_encoder = InputEncoder.Get(pos_encode, self.coord_chns)
        self.mlp = NerfCore(coord_chns=self.pos_encoder.out_dim,
Nianchen Deng's avatar
sync    
Nianchen Deng committed
26
27
28
29
                          density_chns=1,
                          color_chns=self.color_chns,
                          core_nf=fc_params['nf'],
                          core_layers=fc_params['n_layers'],
Nianchen Deng's avatar
sync    
Nianchen Deng committed
30
                          act=fc_params['activation'],
Nianchen Deng's avatar
sync    
Nianchen Deng committed
31
32
33
                          skips=fc_params['skips'])
        self.sampler = AdaptiveSampler(**sampler_params, n_bins=n_bins,
                                       include_neighbor_bins=include_neighbor_bins)
Nianchen Deng's avatar
Nianchen Deng committed
34
        self.rendering = VolumnRenderer()
Nianchen Deng's avatar
sync    
Nianchen Deng committed
35
36
37
38
39
40
41
42
43
44
45
46
47

    def forward(self, rays_o: torch.Tensor, rays_d: torch.Tensor,
                rays_depth: torch.Tensor, rays_bins: torch.Tensor,
                ret_depth=False, debug=False) -> torch.Tensor:
        """
        rays -> colors

        :param rays_o `Tensor(B, 3)`: rays' origin
        :param rays_d `Tensor(B, 3)`: rays' direction
        :param rays_depth `Tensor(B)`: rays' depth
        :return: `Tensor(B, C)``, inferred images/pixels
        """
        coords, pts, depths, _ = self.sampler(rays_o, rays_d, rays_depth, rays_bins)
Nianchen Deng's avatar
Nianchen Deng committed
48
        encoded_position = self.pos_encoder(coords)
Nianchen Deng's avatar
sync    
Nianchen Deng committed
49
50
51
        colors, densities = self.mlp(encoded_position)
        return self.rendering(colors, densities[..., 0], depths,
                              ret_depth=ret_depth, debug=debug)