msl_net.py 1.07 KB
Newer Older
BobYeah's avatar
sync    
BobYeah committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
from typing import List
import torch
import torch.nn as nn
from .pytorch_prototyping.pytorch_prototyping import *
from .my import util
from .my import device


class FcNet(nn.Module):

    def __init__(self, in_chns, out_chns, nf, n_layers):
        super().__init__()
        self.layers = list()
        self.layers.append(nn.Linear(in_chns, nf))
        self.layers.append(nn.LeakyReLU())
        for _ in range(1, n_layers):
            self.layers.append(nn.Linear(nf, nf))
            self.layers.append(nn.LeakyReLU())
        self.layers.append(nn.Linear(nf, out_chns))
        self.net = nn.Sequential(*self.layers)

    def forward(self, x):
        return self.net(x)

class Rendering(nn.Module):

    def __init__(self, n_sphere_layers):
        super().__init__()
        self.n_sl = n_sphere_layers

    def forward(self, net, pos, dir):
        """
        [summary]

        :param pos: B x 3, position of a ray
        :param dir: B x 3, direction of a ray
        """
        

class MslNet(nn.Module):

    def __init__(self):
        super().__init__()

    def forward(self, x):