From c5434e97764d2984441f82b6e667b93cb3c622e7 Mon Sep 17 00:00:00 2001
From: BobYeah <635596704@qq.com>
Date: Sat, 19 Dec 2020 15:17:45 +0800
Subject: [PATCH] Updatae CNN main

---
 main_lf_syn.py | 147 +++++++++++++++++++++++++++++++++++++++++++++++++
 1 file changed, 147 insertions(+)
 create mode 100644 main_lf_syn.py

diff --git a/main_lf_syn.py b/main_lf_syn.py
new file mode 100644
index 0000000..f6de1e1
--- /dev/null
+++ b/main_lf_syn.py
@@ -0,0 +1,147 @@
+import torch
+import argparse
+import os
+import glob
+import numpy as np
+import torchvision.transforms as transforms
+from torchvision.utils import save_image
+
+from torchvision import datasets
+from torch.utils.data import DataLoader 
+from torch.autograd import Variable
+
+import cv2
+from loss import *
+import json
+
+from baseline import *
+from data import * 
+
+import torch.autograd.profiler as profiler
+# param
+BATCH_SIZE = 2
+NUM_EPOCH = 1001
+INTERLEAVE_RATE = 2
+IM_H = 540
+IM_W = 376
+Retinal_IM_H = 540
+Retinal_IM_W = 376
+N = 4 # number of input light field stack
+M = 1 # number of display layers
+DATA_FILE = "/home/yejiannan/Project/LightField/data/lf_syn"
+DATA_JSON = "/home/yejiannan/Project/LightField/data/data_lf_syn_full.json"
+# DATA_VAL_JSON = "/home/yejiannan/Project/LightField/data/data_gaze_fovea_val.json"
+OUTPUT_DIR = "/home/yejiannan/Project/LightField/outputE/lf_syn_full_perc"
+OUT_CHANNELS_RB = 128
+KERNEL_SIZE_RB = 3
+KERNEL_SIZE = 3
+LAST_LAYER_CHANNELS = 3 * INTERLEAVE_RATE**2
+FIRSST_LAYER_CHANNELS = 12 * INTERLEAVE_RATE**2
+
+from weight_init import weight_init_normal
+
+def save_checkpoints(file_path, epoch_idx, model, model_solver):
+    print('[INFO] Saving checkpoint to %s ...' % ( file_path))
+    checkpoint = {
+        'epoch_idx': epoch_idx,
+        'model_state_dict': model.state_dict(),
+        'model_solver_state_dict': model_solver.state_dict()
+    }
+    torch.save(checkpoint, file_path)
+
+mode = "Silence" #"Perf"
+w_frame = 0.9
+loss1 = PerceptionReconstructionLoss()
+if __name__ == "__main__":
+    #train
+    train_data_loader = torch.utils.data.DataLoader(dataset=lightFieldSynDataLoader(DATA_FILE,DATA_JSON),
+                                                    batch_size=BATCH_SIZE,
+                                                    num_workers=8,
+                                                    pin_memory=True,
+                                                    shuffle=True,
+                                                    drop_last=False)
+    #Data loader test
+    print(len(train_data_loader))
+
+    lf_model = model(FIRSST_LAYER_CHANNELS,LAST_LAYER_CHANNELS,OUT_CHANNELS_RB,KERNEL_SIZE,KERNEL_SIZE_RB,INTERLEAVE_RATE,RNN=False)
+    lf_model.apply(weight_init_normal)
+    lf_model.train()
+    epoch_begin = 0
+
+    if torch.cuda.is_available():
+        # lf_model = torch.nn.DataParallel(lf_model).cuda()
+        lf_model = lf_model.to('cuda:3')
+    
+    optimizer = torch.optim.Adam(lf_model.parameters(),lr=5e-3,betas=(0.9,0.999))
+    
+    # lf_model.output_layer.register_backward_hook(hook_fn_back)
+    if mode=="Perf":
+        start = torch.cuda.Event(enable_timing=True)
+        end = torch.cuda.Event(enable_timing=True)
+        start.record()
+    print("begin training....")
+    for epoch in range(epoch_begin, NUM_EPOCH):
+        for batch_idx, (image_set, gt, pos_row, pos_col) in enumerate(train_data_loader):
+            if mode=="Perf":
+                end.record()
+                torch.cuda.synchronize()
+                print("load:",start.elapsed_time(end))
+
+                start.record()
+            #reshape for input
+            image_set = image_set.permute(0,1,4,2,3) # N LF C H W 
+            image_set = image_set.reshape(image_set.shape[0],-1,image_set.shape[3],image_set.shape[4]) # N LFxC H W
+            image_set = var_or_cuda(image_set)
+
+            gt = gt.permute(0,3,1,2) # BS C H W 
+            gt = var_or_cuda(gt)
+
+            if mode=="Perf":
+                end.record()
+                torch.cuda.synchronize()
+                print("data prepare:",start.elapsed_time(end))
+
+                start.record()
+            
+            output = lf_model(image_set,pos_row, pos_col) # 2 6 376 540
+
+            if mode=="Perf":
+                end.record()
+                torch.cuda.synchronize()
+                print("forward:",start.elapsed_time(end))
+
+                start.record()
+            optimizer.zero_grad()
+            # print("output:",output.shape," gt:",gt.shape)
+            loss1_value = loss1(output,gt)
+            loss = (w_frame * loss1_value)
+
+            if mode=="Perf":
+                end.record()
+                torch.cuda.synchronize()
+                print("compute loss:",start.elapsed_time(end))
+
+                start.record()
+            loss.backward()
+            if mode=="Perf":
+                end.record()
+                torch.cuda.synchronize()
+                print("backward:",start.elapsed_time(end))
+
+                start.record()
+            optimizer.step()
+            if mode=="Perf":
+                end.record()
+                torch.cuda.synchronize()
+                print("update:",start.elapsed_time(end))
+
+            print("Epoch:",epoch,",Iter:",batch_idx,",loss:",loss.item())
+            
+            # exit(0)
+            ########################### Save #####################
+            if ((epoch%10== 0 and epoch != 0) or epoch == 2): # torch.Size([2, 5, 160, 160, 3])
+                for i in range(gt.size()[0]):
+                    save_image(output[i].data,os.path.join(OUTPUT_DIR,"out_%.5f_%.5f.png"%(pos_col[i].data,pos_row[i].data)))
+                    save_image(gt[i].data,os.path.join(OUTPUT_DIR,"gt_%.5f_%.5f.png"%(pos_col[i].data,pos_row[i].data)))
+            if ((epoch%100 == 0) and epoch != 0 and batch_idx==len(train_data_loader)-1):
+                save_checkpoints(os.path.join(OUTPUT_DIR, 'ckpt-epoch-%04d.pth' % (epoch)),epoch,lf_model,optimizer)
\ No newline at end of file
-- 
GitLab