import torch import torch.nn.functional as nn_f from typing import Any, List, Mapping, Tuple from torch import nn from utils.view import * from utils.constants import * from .post_process import * from .foveation import Foveation class FoveatedNeuralRenderer(object): def __init__(self, layers_fov: List[float], layers_res: List[Tuple[int, int]], layers_net: nn.ModuleList, output_res: Tuple[int, int], *, device: torch.device = None): super().__init__() self.layers_net = layers_net.to(device=device) self.layers_cam = [ CameraParam({ 'fov': layers_fov[i], 'cx': 0.5, 'cy': 0.5, 'normalized': True }, layers_res[i], device=device) for i in range(len(layers_fov)) ] self.cam = CameraParam({ 'fov': layers_fov[-1], 'cx': 0.5, 'cy': 0.5, 'normalized': True }, output_res, device=device) self.foveation = Foveation(layers_fov, layers_res, output_res, device=device) self.device = device def to(self, device: torch.device): self.layers_net.to(device) self.foveation.to(device) self.cam.to(device) for cam in self.layers_cam: cam.to(device) self.device = device return self def __call__(self, *args: Any, **kwds: Any) -> Any: return self.render(*args, **kwds) def render(self, view: Trans, gaze, right_gaze=None, *, stereo_disparity=0, using_mask=True, mono_periph_mode=0, ret_raw=False) -> Union[Mapping[str, torch.Tensor], Tuple[Mapping[str, torch.Tensor]]]: if stereo_disparity > TINY_FLOAT: left_view = Trans( view.trans_point(torch.tensor([-stereo_disparity / 2, 0, 0], device=self.device)), view.r) right_view = Trans( view.trans_point(torch.tensor([stereo_disparity / 2, 0, 0], device=self.device)), view.r) left_gaze = gaze right_gaze = gaze if right_gaze is None else right_gaze layers_mask = self.foveation.get_layers_mask() if using_mask else [None] * 3 left_shifts = None right_shifts = None if using_mask and mono_periph_mode != 0: fovea_left = self._render(self.layers_net[0], self.layers_cam[0], left_view, left_gaze, layer_mask=layers_mask[0])['color'] fovea_right = self._render(self.layers_net[0], self.layers_cam[0], right_view, right_gaze, layer_mask=layers_mask[0])['color'] if mono_periph_mode == 3: mid = self._render(self.layers_net[1], self.layers_cam[1], view, ((left_gaze[0] + right_gaze[0]) // 2, left_gaze[1]), layer_mask=layers_mask[1])['color'] periph = self._render(self.layers_net[2], self.layers_cam[2], view)['color'] raw_left = [fovea_left, mid, periph] raw_right = [fovea_right, mid, periph] shift = int(left_gaze[0] - right_gaze[0]) // 2 left_shifts = [0, 0, shift] right_shifts = [0, 0, -shift] else: mid_left = self._render_mid(self.layers_net[1], self.layers_cam[1], left_view, left_gaze, layer_mask=layers_mask[1], mono_view=view, blend_view=mono_periph_mode == 1)['color'] mid_right = self._render_mid(self.layers_net[1], self.layers_cam[1], right_view, right_gaze, layer_mask=layers_mask[1], mono_view=view, blend_view=mono_periph_mode == 1)['color'] periph = self._render(self.layers_net[2], self.layers_cam[2], view)['color'] raw_left = [fovea_left, mid_left, periph] raw_right = [fovea_right, mid_right, periph] else: raw_left = [ self._render(self.layers_net[i], self.layers_cam[i], left_view, left_gaze if i < 2 else None, layer_mask=layers_mask[i])['color'] for i in range(3) ] raw_right = [ self._render(self.layers_net[i], self.layers_cam[i], right_view, right_gaze if i < 2 else None, layer_mask=layers_mask[i])['color'] for i in range(3) ] return self._gen_output(raw_left, left_gaze, left_shifts, ret_raw=ret_raw), \ self._gen_output(raw_right, right_gaze, right_shifts, ret_raw=ret_raw) else: layers_mask = self.foveation.get_layers_mask(gaze) if using_mask else None res_raw = [ self._render(self.layers_net[i], self.layers_cam[i], view, gaze if i < 2 else None, layer_mask=layers_mask[i] if layers_mask is not None else None)['color'] for i in range(3) ] return self._gen_output(res_raw, gaze, ret_raw=ret_raw) def _render(self, net, cam: CameraParam, view: Trans, gaze=None, *, ret_depth=False, layer_mask=None) -> Mapping[str, torch.Tensor]: if gaze is not None: cam = self._adjust_cam(cam, gaze) rays_o, rays_d = cam.get_global_rays(view, False) # (1, H, W, 3) if layer_mask is not None: infer_mask = layer_mask >= 0 rays_o = rays_o[:, infer_mask] rays_d = rays_d[:, infer_mask] net_output = net(rays_o.view(-1, 3), rays_d.view(-1, 3), ret_depth=ret_depth) ret = { 'color': torch.zeros(1, cam.res[0], cam.res[1], 3, device=self.device) } ret['color'][:, infer_mask] = net_output['color'] ret['color'] = ret['color'].permute(0, 3, 1, 2) if ret_depth: ret['depth'] = torch.zeros(1, cam.res[0], cam.res[1]) ret['depth'][:, infer_mask] = net_output['depth'] return ret else: net_output = net(rays_o.view(-1, 3), rays_d.view(-1, 3), ret_depth=ret_depth) return { 'color': net_output['color'].view(1, cam.res[0], cam.res[1], -1).permute(0, 3, 1, 2), 'depth': net_output['depth'].view(1, cam.res[0], cam.res[1]) if ret_depth else None } def _render_mid(self, net, cam: CameraParam, view: Trans, gaze=None, *, layer_mask: torch.Tensor, mono_view: Trans, blend_view: bool, ret_depth=False) -> Mapping[str, torch.Tensor]: """ [summary] :param net: [description] :param cam: [description] :param view: [description] :param layer_mask: [description] :param mono_view: [description] :param gaze: [description], defaults to None :param ret_depth: [description], defaults to False :return: [description] """ if gaze is not None: cam = self._adjust_cam(cam, gaze) k = layer_mask[None, ..., None].clamp(1 if blend_view else 2, 2) - 1 # (1, H, W, 1) rays_o = (1 - k) * view.t + k * mono_view.t # (1, H, W, 3) rays_d = view.trans_vector(cam.get_local_rays()) # (1, H, W, 3) if layer_mask is not None: infer_mask = layer_mask >= 0 rays_o = rays_o[:, infer_mask] rays_d = rays_d[:, infer_mask] net_output = net(rays_o.view(-1, 3), rays_d.view(-1, 3), ret_depth=ret_depth) ret = { 'color': torch.zeros(1, cam.res[0], cam.res[1], 3, device=self.device) } ret['color'][:, infer_mask] = net_output['color'] ret['color'] = ret['color'].permute(0, 3, 1, 2) if ret_depth: ret['depth'] = torch.zeros(1, cam.res[0], cam.res[1]) ret['depth'][:, infer_mask] = net_output['depth'] return ret else: net_output = net(rays_o.view(-1, 3), rays_d.view(-1, 3), ret_depth=ret_depth) return { 'color': net_output['color'].view(1, cam.res[0], cam.res[1], -1).permute(0, 3, 1, 2), 'depth': net_output['depth'].view(1, cam.res[0], cam.res[1]) if ret_depth else None } def _gen_output(self, layers_img: List[torch.Tensor], gaze: Tuple[float, float], shifts=None, ret_raw=False) -> Mapping[str, torch.Tensor]: refined = self._post_process(layers_img) blended = self.foveation.synthesis(refined, gaze, shifts) ret = { 'layers_img': refined, 'blended': blended } if ret_raw: ret['layers_raw'] = layers_img ret['blended_raw'] = self.foveation.synthesis(layers_img, gaze) return ret def _post_process(self, layers_img: List[torch.Tensor]) -> List[torch.Tensor]: return [ #grad_aware_median(constrast_enhance(layers_img[0], 3, 0.2), 3, 3, True), constrast_enhance(layers_img[0], 3, 0.2), constrast_enhance(layers_img[1], 5, 0.2), constrast_enhance(layers_img[2], 5, 0.2) ] def _adjust_cam(self, layer_cam: CameraParam, gaze: Tuple[float, float]) -> CameraParam: fovea_offset = ( (gaze[0]) / self.cam.f[0].item() * layer_cam.f[0].item(), (gaze[1]) / self.cam.f[1].item() * layer_cam.f[1].item() ) return CameraParam({ 'fx': layer_cam.f[0].item(), 'fy': layer_cam.f[1].item(), 'cx': layer_cam.c[0].item() - fovea_offset[0], 'cy': layer_cam.c[1].item() - fovea_offset[1] }, layer_cam.res, device=self.device) def _warp(self, trans: Trans, trans0: Trans, cam: CameraParam, z_list: torch.Tensor, image: torch.Tensor, depthmap: torch.Tensor) -> torch.Tensor: """ [summary] :param trans: [description] :param trans0: [description] :param cam: [description] :param z_list: [description] :param image `Tensor(B, C, H, W)`: :param depthmap `Tensor(B, H, W)`: :return `Tensor(B, C, H, W)`: """ B = image.size(0) rays_d = cam.get_global_rays(trans, norm=False)[1] # (1, H, W, 3) rays_d_0 = trans0.trans_vector(rays_d, True)[0] # (1, H, W, 3) t_0 = trans0.trans_point(trans.t, True)[0] # (1, 3) q1_0 = torch.empty(B, cam.res[0], cam.res[1], 3, device=cam.device) # near q2_0 = torch.empty(B, cam.res[0], cam.res[1], 3, device=cam.device) # far determined = torch.zeros(B, cam.res[0], cam.res[1], 1, dtype=torch.bool, device=cam.device) for z in z_list: p_0 = rays_d_0 * z + t_0 # (1, H, W, 3) d_of_p_0 = torch.norm(p_0 - trans0.t, dim=-1, keepdim=True) # (1, H, W, 1) v_of_p_0 = p_0 / d_of_p_0 # (1, H, W, 3) coords = cam.proj(p_0, True) * 2 - 1 # (1, H, W, 2) d = nn_f.grid_sample( depthmap[:, None, :, :], coords.expand(B, -1, -1, -1)).permute(0, 2, 3, 1) # (B, H, W, 1) q = v_of_p_0 * d # (B, H, W, 3) near_selector = d < d_of_p_0 # Fill q2(far) when undetermined and d > d_of_p_0 q2_selector = (~determined & ~near_selector).expand(-1, -1, -1, 3) q2_0[q2_selector] = q[q2_selector] # Fill q1(near) when undetermined and d <= d_of_p_0 q1_selector = (~determined & near_selector).expand(-1, -1, -1, 3) q1_0[q1_selector] = q[q1_selector] # Mark as determined for d0 <= d determined[near_selector] = True # Compute intersection x of q1-q2 and rays (in trans0 space) k = torch.cross(q1_0 - t_0, rays_d_0, dim=-1).norm(dim=-1, keepdim=True) / \ torch.cross(rays_d_0, q2_0 - t_0, dim=-1).norm(dim=- 1, keepdim=True) # (B, H, W, 1) x_0 = (q2_0 - q1_0) * k / (k + 1) + q1_0 coords = cam.proj(x_0, True) * 2 - 1 # (B, H, W, 2) return nn_f.grid_sample(image, coords)