From 648dfd2c0985793a0cccfc71d75ed744e5277044 Mon Sep 17 00:00:00 2001
From: BobYeah <635596704@qq.com>
Date: Tue, 10 Nov 2020 12:12:55 +0800
Subject: [PATCH] focal depth insert middle

---
 main.py | 49 +++++++++++++++++++++++++------------------------
 1 file changed, 25 insertions(+), 24 deletions(-)

diff --git a/main.py b/main.py
index f818c0e..216fb7f 100644
--- a/main.py
+++ b/main.py
@@ -25,9 +25,9 @@ IM_W = 640
 N = 9 # number of input light field stack
 M = 2 # number of display layers
 
-DATA_FILE = "/home/yejiannan/Project/deeplightfield/data/try"
-DATA_JSON = "/home/yejiannan/Project/deeplightfield/data/data.json"
-OUTPUT_DIR = "/home/yejiannan/Project/deeplightfield/output"
+DATA_FILE = "/home/yejiannan/Project/LightField/data/try"
+DATA_JSON = "/home/yejiannan/Project/LightField/data/data.json"
+OUTPUT_DIR = "/home/yejiannan/Project/LightField/output"
 
 class lightFieldDataLoader(torch.utils.data.dataset.Dataset):
     def __init__(self, file_dir_path, file_json, transforms=None):
@@ -68,16 +68,16 @@ KERNEL_SIZE_RB = 3
 KERNEL_SIZE = 3
 
 class residual_block(torch.nn.Module):
-    def __init__(self):
+    def __init__(self,delta_channel_dim):
         super(residual_block,self).__init__()
         self.layer1 = torch.nn.Sequential(
-            torch.nn.Conv2d(OUT_CHANNELS_RB,OUT_CHANNELS_RB,KERNEL_SIZE_RB,stride=1,padding = 1),
-            torch.nn.BatchNorm2d(OUT_CHANNELS_RB),
+            torch.nn.Conv2d(OUT_CHANNELS_RB+delta_channel_dim,OUT_CHANNELS_RB+delta_channel_dim,KERNEL_SIZE_RB,stride=1,padding = 1),
+            torch.nn.BatchNorm2d(OUT_CHANNELS_RB+delta_channel_dim),
             torch.nn.ELU()
         )
         self.layer2 = torch.nn.Sequential(
-            torch.nn.Conv2d(OUT_CHANNELS_RB,OUT_CHANNELS_RB,KERNEL_SIZE_RB,stride=1,padding = 1),
-            torch.nn.BatchNorm2d(OUT_CHANNELS_RB,OUT_CHANNELS_RB),
+            torch.nn.Conv2d(OUT_CHANNELS_RB+delta_channel_dim,OUT_CHANNELS_RB+delta_channel_dim,KERNEL_SIZE_RB,stride=1,padding = 1),
+            torch.nn.BatchNorm2d(OUT_CHANNELS_RB+delta_channel_dim),
             torch.nn.ELU()
         )
 
@@ -127,7 +127,7 @@ class interleave(torch.nn.Module):
 
 
 LAST_LAYER_CHANNELS = 6 * INTERLEAVE_RATE**2
-FIRSST_LAYER_CHANNELS = 28 * INTERLEAVE_RATE**2
+FIRSST_LAYER_CHANNELS = 27 * INTERLEAVE_RATE**2
 
 class model(torch.nn.Module):
     def __init__(self):
@@ -140,12 +140,12 @@ class model(torch.nn.Module):
             torch.nn.ELU()
         )
         
-        self.residual_block1 = residual_block()
-        self.residual_block2 = residual_block()
-        self.residual_block3 = residual_block()
+        self.residual_block1 = residual_block(0)
+        self.residual_block2 = residual_block(1)
+        self.residual_block3 = residual_block(1)
 
         self.output_layer = torch.nn.Sequential(
-            torch.nn.Conv2d(OUT_CHANNELS_RB,LAST_LAYER_CHANNELS,KERNEL_SIZE,stride=1,padding=1),
+            torch.nn.Conv2d(OUT_CHANNELS_RB+1,LAST_LAYER_CHANNELS,KERNEL_SIZE,stride=1,padding=1),
             torch.nn.BatchNorm2d(LAST_LAYER_CHANNELS),
             torch.nn.Tanh()
         )
@@ -160,8 +160,16 @@ class model(torch.nn.Module):
         input_to_rb = self.first_layer(input_to_net)
         output = self.residual_block1(input_to_rb)
         # print("output1:",output.shape)
-        output = self.residual_block2(output)
         
+        depth_layer = torch.ones((output.shape[0],1,output.shape[2],output.shape[3]))
+        # print(df.shape[0])
+        for i in range(focal_length.shape[0]):
+            depth_layer[i] = depth_layer[i] * focal_length[i]
+            # print(depth_layer.shape)
+        depth_layer = var_or_cuda(depth_layer)
+        output = torch.cat((output,depth_layer),dim=1)
+
+        output = self.residual_block2(output)
         output = self.residual_block3(output)
         # output = output + input_to_net
         output = self.output_layer(output)
@@ -263,8 +271,6 @@ if __name__ == "__main__":
     # save_image(output[0][0:3].data,os.path.join(OUTPUT_DIR,"o%d_%d.png"%(epoch,batch_idx)))
     #test end
 
-   
-
     train_data_loader = torch.utils.data.DataLoader(dataset=lightFieldDataLoader(DATA_FILE,DATA_JSON),
                                                     batch_size=BATCH_SIZE,
                                                     num_workers=0,
@@ -285,19 +291,14 @@ if __name__ == "__main__":
             #reshape for input
             image_set = image_set.permute(0,1,4,2,3) # N LF C H W
             image_set = image_set.reshape(image_set.shape[0],-1,image_set.shape[3],image_set.shape[4]) # N, LFxC, H, W
-            depth_layer = torch.ones((image_set.shape[0],1,image_set.shape[2],image_set.shape[3]))
-            # print(df.shape[0])
-            for i in range(df.shape[0]):
-                depth_layer[i] = depth_layer[i] * df[i]
-            # print(depth_layer.shape)
-            image_set = torch.cat((image_set,depth_layer),dim=1)
+            
             image_set = var_or_cuda(image_set)
             # image_set.to(device)
             gt = gt.permute(0,3,1,2)
             gt = var_or_cuda(gt)
             # print("Epoch:",epoch,",Iter:",batch_idx,",Input shape:",image_set.shape, ",Input gt:",gt.shape)
             optimizer.zero_grad()
-            output = lf_model(image_set,0)
+            output = lf_model(image_set,df)
             # print("output:",output.shape," df:",df.shape)
             output = GenRetinalFromLayersBatch(output,conf,df,v,u)
             loss = loss_two_images(output,gt)
@@ -305,5 +306,5 @@ if __name__ == "__main__":
             loss.backward()
             optimizer.step()
             for i in range(5):
-                save_image(output[i][0:3].data,os.path.join(OUTPUT_DIR,"cuda_lr_5e-3_o%d_%d.png"%(epoch,i)))
+                save_image(output[i][0:3].data,os.path.join(OUTPUT_DIR,"cuda_lr_5e-3_insertmid_o%d_%d.png"%(epoch,i)))
 
-- 
GitLab