import torch import torch.nn.functional as nn_f from typing import Any from torch import nn from utils.view import * from utils import math from .post_process import * from .foveation import Foveation class FoveatedNeuralRenderer(object): def __init__(self, layers_fov: list[float], layers_res: list[tuple[int, int]], layers_net: nn.ModuleList, output_res: tuple[int, int], *, coord_sys: str = "gl", device: torch.device = None): super().__init__() self.layers_net = layers_net.to(device=device) self.layers_cam = [ Camera.create({ 'fov': layers_fov[i], 'cx': 0.5, 'cy': 0.5, 'normalized': True }, layers_res[i], coord_sys=coord_sys, device=device) for i in range(len(layers_fov)) ] self.cam = Camera.create({ 'fov': layers_fov[-1], 'cx': 0.5, 'cy': 0.5, 'normalized': True }, output_res, coord_sys=coord_sys, device=device) self.foveation = Foveation(layers_fov, layers_res, output_res, device=device) self.device = device def to(self, device: torch.device): self.layers_net.to(device) self.foveation.to(device) self.cam.to(device) for cam in self.layers_cam: cam.to(device) self.device = device return self def __call__(self, view: Trans, gaze, right_gaze=None, *, stereo_disparity: float = 0, using_mask: bool = True, mono_periph_mode: int = 0, ret_raw: bool = False) -> dict[str, torch.Tensor] | tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]: if stereo_disparity > math.tiny: left_view = Trans( view.trans_point(torch.tensor([-stereo_disparity / 2, 0, 0], device=self.device)), view.r) right_view = Trans( view.trans_point(torch.tensor([stereo_disparity / 2, 0, 0], device=self.device)), view.r) left_gaze = gaze right_gaze = gaze if right_gaze is None else right_gaze layers_mask = self.foveation.get_layers_mask() if using_mask else [None] * 3 left_shifts = None right_shifts = None if using_mask and mono_periph_mode != 0: fovea_left = self._render(self.layers_net[0], self.layers_cam[0], left_view, left_gaze, layer_mask=layers_mask[0])['color'] fovea_right = self._render(self.layers_net[0], self.layers_cam[0], right_view, right_gaze, layer_mask=layers_mask[0])['color'] if mono_periph_mode == 3 or mono_periph_mode == 4: mid = self._render(self.layers_net[1], self.layers_cam[1], view, ((left_gaze[0] + right_gaze[0]) // 2, left_gaze[1]), layer_mask=layers_mask[1])['color'] periph = self._render(self.layers_net[2], self.layers_cam[2], view)['color'] raw_left = [fovea_left, mid, periph] raw_right = [fovea_right, mid, periph] shift = int(left_gaze[0] - right_gaze[0]) // 2 left_shifts = [0, 0, shift if mono_periph_mode == 3 else 0] right_shifts = [0, 0, -shift if mono_periph_mode == 3 else 0] else: mid_left = self._render_mid(self.layers_net[1], self.layers_cam[1], left_view, left_gaze, layer_mask=layers_mask[1], mono_view=view, blend_view=mono_periph_mode == 1)['color'] mid_right = self._render_mid(self.layers_net[1], self.layers_cam[1], right_view, right_gaze, layer_mask=layers_mask[1], mono_view=view, blend_view=mono_periph_mode == 1)['color'] periph = self._render(self.layers_net[2], self.layers_cam[2], view)['color'] raw_left = [fovea_left, mid_left, periph] raw_right = [fovea_right, mid_right, periph] else: raw_left = [ self._render(self.layers_net[i], self.layers_cam[i], left_view, left_gaze if i < 2 else None, layer_mask=layers_mask[i])['color'] for i in range(3) ] raw_right = [ self._render(self.layers_net[i], self.layers_cam[i], right_view, right_gaze if i < 2 else None, layer_mask=layers_mask[i])['color'] for i in range(3) ] return self._gen_output(raw_left, left_gaze, left_shifts, ret_raw=ret_raw), \ self._gen_output(raw_right, right_gaze, right_shifts, ret_raw=ret_raw) else: layers_mask = self.foveation.get_layers_mask(gaze) if using_mask else None res_raw = [ self._render(self.layers_net[i], self.layers_cam[i], view, gaze if i < 2 else None, layer_mask=layers_mask[i] if layers_mask is not None else None)['color'] for i in range(3) ] return self._gen_output(res_raw, gaze, ret_raw=ret_raw) def _render(self, net, cam: Camera, view: Trans, gaze=None, *, ret_depth=False, layer_mask=None) -> dict[str, torch.Tensor]: output_types = ["color"] if ret_depth: output_types.append("depth") if gaze is not None: cam = self._adjust_cam(cam, gaze) rays_d = view.trans_vector(cam.local_rays.reshape(*cam.res, -1)) # (1, H, W, 3) rays_o = view.t.broadcast_to(rays_d.shape) if layer_mask is not None: infer_mask = layer_mask >= 0 net_input = Rays({ "rays_o": rays_o[:, infer_mask].reshape(-1, 3), "rays_d": rays_d[:, infer_mask].reshape(-1, 3) }) net_output = net(net_input, *output_types) ret = { 'color': torch.zeros(1, cam.res[0], cam.res[1], 3, device=self.device) } ret['color'][:, infer_mask] = net_output['color'] ret['color'] = ret['color'].permute(0, 3, 1, 2) if ret_depth: ret['depth'] = torch.zeros(1, cam.res[0], cam.res[1]) ret['depth'][:, infer_mask] = net_output['depth'] return ret else: net_input = Rays({ "rays_o": rays_o.reshape(-1, 3), "rays_d": rays_d.reshape(-1, 3) }) net_output = net(net_input, *output_types) return { 'color': net_output['color'].view(1, cam.res[0], cam.res[1], -1).permute(0, 3, 1, 2), 'depth': net_output['depth'].view(1, cam.res[0], cam.res[1]) if ret_depth else None } def _render_mid(self, net, cam: Camera, view: Trans, gaze=None, *, layer_mask: torch.Tensor, mono_view: Trans, blend_view: bool, ret_depth=False) -> dict[str, torch.Tensor]: """ [summary] :param net: [description] :param cam: [description] :param view: [description] :param layer_mask: [description] :param mono_view: [description] :param gaze: [description], defaults to None :param ret_depth: [description], defaults to False :return: [description] """ output_types = ["color"] if ret_depth: output_types.append("depth") if gaze is not None: cam = self._adjust_cam(cam, gaze) k = layer_mask[None, ..., None].clamp(1 if blend_view else 2, 2) - 1 # (1, H, W, 1) rays_o = (1 - k) * view.t + k * mono_view.t # (1, H, W, 3) rays_d = view.trans_vector(cam.local_rays.reshape(*cam.res, -1)) # (1, H, W, 3) if layer_mask is not None: infer_mask = layer_mask >= 0 net_input = Rays({ "rays_o": rays_o[:, infer_mask].reshape(-1, 3), "rays_d": rays_d[:, infer_mask].reshape(-1, 3) }) net_output = net(net_input, *output_types) ret = { 'color': torch.zeros(1, cam.res[0], cam.res[1], 3, device=self.device) } ret['color'][:, infer_mask] = net_output['color'] ret['color'] = ret['color'].permute(0, 3, 1, 2) if ret_depth: ret['depth'] = torch.zeros(1, cam.res[0], cam.res[1]) ret['depth'][:, infer_mask] = net_output['depth'] return ret else: net_input = { "rays_o": rays_o.reshape(-1, 3), "rays_d": rays_d.reshape(-1, 3) } net_output = net(net_input, *output_types) return { 'color': net_output['color'].view(1, cam.res[0], cam.res[1], -1).permute(0, 3, 1, 2), 'depth': net_output['depth'].view(1, cam.res[0], cam.res[1]) if ret_depth else None } def _gen_output(self, layers_img: list[torch.Tensor], gaze: tuple[float, float], shifts=None, ret_raw=False) -> dict[str, torch.Tensor]: refined = self._post_process(layers_img) blended = self.foveation.synthesis(refined, gaze, shifts) ret = { 'layers_img': refined, 'blended': blended } if ret_raw: ret['layers_raw'] = layers_img ret['blended_raw'] = self.foveation.synthesis(layers_img, gaze, shifts) return ret def _post_process(self, layers_img: list[torch.Tensor]) -> list[torch.Tensor]: return [ #grad_aware_median(constrast_enhance(layers_img[0], 3, 0.2), 3, 3, True), constrast_enhance(layers_img[0], 3, 0.2), constrast_enhance(layers_img[1], 5, 0.2), constrast_enhance(layers_img[2], 5, 0.2) ] def _adjust_cam(self, layer_cam: Camera, gaze: tuple[float, float]) -> Camera: fovea_offset = ( (gaze[0]) / self.cam.f[0].item() * layer_cam.f[0].item(), (gaze[1]) / self.cam.f[1].item() * layer_cam.f[1].item() ) return Camera.create({ 'f': [layer_cam.f[0].item(), layer_cam.f[1].item()], 'c': [layer_cam.c[0].item() - fovea_offset[0], layer_cam.c[1].item() - fovea_offset[1]] }, layer_cam.res, coord_sys=layer_cam.coord_sys, device=self.device) def _warp(self, trans: Trans, trans0: Trans, cam: Camera, z_list: torch.Tensor, image: torch.Tensor, depthmap: torch.Tensor) -> torch.Tensor: """ [summary] :param trans: [description] :param trans0: [description] :param cam: [description] :param z_list: [description] :param image `Tensor(B, C, H, W)`: :param depthmap `Tensor(B, H, W)`: :return `Tensor(B, C, H, W)`: """ B = image.size(0) rays_d = cam.get_global_rays(trans, norm=False)[1] # (1, H, W, 3) rays_d_0 = trans0.trans_vector(rays_d, True)[0] # (1, H, W, 3) t_0 = trans0.trans_point(trans.t, True)[0] # (1, 3) q1_0 = torch.empty(B, cam.res[0], cam.res[1], 3, device=cam.device) # near q2_0 = torch.empty(B, cam.res[0], cam.res[1], 3, device=cam.device) # far determined = torch.zeros(B, cam.res[0], cam.res[1], 1, dtype=torch.bool, device=cam.device) for z in z_list: p_0 = rays_d_0 * z + t_0 # (1, H, W, 3) d_of_p_0 = torch.norm(p_0 - trans0.t, dim=-1, keepdim=True) # (1, H, W, 1) v_of_p_0 = p_0 / d_of_p_0 # (1, H, W, 3) coords = cam.proj(p_0, True) * 2 - 1 # (1, H, W, 2) d = nn_f.grid_sample( depthmap[:, None, :, :], coords.expand(B, -1, -1, -1)).permute(0, 2, 3, 1) # (B, H, W, 1) q = v_of_p_0 * d # (B, H, W, 3) near_selector = d < d_of_p_0 # Fill q2(far) when undetermined and d > d_of_p_0 q2_selector = (~determined & ~near_selector).expand(-1, -1, -1, 3) q2_0[q2_selector] = q[q2_selector] # Fill q1(near) when undetermined and d <= d_of_p_0 q1_selector = (~determined & near_selector).expand(-1, -1, -1, 3) q1_0[q1_selector] = q[q1_selector] # Mark as determined for d0 <= d determined[near_selector] = True # Compute intersection x of q1-q2 and rays (in trans0 space) k = torch.cross(q1_0 - t_0, rays_d_0, dim=-1).norm(dim=-1, keepdim=True) / \ torch.cross(rays_d_0, q2_0 - t_0, dim=-1).norm(dim=- 1, keepdim=True) # (B, H, W, 1) x_0 = (q2_0 - q1_0) * k / (k + 1) + q1_0 coords = cam.proj(x_0, True) * 2 - 1 # (B, H, W, 2) return nn_f.grid_sample(image, coords)