density_decoder.py 463 Bytes
Newer Older
Nianchen Deng's avatar
sync    
Nianchen Deng committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
from ..__common__ import *

__all__ = ["DensityDecoder"]


class DensityDecoder(nn.Module):
    def __init__(self, f_chns: int, density_chns: int, **kwargs):
        super().__init__({"f": f_chns}, {"density": density_chns})
        self.net = nn.FcLayer(f_chns, density_chns)

    # stub method for type hint
    def __call__(self, f: torch.Tensor) -> torch.Tensor:
        ...

    def forward(self, f: torch.Tensor) -> torch.Tensor:
        return self.net(f)