import math import torch import torch.nn as nn from .modules import * from ..my import util from ..my import color_mode class MslNet(nn.Module): def __init__(self, fc_params, sampler_params, normalize_coord: bool, dir_as_input: bool, color: int = color_mode.RGB, encode_to_dim: int = 0, export_mode: bool = False): """ Initialize a multi-sphere-layer net :param fc_params: parameters for full-connection network :param sampler_params: parameters for sampler :param normalize_coord: whether normalize the spherical coords to [0, 2pi] before encode :param color: color mode :param encode_to_dim: encode input to number of dimensions """ super().__init__() self.in_chns = 3 self.input_encoder = InputEncoder.Get( encode_to_dim, self.in_chns) fc_params['in_chns'] = self.input_encoder.out_dim fc_params['out_chns'] = 2 if color == color_mode.GRAY else 4 self.sampler = Sampler(**sampler_params) self.rendering = Rendering() self.export_mode = export_mode self.normalize_coord = normalize_coord self.dir_as_input = dir_as_input self.color = color if self.color == color_mode.YCbCr: self.net1 = FcNet( in_chns=fc_params['in_chns'], out_chns=fc_params['nf'] + 2, nf=fc_params['nf'], n_layers=fc_params['n_layers'] - 2) self.net2 = FcNet( in_chns=fc_params['nf'], out_chns=2, nf=fc_params['nf'], n_layers=1) self.net = None elif self.dir_as_input: self.input_encoder2 = InputEncoder.Get(4, 2) self.net1 = FcNet( in_chns=fc_params['in_chns'], out_chns=fc_params['nf'], nf=fc_params['nf'], n_layers=fc_params['n_layers']) self.net2 = FcNet( in_chns=fc_params['nf'] + self.input_encoder2.out_dim, out_chns=fc_params['out_chns'], nf=fc_params['nf'], n_layers=1) self.net = None else: self.net = FcNet(**fc_params) if self.normalize_coord: self.register_buffer('angle_range', torch.tensor( [[1e5, 1e5], [-1e5, -1e5]])) self.register_buffer('depth_range', torch.tensor([ self.sampler.lower[0], self.sampler.upper[-1] ])) self.n_samples = sampler_params['n_samples'] def update_normalize_range(self, rays_o: torch.Tensor, rays_d: torch.Tensor): coords, _, _ = self.sampler(rays_o, rays_d) coords = coords[..., 1:].view(-1, 2) self.angle_range = torch.stack([ torch.cat([coords, self.angle_range[0:1]]).amin(0), torch.cat([coords, self.angle_range[1:2]]).amax(0) ]) def calc_local_dir(self, rays_d, coords, pts: torch.Tensor): """ [summary] :param rays_d ```Tensor(B, 3)```: :param coords ```Tensor(B, N, 3)```: :param pts ```Tensor(B, N, 3)```: :return ```Tensor(B, N, 2)``` """ local_z = pts / pts.norm(dim=-1, keepdim=True) local_x = util.SphericalToCartesian( coords + torch.tensor([0, 0.1 / 180 * math.pi, 0], device=coords.device)) - pts local_x = local_x / local_x.norm(dim=-1, keepdim=True) local_y = torch.cross(local_x, local_z, -1) local_rot = torch.stack( [local_x, local_y, local_z], dim=-2) # (B, N, 3, 3) return util.CartesianToSpherical(torch.matmul( rays_d[:, None, None, :], local_rot)).squeeze(-2)[..., 1:3] def sample_and_infer(self, rays_o: torch.Tensor, rays_d: torch.Tensor, sampler: Sampler = None) -> torch.Tensor: if not sampler: sampler = self.sampler coords, pts, depths = sampler(rays_o, rays_d) if self.dir_as_input: dirs = self.calc_local_dir(rays_d, coords, pts) if self.normalize_coord: # Normalize coords to [0, 2pi] range = torch.cat( [self.depth_range.view(2, 1), self.angle_range], 1) coords = (coords - range[0]) / (range[1] - range[0]) * 2 * math.pi encoded = self.input_encoder(coords) if self.color == color_mode.YCbCr: mid_output = self.net1(encoded) net2_output = self.net2(mid_output[..., :-2]) raw = torch.cat([ mid_output[..., -2:], net2_output ], -1) elif self.dir_as_input: encoded_dirs = self.input_encoder2(dirs) #print(encoded.size(), self.net1(encoded).size(), encoded_dirs.size()) raw = self.net2(torch.cat([self.net1(encoded), encoded_dirs], -1)) else: raw = self.net(encoded) return raw, depths def forward(self, rays_o: torch.Tensor, rays_d: torch.Tensor, ret_depth: bool = False, sampler: Sampler = None) -> torch.Tensor: """ rays -> colors :param rays_o ```Tensor(B, 3)```: rays' origin :param rays_d ```Tensor(B, 3)```: rays' direction :return: ```Tensor(B, C)``, inferred images/pixels """ raw, depths = self.sample_and_infer(rays_o, rays_d, sampler) if self.export_mode: colors, alphas = self.rendering.raw2color(raw, depths) return torch.cat([colors, alphas[..., None]], -1) if ret_depth: color_map, _, _, _, depth_map = self.rendering( raw, depths, ret_extra=True) return color_map, depth_map return self.rendering(raw, depths) class ExportNet(nn.Module): def __init__(self, net: MslNet): super().__init__() self.net = net def forward(self, encoded: torch.Tensor, depths: torch.Tensor) -> torch.Tensor: raw = self.net.net(encoded) colors, alphas = self.net.rendering.raw2color(raw, depths) return torch.cat([colors, alphas[..., None]], -1)