Commit 69e1d015 authored by BobYeah's avatar BobYeah
Browse files

Update1205ForHPC

parent 5069f8ae
import torch
from gen_image import *
class Conf(object):
def __init__(self):
self.pupil_size = 0.02
self.retinal_res = torch.tensor([ 320, 320 ])
self.layer_res = torch.tensor([ 320, 320 ])
self.layer_hfov = 90 # layers' horizontal FOV
self.eye_hfov = 85 # eye's horizontal FOV (ignored in foveated rendering)
self.eye_enable_fovea = True # enable foveated rendering
self.eye_fovea_angles = [ 40, 80 ] # eye's foveation layers' angles
self.eye_fovea_downsamples = [ 1, 2 ] # eye's foveation layers' downsamples
self.d_layer = [ 1, 3 ] # layers' distance
def GetNLayers(self):
return len(self.d_layer)
def GetLayerSize(self, i):
w = Fov2Length(self.layer_hfov)
h = w * self.layer_res[0] / self.layer_res[1]
return torch.tensor([ h, w ]) * self.d_layer[i]
def GetEyeViewportSize(self):
fov = self.eye_fovea_angles[-1] if self.eye_enable_fovea else self.eye_hfov
w = Fov2Length(fov)
h = w * self.retinal_res[0] / self.retinal_res[1]
return torch.tensor([ h, w ])
\ No newline at end of file
......@@ -159,3 +159,36 @@ class RetinalGen(object):
retinal = mapped_layers.prod(0).sum(3).div(Phi.size()[3])
# print("retinal:",retinal.shape)
return retinal
def GenFoveaLayers(self, retinal, retinal_mask):
'''
Generate foveated layers and corresponding masks
Parameters
--------
retinal - Retinal image generated by GenRetinalFromLayers()
retinal_mask - Mask of retinal image, also generated by GenRetinalFromLayers()
Returns
--------
fovea_layers - list of foveated layers
fovea_layer_masks - list of mask images, corresponding to foveated layers
'''
fovea_layers = []
fovea_layer_masks = []
fov = self.conf.eye_fovea_angles[-1]
retinal_res = int(self.conf.retinal_res[0])
for i in range(0, len(self.conf.eye_fovea_angles)):
angle = self.conf.eye_fovea_angles[i]
k = self.conf.eye_fovea_downsamples[i]
roi_size = int(np.ceil(retinal_res * angle / fov))
roi_offset = int((retinal_res - roi_size) / 2)
roi_img = retinal[:, roi_offset:(roi_offset + roi_size), roi_offset:(roi_offset + roi_size)]
roi_mask = retinal_mask[roi_offset:(roi_offset + roi_size), roi_offset:(roi_offset + roi_size)]
if k == 1:
fovea_layers.append(roi_img)
fovea_layer_masks.append(roi_mask)
else:
fovea_layers.append(torch.nn.functional.avg_pool2d(roi_img.unsqueeze(0), k).squeeze(0))
fovea_layer_masks.append(1 - torch.nn.functional.max_pool2d((1 - roi_mask).unsqueeze(0), k).squeeze(0))
return [ fovea_layers, fovea_layer_masks ]
\ No newline at end of file
......@@ -15,9 +15,14 @@ from gen_image import *
import json
from ssim import *
from perc_loss import *
from conf import Conf
from model.baseline import *
import torch.autograd.profiler as profiler
# param
BATCH_SIZE = 16
NUM_EPOCH = 1000
BATCH_SIZE = 2
NUM_EPOCH = 300
INTERLEAVE_RATE = 2
......@@ -30,15 +35,24 @@ Retinal_IM_W = 320
N = 9 # number of input light field stack
M = 2 # number of display layers
DATA_FILE = "/home/yejiannan/Project/LightField/data/gaze_small_nar_new"
DATA_JSON = "/home/yejiannan/Project/LightField/data/data_gaze_low_new.json"
DATA_VAL_JSON = "/home/yejiannan/Project/LightField/data/data_val.json"
OUTPUT_DIR = "/home/yejiannan/Project/LightField/output/gaze_low_new_1125_minibatch"
DATA_FILE = "/home/yejiannan/Project/LightField/data/gaze_fovea"
DATA_JSON = "/home/yejiannan/Project/LightField/data/data_gaze_fovea_seq.json"
DATA_VAL_JSON = "/home/yejiannan/Project/LightField/data/data_gaze_fovea_val.json"
OUTPUT_DIR = "/home/yejiannan/Project/LightField/outputE/gaze_fovea_seq"
OUT_CHANNELS_RB = 128
KERNEL_SIZE_RB = 3
KERNEL_SIZE = 3
LAST_LAYER_CHANNELS = 6 * INTERLEAVE_RATE**2
FIRSST_LAYER_CHANNELS = 27 * INTERLEAVE_RATE**2
class lightFieldDataLoader(torch.utils.data.dataset.Dataset):
def __init__(self, file_dir_path, file_json, transforms=None):
self.file_dir_path = file_dir_path
self.transforms = transforms
# self.datum_list = glob.glob(os.path.join(file_dir_path,"*"))
with open(file_json, encoding='utf-8') as file:
self.dataset_desc = json.loads(file.read())
......@@ -46,17 +60,27 @@ class lightFieldDataLoader(torch.utils.data.dataset.Dataset):
return len(self.dataset_desc["focaldepth"])
def __getitem__(self, idx):
lightfield_images, gt, fd, gazeX, gazeY, sample_idx = self.get_datum(idx)
lightfield_images, gt, gt2, fd, gazeX, gazeY, sample_idx = self.get_datum(idx)
if self.transforms:
lightfield_images = self.transforms(lightfield_images)
return (lightfield_images, gt, fd, gazeX, gazeY, sample_idx)
# print(lightfield_images.shape,gt.shape,fd,gazeX,gazeY,sample_idx)
return (lightfield_images, gt, gt2, fd, gazeX, gazeY, sample_idx)
def get_datum(self, idx):
lf_image_paths = os.path.join(DATA_FILE, self.dataset_desc["train"][idx])
# print(lf_image_paths)
fd_gt_path = os.path.join(DATA_FILE, self.dataset_desc["gt"][idx])
fd_gt_path2 = os.path.join(DATA_FILE, self.dataset_desc["gt2"][idx])
# print(fd_gt_path)
lf_images = []
lf_image_big = cv2.imread(lf_image_paths, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.
lf_image_big = cv2.cvtColor(lf_image_big,cv2.COLOR_BGR2RGB)
## IF GrayScale
# lf_image_big = cv2.imread(lf_image_paths, cv2.IMREAD_GRAYSCALE).astype(np.float32) / 255.
# lf_image_big = np.expand_dims(lf_image_big, axis=-1)
# print(lf_image_big.shape)
for i in range(9):
lf_image = lf_image_big[i//3*IM_H:i//3*IM_H+IM_H,i%3*IM_W:i%3*IM_W+IM_W,0:3]
## IF GrayScale
......@@ -65,6 +89,8 @@ class lightFieldDataLoader(torch.utils.data.dataset.Dataset):
lf_images.append(lf_image)
gt = cv2.imread(fd_gt_path, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.
gt = cv2.cvtColor(gt,cv2.COLOR_BGR2RGB)
gt2 = cv2.imread(fd_gt_path2, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.
gt2 = cv2.cvtColor(gt2,cv2.COLOR_BGR2RGB)
## IF GrayScale
# gt = cv2.imread(fd_gt_path, cv2.IMREAD_GRAYSCALE).astype(np.float32) / 255.
# gt = np.expand_dims(gt, axis=-1)
......@@ -73,149 +99,124 @@ class lightFieldDataLoader(torch.utils.data.dataset.Dataset):
gazeX = self.dataset_desc["gazeX"][idx]
gazeY = self.dataset_desc["gazeY"][idx]
sample_idx = self.dataset_desc["idx"][idx]
return np.asarray(lf_images),gt,fd,gazeX,gazeY,sample_idx
return np.asarray(lf_images),gt,gt2,fd,gazeX,gazeY,sample_idx
OUT_CHANNELS_RB = 128
KERNEL_SIZE_RB = 3
KERNEL_SIZE = 3
class lightFieldValDataLoader(torch.utils.data.dataset.Dataset):
def __init__(self, file_dir_path, file_json, transforms=None):
self.file_dir_path = file_dir_path
self.transforms = transforms
# self.datum_list = glob.glob(os.path.join(file_dir_path,"*"))
with open(file_json, encoding='utf-8') as file:
self.dataset_desc = json.loads(file.read())
class residual_block(torch.nn.Module):
def __init__(self,delta_channel_dim):
super(residual_block,self).__init__()
self.layer1 = torch.nn.Sequential(
torch.nn.Conv2d(OUT_CHANNELS_RB+delta_channel_dim,OUT_CHANNELS_RB+delta_channel_dim,KERNEL_SIZE_RB,stride=1,padding = 1),
torch.nn.BatchNorm2d(OUT_CHANNELS_RB+delta_channel_dim),
torch.nn.ELU()
)
self.layer2 = torch.nn.Sequential(
torch.nn.Conv2d(OUT_CHANNELS_RB+delta_channel_dim,OUT_CHANNELS_RB+delta_channel_dim,KERNEL_SIZE_RB,stride=1,padding = 1),
torch.nn.BatchNorm2d(OUT_CHANNELS_RB+delta_channel_dim),
torch.nn.ELU()
)
def forward(self,input):
output = self.layer1(input)
output = self.layer2(output)
output = input+output
return output
class deinterleave(torch.nn.Module):
def __init__(self, block_size):
super(deinterleave, self).__init__()
self.block_size = block_size
self.block_size_sq = block_size*block_size
def forward(self, input):
output = input.permute(0, 2, 3, 1)
(batch_size, d_height, d_width, d_depth) = output.size()
s_depth = int(d_depth / self.block_size_sq)
s_width = int(d_width * self.block_size)
s_height = int(d_height * self.block_size)
t_1 = output.reshape(batch_size, d_height, d_width, self.block_size_sq, s_depth)
spl = t_1.split(self.block_size, 3)
stack = [t_t.reshape(batch_size, d_height, s_width, s_depth) for t_t in spl]
output = torch.stack(stack,0).transpose(0,1).permute(0,2,1,3,4).reshape(batch_size, s_height, s_width, s_depth)
output = output.permute(0, 3, 1, 2)
return output
class interleave(torch.nn.Module):
def __init__(self, block_size):
super(interleave, self).__init__()
self.block_size = block_size
self.block_size_sq = block_size*block_size
def forward(self, input):
output = input.permute(0, 2, 3, 1)
(batch_size, s_height, s_width, s_depth) = output.size()
d_depth = s_depth * self.block_size_sq
d_width = int(s_width / self.block_size)
d_height = int(s_height / self.block_size)
t_1 = output.split(self.block_size, 2)
stack = [t_t.reshape(batch_size, d_height, d_depth) for t_t in t_1]
output = torch.stack(stack, 1)
output = output.permute(0, 2, 1, 3)
output = output.permute(0, 3, 1, 2)
return output
def __len__(self):
return len(self.dataset_desc["focaldepth"])
LAST_LAYER_CHANNELS = 6 * INTERLEAVE_RATE**2
FIRSST_LAYER_CHANNELS = 27 * INTERLEAVE_RATE**2
def __getitem__(self, idx):
lightfield_images, fd, gazeX, gazeY, sample_idx = self.get_datum(idx)
if self.transforms:
lightfield_images = self.transforms(lightfield_images)
# print(lightfield_images.shape,gt.shape,fd,gazeX,gazeY,sample_idx)
return (lightfield_images, fd, gazeX, gazeY, sample_idx)
def get_datum(self, idx):
lf_image_paths = os.path.join(DATA_FILE, self.dataset_desc["train"][idx])
# print(fd_gt_path)
lf_images = []
lf_image_big = cv2.imread(lf_image_paths, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.
lf_image_big = cv2.cvtColor(lf_image_big,cv2.COLOR_BGR2RGB)
## IF GrayScale
# lf_image_big = cv2.imread(lf_image_paths, cv2.IMREAD_GRAYSCALE).astype(np.float32) / 255.
# lf_image_big = np.expand_dims(lf_image_big, axis=-1)
# print(lf_image_big.shape)
for i in range(9):
lf_image = lf_image_big[i//3*IM_H:i//3*IM_H+IM_H,i%3*IM_W:i%3*IM_W+IM_W,0:3]
## IF GrayScale
# lf_image = lf_image_big[i//3*IM_H:i//3*IM_H+IM_H,i%3*IM_W:i%3*IM_W+IM_W,0:1]
# print(lf_image.shape)
lf_images.append(lf_image)
## IF GrayScale
# gt = cv2.imread(fd_gt_path, cv2.IMREAD_GRAYSCALE).astype(np.float32) / 255.
# gt = np.expand_dims(gt, axis=-1)
fd = self.dataset_desc["focaldepth"][idx]
gazeX = self.dataset_desc["gazeX"][idx]
gazeY = self.dataset_desc["gazeY"][idx]
sample_idx = self.dataset_desc["idx"][idx]
return np.asarray(lf_images),fd,gazeX,gazeY,sample_idx
class lightFieldSeqDataLoader(torch.utils.data.dataset.Dataset):
def __init__(self, file_dir_path, file_json, transforms=None):
self.file_dir_path = file_dir_path
self.transforms = transforms
with open(file_json, encoding='utf-8') as file:
self.dataset_desc = json.loads(file.read())
def __len__(self):
return len(self.dataset_desc["seq"])
def __getitem__(self, idx):
lightfield_images, gt, gt2, fd, gazeX, gazeY, sample_idx = self.get_datum(idx)
fd = fd.astype(np.float32)
gazeX = gazeX.astype(np.float32)
gazeY = gazeY.astype(np.float32)
sample_idx = sample_idx.astype(np.int64)
# print(fd)
# print(gazeX)
# print(gazeY)
# print(sample_idx)
# print(lightfield_images.dtype,gt.dtype, gt2.dtype, fd.dtype, gazeX.dtype, gazeY.dtype, sample_idx.dtype, delta.dtype)
# print(lightfield_images.shape,gt.shape, gt2.shape, fd.shape, gazeX.shape, gazeY.shape, sample_idx.shape, delta.shape)
if self.transforms:
lightfield_images = self.transforms(lightfield_images)
return (lightfield_images, gt, gt2, fd, gazeX, gazeY, sample_idx)
class model(torch.nn.Module):
def __init__(self):
super(model, self).__init__()
self.interleave = interleave(INTERLEAVE_RATE)
self.first_layer = torch.nn.Sequential(
torch.nn.Conv2d(FIRSST_LAYER_CHANNELS,OUT_CHANNELS_RB,KERNEL_SIZE,stride=1,padding=1),
torch.nn.BatchNorm2d(OUT_CHANNELS_RB),
torch.nn.ELU()
)
self.residual_block1 = residual_block(0)
self.residual_block2 = residual_block(3)
self.residual_block3 = residual_block(3)
self.residual_block4 = residual_block(3)
self.residual_block5 = residual_block(3)
self.output_layer = torch.nn.Sequential(
torch.nn.Conv2d(OUT_CHANNELS_RB+3,LAST_LAYER_CHANNELS,KERNEL_SIZE,stride=1,padding=1),
torch.nn.BatchNorm2d(LAST_LAYER_CHANNELS),
torch.nn.Sigmoid()
)
self.deinterleave = deinterleave(INTERLEAVE_RATE)
def forward(self, lightfield_images, focal_length, gazeX, gazeY):
input_to_net = self.interleave(lightfield_images)
input_to_rb = self.first_layer(input_to_net)
output = self.residual_block1(input_to_rb)
depth_layer = torch.ones((input_to_rb.shape[0],1,input_to_rb.shape[2],input_to_rb.shape[3]))
gazeX_layer = torch.ones((input_to_rb.shape[0],1,input_to_rb.shape[2],input_to_rb.shape[3]))
gazeY_layer = torch.ones((input_to_rb.shape[0],1,input_to_rb.shape[2],input_to_rb.shape[3]))
for i in range(focal_length.shape[0]):
depth_layer[i] *= 1. / focal_length[i]
gazeX_layer[i] *= (gazeX[i] - (-3.333)) / (3.333*2)
gazeY_layer[i] *= (gazeY[i] - (-3.333)) / (3.333*2)
depth_layer = var_or_cuda(depth_layer)
gazeX_layer = var_or_cuda(gazeX_layer)
gazeY_layer = var_or_cuda(gazeY_layer)
output = torch.cat((output,depth_layer,gazeX_layer,gazeY_layer),dim=1)
output = self.residual_block2(output)
output = self.residual_block3(output)
output = self.residual_block4(output)
output = self.residual_block5(output)
output = self.output_layer(output)
output = self.deinterleave(output)
return output
class Conf(object):
def __init__(self):
self.pupil_size = 0.02 # 2cm
self.retinal_res = torch.tensor([ Retinal_IM_H, Retinal_IM_W ])
self.layer_res = torch.tensor([ IM_H, IM_W ])
self.layer_hfov = 90 # layers' horizontal FOV
self.eye_hfov = 85 # eye's horizontal FOV
self.d_layer = [ 1, 3 ] # layers' distance
def GetNLayers(self):
return len(self.d_layer)
def GetLayerSize(self, i):
w = Fov2Length(self.layer_hfov)
h = w * self.layer_res[0] / self.layer_res[1]
return torch.tensor([ h, w ]) * self.d_layer[i]
def GetEyeViewportSize(self):
w = Fov2Length(self.eye_hfov)
h = w * self.retinal_res[0] / self.retinal_res[1]
return torch.tensor([ h, w ])
def get_datum(self, idx):
indices = self.dataset_desc["seq"][idx]
# print("indices:",indices)
lf_images = []
fd = []
gazeX = []
gazeY = []
sample_idx = []
gt = []
gt2 = []
for i in range(len(indices)):
lf_image_paths = os.path.join(DATA_FILE, self.dataset_desc["train"][indices[i]])
fd_gt_path = os.path.join(DATA_FILE, self.dataset_desc["gt"][indices[i]])
fd_gt_path2 = os.path.join(DATA_FILE, self.dataset_desc["gt2"][indices[i]])
lf_image_one_sample = []
lf_image_big = cv2.imread(lf_image_paths, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.
lf_image_big = cv2.cvtColor(lf_image_big,cv2.COLOR_BGR2RGB)
for j in range(9):
lf_image = lf_image_big[j//3*IM_H:j//3*IM_H+IM_H,j%3*IM_W:j%3*IM_W+IM_W,0:3]
## IF GrayScale
# lf_image = lf_image_big[i//3*IM_H:i//3*IM_H+IM_H,i%3*IM_W:i%3*IM_W+IM_W,0:1]
# print(lf_image.shape)
lf_image_one_sample.append(lf_image)
gt_i = cv2.imread(fd_gt_path, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.
gt.append(cv2.cvtColor(gt_i,cv2.COLOR_BGR2RGB))
gt2_i = cv2.imread(fd_gt_path2, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.
gt2.append(cv2.cvtColor(gt2_i,cv2.COLOR_BGR2RGB))
# print("indices[i]:",indices[i])
fd.append([self.dataset_desc["focaldepth"][indices[i]]])
gazeX.append([self.dataset_desc["gazeX"][indices[i]]])
gazeY.append([self.dataset_desc["gazeY"][indices[i]]])
sample_idx.append([self.dataset_desc["idx"][indices[i]]])
lf_images.append(lf_image_one_sample)
#lf_images: 5,9,320,320
return np.asarray(lf_images),np.asarray(gt),np.asarray(gt2),np.asarray(fd),np.asarray(gazeX),np.asarray(gazeY),np.asarray(sample_idx)
#### Image Gen
conf = Conf()
u = GenSamplesInPupil(conf.pupil_size, 5)
gen = RetinalGen(conf, u)
def GenRetinalFromLayersBatch(layers, gen, sample_idx, phi_dict, mask_dict):
......@@ -228,14 +229,44 @@ def GenRetinalFromLayersBatch(layers, gen, sample_idx, phi_dict, mask_dict):
mask = [] # mask shape 480 x 640
for i in range(0, layers.size()[0]):
phi = phi_dict[int(sample_idx[i].data)]
# print("phi_i:",phi.shape)
phi = var_or_cuda(phi)
phi.requires_grad = False
# print("layers[i]:",layers[i].shape)
# print("retinal[i]:",retinal[i].shape)
retinal[i] = gen.GenRetinalFromLayers(layers[i],phi)
mask.append(mask_dict[int(sample_idx[i].data)])
retinal = var_or_cuda(retinal)
mask = torch.stack(mask,dim = 0).unsqueeze(1) # batch x 1 x height x width
return retinal, mask
def GenRetinalGazeFromLayersBatch(layers, gen, sample_idx, phi_dict, mask_dict):
# layers: batchsize, 2*color, height, width
# Phi:torch.Size([batchsize, 480, 640, 2, 41, 2])
# df : batchsize,..
# retinal bs x color x height x width
retinal_fovea = torch.empty(layers.shape[0], 6, 160, 160)
mask_fovea = torch.empty(layers.shape[0], 2, 160, 160)
for i in range(0, layers.size()[0]):
phi = phi_dict[int(sample_idx[i].data)]
# print("phi_i:",phi.shape)
phi = var_or_cuda(phi)
phi.requires_grad = False
mask_i = var_or_cuda(mask_dict[int(sample_idx[i].data)])
mask_i.requires_grad = False
# print("layers[i]:",layers[i].shape)
# print("retinal[i]:",retinal[i].shape)
retinal_i = gen.GenRetinalFromLayers(layers[i],phi)
fovea_layers, fovea_layer_masks = gen.GenFoveaLayers(retinal_i,mask_i)
retinal_fovea[i] = torch.cat([fovea_layers[0],fovea_layers[1]],dim=0)
mask_fovea[i] = torch.stack([fovea_layer_masks[0],fovea_layer_masks[1]],dim=0)
retinal_fovea = var_or_cuda(retinal_fovea)
mask_fovea = var_or_cuda(mask_fovea) # batch x 2 x height x width
# mask = torch.stack(mask,dim = 0).unsqueeze(1)
return retinal_fovea, mask_fovea
def GenRetinalFromLayersBatch_Online(layers, gen, phi, mask):
# layers: batchsize, 2*color, height, width
# Phi:torch.Size([batchsize, 480, 640, 2, 41, 2])
......@@ -249,6 +280,8 @@ def GenRetinalFromLayersBatch_Online(layers, gen, phi, mask):
retinal = gen.GenRetinalFromLayers(layers[0],phi)
retinal = var_or_cuda(retinal)
mask_out = mask.unsqueeze(0).unsqueeze(0)
# print("maskOUt:",mask_out.shape) # 1,1,240,320
# mask_out = torch.stack(mask,dim = 0).unsqueeze(1) # batch x 1 x height x width
return retinal.unsqueeze(0), mask_out
#### Image Gen End
......@@ -264,10 +297,6 @@ def weight_init_normal(m):
torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
torch.nn.init.constant_(m.bias.data, 0.0)
def var_or_cuda(x):
if torch.cuda.is_available():
x = x.cuda(non_blocking=True)
return x
def calImageGradients(images):
# x is a 4-D tensor
......@@ -277,19 +306,25 @@ def calImageGradients(images):
perc_loss = VGGPerceptualLoss()
perc_loss = perc_loss.to("cuda")
perc_loss = perc_loss.to("cuda:1")
def loss_new(generated, gt):
mse_loss = torch.nn.MSELoss()
rmse_intensity = mse_loss(generated, gt)
psnr_intensity = torch.log10(rmse_intensity)
# print("psnr:",psnr_intensity)
# ssim_intensity = ssim(generated, gt)
labels_dx, labels_dy = calImageGradients(gt)
# print("generated:",generated.shape)
preds_dx, preds_dy = calImageGradients(generated)
rmse_grad_x, rmse_grad_y = mse_loss(labels_dx, preds_dx), mse_loss(labels_dy, preds_dy)
psnr_grad_x, psnr_grad_y = torch.log10(rmse_grad_x), torch.log10(rmse_grad_y)
# print("psnr x&y:",psnr_grad_x," ",psnr_grad_y)
p_loss = perc_loss(generated,gt)
# print("-psnr:",-psnr_intensity,",0.5*(psnr_grad_x + psnr_grad_y):",0.5*(psnr_grad_x + psnr_grad_y),",perc_loss:",p_loss)
total_loss = 10 + psnr_intensity + 0.5*(psnr_grad_x + psnr_grad_y) + p_loss
# total_loss = rmse_intensity + 0.5*(rmse_grad_x + rmse_grad_y) # + p_loss
return total_loss
def save_checkpoints(file_path, epoch_idx, model, model_solver):
......@@ -301,6 +336,18 @@ def save_checkpoints(file_path, epoch_idx, model, model_solver):
}
torch.save(checkpoint, file_path)
mode = "train"
import pickle
def save_obj(obj, name ):
# with open('./outputF/dict/'+ name + '.pkl', 'wb') as f:
# pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL)
torch.save(obj,'./outputF/dict/'+ name + '.pkl')
def load_obj(name):
# with open('./outputF/dict/' + name + '.pkl', 'rb') as f:
# return pickle.load(f)
return torch.load('./outputF/dict/'+ name + '.pkl')
def hook_fn_back(m, i, o):
for grad in i:
try:
......@@ -327,14 +374,11 @@ def hook_fn_for(m, i, o):
print ("None found for Gradient")
print("\n")
if __name__ == "__main__":
############################## generate phi and mask in pre-training
def generatePhiMaskDict(data_json, generator):
phi_dict = {}
mask_dict = {}
idx_info_dict = {}
print("generating phi and mask...")
with open(DATA_JSON, encoding='utf-8') as file:
with open(data_json, encoding='utf-8') as file:
dataset_desc = json.loads(file.read())
for i in range(len(dataset_desc["focaldepth"])):
# if i == 2:
......@@ -343,16 +387,34 @@ if __name__ == "__main__":
focaldepth = dataset_desc["focaldepth"][i]
gazeX = dataset_desc["gazeX"][i]
gazeY = dataset_desc["gazeY"][i]
# print("focaldepth:",focaldepth," idx:",idx," gazeX:",gazeX," gazeY:",gazeY)
phi,mask = gen.CalculateRetinal2LayerMappings(focaldepth,torch.tensor([gazeX, gazeY]))
print("focaldepth:",focaldepth," idx:",idx," gazeX:",gazeX," gazeY:",gazeY)
phi,mask = generator.CalculateRetinal2LayerMappings(focaldepth,torch.tensor([gazeX, gazeY]))
phi_dict[idx]=phi
mask_dict[idx]=mask
idx_info_dict[idx]=[idx,focaldepth,gazeX,gazeY]
print("generating phi and mask end.")
return phi_dict,mask_dict,idx_info_dict
if __name__ == "__main__":
############################## generate phi and mask in pre-training
# print("generating phi and mask...")
# phi_dict,mask_dict,idx_info_dict = generatePhiMaskDict(DATA_JSON,gen)
# save_obj(phi_dict,"phi_1204")
# save_obj(mask_dict,"mask_1204")
# save_obj(idx_info_dict,"idx_info_1204")
# print("generating phi and mask end.")
# exit(0)
############################# load phi and mask in pre-training
print("loading phi and mask ...")
phi_dict = load_obj("phi_1204")
mask_dict = load_obj("mask_1204")
idx_info_dict = load_obj("idx_info_1204")
print(len(phi_dict))
print(len(mask_dict))
print("loading phi and mask end")
#train
train_data_loader = torch.utils.data.DataLoader(dataset=lightFieldDataLoader(DATA_FILE,DATA_JSON),
train_data_loader = torch.utils.data.DataLoader(dataset=lightFieldSeqDataLoader(DATA_FILE,DATA_JSON),
batch_size=BATCH_SIZE,
num_workers=0,
pin_memory=True,
......@@ -362,49 +424,213 @@ if __name__ == "__main__":
# exit(0)
################################################ val #########################################################
# val_data_loader = torch.utils.data.DataLoader(dataset=lightFieldValDataLoader(DATA_FILE,DATA_VAL_JSON),
# batch_size=1,
# num_workers=0,
# pin_memory=True,
# shuffle=False,
# drop_last=False)
# print(len(val_data_loader))
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# lf_model = baseline.model()
# if torch.cuda.is_available():
# lf_model = torch.nn.DataParallel(lf_model).cuda()
# checkpoint = torch.load(os.path.join(OUTPUT_DIR,"gaze-ckpt-epoch-0201.pth"))
# lf_model.load_state_dict(checkpoint["model_state_dict"])
# lf_model.eval()
# print("Eval::")
# for sample_idx, (image_set, df, gazeX, gazeY, sample_idx) in enumerate(val_data_loader):
# print("sample_idx::",sample_idx)
# with torch.no_grad():
# #reshape for input
# image_set = image_set.permute(0,1,4,2,3) # N LF C H W
# image_set = image_set.reshape(image_set.shape[0],-1,image_set.shape[3],image_set.shape[4]) # N, LFxC, H, W
# image_set = var_or_cuda(image_set)
# # print("Epoch:",epoch,",Iter:",batch_idx,",Input shape:",image_set.shape, ",Input gt:",gt.shape)
# output = lf_model(image_set,df,gazeX,gazeY)
# output1,mask = GenRetinalGazeFromLayersBatch(output, gen, sample_idx, phi_dict, mask_dict)
# for i in range(0, 2):
# output1[:,i*3:i*3+3].mul_(mask[:,i:i+1])
# output1[:,i*3:i*3+3].clamp_(0., 1.)
# print("output:",output.shape," df:",df[0].data, ",gazeX:",gazeX[0].data,",gazeY:", gazeY[0].data)
# for i in range(output1.size()[0]):
# save_image(output[i][0:3].data,os.path.join(OUTPUT_DIR,"test_interp_gaze_fac1_o_%.3f_%.3f_%.3f.png"%(df[i].data,gazeX[i].data,gazeY[i].data)))
# save_image(output[i][3:6].data,os.path.join(OUTPUT_DIR,"test_interp_gaze_fac2_o_%.3f_%.3f_%.3f.png"%(df[i].data,gazeX[i].data,gazeY[i].data)))
# save_image(output1[i][0:3].data,os.path.join(OUTPUT_DIR,"test_interp_gaze_out1_o_%.3f_%.3f_%.3f.png"%(df[i].data,gazeX[i].data,gazeY[i].data)))
# save_image(output1[i][3:6].data,os.path.join(OUTPUT_DIR,"test_interp_gaze_out2_o_%.3f_%.3f_%.3f.png"%(df[i].data,gazeX[i].data,gazeY[i].data)))
# # save_image(output[0][0:3].data,os.path.join(OUTPUT_DIR,"gaze_fovea_interp_l1_%.3f.png"%(df[0].data)))
# # save_image(output[0][3:6].data,os.path.join(OUTPUT_DIR,"gaze_fovea_interp_l2_%.3f.png"%(df[0].data)))
# # output = GenRetinalFromLayersBatch(output,conf,df,v,u)
# # save_image(output[0][0:3].data,os.path.join(OUTPUT_DIR,"1113_interp_o%.3f.png"%(df[0].data)))
# exit()
################################################ train #########################################################
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
lf_model = model()
lf_model = model(FIRSST_LAYER_CHANNELS,LAST_LAYER_CHANNELS,OUT_CHANNELS_RB,KERNEL_SIZE,KERNEL_SIZE_RB,INTERLEAVE_RATE)
lf_model.apply(weight_init_normal)
epoch_begin = 0
################################ load model file
# WEIGHTS = os.path.join(OUTPUT_DIR, 'gaze-ckpt-epoch-%04d.pth' % (101))
# print('[INFO] Recovering from %s ...' % (WEIGHTS))
# checkpoint = torch.load(WEIGHTS)
# init_epoch = checkpoint['epoch_idx']
# lf_model.load_state_dict(checkpoint['model_state_dict'])
# epoch_begin = init_epoch + 1
# print(lf_model)
############################################################
if torch.cuda.is_available():
lf_model = torch.nn.DataParallel(lf_model).cuda()
# lf_model = torch.nn.DataParallel(lf_model).cuda()
lf_model = lf_model.to('cuda:1')
lf_model.train()
optimizer = torch.optim.Adam(lf_model.parameters(),lr=1e-2,betas=(0.9,0.999))
l1loss = torch.nn.L1Loss()
# lf_model.output_layer.register_backward_hook(hook_fn_back)
print("begin training....")
for epoch in range(epoch_begin, NUM_EPOCH):
for batch_idx, (image_set, gt, df, gazeX, gazeY, sample_idx) in enumerate(train_data_loader):
for batch_idx, (image_set, gt, gt2, df, gazeX, gazeY, sample_idx) in enumerate(train_data_loader):
# print(sample_idx.shape,df.shape,gazeX.shape,gazeY.shape) # torch.Size([2, 5])
# print(image_set.shape,gt.shape,gt2.shape) #torch.Size([2, 5, 9, 320, 320, 3]) torch.Size([2, 5, 160, 160, 3]) torch.Size([2, 5, 160, 160, 3])
# print(delta.shape) # delta: torch.Size([2, 4, 160, 160, 3])
#reshape for input
image_set = image_set.permute(0,1,4,2,3) # N LF C H W
image_set = image_set.reshape(image_set.shape[0],-1,image_set.shape[3],image_set.shape[4]) # N, LFxC, H, W
image_set = image_set.permute(0,1,2,5,3,4) # N S LF C H W
image_set = image_set.reshape(image_set.shape[0],image_set.shape[1],-1,image_set.shape[4],image_set.shape[5]) # N, LFxC, H, W
image_set = var_or_cuda(image_set)
gt = gt.permute(0,3,1,2)
gt = gt.permute(0,1,4,2,3) # N S C H W
gt = var_or_cuda(gt)
optimizer.zero_grad()
output = lf_model(image_set,df,gazeX,gazeY)
gt2 = gt2.permute(0,1,4,2,3)
gt2 = var_or_cuda(gt2)
gen1 = torch.empty(gt.shape)
gen1 = var_or_cuda(gen1)
gen2 = torch.empty(gt2.shape)
gen2 = var_or_cuda(gen2)
warped = torch.empty(gt2.shape[0],gt2.shape[1]-1,gt2.shape[2],gt2.shape[3],gt2.shape[4])
warped = var_or_cuda(warped)
delta = torch.empty(gt2.shape[0],gt2.shape[1]-1,gt2.shape[2],gt2.shape[3],gt2.shape[4])
delta = var_or_cuda(delta)
for k in range(image_set.shape[1]):
if k == 0:
lf_model.reset_hidden(image_set[:,k])
# start = torch.cuda.Event(enable_timing=True)
# end = torch.cuda.Event(enable_timing=True)
# start.record()
output = lf_model(image_set[:,k],df[:,k],gazeX[:,k],gazeY[:,k])
# end.record()
# torch.cuda.synchronize()
# print("Model Forward:",start.elapsed_time(end))
# print("output:",output.shape) # [2, 6, 320, 320]
# exit()
########################### Use Pregen Phi and Mask ###################
output1,mask = GenRetinalFromLayersBatch(output, gen, sample_idx, phi_dict, mask_dict)
mask = var_or_cuda(mask)
mask.requires_grad = False
output_f = output1 * mask
gt = gt * mask
loss = loss_new(output_f,gt)
print("Epoch:",epoch,",Iter:",batch_idx,",loss:",loss)
# start.record()
output1,mask = GenRetinalGazeFromLayersBatch(output, gen, sample_idx[:,k], phi_dict, mask_dict)
# end.record()
# torch.cuda.synchronize()
# print("Merge:",start.elapsed_time(end))
# print("output1 shape:",output1.shape, "mask shape:",mask.shape)
# output1 shape: torch.Size([2, 6, 160, 160]) mask shape: torch.Size([2, 2, 160, 160])
for i in range(0, 2):
output1[:,i*3:i*3+3].mul_(mask[:,i:i+1])
if i == 0:
gt[:,k].mul_(mask[:,i:i+1])
if i == 1:
gt2[:,k].mul_(mask[:,i:i+1])
gen1[:,k] = output1[:,0:3]
gen2[:,k] = output1[:,3:6]
if ((epoch%5== 0) or epoch == 2):
for i in range(output.shape[0]):
save_image(output[i][0:3].data,os.path.join(OUTPUT_DIR,"gaze_fac1_o_%.3f_%.3f_%.3f.png"%(df[i][k].data,gazeX[i][k].data,gazeY[i][k].data)))
save_image(output[i][3:6].data,os.path.join(OUTPUT_DIR,"gaze_fac2_o_%.3f_%.3f_%.3f.png"%(df[i][k].data,gazeX[i][k].data,gazeY[i][k].data)))
########################### Update ###################
for i in range(1,image_set.shape[1]):
delta[:,i-1] = gt2[:,i] - gt2[:,i]
warped[:,i-1] = gen2[:,i]-gen2[:,i-1]
optimizer.zero_grad()
# # N S C H W
gen1 = gen1.reshape(-1,gen1.shape[2],gen1.shape[3],gen1.shape[4])
gen2 = gen2.reshape(-1,gen2.shape[2],gen2.shape[3],gen2.shape[4])
gt = gt.reshape(-1,gt.shape[2],gt.shape[3],gt.shape[4])
gt2 = gt2.reshape(-1,gt2.shape[2],gt2.shape[3],gt2.shape[4])
warped = warped.reshape(-1,warped.shape[2],warped.shape[3],warped.shape[4])
delta = delta.reshape(-1,delta.shape[2],delta.shape[3],delta.shape[4])
# start = torch.cuda.Event(enable_timing=True)
# end = torch.cuda.Event(enable_timing=True)
# start.record()
loss1 = loss_new(gen1,gt)
loss2 = loss_new(gen2,gt2)
loss3 = l1loss(warped,delta)
loss = loss1+loss2+loss3
# end.record()
# torch.cuda.synchronize()
# print("loss comp:",start.elapsed_time(end))
# start.record()
loss.backward()
# end.record()
# torch.cuda.synchronize()
# print("backward:",start.elapsed_time(end))
# start.record()
optimizer.step()
# end.record()
# torch.cuda.synchronize()
# print("optimizer step:",start.elapsed_time(end))
## Update Prev
print("Epoch:",epoch,",Iter:",batch_idx,",loss:",loss)
########################### Save #####################
if ((epoch%50== 0) or epoch == 5):
for i in range(output_f.size()[0]):
save_image(output[i][0:3].data,os.path.join(OUTPUT_DIR,"gaze_fac1_o_%.3f_%.3f_%.3f.png"%(df[i].data,gazeX[i].data,gazeY[i].data)))
save_image(output[i][3:6].data,os.path.join(OUTPUT_DIR,"gaze_fac2_o_%.3f_%.3f_%.3f.png"%(df[i].data,gazeX[i].data,gazeY[i].data)))
save_image(output_f[i][0:3].data,os.path.join(OUTPUT_DIR,"gaze_test1_o_%.3f_%.3f_%.3f.png"%(df[i].data,gazeX[i].data,gazeY[i].data)))
if ((epoch%200 == 0) and epoch != 0 and batch_idx==len(train_data_loader)-1):
save_checkpoints(os.path.join(OUTPUT_DIR, 'gaze-ckpt-epoch-%04d.pth' % (epoch + 1)),
epoch,lf_model,optimizer)
\ No newline at end of file
if ((epoch%5== 0) or epoch == 2): # torch.Size([2, 5, 160, 160, 3])
for i in range(gt.size()[0]):
# df 2,5
save_image(gen1[i].data,os.path.join(OUTPUT_DIR,"gaze_out1_o_%.3f_%.3f_%.3f.png"%(df[i//5][i%5].data,gazeX[i//5][i%5].data,gazeY[i//5][i%5].data)))
save_image(gen2[i].data,os.path.join(OUTPUT_DIR,"gaze_out2_o_%.3f_%.3f_%.3f.png"%(df[i//5][i%5].data,gazeX[i//5][i%5].data,gazeY[i//5][i%5].data)))
save_image(gt[i].data,os.path.join(OUTPUT_DIR,"gaze_test1_gt0_%.3f_%.3f_%.3f.png"%(df[i//5][i%5].data,gazeX[i//5][i%5].data,gazeY[i//5][i%5].data)))
save_image(gt2[i].data,os.path.join(OUTPUT_DIR,"gaze_test1_gt1_%.3f_%.3f_%.3f.png"%(df[i//5][i%5].data,gazeX[i//5][i%5].data,gazeY[i//5][i%5].data)))
if ((epoch%100 == 0) and epoch != 0 and batch_idx==len(train_data_loader)-1):
save_checkpoints(os.path.join(OUTPUT_DIR, 'gaze-ckpt-epoch-%04d.pth' % (epoch + 1)),epoch,lf_model,optimizer)
########################## test Phi and Mask ##########################
# phi,mask = gen.CalculateRetinal2LayerMappings(df[0],torch.tensor([gazeX[0], gazeY[0]]))
# # print("gaze Online:",gazeX[0]," ,",gazeY[0])
# # print("df Online:",df[0])
# # print("idx:",int(sample_idx[0].data))
# phi_t = phi_dict[int(sample_idx[0].data)]
# mask_t = mask_dict[int(sample_idx[0].data)]
# # print("idx info:",idx_info_dict[int(sample_idx[0].data)])
# # print("phi online:", phi.shape, " phi_t:", phi_t.shape)
# # print("mask online:", mask.shape, " mask_t:", mask_t.shape)
# print("phi delta:", (phi-phi_t).sum()," mask delta:",(mask -mask_t).sum())
# exit(0)
###########################Gen Batch 1 by 1###################
# phi,mask = gen.CalculateRetinal2LayerMappings(df[0],torch.tensor([gazeX[0], gazeY[0]]))
# # print(phi.shape) # 2,240,320,41,2
# output1, mask = GenRetinalFromLayersBatch_Online(output, gen, phi, mask)
###########################Gen Batch 1 by 1###################
\ No newline at end of file
import torch
def var_or_cuda(x):
if torch.cuda.is_available():
# x = x.cuda(non_blocking=True)
x = x.to('cuda:1')
return x
class residual_block(torch.nn.Module):
def __init__(self, OUT_CHANNELS_RB, delta_channel_dim,KERNEL_SIZE_RB,RNN=False):
super(residual_block,self).__init__()
self.delta_channel_dim = delta_channel_dim
self.out_channels_rb = OUT_CHANNELS_RB
self.hidden = None
self.RNN = RNN
if self.RNN:
self.layer1 = torch.nn.Sequential(
torch.nn.Conv2d((OUT_CHANNELS_RB+delta_channel_dim)*2,OUT_CHANNELS_RB+delta_channel_dim,KERNEL_SIZE_RB,stride=1,padding = 1),
torch.nn.BatchNorm2d(OUT_CHANNELS_RB+delta_channel_dim),
torch.nn.ELU()
)
self.layer2 = torch.nn.Sequential(
torch.nn.Conv2d(OUT_CHANNELS_RB+delta_channel_dim,OUT_CHANNELS_RB+delta_channel_dim,KERNEL_SIZE_RB,stride=1,padding = 1),
torch.nn.BatchNorm2d(OUT_CHANNELS_RB+delta_channel_dim),
torch.nn.ELU()
)
else:
self.layer1 = torch.nn.Sequential(
torch.nn.Conv2d(OUT_CHANNELS_RB+delta_channel_dim,OUT_CHANNELS_RB+delta_channel_dim,KERNEL_SIZE_RB,stride=1,padding = 1),
torch.nn.BatchNorm2d(OUT_CHANNELS_RB+delta_channel_dim),
torch.nn.ELU()
)
self.layer2 = torch.nn.Sequential(
torch.nn.Conv2d(OUT_CHANNELS_RB+delta_channel_dim,OUT_CHANNELS_RB+delta_channel_dim,KERNEL_SIZE_RB,stride=1,padding = 1),
torch.nn.BatchNorm2d(OUT_CHANNELS_RB+delta_channel_dim),
torch.nn.ELU()
)
def forward(self,input):
if self.RNN:
# print("input:",input.shape,"hidden:",self.hidden.shape)
inp = torch.cat((input,self.hidden),dim=1)
# print(inp.shape)
output = self.layer1(inp)
output = self.layer2(output)
output = input+output
self.hidden = output
else:
output = self.layer1(input)
output = self.layer2(output)
output = input+output
return output
def reset_hidden(self, inp):
size = list(inp.size())
size[1] = self.delta_channel_dim + self.out_channels_rb
size[2] = size[2]//2
size[3] = size[3]//2
hidden = torch.zeros(*(size))
self.hidden = var_or_cuda(hidden)
class deinterleave(torch.nn.Module):
def __init__(self, block_size):
super(deinterleave, self).__init__()
self.block_size = block_size
self.block_size_sq = block_size*block_size
def forward(self, input):
output = input.permute(0, 2, 3, 1)
(batch_size, d_height, d_width, d_depth) = output.size()
s_depth = int(d_depth / self.block_size_sq)
s_width = int(d_width * self.block_size)
s_height = int(d_height * self.block_size)
t_1 = output.reshape(batch_size, d_height, d_width, self.block_size_sq, s_depth)
spl = t_1.split(self.block_size, 3)
stack = [t_t.reshape(batch_size, d_height, s_width, s_depth) for t_t in spl]
output = torch.stack(stack,0).transpose(0,1).permute(0,2,1,3,4).reshape(batch_size, s_height, s_width, s_depth)
output = output.permute(0, 3, 1, 2)
return output
class interleave(torch.nn.Module):
def __init__(self, block_size):
super(interleave, self).__init__()
self.block_size = block_size
self.block_size_sq = block_size*block_size
def forward(self, input):
output = input.permute(0, 2, 3, 1)
(batch_size, s_height, s_width, s_depth) = output.size()
d_depth = s_depth * self.block_size_sq
d_width = int(s_width / self.block_size)
d_height = int(s_height / self.block_size)
t_1 = output.split(self.block_size, 2)
stack = [t_t.reshape(batch_size, d_height, d_depth) for t_t in t_1]
output = torch.stack(stack, 1)
output = output.permute(0, 2, 1, 3)
output = output.permute(0, 3, 1, 2)
return output
class model(torch.nn.Module):
def __init__(self,FIRSST_LAYER_CHANNELS,LAST_LAYER_CHANNELS,OUT_CHANNELS_RB,KERNEL_SIZE,KERNEL_SIZE_RB,INTERLEAVE_RATE):
super(model, self).__init__()
self.interleave = interleave(INTERLEAVE_RATE)
self.first_layer = torch.nn.Sequential(
torch.nn.Conv2d(FIRSST_LAYER_CHANNELS,OUT_CHANNELS_RB,KERNEL_SIZE,stride=1,padding=1),
torch.nn.BatchNorm2d(OUT_CHANNELS_RB),
torch.nn.ELU()
)
self.residual_block1 = residual_block(OUT_CHANNELS_RB,0,KERNEL_SIZE_RB,False)
self.residual_block2 = residual_block(OUT_CHANNELS_RB,3,KERNEL_SIZE_RB,False)
self.residual_block3 = residual_block(OUT_CHANNELS_RB,3,KERNEL_SIZE_RB,True)
self.residual_block4 = residual_block(OUT_CHANNELS_RB,3,KERNEL_SIZE_RB,True)
self.residual_block5 = residual_block(OUT_CHANNELS_RB,3,KERNEL_SIZE_RB,True)
self.output_layer = torch.nn.Sequential(
torch.nn.Conv2d(OUT_CHANNELS_RB+3,LAST_LAYER_CHANNELS,KERNEL_SIZE,stride=1,padding=1),
torch.nn.BatchNorm2d(LAST_LAYER_CHANNELS),
torch.nn.Sigmoid()
)
self.deinterleave = deinterleave(INTERLEAVE_RATE)
def reset_hidden(self,inp):
self.residual_block3.reset_hidden(inp)
self.residual_block4.reset_hidden(inp)
self.residual_block5.reset_hidden(inp)
def forward(self, lightfield_images, focal_length, gazeX, gazeY):
# lightfield_images: torch.Size([batch_size, channels * D, H, W])
# channels : RGB*D: 3*9, H:256, W:256
# print("lightfield_images:",lightfield_images.shape)
input_to_net = self.interleave(lightfield_images)
# print("after interleave:",input_to_net.shape)
input_to_rb = self.first_layer(input_to_net)
# print("input_to_rb1:",input_to_rb.shape)
output = self.residual_block1(input_to_rb)
depth_layer = torch.ones((input_to_rb.shape[0],1,input_to_rb.shape[2],input_to_rb.shape[3]))
gazeX_layer = torch.ones((input_to_rb.shape[0],1,input_to_rb.shape[2],input_to_rb.shape[3]))
gazeY_layer = torch.ones((input_to_rb.shape[0],1,input_to_rb.shape[2],input_to_rb.shape[3]))
# print("depth_layer:",depth_layer.shape)
# print("focal_depth:",focal_length," gazeX:",gazeX," gazeY:",gazeY, " gazeX norm:",(gazeX[0] - (-3.333)) / (3.333*2))
for i in range(focal_length.shape[0]):
depth_layer[i] *= 1. / focal_length[i]
gazeX_layer[i] *= (gazeX[i] - (-3.333)) / (3.333*2)
gazeY_layer[i] *= (gazeY[i] - (-3.333)) / (3.333*2)
# print(depth_layer.shape)
depth_layer = var_or_cuda(depth_layer)
gazeX_layer = var_or_cuda(gazeX_layer)
gazeY_layer = var_or_cuda(gazeY_layer)
output = torch.cat((output,depth_layer,gazeX_layer,gazeY_layer),dim=1)
# output = torch.cat((output,depth_layer),dim=1)
# print("output to rb2:",output.shape)
output = self.residual_block2(output)
# print("output to rb3:",output.shape)
output = self.residual_block3(output)
# print("output to rb4:",output.shape)
output = self.residual_block4(output)
# print("output to rb5:",output.shape)
output = self.residual_block5(output)
# output = output + input_to_net
output = self.output_layer(output)
output = self.deinterleave(output)
return output
import torch, os, sys, cv2
import torch.nn as nn
from torch.nn import init
import functools
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.nn import functional as func
from PIL import Image
import torchvision.transforms as transforms
import numpy as np
import torch
class RecurrentBlock(nn.Module):
def __init__(self, input_nc, output_nc, downsampling=False, bottleneck=False, upsampling=False):
super(RecurrentBlock, self).__init__()
self.input_nc = input_nc
self.output_nc = output_nc
self.downsampling = downsampling
self.upsampling = upsampling
self.bottleneck = bottleneck
self.hidden = None
if self.downsampling:
self.l1 = nn.Sequential(
nn.Conv2d(input_nc, output_nc, 3, padding=1),
nn.LeakyReLU(negative_slope=0.1)
)
self.l2 = nn.Sequential(
nn.Conv2d(2 * output_nc, output_nc, 3, padding=1),
nn.LeakyReLU(negative_slope=0.1),
nn.Conv2d(output_nc, output_nc, 3, padding=1),
nn.LeakyReLU(negative_slope=0.1),
)
elif self.upsampling:
self.l1 = nn.Sequential(
nn.Upsample(scale_factor=2, mode='nearest'),
nn.Conv2d(2 * input_nc, output_nc, 3, padding=1),
nn.LeakyReLU(negative_slope=0.1),
nn.Conv2d(output_nc, output_nc, 3, padding=1),
nn.LeakyReLU(negative_slope=0.1),
)
elif self.bottleneck:
self.l1 = nn.Sequential(
nn.Conv2d(input_nc, output_nc, 3, padding=1),
nn.LeakyReLU(negative_slope=0.1)
)
self.l2 = nn.Sequential(
nn.Conv2d(2 * output_nc, output_nc, 3, padding=1),
nn.LeakyReLU(negative_slope=0.1),
nn.Conv2d(output_nc, output_nc, 3, padding=1),
nn.LeakyReLU(negative_slope=0.1),
)
def forward(self, inp):
if self.downsampling:
op1 = self.l1(inp)
op2 = self.l2(torch.cat((op1, self.hidden), dim=1))
self.hidden = op2
return op2
elif self.upsampling:
op1 = self.l1(inp)
return op1
elif self.bottleneck:
op1 = self.l1(inp)
op2 = self.l2(torch.cat((op1, self.hidden), dim=1))
self.hidden = op2
return op2
def reset_hidden(self, inp, dfac):
size = list(inp.size())
size[1] = self.output_nc
size[2] /= dfac
size[3] /= dfac
self.hidden_size = size
self.hidden = torch.zeros(*(size)).to('cuda:0')
class RecurrentAE(nn.Module):
def __init__(self, input_nc):
super(RecurrentAE, self).__init__()
self.d1 = RecurrentBlock(input_nc=input_nc, output_nc=32, downsampling=True)
self.d2 = RecurrentBlock(input_nc=32, output_nc=43, downsampling=True)
self.d3 = RecurrentBlock(input_nc=43, output_nc=57, downsampling=True)
self.d4 = RecurrentBlock(input_nc=57, output_nc=76, downsampling=True)
self.d5 = RecurrentBlock(input_nc=76, output_nc=101, downsampling=True)
self.bottleneck = RecurrentBlock(input_nc=101, output_nc=101, bottleneck=True)
self.u5 = RecurrentBlock(input_nc=101, output_nc=76, upsampling=True)
self.u4 = RecurrentBlock(input_nc=76, output_nc=57, upsampling=True)
self.u3 = RecurrentBlock(input_nc=57, output_nc=43, upsampling=True)
self.u2 = RecurrentBlock(input_nc=43, output_nc=32, upsampling=True)
self.u1 = RecurrentBlock(input_nc=32, output_nc=3, upsampling=True)
def set_input(self, inp):
self.inp = inp['A']
def forward(self):
d1 = func.max_pool2d(input=self.d1(self.inp), kernel_size=2)
d2 = func.max_pool2d(input=self.d2(d1), kernel_size=2)
d3 = func.max_pool2d(input=self.d3(d2), kernel_size=2)
d4 = func.max_pool2d(input=self.d4(d3), kernel_size=2)
d5 = func.max_pool2d(input=self.d5(d4), kernel_size=2)
b = self.bottleneck(d5)
u5 = self.u5(torch.cat((b, d5), dim=1))
u4 = self.u4(torch.cat((u5, d4), dim=1))
u3 = self.u3(torch.cat((u4, d3), dim=1))
u2 = self.u2(torch.cat((u3, d2), dim=1))
u1 = self.u1(torch.cat((u2, d1), dim=1))
return u1
def reset_hidden(self):
self.d1.reset_hidden(self.inp, dfac=1)
self.d2.reset_hidden(self.inp, dfac=2)
self.d3.reset_hidden(self.inp, dfac=4)
self.d4.reset_hidden(self.inp, dfac=8)
self.d5.reset_hidden(self.inp, dfac=16)
self.bottleneck.reset_hidden(self.inp, dfac=32)
self.u4.reset_hidden(self.inp, dfac=16)
self.u3.reset_hidden(self.inp, dfac=8)
self.u5.reset_hidden(self.inp, dfac=4)
self.u2.reset_hidden(self.inp, dfac=2)
self.u1.reset_hidden(self.inp, dfac=1)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment