Commit 67c4de9e authored by BobYeah's avatar BobYeah
Browse files

sync

parent 421085df
import torch
def var_or_cuda(x):
if torch.cuda.is_available():
# x = x.cuda(non_blocking=True)
x = x.to('cuda:1')
return x
class residual_block(torch.nn.Module):
def __init__(self, OUT_CHANNELS_RB, delta_channel_dim,KERNEL_SIZE_RB,RNN=False):
super(residual_block,self).__init__()
self.delta_channel_dim = delta_channel_dim
self.out_channels_rb = OUT_CHANNELS_RB
self.hidden = None
self.RNN = RNN
if self.RNN:
self.layer1 = torch.nn.Sequential(
torch.nn.Conv2d((OUT_CHANNELS_RB+delta_channel_dim)*2,OUT_CHANNELS_RB+delta_channel_dim,KERNEL_SIZE_RB,stride=1,padding = 1),
torch.nn.BatchNorm2d(OUT_CHANNELS_RB+delta_channel_dim),
torch.nn.ELU()
)
self.layer2 = torch.nn.Sequential(
torch.nn.Conv2d(OUT_CHANNELS_RB+delta_channel_dim,OUT_CHANNELS_RB+delta_channel_dim,KERNEL_SIZE_RB,stride=1,padding = 1),
torch.nn.BatchNorm2d(OUT_CHANNELS_RB+delta_channel_dim),
torch.nn.ELU()
)
else:
self.layer1 = torch.nn.Sequential(
torch.nn.Conv2d(OUT_CHANNELS_RB+delta_channel_dim,OUT_CHANNELS_RB+delta_channel_dim,KERNEL_SIZE_RB,stride=1,padding = 1),
torch.nn.BatchNorm2d(OUT_CHANNELS_RB+delta_channel_dim),
torch.nn.ELU()
)
self.layer2 = torch.nn.Sequential(
torch.nn.Conv2d(OUT_CHANNELS_RB+delta_channel_dim,OUT_CHANNELS_RB+delta_channel_dim,KERNEL_SIZE_RB,stride=1,padding = 1),
torch.nn.BatchNorm2d(OUT_CHANNELS_RB+delta_channel_dim),
torch.nn.ELU()
)
def forward(self,input):
if self.RNN:
# print("input:",input.shape,"hidden:",self.hidden.shape)
inp = torch.cat((input,self.hidden.detach()),dim=1)
# print(inp.shape)
output = self.layer1(inp)
output = self.layer2(output)
output = input+output
self.hidden = output
else:
output = self.layer1(input)
output = self.layer2(output)
output = input+output
return output
def reset_hidden(self, inp):
size = list(inp.size())
size[1] = self.delta_channel_dim + self.out_channels_rb
size[2] = size[2]//2
size[3] = size[3]//2
hidden = torch.zeros(*(size))
self.hidden = var_or_cuda(hidden)
class deinterleave(torch.nn.Module):
def __init__(self, block_size):
super(deinterleave, self).__init__()
self.block_size = block_size
self.block_size_sq = block_size*block_size
def forward(self, input):
output = input.permute(0, 2, 3, 1)
(batch_size, d_height, d_width, d_depth) = output.size()
s_depth = int(d_depth / self.block_size_sq)
s_width = int(d_width * self.block_size)
s_height = int(d_height * self.block_size)
t_1 = output.reshape(batch_size, d_height, d_width, self.block_size_sq, s_depth)
spl = t_1.split(self.block_size, 3)
stack = [t_t.reshape(batch_size, d_height, s_width, s_depth) for t_t in spl]
output = torch.stack(stack,0).transpose(0,1).permute(0,2,1,3,4).reshape(batch_size, s_height, s_width, s_depth)
output = output.permute(0, 3, 1, 2)
return output
class interleave(torch.nn.Module):
def __init__(self, block_size):
super(interleave, self).__init__()
self.block_size = block_size
self.block_size_sq = block_size*block_size
def forward(self, input):
output = input.permute(0, 2, 3, 1)
(batch_size, s_height, s_width, s_depth) = output.size()
d_depth = s_depth * self.block_size_sq
d_width = int(s_width / self.block_size)
d_height = int(s_height / self.block_size)
t_1 = output.split(self.block_size, 2)
stack = [t_t.reshape(batch_size, d_height, d_depth) for t_t in t_1]
output = torch.stack(stack, 1)
output = output.permute(0, 2, 1, 3)
output = output.permute(0, 3, 1, 2)
return output
class model(torch.nn.Module):
def __init__(self,FIRSST_LAYER_CHANNELS,LAST_LAYER_CHANNELS,OUT_CHANNELS_RB,KERNEL_SIZE,KERNEL_SIZE_RB,INTERLEAVE_RATE,RNN=False):
super(model, self).__init__()
self.interleave = interleave(INTERLEAVE_RATE)
self.first_layer = torch.nn.Sequential(
torch.nn.Conv2d(FIRSST_LAYER_CHANNELS,OUT_CHANNELS_RB,KERNEL_SIZE,stride=1,padding=1),
torch.nn.BatchNorm2d(OUT_CHANNELS_RB),
torch.nn.ELU()
)
self.residual_block1 = residual_block(OUT_CHANNELS_RB,0,KERNEL_SIZE_RB,False)
self.residual_block2 = residual_block(OUT_CHANNELS_RB,2,KERNEL_SIZE_RB,False)
self.residual_block3 = residual_block(OUT_CHANNELS_RB,2,KERNEL_SIZE_RB,False)
# if RNN:
# self.residual_block3 = residual_block(OUT_CHANNELS_RB,6,KERNEL_SIZE_RB,True)
# self.residual_block4 = residual_block(OUT_CHANNELS_RB,6,KERNEL_SIZE_RB,True)
# self.residual_block5 = residual_block(OUT_CHANNELS_RB,6,KERNEL_SIZE_RB,True)
# else:
# self.residual_block3 = residual_block(OUT_CHANNELS_RB,6,KERNEL_SIZE_RB,False)
# self.residual_block4 = residual_block(OUT_CHANNELS_RB,6,KERNEL_SIZE_RB,False)
# self.residual_block5 = residual_block(OUT_CHANNELS_RB,6,KERNEL_SIZE_RB,False)
self.output_layer = torch.nn.Sequential(
torch.nn.Conv2d(OUT_CHANNELS_RB+2,LAST_LAYER_CHANNELS,KERNEL_SIZE,stride=1,padding=1),
torch.nn.BatchNorm2d(LAST_LAYER_CHANNELS),
torch.nn.Sigmoid()
)
self.deinterleave = deinterleave(INTERLEAVE_RATE)
def reset_hidden(self,inp):
self.residual_block3.reset_hidden(inp)
self.residual_block4.reset_hidden(inp)
self.residual_block5.reset_hidden(inp)
def forward(self, lightfield_images, pos_row, pos_col):
# lightfield_images: torch.Size([batch_size, channels * D, H, W])
# channels : RGB*D: 3*9, H:256, W:256
# print("lightfield_images:",lightfield_images.shape)
input_to_net = self.interleave(lightfield_images)
# print("after interleave:",input_to_net.shape)
input_to_rb = self.first_layer(input_to_net)
# print("input_to_rb1:",input_to_rb.shape)
output = self.residual_block1(input_to_rb)
pos_row_layer = torch.ones((input_to_rb.shape[0],1,input_to_rb.shape[2],input_to_rb.shape[3]))
pos_col_layer = torch.ones((input_to_rb.shape[0],1,input_to_rb.shape[2],input_to_rb.shape[3]))
for i in range(pos_row.shape[0]):
pos_row_layer[i] *= pos_row[i]
pos_col_layer[i] *= pos_col[i]
# print(depth_layer.shape)
pos_row_layer = var_or_cuda(pos_row_layer)
pos_col_layer = var_or_cuda(pos_col_layer)
output = torch.cat((output,pos_row_layer,pos_col_layer),dim=1)
output = self.residual_block2(output)
output = self.residual_block3(output)
output = self.output_layer(output)
output = self.deinterleave(output)
return output
\ No newline at end of file
import matplotlib.pyplot as plt
import numpy as np
import torch
import glm
import time
from .my import util
from .my import sample_in_pupil
class RetinalGen(object):
'''
Class for retinal generation process
Properties
--------
conf - multi-layers' parameters configuration
u - M x 3 tensor, M sample positions in pupil
p_r - H_r x W_r x 3 tensor, retinal pixel grid, [H_r, W_r] is the retinal resolution
Phi - N x H_r x W_r x M x 2 tensor, retinal to layers mapping, N is number of layers
mask - N x H_r x W_r x M x 2 tensor, indicates invalid (out-of-range) mapping
Methods
--------
'''
def __init__(self, conf):
'''
Initialize retinal generator instance
Parameters
--------
conf - multi-layers' parameters configuration
u - a M x 3 tensor stores M sample positions in pupil
'''
self.conf = conf
self.u = sample_in_pupil.CircleGen(conf.pupil_size, 5)
# self.u = u.to(cuda_dev)
# self.u = u # M x 3 M sample positions
self.D_r = conf.retinal_res # retinal res 480 x 640
self.N = conf.GetNLayers() # 2
self.M = self.u.size()[0] # samples
# p_rx, p_ry = torch.meshgrid(torch.tensor(range(0, self.D_r[0])),
# torch.tensor(range(0, self.D_r[1])))
# self.p_r = torch.cat([
# ((torch.stack([p_rx, p_ry], 2) + 0.5) / self.D_r - 0.5) * conf.GetEyeViewportSize(), # 眼球视野
# torch.ones(self.D_r[0], self.D_r[1], 1)
# ], 2)
self.p_r = torch.cat([
((util.MeshGrid(self.D_r) + 0.5) / self.D_r - 0.5) * conf.GetEyeViewportSize(),
torch.ones(self.D_r[0], self.D_r[1], 1)
], 2)
# self.Phi = torch.empty(N, D_r[0], D_r[1], M, 2, device=cuda_dev, dtype=torch.long)
# self.mask = torch.empty(self.N, self.D_r[0], self.D_r[1], self.M, 2, dtype=torch.float) # 2 x 480 x 640 x 41 x 2
def CalculateRetinal2LayerMappings(self, position, gaze_dir, df):
'''
Calculate the mapping matrix from retinal to layers.
Parameters
--------
position - 1 x 3 tensor, eye's position
gaze_dir - 1 x 2 tensor, gaze forward vector (with z normalized)
df - focus distance
Returns
--------
phi - N x H_r x W_r x M x 2, retinal to layers mapping, N is number of layers
phi_invalid - N x H_r x W_r x M x 1, indicates invalid (out-of-range) mapping
retinal_invalid - 1 x H_r x W_r, indicates invalid pixels in retinal image
'''
D = self.conf.layer_res
c = torch.tensor([ D[1] / 2, D[0] / 2 ]) # c: Center of layers (pixel)
D_r = self.conf.retinal_res # D_r: Resolution of retinal 480 640
V = self.conf.GetEyeViewportSize() # V: Viewport size of eye
p_f = self.p_r * df # p_f: H x W x 3, focus positions of retinal pixels on focus plane
# Calculate transformation from eye to display
gvec_lookat = glm.dvec3(gaze_dir[0], -gaze_dir[1], 1)
gmat_eye = glm.inverse(glm.lookAtLH(glm.dvec3(), gvec_lookat, glm.dvec3(0, 1, 0)))
eye_rot = util.Glm2Tensor(glm.dmat3(gmat_eye))
eye_center = torch.tensor([ position[0], -position[1], position[2] ])
u_rot = torch.mm(self.u, eye_rot)
v_rot = torch.matmul(p_f, eye_rot).unsqueeze(2).expand(
-1, -1, self.M, -1) - u_rot # v_rot: H x W x M x 3, rotated rays' direction vector
u_rot.add_(eye_center) # translate by eye's center
v_rot = v_rot.div(v_rot[:, :, :, 2].unsqueeze(3)) # make z = 1 for each direction vector in v_rot
phi = torch.empty(self.N, self.D_r[0], self.D_r[1], self.M, 2, dtype=torch.long)
for i in range(0, self.N):
dp_i = self.conf.GetPixelSizeOfLayer(i) # dp_i: Pixel size of layer i
d_i = self.conf.d_layer[i] # d_i: Distance of layer i
k = (d_i - u_rot[:, 2]).unsqueeze(1)
pi_r = (u_rot[:, 0:2] + v_rot[:, :, :, 0:2] * k) / dp_i # pi_r: H x W x M x 2, rays' pixel coord on layer i
phi[i, :, :, :, :] = torch.floor(pi_r + c)
# Calculate invalid mask (out-of-range elements in phi) and reduced to retinal
phi_invalid = (phi[:, :, :, :, 0] < 0) | (phi[:, :, :, :, 0] >= D[1]) | \
(phi[:, :, :, :, 1] < 0) | (phi[:, :, :, :, 1] >= D[0])
phi_invalid = phi_invalid.unsqueeze(4)
# print("phi_invalid:",phi_invalid.shape)
retinal_invalid = phi_invalid.amax((0, 3)).squeeze().unsqueeze(0)
# print("retinal_invalid:",retinal_invalid.shape)
# Fix invalid elements in phi
phi[phi_invalid.expand(-1, -1, -1, -1, 2)] = 0
return [ phi, phi_invalid, retinal_invalid ]
def GenRetinalFromLayers(self, layers, Phi):
'''
Generate retinal image from layers, using precalculated mapping matrix
Parameters
--------
layers - 3N x H x W, stacked layer images, with 3 channels in each layer
phi - N x H_r x W_r x M x 2, retinal to layers mapping, N is number of layers
Returns
--------
3 x H_r x W_r, 3 channels retinal image
'''
# FOR GRAYSCALE 1 FOR RGB 3
mapped_layers = torch.empty(self.N, 3, self.D_r[0], self.D_r[1], self.M) # 2 x 3 x 480 x 640 x 41
# print("mapped_layers:",mapped_layers.shape)
for i in range(0, Phi.size()[0]):
# torch.Size([3, 2, 320, 320, 2])
# print("gather layers:",layers[(i * 3) : (i * 3 + 3),Phi[i, :, :, :, 0],Phi[i, :, :, :, 1]].shape)
mapped_layers[i, :, :, :, :] = layers[(i * 3) : (i * 3 + 3),
Phi[i, :, :, :, 1],
Phi[i, :, :, :, 0]]
# print("mapped_layers:",mapped_layers.shape)
retinal = mapped_layers.prod(0).sum(3).div(Phi.size()[3])
# print("retinal:",retinal.shape)
return retinal
def GenRetinalFromLayersBatch(self, layers, Phi):
'''
Generate retinal image from layers, using precalculated mapping matrix
Parameters
--------
layers - 3N x H_l x W_l tensor, stacked layer images, with 3 channels in each layer
Returns
--------
3 x H_r x W_r tensor, 3 channels retinal image
H_r x W_r tensor, retinal image mask, indicates pixels valid or not
'''
mapped_layers = torch.empty(layers.size()[0], self.N, 3, self.D_r[0], self.D_r[1], self.M) #BS x Layers x C x H x W x Sample
# truth = torch.empty(layers.size()[0], self.N, 3, self.D_r[0], self.D_r[1], self.M)
# layers_truth = layers.clone()
# Phi_truth = Phi.clone()
layers = torch.stack((layers[:,0:3,:,:],layers[:,3:6,:,:]),dim=1) ## torch.Size([BS, Layer, RGB 3, 320, 320])
# Phi = Phi[:,:,None,:,:,:,:].expand(-1,-1,3,-1,-1,-1,-1)
# print("mapped_layers:",mapped_layers.shape) #torch.Size([2, 2, 3, 320, 320, 41])
# print("input layers:",layers.shape) ## torch.Size([2, 2, 3, 320, 320])
# print("input Phi:",Phi.shape) #torch.Size([2, 2, 320, 320, 41, 2])
# #没优化
# for i in range(0, Phi_truth.size()[0]):
# for j in range(0, Phi_truth.size()[1]):
# truth[i, j, :, :, :, :] = layers_truth[i, (j * 3) : (j * 3 + 3),
# Phi_truth[i, j, :, :, :, 0],
# Phi_truth[i, j, :, :, :, 1]]
#优化2
# start = time.time()
mapped_layers_op1 = mapped_layers.reshape(-1,
mapped_layers.shape[2],mapped_layers.shape[3],mapped_layers.shape[4],mapped_layers.shape[5])
# BatchSizexLayer Channel 3 320 320 41
layers_op1 = layers.reshape(-1,layers.shape[2],layers.shape[3],layers.shape[4]) # 2x2 3 320 320
Phi_op1 = Phi.reshape(-1,Phi.shape[2],Phi.shape[3],Phi.shape[4],Phi.shape[5]) # 2x2 320 320 41 2
x = Phi_op1[:,:,:,:,0] # 2x2 320 320 41
y = Phi_op1[:,:,:,:,1] # 2x2 320 320 41
# print("reshape:",time.time() - start)
# start = time.time()
mapped_layers_op1 = layers_op1[torch.arange(layers_op1.shape[0])[:, None, None, None], :, y, x] # x,y 切换
#2x2 320 320 41 3
# print("mapping one step:",time.time() - start)
# print("mapped_layers:",mapped_layers_op1.shape) # torch.Size([4, 3, 320, 320, 41])
# start = time.time()
mapped_layers_op1 = mapped_layers_op1.permute(0,4,1,2,3)
mapped_layers = mapped_layers_op1.reshape(mapped_layers.shape[0],mapped_layers.shape[1],
mapped_layers.shape[2],mapped_layers.shape[3],mapped_layers.shape[4],mapped_layers.shape[5])
# print("reshape end:",time.time() - start)
# print("test:")
# print((truth.cpu() == mapped_layers.cpu()).all())
#优化1
# start = time.time()
# mapped_layers_op1 = mapped_layers.reshape(-1,
# mapped_layers.shape[2],mapped_layers.shape[3],mapped_layers.shape[4],mapped_layers.shape[5])
# layers_op1 = layers.reshape(-1,layers.shape[2],layers.shape[3],layers.shape[4])
# Phi_op1 = Phi.reshape(-1,Phi.shape[2],Phi.shape[3],Phi.shape[4],Phi.shape[5])
# print("reshape:",time.time() - start)
# for i in range(0, Phi_op1.size()[0]):
# start = time.time()
# mapped_layers_op1[i, :, :, :, :] = layers_op1[i,:,
# Phi_op1[i, :, :, :, 0],
# Phi_op1[i, :, :, :, 1]]
# print("mapping one step:",time.time() - start)
# print("mapped_layers:",mapped_layers_op1.shape) # torch.Size([4, 3, 320, 320, 41])
# start = time.time()
# mapped_layers = mapped_layers_op1.reshape(mapped_layers.shape[0],mapped_layers.shape[1],
# mapped_layers.shape[2],mapped_layers.shape[3],mapped_layers.shape[4],mapped_layers.shape[5])
# print("reshape end:",time.time() - start)
# print("mapped_layers:",mapped_layers.shape) # torch.Size([2, 2, 3, 320, 320, 41])
retinal = mapped_layers.prod(1).sum(4).div(Phi.size()[4])
# print("retinal:",retinal.shape) # torch.Size([BatchSize, 3, 320, 320])
return retinal
## TO BE CHECK
def GenFoveaLayers(self, b_retinal, is_mask):
'''
Generate foveated layers for retinal images or masks
Parameters
--------
b_retinal - B x C x H_r x W_r, Batch of retinal images/masks
is_mask - Whether b_retinal is masks or images
Returns
--------
b_fovea_layers - N_f x (B x C x H[f] x W[f]) list of batch of foveated layers
'''
b_fovea_layers = []
for i in range(0, len(self.conf.eye_fovea_angles)):
k = self.conf.eye_fovea_downsamples[i]
region = self.conf.GetRegionOfFoveaLayer(i)
b_roi = b_retinal[:, :, region, region]
if k == 1:
b_fovea_layers.append(b_roi)
elif is_mask:
b_fovea_layers.append(torch.nn.functional.max_pool2d(b_roi.to(torch.float), k).to(torch.bool))
else:
b_fovea_layers.append(torch.nn.functional.avg_pool2d(b_roi, k))
return b_fovea_layers
# fovea_layers = []
# fovea_layer_masks = []
# fov = self.conf.eye_fovea_angles[-1]
# retinal_res = int(self.conf.retinal_res[0])
# for i in range(0, len(self.conf.eye_fovea_angles)):
# angle = self.conf.eye_fovea_angles[i]
# k = self.conf.eye_fovea_downsamples[i]
# roi_size = int(np.ceil(retinal_res * angle / fov))
# roi_offset = int((retinal_res - roi_size) / 2)
# roi_img = retinal[:, roi_offset:(roi_offset + roi_size), roi_offset:(roi_offset + roi_size)]
# roi_mask = retinal_mask[roi_offset:(roi_offset + roi_size), roi_offset:(roi_offset + roi_size)]
# if k == 1:
# fovea_layers.append(roi_img)
# fovea_layer_masks.append(roi_mask)
# else:
# fovea_layers.append(torch.nn.functional.avg_pool2d(roi_img.unsqueeze(0), k).squeeze(0))
# fovea_layer_masks.append(1 - torch.nn.functional.max_pool2d((1 - roi_mask).unsqueeze(0), k).squeeze(0))
# return [ fovea_layers, fovea_layer_masks ]
## TO BE CHECK
def GenFoveaLayersBatch(self, retinal, retinal_mask):
'''
Generate foveated layers and corresponding masks
Parameters
--------
retinal - Retinal image generated by GenRetinalFromLayers()
retinal_mask - Mask of retinal image, also generated by GenRetinalFromLayers()
Returns
--------
fovea_layers - list of foveated layers
fovea_layer_masks - list of mask images, corresponding to foveated layers
'''
fovea_layers = []
fovea_layer_masks = []
fov = self.conf.eye_fovea_angles[-1]
# print("fov:",fov)
retinal_res = int(self.conf.retinal_res[0])
# print("retinal_res:",retinal_res)
# print("len(self.conf.eye_fovea_angles):",len(self.conf.eye_fovea_angles))
for i in range(0, len(self.conf.eye_fovea_angles)):
angle = self.conf.eye_fovea_angles[i]
k = self.conf.eye_fovea_downsamples[i]
roi_size = int(np.ceil(retinal_res * angle / fov))
roi_offset = int((retinal_res - roi_size) / 2)
# [2, 3, 320, 320]
roi_img = retinal[:, :, roi_offset:(roi_offset + roi_size), roi_offset:(roi_offset + roi_size)]
# print("roi_img:",roi_img.shape)
# [2, 320, 320]
roi_mask = retinal_mask[:, roi_offset:(roi_offset + roi_size), roi_offset:(roi_offset + roi_size)]
# print("roi_mask:",roi_mask.shape)
if k == 1:
fovea_layers.append(roi_img)
fovea_layer_masks.append(roi_mask)
else:
fovea_layers.append(torch.nn.functional.avg_pool2d(roi_img, k))
fovea_layer_masks.append(1 - torch.nn.functional.max_pool2d((1 - roi_mask), k))
return [ fovea_layers, fovea_layer_masks ]
## TO BE CHECK
def GenFoveaRetinal(self, b_fovea_layers):
'''
Generate foveated retinal image by blending fovea layers
**Note: current implementation only support two fovea layers**
Parameters
--------
b_fovea_layers - N_f x (B x 3 x H[f] x W[f]), list of batch of (masked) foveated layers
Returns
--------
B x 3 x H_r x W_r, batch of foveated retinal images
'''
b_fovea_retinal = torch.nn.functional.interpolate(b_fovea_layers[1],
scale_factor=self.conf.eye_fovea_downsamples[1],
mode='bilinear', align_corners=False)
region = self.conf.GetRegionOfFoveaLayer(0)
blend = self.conf.eye_fovea_blend[0]
b_roi = b_fovea_retinal[:, :, region, region]
b_roi.mul_(1 - blend).add_(b_fovea_layers[0] * blend)
return b_fovea_retinal
This diff is collapsed.
import torch
import argparse
import os
import glob
import numpy as np
import torchvision.transforms as transforms
from torchvision.utils import save_image
from torchvision import datasets
from torch.utils.data import DataLoader
from torch.autograd import Variable
import cv2
from loss import *
import json
from baseline import *
from data import *
import torch.autograd.profiler as profiler
# param
BATCH_SIZE = 2
NUM_EPOCH = 1001
INTERLEAVE_RATE = 2
IM_H = 540
IM_W = 376
Retinal_IM_H = 540
Retinal_IM_W = 376
N = 4 # number of input light field stack
M = 1 # number of display layers
DATA_FILE = "/home/yejiannan/Project/deeplightfield/data/lf_syn"
DATA_JSON = "/home/yejiannan/Project/deeplightfield/data/data_lf_syn_full.json"
# DATA_VAL_JSON = "/home/yejiannan/Project/LightField/data/data_gaze_fovea_val.json"
OUTPUT_DIR = "/home/yejiannan/Project/deeplightfield/outputE/lf_syn_full1219"
OUT_CHANNELS_RB = 128
KERNEL_SIZE_RB = 3
KERNEL_SIZE = 3
LAST_LAYER_CHANNELS = 3 * INTERLEAVE_RATE**2
FIRSST_LAYER_CHANNELS = 12 * INTERLEAVE_RATE**2
from weight_init import weight_init_normal
def save_checkpoints(file_path, epoch_idx, model, model_solver):
print('[INFO] Saving checkpoint to %s ...' % ( file_path))
checkpoint = {
'epoch_idx': epoch_idx,
'model_state_dict': model.state_dict(),
'model_solver_state_dict': model_solver.state_dict()
}
torch.save(checkpoint, file_path)
mode = "Silence" #"Perf"
w_frame = 1.0
loss1 = PerceptionReconstructionLoss()
if __name__ == "__main__":
#train
train_data_loader = torch.utils.data.DataLoader(dataset=lightFieldSynDataLoader(DATA_FILE,DATA_JSON),
batch_size=BATCH_SIZE,
num_workers=8,
pin_memory=True,
shuffle=True,
drop_last=False)
#Data loader test
print(len(train_data_loader))
lf_model = model(FIRSST_LAYER_CHANNELS,LAST_LAYER_CHANNELS,OUT_CHANNELS_RB,KERNEL_SIZE,KERNEL_SIZE_RB,INTERLEAVE_RATE,RNN=False)
lf_model.apply(weight_init_normal)
lf_model.train()
epoch_begin = 0
if torch.cuda.is_available():
# lf_model = torch.nn.DataParallel(lf_model).cuda()
lf_model = lf_model.to('cuda:1')
optimizer = torch.optim.Adam(lf_model.parameters(),lr=5e-3,betas=(0.9,0.999))
# lf_model.output_layer.register_backward_hook(hook_fn_back)
if mode=="Perf":
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
print("begin training....")
for epoch in range(epoch_begin, NUM_EPOCH):
for batch_idx, (image_set, gt, pos_row, pos_col) in enumerate(train_data_loader):
if mode=="Perf":
end.record()
torch.cuda.synchronize()
print("load:",start.elapsed_time(end))
start.record()
#reshape for input
image_set = image_set.permute(0,1,4,2,3) # N LF C H W
image_set = image_set.reshape(image_set.shape[0],-1,image_set.shape[3],image_set.shape[4]) # N LFxC H W
image_set = var_or_cuda(image_set)
gt = gt.permute(0,3,1,2) # BS C H W
gt = var_or_cuda(gt)
if mode=="Perf":
end.record()
torch.cuda.synchronize()
print("data prepare:",start.elapsed_time(end))
start.record()
output = lf_model(image_set,pos_row, pos_col) # 2 6 376 540
if mode=="Perf":
end.record()
torch.cuda.synchronize()
print("forward:",start.elapsed_time(end))
start.record()
optimizer.zero_grad()
# print("output:",output.shape," gt:",gt.shape)
loss1_value = loss1(output,gt)
loss = (w_frame * loss1_value)
if mode=="Perf":
end.record()
torch.cuda.synchronize()
print("compute loss:",start.elapsed_time(end))
start.record()
loss.backward()
if mode=="Perf":
end.record()
torch.cuda.synchronize()
print("backward:",start.elapsed_time(end))
start.record()
optimizer.step()
if mode=="Perf":
end.record()
torch.cuda.synchronize()
print("update:",start.elapsed_time(end))
print("Epoch:",epoch,",Iter:",batch_idx,",loss:",loss.item())
# exit(0)
########################### Save #####################
if ((epoch%10== 0 and epoch != 0) or epoch == 2): # torch.Size([2, 5, 160, 160, 3])
for i in range(gt.size()[0]):
save_image(output[i].data,os.path.join(OUTPUT_DIR,"out_%.5f_%.5f.png"%(pos_col[i].data,pos_row[i].data)))
save_image(gt[i].data,os.path.join(OUTPUT_DIR,"gt_%.5f_%.5f.png"%(pos_col[i].data,pos_row[i].data)))
if ((epoch%100 == 0) and epoch != 0 and batch_idx==len(train_data_loader)-1):
save_checkpoints(os.path.join(OUTPUT_DIR, 'ckpt-epoch-%04d.pth' % (epoch)),epoch,lf_model,optimizer)
\ No newline at end of file
import torch
def GetDevice():
if torch.cuda.is_available():
return torch.device('cuda')
return torch.device('cpu')
\ No newline at end of file
...@@ -10,21 +10,24 @@ from tensorboardX import SummaryWriter ...@@ -10,21 +10,24 @@ from tensorboardX import SummaryWriter
from .loss.loss import PerceptionReconstructionLoss from .loss.loss import PerceptionReconstructionLoss
from .my import netio from .my import netio
from .my import util from .my import util
from .my import device
from .my.simple_perf import SimplePerf from .my.simple_perf import SimplePerf
from .data.lf_syn import LightFieldSynDataset from .data.lf_syn import LightFieldSynDataset
from .trans_unet import TransUnet from .trans_unet import TransUnet
device = torch.device("cuda:2") torch.cuda.set_device(2)
print("Set CUDA:%d as current device." % torch.cuda.current_device())
DATA_DIR = os.path.dirname(__file__) + '/data/lf_syn_2020.12.23' DATA_DIR = os.path.dirname(__file__) + '/data/lf_syn_2020.12.23'
TRAIN_DATA_DESC_FILE = DATA_DIR + '/train.json' TRAIN_DATA_DESC_FILE = DATA_DIR + '/train.json'
OUTPUT_DIR = DATA_DIR + '/output_low_lr' OUTPUT_DIR = DATA_DIR + '/output_bat2'
RUN_DIR = DATA_DIR + '/run_low_lr' RUN_DIR = DATA_DIR + '/run_bat2'
BATCH_SIZE = 1 BATCH_SIZE = 8
TEST_BATCH_SIZE = 10 TEST_BATCH_SIZE = 10
NUM_EPOCH = 1000 NUM_EPOCH = 1000
MODE = "Silence" # "Perf" MODE = "Silence" # "Perf"
EPOCH_BEGIN = 500 EPOCH_BEGIN = 0
def train(): def train():
...@@ -44,7 +47,7 @@ def train(): ...@@ -44,7 +47,7 @@ def train():
view_images=train_dataset.sparse_view_images, view_images=train_dataset.sparse_view_images,
view_depths=train_dataset.sparse_view_depths, view_depths=train_dataset.sparse_view_depths,
view_positions=train_dataset.sparse_view_positions, view_positions=train_dataset.sparse_view_positions,
diopter_of_layers=train_dataset.diopter_of_layers).to(device) diopter_of_layers=train_dataset.diopter_of_layers).to(device.GetDevice())
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
loss = PerceptionReconstructionLoss() loss = PerceptionReconstructionLoss()
...@@ -66,7 +69,7 @@ def train(): ...@@ -66,7 +69,7 @@ def train():
for epoch in range(EPOCH_BEGIN, NUM_EPOCH): for epoch in range(EPOCH_BEGIN, NUM_EPOCH):
for _, view_images, _, view_positions in train_data_loader: for _, view_images, _, view_positions in train_data_loader:
view_images = view_images.to(device) view_images = view_images.to(device.GetDevice())
perf.Checkpoint("Load") perf.Checkpoint("Load")
...@@ -106,7 +109,6 @@ def train(): ...@@ -106,7 +109,6 @@ def train():
solver=optimizer) solver=optimizer)
print("Train finished") print("Train finished")
netio.SaveNet('%s/model-epoch_%d.pth' % (RUN_DIR, epoch + 1), model)
def test(net_file: str): def test(net_file: str):
...@@ -125,7 +127,7 @@ def test(net_file: str): ...@@ -125,7 +127,7 @@ def test(net_file: str):
view_images=train_dataset.sparse_view_images, view_images=train_dataset.sparse_view_images,
view_depths=train_dataset.sparse_view_depths, view_depths=train_dataset.sparse_view_depths,
view_positions=train_dataset.sparse_view_positions, view_positions=train_dataset.sparse_view_positions,
diopter_of_layers=train_dataset.diopter_of_layers).to(device) diopter_of_layers=train_dataset.diopter_of_layers).to(device.GetDevice())
netio.LoadNet(net_file, model) netio.LoadNet(net_file, model)
# 3. Test on train dataset # 3. Test on train dataset
...@@ -142,5 +144,5 @@ def test(net_file: str): ...@@ -142,5 +144,5 @@ def test(net_file: str):
if __name__ == "__main__": if __name__ == "__main__":
#train() train()
test(RUN_DIR + '/model-epoch_1000.pth') #test(RUN_DIR + '/model-epoch_1000.pth')
...@@ -3,8 +3,7 @@ import torch ...@@ -3,8 +3,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from .pytorch_prototyping.pytorch_prototyping import * from .pytorch_prototyping.pytorch_prototyping import *
from .my import util from .my import util
from .my import device
device = torch.device("cuda:2")
class Encoder(nn.Module): class Encoder(nn.Module):
...@@ -66,15 +65,15 @@ class LatentSpaceTransformer(nn.Module): ...@@ -66,15 +65,15 @@ class LatentSpaceTransformer(nn.Module):
self.n_views = view_positions.size()[0] self.n_views = view_positions.size()[0]
self.diopter_of_layers = diopter_of_layers self.diopter_of_layers = diopter_of_layers
self.feat_coords = util.MeshGrid( self.feat_coords = util.MeshGrid(
(feat_dim, feat_dim)).to(device=device) (feat_dim, feat_dim)).to(device.GetDevice())
def forward(self, feats: torch.Tensor, def forward(self, feats: torch.Tensor,
feat_depths: torch.Tensor, feat_depths: torch.Tensor,
novel_views: torch.Tensor) -> torch.Tensor: novel_views: torch.Tensor) -> torch.Tensor:
trans_feats = torch.zeros(novel_views.size()[0], feats.size()[0], trans_feats = torch.zeros(novel_views.size()[0],
feats.size()[1], feats.size()[ feats.size()[0], feats.size()[1],
2], feats.size()[3], feats.size()[2], feats.size()[3],
device=device) device=device.GetDevice())
for i in range(novel_views.size()[0]): for i in range(novel_views.size()[0]):
for v in range(self.n_views): for v in range(self.n_views):
for l in range(len(self.diopter_of_layers)): for l in range(len(self.diopter_of_layers)):
...@@ -151,8 +150,8 @@ class TransUnet(nn.Module): ...@@ -151,8 +150,8 @@ class TransUnet(nn.Module):
latent_sidelength = 64 # The dimensions of the latent space latent_sidelength = 64 # The dimensions of the latent space
image_sidelength = view_images.size()[2] image_sidelength = view_images.size()[2]
self.view_images = view_images.to(device) self.view_images = view_images.to(device.GetDevice())
self.view_depths = view_depths.to(device) self.view_depths = view_depths.to(device.GetDevice())
self.n_views = view_images.size()[0] self.n_views = view_images.size()[0]
self.encoder = Encoder(nf0=nf0, self.encoder = Encoder(nf0=nf0,
out_channels=nf, out_channels=nf,
......
import torch
weightVarScale = 0.25
bias_stddev = 0.01
def weight_init_normal(m):
classname = m.__class__.__name__
if classname.find("Conv") != -1:
torch.nn.init.xavier_normal_(m.weight.data)
torch.nn.init.normal_(m.bias.data,mean = 0.0, std=bias_stddev)
elif classname.find("BatchNorm2d") != -1:
torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
torch.nn.init.constant_(m.bias.data, 0.0)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment