gen_image.py 16.4 KB
Newer Older
BobYeah's avatar
BobYeah committed
1
2
3
import matplotlib.pyplot as plt
import numpy as np
import torch
BobYeah's avatar
Gaze    
BobYeah committed
4
import glm
BobYeah's avatar
BobYeah committed
5
6
import time
import util
BobYeah's avatar
Gaze    
BobYeah committed
7
8

def RandomGenSamplesInPupil(pupil_size, n_samples):
BobYeah's avatar
BobYeah committed
9
10
11
12
13
14
15
16
17
18
    '''
    Random sample n_samples positions in pupil region
    
    Parameters
    --------
    conf      - multi-layers' parameters configuration
    n_samples - number of samples to generate
    
    Returns
    --------
BobYeah's avatar
BobYeah committed
19
    a n_samples x 3 tensor with 3D sample position in each row
BobYeah's avatar
BobYeah committed
20
    '''
BobYeah's avatar
BobYeah committed
21
    samples = torch.empty(n_samples, 3)
BobYeah's avatar
BobYeah committed
22
23
    i = 0
    while i < n_samples:
BobYeah's avatar
Gaze    
BobYeah committed
24
25
        s = (torch.rand(2) - 0.5) * pupil_size
        if np.linalg.norm(s) > pupil_size / 2.:
BobYeah's avatar
BobYeah committed
26
            continue
BobYeah's avatar
Gaze    
BobYeah committed
27
        samples[i, :] = [ s[0], s[1], 0 ]
BobYeah's avatar
BobYeah committed
28
29
30
        i += 1
    return samples

BobYeah's avatar
Gaze    
BobYeah committed
31
def GenSamplesInPupil(pupil_size, circles):
BobYeah's avatar
BobYeah committed
32
33
34
35
36
37
38
39
40
41
    '''
    Sample positions on circles in pupil region
    
    Parameters
    --------
    conf      - multi-layers' parameters configuration
    circles   - number of circles to sample
    
    Returns
    --------
BobYeah's avatar
BobYeah committed
42
    a n_samples x 3 tensor with 3D sample position in each row
BobYeah's avatar
BobYeah committed
43
    '''
BobYeah's avatar
Gaze    
BobYeah committed
44
    samples = torch.zeros(1, 3)
BobYeah's avatar
BobYeah committed
45
    for i in range(1, circles):
BobYeah's avatar
Gaze    
BobYeah committed
46
        r = pupil_size / 2. / (circles - 1) * i
BobYeah's avatar
BobYeah committed
47
48
49
        n = 4 * i
        for j in range(0, n):
            angle = 2 * np.pi / n * j
BobYeah's avatar
Gaze    
BobYeah committed
50
            samples = torch.cat([ samples, torch.tensor([[ r * np.cos(angle), r * np.sin(angle), 0 ]]) ], 0)
BobYeah's avatar
BobYeah committed
51
52
    return samples

BobYeah's avatar
Gaze    
BobYeah committed
53
class RetinalGen(object):
BobYeah's avatar
BobYeah committed
54
    '''
BobYeah's avatar
Gaze    
BobYeah committed
55
    Class for retinal generation process
BobYeah's avatar
BobYeah committed
56
    
BobYeah's avatar
Gaze    
BobYeah committed
57
    Properties
BobYeah's avatar
BobYeah committed
58
59
    --------
    conf - multi-layers' parameters configuration
BobYeah's avatar
Gaze    
BobYeah committed
60
61
62
63
    u    - M x 3 tensor, M sample positions in pupil
    p_r  - H_r x W_r x 3 tensor, retinal pixel grid, [H_r, W_r] is the retinal resolution
    Phi  - N x H_r x W_r x M x 2 tensor, retinal to layers mapping, N is number of layers
    mask - N x H_r x W_r x M x 2 tensor, indicates invalid (out-of-range) mapping
BobYeah's avatar
BobYeah committed
64
    
BobYeah's avatar
Gaze    
BobYeah committed
65
    Methods
BobYeah's avatar
BobYeah committed
66
67
    --------
    '''
BobYeah's avatar
BobYeah committed
68
    def __init__(self, conf):
BobYeah's avatar
Gaze    
BobYeah committed
69
70
        '''
        Initialize retinal generator instance
BobYeah's avatar
BobYeah committed
71

BobYeah's avatar
Gaze    
BobYeah committed
72
73
74
75
76
77
        Parameters
        --------
        conf - multi-layers' parameters configuration
        u    - a M x 3 tensor stores M sample positions in pupil
        '''
        self.conf = conf
BobYeah's avatar
BobYeah committed
78
        self.u = GenSamplesInPupil(conf.pupil_size, 5)
BobYeah's avatar
Gaze    
BobYeah committed
79
        # self.u = u.to(cuda_dev)
BobYeah's avatar
BobYeah committed
80
        # self.u = u # M x 3 M sample positions 
BobYeah's avatar
Gaze    
BobYeah committed
81
82
        self.D_r = conf.retinal_res # retinal res 480 x 640 
        self.N = conf.GetNLayers() # 2 
BobYeah's avatar
BobYeah committed
83
84
85
86
87
88
89
90
        self.M = self.u.size()[0] # samples
        # p_rx, p_ry = torch.meshgrid(torch.tensor(range(0, self.D_r[0])),
        #                             torch.tensor(range(0, self.D_r[1])))
        # self.p_r = torch.cat([
        #     ((torch.stack([p_rx, p_ry], 2) + 0.5) / self.D_r - 0.5) * conf.GetEyeViewportSize(), # 眼球视野
        #     torch.ones(self.D_r[0], self.D_r[1], 1)
        # ], 2)

BobYeah's avatar
Gaze    
BobYeah committed
91
        self.p_r = torch.cat([
BobYeah's avatar
BobYeah committed
92
            ((util.MeshGrid(self.D_r) + 0.5) / self.D_r - 0.5) * conf.GetEyeViewportSize(),
BobYeah's avatar
Gaze    
BobYeah committed
93
94
            torch.ones(self.D_r[0], self.D_r[1], 1)
        ], 2)
BobYeah's avatar
BobYeah committed
95

BobYeah's avatar
Gaze    
BobYeah committed
96
97
98
        # self.Phi = torch.empty(N, D_r[0], D_r[1], M, 2, device=cuda_dev, dtype=torch.long)
        # self.mask = torch.empty(self.N, self.D_r[0], self.D_r[1], self.M, 2, dtype=torch.float) # 2 x 480 x 640 x 41 x 2
        
BobYeah's avatar
BobYeah committed
99
    def CalculateRetinal2LayerMappings(self, position, gaze_dir, df):
BobYeah's avatar
Gaze    
BobYeah committed
100
101
        '''
        Calculate the mapping matrix from retinal to layers.
BobYeah's avatar
BobYeah committed
102

BobYeah's avatar
Gaze    
BobYeah committed
103
104
        Parameters
        --------
BobYeah's avatar
BobYeah committed
105
106
107
        position - 1 x 3 tensor, eye's position
        gaze_dir - 1 x 2 tensor, gaze forward vector (with z normalized)
        df       - focus distance
BobYeah's avatar
Gaze    
BobYeah committed
108

BobYeah's avatar
BobYeah committed
109
110
111
112
113
        Returns
        --------
        phi             - N x H_r x W_r x M x 2, retinal to layers mapping, N is number of layers
        phi_invalid     - N x H_r x W_r x M x 1, indicates invalid (out-of-range) mapping
        retinal_invalid - 1 x H_r x W_r, indicates invalid pixels in retinal image
BobYeah's avatar
Gaze    
BobYeah committed
114
        '''
BobYeah's avatar
BobYeah committed
115
116
117
        D = self.conf.layer_res 
        c = torch.tensor([ D[1] / 2, D[0] / 2 ])     # c: Center of layers (pixel)

BobYeah's avatar
Gaze    
BobYeah committed
118
119
120
        D_r = self.conf.retinal_res        # D_r: Resolution of retinal 480 640
        V = self.conf.GetEyeViewportSize() # V: Viewport size of eye 
        p_f = self.p_r * df                # p_f: H x W x 3, focus positions of retinal pixels on focus plane
BobYeah's avatar
BobYeah committed
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
        
        # Calculate transformation from eye to display
        gvec_lookat = glm.dvec3(gaze_dir[0], -gaze_dir[1], 1)
        gmat_eye = glm.inverse(glm.lookAtLH(glm.dvec3(), gvec_lookat, glm.dvec3(0, 1, 0)))
        eye_rot = util.Glm2Tensor(glm.dmat3(gmat_eye))
        eye_center = torch.tensor([ position[0], -position[1], position[2] ])

        u_rot = torch.mm(self.u, eye_rot)
        v_rot = torch.matmul(p_f, eye_rot).unsqueeze(2).expand(
            -1, -1, self.M, -1) - u_rot # v_rot: H x W x M x 3, rotated rays' direction vector
        u_rot.add_(eye_center)                            # translate by eye's center
        v_rot = v_rot.div(v_rot[:, :, :, 2].unsqueeze(3)) # make z = 1 for each direction vector in v_rot
        
        phi = torch.empty(self.N, self.D_r[0], self.D_r[1], self.M, 2, dtype=torch.long)

        for i in range(0, self.N):
            dp_i = self.conf.GetPixelSizeOfLayer(i)     # dp_i: Pixel size of layer i
            d_i = self.conf.d_layer[i]                  # d_i: Distance of layer i
BobYeah's avatar
Gaze    
BobYeah committed
139
140
            k = (d_i - u_rot[:, 2]).unsqueeze(1)
            pi_r = (u_rot[:, 0:2] + v_rot[:, :, :, 0:2] * k) / dp_i      # pi_r: H x W x M x 2, rays' pixel coord on layer i
BobYeah's avatar
BobYeah committed
141
142
143
144
145
146
147
148
149
150
151
152
153
154
            phi[i, :, :, :, :] = torch.floor(pi_r + c)
        
        # Calculate invalid mask (out-of-range elements in phi) and reduced to retinal
        phi_invalid = (phi[:, :, :, :, 0] < 0) | (phi[:, :, :, :, 0] >= D[1]) | \
                       (phi[:, :, :, :, 1] < 0) | (phi[:, :, :, :, 1] >= D[0])
        phi_invalid = phi_invalid.unsqueeze(4)
        # print("phi_invalid:",phi_invalid.shape) 
        retinal_invalid = phi_invalid.amax((0, 3)).squeeze().unsqueeze(0)
        # print("retinal_invalid:",retinal_invalid.shape)
        # Fix invalid elements in phi
        phi[phi_invalid.expand(-1, -1, -1, -1, 2)] = 0

        return [ phi, phi_invalid, retinal_invalid  ]
    
BobYeah's avatar
Gaze    
BobYeah committed
155
156
157
158
159
160
161
    
    def GenRetinalFromLayers(self, layers, Phi):
        '''
        Generate retinal image from layers, using precalculated mapping matrix
        
        Parameters
        --------
BobYeah's avatar
BobYeah committed
162
163
        layers       - 3N x H x W, stacked layer images, with 3 channels in each layer
        phi          - N x H_r x W_r x M x 2, retinal to layers mapping, N is number of layers
BobYeah's avatar
Gaze    
BobYeah committed
164
165
166
        
        Returns
        --------
BobYeah's avatar
BobYeah committed
167
        3 x H_r x W_r, 3 channels retinal image
BobYeah's avatar
Gaze    
BobYeah committed
168
169
170
171
172
        '''
        # FOR GRAYSCALE 1 FOR RGB 3
        mapped_layers = torch.empty(self.N, 3, self.D_r[0], self.D_r[1], self.M) # 2 x 3 x 480 x 640 x 41
        # print("mapped_layers:",mapped_layers.shape)
        for i in range(0, Phi.size()[0]):
BobYeah's avatar
BobYeah committed
173
            # torch.Size([3, 2, 320, 320, 2])
BobYeah's avatar
Gaze    
BobYeah committed
174
175
            # print("gather layers:",layers[(i * 3) : (i * 3 + 3),Phi[i, :, :, :, 0],Phi[i, :, :, :, 1]].shape)
            mapped_layers[i, :, :, :, :] = layers[(i * 3) : (i * 3 + 3),
BobYeah's avatar
BobYeah committed
176
177
                                                    Phi[i, :, :, :, 1],
                                                    Phi[i, :, :, :, 0]]
BobYeah's avatar
Gaze    
BobYeah committed
178
179
180
        # print("mapped_layers:",mapped_layers.shape)
        retinal = mapped_layers.prod(0).sum(3).div(Phi.size()[3])
        # print("retinal:",retinal.shape)
BobYeah's avatar
BobYeah committed
181
        return retinal
BobYeah's avatar
BobYeah committed
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276

    def GenRetinalFromLayersBatch(self, layers, Phi):
        '''
        Generate retinal image from layers, using precalculated mapping matrix
        
        Parameters
        --------
        layers - 3N x H_l x W_l tensor, stacked layer images, with 3 channels in each layer
        
        Returns
        --------
        3 x H_r x W_r tensor, 3 channels retinal image
        H_r x W_r tensor, retinal image mask, indicates pixels valid or not
        
        '''
        mapped_layers = torch.empty(layers.size()[0], self.N, 3, self.D_r[0], self.D_r[1], self.M) #BS x Layers x C x H x W x Sample
        
        # truth = torch.empty(layers.size()[0], self.N, 3, self.D_r[0], self.D_r[1], self.M)
        # layers_truth = layers.clone()
        # Phi_truth = Phi.clone()
        layers = torch.stack((layers[:,0:3,:,:],layers[:,3:6,:,:]),dim=1) ## torch.Size([BS, Layer, RGB 3, 320, 320])
        
        # Phi = Phi[:,:,None,:,:,:,:].expand(-1,-1,3,-1,-1,-1,-1)
        # print("mapped_layers:",mapped_layers.shape) #torch.Size([2, 2, 3, 320, 320, 41])
        # print("input layers:",layers.shape) ## torch.Size([2, 2, 3, 320, 320])
        # print("input Phi:",Phi.shape) #torch.Size([2, 2, 320, 320, 41, 2])
        
        # #没优化

        # for i in range(0, Phi_truth.size()[0]):
        #     for j in range(0, Phi_truth.size()[1]):
        #         truth[i, j, :, :, :, :] = layers_truth[i, (j * 3) : (j * 3 + 3),
        #                                                 Phi_truth[i, j, :, :, :, 0],
        #                                                 Phi_truth[i, j, :, :, :, 1]]

        #优化2
        # start = time.time() 
        mapped_layers_op1 = mapped_layers.reshape(-1,
                mapped_layers.shape[2],mapped_layers.shape[3],mapped_layers.shape[4],mapped_layers.shape[5])
                # BatchSizexLayer Channel 3 320 320 41
        layers_op1 = layers.reshape(-1,layers.shape[2],layers.shape[3],layers.shape[4]) # 2x2 3 320 320
        Phi_op1 = Phi.reshape(-1,Phi.shape[2],Phi.shape[3],Phi.shape[4],Phi.shape[5]) # 2x2 320 320 41 2
        x = Phi_op1[:,:,:,:,0] # 2x2 320 320 41
        y = Phi_op1[:,:,:,:,1] # 2x2 320 320 41
        # print("reshape:",time.time() - start)

        # start = time.time()
        mapped_layers_op1 = layers_op1[torch.arange(layers_op1.shape[0])[:, None, None, None], :, y, x] # x,y 切换
        #2x2 320 320 41 3
        # print("mapping one step:",time.time() - start)
        
        # print("mapped_layers:",mapped_layers_op1.shape) # torch.Size([4, 3, 320, 320, 41])
        # start = time.time()
        mapped_layers_op1 = mapped_layers_op1.permute(0,4,1,2,3)
        mapped_layers = mapped_layers_op1.reshape(mapped_layers.shape[0],mapped_layers.shape[1],
                    mapped_layers.shape[2],mapped_layers.shape[3],mapped_layers.shape[4],mapped_layers.shape[5])
        # print("reshape end:",time.time() - start)

        # print("test:")
        # print((truth.cpu() == mapped_layers.cpu()).all())
        #优化1
        # start = time.time()
        # mapped_layers_op1 = mapped_layers.reshape(-1,
        #         mapped_layers.shape[2],mapped_layers.shape[3],mapped_layers.shape[4],mapped_layers.shape[5])
        # layers_op1 = layers.reshape(-1,layers.shape[2],layers.shape[3],layers.shape[4])
        # Phi_op1 = Phi.reshape(-1,Phi.shape[2],Phi.shape[3],Phi.shape[4],Phi.shape[5])
        # print("reshape:",time.time() - start)


        # for i in range(0, Phi_op1.size()[0]):
        #     start = time.time()
        #     mapped_layers_op1[i, :, :, :, :] = layers_op1[i,:,
        #                                             Phi_op1[i, :, :, :, 0],
        #                                             Phi_op1[i, :, :, :, 1]]
        #     print("mapping one step:",time.time() - start)
        # print("mapped_layers:",mapped_layers_op1.shape) # torch.Size([4, 3, 320, 320, 41])
        # start = time.time()
        # mapped_layers = mapped_layers_op1.reshape(mapped_layers.shape[0],mapped_layers.shape[1],
        #             mapped_layers.shape[2],mapped_layers.shape[3],mapped_layers.shape[4],mapped_layers.shape[5])
        # print("reshape end:",time.time() - start)

        # print("mapped_layers:",mapped_layers.shape) # torch.Size([2, 2, 3, 320, 320, 41])
        retinal = mapped_layers.prod(1).sum(4).div(Phi.size()[4])
        # print("retinal:",retinal.shape) # torch.Size([BatchSize, 3, 320, 320])
        return retinal
        
    ## TO BE CHECK
    def GenFoveaLayers(self, b_retinal, is_mask):
        '''
        Generate foveated layers for retinal images or masks
        
        Parameters
        --------
        b_retinal - B x C x H_r x W_r, Batch of retinal images/masks
        is_mask   - Whether b_retinal is masks or images
BobYeah's avatar
BobYeah committed
277
        
BobYeah's avatar
BobYeah committed
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
        Returns
        --------
        b_fovea_layers - N_f x (B x C x H[f] x W[f]) list of batch of foveated layers
        '''
        b_fovea_layers = []
        for i in range(0, len(self.conf.eye_fovea_angles)):
            k = self.conf.eye_fovea_downsamples[i]
            region = self.conf.GetRegionOfFoveaLayer(i)
            b_roi = b_retinal[:, :, region, region]
            if k == 1:
                b_fovea_layers.append(b_roi)
            elif is_mask:
                b_fovea_layers.append(torch.nn.functional.max_pool2d(b_roi.to(torch.float), k).to(torch.bool))
            else:
                b_fovea_layers.append(torch.nn.functional.avg_pool2d(b_roi, k))
        return b_fovea_layers
        # fovea_layers = []
        # fovea_layer_masks = []
        # fov = self.conf.eye_fovea_angles[-1]
        # retinal_res = int(self.conf.retinal_res[0])
        # for i in range(0, len(self.conf.eye_fovea_angles)):
        #     angle = self.conf.eye_fovea_angles[i]
        #     k = self.conf.eye_fovea_downsamples[i]
        #     roi_size = int(np.ceil(retinal_res * angle / fov))
        #     roi_offset = int((retinal_res - roi_size) / 2)
        #     roi_img = retinal[:, roi_offset:(roi_offset + roi_size), roi_offset:(roi_offset + roi_size)]
        #     roi_mask = retinal_mask[roi_offset:(roi_offset + roi_size), roi_offset:(roi_offset + roi_size)]
        #     if k == 1:
        #         fovea_layers.append(roi_img)
        #         fovea_layer_masks.append(roi_mask)
        #     else:
        #         fovea_layers.append(torch.nn.functional.avg_pool2d(roi_img.unsqueeze(0), k).squeeze(0))
        #         fovea_layer_masks.append(1 - torch.nn.functional.max_pool2d((1 - roi_mask).unsqueeze(0), k).squeeze(0))
        # return [ fovea_layers, fovea_layer_masks ]

    ## TO BE CHECK
    def GenFoveaLayersBatch(self, retinal, retinal_mask):
BobYeah's avatar
BobYeah committed
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
        '''
        Generate foveated layers and corresponding masks
        
        Parameters
        --------
        retinal      - Retinal image generated by GenRetinalFromLayers()
        retinal_mask - Mask of retinal image, also generated by GenRetinalFromLayers()
        
        Returns
        --------
        fovea_layers      - list of foveated layers
        fovea_layer_masks - list of mask images, corresponding to foveated layers
        '''
        fovea_layers = []
        fovea_layer_masks = []
        fov = self.conf.eye_fovea_angles[-1]
BobYeah's avatar
BobYeah committed
331
        # print("fov:",fov)
BobYeah's avatar
BobYeah committed
332
        retinal_res = int(self.conf.retinal_res[0])
BobYeah's avatar
BobYeah committed
333
334
        # print("retinal_res:",retinal_res)
        # print("len(self.conf.eye_fovea_angles):",len(self.conf.eye_fovea_angles))
BobYeah's avatar
BobYeah committed
335
336
337
338
339
        for i in range(0, len(self.conf.eye_fovea_angles)):
            angle = self.conf.eye_fovea_angles[i]
            k = self.conf.eye_fovea_downsamples[i]
            roi_size = int(np.ceil(retinal_res * angle / fov))
            roi_offset = int((retinal_res - roi_size) / 2)
BobYeah's avatar
BobYeah committed
340
341
342
343
344
345
            # [2, 3, 320, 320]
            roi_img = retinal[:, :, roi_offset:(roi_offset + roi_size), roi_offset:(roi_offset + roi_size)]
            # print("roi_img:",roi_img.shape)
            # [2, 320, 320]
            roi_mask = retinal_mask[:, roi_offset:(roi_offset + roi_size), roi_offset:(roi_offset + roi_size)]
            # print("roi_mask:",roi_mask.shape)
BobYeah's avatar
BobYeah committed
346
347
348
349
            if k == 1:
                fovea_layers.append(roi_img)
                fovea_layer_masks.append(roi_mask)
            else:
BobYeah's avatar
BobYeah committed
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
                fovea_layers.append(torch.nn.functional.avg_pool2d(roi_img, k))
                fovea_layer_masks.append(1 - torch.nn.functional.max_pool2d((1 - roi_mask), k))
        return [ fovea_layers, fovea_layer_masks ]
    
    ## TO BE CHECK
    def GenFoveaRetinal(self, b_fovea_layers):
        '''
        Generate foveated retinal image by blending fovea layers
        **Note: current implementation only support two fovea layers**
        
        Parameters
        --------
        b_fovea_layers - N_f x (B x 3 x H[f] x W[f]), list of batch of (masked) foveated layers
        
        Returns
        --------
        B x 3 x H_r x W_r, batch of foveated retinal images
        '''
        b_fovea_retinal = torch.nn.functional.interpolate(b_fovea_layers[1],
            scale_factor=self.conf.eye_fovea_downsamples[1],
            mode='bilinear', align_corners=False)
        region = self.conf.GetRegionOfFoveaLayer(0)
        blend = self.conf.eye_fovea_blend[0]
        b_roi = b_fovea_retinal[:, :, region, region]
        b_roi.mul_(1 - blend).add_(b_fovea_layers[0] * blend)
        return b_fovea_retinal