main.py 15.9 KB
Newer Older
BobYeah's avatar
BobYeah committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import torch
import argparse
import os
import glob
import numpy as np
import torchvision.transforms as transforms
from torchvision.utils import save_image

from torchvision import datasets
from torch.utils.data import DataLoader 
from torch.autograd import Variable

import cv2
from gen_image import *
import json
BobYeah's avatar
BobYeah committed
16
17
from ssim import *
from perc_loss import * 
BobYeah's avatar
BobYeah committed
18
# param
BobYeah's avatar
Gaze    
BobYeah committed
19
20
BATCH_SIZE = 16
NUM_EPOCH = 1000
BobYeah's avatar
BobYeah committed
21
22
23

INTERLEAVE_RATE = 2

BobYeah's avatar
Gaze    
BobYeah committed
24
25
26
27
28
IM_H = 320
IM_W = 320

Retinal_IM_H = 320
Retinal_IM_W = 320
BobYeah's avatar
BobYeah committed
29
30
31
32

N = 9 # number of input light field stack
M = 2 # number of display layers

BobYeah's avatar
Gaze    
BobYeah committed
33
34
DATA_FILE = "/home/yejiannan/Project/LightField/data/gaze_small_nar_new"
DATA_JSON = "/home/yejiannan/Project/LightField/data/data_gaze_low_new.json"
BobYeah's avatar
BobYeah committed
35
DATA_VAL_JSON = "/home/yejiannan/Project/LightField/data/data_val.json"
BobYeah's avatar
Gaze    
BobYeah committed
36
OUTPUT_DIR = "/home/yejiannan/Project/LightField/output/gaze_low_new_1125_minibatch"
BobYeah's avatar
BobYeah committed
37
38
39
40
41

class lightFieldDataLoader(torch.utils.data.dataset.Dataset):
    def __init__(self, file_dir_path, file_json, transforms=None):
        self.file_dir_path = file_dir_path
        self.transforms = transforms
BobYeah's avatar
BobYeah committed
42
        with open(file_json, encoding='utf-8') as file:
BobYeah's avatar
Gaze    
BobYeah committed
43
            self.dataset_desc = json.loads(file.read())
BobYeah's avatar
BobYeah committed
44
45

    def __len__(self):
BobYeah's avatar
Gaze    
BobYeah committed
46
        return len(self.dataset_desc["focaldepth"])
BobYeah's avatar
BobYeah committed
47
48

    def __getitem__(self, idx):
BobYeah's avatar
Gaze    
BobYeah committed
49
        lightfield_images, gt, fd, gazeX, gazeY, sample_idx = self.get_datum(idx)
BobYeah's avatar
BobYeah committed
50
51
        if self.transforms:
            lightfield_images = self.transforms(lightfield_images)
BobYeah's avatar
Gaze    
BobYeah committed
52
        return (lightfield_images, gt, fd, gazeX, gazeY, sample_idx)
BobYeah's avatar
BobYeah committed
53
54

    def get_datum(self, idx):
BobYeah's avatar
Gaze    
BobYeah committed
55
56
        lf_image_paths = os.path.join(DATA_FILE, self.dataset_desc["train"][idx])
        fd_gt_path = os.path.join(DATA_FILE, self.dataset_desc["gt"][idx])
BobYeah's avatar
BobYeah committed
57
58
59
60
61
        lf_images = []
        lf_image_big = cv2.imread(lf_image_paths, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.
        lf_image_big = cv2.cvtColor(lf_image_big,cv2.COLOR_BGR2RGB)
        for i in range(9):
            lf_image = lf_image_big[i//3*IM_H:i//3*IM_H+IM_H,i%3*IM_W:i%3*IM_W+IM_W,0:3]
BobYeah's avatar
Gaze    
BobYeah committed
62
63
            ## IF GrayScale
            # lf_image = lf_image_big[i//3*IM_H:i//3*IM_H+IM_H,i%3*IM_W:i%3*IM_W+IM_W,0:1]
BobYeah's avatar
BobYeah committed
64
65
66
67
            # print(lf_image.shape)
            lf_images.append(lf_image)
        gt = cv2.imread(fd_gt_path, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.
        gt = cv2.cvtColor(gt,cv2.COLOR_BGR2RGB)
BobYeah's avatar
Gaze    
BobYeah committed
68
69
70
71
72
73
74
75
76
        ## IF GrayScale
        # gt = cv2.imread(fd_gt_path, cv2.IMREAD_GRAYSCALE).astype(np.float32) / 255.
        # gt = np.expand_dims(gt, axis=-1)

        fd = self.dataset_desc["focaldepth"][idx]
        gazeX = self.dataset_desc["gazeX"][idx]
        gazeY = self.dataset_desc["gazeY"][idx]
        sample_idx = self.dataset_desc["idx"][idx]
        return np.asarray(lf_images),gt,fd,gazeX,gazeY,sample_idx
BobYeah's avatar
BobYeah committed
77
78
79
80
81
82

OUT_CHANNELS_RB = 128
KERNEL_SIZE_RB = 3
KERNEL_SIZE = 3

class residual_block(torch.nn.Module):
BobYeah's avatar
BobYeah committed
83
    def __init__(self,delta_channel_dim):
BobYeah's avatar
BobYeah committed
84
85
        super(residual_block,self).__init__()
        self.layer1 = torch.nn.Sequential(
BobYeah's avatar
BobYeah committed
86
87
            torch.nn.Conv2d(OUT_CHANNELS_RB+delta_channel_dim,OUT_CHANNELS_RB+delta_channel_dim,KERNEL_SIZE_RB,stride=1,padding = 1),
            torch.nn.BatchNorm2d(OUT_CHANNELS_RB+delta_channel_dim),
BobYeah's avatar
BobYeah committed
88
89
90
            torch.nn.ELU()
        )
        self.layer2 = torch.nn.Sequential(
BobYeah's avatar
BobYeah committed
91
92
            torch.nn.Conv2d(OUT_CHANNELS_RB+delta_channel_dim,OUT_CHANNELS_RB+delta_channel_dim,KERNEL_SIZE_RB,stride=1,padding = 1),
            torch.nn.BatchNorm2d(OUT_CHANNELS_RB+delta_channel_dim),
BobYeah's avatar
BobYeah committed
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
            torch.nn.ELU()
        )

    def forward(self,input):
        output = self.layer1(input)
        output = self.layer2(output)
        output = input+output
        return output

class deinterleave(torch.nn.Module):
    def __init__(self, block_size):
        super(deinterleave, self).__init__()
        self.block_size = block_size
        self.block_size_sq = block_size*block_size

    def forward(self, input):
        output = input.permute(0, 2, 3, 1)
        (batch_size, d_height, d_width, d_depth) = output.size()
        s_depth = int(d_depth / self.block_size_sq)
        s_width = int(d_width * self.block_size)
        s_height = int(d_height * self.block_size)
        t_1 = output.reshape(batch_size, d_height, d_width, self.block_size_sq, s_depth)
        spl = t_1.split(self.block_size, 3)
        stack = [t_t.reshape(batch_size, d_height, s_width, s_depth) for t_t in spl]
        output = torch.stack(stack,0).transpose(0,1).permute(0,2,1,3,4).reshape(batch_size, s_height, s_width, s_depth)
        output = output.permute(0, 3, 1, 2)
        return output

class interleave(torch.nn.Module):
    def __init__(self, block_size):
        super(interleave, self).__init__()
        self.block_size = block_size
        self.block_size_sq = block_size*block_size

    def forward(self, input):
        output = input.permute(0, 2, 3, 1)
        (batch_size, s_height, s_width, s_depth) = output.size()
        d_depth = s_depth * self.block_size_sq
        d_width = int(s_width / self.block_size)
        d_height = int(s_height / self.block_size)
        t_1 = output.split(self.block_size, 2)
        stack = [t_t.reshape(batch_size, d_height, d_depth) for t_t in t_1]
        output = torch.stack(stack, 1)
        output = output.permute(0, 2, 1, 3)
        output = output.permute(0, 3, 1, 2)
        return output

LAST_LAYER_CHANNELS = 6 * INTERLEAVE_RATE**2
BobYeah's avatar
BobYeah committed
141
FIRSST_LAYER_CHANNELS = 27 * INTERLEAVE_RATE**2
BobYeah's avatar
BobYeah committed
142
143
144
145
146
147
148
149
150
151
152
153

class model(torch.nn.Module):
    def __init__(self):
        super(model, self).__init__()
        self.interleave = interleave(INTERLEAVE_RATE)

        self.first_layer = torch.nn.Sequential(
            torch.nn.Conv2d(FIRSST_LAYER_CHANNELS,OUT_CHANNELS_RB,KERNEL_SIZE,stride=1,padding=1),
            torch.nn.BatchNorm2d(OUT_CHANNELS_RB),
            torch.nn.ELU()
        )
        
BobYeah's avatar
BobYeah committed
154
        self.residual_block1 = residual_block(0)
BobYeah's avatar
Gaze    
BobYeah committed
155
156
157
158
        self.residual_block2 = residual_block(3)
        self.residual_block3 = residual_block(3)
        self.residual_block4 = residual_block(3)
        self.residual_block5 = residual_block(3)
BobYeah's avatar
BobYeah committed
159
160

        self.output_layer = torch.nn.Sequential(
BobYeah's avatar
Gaze    
BobYeah committed
161
            torch.nn.Conv2d(OUT_CHANNELS_RB+3,LAST_LAYER_CHANNELS,KERNEL_SIZE,stride=1,padding=1),
BobYeah's avatar
BobYeah committed
162
            torch.nn.BatchNorm2d(LAST_LAYER_CHANNELS),
BobYeah's avatar
BobYeah committed
163
            torch.nn.Sigmoid()
BobYeah's avatar
BobYeah committed
164
165
166
167
        )
        self.deinterleave = deinterleave(INTERLEAVE_RATE)


BobYeah's avatar
Gaze    
BobYeah committed
168
    def forward(self, lightfield_images, focal_length, gazeX, gazeY):
BobYeah's avatar
BobYeah committed
169
170
171
        input_to_net = self.interleave(lightfield_images)
        input_to_rb = self.first_layer(input_to_net)
        output = self.residual_block1(input_to_rb)
BobYeah's avatar
Gaze    
BobYeah committed
172
173
174
        depth_layer = torch.ones((input_to_rb.shape[0],1,input_to_rb.shape[2],input_to_rb.shape[3]))
        gazeX_layer = torch.ones((input_to_rb.shape[0],1,input_to_rb.shape[2],input_to_rb.shape[3]))
        gazeY_layer = torch.ones((input_to_rb.shape[0],1,input_to_rb.shape[2],input_to_rb.shape[3]))
BobYeah's avatar
BobYeah committed
175
        for i in range(focal_length.shape[0]):
BobYeah's avatar
Gaze    
BobYeah committed
176
177
178
            depth_layer[i] *= 1. / focal_length[i]
            gazeX_layer[i] *= (gazeX[i] - (-3.333)) / (3.333*2)
            gazeY_layer[i] *= (gazeY[i] - (-3.333)) / (3.333*2)
BobYeah's avatar
BobYeah committed
179
        depth_layer = var_or_cuda(depth_layer)
BobYeah's avatar
Gaze    
BobYeah committed
180
181
        gazeX_layer = var_or_cuda(gazeX_layer)
        gazeY_layer = var_or_cuda(gazeY_layer)
BobYeah's avatar
BobYeah committed
182

BobYeah's avatar
Gaze    
BobYeah committed
183
        output = torch.cat((output,depth_layer,gazeX_layer,gazeY_layer),dim=1)
BobYeah's avatar
BobYeah committed
184
        output = self.residual_block2(output)
BobYeah's avatar
BobYeah committed
185
        output = self.residual_block3(output)
BobYeah's avatar
Gaze    
BobYeah committed
186
187
        output = self.residual_block4(output)
        output = self.residual_block5(output)
BobYeah's avatar
BobYeah committed
188
189
190
191
192
193
194
        output = self.output_layer(output)
        output = self.deinterleave(output)
        return output

class Conf(object):
    def __init__(self):
        self.pupil_size = 0.02 # 2cm
BobYeah's avatar
Gaze    
BobYeah committed
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
        self.retinal_res = torch.tensor([ Retinal_IM_H, Retinal_IM_W ])
        self.layer_res = torch.tensor([ IM_H, IM_W ])
        self.layer_hfov = 90    # layers' horizontal FOV
        self.eye_hfov = 85      # eye's horizontal FOV
        self.d_layer = [ 1, 3 ] # layers' distance
        
    def GetNLayers(self):
        return len(self.d_layer)
    
    def GetLayerSize(self, i):
        w = Fov2Length(self.layer_hfov)
        h = w * self.layer_res[0] / self.layer_res[1]
        return torch.tensor([ h, w ]) * self.d_layer[i]

    def GetEyeViewportSize(self): 
        w = Fov2Length(self.eye_hfov) 
        h = w * self.retinal_res[0] / self.retinal_res[1] 
        return torch.tensor([ h, w ])
BobYeah's avatar
BobYeah committed
213
214
215
216

#### Image Gen
conf = Conf()

BobYeah's avatar
Gaze    
BobYeah committed
217
u = GenSamplesInPupil(conf.pupil_size, 5)
BobYeah's avatar
BobYeah committed
218

BobYeah's avatar
Gaze    
BobYeah committed
219
gen = RetinalGen(conf, u)
BobYeah's avatar
BobYeah committed
220

BobYeah's avatar
Gaze    
BobYeah committed
221
222
def GenRetinalFromLayersBatch(layers, gen, sample_idx, phi_dict, mask_dict):
    # layers: batchsize, 2*color, height, width 
BobYeah's avatar
BobYeah committed
223
224
    # Phi:torch.Size([batchsize, 480, 640, 2, 41, 2])
    # df : batchsize,..
BobYeah's avatar
Gaze    
BobYeah committed
225
226
227
228
229
230
231
232
233
234
    
    #  retinal bs x color x height x width
    retinal = torch.zeros(layers.shape[0], 3, Retinal_IM_H, Retinal_IM_W)
    mask = [] # mask shape 480 x 640
    for i in range(0, layers.size()[0]):
        phi = phi_dict[int(sample_idx[i].data)]
        phi = var_or_cuda(phi)
        phi.requires_grad = False
        retinal[i] = gen.GenRetinalFromLayers(layers[i],phi)
        mask.append(mask_dict[int(sample_idx[i].data)])
BobYeah's avatar
BobYeah committed
235
    retinal = var_or_cuda(retinal)
BobYeah's avatar
Gaze    
BobYeah committed
236
237
    mask = torch.stack(mask,dim = 0).unsqueeze(1) # batch x 1 x height x width
    return retinal, mask
BobYeah's avatar
BobYeah committed
238

BobYeah's avatar
Gaze    
BobYeah committed
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
def GenRetinalFromLayersBatch_Online(layers, gen, phi, mask):
    # layers: batchsize, 2*color, height, width 
    # Phi:torch.Size([batchsize, 480, 640, 2, 41, 2])
    # df : batchsize,..
    
    #  retinal bs x color x height x width
    # retinal = torch.zeros(layers.shape[0], 3, Retinal_IM_H, Retinal_IM_W)
    # retinal = var_or_cuda(retinal)
    phi = var_or_cuda(phi)
    phi.requires_grad = False
    retinal = gen.GenRetinalFromLayers(layers[0],phi)
    retinal = var_or_cuda(retinal)
    mask_out = mask.unsqueeze(0).unsqueeze(0)
    return retinal.unsqueeze(0), mask_out
#### Image Gen End
BobYeah's avatar
BobYeah committed
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271

weightVarScale = 0.25
bias_stddev = 0.01

def weight_init_normal(m):
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        torch.nn.init.xavier_normal_(m.weight.data)
        torch.nn.init.normal_(m.bias.data,mean = 0.0, std=bias_stddev)
    elif classname.find("BatchNorm2d") != -1:
        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
        torch.nn.init.constant_(m.bias.data, 0.0)

def var_or_cuda(x):
    if torch.cuda.is_available():
        x = x.cuda(non_blocking=True)
    return x

BobYeah's avatar
BobYeah committed
272
273
274
def calImageGradients(images):
    # x is a 4-D tensor
    dx = images[:, :, 1:, :] - images[:, :, :-1, :]
BobYeah's avatar
Gaze    
BobYeah committed
275
    dy = images[:, :, :, 1:] - images[:, :, :, :-1]
BobYeah's avatar
BobYeah committed
276
277
278
279
280
281
282
283
284
    return dx, dy


perc_loss = VGGPerceptualLoss() 
perc_loss = perc_loss.to("cuda")

def loss_new(generated, gt):
    mse_loss = torch.nn.MSELoss()
    rmse_intensity = mse_loss(generated, gt)
BobYeah's avatar
Gaze    
BobYeah committed
285
    
BobYeah's avatar
BobYeah committed
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
    psnr_intensity = torch.log10(rmse_intensity)
    labels_dx, labels_dy = calImageGradients(gt)
    preds_dx, preds_dy = calImageGradients(generated)
    rmse_grad_x, rmse_grad_y = mse_loss(labels_dx, preds_dx), mse_loss(labels_dy, preds_dy)
    psnr_grad_x, psnr_grad_y = torch.log10(rmse_grad_x), torch.log10(rmse_grad_y)
    p_loss = perc_loss(generated,gt)
    total_loss = 10 + psnr_intensity + 0.5*(psnr_grad_x + psnr_grad_y) + p_loss
    return total_loss

def save_checkpoints(file_path, epoch_idx, model, model_solver):
    print('[INFO] Saving checkpoint to %s ...' % ( file_path))
    checkpoint = {
        'epoch_idx': epoch_idx,
        'model_state_dict': model.state_dict(),
        'model_solver_state_dict': model_solver.state_dict()
    }
    torch.save(checkpoint, file_path)

BobYeah's avatar
Gaze    
BobYeah committed
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
def hook_fn_back(m, i, o):
  for grad in i:
    try:
      print("Input Grad:",m,grad.shape,grad.sum())
    except AttributeError: 
      print ("None found for Gradient")
  for grad in o:  
    try:
      print("Output Grad:",m,grad.shape,grad.sum())
    except AttributeError: 
      print ("None found for Gradient")
  print("\n")

def hook_fn_for(m, i, o):
  for grad in i:
    try:
      print("Input Feats:",m,grad.shape,grad.sum())
    except AttributeError: 
      print ("None found for Gradient")
  for grad in o:  
    try:
      print("Output Feats:",m,grad.shape,grad.sum())
    except AttributeError: 
      print ("None found for Gradient")
  print("\n")

BobYeah's avatar
BobYeah committed
330
if __name__ == "__main__":
BobYeah's avatar
Gaze    
BobYeah committed
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351

    ############################## generate phi and mask in pre-training
    phi_dict = {}
    mask_dict = {}
    idx_info_dict = {}
    print("generating phi and mask...")
    with open(DATA_JSON, encoding='utf-8') as file:
        dataset_desc = json.loads(file.read())
        for i in range(len(dataset_desc["focaldepth"])):
            # if i == 2:
            #     break
            idx = dataset_desc["idx"][i] 
            focaldepth = dataset_desc["focaldepth"][i]
            gazeX = dataset_desc["gazeX"][i]
            gazeY = dataset_desc["gazeY"][i]
            # print("focaldepth:",focaldepth," idx:",idx," gazeX:",gazeX," gazeY:",gazeY)
            phi,mask =  gen.CalculateRetinal2LayerMappings(focaldepth,torch.tensor([gazeX, gazeY]))
            phi_dict[idx]=phi
            mask_dict[idx]=mask
            idx_info_dict[idx]=[idx,focaldepth,gazeX,gazeY]
    print("generating phi and mask end.")
BobYeah's avatar
BobYeah committed
352
    
BobYeah's avatar
Gaze    
BobYeah committed
353
    # exit(0)
BobYeah's avatar
BobYeah committed
354
    #train
BobYeah's avatar
BobYeah committed
355
356
357
358
    train_data_loader = torch.utils.data.DataLoader(dataset=lightFieldDataLoader(DATA_FILE,DATA_JSON),
                                                    batch_size=BATCH_SIZE,
                                                    num_workers=0,
                                                    pin_memory=True,
BobYeah's avatar
BobYeah committed
359
                                                    shuffle=True,
BobYeah's avatar
BobYeah committed
360
361
                                                    drop_last=False)
    print(len(train_data_loader))
BobYeah's avatar
BobYeah committed
362

BobYeah's avatar
Gaze    
BobYeah committed
363
    # exit(0)
BobYeah's avatar
BobYeah committed
364

BobYeah's avatar
Gaze    
BobYeah committed
365
    ################################################ train #########################################################
BobYeah's avatar
BobYeah committed
366
367
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    lf_model = model()
BobYeah's avatar
Gaze    
BobYeah committed
368
369
370
371
    lf_model.apply(weight_init_normal)

    epoch_begin = 0

BobYeah's avatar
BobYeah committed
372
373
    if torch.cuda.is_available():
        lf_model = torch.nn.DataParallel(lf_model).cuda()
BobYeah's avatar
Gaze    
BobYeah committed
374
375
    lf_model.train()
    optimizer = torch.optim.Adam(lf_model.parameters(),lr=1e-2,betas=(0.9,0.999))
BobYeah's avatar
BobYeah committed
376

BobYeah's avatar
Gaze    
BobYeah committed
377
378
379
    print("begin training....")
    for epoch in range(epoch_begin, NUM_EPOCH):
        for batch_idx, (image_set, gt, df, gazeX, gazeY, sample_idx) in enumerate(train_data_loader):
BobYeah's avatar
BobYeah committed
380
381
382
383
384
385
            #reshape for input
            image_set = image_set.permute(0,1,4,2,3) # N LF C H W
            image_set = image_set.reshape(image_set.shape[0],-1,image_set.shape[3],image_set.shape[4]) # N, LFxC, H, W
            image_set = var_or_cuda(image_set)
            gt = gt.permute(0,3,1,2)
            gt = var_or_cuda(gt)
BobYeah's avatar
Gaze    
BobYeah committed
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410

            optimizer.zero_grad()
            output = lf_model(image_set,df,gazeX,gazeY)
            ########################### Use Pregen Phi and Mask ###################
            output1,mask = GenRetinalFromLayersBatch(output, gen, sample_idx, phi_dict, mask_dict)
            mask = var_or_cuda(mask)
            mask.requires_grad = False
            output_f = output1 * mask
            gt = gt * mask
            loss = loss_new(output_f,gt)
            print("Epoch:",epoch,",Iter:",batch_idx,",loss:",loss)

            ########################### Update ###################
            loss.backward()
            optimizer.step()

            ########################### Save #####################
            if ((epoch%50== 0) or epoch == 5):
                for i in range(output_f.size()[0]):
                    save_image(output[i][0:3].data,os.path.join(OUTPUT_DIR,"gaze_fac1_o_%.3f_%.3f_%.3f.png"%(df[i].data,gazeX[i].data,gazeY[i].data)))
                    save_image(output[i][3:6].data,os.path.join(OUTPUT_DIR,"gaze_fac2_o_%.3f_%.3f_%.3f.png"%(df[i].data,gazeX[i].data,gazeY[i].data)))
                    save_image(output_f[i][0:3].data,os.path.join(OUTPUT_DIR,"gaze_test1_o_%.3f_%.3f_%.3f.png"%(df[i].data,gazeX[i].data,gazeY[i].data)))
            if ((epoch%200 == 0) and epoch != 0 and batch_idx==len(train_data_loader)-1):
                save_checkpoints(os.path.join(OUTPUT_DIR, 'gaze-ckpt-epoch-%04d.pth' % (epoch + 1)),
                                epoch,lf_model,optimizer)