from typing import List
import torch
import torch.nn as nn
from .pytorch_prototyping.pytorch_prototyping import *
from .my import util
from .my import device


class FcNet(nn.Module):

    def __init__(self, in_chns, out_chns, nf, n_layers):
        super().__init__()
        self.layers = list()
        self.layers.append(nn.Linear(in_chns, nf))
        self.layers.append(nn.LeakyReLU())
        for _ in range(1, n_layers):
            self.layers.append(nn.Linear(nf, nf))
            self.layers.append(nn.LeakyReLU())
        self.layers.append(nn.Linear(nf, out_chns))
        self.net = nn.Sequential(*self.layers)

    def forward(self, x):
        return self.net(x)

class Rendering(nn.Module):

    def __init__(self, n_sphere_layers):
        super().__init__()
        self.n_sl = n_sphere_layers

    def forward(self, net, pos, dir):
        """
        [summary]

        :param pos: B x 3, position of a ray
        :param dir: B x 3, direction of a ray
        """
        

class MslNet(nn.Module):

    def __init__(self):
        super().__init__()

    def forward(self, x):