msl_net_new.py 3.42 KB
Newer Older
BobYeah's avatar
sync    
BobYeah committed
1
2
import torch
import torch.nn as nn
Nianchen Deng's avatar
Nianchen Deng committed
3
from .modules import *
Nianchen Deng's avatar
sync    
Nianchen Deng committed
4
from ..my import color_mode
Nianchen Deng's avatar
Nianchen Deng committed
5
from ..my.simple_perf import SimplePerf
6

BobYeah's avatar
sync    
BobYeah committed
7

Nianchen Deng's avatar
sync    
Nianchen Deng committed
8
class NewMslNet(nn.Module):
BobYeah's avatar
sync    
BobYeah committed
9

BobYeah's avatar
BobYeah committed
10
    def __init__(self, fc_params, sampler_params,
Nianchen Deng's avatar
sync    
Nianchen Deng committed
11
12
                 normalize_coord: bool,
                 dir_as_input: bool,
Nianchen Deng's avatar
Nianchen Deng committed
13
                 n_nets: int = 2,
Nianchen Deng's avatar
sync    
Nianchen Deng committed
14
                 not_same_net: bool = False,
Nianchen Deng's avatar
sync    
Nianchen Deng committed
15
                 color: int = color_mode.RGB,
BobYeah's avatar
sync    
BobYeah committed
16
17
                 encode_to_dim: int = 0,
                 export_mode: bool = False):
18
19
20
        """
        Initialize a multi-sphere-layer net

BobYeah's avatar
BobYeah committed
21
22
        :param fc_params: parameters for full-connection network
        :param sampler_params: parameters for sampler
Nianchen Deng's avatar
sync    
Nianchen Deng committed
23
24
        :param normalize_coord: whether normalize the spherical coords to [0, 2pi] before encode
        :param color: color mode
BobYeah's avatar
BobYeah committed
25
        :param encode_to_dim: encode input to number of dimensions
26
        """
BobYeah's avatar
sync    
BobYeah committed
27
        super().__init__()
BobYeah's avatar
BobYeah committed
28
        self.in_chns = 3
Nianchen Deng's avatar
Nianchen Deng committed
29
        self.input_encoder = InputEncoder.Get(
BobYeah's avatar
BobYeah committed
30
31
            encode_to_dim, self.in_chns)
        fc_params['in_chns'] = self.input_encoder.out_dim
Nianchen Deng's avatar
sync    
Nianchen Deng committed
32
        fc_params['out_chns'] = 2 if color == color_mode.GRAY else 4
BobYeah's avatar
BobYeah committed
33
        self.sampler = Sampler(**sampler_params)
BobYeah's avatar
BobYeah committed
34
        self.rendering = Rendering()
BobYeah's avatar
sync    
BobYeah committed
35
        self.export_mode = export_mode
Nianchen Deng's avatar
sync    
Nianchen Deng committed
36
        self.normalize_coord = normalize_coord
Nianchen Deng's avatar
Nianchen Deng committed
37
38
39
40
41
42
43
44
45
46
47
48
49
50

        if not_same_net:
            self.n_nets = 2
            self.nets = nn.ModuleList([
                FcNet(**fc_params),
                FcNet(in_chns=fc_params['in_chns'],
                      out_chns=fc_params['out_chns'],
                      nf=128, n_layers=4)
            ])
        else:
            self.n_nets = n_nets
            self.nets = nn.ModuleList([
                FcNet(**fc_params) for _ in range(n_nets)
            ])
Nianchen Deng's avatar
sync    
Nianchen Deng committed
51
52
53
54
55
56
57
58
59
        self.n_samples = sampler_params['n_samples']

    def update_normalize_range(self, rays_o: torch.Tensor, rays_d: torch.Tensor):
        coords, _, _ = self.sampler(rays_o, rays_d)
        coords = coords[..., 1:].view(-1, 2)
        self.angle_range = torch.stack([
            torch.cat([coords, self.angle_range[0:1]]).amin(0),
            torch.cat([coords, self.angle_range[1:2]]).amax(0)
        ])
BobYeah's avatar
BobYeah committed
60

Nianchen Deng's avatar
Nianchen Deng committed
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
    def sample_and_infer(self, rays_o: torch.Tensor, rays_d: torch.Tensor,
                         sampler: Sampler = None) -> torch.Tensor:
        if not sampler:
            sampler = self.sampler
        coords, pts, depths = sampler(rays_o, rays_d)

        encoded = self.input_encoder(coords)

        sn = sampler.samples // self.n_nets
        raw = torch.cat([
            self.nets[i](encoded[:, i * sn:(i + 1) * sn])
            for i in range(self.n_nets)
        ], 1)
        return raw, depths

BobYeah's avatar
sync    
BobYeah committed
76
    def forward(self, rays_o: torch.Tensor, rays_d: torch.Tensor,
Nianchen Deng's avatar
Nianchen Deng committed
77
                ret_depth: bool = False, sampler: Sampler = None) -> torch.Tensor:
78
        """
BobYeah's avatar
BobYeah committed
79
        rays -> colors
80

BobYeah's avatar
sync    
BobYeah committed
81
82
83
        :param rays_o ```Tensor(B, 3)```: rays' origin
        :param rays_d ```Tensor(B, 3)```: rays' direction
        :return: ```Tensor(B, C)``, inferred images/pixels
84
        """
Nianchen Deng's avatar
Nianchen Deng committed
85
        raw, depths = self.sample_and_infer(rays_o, rays_d, sampler)
BobYeah's avatar
sync    
BobYeah committed
86
        if self.export_mode:
Nianchen Deng's avatar
sync    
Nianchen Deng committed
87
            colors, alphas = self.rendering.raw2color(raw, depths)
BobYeah's avatar
sync    
BobYeah committed
88
89
90
91
            return torch.cat([colors, alphas[..., None]], -1)

        if ret_depth:
            color_map, _, _, _, depth_map = self.rendering(
Nianchen Deng's avatar
sync    
Nianchen Deng committed
92
                raw, depths, ret_extra=True)
BobYeah's avatar
sync    
BobYeah committed
93
            return color_map, depth_map
Nianchen Deng's avatar
sync    
Nianchen Deng committed
94

Nianchen Deng's avatar
sync    
Nianchen Deng committed
95
        return self.rendering(raw, depths)