From 648dfd2c0985793a0cccfc71d75ed744e5277044 Mon Sep 17 00:00:00 2001 From: BobYeah <635596704@qq.com> Date: Tue, 10 Nov 2020 12:12:55 +0800 Subject: [PATCH] focal depth insert middle --- main.py | 49 +++++++++++++++++++++++++------------------------ 1 file changed, 25 insertions(+), 24 deletions(-) diff --git a/main.py b/main.py index f818c0e..216fb7f 100644 --- a/main.py +++ b/main.py @@ -25,9 +25,9 @@ IM_W = 640 N = 9 # number of input light field stack M = 2 # number of display layers -DATA_FILE = "/home/yejiannan/Project/deeplightfield/data/try" -DATA_JSON = "/home/yejiannan/Project/deeplightfield/data/data.json" -OUTPUT_DIR = "/home/yejiannan/Project/deeplightfield/output" +DATA_FILE = "/home/yejiannan/Project/LightField/data/try" +DATA_JSON = "/home/yejiannan/Project/LightField/data/data.json" +OUTPUT_DIR = "/home/yejiannan/Project/LightField/output" class lightFieldDataLoader(torch.utils.data.dataset.Dataset): def __init__(self, file_dir_path, file_json, transforms=None): @@ -68,16 +68,16 @@ KERNEL_SIZE_RB = 3 KERNEL_SIZE = 3 class residual_block(torch.nn.Module): - def __init__(self): + def __init__(self,delta_channel_dim): super(residual_block,self).__init__() self.layer1 = torch.nn.Sequential( - torch.nn.Conv2d(OUT_CHANNELS_RB,OUT_CHANNELS_RB,KERNEL_SIZE_RB,stride=1,padding = 1), - torch.nn.BatchNorm2d(OUT_CHANNELS_RB), + torch.nn.Conv2d(OUT_CHANNELS_RB+delta_channel_dim,OUT_CHANNELS_RB+delta_channel_dim,KERNEL_SIZE_RB,stride=1,padding = 1), + torch.nn.BatchNorm2d(OUT_CHANNELS_RB+delta_channel_dim), torch.nn.ELU() ) self.layer2 = torch.nn.Sequential( - torch.nn.Conv2d(OUT_CHANNELS_RB,OUT_CHANNELS_RB,KERNEL_SIZE_RB,stride=1,padding = 1), - torch.nn.BatchNorm2d(OUT_CHANNELS_RB,OUT_CHANNELS_RB), + torch.nn.Conv2d(OUT_CHANNELS_RB+delta_channel_dim,OUT_CHANNELS_RB+delta_channel_dim,KERNEL_SIZE_RB,stride=1,padding = 1), + torch.nn.BatchNorm2d(OUT_CHANNELS_RB+delta_channel_dim), torch.nn.ELU() ) @@ -127,7 +127,7 @@ class interleave(torch.nn.Module): LAST_LAYER_CHANNELS = 6 * INTERLEAVE_RATE**2 -FIRSST_LAYER_CHANNELS = 28 * INTERLEAVE_RATE**2 +FIRSST_LAYER_CHANNELS = 27 * INTERLEAVE_RATE**2 class model(torch.nn.Module): def __init__(self): @@ -140,12 +140,12 @@ class model(torch.nn.Module): torch.nn.ELU() ) - self.residual_block1 = residual_block() - self.residual_block2 = residual_block() - self.residual_block3 = residual_block() + self.residual_block1 = residual_block(0) + self.residual_block2 = residual_block(1) + self.residual_block3 = residual_block(1) self.output_layer = torch.nn.Sequential( - torch.nn.Conv2d(OUT_CHANNELS_RB,LAST_LAYER_CHANNELS,KERNEL_SIZE,stride=1,padding=1), + torch.nn.Conv2d(OUT_CHANNELS_RB+1,LAST_LAYER_CHANNELS,KERNEL_SIZE,stride=1,padding=1), torch.nn.BatchNorm2d(LAST_LAYER_CHANNELS), torch.nn.Tanh() ) @@ -160,8 +160,16 @@ class model(torch.nn.Module): input_to_rb = self.first_layer(input_to_net) output = self.residual_block1(input_to_rb) # print("output1:",output.shape) - output = self.residual_block2(output) + depth_layer = torch.ones((output.shape[0],1,output.shape[2],output.shape[3])) + # print(df.shape[0]) + for i in range(focal_length.shape[0]): + depth_layer[i] = depth_layer[i] * focal_length[i] + # print(depth_layer.shape) + depth_layer = var_or_cuda(depth_layer) + output = torch.cat((output,depth_layer),dim=1) + + output = self.residual_block2(output) output = self.residual_block3(output) # output = output + input_to_net output = self.output_layer(output) @@ -263,8 +271,6 @@ if __name__ == "__main__": # save_image(output[0][0:3].data,os.path.join(OUTPUT_DIR,"o%d_%d.png"%(epoch,batch_idx))) #test end - - train_data_loader = torch.utils.data.DataLoader(dataset=lightFieldDataLoader(DATA_FILE,DATA_JSON), batch_size=BATCH_SIZE, num_workers=0, @@ -285,19 +291,14 @@ if __name__ == "__main__": #reshape for input image_set = image_set.permute(0,1,4,2,3) # N LF C H W image_set = image_set.reshape(image_set.shape[0],-1,image_set.shape[3],image_set.shape[4]) # N, LFxC, H, W - depth_layer = torch.ones((image_set.shape[0],1,image_set.shape[2],image_set.shape[3])) - # print(df.shape[0]) - for i in range(df.shape[0]): - depth_layer[i] = depth_layer[i] * df[i] - # print(depth_layer.shape) - image_set = torch.cat((image_set,depth_layer),dim=1) + image_set = var_or_cuda(image_set) # image_set.to(device) gt = gt.permute(0,3,1,2) gt = var_or_cuda(gt) # print("Epoch:",epoch,",Iter:",batch_idx,",Input shape:",image_set.shape, ",Input gt:",gt.shape) optimizer.zero_grad() - output = lf_model(image_set,0) + output = lf_model(image_set,df) # print("output:",output.shape," df:",df.shape) output = GenRetinalFromLayersBatch(output,conf,df,v,u) loss = loss_two_images(output,gt) @@ -305,5 +306,5 @@ if __name__ == "__main__": loss.backward() optimizer.step() for i in range(5): - save_image(output[i][0:3].data,os.path.join(OUTPUT_DIR,"cuda_lr_5e-3_o%d_%d.png"%(epoch,i))) + save_image(output[i][0:3].data,os.path.join(OUTPUT_DIR,"cuda_lr_5e-3_insertmid_o%d_%d.png"%(epoch,i))) -- GitLab