import torch import torch.nn.functional as nn_f from typing import Any, List, Mapping, Tuple from torch import nn from utils.view import * from utils.constants import * from .post_process import * from .foveation import Foveation class FoveatedNeuralRenderer(object): def __init__(self, layers_fov: List[float], layers_res: List[Tuple[int, int]], layers_net: nn.ModuleList, output_res: Tuple[int, int], *, using_mask=True, device: torch.device = None): super().__init__() self.layers_net = layers_net.to(device=device) self.layers_cam = [ CameraParam({ 'fov': layers_fov[i], 'cx': 0.5, 'cy': 0.5, 'normalized': True }, layers_res[i], device=device) for i in range(len(layers_fov)) ] self.cam = CameraParam({ 'fov': layers_fov[-1], 'cx': 0.5, 'cy': 0.5, 'normalized': True }, output_res, device=device) self.foveation = Foveation(layers_fov, layers_res, output_res, device=device) self.layers_mask = self.foveation.get_layers_mask() if using_mask else None self.device = device def to(self, device: torch.device): self.layers_net.to(device) self.foveation.to(device) self.cam.to(device) for cam in self.layers_cam: cam.to(device) if self.layers_mask is not None: self.layers_mask = self.layers_mask.to(device) self.device = device return self def __call__(self, *args: Any, **kwds: Any) -> Any: return self.render(*args, **kwds) def render(self, view: Trans, gaze, right_gaze=None, *, stereo_disparity=0, ret_raw=False) -> Union[Mapping[str, torch.Tensor], Tuple[Mapping[str, torch.Tensor]]]: if stereo_disparity > TINY_FLOAT: left_view = Trans( view.trans_point(torch.tensor([-stereo_disparity / 2, 0, 0], device=view.device())), view.r) right_view = Trans( view.trans_point(torch.tensor([stereo_disparity / 2, 0, 0], device=view.device())), view.r) left_gaze = gaze right_gaze = gaze if right_gaze is None else right_gaze res_raw_left = [ self._render(i, left_view, left_gaze if i < 2 else None)['color'] for i in range(3) ] res_raw_right = [ self._render(i, right_view, right_gaze if i < 2 else None)['color'] for i in range(3) ] return self._gen_output(res_raw_left, left_gaze, ret_raw), \ self._gen_output(res_raw_right, right_gaze, ret_raw) else: res_raw = [ self._render(i, view, gaze if i < 2 else None)['color'] for i in range(3) ] return self._gen_output(res_raw, gaze, ret_raw) ''' if mono_trans != None and shift == 0: # do warp fovea_depth[torch.isnan(fovea_depth)] = 50 mid_depth[torch.isnan(mid_depth)] = 50 periph_depth[torch.isnan(periph_depth)] = 50 if warp_by_depth: z_list = misc.depth_sample((1, 50), 4, True) mid_inferred = self._warp(trans, mono_trans, mid_cam, z_list, mid_inferred, mid_depth) periph_inferred = self._warp(trans, mono_trans, periph_cam, z_list, periph_inferred, periph_depth) else: p = torch.tensor([[0, 0, torch.mean(fovea_depth)]], device=self.device) p_ = trans.trans_point(mono_trans.trans_point(p), inverse=True) shift = self.full_cam.proj( p_, center_as_origin=True)[..., 0].item() shift = round(shift) blended = self.foveation.synthesis([ fovea_refined, mid_refined, periph_refined ], (gaze[0], gaze[1]), [0, shift, shift] if shift != 0 else None) ''' def _render(self, layer: int, view: Trans, gaze=None, ret_depth=False) -> Mapping[str, torch.Tensor]: net = self.layers_net[layer] cam = self.layers_cam[layer] if gaze is not None: cam = self._adjust_cam(cam, gaze) rays_o, rays_d = cam.get_global_rays(view, True) # (1, N, 3) if self.layers_mask is not None and layer < len(self.layers_mask): mask = self.layers_mask[layer] >= 0 rays_o = rays_o[:, mask] rays_d = rays_d[:, mask] net_output = net(rays_o.view(-1, 3), rays_d.view(-1, 3), ret_depth=ret_depth) ret = { 'color': torch.zeros(1, cam.res[0], cam.res[1], 3) } ret['color'][:, mask] = net_output['color'] ret['color'] = ret['color'].permute(0, 3, 1, 2) if ret_depth: ret['depth'] = torch.zeros(1, cam.res[0], cam.res[1]) ret['depth'][:, mask] = net_output['depth'] return ret else: net_output = net(rays_o.view(-1, 3), rays_d.view(-1, 3), ret_depth=ret_depth) return { 'color': net_output['color'].view(1, cam.res[0], cam.res[1], -1).permute(0, 3, 1, 2), 'depth': net_output['depth'].view(1, cam.res[0], cam.res[1]) if ret_depth else None } def _gen_output(self, layers_img: List[torch.Tensor], gaze: Tuple[float, float], ret_raw=False) -> Mapping[str, torch.Tensor]: refined = self._post_process(layers_img) blended = self.foveation.synthesis(refined, gaze) ret = { 'layers_img': refined, 'blended': blended } if ret_raw: ret['layers_raw'] = layers_img, ret['blended_raw'] = self.foveation.synthesis(layers_img, gaze) return ret def _post_process(self, layers_img: List[torch.Tensor]) -> List[torch.Tensor]: return [ #grad_aware_median(constrast_enhance(layers_img[0], 3, 0.2), 3, 3, True), constrast_enhance(layers_img[0], 3, 0.2), constrast_enhance(layers_img[1], 5, 0.2), constrast_enhance(layers_img[2], 5, 0.2) ] def _adjust_cam(self, layer_cam: CameraParam, gaze: Tuple[float, float]) -> CameraParam: fovea_offset = ( (gaze[0]) / self.cam.f[0].item() * layer_cam.f[0].item(), (gaze[1]) / self.cam.f[1].item() * layer_cam.f[1].item() ) return CameraParam({ 'fx': layer_cam.f[0].item(), 'fy': layer_cam.f[1].item(), 'cx': layer_cam.c[0].item() - fovea_offset[0], 'cy': layer_cam.c[1].item() - fovea_offset[1] }, layer_cam.res, device=self.device) def _warp(self, trans: Trans, trans0: Trans, cam: CameraParam, z_list: torch.Tensor, image: torch.Tensor, depthmap: torch.Tensor) -> torch.Tensor: """ [summary] :param trans: [description] :param trans0: [description] :param cam: [description] :param z_list: [description] :param image `Tensor(B, C, H, W)`: :param depthmap `Tensor(B, H, W)`: :return `Tensor(B, C, H, W)`: """ B = image.size(0) rays_d = cam.get_global_rays(trans, norm=False)[1] # (1, H, W, 3) rays_d_0 = trans0.trans_vector(rays_d, True)[0] # (1, H, W, 3) t_0 = trans0.trans_point(trans.t, True)[0] # (1, 3) q1_0 = torch.empty(B, cam.res[0], cam.res[1], 3, device=cam.device) # near q2_0 = torch.empty(B, cam.res[0], cam.res[1], 3, device=cam.device) # far determined = torch.zeros(B, cam.res[0], cam.res[1], 1, dtype=torch.bool, device=cam.device) for z in z_list: p_0 = rays_d_0 * z + t_0 # (1, H, W, 3) d_of_p_0 = torch.norm(p_0 - trans0.t, dim=-1, keepdim=True) # (1, H, W, 1) v_of_p_0 = p_0 / d_of_p_0 # (1, H, W, 3) coords = cam.proj(p_0, True) * 2 - 1 # (1, H, W, 2) d = nn_f.grid_sample( depthmap[:, None, :, :], coords.expand(B, -1, -1, -1)).permute(0, 2, 3, 1) # (B, H, W, 1) q = v_of_p_0 * d # (B, H, W, 3) near_selector = d < d_of_p_0 # Fill q2(far) when undetermined and d > d_of_p_0 q2_selector = (~determined & ~near_selector).expand(-1, -1, -1, 3) q2_0[q2_selector] = q[q2_selector] # Fill q1(near) when undetermined and d <= d_of_p_0 q1_selector = (~determined & near_selector).expand(-1, -1, -1, 3) q1_0[q1_selector] = q[q1_selector] # Mark as determined for d0 <= d determined[near_selector] = True # Compute intersection x of q1-q2 and rays (in trans0 space) k = torch.cross(q1_0 - t_0, rays_d_0, dim=-1).norm(dim=-1, keepdim=True) / \ torch.cross(rays_d_0, q2_0 - t_0, dim=-1).norm(dim=- 1, keepdim=True) # (B, H, W, 1) x_0 = (q2_0 - q1_0) * k / (k + 1) + q1_0 coords = cam.proj(x_0, True) * 2 - 1 # (B, H, W, 2) return nn_f.grid_sample(image, coords)