spher_net.py 1.61 KB
Newer Older
1
2
import torch
import torch.nn as nn
Nianchen Deng's avatar
sync    
Nianchen Deng committed
3
4
from ..my import net_modules
from ..my import util
5
6
7
8


class SpherNet(nn.Module):

BobYeah's avatar
BobYeah committed
9
    def __init__(self, fc_params,
10
                 gray: bool = False,
BobYeah's avatar
BobYeah committed
11
                 translation: bool = False,
12
13
14
15
                 encode_to_dim: int = 0):
        """
        Initialize a sphere net

BobYeah's avatar
BobYeah committed
16
17
18
        :param fc_params: parameters for full-connection network
        :param gray: whether grayscale mode
        :param translation: whether support translation of view
19
20
21
        :param encode_to_dim: encode input to number of dimensions
        """
        super().__init__()
BobYeah's avatar
BobYeah committed
22
23
24
        self.in_chns = 5 if translation else 2
        self.input_encoder = net_modules.InputEncoder.Get(
            encode_to_dim, self.in_chns)
25
26
27
28
        fc_params['in_chns'] = self.input_encoder.out_dim
        fc_params['out_chns'] = 1 if gray else 3
        self.net = net_modules.FcNet(**fc_params)

BobYeah's avatar
BobYeah committed
29
    def forward(self, rays_o: torch.Tensor, rays_d: torch.Tensor) -> torch.Tensor:
30
31
32
        """
        rays -> colors

BobYeah's avatar
BobYeah committed
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
        :param rays_o ```Tensor(B, ..., 3)```: rays' origin
        :param rays_d ```Tensor(B, ..., 3)```: rays' direction
        :return: Tensor(B, 1|3, ...), inferred images/pixels
       """
        p = rays_o.view(-1, 3)
        v = rays_d.view(-1, 3)
        spher = util.CartesianToSpherical(v)[..., 1:3]  # (..., 2)
        input = torch.cat([p, spher], dim=-1) if self.in_chns == 5 else spher

        c: torch.Tensor = self.net(self.input_encoder(input))

        # Unflatten according to input shape
        out_shape = list(rays_d.size())
        out_shape[-1] = -1
        return c.view(out_shape).movedim(-1, 1)