main.py 28 KB
Newer Older
BobYeah's avatar
BobYeah committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import torch
import argparse
import os
import glob
import numpy as np
import torchvision.transforms as transforms
from torchvision.utils import save_image

from torchvision import datasets
from torch.utils.data import DataLoader 
from torch.autograd import Variable

import cv2
from gen_image import *
import json
BobYeah's avatar
BobYeah committed
16
17
from ssim import *
from perc_loss import * 
BobYeah's avatar
BobYeah committed
18
19
20
21
22
from conf import Conf

from model.baseline import *

import torch.autograd.profiler as profiler
BobYeah's avatar
BobYeah committed
23
# param
BobYeah's avatar
BobYeah committed
24
25
BATCH_SIZE = 2
NUM_EPOCH = 300
BobYeah's avatar
BobYeah committed
26
27
28

INTERLEAVE_RATE = 2

BobYeah's avatar
Gaze    
BobYeah committed
29
30
31
32
33
IM_H = 320
IM_W = 320

Retinal_IM_H = 320
Retinal_IM_W = 320
BobYeah's avatar
BobYeah committed
34
35
36
37

N = 9 # number of input light field stack
M = 2 # number of display layers

BobYeah's avatar
BobYeah committed
38
39
40
41
42
43
44
45
46
47
48
49
DATA_FILE = "/home/yejiannan/Project/LightField/data/gaze_fovea"
DATA_JSON = "/home/yejiannan/Project/LightField/data/data_gaze_fovea_seq.json"
DATA_VAL_JSON = "/home/yejiannan/Project/LightField/data/data_gaze_fovea_val.json"
OUTPUT_DIR = "/home/yejiannan/Project/LightField/outputE/gaze_fovea_seq"


OUT_CHANNELS_RB = 128
KERNEL_SIZE_RB = 3
KERNEL_SIZE = 3

LAST_LAYER_CHANNELS = 6 * INTERLEAVE_RATE**2
FIRSST_LAYER_CHANNELS = 27 * INTERLEAVE_RATE**2
BobYeah's avatar
BobYeah committed
50
51
52
53
54

class lightFieldDataLoader(torch.utils.data.dataset.Dataset):
    def __init__(self, file_dir_path, file_json, transforms=None):
        self.file_dir_path = file_dir_path
        self.transforms = transforms
BobYeah's avatar
BobYeah committed
55
        # self.datum_list = glob.glob(os.path.join(file_dir_path,"*"))
BobYeah's avatar
BobYeah committed
56
        with open(file_json, encoding='utf-8') as file:
BobYeah's avatar
Gaze    
BobYeah committed
57
            self.dataset_desc = json.loads(file.read())
BobYeah's avatar
BobYeah committed
58
59

    def __len__(self):
BobYeah's avatar
Gaze    
BobYeah committed
60
        return len(self.dataset_desc["focaldepth"])
BobYeah's avatar
BobYeah committed
61
62

    def __getitem__(self, idx):
BobYeah's avatar
BobYeah committed
63
        lightfield_images, gt, gt2, fd, gazeX, gazeY, sample_idx = self.get_datum(idx)
BobYeah's avatar
BobYeah committed
64
65
        if self.transforms:
            lightfield_images = self.transforms(lightfield_images)
BobYeah's avatar
BobYeah committed
66
67
        # print(lightfield_images.shape,gt.shape,fd,gazeX,gazeY,sample_idx)
        return (lightfield_images, gt, gt2, fd, gazeX, gazeY, sample_idx)
BobYeah's avatar
BobYeah committed
68
69

    def get_datum(self, idx):
BobYeah's avatar
Gaze    
BobYeah committed
70
        lf_image_paths = os.path.join(DATA_FILE, self.dataset_desc["train"][idx])
BobYeah's avatar
BobYeah committed
71
        # print(lf_image_paths)
BobYeah's avatar
Gaze    
BobYeah committed
72
        fd_gt_path = os.path.join(DATA_FILE, self.dataset_desc["gt"][idx])
BobYeah's avatar
BobYeah committed
73
74
        fd_gt_path2 = os.path.join(DATA_FILE, self.dataset_desc["gt2"][idx])
        # print(fd_gt_path)
BobYeah's avatar
BobYeah committed
75
76
77
        lf_images = []
        lf_image_big = cv2.imread(lf_image_paths, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.
        lf_image_big = cv2.cvtColor(lf_image_big,cv2.COLOR_BGR2RGB)
BobYeah's avatar
BobYeah committed
78
79
80
81
82
83

        ## IF GrayScale
        # lf_image_big = cv2.imread(lf_image_paths, cv2.IMREAD_GRAYSCALE).astype(np.float32) / 255.
        # lf_image_big = np.expand_dims(lf_image_big, axis=-1)
        # print(lf_image_big.shape)

BobYeah's avatar
BobYeah committed
84
85
        for i in range(9):
            lf_image = lf_image_big[i//3*IM_H:i//3*IM_H+IM_H,i%3*IM_W:i%3*IM_W+IM_W,0:3]
BobYeah's avatar
Gaze    
BobYeah committed
86
87
            ## IF GrayScale
            # lf_image = lf_image_big[i//3*IM_H:i//3*IM_H+IM_H,i%3*IM_W:i%3*IM_W+IM_W,0:1]
BobYeah's avatar
BobYeah committed
88
89
90
91
            # print(lf_image.shape)
            lf_images.append(lf_image)
        gt = cv2.imread(fd_gt_path, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.
        gt = cv2.cvtColor(gt,cv2.COLOR_BGR2RGB)
BobYeah's avatar
BobYeah committed
92
93
        gt2 = cv2.imread(fd_gt_path2, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.
        gt2 = cv2.cvtColor(gt2,cv2.COLOR_BGR2RGB)
BobYeah's avatar
Gaze    
BobYeah committed
94
95
96
97
98
99
100
101
        ## IF GrayScale
        # gt = cv2.imread(fd_gt_path, cv2.IMREAD_GRAYSCALE).astype(np.float32) / 255.
        # gt = np.expand_dims(gt, axis=-1)

        fd = self.dataset_desc["focaldepth"][idx]
        gazeX = self.dataset_desc["gazeX"][idx]
        gazeY = self.dataset_desc["gazeY"][idx]
        sample_idx = self.dataset_desc["idx"][idx]
BobYeah's avatar
BobYeah committed
102
        return np.asarray(lf_images),gt,gt2,fd,gazeX,gazeY,sample_idx
BobYeah's avatar
BobYeah committed
103

BobYeah's avatar
BobYeah committed
104
105
106
107
108
109
110
class lightFieldValDataLoader(torch.utils.data.dataset.Dataset):
    def __init__(self, file_dir_path, file_json, transforms=None):
        self.file_dir_path = file_dir_path
        self.transforms = transforms
        # self.datum_list = glob.glob(os.path.join(file_dir_path,"*"))
        with open(file_json, encoding='utf-8') as file:
            self.dataset_desc = json.loads(file.read())
BobYeah's avatar
BobYeah committed
111

BobYeah's avatar
BobYeah committed
112
113
    def __len__(self):
        return len(self.dataset_desc["focaldepth"])
BobYeah's avatar
BobYeah committed
114

BobYeah's avatar
BobYeah committed
115
116
117
118
119
120
    def __getitem__(self, idx):
        lightfield_images, fd, gazeX, gazeY, sample_idx = self.get_datum(idx)
        if self.transforms:
            lightfield_images = self.transforms(lightfield_images)
        # print(lightfield_images.shape,gt.shape,fd,gazeX,gazeY,sample_idx)
        return (lightfield_images, fd, gazeX, gazeY, sample_idx)
BobYeah's avatar
BobYeah committed
121

BobYeah's avatar
BobYeah committed
122
123
124
125
126
127
    def get_datum(self, idx):
        lf_image_paths = os.path.join(DATA_FILE, self.dataset_desc["train"][idx])
        # print(fd_gt_path)
        lf_images = []
        lf_image_big = cv2.imread(lf_image_paths, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.
        lf_image_big = cv2.cvtColor(lf_image_big,cv2.COLOR_BGR2RGB)
BobYeah's avatar
BobYeah committed
128

BobYeah's avatar
BobYeah committed
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
        ## IF GrayScale
        # lf_image_big = cv2.imread(lf_image_paths, cv2.IMREAD_GRAYSCALE).astype(np.float32) / 255.
        # lf_image_big = np.expand_dims(lf_image_big, axis=-1)
        # print(lf_image_big.shape)

        for i in range(9):
            lf_image = lf_image_big[i//3*IM_H:i//3*IM_H+IM_H,i%3*IM_W:i%3*IM_W+IM_W,0:3]
            ## IF GrayScale
            # lf_image = lf_image_big[i//3*IM_H:i//3*IM_H+IM_H,i%3*IM_W:i%3*IM_W+IM_W,0:1]
            # print(lf_image.shape)
            lf_images.append(lf_image)
        ## IF GrayScale
        # gt = cv2.imread(fd_gt_path, cv2.IMREAD_GRAYSCALE).astype(np.float32) / 255.
        # gt = np.expand_dims(gt, axis=-1)

        fd = self.dataset_desc["focaldepth"][idx]
        gazeX = self.dataset_desc["gazeX"][idx]
        gazeY = self.dataset_desc["gazeY"][idx]
        sample_idx = self.dataset_desc["idx"][idx]
        return np.asarray(lf_images),fd,gazeX,gazeY,sample_idx
BobYeah's avatar
Gaze    
BobYeah committed
149

BobYeah's avatar
BobYeah committed
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
class lightFieldSeqDataLoader(torch.utils.data.dataset.Dataset):
    def __init__(self, file_dir_path, file_json, transforms=None):
        self.file_dir_path = file_dir_path
        self.transforms = transforms
        with open(file_json, encoding='utf-8') as file:
            self.dataset_desc = json.loads(file.read())

    def __len__(self):
        return len(self.dataset_desc["seq"])

    def __getitem__(self, idx):
        lightfield_images, gt, gt2, fd, gazeX, gazeY, sample_idx = self.get_datum(idx)
        fd = fd.astype(np.float32)
        gazeX = gazeX.astype(np.float32)
        gazeY = gazeY.astype(np.float32)
        sample_idx = sample_idx.astype(np.int64)
        # print(fd)
        # print(gazeX)
        # print(gazeY)
        # print(sample_idx)

        # print(lightfield_images.dtype,gt.dtype, gt2.dtype, fd.dtype, gazeX.dtype, gazeY.dtype, sample_idx.dtype, delta.dtype)
        # print(lightfield_images.shape,gt.shape, gt2.shape, fd.shape, gazeX.shape, gazeY.shape, sample_idx.shape, delta.shape)
        if self.transforms:
            lightfield_images = self.transforms(lightfield_images)
        return (lightfield_images, gt, gt2, fd, gazeX, gazeY, sample_idx)

    def get_datum(self, idx):
        indices = self.dataset_desc["seq"][idx]
        # print("indices:",indices)
        lf_images = []
        fd = []
        gazeX = []
        gazeY = []
        sample_idx = []
        gt = []
        gt2 = []
        for i in range(len(indices)):
            lf_image_paths = os.path.join(DATA_FILE, self.dataset_desc["train"][indices[i]])
            fd_gt_path = os.path.join(DATA_FILE, self.dataset_desc["gt"][indices[i]])
            fd_gt_path2 = os.path.join(DATA_FILE, self.dataset_desc["gt2"][indices[i]])
            lf_image_one_sample = []
            lf_image_big = cv2.imread(lf_image_paths, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.
            lf_image_big = cv2.cvtColor(lf_image_big,cv2.COLOR_BGR2RGB)

            for j in range(9):
                lf_image = lf_image_big[j//3*IM_H:j//3*IM_H+IM_H,j%3*IM_W:j%3*IM_W+IM_W,0:3]
                ## IF GrayScale
                # lf_image = lf_image_big[i//3*IM_H:i//3*IM_H+IM_H,i%3*IM_W:i%3*IM_W+IM_W,0:1]
                # print(lf_image.shape)
                lf_image_one_sample.append(lf_image)

            gt_i = cv2.imread(fd_gt_path, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.
            gt.append(cv2.cvtColor(gt_i,cv2.COLOR_BGR2RGB))
            gt2_i = cv2.imread(fd_gt_path2, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.
            gt2.append(cv2.cvtColor(gt2_i,cv2.COLOR_BGR2RGB))

            # print("indices[i]:",indices[i])
            fd.append([self.dataset_desc["focaldepth"][indices[i]]])
            gazeX.append([self.dataset_desc["gazeX"][indices[i]]])
            gazeY.append([self.dataset_desc["gazeY"][indices[i]]])
            sample_idx.append([self.dataset_desc["idx"][indices[i]]])
            lf_images.append(lf_image_one_sample)
        #lf_images: 5,9,320,320

        return np.asarray(lf_images),np.asarray(gt),np.asarray(gt2),np.asarray(fd),np.asarray(gazeX),np.asarray(gazeY),np.asarray(sample_idx)
BobYeah's avatar
BobYeah committed
216
217
218

#### Image Gen
conf = Conf()
BobYeah's avatar
Gaze    
BobYeah committed
219
220
u = GenSamplesInPupil(conf.pupil_size, 5)
gen = RetinalGen(conf, u)
BobYeah's avatar
BobYeah committed
221

BobYeah's avatar
Gaze    
BobYeah committed
222
223
def GenRetinalFromLayersBatch(layers, gen, sample_idx, phi_dict, mask_dict):
    # layers: batchsize, 2*color, height, width 
BobYeah's avatar
BobYeah committed
224
225
    # Phi:torch.Size([batchsize, 480, 640, 2, 41, 2])
    # df : batchsize,..
BobYeah's avatar
Gaze    
BobYeah committed
226
227
228
229
230
231
    
    #  retinal bs x color x height x width
    retinal = torch.zeros(layers.shape[0], 3, Retinal_IM_H, Retinal_IM_W)
    mask = [] # mask shape 480 x 640
    for i in range(0, layers.size()[0]):
        phi = phi_dict[int(sample_idx[i].data)]
BobYeah's avatar
BobYeah committed
232
        # print("phi_i:",phi.shape)
BobYeah's avatar
Gaze    
BobYeah committed
233
234
        phi = var_or_cuda(phi)
        phi.requires_grad = False
BobYeah's avatar
BobYeah committed
235
236
        # print("layers[i]:",layers[i].shape)
        # print("retinal[i]:",retinal[i].shape)
BobYeah's avatar
Gaze    
BobYeah committed
237
238
        retinal[i] = gen.GenRetinalFromLayers(layers[i],phi)
        mask.append(mask_dict[int(sample_idx[i].data)])
BobYeah's avatar
BobYeah committed
239
    retinal = var_or_cuda(retinal)
BobYeah's avatar
Gaze    
BobYeah committed
240
241
    mask = torch.stack(mask,dim = 0).unsqueeze(1) # batch x 1 x height x width
    return retinal, mask
BobYeah's avatar
BobYeah committed
242

BobYeah's avatar
BobYeah committed
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
def GenRetinalGazeFromLayersBatch(layers, gen, sample_idx, phi_dict, mask_dict):
    # layers: batchsize, 2*color, height, width 
    # Phi:torch.Size([batchsize, 480, 640, 2, 41, 2])
    # df : batchsize,..
    
    #  retinal bs x color x height x width
    retinal_fovea = torch.empty(layers.shape[0], 6, 160, 160)
    mask_fovea = torch.empty(layers.shape[0], 2, 160, 160)
    for i in range(0, layers.size()[0]):
        phi = phi_dict[int(sample_idx[i].data)]
        # print("phi_i:",phi.shape)
        phi = var_or_cuda(phi)
        phi.requires_grad = False
        mask_i = var_or_cuda(mask_dict[int(sample_idx[i].data)])
        mask_i.requires_grad = False
        # print("layers[i]:",layers[i].shape)
        # print("retinal[i]:",retinal[i].shape)
        retinal_i = gen.GenRetinalFromLayers(layers[i],phi)
        fovea_layers, fovea_layer_masks = gen.GenFoveaLayers(retinal_i,mask_i)
        retinal_fovea[i] = torch.cat([fovea_layers[0],fovea_layers[1]],dim=0)
        mask_fovea[i] = torch.stack([fovea_layer_masks[0],fovea_layer_masks[1]],dim=0)
        
    retinal_fovea = var_or_cuda(retinal_fovea)
    mask_fovea = var_or_cuda(mask_fovea) # batch x 2 x height x width
    # mask = torch.stack(mask,dim = 0).unsqueeze(1) 
    return retinal_fovea, mask_fovea

BobYeah's avatar
Gaze    
BobYeah committed
270
271
272
273
274
275
276
277
278
279
280
281
282
def GenRetinalFromLayersBatch_Online(layers, gen, phi, mask):
    # layers: batchsize, 2*color, height, width 
    # Phi:torch.Size([batchsize, 480, 640, 2, 41, 2])
    # df : batchsize,..
    
    #  retinal bs x color x height x width
    # retinal = torch.zeros(layers.shape[0], 3, Retinal_IM_H, Retinal_IM_W)
    # retinal = var_or_cuda(retinal)
    phi = var_or_cuda(phi)
    phi.requires_grad = False
    retinal = gen.GenRetinalFromLayers(layers[0],phi)
    retinal = var_or_cuda(retinal)
    mask_out = mask.unsqueeze(0).unsqueeze(0)
BobYeah's avatar
BobYeah committed
283
284
    # print("maskOUt:",mask_out.shape) # 1,1,240,320
    # mask_out = torch.stack(mask,dim = 0).unsqueeze(1) # batch x 1 x height x width
BobYeah's avatar
Gaze    
BobYeah committed
285
286
    return retinal.unsqueeze(0), mask_out
#### Image Gen End
BobYeah's avatar
BobYeah committed
287
288
289
290
291
292
293
294
295
296
297
298
299
300

weightVarScale = 0.25
bias_stddev = 0.01

def weight_init_normal(m):
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        torch.nn.init.xavier_normal_(m.weight.data)
        torch.nn.init.normal_(m.bias.data,mean = 0.0, std=bias_stddev)
    elif classname.find("BatchNorm2d") != -1:
        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
        torch.nn.init.constant_(m.bias.data, 0.0)


BobYeah's avatar
BobYeah committed
301
302
303
def calImageGradients(images):
    # x is a 4-D tensor
    dx = images[:, :, 1:, :] - images[:, :, :-1, :]
BobYeah's avatar
Gaze    
BobYeah committed
304
    dy = images[:, :, :, 1:] - images[:, :, :, :-1]
BobYeah's avatar
BobYeah committed
305
306
307
308
    return dx, dy


perc_loss = VGGPerceptualLoss() 
BobYeah's avatar
BobYeah committed
309
perc_loss = perc_loss.to("cuda:1")
BobYeah's avatar
BobYeah committed
310
311
312
313

def loss_new(generated, gt):
    mse_loss = torch.nn.MSELoss()
    rmse_intensity = mse_loss(generated, gt)
BobYeah's avatar
Gaze    
BobYeah committed
314
    
BobYeah's avatar
BobYeah committed
315
    psnr_intensity = torch.log10(rmse_intensity)
BobYeah's avatar
BobYeah committed
316
317
    # print("psnr:",psnr_intensity)
    # ssim_intensity = ssim(generated, gt)
BobYeah's avatar
BobYeah committed
318
    labels_dx, labels_dy = calImageGradients(gt)
BobYeah's avatar
BobYeah committed
319
    # print("generated:",generated.shape)
BobYeah's avatar
BobYeah committed
320
321
322
    preds_dx, preds_dy = calImageGradients(generated)
    rmse_grad_x, rmse_grad_y = mse_loss(labels_dx, preds_dx), mse_loss(labels_dy, preds_dy)
    psnr_grad_x, psnr_grad_y = torch.log10(rmse_grad_x), torch.log10(rmse_grad_y)
BobYeah's avatar
BobYeah committed
323
    # print("psnr x&y:",psnr_grad_x," ",psnr_grad_y)
BobYeah's avatar
BobYeah committed
324
    p_loss = perc_loss(generated,gt)
BobYeah's avatar
BobYeah committed
325
    # print("-psnr:",-psnr_intensity,",0.5*(psnr_grad_x + psnr_grad_y):",0.5*(psnr_grad_x + psnr_grad_y),",perc_loss:",p_loss)
BobYeah's avatar
BobYeah committed
326
    total_loss = 10 + psnr_intensity + 0.5*(psnr_grad_x + psnr_grad_y) + p_loss
BobYeah's avatar
BobYeah committed
327
    # total_loss = rmse_intensity + 0.5*(rmse_grad_x + rmse_grad_y) # + p_loss
BobYeah's avatar
BobYeah committed
328
329
330
331
332
333
334
335
336
337
338
    return total_loss

def save_checkpoints(file_path, epoch_idx, model, model_solver):
    print('[INFO] Saving checkpoint to %s ...' % ( file_path))
    checkpoint = {
        'epoch_idx': epoch_idx,
        'model_state_dict': model.state_dict(),
        'model_solver_state_dict': model_solver.state_dict()
    }
    torch.save(checkpoint, file_path)

BobYeah's avatar
BobYeah committed
339
340
341
342
343
344
345
346
347
348
349
350
mode = "train"

import pickle
def save_obj(obj, name ):
    # with open('./outputF/dict/'+ name + '.pkl', 'wb') as f:
    #     pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL)
    torch.save(obj,'./outputF/dict/'+ name + '.pkl')
def load_obj(name):
    # with open('./outputF/dict/' + name + '.pkl', 'rb') as f:
    #     return pickle.load(f)
    return torch.load('./outputF/dict/'+ name + '.pkl')

BobYeah's avatar
Gaze    
BobYeah committed
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
def hook_fn_back(m, i, o):
  for grad in i:
    try:
      print("Input Grad:",m,grad.shape,grad.sum())
    except AttributeError: 
      print ("None found for Gradient")
  for grad in o:  
    try:
      print("Output Grad:",m,grad.shape,grad.sum())
    except AttributeError: 
      print ("None found for Gradient")
  print("\n")

def hook_fn_for(m, i, o):
  for grad in i:
    try:
      print("Input Feats:",m,grad.shape,grad.sum())
    except AttributeError: 
      print ("None found for Gradient")
  for grad in o:  
    try:
      print("Output Feats:",m,grad.shape,grad.sum())
    except AttributeError: 
      print ("None found for Gradient")
  print("\n")

BobYeah's avatar
BobYeah committed
377
def generatePhiMaskDict(data_json, generator):
BobYeah's avatar
Gaze    
BobYeah committed
378
379
380
    phi_dict = {}
    mask_dict = {}
    idx_info_dict = {}
BobYeah's avatar
BobYeah committed
381
    with open(data_json, encoding='utf-8') as file:
BobYeah's avatar
Gaze    
BobYeah committed
382
383
384
385
386
387
388
389
        dataset_desc = json.loads(file.read())
        for i in range(len(dataset_desc["focaldepth"])):
            # if i == 2:
            #     break
            idx = dataset_desc["idx"][i] 
            focaldepth = dataset_desc["focaldepth"][i]
            gazeX = dataset_desc["gazeX"][i]
            gazeY = dataset_desc["gazeY"][i]
BobYeah's avatar
BobYeah committed
390
391
            print("focaldepth:",focaldepth," idx:",idx," gazeX:",gazeX," gazeY:",gazeY)
            phi,mask =  generator.CalculateRetinal2LayerMappings(focaldepth,torch.tensor([gazeX, gazeY]))
BobYeah's avatar
Gaze    
BobYeah committed
392
393
394
            phi_dict[idx]=phi
            mask_dict[idx]=mask
            idx_info_dict[idx]=[idx,focaldepth,gazeX,gazeY]
BobYeah's avatar
BobYeah committed
395
396
397
398
    return phi_dict,mask_dict,idx_info_dict

if __name__ == "__main__":
    ############################## generate phi and mask in pre-training
BobYeah's avatar
BobYeah committed
399
    
BobYeah's avatar
BobYeah committed
400
401
402
403
404
405
    # print("generating phi and mask...")
    # phi_dict,mask_dict,idx_info_dict = generatePhiMaskDict(DATA_JSON,gen)
    # save_obj(phi_dict,"phi_1204")
    # save_obj(mask_dict,"mask_1204")
    # save_obj(idx_info_dict,"idx_info_1204")
    # print("generating phi and mask end.")
BobYeah's avatar
Gaze    
BobYeah committed
406
    # exit(0)
BobYeah's avatar
BobYeah committed
407
408
409
410
411
412
413
414
415
    ############################# load phi and mask in pre-training
    print("loading phi and mask ...")
    phi_dict = load_obj("phi_1204")
    mask_dict = load_obj("mask_1204")
    idx_info_dict = load_obj("idx_info_1204")
    print(len(phi_dict))
    print(len(mask_dict))
    print("loading phi and mask end") 

BobYeah's avatar
BobYeah committed
416
    #train
BobYeah's avatar
BobYeah committed
417
    train_data_loader = torch.utils.data.DataLoader(dataset=lightFieldSeqDataLoader(DATA_FILE,DATA_JSON),
BobYeah's avatar
BobYeah committed
418
419
420
                                                    batch_size=BATCH_SIZE,
                                                    num_workers=0,
                                                    pin_memory=True,
BobYeah's avatar
BobYeah committed
421
                                                    shuffle=True,
BobYeah's avatar
BobYeah committed
422
423
                                                    drop_last=False)
    print(len(train_data_loader))
BobYeah's avatar
BobYeah committed
424

BobYeah's avatar
Gaze    
BobYeah committed
425
    # exit(0)
BobYeah's avatar
BobYeah committed
426

BobYeah's avatar
BobYeah committed
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477

    ################################################ val #########################################################
    # val_data_loader = torch.utils.data.DataLoader(dataset=lightFieldValDataLoader(DATA_FILE,DATA_VAL_JSON),
    #                                                 batch_size=1,
    #                                                 num_workers=0,
    #                                                 pin_memory=True,
    #                                                 shuffle=False,
    #                                                 drop_last=False)

    # print(len(val_data_loader))

    # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    # lf_model = baseline.model()
    # if torch.cuda.is_available():
    #     lf_model = torch.nn.DataParallel(lf_model).cuda()

    # checkpoint = torch.load(os.path.join(OUTPUT_DIR,"gaze-ckpt-epoch-0201.pth"))
    # lf_model.load_state_dict(checkpoint["model_state_dict"])
    # lf_model.eval()

    # print("Eval::")
    # for sample_idx, (image_set, df, gazeX, gazeY, sample_idx) in enumerate(val_data_loader):
    #     print("sample_idx::",sample_idx)
    #     with torch.no_grad():
            
    #         #reshape for input
    #         image_set = image_set.permute(0,1,4,2,3) # N LF C H W
    #         image_set = image_set.reshape(image_set.shape[0],-1,image_set.shape[3],image_set.shape[4]) # N, LFxC, H, W
    #         image_set = var_or_cuda(image_set)

    #         # print("Epoch:",epoch,",Iter:",batch_idx,",Input shape:",image_set.shape, ",Input gt:",gt.shape)
    #         output = lf_model(image_set,df,gazeX,gazeY)
    #         output1,mask = GenRetinalGazeFromLayersBatch(output, gen, sample_idx, phi_dict, mask_dict)

    #         for i in range(0, 2):
    #             output1[:,i*3:i*3+3].mul_(mask[:,i:i+1])
    #             output1[:,i*3:i*3+3].clamp_(0., 1.)
            
    #         print("output:",output.shape," df:",df[0].data, ",gazeX:",gazeX[0].data,",gazeY:", gazeY[0].data)
    #         for i in range(output1.size()[0]):
    #             save_image(output[i][0:3].data,os.path.join(OUTPUT_DIR,"test_interp_gaze_fac1_o_%.3f_%.3f_%.3f.png"%(df[i].data,gazeX[i].data,gazeY[i].data)))
    #             save_image(output[i][3:6].data,os.path.join(OUTPUT_DIR,"test_interp_gaze_fac2_o_%.3f_%.3f_%.3f.png"%(df[i].data,gazeX[i].data,gazeY[i].data)))
    #             save_image(output1[i][0:3].data,os.path.join(OUTPUT_DIR,"test_interp_gaze_out1_o_%.3f_%.3f_%.3f.png"%(df[i].data,gazeX[i].data,gazeY[i].data)))
    #             save_image(output1[i][3:6].data,os.path.join(OUTPUT_DIR,"test_interp_gaze_out2_o_%.3f_%.3f_%.3f.png"%(df[i].data,gazeX[i].data,gazeY[i].data)))

    #         # save_image(output[0][0:3].data,os.path.join(OUTPUT_DIR,"gaze_fovea_interp_l1_%.3f.png"%(df[0].data)))
    #         # save_image(output[0][3:6].data,os.path.join(OUTPUT_DIR,"gaze_fovea_interp_l2_%.3f.png"%(df[0].data)))
    #         # output = GenRetinalFromLayersBatch(output,conf,df,v,u)
    #         # save_image(output[0][0:3].data,os.path.join(OUTPUT_DIR,"1113_interp_o%.3f.png"%(df[0].data)))
    # exit()

BobYeah's avatar
Gaze    
BobYeah committed
478
    ################################################ train #########################################################
BobYeah's avatar
BobYeah committed
479
    lf_model = model(FIRSST_LAYER_CHANNELS,LAST_LAYER_CHANNELS,OUT_CHANNELS_RB,KERNEL_SIZE,KERNEL_SIZE_RB,INTERLEAVE_RATE)
BobYeah's avatar
Gaze    
BobYeah committed
480
481
482
483
    lf_model.apply(weight_init_normal)

    epoch_begin = 0

BobYeah's avatar
BobYeah committed
484
485
486
487
488
489
490
491
492
493
    ################################ load model file
    # WEIGHTS = os.path.join(OUTPUT_DIR, 'gaze-ckpt-epoch-%04d.pth' % (101))
    # print('[INFO] Recovering from %s ...' % (WEIGHTS))
    # checkpoint = torch.load(WEIGHTS)
    # init_epoch = checkpoint['epoch_idx']
    # lf_model.load_state_dict(checkpoint['model_state_dict'])
    # epoch_begin = init_epoch + 1
    # print(lf_model)
    ############################################################

BobYeah's avatar
BobYeah committed
494
    if torch.cuda.is_available():
BobYeah's avatar
BobYeah committed
495
496
        # lf_model = torch.nn.DataParallel(lf_model).cuda()
        lf_model = lf_model.to('cuda:1')
BobYeah's avatar
Gaze    
BobYeah committed
497
498
    lf_model.train()
    optimizer = torch.optim.Adam(lf_model.parameters(),lr=1e-2,betas=(0.9,0.999))
BobYeah's avatar
BobYeah committed
499
500
    l1loss = torch.nn.L1Loss()
    # lf_model.output_layer.register_backward_hook(hook_fn_back)
BobYeah's avatar
Gaze    
BobYeah committed
501
502
    print("begin training....")
    for epoch in range(epoch_begin, NUM_EPOCH):
BobYeah's avatar
BobYeah committed
503
504
505
506
507
        for batch_idx, (image_set, gt, gt2, df, gazeX, gazeY, sample_idx) in enumerate(train_data_loader):
            # print(sample_idx.shape,df.shape,gazeX.shape,gazeY.shape) # torch.Size([2, 5])
            # print(image_set.shape,gt.shape,gt2.shape) #torch.Size([2, 5, 9, 320, 320, 3]) torch.Size([2, 5, 160, 160, 3]) torch.Size([2, 5, 160, 160, 3])
            # print(delta.shape) # delta: torch.Size([2, 4, 160, 160, 3])
            
BobYeah's avatar
BobYeah committed
508
            #reshape for input
BobYeah's avatar
BobYeah committed
509
510
            image_set = image_set.permute(0,1,2,5,3,4) # N S LF C H W
            image_set = image_set.reshape(image_set.shape[0],image_set.shape[1],-1,image_set.shape[4],image_set.shape[5]) # N, LFxC, H, W
BobYeah's avatar
BobYeah committed
511
            image_set = var_or_cuda(image_set)
BobYeah's avatar
BobYeah committed
512
            gt = gt.permute(0,1,4,2,3) # N S C H W
BobYeah's avatar
BobYeah committed
513
            gt = var_or_cuda(gt)
BobYeah's avatar
Gaze    
BobYeah committed
514

BobYeah's avatar
BobYeah committed
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
            gt2 = gt2.permute(0,1,4,2,3)
            gt2 = var_or_cuda(gt2)

            gen1 = torch.empty(gt.shape)
            gen1 = var_or_cuda(gen1)

            gen2 = torch.empty(gt2.shape)
            gen2 = var_or_cuda(gen2)

            warped = torch.empty(gt2.shape[0],gt2.shape[1]-1,gt2.shape[2],gt2.shape[3],gt2.shape[4])
            warped = var_or_cuda(warped)

            delta = torch.empty(gt2.shape[0],gt2.shape[1]-1,gt2.shape[2],gt2.shape[3],gt2.shape[4])
            delta = var_or_cuda(delta)
            
            for k in range(image_set.shape[1]):
                if k == 0:
                    lf_model.reset_hidden(image_set[:,k])
                
                # start = torch.cuda.Event(enable_timing=True)
                # end = torch.cuda.Event(enable_timing=True)
                # start.record()
                output = lf_model(image_set[:,k],df[:,k],gazeX[:,k],gazeY[:,k])
                # end.record()
                # torch.cuda.synchronize()
                # print("Model Forward:",start.elapsed_time(end))
                # print("output:",output.shape) # [2, 6, 320, 320]
                # exit()
                ########################### Use Pregen Phi and Mask ###################
                # start.record()
                output1,mask = GenRetinalGazeFromLayersBatch(output, gen, sample_idx[:,k], phi_dict, mask_dict)
                # end.record()
                # torch.cuda.synchronize()
                # print("Merge:",start.elapsed_time(end))

                # print("output1 shape:",output1.shape, "mask shape:",mask.shape)
                # output1 shape: torch.Size([2, 6, 160, 160]) mask shape: torch.Size([2, 2, 160, 160])
                for i in range(0, 2):
                    output1[:,i*3:i*3+3].mul_(mask[:,i:i+1])
                    if i == 0:
                        gt[:,k].mul_(mask[:,i:i+1])
                    if i == 1:
                        gt2[:,k].mul_(mask[:,i:i+1])
                
                gen1[:,k] = output1[:,0:3]
                gen2[:,k] = output1[:,3:6]
                if ((epoch%5== 0) or epoch == 2):
                    for i in range(output.shape[0]):
                        save_image(output[i][0:3].data,os.path.join(OUTPUT_DIR,"gaze_fac1_o_%.3f_%.3f_%.3f.png"%(df[i][k].data,gazeX[i][k].data,gazeY[i][k].data)))
                        save_image(output[i][3:6].data,os.path.join(OUTPUT_DIR,"gaze_fac2_o_%.3f_%.3f_%.3f.png"%(df[i][k].data,gazeX[i][k].data,gazeY[i][k].data)))
BobYeah's avatar
Gaze    
BobYeah committed
565
566

            ########################### Update ###################
BobYeah's avatar
BobYeah committed
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
            for i in range(1,image_set.shape[1]):
                delta[:,i-1] = gt2[:,i] - gt2[:,i]
                warped[:,i-1] = gen2[:,i]-gen2[:,i-1]
            
            optimizer.zero_grad()

            # # N S C H W
            gen1 = gen1.reshape(-1,gen1.shape[2],gen1.shape[3],gen1.shape[4])
            gen2 = gen2.reshape(-1,gen2.shape[2],gen2.shape[3],gen2.shape[4])
            gt = gt.reshape(-1,gt.shape[2],gt.shape[3],gt.shape[4])
            gt2 = gt2.reshape(-1,gt2.shape[2],gt2.shape[3],gt2.shape[4])
            warped = warped.reshape(-1,warped.shape[2],warped.shape[3],warped.shape[4])
            delta = delta.reshape(-1,delta.shape[2],delta.shape[3],delta.shape[4])


            # start = torch.cuda.Event(enable_timing=True)
            # end = torch.cuda.Event(enable_timing=True)
            # start.record()
            loss1 = loss_new(gen1,gt)
            loss2 = loss_new(gen2,gt2)
            loss3 = l1loss(warped,delta)
            loss = loss1+loss2+loss3
            # end.record()
            # torch.cuda.synchronize()
            # print("loss comp:",start.elapsed_time(end))

            
            # start.record()
BobYeah's avatar
Gaze    
BobYeah committed
595
            loss.backward()
BobYeah's avatar
BobYeah committed
596
597
598
            # end.record()
            # torch.cuda.synchronize()
            # print("backward:",start.elapsed_time(end))
BobYeah's avatar
Gaze    
BobYeah committed
599

BobYeah's avatar
BobYeah committed
600
601
602
603
604
605
606
607
            # start.record()
            optimizer.step()
            # end.record()
            # torch.cuda.synchronize()
            # print("optimizer step:",start.elapsed_time(end))
            
            ## Update Prev
            print("Epoch:",epoch,",Iter:",batch_idx,",loss:",loss)
BobYeah's avatar
Gaze    
BobYeah committed
608
            ########################### Save #####################
BobYeah's avatar
BobYeah committed
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
            if ((epoch%5== 0) or epoch == 2): # torch.Size([2, 5, 160, 160, 3])
                for i in range(gt.size()[0]):
                    # df 2,5 
                    save_image(gen1[i].data,os.path.join(OUTPUT_DIR,"gaze_out1_o_%.3f_%.3f_%.3f.png"%(df[i//5][i%5].data,gazeX[i//5][i%5].data,gazeY[i//5][i%5].data)))
                    save_image(gen2[i].data,os.path.join(OUTPUT_DIR,"gaze_out2_o_%.3f_%.3f_%.3f.png"%(df[i//5][i%5].data,gazeX[i//5][i%5].data,gazeY[i//5][i%5].data)))
                    save_image(gt[i].data,os.path.join(OUTPUT_DIR,"gaze_test1_gt0_%.3f_%.3f_%.3f.png"%(df[i//5][i%5].data,gazeX[i//5][i%5].data,gazeY[i//5][i%5].data)))
                    save_image(gt2[i].data,os.path.join(OUTPUT_DIR,"gaze_test1_gt1_%.3f_%.3f_%.3f.png"%(df[i//5][i%5].data,gazeX[i//5][i%5].data,gazeY[i//5][i%5].data)))
            if ((epoch%100 == 0) and epoch != 0 and batch_idx==len(train_data_loader)-1):
                save_checkpoints(os.path.join(OUTPUT_DIR, 'gaze-ckpt-epoch-%04d.pth' % (epoch + 1)),epoch,lf_model,optimizer)

            ########################## test Phi and Mask ##########################
            # phi,mask =  gen.CalculateRetinal2LayerMappings(df[0],torch.tensor([gazeX[0], gazeY[0]]))
            # # print("gaze Online:",gazeX[0]," ,",gazeY[0])
            # # print("df Online:",df[0])
            # # print("idx:",int(sample_idx[0].data))
            # phi_t = phi_dict[int(sample_idx[0].data)]
            # mask_t = mask_dict[int(sample_idx[0].data)]
            # # print("idx info:",idx_info_dict[int(sample_idx[0].data)])
            # # print("phi online:", phi.shape, " phi_t:", phi_t.shape)
            # # print("mask online:", mask.shape, " mask_t:", mask_t.shape)
            # print("phi delta:", (phi-phi_t).sum()," mask delta:",(mask -mask_t).sum())
            # exit(0)

            ###########################Gen Batch 1 by 1###################
            # phi,mask =  gen.CalculateRetinal2LayerMappings(df[0],torch.tensor([gazeX[0], gazeY[0]]))
            # # print(phi.shape) # 2,240,320,41,2
            # output1, mask = GenRetinalFromLayersBatch_Online(output, gen, phi, mask)
            ###########################Gen Batch 1 by 1###################