foveation.py 6.32 KB
Newer Older
BobYeah's avatar
sync    
BobYeah committed
1
2
3
import torch
import torch.nn.functional as nn_f
from typing import List, Tuple
Nianchen Deng's avatar
sync    
Nianchen Deng committed
4
5
6
from utils import img
from utils import view
from utils import misc
Nianchen Deng's avatar
sync    
Nianchen Deng committed
7
from utils import math
BobYeah's avatar
sync    
BobYeah committed
8
9
10

class Foveation(object):

Nianchen Deng's avatar
sync    
Nianchen Deng committed
11
12
    def __init__(self, layers_fov: list[float], layers_res: list[tuple[float, float]],
                 out_res: tuple[int, int], *, blend: float = 0.6, device: torch.device = None):
13
14
        self.layers_fov = layers_fov
        self.layers_res = layers_res
BobYeah's avatar
sync    
BobYeah committed
15
        self.out_res = out_res
16
        self.blend = blend
BobYeah's avatar
sync    
BobYeah committed
17
        self.device = device
18
        self.n_layers = len(self.layers_fov)
BobYeah's avatar
sync    
BobYeah committed
19
20
21
22
        self.eye_fovea_blend = [
            self._gen_layer_blendmap(i)
            for i in range(self.n_layers - 1)
        ]  # blend maps of fovea layers
Nianchen Deng's avatar
sync    
Nianchen Deng committed
23
        self.coords = misc.grid2d(*out_res, device=device)
BobYeah's avatar
sync    
BobYeah committed
24

Nianchen Deng's avatar
Nianchen Deng committed
25
26
    def to(self, device: torch.device):
        self.eye_fovea_blend = [x.to(device=device) for x in self.eye_fovea_blend]
Nianchen Deng's avatar
sync    
Nianchen Deng committed
27
        self.coords = self.coords.to(device=device)
BobYeah's avatar
sync    
BobYeah committed
28
29
        return self

Nianchen Deng's avatar
sync    
Nianchen Deng committed
30
31
    def synthesis(self, layers: list[torch.Tensor], fovea_center: tuple[float, float],
                  shifts: list[int] = None,
Nianchen Deng's avatar
Nianchen Deng committed
32
33
                  do_blend: bool = True,
                  crop_mode: bool = False) -> torch.Tensor:
BobYeah's avatar
sync    
BobYeah committed
34
35
36
37
        """
        Generate foveated retinal image by blending fovea layers
        **Note: current implementation only support two fovea layers**

Nianchen Deng's avatar
sync    
Nianchen Deng committed
38
39
        :param layers `List(Tensor(B, C, H'{l}, W'{l}))`: list of foveated layers
        :return `Tensor(B, C, H:out, W:out)`: foveated images
BobYeah's avatar
sync    
BobYeah committed
40
41
        """
        output: torch.Tensor = nn_f.interpolate(layers[-1], self.out_res,
Nianchen Deng's avatar
sync    
Nianchen Deng committed
42
                                                mode='bilinear', align_corners=False)
Nianchen Deng's avatar
sync    
Nianchen Deng committed
43
        #output.fill_(0) # TODO: debug
Nianchen Deng's avatar
Nianchen Deng committed
44
        if shifts is not None:
Nianchen Deng's avatar
sync    
Nianchen Deng committed
45
            output = img.horizontal_shift(output, shifts[-1])
Nianchen Deng's avatar
sync    
Nianchen Deng committed
46
        c = torch.tensor([
Nianchen Deng's avatar
Nianchen Deng committed
47
48
49
            fovea_center[0] + (self.out_res[1] - 1) / 2,
            fovea_center[1] + (self.out_res[0] - 1) / 2
        ], device=self.device)
BobYeah's avatar
sync    
BobYeah committed
50
        for i in range(self.n_layers - 2, -1, -1):
Nianchen Deng's avatar
Nianchen Deng committed
51
            if layers[i] is None:
Nianchen Deng's avatar
sync    
Nianchen Deng committed
52
53
54
                continue
            R = self.get_layer_size_in_final_image(i) / 2
            grid = ((self.coords - c) / R)[None, ...]
Nianchen Deng's avatar
Nianchen Deng committed
55
            if shifts is not None:
Nianchen Deng's avatar
sync    
Nianchen Deng committed
56
                grid = img.horizontal_shift(grid, shifts[i], -2)
57
            # (1, 1, H:out, W:out)
Nianchen Deng's avatar
sync    
Nianchen Deng committed
58
            if do_blend:
Nianchen Deng's avatar
Nianchen Deng committed
59
60
                blend = nn_f.grid_sample(self.eye_fovea_blend[i][None, None], grid,
                                         align_corners=False)
Nianchen Deng's avatar
sync    
Nianchen Deng committed
61
            else:
Nianchen Deng's avatar
Nianchen Deng committed
62
63
64
65
66
67
68
69
                blend = nn_f.grid_sample(torch.ones_like(self.eye_fovea_blend[i][None, None]), grid,
                                         align_corners=False)
            output.mul_(1 - blend)
            if crop_mode:
                output.add_(blend * nn_f.interpolate(layers[i], self.out_res, mode='bilinear',
                                                     align_corners=False))
            else:    
                output.add_(blend * nn_f.grid_sample(layers[i], grid, align_corners=False))
BobYeah's avatar
sync    
BobYeah committed
70
71
72
73
74
75
76
77
78
        return output

    def get_layer_size_in_final_image(self, i: int) -> int:
        """
        Get size of layer i in final image

        :param i: index of layer
        :return: size of layer i in final image (in pixels)
        """
79
80
81
82
83
84
85
86
87
88
89
90
91
92
        return self.get_source_layer_cover_size_in_target_layer(
            self.layers_fov[i], self.layers_fov[-1], self.out_res[0])

    def get_source_layer_cover_size_in_target_layer(self, source_fov, target_fov,
                                                    target_pixel_height) -> int:
        """
        Get size of layer i in final image

        :param i: index of layer
        :return: size of layer i in final image (in pixels)
        """
        source_physical_height = view.fov2length(source_fov)
        target_physical_height = view.fov2length(target_fov)
        return int(math.ceil(target_pixel_height * source_physical_height / target_physical_height))
BobYeah's avatar
sync    
BobYeah committed
93
94
95
96
97
98

    def _gen_layer_blendmap(self, i: int) -> torch.Tensor:
        """
        Generate blend map for fovea layer i

        :param i: index of fovea layer
Nianchen Deng's avatar
sync    
Nianchen Deng committed
99
        :return `Tensor(H{i}, W{i})`: blend map
BobYeah's avatar
sync    
BobYeah committed
100
101
102
        """
        size = self.get_layer_size_in_final_image(i)
        R = size / 2
Nianchen Deng's avatar
sync    
Nianchen Deng committed
103
        p = misc.grid2d(size, device=self.device)  # (size, size, 2)
BobYeah's avatar
sync    
BobYeah committed
104
        r = torch.norm(p - R, dim=2)  # (size, size, 2)
105
106
        return misc.smooth_step(R, R * self.blend, r)

Nianchen Deng's avatar
sync    
Nianchen Deng committed
107
    def get_layers_mask(self, gaze=None) -> list[torch.Tensor]:
108
109
110
111
112
113
114
115
116
117
118
        """
        Generate mask images for layers[:-1]
        the meaning of values in mask images:
        -1: skipped
        0~1: blend with inner layer
        1~2: only self layer
        2~3: blend with outer layer

        :return: Mask images for layers except outermost
        """
        layers_mask = []
Nianchen Deng's avatar
sync    
Nianchen Deng committed
119
120
        for i in range(self.n_layers):
            if i == self.n_layers - 1:
Nianchen Deng's avatar
Nianchen Deng committed
121
122
123
                if gaze is None:
                    layers_mask.append(torch.ones(*self.layers_res[i], device=self.device))
                    continue
Nianchen Deng's avatar
sync    
Nianchen Deng committed
124
125
126
127
128
129
                c = torch.tensor([
                    (gaze[0] + 0.5 * self.out_res[1]) / self.out_res[0],
                    (gaze[1] + 0.5 * self.out_res[0]) / self.out_res[0]
                ], device=self.device)
            else:
                c = torch.tensor([0.5, 0.5], device=self.device)
Nianchen Deng's avatar
Nianchen Deng committed
130
            layers_mask.append(torch.ones(*self.layers_res[i], device=self.device) * -1)
Nianchen Deng's avatar
sync    
Nianchen Deng committed
131
            coord = misc.grid2d(*self.layers_res[i], device=self.device) / self.layers_res[i][0]
Nianchen Deng's avatar
sync    
Nianchen Deng committed
132
            r = 2 * torch.norm(coord - c, dim=-1)
133
            inner_radius = self.get_source_layer_cover_size_in_target_layer(
Nianchen Deng's avatar
sync    
Nianchen Deng committed
134
135
136
137
138
139
                self.layers_fov[i - 1], self.layers_fov[i], self.layers_res[i][0]) / self.layers_res[i][0] \
                if i > 0 else 0
            if i == self.n_layers - 1:
                bounds = [inner_radius * (1 - self.blend), inner_radius, 100, 100]
            else:
                bounds = [inner_radius * (1 - self.blend), inner_radius, self.blend, 1]
140
            for bi in range(len(bounds) - 1):
Nianchen Deng's avatar
Nianchen Deng committed
141
                region = torch.logical_and(r >= bounds[bi], r < bounds[bi + 1])
142
143
                layers_mask[i][region] = bi + \
                    (r[region] - bounds[bi]) / (bounds[bi + 1] - bounds[bi])
Nianchen Deng's avatar
sync    
Nianchen Deng committed
144
        return layers_mask