Commit 648dfd2c authored by BobYeah's avatar BobYeah
Browse files

focal depth insert middle

parent 4d8a98da
...@@ -25,9 +25,9 @@ IM_W = 640 ...@@ -25,9 +25,9 @@ IM_W = 640
N = 9 # number of input light field stack N = 9 # number of input light field stack
M = 2 # number of display layers M = 2 # number of display layers
DATA_FILE = "/home/yejiannan/Project/deeplightfield/data/try" DATA_FILE = "/home/yejiannan/Project/LightField/data/try"
DATA_JSON = "/home/yejiannan/Project/deeplightfield/data/data.json" DATA_JSON = "/home/yejiannan/Project/LightField/data/data.json"
OUTPUT_DIR = "/home/yejiannan/Project/deeplightfield/output" OUTPUT_DIR = "/home/yejiannan/Project/LightField/output"
class lightFieldDataLoader(torch.utils.data.dataset.Dataset): class lightFieldDataLoader(torch.utils.data.dataset.Dataset):
def __init__(self, file_dir_path, file_json, transforms=None): def __init__(self, file_dir_path, file_json, transforms=None):
...@@ -68,16 +68,16 @@ KERNEL_SIZE_RB = 3 ...@@ -68,16 +68,16 @@ KERNEL_SIZE_RB = 3
KERNEL_SIZE = 3 KERNEL_SIZE = 3
class residual_block(torch.nn.Module): class residual_block(torch.nn.Module):
def __init__(self): def __init__(self,delta_channel_dim):
super(residual_block,self).__init__() super(residual_block,self).__init__()
self.layer1 = torch.nn.Sequential( self.layer1 = torch.nn.Sequential(
torch.nn.Conv2d(OUT_CHANNELS_RB,OUT_CHANNELS_RB,KERNEL_SIZE_RB,stride=1,padding = 1), torch.nn.Conv2d(OUT_CHANNELS_RB+delta_channel_dim,OUT_CHANNELS_RB+delta_channel_dim,KERNEL_SIZE_RB,stride=1,padding = 1),
torch.nn.BatchNorm2d(OUT_CHANNELS_RB), torch.nn.BatchNorm2d(OUT_CHANNELS_RB+delta_channel_dim),
torch.nn.ELU() torch.nn.ELU()
) )
self.layer2 = torch.nn.Sequential( self.layer2 = torch.nn.Sequential(
torch.nn.Conv2d(OUT_CHANNELS_RB,OUT_CHANNELS_RB,KERNEL_SIZE_RB,stride=1,padding = 1), torch.nn.Conv2d(OUT_CHANNELS_RB+delta_channel_dim,OUT_CHANNELS_RB+delta_channel_dim,KERNEL_SIZE_RB,stride=1,padding = 1),
torch.nn.BatchNorm2d(OUT_CHANNELS_RB,OUT_CHANNELS_RB), torch.nn.BatchNorm2d(OUT_CHANNELS_RB+delta_channel_dim),
torch.nn.ELU() torch.nn.ELU()
) )
...@@ -127,7 +127,7 @@ class interleave(torch.nn.Module): ...@@ -127,7 +127,7 @@ class interleave(torch.nn.Module):
LAST_LAYER_CHANNELS = 6 * INTERLEAVE_RATE**2 LAST_LAYER_CHANNELS = 6 * INTERLEAVE_RATE**2
FIRSST_LAYER_CHANNELS = 28 * INTERLEAVE_RATE**2 FIRSST_LAYER_CHANNELS = 27 * INTERLEAVE_RATE**2
class model(torch.nn.Module): class model(torch.nn.Module):
def __init__(self): def __init__(self):
...@@ -140,12 +140,12 @@ class model(torch.nn.Module): ...@@ -140,12 +140,12 @@ class model(torch.nn.Module):
torch.nn.ELU() torch.nn.ELU()
) )
self.residual_block1 = residual_block() self.residual_block1 = residual_block(0)
self.residual_block2 = residual_block() self.residual_block2 = residual_block(1)
self.residual_block3 = residual_block() self.residual_block3 = residual_block(1)
self.output_layer = torch.nn.Sequential( self.output_layer = torch.nn.Sequential(
torch.nn.Conv2d(OUT_CHANNELS_RB,LAST_LAYER_CHANNELS,KERNEL_SIZE,stride=1,padding=1), torch.nn.Conv2d(OUT_CHANNELS_RB+1,LAST_LAYER_CHANNELS,KERNEL_SIZE,stride=1,padding=1),
torch.nn.BatchNorm2d(LAST_LAYER_CHANNELS), torch.nn.BatchNorm2d(LAST_LAYER_CHANNELS),
torch.nn.Tanh() torch.nn.Tanh()
) )
...@@ -160,8 +160,16 @@ class model(torch.nn.Module): ...@@ -160,8 +160,16 @@ class model(torch.nn.Module):
input_to_rb = self.first_layer(input_to_net) input_to_rb = self.first_layer(input_to_net)
output = self.residual_block1(input_to_rb) output = self.residual_block1(input_to_rb)
# print("output1:",output.shape) # print("output1:",output.shape)
output = self.residual_block2(output)
depth_layer = torch.ones((output.shape[0],1,output.shape[2],output.shape[3]))
# print(df.shape[0])
for i in range(focal_length.shape[0]):
depth_layer[i] = depth_layer[i] * focal_length[i]
# print(depth_layer.shape)
depth_layer = var_or_cuda(depth_layer)
output = torch.cat((output,depth_layer),dim=1)
output = self.residual_block2(output)
output = self.residual_block3(output) output = self.residual_block3(output)
# output = output + input_to_net # output = output + input_to_net
output = self.output_layer(output) output = self.output_layer(output)
...@@ -263,8 +271,6 @@ if __name__ == "__main__": ...@@ -263,8 +271,6 @@ if __name__ == "__main__":
# save_image(output[0][0:3].data,os.path.join(OUTPUT_DIR,"o%d_%d.png"%(epoch,batch_idx))) # save_image(output[0][0:3].data,os.path.join(OUTPUT_DIR,"o%d_%d.png"%(epoch,batch_idx)))
#test end #test end
train_data_loader = torch.utils.data.DataLoader(dataset=lightFieldDataLoader(DATA_FILE,DATA_JSON), train_data_loader = torch.utils.data.DataLoader(dataset=lightFieldDataLoader(DATA_FILE,DATA_JSON),
batch_size=BATCH_SIZE, batch_size=BATCH_SIZE,
num_workers=0, num_workers=0,
...@@ -285,19 +291,14 @@ if __name__ == "__main__": ...@@ -285,19 +291,14 @@ if __name__ == "__main__":
#reshape for input #reshape for input
image_set = image_set.permute(0,1,4,2,3) # N LF C H W image_set = image_set.permute(0,1,4,2,3) # N LF C H W
image_set = image_set.reshape(image_set.shape[0],-1,image_set.shape[3],image_set.shape[4]) # N, LFxC, H, W image_set = image_set.reshape(image_set.shape[0],-1,image_set.shape[3],image_set.shape[4]) # N, LFxC, H, W
depth_layer = torch.ones((image_set.shape[0],1,image_set.shape[2],image_set.shape[3]))
# print(df.shape[0])
for i in range(df.shape[0]):
depth_layer[i] = depth_layer[i] * df[i]
# print(depth_layer.shape)
image_set = torch.cat((image_set,depth_layer),dim=1)
image_set = var_or_cuda(image_set) image_set = var_or_cuda(image_set)
# image_set.to(device) # image_set.to(device)
gt = gt.permute(0,3,1,2) gt = gt.permute(0,3,1,2)
gt = var_or_cuda(gt) gt = var_or_cuda(gt)
# print("Epoch:",epoch,",Iter:",batch_idx,",Input shape:",image_set.shape, ",Input gt:",gt.shape) # print("Epoch:",epoch,",Iter:",batch_idx,",Input shape:",image_set.shape, ",Input gt:",gt.shape)
optimizer.zero_grad() optimizer.zero_grad()
output = lf_model(image_set,0) output = lf_model(image_set,df)
# print("output:",output.shape," df:",df.shape) # print("output:",output.shape," df:",df.shape)
output = GenRetinalFromLayersBatch(output,conf,df,v,u) output = GenRetinalFromLayersBatch(output,conf,df,v,u)
loss = loss_two_images(output,gt) loss = loss_two_images(output,gt)
...@@ -305,5 +306,5 @@ if __name__ == "__main__": ...@@ -305,5 +306,5 @@ if __name__ == "__main__":
loss.backward() loss.backward()
optimizer.step() optimizer.step()
for i in range(5): for i in range(5):
save_image(output[i][0:3].data,os.path.join(OUTPUT_DIR,"cuda_lr_5e-3_o%d_%d.png"%(epoch,i))) save_image(output[i][0:3].data,os.path.join(OUTPUT_DIR,"cuda_lr_5e-3_insertmid_o%d_%d.png"%(epoch,i)))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment