import torch import torch.nn as nn from .modules import * from ..my import color_mode from ..my.simple_perf import SimplePerf class NewMslNet(nn.Module): def __init__(self, fc_params, sampler_params, normalize_coord: bool, dir_as_input: bool, n_nets: int = 2, not_same_net: bool = False, color: int = color_mode.RGB, encode_to_dim: int = 0, export_mode: bool = False): """ Initialize a multi-sphere-layer net :param fc_params: parameters for full-connection network :param sampler_params: parameters for sampler :param normalize_coord: whether normalize the spherical coords to [0, 2pi] before encode :param color: color mode :param encode_to_dim: encode input to number of dimensions """ super().__init__() self.in_chns = 3 self.input_encoder = InputEncoder.Get( encode_to_dim, self.in_chns) fc_params['in_chns'] = self.input_encoder.out_dim fc_params['out_chns'] = 2 if color == color_mode.GRAY else 4 self.sampler = Sampler(**sampler_params) self.rendering = Rendering() self.export_mode = export_mode self.normalize_coord = normalize_coord if not_same_net: self.n_nets = 2 self.nets = nn.ModuleList([ FcNet(**fc_params), FcNet(in_chns=fc_params['in_chns'], out_chns=fc_params['out_chns'], nf=128, n_layers=4) ]) else: self.n_nets = n_nets self.nets = nn.ModuleList([ FcNet(**fc_params) for _ in range(n_nets) ]) self.n_samples = sampler_params['n_samples'] def update_normalize_range(self, rays_o: torch.Tensor, rays_d: torch.Tensor): coords, _, _ = self.sampler(rays_o, rays_d) coords = coords[..., 1:].view(-1, 2) self.angle_range = torch.stack([ torch.cat([coords, self.angle_range[0:1]]).amin(0), torch.cat([coords, self.angle_range[1:2]]).amax(0) ]) def sample_and_infer(self, rays_o: torch.Tensor, rays_d: torch.Tensor, sampler: Sampler = None) -> torch.Tensor: if not sampler: sampler = self.sampler coords, pts, depths = sampler(rays_o, rays_d) encoded = self.input_encoder(coords) sn = sampler.samples // self.n_nets raw = torch.cat([ self.nets[i](encoded[:, i * sn:(i + 1) * sn]) for i in range(self.n_nets) ], 1) return raw, depths def forward(self, rays_o: torch.Tensor, rays_d: torch.Tensor, ret_depth: bool = False, sampler: Sampler = None) -> torch.Tensor: """ rays -> colors :param rays_o ```Tensor(B, 3)```: rays' origin :param rays_d ```Tensor(B, 3)```: rays' direction :return: ```Tensor(B, C)``, inferred images/pixels """ raw, depths = self.sample_and_infer(rays_o, rays_d, sampler) if self.export_mode: colors, alphas = self.rendering.raw2color(raw, depths) return torch.cat([colors, alphas[..., None]], -1) if ret_depth: color_map, _, _, _, depth_map = self.rendering( raw, depths, ret_extra=True) return color_map, depth_map return self.rendering(raw, depths)