foveation.py 3.09 KB
Newer Older
BobYeah's avatar
sync    
BobYeah committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import math
import torch
import torch.nn.functional as nn_f
from typing import List, Tuple
from . import util


class Foveation(object):

    def __init__(self, fov_list: List[float],
                 out_res: Tuple[int, int], *, device=None):
        self.fov_list = fov_list
        self.out_res = out_res
        self.device = device
        self.n_layers = len(self.fov_list)
        self.eye_fovea_blend = [
            self._gen_layer_blendmap(i)
            for i in range(self.n_layers - 1)
        ]  # blend maps of fovea layers

    def to(self, device):
        self.eye_fovea_blend = [x.to(device=device) for x in self.eye_fovea_blend]
        return self

    def synthesis(self, layers: List[torch.Tensor]) -> torch.Tensor:
        """
        Generate foveated retinal image by blending fovea layers
        **Note: current implementation only support two fovea layers**

        :param layers ```List(Tensor(B, C, H'{l}, W'{l}))```: list of foveated layers
        :return ```Tensor(B, C, H:out, W:out)```: foveated images
        """
        output: torch.Tensor = nn_f.interpolate(layers[-1], self.out_res,
                                  mode='bilinear', align_corners=False)
        for i in range(self.n_layers - 2, -1, -1):
            output_roi = output[self.get_layer_region_in_final_image(i)]
            image = nn_f.interpolate(layers[i], output_roi.size()[-2:],
                                     mode='bilinear', align_corners=False)
            blend = self.eye_fovea_blend[i]
            output_roi.mul_(1 - blend).add_(image * blend)
        return output

    def get_layer_size_in_final_image(self, i: int) -> int:
        """
        Get size of layer i in final image

        :param i: index of layer
        :return: size of layer i in final image (in pixels)
        """
        length_i = util.Fov2Length(self.fov_list[i])
        length = util.Fov2Length(self.fov_list[-1])
        k = length_i / length
        return int(math.ceil(self.out_res[0] * k))

    def get_layer_region_in_final_image(self, i: int) -> Tuple[slice, slice]:
        """
        Get region of fovea layer i in final image

        :param i: index of layer
        :return: tuple of slice objects stores the start and end of region in horizontal and vertical
        """
        roi_size = self.get_layer_size_in_final_image(i)
        roi_offset_y = (self.out_res[0] - roi_size) // 2
        roi_offset_x = (self.out_res[1] - roi_size) // 2
        return (...,
            slice(roi_offset_y, roi_offset_y + roi_size),
            slice(roi_offset_x, roi_offset_x + roi_size)
        )

    def _gen_layer_blendmap(self, i: int) -> torch.Tensor:
        """
        Generate blend map for fovea layer i

        :param i: index of fovea layer
        :return ```Tensor(H{i}, W{i})```: blend map
        """
        size = self.get_layer_size_in_final_image(i)
        R = size / 2
        p = util.MeshGrid((size, size)).to(device=self.device)  # (size, size, 2)
        r = torch.norm(p - R, dim=2)  # (size, size, 2)
        return util.SmoothStep(R, R * 0.6, r)