foveation.py 3.17 KB
Newer Older
BobYeah's avatar
sync    
BobYeah committed
1
2
3
4
import math
import torch
import torch.nn.functional as nn_f
from typing import List, Tuple
Nianchen Deng's avatar
sync    
Nianchen Deng committed
5
6
7
from utils import img
from utils import view
from utils import misc
BobYeah's avatar
sync    
BobYeah committed
8
9
10
11
12
13
14
15
16
17
18
19
20
21


class Foveation(object):

    def __init__(self, fov_list: List[float],
                 out_res: Tuple[int, int], *, device=None):
        self.fov_list = fov_list
        self.out_res = out_res
        self.device = device
        self.n_layers = len(self.fov_list)
        self.eye_fovea_blend = [
            self._gen_layer_blendmap(i)
            for i in range(self.n_layers - 1)
        ]  # blend maps of fovea layers
Nianchen Deng's avatar
sync    
Nianchen Deng committed
22
        self.coords = misc.meshgrid(*out_res).to(device=device)
BobYeah's avatar
sync    
BobYeah committed
23
24

    def to(self, device):
Nianchen Deng's avatar
sync    
Nianchen Deng committed
25
26
27
        self.eye_fovea_blend = [x.to(device=device)
                                for x in self.eye_fovea_blend]
        self.coords = self.coords.to(device=device)
BobYeah's avatar
sync    
BobYeah committed
28
29
        return self

Nianchen Deng's avatar
sync    
Nianchen Deng committed
30
    def synthesis(self, layers: List[torch.Tensor],
Nianchen Deng's avatar
Nianchen Deng committed
31
32
                  fovea_center: Tuple[float, float],
                  shifts: List[int] = None) -> torch.Tensor:
BobYeah's avatar
sync    
BobYeah committed
33
34
35
36
        """
        Generate foveated retinal image by blending fovea layers
        **Note: current implementation only support two fovea layers**

Nianchen Deng's avatar
sync    
Nianchen Deng committed
37
38
        :param layers `List(Tensor(B, C, H'{l}, W'{l}))`: list of foveated layers
        :return `Tensor(B, C, H:out, W:out)`: foveated images
BobYeah's avatar
sync    
BobYeah committed
39
40
        """
        output: torch.Tensor = nn_f.interpolate(layers[-1], self.out_res,
Nianchen Deng's avatar
sync    
Nianchen Deng committed
41
                                                mode='bilinear', align_corners=False)
Nianchen Deng's avatar
Nianchen Deng committed
42
        if shifts != None:
Nianchen Deng's avatar
sync    
Nianchen Deng committed
43
            output = img.horizontal_shift(output, shifts[-1])
Nianchen Deng's avatar
sync    
Nianchen Deng committed
44
        c = torch.tensor([
Nianchen Deng's avatar
sync    
Nianchen Deng committed
45
46
            fovea_center[0] + self.out_res[1] / 2,
            fovea_center[1] + self.out_res[0] / 2
Nianchen Deng's avatar
sync    
Nianchen Deng committed
47
        ], device=self.coords.device)
BobYeah's avatar
sync    
BobYeah committed
48
        for i in range(self.n_layers - 2, -1, -1):
Nianchen Deng's avatar
sync    
Nianchen Deng committed
49
50
51
52
            if layers[i] == None:
                continue
            R = self.get_layer_size_in_final_image(i) / 2
            grid = ((self.coords - c) / R)[None, ...]
Nianchen Deng's avatar
Nianchen Deng committed
53
            if shifts != None:
Nianchen Deng's avatar
sync    
Nianchen Deng committed
54
                grid = img.horizontal_shift(grid, shifts[i], -2)
Nianchen Deng's avatar
sync    
Nianchen Deng committed
55
56
            blend = nn_f.grid_sample(self.eye_fovea_blend[i][None, None, ...], grid) # (1, 1, H:out, W:out)
            output.mul_(1 - blend).add_(nn_f.grid_sample(layers[i], grid) * blend)
BobYeah's avatar
sync    
BobYeah committed
57
58
59
60
61
62
63
64
65
        return output

    def get_layer_size_in_final_image(self, i: int) -> int:
        """
        Get size of layer i in final image

        :param i: index of layer
        :return: size of layer i in final image (in pixels)
        """
Nianchen Deng's avatar
sync    
Nianchen Deng committed
66
67
        length_i = view.fov2length(self.fov_list[i])
        length = view.fov2length(self.fov_list[-1])
BobYeah's avatar
sync    
BobYeah committed
68
69
70
71
72
73
74
75
        k = length_i / length
        return int(math.ceil(self.out_res[0] * k))

    def _gen_layer_blendmap(self, i: int) -> torch.Tensor:
        """
        Generate blend map for fovea layer i

        :param i: index of fovea layer
Nianchen Deng's avatar
sync    
Nianchen Deng committed
76
        :return `Tensor(H{i}, W{i})`: blend map
BobYeah's avatar
sync    
BobYeah committed
77
78
79
        """
        size = self.get_layer_size_in_final_image(i)
        R = size / 2
Nianchen Deng's avatar
sync    
Nianchen Deng committed
80
        p = misc.meshgrid(size, size).to(device=self.device)  # (size, size, 2)
BobYeah's avatar
sync    
BobYeah committed
81
        r = torch.norm(p - R, dim=2)  # (size, size, 2)
Nianchen Deng's avatar
sync    
Nianchen Deng committed
82
        return misc.smooth_step(R, R * 0.6, r)