from typing import Tuple import math import torch import torch.nn as nn from ..my import net_modules from ..my import util from ..my import device from ..my import color_mode rand_gen = torch.Generator(device=device.GetDevice()) rand_gen.manual_seed(torch.seed()) def RaySphereIntersect(p: torch.Tensor, v: torch.Tensor, r: torch.Tensor) -> torch.Tensor: """ Calculate intersections of each rays and each spheres :param p ```Tensor(B, 3)```: positions of rays :param v ```Tensor(B, 3)```: directions of rays :param r ```Tensor(N)```: , radius of spheres :return ```Tensor(B, N, 3)```: points of intersection :return ```Tensor(B, N)```: depths of intersection along ray """ # p, v: Expand to (B, 1, 3) p = p.unsqueeze(1) v = v.unsqueeze(1) # pp, vv, pv: (B, 1) pp = (p * p).sum(dim=2) vv = (v * v).sum(dim=2) pv = (p * v).sum(dim=2) depths = (((pv * pv - vv * (pp - r * r)).sqrt() - pv) / vv) return p + depths[..., None] * v, depths class Rendering(nn.Module): def __init__(self, *, raw_noise_std: float = 0.0, white_bg: bool = False): """ Initialize a Rendering module """ super().__init__() self.raw_noise_std = raw_noise_std self.white_bg = white_bg def forward(self, raw, z_vals, ret_extra: bool = False): """Transforms model's predictions to semantically meaningful values. Args: raw: [num_rays, num_samples along ray, 4]. Prediction from model. z_vals: [num_rays, num_samples along ray]. Integration time. Returns: rgb_map: [num_rays, 3]. Estimated RGB color of a ray. disp_map: [num_rays]. Disparity map. Inverse of depth map. acc_map: [num_rays]. Sum of weights along each ray. weights: [num_rays, num_samples]. Weights assigned to each sampled color. depth_map: [num_rays]. Estimated distance to object. """ color, alpha = self.raw2color(raw, z_vals) # Compute weight for RGB of each sample along each ray. A cumprod() is # used to express the idea of the ray not having reflected up to this # sample yet. one_minus_alpha = util.broadcast_cat( torch.cumprod(1 - alpha[..., :-1] + 1e-10, dim=-1), 1.0, append=False) weights = alpha * one_minus_alpha # (N_rays, N_samples) # (N_rays, 1|3), computed weighted color of each sample along each ray. color_map = torch.sum(weights[..., None] * color, dim=-2) # To composite onto a white background, use the accumulated alpha map. if self.white_bg or ret_extra: # Sum of weights along each ray. This value is in [0, 1] up to numerical error. acc_map = torch.sum(weights, -1) if self.white_bg: color_map = color_map + (1. - acc_map[..., None]) else: acc_map = None if not ret_extra: return color_map # Estimated depth map is expected distance. depth_map = torch.sum(weights * z_vals, dim=-1) # Disparity map is inverse depth. disp_map = torch.clamp_min( depth_map / torch.sum(weights, dim=-1), 1e-10).reciprocal() return color_map, disp_map, acc_map, weights, depth_map def raw2color(self, raw: torch.Tensor, z_vals: torch.Tensor): """ Raw value inferred from model to color and alpha :param raw ```Tensor(N.rays, N.samples, 2|4)```: model's output :param z_vals ```Tensor(N.rays, N.samples)```: integration time :return ```Tensor(N.rays, N.samples, 1|3)```: color :return ```Tensor(N.rays, N.samples)```: alpha """ def raw2alpha(raw, dists, act_fn=torch.relu): """ Function for computing density from model prediction. This value is strictly between [0, 1]. """ return -torch.exp(-act_fn(raw) * dists) + 1.0 # Compute 'distance' (in time) between each integration time along a ray. # The 'distance' from the last integration time is infinity. # dists: (N_rays, N_samples) dists = z_vals[..., 1:] - z_vals[..., :-1] last_dist = z_vals[..., 0:1] * 0 + 1e10 dists = torch.cat([ dists, last_dist ], -1) # Extract RGB of each sample position along each ray. color = torch.sigmoid(raw[..., :-1]) # (N_rays, N_samples, 1|3) if self.raw_noise_std > 0.: # Add noise to model's predictions for density. Can be used to # regularize network during training (prevents floater artifacts). noise = torch.normal(0.0, self.raw_noise_std, raw[..., 3].size(), rand_gen) alpha = raw2alpha(raw[..., -1] + noise, dists) else: alpha = raw2alpha(raw[..., -1], dists) return color, alpha class Sampler(nn.Module): def __init__(self, *, depth_range: Tuple[float, float], n_samples: int, perturb_sample: bool, spherical: bool, lindisp: bool, inverse_r: bool): """ Initialize a Sampler module :param depth_range: depth range for sampler :param n_samples: count to sample along ray :param perturb_sample: perturb the sample depths :param lindisp: If True, sample linearly in inverse depth rather than in depth """ super().__init__() self.lindisp = lindisp if self.lindisp: depth_range = (1 / depth_range[0], 1 / depth_range[1]) self.r = torch.linspace(depth_range[0], depth_range[1], n_samples, device=device.GetDevice()) step = (depth_range[1] - depth_range[0]) / (n_samples - 1) self.perturb_sample = perturb_sample self.spherical = spherical self.inverse_r = inverse_r self.upper = torch.clamp_min(self.r + step / 2, 0) self.lower = torch.clamp_min(self.r - step / 2, 0) def forward(self, rays_o, rays_d): """ Sample points along rays. return Spherical or Cartesian coordinates, specified by ```self.shperical``` :param rays_o ```Tensor(B, 3)```: rays' origin :param rays_d ```Tensor(B, 3)```: rays' direction :return ```Tensor(B, N, 3)```: sampled points :return ```Tensor(B, N)```: corresponding depths along rays """ if self.perturb_sample: # stratified samples in those intervals t_rand = torch.rand(self.r.size(), generator=rand_gen, device=device.GetDevice()) r = self.lower + (self.upper - self.lower) * t_rand else: r = self.r if self.lindisp: r = torch.reciprocal(r) if self.spherical: pts, depths = RaySphereIntersect(rays_o, rays_d, r) sphers = util.CartesianToSpherical(pts, inverse_r=self.inverse_r) return sphers, pts, depths else: return rays_o[..., None, :] + rays_d[..., None, :] * r[..., None], r # x>0, y>0 -> (y, -x) # x<0, y>0 -> (-y, x) # x<0, y<0 -> (y, -x) # x>0, y<0 -> (-y, x) class MslNet(nn.Module): def __init__(self, fc_params, sampler_params, normalize_coord: bool, dir_as_input: bool, color: int = color_mode.RGB, encode_to_dim: int = 0, export_mode: bool = False): """ Initialize a multi-sphere-layer net :param fc_params: parameters for full-connection network :param sampler_params: parameters for sampler :param normalize_coord: whether normalize the spherical coords to [0, 2pi] before encode :param color: color mode :param encode_to_dim: encode input to number of dimensions """ super().__init__() self.in_chns = 3 self.input_encoder = net_modules.InputEncoder.Get( encode_to_dim, self.in_chns) fc_params['in_chns'] = self.input_encoder.out_dim fc_params['out_chns'] = 2 if color == color_mode.GRAY else 4 self.sampler = Sampler(**sampler_params) self.rendering = Rendering() self.export_mode = export_mode self.normalize_coord = normalize_coord self.dir_as_input = dir_as_input self.color = color if self.color == color_mode.YCbCr: self.net1 = net_modules.FcNet( in_chns=fc_params['in_chns'], out_chns=fc_params['nf'] + 2, nf=fc_params['nf'], n_layers=fc_params['n_layers'] - 2) self.net2 = net_modules.FcNet( in_chns=fc_params['nf'], out_chns=2, nf=fc_params['nf'], n_layers=1) self.net = None elif self.dir_as_input: self.input_encoder2 = net_modules.InputEncoder.Get(4, 2) self.net1 = net_modules.FcNet( in_chns=fc_params['in_chns'], out_chns=fc_params['nf'], nf=fc_params['nf'], n_layers=fc_params['n_layers']) self.net2 = net_modules.FcNet( in_chns=fc_params['nf'] + self.input_encoder2.out_dim, out_chns=fc_params['out_chns'], nf=fc_params['nf'], n_layers=1) self.net = None else: self.net = net_modules.FcNet(**fc_params) if self.normalize_coord: self.register_buffer('angle_range', torch.tensor( [[1e5, 1e5], [-1e5, -1e5]])) self.register_buffer('depth_range', torch.tensor([ self.sampler.lower[0], self.sampler.upper[-1] ])) def update_normalize_range(self, rays_o: torch.Tensor, rays_d: torch.Tensor): coords, _, _ = self.sampler(rays_o, rays_d) coords = coords[..., 1:].view(-1, 2) self.angle_range = torch.stack([ torch.cat([coords, self.angle_range[0:1]]).amin(0), torch.cat([coords, self.angle_range[1:2]]).amax(0) ]) def calc_local_dir(self, rays_d, coords, pts: torch.Tensor): """ [summary] :param rays_d ```Tensor(B, 3)```: :param coords ```Tensor(B, N, 3)```: :param pts ```Tensor(B, N, 3)```: :return ```Tensor(B, N, 2)``` """ local_z = pts / pts.norm(dim=-1, keepdim=True) local_x = util.SphericalToCartesian(coords + torch.tensor([0, 0.1 / 180 * math.pi, 0], device=coords.device)) - pts local_x = local_x / local_x.norm(dim=-1, keepdim=True) local_y = torch.cross(local_x, local_z, -1) local_rot = torch.stack([local_x, local_y, local_z], dim=-2) # (B, N, 3, 3) return util.CartesianToSpherical(torch.matmul(rays_d[:, None, None, :], local_rot)).squeeze(-2)[..., 1:3] def forward(self, rays_o: torch.Tensor, rays_d: torch.Tensor, ret_depth: bool = False) -> torch.Tensor: """ rays -> colors :param rays_o ```Tensor(B, 3)```: rays' origin :param rays_d ```Tensor(B, 3)```: rays' direction :return: ```Tensor(B, C)``, inferred images/pixels """ coords, pts, depths = self.sampler(rays_o, rays_d) if self.dir_as_input: dirs = self.calc_local_dir(rays_d, coords, pts) if self.normalize_coord: # Normalize coords to [0, 2pi] range = torch.cat([self.depth_range.view(2, 1), self.angle_range], 1) coords = (coords - range[0]) / (range[1] - range[0]) * 2 * math.pi encoded = self.input_encoder(coords) if self.color == color_mode.YCbCr: mid_output = self.net1(encoded) net2_output = self.net2(mid_output[..., :-2]) raw = torch.cat([ mid_output[..., -2:], net2_output ], -1) elif self.dir_as_input: encoded_dirs = self.input_encoder2(dirs) #print(encoded.size(), self.net1(encoded).size(), encoded_dirs.size()) raw = self.net2(torch.cat([self.net1(encoded), encoded_dirs], -1)) else: raw = self.net(encoded) if self.export_mode: colors, alphas = self.rendering.raw2color(raw, depths) return torch.cat([colors, alphas[..., None]], -1) if ret_depth: color_map, _, _, _, depth_map = self.rendering( raw, depths, ret_extra=True) return color_map, depth_map return self.rendering(raw, depths)