msl_net_new_export.py 1.38 KB
Newer Older
Nianchen Deng's avatar
Nianchen Deng committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
from typing import Tuple
import math
import torch
import torch.nn as nn
from ..my import net_modules
from ..my import util
from ..my import device
from ..my import color_mode
from .msl_net_new import NewMslNet


class Sampler(nn.Module):

    def __init__(self, net: NewMslNet):
        super().__init__()
        self.net = net

    def forward(self, rays_o: torch.Tensor, rays_d: torch.Tensor) -> torch.Tensor:
        coords, pts, depths = self.net.sampler(rays_o, rays_d)
        return self.net.input_encoder(coords), depths


class FcNet1(nn.Module):

    def __init__(self, net: NewMslNet):
        super().__init__()
        self.net = net

    def forward(self, encoded: torch.Tensor) -> torch.Tensor:
        return self.net.nets[0](encoded[:, :self.net.n_samples // 2]),


class FcNet2(nn.Module):

    def __init__(self, net: NewMslNet):
        super().__init__()
        self.net = net

    def forward(self, encoded: torch.Tensor) -> torch.Tensor:
        return self.net.nets[1](encoded[:, self.net.n_samples // 2:])


class CatNet(nn.Module):

    def __init__(self, net: NewMslNet):
        super().__init__()
        self.net = net

    def forward(self, raw1: torch.Tensor, raw2: torch.Tensor, depths: torch.Tensor) -> torch.Tensor:
        raw = torch.cat([raw1, raw2], 1)
        colors, alphas = self.net.rendering.raw2color(raw, depths)
        return torch.cat([colors, alphas[..., None]], -1)