Commit 67c4de9e authored by BobYeah's avatar BobYeah
Browse files

sync

parent 421085df
import torch
def var_or_cuda(x):
if torch.cuda.is_available():
# x = x.cuda(non_blocking=True)
x = x.to('cuda:1')
return x
class residual_block(torch.nn.Module):
def __init__(self, OUT_CHANNELS_RB, delta_channel_dim,KERNEL_SIZE_RB,RNN=False):
super(residual_block,self).__init__()
self.delta_channel_dim = delta_channel_dim
self.out_channels_rb = OUT_CHANNELS_RB
self.hidden = None
self.RNN = RNN
if self.RNN:
self.layer1 = torch.nn.Sequential(
torch.nn.Conv2d((OUT_CHANNELS_RB+delta_channel_dim)*2,OUT_CHANNELS_RB+delta_channel_dim,KERNEL_SIZE_RB,stride=1,padding = 1),
torch.nn.BatchNorm2d(OUT_CHANNELS_RB+delta_channel_dim),
torch.nn.ELU()
)
self.layer2 = torch.nn.Sequential(
torch.nn.Conv2d(OUT_CHANNELS_RB+delta_channel_dim,OUT_CHANNELS_RB+delta_channel_dim,KERNEL_SIZE_RB,stride=1,padding = 1),
torch.nn.BatchNorm2d(OUT_CHANNELS_RB+delta_channel_dim),
torch.nn.ELU()
)
else:
self.layer1 = torch.nn.Sequential(
torch.nn.Conv2d(OUT_CHANNELS_RB+delta_channel_dim,OUT_CHANNELS_RB+delta_channel_dim,KERNEL_SIZE_RB,stride=1,padding = 1),
torch.nn.BatchNorm2d(OUT_CHANNELS_RB+delta_channel_dim),
torch.nn.ELU()
)
self.layer2 = torch.nn.Sequential(
torch.nn.Conv2d(OUT_CHANNELS_RB+delta_channel_dim,OUT_CHANNELS_RB+delta_channel_dim,KERNEL_SIZE_RB,stride=1,padding = 1),
torch.nn.BatchNorm2d(OUT_CHANNELS_RB+delta_channel_dim),
torch.nn.ELU()
)
def forward(self,input):
if self.RNN:
# print("input:",input.shape,"hidden:",self.hidden.shape)
inp = torch.cat((input,self.hidden.detach()),dim=1)
# print(inp.shape)
output = self.layer1(inp)
output = self.layer2(output)
output = input+output
self.hidden = output
else:
output = self.layer1(input)
output = self.layer2(output)
output = input+output
return output
def reset_hidden(self, inp):
size = list(inp.size())
size[1] = self.delta_channel_dim + self.out_channels_rb
size[2] = size[2]//2
size[3] = size[3]//2
hidden = torch.zeros(*(size))
self.hidden = var_or_cuda(hidden)
class deinterleave(torch.nn.Module):
def __init__(self, block_size):
super(deinterleave, self).__init__()
self.block_size = block_size
self.block_size_sq = block_size*block_size
def forward(self, input):
output = input.permute(0, 2, 3, 1)
(batch_size, d_height, d_width, d_depth) = output.size()
s_depth = int(d_depth / self.block_size_sq)
s_width = int(d_width * self.block_size)
s_height = int(d_height * self.block_size)
t_1 = output.reshape(batch_size, d_height, d_width, self.block_size_sq, s_depth)
spl = t_1.split(self.block_size, 3)
stack = [t_t.reshape(batch_size, d_height, s_width, s_depth) for t_t in spl]
output = torch.stack(stack,0).transpose(0,1).permute(0,2,1,3,4).reshape(batch_size, s_height, s_width, s_depth)
output = output.permute(0, 3, 1, 2)
return output
class interleave(torch.nn.Module):
def __init__(self, block_size):
super(interleave, self).__init__()
self.block_size = block_size
self.block_size_sq = block_size*block_size
def forward(self, input):
output = input.permute(0, 2, 3, 1)
(batch_size, s_height, s_width, s_depth) = output.size()
d_depth = s_depth * self.block_size_sq
d_width = int(s_width / self.block_size)
d_height = int(s_height / self.block_size)
t_1 = output.split(self.block_size, 2)
stack = [t_t.reshape(batch_size, d_height, d_depth) for t_t in t_1]
output = torch.stack(stack, 1)
output = output.permute(0, 2, 1, 3)
output = output.permute(0, 3, 1, 2)
return output
class model(torch.nn.Module):
def __init__(self,FIRSST_LAYER_CHANNELS,LAST_LAYER_CHANNELS,OUT_CHANNELS_RB,KERNEL_SIZE,KERNEL_SIZE_RB,INTERLEAVE_RATE,RNN=False):
super(model, self).__init__()
self.interleave = interleave(INTERLEAVE_RATE)
self.first_layer = torch.nn.Sequential(
torch.nn.Conv2d(FIRSST_LAYER_CHANNELS,OUT_CHANNELS_RB,KERNEL_SIZE,stride=1,padding=1),
torch.nn.BatchNorm2d(OUT_CHANNELS_RB),
torch.nn.ELU()
)
self.residual_block1 = residual_block(OUT_CHANNELS_RB,0,KERNEL_SIZE_RB,False)
self.residual_block2 = residual_block(OUT_CHANNELS_RB,2,KERNEL_SIZE_RB,False)
self.residual_block3 = residual_block(OUT_CHANNELS_RB,2,KERNEL_SIZE_RB,False)
# if RNN:
# self.residual_block3 = residual_block(OUT_CHANNELS_RB,6,KERNEL_SIZE_RB,True)
# self.residual_block4 = residual_block(OUT_CHANNELS_RB,6,KERNEL_SIZE_RB,True)
# self.residual_block5 = residual_block(OUT_CHANNELS_RB,6,KERNEL_SIZE_RB,True)
# else:
# self.residual_block3 = residual_block(OUT_CHANNELS_RB,6,KERNEL_SIZE_RB,False)
# self.residual_block4 = residual_block(OUT_CHANNELS_RB,6,KERNEL_SIZE_RB,False)
# self.residual_block5 = residual_block(OUT_CHANNELS_RB,6,KERNEL_SIZE_RB,False)
self.output_layer = torch.nn.Sequential(
torch.nn.Conv2d(OUT_CHANNELS_RB+2,LAST_LAYER_CHANNELS,KERNEL_SIZE,stride=1,padding=1),
torch.nn.BatchNorm2d(LAST_LAYER_CHANNELS),
torch.nn.Sigmoid()
)
self.deinterleave = deinterleave(INTERLEAVE_RATE)
def reset_hidden(self,inp):
self.residual_block3.reset_hidden(inp)
self.residual_block4.reset_hidden(inp)
self.residual_block5.reset_hidden(inp)
def forward(self, lightfield_images, pos_row, pos_col):
# lightfield_images: torch.Size([batch_size, channels * D, H, W])
# channels : RGB*D: 3*9, H:256, W:256
# print("lightfield_images:",lightfield_images.shape)
input_to_net = self.interleave(lightfield_images)
# print("after interleave:",input_to_net.shape)
input_to_rb = self.first_layer(input_to_net)
# print("input_to_rb1:",input_to_rb.shape)
output = self.residual_block1(input_to_rb)
pos_row_layer = torch.ones((input_to_rb.shape[0],1,input_to_rb.shape[2],input_to_rb.shape[3]))
pos_col_layer = torch.ones((input_to_rb.shape[0],1,input_to_rb.shape[2],input_to_rb.shape[3]))
for i in range(pos_row.shape[0]):
pos_row_layer[i] *= pos_row[i]
pos_col_layer[i] *= pos_col[i]
# print(depth_layer.shape)
pos_row_layer = var_or_cuda(pos_row_layer)
pos_col_layer = var_or_cuda(pos_col_layer)
output = torch.cat((output,pos_row_layer,pos_col_layer),dim=1)
output = self.residual_block2(output)
output = self.residual_block3(output)
output = self.output_layer(output)
output = self.deinterleave(output)
return output
\ No newline at end of file
import matplotlib.pyplot as plt
import numpy as np
import torch
import glm
import time
from .my import util
from .my import sample_in_pupil
class RetinalGen(object):
'''
Class for retinal generation process
Properties
--------
conf - multi-layers' parameters configuration
u - M x 3 tensor, M sample positions in pupil
p_r - H_r x W_r x 3 tensor, retinal pixel grid, [H_r, W_r] is the retinal resolution
Phi - N x H_r x W_r x M x 2 tensor, retinal to layers mapping, N is number of layers
mask - N x H_r x W_r x M x 2 tensor, indicates invalid (out-of-range) mapping
Methods
--------
'''
def __init__(self, conf):
'''
Initialize retinal generator instance
Parameters
--------
conf - multi-layers' parameters configuration
u - a M x 3 tensor stores M sample positions in pupil
'''
self.conf = conf
self.u = sample_in_pupil.CircleGen(conf.pupil_size, 5)
# self.u = u.to(cuda_dev)
# self.u = u # M x 3 M sample positions
self.D_r = conf.retinal_res # retinal res 480 x 640
self.N = conf.GetNLayers() # 2
self.M = self.u.size()[0] # samples
# p_rx, p_ry = torch.meshgrid(torch.tensor(range(0, self.D_r[0])),
# torch.tensor(range(0, self.D_r[1])))
# self.p_r = torch.cat([
# ((torch.stack([p_rx, p_ry], 2) + 0.5) / self.D_r - 0.5) * conf.GetEyeViewportSize(), # 眼球视野
# torch.ones(self.D_r[0], self.D_r[1], 1)
# ], 2)
self.p_r = torch.cat([
((util.MeshGrid(self.D_r) + 0.5) / self.D_r - 0.5) * conf.GetEyeViewportSize(),
torch.ones(self.D_r[0], self.D_r[1], 1)
], 2)
# self.Phi = torch.empty(N, D_r[0], D_r[1], M, 2, device=cuda_dev, dtype=torch.long)
# self.mask = torch.empty(self.N, self.D_r[0], self.D_r[1], self.M, 2, dtype=torch.float) # 2 x 480 x 640 x 41 x 2
def CalculateRetinal2LayerMappings(self, position, gaze_dir, df):
'''
Calculate the mapping matrix from retinal to layers.
Parameters
--------
position - 1 x 3 tensor, eye's position
gaze_dir - 1 x 2 tensor, gaze forward vector (with z normalized)
df - focus distance
Returns
--------
phi - N x H_r x W_r x M x 2, retinal to layers mapping, N is number of layers
phi_invalid - N x H_r x W_r x M x 1, indicates invalid (out-of-range) mapping
retinal_invalid - 1 x H_r x W_r, indicates invalid pixels in retinal image
'''
D = self.conf.layer_res
c = torch.tensor([ D[1] / 2, D[0] / 2 ]) # c: Center of layers (pixel)
D_r = self.conf.retinal_res # D_r: Resolution of retinal 480 640
V = self.conf.GetEyeViewportSize() # V: Viewport size of eye
p_f = self.p_r * df # p_f: H x W x 3, focus positions of retinal pixels on focus plane
# Calculate transformation from eye to display
gvec_lookat = glm.dvec3(gaze_dir[0], -gaze_dir[1], 1)
gmat_eye = glm.inverse(glm.lookAtLH(glm.dvec3(), gvec_lookat, glm.dvec3(0, 1, 0)))
eye_rot = util.Glm2Tensor(glm.dmat3(gmat_eye))
eye_center = torch.tensor([ position[0], -position[1], position[2] ])
u_rot = torch.mm(self.u, eye_rot)
v_rot = torch.matmul(p_f, eye_rot).unsqueeze(2).expand(
-1, -1, self.M, -1) - u_rot # v_rot: H x W x M x 3, rotated rays' direction vector
u_rot.add_(eye_center) # translate by eye's center
v_rot = v_rot.div(v_rot[:, :, :, 2].unsqueeze(3)) # make z = 1 for each direction vector in v_rot
phi = torch.empty(self.N, self.D_r[0], self.D_r[1], self.M, 2, dtype=torch.long)
for i in range(0, self.N):
dp_i = self.conf.GetPixelSizeOfLayer(i) # dp_i: Pixel size of layer i
d_i = self.conf.d_layer[i] # d_i: Distance of layer i
k = (d_i - u_rot[:, 2]).unsqueeze(1)
pi_r = (u_rot[:, 0:2] + v_rot[:, :, :, 0:2] * k) / dp_i # pi_r: H x W x M x 2, rays' pixel coord on layer i
phi[i, :, :, :, :] = torch.floor(pi_r + c)
# Calculate invalid mask (out-of-range elements in phi) and reduced to retinal
phi_invalid = (phi[:, :, :, :, 0] < 0) | (phi[:, :, :, :, 0] >= D[1]) | \
(phi[:, :, :, :, 1] < 0) | (phi[:, :, :, :, 1] >= D[0])
phi_invalid = phi_invalid.unsqueeze(4)
# print("phi_invalid:",phi_invalid.shape)
retinal_invalid = phi_invalid.amax((0, 3)).squeeze().unsqueeze(0)
# print("retinal_invalid:",retinal_invalid.shape)
# Fix invalid elements in phi
phi[phi_invalid.expand(-1, -1, -1, -1, 2)] = 0
return [ phi, phi_invalid, retinal_invalid ]
def GenRetinalFromLayers(self, layers, Phi):
'''
Generate retinal image from layers, using precalculated mapping matrix
Parameters
--------
layers - 3N x H x W, stacked layer images, with 3 channels in each layer
phi - N x H_r x W_r x M x 2, retinal to layers mapping, N is number of layers
Returns
--------
3 x H_r x W_r, 3 channels retinal image
'''
# FOR GRAYSCALE 1 FOR RGB 3
mapped_layers = torch.empty(self.N, 3, self.D_r[0], self.D_r[1], self.M) # 2 x 3 x 480 x 640 x 41
# print("mapped_layers:",mapped_layers.shape)
for i in range(0, Phi.size()[0]):
# torch.Size([3, 2, 320, 320, 2])
# print("gather layers:",layers[(i * 3) : (i * 3 + 3),Phi[i, :, :, :, 0],Phi[i, :, :, :, 1]].shape)
mapped_layers[i, :, :, :, :] = layers[(i * 3) : (i * 3 + 3),
Phi[i, :, :, :, 1],
Phi[i, :, :, :, 0]]
# print("mapped_layers:",mapped_layers.shape)
retinal = mapped_layers.prod(0).sum(3).div(Phi.size()[3])
# print("retinal:",retinal.shape)
return retinal
def GenRetinalFromLayersBatch(self, layers, Phi):
'''
Generate retinal image from layers, using precalculated mapping matrix
Parameters
--------
layers - 3N x H_l x W_l tensor, stacked layer images, with 3 channels in each layer
Returns
--------
3 x H_r x W_r tensor, 3 channels retinal image
H_r x W_r tensor, retinal image mask, indicates pixels valid or not
'''
mapped_layers = torch.empty(layers.size()[0], self.N, 3, self.D_r[0], self.D_r[1], self.M) #BS x Layers x C x H x W x Sample
# truth = torch.empty(layers.size()[0], self.N, 3, self.D_r[0], self.D_r[1], self.M)
# layers_truth = layers.clone()
# Phi_truth = Phi.clone()
layers = torch.stack((layers[:,0:3,:,:],layers[:,3:6,:,:]),dim=1) ## torch.Size([BS, Layer, RGB 3, 320, 320])
# Phi = Phi[:,:,None,:,:,:,:].expand(-1,-1,3,-1,-1,-1,-1)
# print("mapped_layers:",mapped_layers.shape) #torch.Size([2, 2, 3, 320, 320, 41])
# print("input layers:",layers.shape) ## torch.Size([2, 2, 3, 320, 320])
# print("input Phi:",Phi.shape) #torch.Size([2, 2, 320, 320, 41, 2])
# #没优化
# for i in range(0, Phi_truth.size()[0]):
# for j in range(0, Phi_truth.size()[1]):
# truth[i, j, :, :, :, :] = layers_truth[i, (j * 3) : (j * 3 + 3),
# Phi_truth[i, j, :, :, :, 0],
# Phi_truth[i, j, :, :, :, 1]]
#优化2
# start = time.time()
mapped_layers_op1 = mapped_layers.reshape(-1,
mapped_layers.shape[2],mapped_layers.shape[3],mapped_layers.shape[4],mapped_layers.shape[5])
# BatchSizexLayer Channel 3 320 320 41
layers_op1 = layers.reshape(-1,layers.shape[2],layers.shape[3],layers.shape[4]) # 2x2 3 320 320
Phi_op1 = Phi.reshape(-1,Phi.shape[2],Phi.shape[3],Phi.shape[4],Phi.shape[5]) # 2x2 320 320 41 2
x = Phi_op1[:,:,:,:,0] # 2x2 320 320 41
y = Phi_op1[:,:,:,:,1] # 2x2 320 320 41
# print("reshape:",time.time() - start)
# start = time.time()
mapped_layers_op1 = layers_op1[torch.arange(layers_op1.shape[0])[:, None, None, None], :, y, x] # x,y 切换
#2x2 320 320 41 3
# print("mapping one step:",time.time() - start)
# print("mapped_layers:",mapped_layers_op1.shape) # torch.Size([4, 3, 320, 320, 41])
# start = time.time()
mapped_layers_op1 = mapped_layers_op1.permute(0,4,1,2,3)
mapped_layers = mapped_layers_op1.reshape(mapped_layers.shape[0],mapped_layers.shape[1],
mapped_layers.shape[2],mapped_layers.shape[3],mapped_layers.shape[4],mapped_layers.shape[5])
# print("reshape end:",time.time() - start)
# print("test:")
# print((truth.cpu() == mapped_layers.cpu()).all())
#优化1
# start = time.time()
# mapped_layers_op1 = mapped_layers.reshape(-1,
# mapped_layers.shape[2],mapped_layers.shape[3],mapped_layers.shape[4],mapped_layers.shape[5])
# layers_op1 = layers.reshape(-1,layers.shape[2],layers.shape[3],layers.shape[4])
# Phi_op1 = Phi.reshape(-1,Phi.shape[2],Phi.shape[3],Phi.shape[4],Phi.shape[5])
# print("reshape:",time.time() - start)
# for i in range(0, Phi_op1.size()[0]):
# start = time.time()
# mapped_layers_op1[i, :, :, :, :] = layers_op1[i,:,
# Phi_op1[i, :, :, :, 0],
# Phi_op1[i, :, :, :, 1]]
# print("mapping one step:",time.time() - start)
# print("mapped_layers:",mapped_layers_op1.shape) # torch.Size([4, 3, 320, 320, 41])
# start = time.time()
# mapped_layers = mapped_layers_op1.reshape(mapped_layers.shape[0],mapped_layers.shape[1],
# mapped_layers.shape[2],mapped_layers.shape[3],mapped_layers.shape[4],mapped_layers.shape[5])
# print("reshape end:",time.time() - start)
# print("mapped_layers:",mapped_layers.shape) # torch.Size([2, 2, 3, 320, 320, 41])
retinal = mapped_layers.prod(1).sum(4).div(Phi.size()[4])
# print("retinal:",retinal.shape) # torch.Size([BatchSize, 3, 320, 320])
return retinal
## TO BE CHECK
def GenFoveaLayers(self, b_retinal, is_mask):
'''
Generate foveated layers for retinal images or masks
Parameters
--------
b_retinal - B x C x H_r x W_r, Batch of retinal images/masks
is_mask - Whether b_retinal is masks or images
Returns
--------
b_fovea_layers - N_f x (B x C x H[f] x W[f]) list of batch of foveated layers
'''
b_fovea_layers = []
for i in range(0, len(self.conf.eye_fovea_angles)):
k = self.conf.eye_fovea_downsamples[i]
region = self.conf.GetRegionOfFoveaLayer(i)
b_roi = b_retinal[:, :, region, region]
if k == 1:
b_fovea_layers.append(b_roi)
elif is_mask:
b_fovea_layers.append(torch.nn.functional.max_pool2d(b_roi.to(torch.float), k).to(torch.bool))
else:
b_fovea_layers.append(torch.nn.functional.avg_pool2d(b_roi, k))
return b_fovea_layers
# fovea_layers = []
# fovea_layer_masks = []
# fov = self.conf.eye_fovea_angles[-1]
# retinal_res = int(self.conf.retinal_res[0])
# for i in range(0, len(self.conf.eye_fovea_angles)):
# angle = self.conf.eye_fovea_angles[i]
# k = self.conf.eye_fovea_downsamples[i]
# roi_size = int(np.ceil(retinal_res * angle / fov))
# roi_offset = int((retinal_res - roi_size) / 2)
# roi_img = retinal[:, roi_offset:(roi_offset + roi_size), roi_offset:(roi_offset + roi_size)]
# roi_mask = retinal_mask[roi_offset:(roi_offset + roi_size), roi_offset:(roi_offset + roi_size)]
# if k == 1:
# fovea_layers.append(roi_img)
# fovea_layer_masks.append(roi_mask)
# else:
# fovea_layers.append(torch.nn.functional.avg_pool2d(roi_img.unsqueeze(0), k).squeeze(0))
# fovea_layer_masks.append(1 - torch.nn.functional.max_pool2d((1 - roi_mask).unsqueeze(0), k).squeeze(0))
# return [ fovea_layers, fovea_layer_masks ]
## TO BE CHECK
def GenFoveaLayersBatch(self, retinal, retinal_mask):
'''
Generate foveated layers and corresponding masks
Parameters
--------
retinal - Retinal image generated by GenRetinalFromLayers()
retinal_mask - Mask of retinal image, also generated by GenRetinalFromLayers()
Returns
--------
fovea_layers - list of foveated layers
fovea_layer_masks - list of mask images, corresponding to foveated layers
'''
fovea_layers = []
fovea_layer_masks = []
fov = self.conf.eye_fovea_angles[-1]
# print("fov:",fov)
retinal_res = int(self.conf.retinal_res[0])
# print("retinal_res:",retinal_res)
# print("len(self.conf.eye_fovea_angles):",len(self.conf.eye_fovea_angles))
for i in range(0, len(self.conf.eye_fovea_angles)):
angle = self.conf.eye_fovea_angles[i]
k = self.conf.eye_fovea_downsamples[i]
roi_size = int(np.ceil(retinal_res * angle / fov))
roi_offset = int((retinal_res - roi_size) / 2)
# [2, 3, 320, 320]
roi_img = retinal[:, :, roi_offset:(roi_offset + roi_size), roi_offset:(roi_offset + roi_size)]
# print("roi_img:",roi_img.shape)
# [2, 320, 320]
roi_mask = retinal_mask[:, roi_offset:(roi_offset + roi_size), roi_offset:(roi_offset + roi_size)]
# print("roi_mask:",roi_mask.shape)
if k == 1:
fovea_layers.append(roi_img)
fovea_layer_masks.append(roi_mask)
else:
fovea_layers.append(torch.nn.functional.avg_pool2d(roi_img, k))
fovea_layer_masks.append(1 - torch.nn.functional.max_pool2d((1 - roi_mask), k))
return [ fovea_layers, fovea_layer_masks ]
## TO BE CHECK
def GenFoveaRetinal(self, b_fovea_layers):
'''
Generate foveated retinal image by blending fovea layers
**Note: current implementation only support two fovea layers**
Parameters
--------
b_fovea_layers - N_f x (B x 3 x H[f] x W[f]), list of batch of (masked) foveated layers
Returns
--------
B x 3 x H_r x W_r, batch of foveated retinal images
'''
b_fovea_retinal = torch.nn.functional.interpolate(b_fovea_layers[1],
scale_factor=self.conf.eye_fovea_downsamples[1],
mode='bilinear', align_corners=False)
region = self.conf.GetRegionOfFoveaLayer(0)
blend = self.conf.eye_fovea_blend[0]
b_roi = b_fovea_retinal[:, :, region, region]
b_roi.mul_(1 - blend).add_(b_fovea_layers[0] * blend)
return b_fovea_retinal
This diff is collapsed.
import torch
import argparse
import os
import glob
import numpy as np
import torchvision.transforms as transforms
from torchvision.utils import save_image
from torchvision import datasets
from torch.utils.data import DataLoader
from torch.autograd import Variable
import cv2
from loss import *
import json
from baseline import *
from data import *
import torch.autograd.profiler as profiler
# param
BATCH_SIZE = 2
NUM_EPOCH = 1001
INTERLEAVE_RATE = 2
IM_H = 540
IM_W = 376
Retinal_IM_H = 540
Retinal_IM_W = 376
N = 4 # number of input light field stack
M = 1 # number of display layers
DATA_FILE = "/home/yejiannan/Project/deeplightfield/data/lf_syn"
DATA_JSON = "/home/yejiannan/Project/deeplightfield/data/data_lf_syn_full.json"
# DATA_VAL_JSON = "/home/yejiannan/Project/LightField/data/data_gaze_fovea_val.json"
OUTPUT_DIR = "/home/yejiannan/Project/deeplightfield/outputE/lf_syn_full1219"
OUT_CHANNELS_RB = 128
KERNEL_SIZE_RB = 3
KERNEL_SIZE = 3
LAST_LAYER_CHANNELS = 3 * INTERLEAVE_RATE**2
FIRSST_LAYER_CHANNELS = 12 * INTERLEAVE_RATE**2
from weight_init import weight_init_normal
def save_checkpoints(file_path, epoch_idx, model, model_solver):
print('[INFO] Saving checkpoint to %s ...' % ( file_path))
checkpoint = {
'epoch_idx': epoch_idx,
'model_state_dict': model.state_dict(),
'model_solver_state_dict': model_solver.state_dict()
}
torch.save(checkpoint, file_path)
mode = "Silence" #"Perf"
w_frame = 1.0
loss1 = PerceptionReconstructionLoss()
if __name__ == "__main__":
#train
train_data_loader = torch.utils.data.DataLoader(dataset=lightFieldSynDataLoader(DATA_FILE,DATA_JSON),
batch_size=BATCH_SIZE,
num_workers=8,
pin_memory=True,
shuffle=True,
drop_last=False)
#Data loader test
print(len(train_data_loader))
lf_model = model(FIRSST_LAYER_CHANNELS,LAST_LAYER_CHANNELS,OUT_CHANNELS_RB,KERNEL_SIZE,KERNEL_SIZE_RB,INTERLEAVE_RATE,RNN=False)
lf_model.apply(weight_init_normal)
lf_model.train()
epoch_begin = 0
if torch.cuda.is_available():
# lf_model = torch.nn.DataParallel(lf_model).cuda()
lf_model = lf_model.to('cuda:1')
optimizer = torch.optim.Adam(lf_model.parameters(),lr=5e-3,betas=(0.9,0.999))
# lf_model.output_layer.register_backward_hook(hook_fn_back)
if mode=="Perf":
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
print("begin training....")
for epoch in range(epoch_begin, NUM_EPOCH):
for batch_idx, (image_set, gt, pos_row, pos_col) in enumerate(train_data_loader):
if mode=="Perf":
end.record()
torch.cuda.synchronize()
print("load:",start.elapsed_time(end))
start.record()
#reshape for input
image_set = image_set.permute(0,1,4,2,3) # N LF C H W
image_set = image_set.reshape(image_set.shape[0],-1,image_set.shape[3],image_set.shape[4]) # N LFxC H W
image_set = var_or_cuda(image_set)
gt = gt.permute(0,3,1,2) # BS C H W
gt = var_or_cuda(gt)
if mode=="Perf":
end.record()
torch.cuda.synchronize()
print("data prepare:",start.elapsed_time(end))
start.record()
output = lf_model(image_set,pos_row, pos_col) # 2 6 376 540
if mode=="Perf":
end.record()
torch.cuda.synchronize()
print("forward:",start.elapsed_time(end))
start.record()
optimizer.zero_grad()
# print("output:",output.shape," gt:",gt.shape)
loss1_value = loss1(output,gt)
loss = (w_frame * loss1_value)
if mode=="Perf":
end.record()
torch.cuda.synchronize()
print("compute loss:",start.elapsed_time(end))
start.record()
loss.backward()
if mode=="Perf":
end.record()
torch.cuda.synchronize()
print("backward:",start.elapsed_time(end))
start.record()
optimizer.step()
if mode=="Perf":
end.record()
torch.cuda.synchronize()
print("update:",start.elapsed_time(end))
print("Epoch:",epoch,",Iter:",batch_idx,",loss:",loss.item())
# exit(0)
########################### Save #####################
if ((epoch%10== 0 and epoch != 0) or epoch == 2): # torch.Size([2, 5, 160, 160, 3])
for i in range(gt.size()[0]):
save_image(output[i].data,os.path.join(OUTPUT_DIR,"out_%.5f_%.5f.png"%(pos_col[i].data,pos_row[i].data)))
save_image(gt[i].data,os.path.join(OUTPUT_DIR,"gt_%.5f_%.5f.png"%(pos_col[i].data,pos_row[i].data)))
if ((epoch%100 == 0) and epoch != 0 and batch_idx==len(train_data_loader)-1):
save_checkpoints(os.path.join(OUTPUT_DIR, 'ckpt-epoch-%04d.pth' % (epoch)),epoch,lf_model,optimizer)
\ No newline at end of file
import torch
def GetDevice():
if torch.cuda.is_available():
return torch.device('cuda')
return torch.device('cpu')
\ No newline at end of file
......@@ -10,21 +10,24 @@ from tensorboardX import SummaryWriter
from .loss.loss import PerceptionReconstructionLoss
from .my import netio
from .my import util
from .my import device
from .my.simple_perf import SimplePerf
from .data.lf_syn import LightFieldSynDataset
from .trans_unet import TransUnet
device = torch.device("cuda:2")
torch.cuda.set_device(2)
print("Set CUDA:%d as current device." % torch.cuda.current_device())
DATA_DIR = os.path.dirname(__file__) + '/data/lf_syn_2020.12.23'
TRAIN_DATA_DESC_FILE = DATA_DIR + '/train.json'
OUTPUT_DIR = DATA_DIR + '/output_low_lr'
RUN_DIR = DATA_DIR + '/run_low_lr'
BATCH_SIZE = 1
OUTPUT_DIR = DATA_DIR + '/output_bat2'
RUN_DIR = DATA_DIR + '/run_bat2'
BATCH_SIZE = 8
TEST_BATCH_SIZE = 10
NUM_EPOCH = 1000
MODE = "Silence" # "Perf"
EPOCH_BEGIN = 500
EPOCH_BEGIN = 0
def train():
......@@ -44,7 +47,7 @@ def train():
view_images=train_dataset.sparse_view_images,
view_depths=train_dataset.sparse_view_depths,
view_positions=train_dataset.sparse_view_positions,
diopter_of_layers=train_dataset.diopter_of_layers).to(device)
diopter_of_layers=train_dataset.diopter_of_layers).to(device.GetDevice())
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
loss = PerceptionReconstructionLoss()
......@@ -66,7 +69,7 @@ def train():
for epoch in range(EPOCH_BEGIN, NUM_EPOCH):
for _, view_images, _, view_positions in train_data_loader:
view_images = view_images.to(device)
view_images = view_images.to(device.GetDevice())
perf.Checkpoint("Load")
......@@ -106,7 +109,6 @@ def train():
solver=optimizer)
print("Train finished")
netio.SaveNet('%s/model-epoch_%d.pth' % (RUN_DIR, epoch + 1), model)
def test(net_file: str):
......@@ -125,7 +127,7 @@ def test(net_file: str):
view_images=train_dataset.sparse_view_images,
view_depths=train_dataset.sparse_view_depths,
view_positions=train_dataset.sparse_view_positions,
diopter_of_layers=train_dataset.diopter_of_layers).to(device)
diopter_of_layers=train_dataset.diopter_of_layers).to(device.GetDevice())
netio.LoadNet(net_file, model)
# 3. Test on train dataset
......@@ -142,5 +144,5 @@ def test(net_file: str):
if __name__ == "__main__":
#train()
test(RUN_DIR + '/model-epoch_1000.pth')
train()
#test(RUN_DIR + '/model-epoch_1000.pth')
......@@ -3,8 +3,7 @@ import torch
import torch.nn as nn
from .pytorch_prototyping.pytorch_prototyping import *
from .my import util
device = torch.device("cuda:2")
from .my import device
class Encoder(nn.Module):
......@@ -66,15 +65,15 @@ class LatentSpaceTransformer(nn.Module):
self.n_views = view_positions.size()[0]
self.diopter_of_layers = diopter_of_layers
self.feat_coords = util.MeshGrid(
(feat_dim, feat_dim)).to(device=device)
(feat_dim, feat_dim)).to(device.GetDevice())
def forward(self, feats: torch.Tensor,
feat_depths: torch.Tensor,
novel_views: torch.Tensor) -> torch.Tensor:
trans_feats = torch.zeros(novel_views.size()[0], feats.size()[0],
feats.size()[1], feats.size()[
2], feats.size()[3],
device=device)
trans_feats = torch.zeros(novel_views.size()[0],
feats.size()[0], feats.size()[1],
feats.size()[2], feats.size()[3],
device=device.GetDevice())
for i in range(novel_views.size()[0]):
for v in range(self.n_views):
for l in range(len(self.diopter_of_layers)):
......@@ -151,8 +150,8 @@ class TransUnet(nn.Module):
latent_sidelength = 64 # The dimensions of the latent space
image_sidelength = view_images.size()[2]
self.view_images = view_images.to(device)
self.view_depths = view_depths.to(device)
self.view_images = view_images.to(device.GetDevice())
self.view_depths = view_depths.to(device.GetDevice())
self.n_views = view_images.size()[0]
self.encoder = Encoder(nf0=nf0,
out_channels=nf,
......
import torch
weightVarScale = 0.25
bias_stddev = 0.01
def weight_init_normal(m):
classname = m.__class__.__name__
if classname.find("Conv") != -1:
torch.nn.init.xavier_normal_(m.weight.data)
torch.nn.init.normal_(m.bias.data,mean = 0.0, std=bias_stddev)
elif classname.find("BatchNorm2d") != -1:
torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
torch.nn.init.constant_(m.bias.data, 0.0)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment