import torch import torch.nn.functional as nn_f from typing import List, Tuple from utils import img from utils import view from utils import misc from utils import math class Foveation(object): def __init__(self, layers_fov: list[float], layers_res: list[tuple[float, float]], out_res: tuple[int, int], *, blend: float = 0.6, device: torch.device = None): self.layers_fov = layers_fov self.layers_res = layers_res self.out_res = out_res self.blend = blend self.device = device self.n_layers = len(self.layers_fov) self.eye_fovea_blend = [ self._gen_layer_blendmap(i) for i in range(self.n_layers - 1) ] # blend maps of fovea layers self.coords = misc.grid2d(*out_res, device=device) def to(self, device: torch.device): self.eye_fovea_blend = [x.to(device=device) for x in self.eye_fovea_blend] self.coords = self.coords.to(device=device) return self def synthesis(self, layers: list[torch.Tensor], fovea_center: tuple[float, float], shifts: list[int] = None, do_blend: bool = True, crop_mode: bool = False) -> torch.Tensor: """ Generate foveated retinal image by blending fovea layers **Note: current implementation only support two fovea layers** :param layers `List(Tensor(B, C, H'{l}, W'{l}))`: list of foveated layers :return `Tensor(B, C, H:out, W:out)`: foveated images """ output: torch.Tensor = nn_f.interpolate(layers[-1], self.out_res, mode='bilinear', align_corners=False) #output.fill_(0) # TODO: debug if shifts is not None: output = img.horizontal_shift(output, shifts[-1]) c = torch.tensor([ fovea_center[0] + (self.out_res[1] - 1) / 2, fovea_center[1] + (self.out_res[0] - 1) / 2 ], device=self.device) for i in range(self.n_layers - 2, -1, -1): if layers[i] is None: continue R = self.get_layer_size_in_final_image(i) / 2 grid = ((self.coords - c) / R)[None, ...] if shifts is not None: grid = img.horizontal_shift(grid, shifts[i], -2) # (1, 1, H:out, W:out) if do_blend: blend = nn_f.grid_sample(self.eye_fovea_blend[i][None, None], grid, align_corners=False) else: blend = nn_f.grid_sample(torch.ones_like(self.eye_fovea_blend[i][None, None]), grid, align_corners=False) output.mul_(1 - blend) if crop_mode: output.add_(blend * nn_f.interpolate(layers[i], self.out_res, mode='bilinear', align_corners=False)) else: output.add_(blend * nn_f.grid_sample(layers[i], grid, align_corners=False)) return output def get_layer_size_in_final_image(self, i: int) -> int: """ Get size of layer i in final image :param i: index of layer :return: size of layer i in final image (in pixels) """ return self.get_source_layer_cover_size_in_target_layer( self.layers_fov[i], self.layers_fov[-1], self.out_res[0]) def get_source_layer_cover_size_in_target_layer(self, source_fov, target_fov, target_pixel_height) -> int: """ Get size of layer i in final image :param i: index of layer :return: size of layer i in final image (in pixels) """ source_physical_height = view.fov2length(source_fov) target_physical_height = view.fov2length(target_fov) return int(math.ceil(target_pixel_height * source_physical_height / target_physical_height)) def _gen_layer_blendmap(self, i: int) -> torch.Tensor: """ Generate blend map for fovea layer i :param i: index of fovea layer :return `Tensor(H{i}, W{i})`: blend map """ size = self.get_layer_size_in_final_image(i) R = size / 2 p = misc.grid2d(size, device=self.device) # (size, size, 2) r = torch.norm(p - R, dim=2) # (size, size, 2) return misc.smooth_step(R, R * self.blend, r) def get_layers_mask(self, gaze=None) -> list[torch.Tensor]: """ Generate mask images for layers[:-1] the meaning of values in mask images: -1: skipped 0~1: blend with inner layer 1~2: only self layer 2~3: blend with outer layer :return: Mask images for layers except outermost """ layers_mask = [] for i in range(self.n_layers): if i == self.n_layers - 1: if gaze is None: layers_mask.append(torch.ones(*self.layers_res[i], device=self.device)) continue c = torch.tensor([ (gaze[0] + 0.5 * self.out_res[1]) / self.out_res[0], (gaze[1] + 0.5 * self.out_res[0]) / self.out_res[0] ], device=self.device) else: c = torch.tensor([0.5, 0.5], device=self.device) layers_mask.append(torch.ones(*self.layers_res[i], device=self.device) * -1) coord = misc.grid2d(*self.layers_res[i], device=self.device) / self.layers_res[i][0] r = 2 * torch.norm(coord - c, dim=-1) inner_radius = self.get_source_layer_cover_size_in_target_layer( self.layers_fov[i - 1], self.layers_fov[i], self.layers_res[i][0]) / self.layers_res[i][0] \ if i > 0 else 0 if i == self.n_layers - 1: bounds = [inner_radius * (1 - self.blend), inner_radius, 100, 100] else: bounds = [inner_radius * (1 - self.blend), inner_radius, self.blend, 1] for bi in range(len(bounds) - 1): region = torch.logical_and(r >= bounds[bi], r < bounds[bi + 1]) layers_mask[i][region] = bi + \ (r[region] - bounds[bi]) / (bounds[bi + 1] - bounds[bi]) return layers_mask