import math import torch import torch.nn.functional as nn_f from typing import List, Tuple from utils import img from utils import view from utils import misc class Foveation(object): def __init__(self, layers_fov: List[float], layers_res: List[Tuple[float, float]], out_res: Tuple[int, int], *, blend=0.6, device=None): self.layers_fov = layers_fov self.layers_res = layers_res self.out_res = out_res self.blend = blend self.device = device self.n_layers = len(self.layers_fov) self.eye_fovea_blend = [ self._gen_layer_blendmap(i) for i in range(self.n_layers - 1) ] # blend maps of fovea layers self.coords = misc.meshgrid(*out_res).to(device=device) def to(self, device): self.eye_fovea_blend = [x.to(device=device) for x in self.eye_fovea_blend] self.coords = self.coords.to(device=device) return self def synthesis(self, layers: List[torch.Tensor], fovea_center: Tuple[float, float], shifts: List[int] = None) -> torch.Tensor: """ Generate foveated retinal image by blending fovea layers **Note: current implementation only support two fovea layers** :param layers `List(Tensor(B, C, H'{l}, W'{l}))`: list of foveated layers :return `Tensor(B, C, H:out, W:out)`: foveated images """ output: torch.Tensor = nn_f.interpolate(layers[-1], self.out_res, mode='bilinear', align_corners=False) if shifts != None: output = img.horizontal_shift(output, shifts[-1]) c = torch.tensor([ fovea_center[0] + self.out_res[1] / 2, fovea_center[1] + self.out_res[0] / 2 ], device=self.coords.device) for i in range(self.n_layers - 2, -1, -1): if layers[i] == None: continue R = self.get_layer_size_in_final_image(i) / 2 grid = ((self.coords - c) / R)[None, ...] if shifts != None: grid = img.horizontal_shift(grid, shifts[i], -2) # (1, 1, H:out, W:out) blend = nn_f.grid_sample(self.eye_fovea_blend[i][None, None, ...], grid) output.mul_(1 - blend).add_(nn_f.grid_sample(layers[i], grid) * blend) return output def get_layer_size_in_final_image(self, i: int) -> int: """ Get size of layer i in final image :param i: index of layer :return: size of layer i in final image (in pixels) """ return self.get_source_layer_cover_size_in_target_layer( self.layers_fov[i], self.layers_fov[-1], self.out_res[0]) def get_source_layer_cover_size_in_target_layer(self, source_fov, target_fov, target_pixel_height) -> int: """ Get size of layer i in final image :param i: index of layer :return: size of layer i in final image (in pixels) """ source_physical_height = view.fov2length(source_fov) target_physical_height = view.fov2length(target_fov) return int(math.ceil(target_pixel_height * source_physical_height / target_physical_height)) def _gen_layer_blendmap(self, i: int) -> torch.Tensor: """ Generate blend map for fovea layer i :param i: index of fovea layer :return `Tensor(H{i}, W{i})`: blend map """ size = self.get_layer_size_in_final_image(i) R = size / 2 p = misc.meshgrid(size, size).to(device=self.device) # (size, size, 2) r = torch.norm(p - R, dim=2) # (size, size, 2) return misc.smooth_step(R, R * self.blend, r) def get_layers_mask(self) -> List[torch.Tensor]: """ Generate mask images for layers[:-1] the meaning of values in mask images: -1: skipped 0~1: blend with inner layer 1~2: only self layer 2~3: blend with outer layer :return: Mask images for layers except outermost """ layers_mask = [] for i in range(self.n_layers - 1): layers_mask.append(torch.ones(*self.layers_res[i], device=self.device) * -1) r = torch.norm(misc.meshgrid(*self.layers_res[i], normalize=True).to(device=self.device) * 2 - 1, dim=-1) inner_radius = self.get_source_layer_cover_size_in_target_layer( self.layers_fov[i - 1], self.layers_fov[i], self.layers_res[i][0]) / self.layers_res[i][0] if i > 0 else 0 bounds = [inner_radius * (1 - self.blend), inner_radius, self.blend, 1] for bi in range(len(bounds) - 1): region = torch.logical_and(r > bounds[bi], r <= bounds[bi + 1]) layers_mask[i][region] = bi + \ (r[region] - bounds[bi]) / (bounds[bi + 1] - bounds[bi]) return layers_mask