Commit a22cfc90 authored by BobYeah's avatar BobYeah
Browse files

Update CNN scripts

parent c5434e97
import matplotlib.pyplot as plt
import torch
import util
import numpy as np
def FlowMap(b_last_frame, b_map):
'''
Map images using the flow data.
Parameters
--------
b_last_frame - B x 3 x H x W tensor, batch of images
b_map - B x H x W x 2, batch of map data records pixel coords in last frames
Returns
--------
B x 3 x H x W tensor, batch of images mapped by flow data
'''
return torch.nn.functional.grid_sample(b_last_frame, b_map, align_corners=False)
class Flow(object):
'''
Class representating optical flow
Properties
--------
b_data - B x H x W x 2, batch of flow data
b_map - B x H x W x 2, batch of map data records pixel coords in last frames
b_invalid_mask - B x H x W, batch of masks, indicate invalid elements in corresponding flow data
'''
def Load(paths):
'''
Create a Flow instance using a batch of encoded data images loaded from paths
Parameters
--------
paths - list of encoded data image paths
Returns
--------
Flow instance
'''
b_encoded_image = util.ReadImageTensor(paths, rgb_only=False, permute=False, batch_dim=True)
return Flow(b_encoded_image)
def __init__(self, b_encoded_image):
'''
Initialize a Flow instance from a batch of encoded data images
Parameters
--------
b_encoded_image - batch of encoded data images
'''
b_encoded_image = b_encoded_image.mul(255)
# print("b_encoded_image:",b_encoded_image.shape)
self.b_invalid_mask = (b_encoded_image[:, :, :, 0] == 255)
self.b_data = (b_encoded_image[:, :, :, 0:2] / 254 + b_encoded_image[:, :, :, 2:4] - 127) / 127
self.b_data[:, :, :, 1] = -self.b_data[:, :, :, 1]
D = self.b_data.size()
grid = util.MeshGrid((D[1], D[2]), True)
self.b_map = (grid - self.b_data - 0.5) * 2
self.b_map[self.b_invalid_mask] = torch.tensor([ -2.0, -2.0 ])
def getMap(self):
return self.b_map
def Visualize(self, scale_factor = 1):
'''
Visualize the flow data by "color wheel".
Parameters
--------
scale_factor - scale factor of flow data to visualize, default is 1
Returns
--------
B x 3 x H x W tensor, visualization of flow data
'''
try:
Flow.b_color_wheel
except AttributeError:
Flow.b_color_wheel = util.ReadImageTensor('color_wheel.png')
return torch.nn.functional.grid_sample(Flow.b_color_wheel.expand(self.b_data.size()[0], -1, -1, -1),
(self.b_data * scale_factor), align_corners=False)
\ No newline at end of file
......@@ -38,7 +38,7 @@ class residual_block(torch.nn.Module):
def forward(self,input):
if self.RNN:
# print("input:",input.shape,"hidden:",self.hidden.shape)
inp = torch.cat((input,self.hidden),dim=1)
inp = torch.cat((input,self.hidden.detach()),dim=1)
# print(inp.shape)
output = self.layer1(inp)
output = self.layer2(output)
......@@ -97,7 +97,7 @@ class interleave(torch.nn.Module):
return output
class model(torch.nn.Module):
def __init__(self,FIRSST_LAYER_CHANNELS,LAST_LAYER_CHANNELS,OUT_CHANNELS_RB,KERNEL_SIZE,KERNEL_SIZE_RB,INTERLEAVE_RATE):
def __init__(self,FIRSST_LAYER_CHANNELS,LAST_LAYER_CHANNELS,OUT_CHANNELS_RB,KERNEL_SIZE,KERNEL_SIZE_RB,INTERLEAVE_RATE,RNN=False):
super(model, self).__init__()
self.interleave = interleave(INTERLEAVE_RATE)
......@@ -108,13 +108,19 @@ class model(torch.nn.Module):
)
self.residual_block1 = residual_block(OUT_CHANNELS_RB,0,KERNEL_SIZE_RB,False)
self.residual_block2 = residual_block(OUT_CHANNELS_RB,3,KERNEL_SIZE_RB,False)
self.residual_block3 = residual_block(OUT_CHANNELS_RB,3,KERNEL_SIZE_RB,True)
self.residual_block4 = residual_block(OUT_CHANNELS_RB,3,KERNEL_SIZE_RB,True)
self.residual_block5 = residual_block(OUT_CHANNELS_RB,3,KERNEL_SIZE_RB,True)
self.residual_block2 = residual_block(OUT_CHANNELS_RB,2,KERNEL_SIZE_RB,False)
self.residual_block3 = residual_block(OUT_CHANNELS_RB,2,KERNEL_SIZE_RB,False)
# if RNN:
# self.residual_block3 = residual_block(OUT_CHANNELS_RB,6,KERNEL_SIZE_RB,True)
# self.residual_block4 = residual_block(OUT_CHANNELS_RB,6,KERNEL_SIZE_RB,True)
# self.residual_block5 = residual_block(OUT_CHANNELS_RB,6,KERNEL_SIZE_RB,True)
# else:
# self.residual_block3 = residual_block(OUT_CHANNELS_RB,6,KERNEL_SIZE_RB,False)
# self.residual_block4 = residual_block(OUT_CHANNELS_RB,6,KERNEL_SIZE_RB,False)
# self.residual_block5 = residual_block(OUT_CHANNELS_RB,6,KERNEL_SIZE_RB,False)
self.output_layer = torch.nn.Sequential(
torch.nn.Conv2d(OUT_CHANNELS_RB+3,LAST_LAYER_CHANNELS,KERNEL_SIZE,stride=1,padding=1),
torch.nn.Conv2d(OUT_CHANNELS_RB+2,LAST_LAYER_CHANNELS,KERNEL_SIZE,stride=1,padding=1),
torch.nn.BatchNorm2d(LAST_LAYER_CHANNELS),
torch.nn.Sigmoid()
)
......@@ -125,7 +131,7 @@ class model(torch.nn.Module):
self.residual_block4.reset_hidden(inp)
self.residual_block5.reset_hidden(inp)
def forward(self, lightfield_images, focal_length, gazeX, gazeY):
def forward(self, lightfield_images, pos_row, pos_col):
# lightfield_images: torch.Size([batch_size, channels * D, H, W])
# channels : RGB*D: 3*9, H:256, W:256
# print("lightfield_images:",lightfield_images.shape)
......@@ -136,32 +142,18 @@ class model(torch.nn.Module):
# print("input_to_rb1:",input_to_rb.shape)
output = self.residual_block1(input_to_rb)
depth_layer = torch.ones((input_to_rb.shape[0],1,input_to_rb.shape[2],input_to_rb.shape[3]))
gazeX_layer = torch.ones((input_to_rb.shape[0],1,input_to_rb.shape[2],input_to_rb.shape[3]))
gazeY_layer = torch.ones((input_to_rb.shape[0],1,input_to_rb.shape[2],input_to_rb.shape[3]))
# print("depth_layer:",depth_layer.shape)
# print("focal_depth:",focal_length," gazeX:",gazeX," gazeY:",gazeY, " gazeX norm:",(gazeX[0] - (-3.333)) / (3.333*2))
for i in range(focal_length.shape[0]):
depth_layer[i] *= 1. / focal_length[i]
gazeX_layer[i] *= (gazeX[i] - (-3.333)) / (3.333*2)
gazeY_layer[i] *= (gazeY[i] - (-3.333)) / (3.333*2)
pos_row_layer = torch.ones((input_to_rb.shape[0],1,input_to_rb.shape[2],input_to_rb.shape[3]))
pos_col_layer = torch.ones((input_to_rb.shape[0],1,input_to_rb.shape[2],input_to_rb.shape[3]))
for i in range(pos_row.shape[0]):
pos_row_layer[i] *= pos_row[i]
pos_col_layer[i] *= pos_col[i]
# print(depth_layer.shape)
depth_layer = var_or_cuda(depth_layer)
gazeX_layer = var_or_cuda(gazeX_layer)
gazeY_layer = var_or_cuda(gazeY_layer)
output = torch.cat((output,depth_layer,gazeX_layer,gazeY_layer),dim=1)
# output = torch.cat((output,depth_layer),dim=1)
# print("output to rb2:",output.shape)
pos_row_layer = var_or_cuda(pos_row_layer)
pos_col_layer = var_or_cuda(pos_col_layer)
output = torch.cat((output,pos_row_layer,pos_col_layer),dim=1)
output = self.residual_block2(output)
# print("output to rb3:",output.shape)
output = self.residual_block3(output)
# print("output to rb4:",output.shape)
output = self.residual_block4(output)
# print("output to rb5:",output.shape)
output = self.residual_block5(output)
# output = output + input_to_net
output = self.output_layer(output)
output = self.deinterleave(output)
return output
\ No newline at end of file
import torch
from gen_image import *
import util
import numpy as np
class Conf(object):
def __init__(self):
self.pupil_size = 0.02
self.retinal_res = torch.tensor([ 320, 320 ])
self.layer_res = torch.tensor([ 320, 320 ])
self.layer_hfov = 90 # layers' horizontal FOV
self.eye_hfov = 85 # eye's horizontal FOV (ignored in foveated rendering)
self.eye_enable_fovea = True # enable foveated rendering
self.eye_hfov = 80 # eye's horizontal FOV (ignored in foveated rendering)
self.eye_enable_fovea = False # enable foveated rendering
self.eye_fovea_angles = [ 40, 80 ] # eye's foveation layers' angles
self.eye_fovea_downsamples = [ 1, 2 ] # eye's foveation layers' downsamples
self.d_layer = [ 1, 3 ] # layers' distance
self.eye_fovea_blend = [ self._GenFoveaLayerBlend(0) ]
# blend maps of fovea layers
self.light_field_dim = 5
def GetNLayers(self):
return len(self.d_layer)
def GetLayerSize(self, i):
w = Fov2Length(self.layer_hfov)
w = util.Fov2Length(self.layer_hfov)
h = w * self.layer_res[0] / self.layer_res[1]
return torch.tensor([ h, w ]) * self.d_layer[i]
def GetPixelSizeOfLayer(self, i):
'''
Get pixel size of layer i
'''
return util.Fov2Length(self.layer_hfov) * self.d_layer[i] / self.layer_res[0]
def GetEyeViewportSize(self):
fov = self.eye_fovea_angles[-1] if self.eye_enable_fovea else self.eye_hfov
w = Fov2Length(fov)
w = util.Fov2Length(fov)
h = w * self.retinal_res[0] / self.retinal_res[1]
return torch.tensor([ h, w ])
def GetRegionOfFoveaLayer(self, i):
'''
Get region of fovea layer i in retinal image
Returns
--------
slice object stores the start and end of region
'''
roi_size = int(np.ceil(self.retinal_res[0] * self.eye_fovea_angles[i] / self.eye_fovea_angles[-1]))
roi_offset = int((self.retinal_res[0] - roi_size) / 2)
return slice(roi_offset, roi_offset + roi_size)
def _GenFoveaLayerBlend(self, i):
'''
Generate blend map for fovea layer i
Parameters
--------
i - index of fovea layer
Returns
--------
H[i] x W[i], blend map
'''
region = self.GetRegionOfFoveaLayer(i)
width = region.stop - region.start
R = width / 2
p = util.MeshGrid([ width, width ])
r = torch.linalg.norm(p - R, 2, dim=2, keepdim=False)
return util.SmoothStep(R, R * 0.6, r)
import torch
import os
import glob
import numpy as np
import torchvision.transforms as transforms
from torchvision import datasets
from torch.utils.data import DataLoader
import cv2
import json
from Flow import *
from gen_image import *
import util
import time
class lightFieldSynDataLoader(torch.utils.data.dataset.Dataset):
def __init__(self, file_dir_path,file_json):
self.file_dir_path = file_dir_path
with open(file_json, encoding='utf-8') as file:
self.dataset_desc = json.loads(file.read())
self.input_img = []
for i in self.dataset_desc["train"]:
lf_element = os.path.join(self.file_dir_path,i)
lf_element = cv2.imread(lf_element, -cv2.IMREAD_ANYDEPTH)[:, :, 0:3]
lf_element = cv2.cvtColor(lf_element, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.
lf_element = lf_element[:,:-1,:]
self.input_img.append(lf_element)
self.input_img = np.asarray(self.input_img)
def __len__(self):
return len(self.dataset_desc["gt"])
def __getitem__(self, idx):
gt, pos_row, pos_col = self.get_datum(idx)
return (self.input_img, gt, pos_row, pos_col)
def get_datum(self, idx):
fd_gt_path = os.path.join(self.file_dir_path, self.dataset_desc["gt"][idx])
gt = cv2.imread(fd_gt_path, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.
gt = cv2.cvtColor(gt,cv2.COLOR_BGR2RGB)
gt = gt[:,:-1,:]
pos_col = self.dataset_desc["x"][idx]
pos_row = self.dataset_desc["y"][idx]
return gt, pos_row, pos_col
class lightFieldDataLoader(torch.utils.data.dataset.Dataset):
def __init__(self, file_dir_path, file_json, transforms=None):
self.file_dir_path = file_dir_path
self.transforms = transforms
# self.datum_list = glob.glob(os.path.join(file_dir_path,"*"))
with open(file_json, encoding='utf-8') as file:
self.dataset_desc = json.loads(file.read())
def __len__(self):
return len(self.dataset_desc["focaldepth"])
def __getitem__(self, idx):
lightfield_images, gt, gt2, fd, gazeX, gazeY, sample_idx = self.get_datum(idx)
if self.transforms:
lightfield_images = self.transforms(lightfield_images)
# print(lightfield_images.shape,gt.shape,fd,gazeX,gazeY,sample_idx)
return (lightfield_images, gt, gt2, fd, gazeX, gazeY, sample_idx)
def get_datum(self, idx):
lf_image_paths = os.path.join(self.file_dir_path, self.dataset_desc["train"][idx])
# print(lf_image_paths)
fd_gt_path = os.path.join(self.file_dir_path, self.dataset_desc["gt"][idx])
fd_gt_path2 = os.path.join(self.file_dir_path, self.dataset_desc["gt2"][idx])
# print(fd_gt_path)
lf_images = []
lf_image_big = cv2.imread(lf_image_paths, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.
lf_image_big = cv2.cvtColor(lf_image_big,cv2.COLOR_BGR2RGB)
## IF GrayScale
# lf_image_big = cv2.imread(lf_image_paths, cv2.IMREAD_GRAYSCALE).astype(np.float32) / 255.
# lf_image_big = np.expand_dims(lf_image_big, axis=-1)
# print(lf_image_big.shape)
for i in range(9):
lf_image = lf_image_big[i//3*IM_H:i//3*IM_H+IM_H,i%3*IM_W:i%3*IM_W+IM_W,0:3]
## IF GrayScale
# lf_image = lf_image_big[i//3*IM_H:i//3*IM_H+IM_H,i%3*IM_W:i%3*IM_W+IM_W,0:1]
# print(lf_image.shape)
lf_images.append(lf_image)
gt = cv2.imread(fd_gt_path, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.
gt = cv2.cvtColor(gt,cv2.COLOR_BGR2RGB)
gt2 = cv2.imread(fd_gt_path2, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.
gt2 = cv2.cvtColor(gt2,cv2.COLOR_BGR2RGB)
## IF GrayScale
# gt = cv2.imread(fd_gt_path, cv2.IMREAD_GRAYSCALE).astype(np.float32) / 255.
# gt = np.expand_dims(gt, axis=-1)
fd = self.dataset_desc["focaldepth"][idx]
gazeX = self.dataset_desc["gazeX"][idx]
gazeY = self.dataset_desc["gazeY"][idx]
sample_idx = self.dataset_desc["idx"][idx]
return np.asarray(lf_images),gt,gt2,fd,gazeX,gazeY,sample_idx
class lightFieldValDataLoader(torch.utils.data.dataset.Dataset):
def __init__(self, file_dir_path, file_json, transforms=None):
self.file_dir_path = file_dir_path
self.transforms = transforms
# self.datum_list = glob.glob(os.path.join(file_dir_path,"*"))
with open(file_json, encoding='utf-8') as file:
self.dataset_desc = json.loads(file.read())
def __len__(self):
return len(self.dataset_desc["focaldepth"])
def __getitem__(self, idx):
lightfield_images, fd, gazeX, gazeY, sample_idx = self.get_datum(idx)
if self.transforms:
lightfield_images = self.transforms(lightfield_images)
# print(lightfield_images.shape,gt.shape,fd,gazeX,gazeY,sample_idx)
return (lightfield_images, fd, gazeX, gazeY, sample_idx)
def get_datum(self, idx):
lf_image_paths = os.path.join(self.file_dir_path, self.dataset_desc["train"][idx])
# print(fd_gt_path)
lf_images = []
lf_image_big = cv2.imread(lf_image_paths, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.
lf_image_big = cv2.cvtColor(lf_image_big,cv2.COLOR_BGR2RGB)
## IF GrayScale
# lf_image_big = cv2.imread(lf_image_paths, cv2.IMREAD_GRAYSCALE).astype(np.float32) / 255.
# lf_image_big = np.expand_dims(lf_image_big, axis=-1)
# print(lf_image_big.shape)
for i in range(9):
lf_image = lf_image_big[i//3*IM_H:i//3*IM_H+IM_H,i%3*IM_W:i%3*IM_W+IM_W,0:3]
## IF GrayScale
# lf_image = lf_image_big[i//3*IM_H:i//3*IM_H+IM_H,i%3*IM_W:i%3*IM_W+IM_W,0:1]
# print(lf_image.shape)
lf_images.append(lf_image)
## IF GrayScale
# gt = cv2.imread(fd_gt_path, cv2.IMREAD_GRAYSCALE).astype(np.float32) / 255.
# gt = np.expand_dims(gt, axis=-1)
fd = self.dataset_desc["focaldepth"][idx]
gazeX = self.dataset_desc["gazeX"][idx]
gazeY = self.dataset_desc["gazeY"][idx]
sample_idx = self.dataset_desc["idx"][idx]
return np.asarray(lf_images),fd,gazeX,gazeY,sample_idx
class lightFieldSeqDataLoader(torch.utils.data.dataset.Dataset):
def __init__(self, file_dir_path, file_json, transforms=None):
self.file_dir_path = file_dir_path
self.transforms = transforms
with open(file_json, encoding='utf-8') as file:
self.dataset_desc = json.loads(file.read())
def __len__(self):
return len(self.dataset_desc["seq"])
def __getitem__(self, idx):
lightfield_images, gt, gt2, fd, gazeX, gazeY, sample_idx = self.get_datum(idx)
fd = fd.astype(np.float32)
gazeX = gazeX.astype(np.float32)
gazeY = gazeY.astype(np.float32)
sample_idx = sample_idx.astype(np.int64)
# print(fd)
# print(gazeX)
# print(gazeY)
# print(sample_idx)
# print(lightfield_images.dtype,gt.dtype, gt2.dtype, fd.dtype, gazeX.dtype, gazeY.dtype, sample_idx.dtype, delta.dtype)
# print(lightfield_images.shape,gt.shape, gt2.shape, fd.shape, gazeX.shape, gazeY.shape, sample_idx.shape, delta.shape)
if self.transforms:
lightfield_images = self.transforms(lightfield_images)
return (lightfield_images, gt, gt2, fd, gazeX, gazeY, sample_idx)
def get_datum(self, idx):
indices = self.dataset_desc["seq"][idx]
# print("indices:",indices)
lf_images = []
fd = []
gazeX = []
gazeY = []
sample_idx = []
gt = []
gt2 = []
for i in range(len(indices)):
lf_image_paths = os.path.join(self.file_dir_path, self.dataset_desc["train"][indices[i]])
fd_gt_path = os.path.join(self.file_dir_path, self.dataset_desc["gt"][indices[i]])
fd_gt_path2 = os.path.join(self.file_dir_path, self.dataset_desc["gt2"][indices[i]])
lf_image_one_sample = []
lf_image_big = cv2.imread(lf_image_paths, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.
lf_image_big = cv2.cvtColor(lf_image_big,cv2.COLOR_BGR2RGB)
for j in range(9):
lf_image = lf_image_big[j//3*IM_H:j//3*IM_H+IM_H,j%3*IM_W:j%3*IM_W+IM_W,0:3]
lf_image_one_sample.append(lf_image)
gt_i = cv2.imread(fd_gt_path, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.
gt.append(cv2.cvtColor(gt_i,cv2.COLOR_BGR2RGB))
gt2_i = cv2.imread(fd_gt_path2, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.
gt2.append(cv2.cvtColor(gt2_i,cv2.COLOR_BGR2RGB))
# print("indices[i]:",indices[i])
fd.append([self.dataset_desc["focaldepth"][indices[i]]])
gazeX.append([self.dataset_desc["gazeX"][indices[i]]])
gazeY.append([self.dataset_desc["gazeY"][indices[i]]])
sample_idx.append([self.dataset_desc["idx"][indices[i]]])
lf_images.append(lf_image_one_sample)
#lf_images: 5,9,320,320
return np.asarray(lf_images),np.asarray(gt),np.asarray(gt2),np.asarray(fd),np.asarray(gazeX),np.asarray(gazeY),np.asarray(sample_idx)
class lightFieldFlowSeqDataLoader(torch.utils.data.dataset.Dataset):
def __init__(self, file_dir_path, file_json, gen, conf, transforms=None):
self.file_dir_path = file_dir_path
self.file_json = file_json
self.gen = gen
self.conf = conf
self.transforms = transforms
with open(file_json, encoding='utf-8') as file:
self.dataset_desc = json.loads(file.read())
def __len__(self):
return len(self.dataset_desc["seq"])
def __getitem__(self, idx):
# start = time.time()
lightfield_images, gt, phi, phi_invalid, retinal_invalid, flow, flow_invalid_mask, fd, gazeX, gazeY, posX, posY, posZ, sample_idx = self.get_datum(idx)
fd = fd.astype(np.float32)
gazeX = gazeX.astype(np.float32)
gazeY = gazeY.astype(np.float32)
posX = posX.astype(np.float32)
posY = posY.astype(np.float32)
posZ = posZ.astype(np.float32)
sample_idx = sample_idx.astype(np.int64)
if self.transforms:
lightfield_images = self.transforms(lightfield_images)
# print("read once:",time.time() - start) # 8 ms
return (lightfield_images, gt, phi, phi_invalid, retinal_invalid, flow, flow_invalid_mask, fd, gazeX, gazeY, posX, posY, posZ, sample_idx)
def get_datum(self, idx):
IM_H = 320
IM_W = 320
indices = self.dataset_desc["seq"][idx]
# print("indices:",indices)
lf_images = []
fd = []
gazeX = []
gazeY = []
posX = []
posY = []
posZ = []
sample_idx = []
gt = []
# gt2 = []
phi = []
phi_invalid = []
retinal_invalid = []
for i in range(len(indices)): # 5
lf_image_paths = os.path.join(self.file_dir_path, self.dataset_desc["train"][indices[i]])
fd_gt_path = os.path.join(self.file_dir_path, self.dataset_desc["gt"][indices[i]])
# fd_gt_path2 = os.path.join(self.file_dir_path, self.dataset_desc["gt2"][indices[i]])
lf_image_one_sample = []
lf_image_big = cv2.imread(lf_image_paths, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.
lf_image_big = cv2.cvtColor(lf_image_big,cv2.COLOR_BGR2RGB)
lf_dim = int(self.conf.light_field_dim)
for j in range(lf_dim**2):
lf_image = lf_image_big[j//lf_dim*IM_H:j//lf_dim*IM_H+IM_H,j%lf_dim*IM_W:j%lf_dim*IM_W+IM_W,0:3]
lf_image_one_sample.append(lf_image)
gt_i = cv2.imread(fd_gt_path, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.
gt.append(cv2.cvtColor(gt_i,cv2.COLOR_BGR2RGB))
# gt2_i = cv2.imread(fd_gt_path2, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.
# gt2.append(cv2.cvtColor(gt2_i,cv2.COLOR_BGR2RGB))
# print("indices[i]:",indices[i])
fd.append([self.dataset_desc["focaldepth"][indices[i]]])
gazeX.append([self.dataset_desc["gazeX"][indices[i]]])
gazeY.append([self.dataset_desc["gazeY"][indices[i]]])
posX.append([self.dataset_desc["x"][indices[i]]])
posY.append([self.dataset_desc["y"][indices[i]]])
posZ.append([0.0])
sample_idx.append([self.dataset_desc["idx"][indices[i]]])
lf_images.append(lf_image_one_sample)
idx_i = sample_idx[i][0]
focaldepth_i = fd[i][0]
gazeX_i = gazeX[i][0]
gazeY_i = gazeY[i][0]
posX_i = posX[i][0]
posY_i = posY[i][0]
posZ_i = posZ[i][0]
# print("x:%.3f,y:%.3f,z:%.3f;gaze:%.4f,%.4f,focaldepth:%.3f."%(posX_i,posY_i,posZ_i,gazeX_i,gazeY_i,focaldepth_i))
phi_i,phi_invalid_i,retinal_invalid_i = self.gen.CalculateRetinal2LayerMappings(torch.tensor([posX_i, posY_i,posZ_i]),torch.tensor([gazeX_i, gazeY_i]),focaldepth_i)
phi.append(phi_i)
phi_invalid.append(phi_invalid_i)
retinal_invalid.append(retinal_invalid_i)
#lf_images: 5,9,320,320
flow = Flow.Load([os.path.join(self.file_dir_path, self.dataset_desc["flow"][indices[i-1]]) for i in range(1,len(indices))])
flow_map = flow.getMap()
flow_invalid_mask = flow.b_invalid_mask
# print("flow:",flow_map.shape)
return np.asarray(lf_images),np.asarray(gt), torch.stack(phi,dim=0), torch.stack(phi_invalid,dim=0),torch.stack(retinal_invalid,dim=0), flow_map, flow_invalid_mask, np.asarray(fd),np.asarray(gazeX),np.asarray(gazeY),np.asarray(posX),np.asarray(posY),np.asarray(posZ),np.asarray(sample_idx)
......@@ -2,13 +2,8 @@ import matplotlib.pyplot as plt
import numpy as np
import torch
import glm
def Fov2Length(angle):
'''
'''
return np.tan(angle * np.pi / 360) * 2
import time
import util
def RandomGenSamplesInPupil(pupil_size, n_samples):
'''
......@@ -21,9 +16,9 @@ def RandomGenSamplesInPupil(pupil_size, n_samples):
Returns
--------
a n_samples x 2 tensor with 2D sample position in each row
a n_samples x 3 tensor with 3D sample position in each row
'''
samples = torch.empty(n_samples, 2)
samples = torch.empty(n_samples, 3)
i = 0
while i < n_samples:
s = (torch.rand(2) - 0.5) * pupil_size
......@@ -44,7 +39,7 @@ def GenSamplesInPupil(pupil_size, circles):
Returns
--------
a n_samples x 2 tensor with 2D sample position in each row
a n_samples x 3 tensor with 3D sample position in each row
'''
samples = torch.zeros(1, 3)
for i in range(1, circles):
......@@ -70,7 +65,7 @@ class RetinalGen(object):
Methods
--------
'''
def __init__(self, conf, u):
def __init__(self, conf):
'''
Initialize retinal generator instance
......@@ -80,58 +75,83 @@ class RetinalGen(object):
u - a M x 3 tensor stores M sample positions in pupil
'''
self.conf = conf
self.u = GenSamplesInPupil(conf.pupil_size, 5)
# self.u = u.to(cuda_dev)
self.u = u # M x 3 M sample positions
# self.u = u # M x 3 M sample positions
self.D_r = conf.retinal_res # retinal res 480 x 640
self.N = conf.GetNLayers() # 2
self.M = u.size()[0] # samples
p_rx, p_ry = torch.meshgrid(torch.tensor(range(0, self.D_r[0])),
torch.tensor(range(0, self.D_r[1])))
self.M = self.u.size()[0] # samples
# p_rx, p_ry = torch.meshgrid(torch.tensor(range(0, self.D_r[0])),
# torch.tensor(range(0, self.D_r[1])))
# self.p_r = torch.cat([
# ((torch.stack([p_rx, p_ry], 2) + 0.5) / self.D_r - 0.5) * conf.GetEyeViewportSize(), # 眼球视野
# torch.ones(self.D_r[0], self.D_r[1], 1)
# ], 2)
self.p_r = torch.cat([
((torch.stack([p_rx, p_ry], 2) + 0.5) / self.D_r - 0.5) * conf.GetEyeViewportSize(), # 眼球视野
((util.MeshGrid(self.D_r) + 0.5) / self.D_r - 0.5) * conf.GetEyeViewportSize(),
torch.ones(self.D_r[0], self.D_r[1], 1)
], 2)
# self.Phi = torch.empty(N, D_r[0], D_r[1], M, 2, device=cuda_dev, dtype=torch.long)
# self.mask = torch.empty(self.N, self.D_r[0], self.D_r[1], self.M, 2, dtype=torch.float) # 2 x 480 x 640 x 41 x 2
def CalculateRetinal2LayerMappings(self, df, gaze):
def CalculateRetinal2LayerMappings(self, position, gaze_dir, df):
'''
Calculate the mapping matrix from retinal to layers.
Parameters
--------
position - 1 x 3 tensor, eye's position
gaze_dir - 1 x 2 tensor, gaze forward vector (with z normalized)
df - focus distance
gaze - 2 x 1 tensor, eye rotation angle (degs) in horizontal and vertical direction
Returns
--------
phi - N x H_r x W_r x M x 2, retinal to layers mapping, N is number of layers
phi_invalid - N x H_r x W_r x M x 1, indicates invalid (out-of-range) mapping
retinal_invalid - 1 x H_r x W_r, indicates invalid pixels in retinal image
'''
Phi = torch.empty(self.N, self.D_r[0], self.D_r[1], self.M, 2, dtype=torch.long) # 2 x 480 x 640 x 41 x 2
mask = torch.empty(self.N, self.D_r[0], self.D_r[1], self.M, 2, dtype=torch.float)
D = self.conf.layer_res
c = torch.tensor([ D[1] / 2, D[0] / 2 ]) # c: Center of layers (pixel)
D_r = self.conf.retinal_res # D_r: Resolution of retinal 480 640
V = self.conf.GetEyeViewportSize() # V: Viewport size of eye
c = (self.conf.layer_res / 2) # c: Center of layers (pixel)
p_f = self.p_r * df # p_f: H x W x 3, focus positions of retinal pixels on focus plane
rot_forward = glm.dvec3(glm.tan(glm.radians(glm.dvec2(gaze[1], -gaze[0]))), 1)
rot_mat = torch.from_numpy(np.array(
glm.dmat3(glm.lookAtLH(glm.dvec3(), rot_forward, glm.dvec3(0, 1, 0)))))
rot_mat = rot_mat.float()
u_rot = torch.mm(self.u, rot_mat)
v_rot = torch.matmul(p_f, rot_mat).unsqueeze(2).expand(
-1, -1, self.u.size()[0], -1) - u_rot # v_rot: H x W x M x 3, rotated rays' direction vector
v_rot.div_(v_rot[:, :, :, 2].unsqueeze(3)) # make z = 1 for each direction vector in v_rot
for i in range(0, self.conf.GetNLayers()):
dp_i = self.conf.GetLayerSize(i)[0] / self.conf.layer_res[0] # dp_i: Pixel size of layer i
# Calculate transformation from eye to display
gvec_lookat = glm.dvec3(gaze_dir[0], -gaze_dir[1], 1)
gmat_eye = glm.inverse(glm.lookAtLH(glm.dvec3(), gvec_lookat, glm.dvec3(0, 1, 0)))
eye_rot = util.Glm2Tensor(glm.dmat3(gmat_eye))
eye_center = torch.tensor([ position[0], -position[1], position[2] ])
u_rot = torch.mm(self.u, eye_rot)
v_rot = torch.matmul(p_f, eye_rot).unsqueeze(2).expand(
-1, -1, self.M, -1) - u_rot # v_rot: H x W x M x 3, rotated rays' direction vector
u_rot.add_(eye_center) # translate by eye's center
v_rot = v_rot.div(v_rot[:, :, :, 2].unsqueeze(3)) # make z = 1 for each direction vector in v_rot
phi = torch.empty(self.N, self.D_r[0], self.D_r[1], self.M, 2, dtype=torch.long)
for i in range(0, self.N):
dp_i = self.conf.GetPixelSizeOfLayer(i) # dp_i: Pixel size of layer i
d_i = self.conf.d_layer[i] # d_i: Distance of layer i
k = (d_i - u_rot[:, 2]).unsqueeze(1)
pi_r = (u_rot[:, 0:2] + v_rot[:, :, :, 0:2] * k) / dp_i # pi_r: H x W x M x 2, rays' pixel coord on layer i
Phi[i, :, :, :, :] = torch.floor(pi_r + c)
mask[:, :, :, :, 0] = ((Phi[:, :, :, :, 0] >= 0) & (Phi[:, :, :, :, 0] < self.conf.layer_res[0])).float()
mask[:, :, :, :, 1] = ((Phi[:, :, :, :, 1] >= 0) & (Phi[:, :, :, :, 1] < self.conf.layer_res[1])).float()
Phi[:, :, :, :, 0].clamp_(0, self.conf.layer_res[0] - 1)
Phi[:, :, :, :, 1].clamp_(0, self.conf.layer_res[1] - 1)
retinal_mask = mask.prod(0).prod(2).prod(2)
return [ Phi, retinal_mask ]
phi[i, :, :, :, :] = torch.floor(pi_r + c)
# Calculate invalid mask (out-of-range elements in phi) and reduced to retinal
phi_invalid = (phi[:, :, :, :, 0] < 0) | (phi[:, :, :, :, 0] >= D[1]) | \
(phi[:, :, :, :, 1] < 0) | (phi[:, :, :, :, 1] >= D[0])
phi_invalid = phi_invalid.unsqueeze(4)
# print("phi_invalid:",phi_invalid.shape)
retinal_invalid = phi_invalid.amax((0, 3)).squeeze().unsqueeze(0)
# print("retinal_invalid:",retinal_invalid.shape)
# Fix invalid elements in phi
phi[phi_invalid.expand(-1, -1, -1, -1, 2)] = 0
return [ phi, phi_invalid, retinal_invalid ]
def GenRetinalFromLayers(self, layers, Phi):
'''
......@@ -139,28 +159,159 @@ class RetinalGen(object):
Parameters
--------
layers - 3N x H_l x W_l tensor, stacked layer images, with 3 channels in each layer
layers - 3N x H x W, stacked layer images, with 3 channels in each layer
phi - N x H_r x W_r x M x 2, retinal to layers mapping, N is number of layers
Returns
--------
3 x H_r x W_r tensor, 3 channels retinal image
H_r x W_r tensor, retinal image mask, indicates pixels valid or not
3 x H_r x W_r, 3 channels retinal image
'''
# FOR GRAYSCALE 1 FOR RGB 3
mapped_layers = torch.empty(self.N, 3, self.D_r[0], self.D_r[1], self.M) # 2 x 3 x 480 x 640 x 41
# print("mapped_layers:",mapped_layers.shape)
for i in range(0, Phi.size()[0]):
# torch.Size([3, 2, 320, 320, 2])
# print("gather layers:",layers[(i * 3) : (i * 3 + 3),Phi[i, :, :, :, 0],Phi[i, :, :, :, 1]].shape)
mapped_layers[i, :, :, :, :] = layers[(i * 3) : (i * 3 + 3),
Phi[i, :, :, :, 0],
Phi[i, :, :, :, 1]]
Phi[i, :, :, :, 1],
Phi[i, :, :, :, 0]]
# print("mapped_layers:",mapped_layers.shape)
retinal = mapped_layers.prod(0).sum(3).div(Phi.size()[3])
# print("retinal:",retinal.shape)
return retinal
def GenFoveaLayers(self, retinal, retinal_mask):
def GenRetinalFromLayersBatch(self, layers, Phi):
'''
Generate retinal image from layers, using precalculated mapping matrix
Parameters
--------
layers - 3N x H_l x W_l tensor, stacked layer images, with 3 channels in each layer
Returns
--------
3 x H_r x W_r tensor, 3 channels retinal image
H_r x W_r tensor, retinal image mask, indicates pixels valid or not
'''
mapped_layers = torch.empty(layers.size()[0], self.N, 3, self.D_r[0], self.D_r[1], self.M) #BS x Layers x C x H x W x Sample
# truth = torch.empty(layers.size()[0], self.N, 3, self.D_r[0], self.D_r[1], self.M)
# layers_truth = layers.clone()
# Phi_truth = Phi.clone()
layers = torch.stack((layers[:,0:3,:,:],layers[:,3:6,:,:]),dim=1) ## torch.Size([BS, Layer, RGB 3, 320, 320])
# Phi = Phi[:,:,None,:,:,:,:].expand(-1,-1,3,-1,-1,-1,-1)
# print("mapped_layers:",mapped_layers.shape) #torch.Size([2, 2, 3, 320, 320, 41])
# print("input layers:",layers.shape) ## torch.Size([2, 2, 3, 320, 320])
# print("input Phi:",Phi.shape) #torch.Size([2, 2, 320, 320, 41, 2])
# #没优化
# for i in range(0, Phi_truth.size()[0]):
# for j in range(0, Phi_truth.size()[1]):
# truth[i, j, :, :, :, :] = layers_truth[i, (j * 3) : (j * 3 + 3),
# Phi_truth[i, j, :, :, :, 0],
# Phi_truth[i, j, :, :, :, 1]]
#优化2
# start = time.time()
mapped_layers_op1 = mapped_layers.reshape(-1,
mapped_layers.shape[2],mapped_layers.shape[3],mapped_layers.shape[4],mapped_layers.shape[5])
# BatchSizexLayer Channel 3 320 320 41
layers_op1 = layers.reshape(-1,layers.shape[2],layers.shape[3],layers.shape[4]) # 2x2 3 320 320
Phi_op1 = Phi.reshape(-1,Phi.shape[2],Phi.shape[3],Phi.shape[4],Phi.shape[5]) # 2x2 320 320 41 2
x = Phi_op1[:,:,:,:,0] # 2x2 320 320 41
y = Phi_op1[:,:,:,:,1] # 2x2 320 320 41
# print("reshape:",time.time() - start)
# start = time.time()
mapped_layers_op1 = layers_op1[torch.arange(layers_op1.shape[0])[:, None, None, None], :, y, x] # x,y 切换
#2x2 320 320 41 3
# print("mapping one step:",time.time() - start)
# print("mapped_layers:",mapped_layers_op1.shape) # torch.Size([4, 3, 320, 320, 41])
# start = time.time()
mapped_layers_op1 = mapped_layers_op1.permute(0,4,1,2,3)
mapped_layers = mapped_layers_op1.reshape(mapped_layers.shape[0],mapped_layers.shape[1],
mapped_layers.shape[2],mapped_layers.shape[3],mapped_layers.shape[4],mapped_layers.shape[5])
# print("reshape end:",time.time() - start)
# print("test:")
# print((truth.cpu() == mapped_layers.cpu()).all())
#优化1
# start = time.time()
# mapped_layers_op1 = mapped_layers.reshape(-1,
# mapped_layers.shape[2],mapped_layers.shape[3],mapped_layers.shape[4],mapped_layers.shape[5])
# layers_op1 = layers.reshape(-1,layers.shape[2],layers.shape[3],layers.shape[4])
# Phi_op1 = Phi.reshape(-1,Phi.shape[2],Phi.shape[3],Phi.shape[4],Phi.shape[5])
# print("reshape:",time.time() - start)
# for i in range(0, Phi_op1.size()[0]):
# start = time.time()
# mapped_layers_op1[i, :, :, :, :] = layers_op1[i,:,
# Phi_op1[i, :, :, :, 0],
# Phi_op1[i, :, :, :, 1]]
# print("mapping one step:",time.time() - start)
# print("mapped_layers:",mapped_layers_op1.shape) # torch.Size([4, 3, 320, 320, 41])
# start = time.time()
# mapped_layers = mapped_layers_op1.reshape(mapped_layers.shape[0],mapped_layers.shape[1],
# mapped_layers.shape[2],mapped_layers.shape[3],mapped_layers.shape[4],mapped_layers.shape[5])
# print("reshape end:",time.time() - start)
# print("mapped_layers:",mapped_layers.shape) # torch.Size([2, 2, 3, 320, 320, 41])
retinal = mapped_layers.prod(1).sum(4).div(Phi.size()[4])
# print("retinal:",retinal.shape) # torch.Size([BatchSize, 3, 320, 320])
return retinal
## TO BE CHECK
def GenFoveaLayers(self, b_retinal, is_mask):
'''
Generate foveated layers for retinal images or masks
Parameters
--------
b_retinal - B x C x H_r x W_r, Batch of retinal images/masks
is_mask - Whether b_retinal is masks or images
Returns
--------
b_fovea_layers - N_f x (B x C x H[f] x W[f]) list of batch of foveated layers
'''
b_fovea_layers = []
for i in range(0, len(self.conf.eye_fovea_angles)):
k = self.conf.eye_fovea_downsamples[i]
region = self.conf.GetRegionOfFoveaLayer(i)
b_roi = b_retinal[:, :, region, region]
if k == 1:
b_fovea_layers.append(b_roi)
elif is_mask:
b_fovea_layers.append(torch.nn.functional.max_pool2d(b_roi.to(torch.float), k).to(torch.bool))
else:
b_fovea_layers.append(torch.nn.functional.avg_pool2d(b_roi, k))
return b_fovea_layers
# fovea_layers = []
# fovea_layer_masks = []
# fov = self.conf.eye_fovea_angles[-1]
# retinal_res = int(self.conf.retinal_res[0])
# for i in range(0, len(self.conf.eye_fovea_angles)):
# angle = self.conf.eye_fovea_angles[i]
# k = self.conf.eye_fovea_downsamples[i]
# roi_size = int(np.ceil(retinal_res * angle / fov))
# roi_offset = int((retinal_res - roi_size) / 2)
# roi_img = retinal[:, roi_offset:(roi_offset + roi_size), roi_offset:(roi_offset + roi_size)]
# roi_mask = retinal_mask[roi_offset:(roi_offset + roi_size), roi_offset:(roi_offset + roi_size)]
# if k == 1:
# fovea_layers.append(roi_img)
# fovea_layer_masks.append(roi_mask)
# else:
# fovea_layers.append(torch.nn.functional.avg_pool2d(roi_img.unsqueeze(0), k).squeeze(0))
# fovea_layer_masks.append(1 - torch.nn.functional.max_pool2d((1 - roi_mask).unsqueeze(0), k).squeeze(0))
# return [ fovea_layers, fovea_layer_masks ]
## TO BE CHECK
def GenFoveaLayersBatch(self, retinal, retinal_mask):
'''
Generate foveated layers and corresponding masks
......@@ -177,18 +328,48 @@ class RetinalGen(object):
fovea_layers = []
fovea_layer_masks = []
fov = self.conf.eye_fovea_angles[-1]
# print("fov:",fov)
retinal_res = int(self.conf.retinal_res[0])
# print("retinal_res:",retinal_res)
# print("len(self.conf.eye_fovea_angles):",len(self.conf.eye_fovea_angles))
for i in range(0, len(self.conf.eye_fovea_angles)):
angle = self.conf.eye_fovea_angles[i]
k = self.conf.eye_fovea_downsamples[i]
roi_size = int(np.ceil(retinal_res * angle / fov))
roi_offset = int((retinal_res - roi_size) / 2)
roi_img = retinal[:, roi_offset:(roi_offset + roi_size), roi_offset:(roi_offset + roi_size)]
roi_mask = retinal_mask[roi_offset:(roi_offset + roi_size), roi_offset:(roi_offset + roi_size)]
# [2, 3, 320, 320]
roi_img = retinal[:, :, roi_offset:(roi_offset + roi_size), roi_offset:(roi_offset + roi_size)]
# print("roi_img:",roi_img.shape)
# [2, 320, 320]
roi_mask = retinal_mask[:, roi_offset:(roi_offset + roi_size), roi_offset:(roi_offset + roi_size)]
# print("roi_mask:",roi_mask.shape)
if k == 1:
fovea_layers.append(roi_img)
fovea_layer_masks.append(roi_mask)
else:
fovea_layers.append(torch.nn.functional.avg_pool2d(roi_img.unsqueeze(0), k).squeeze(0))
fovea_layer_masks.append(1 - torch.nn.functional.max_pool2d((1 - roi_mask).unsqueeze(0), k).squeeze(0))
fovea_layers.append(torch.nn.functional.avg_pool2d(roi_img, k))
fovea_layer_masks.append(1 - torch.nn.functional.max_pool2d((1 - roi_mask), k))
return [ fovea_layers, fovea_layer_masks ]
## TO BE CHECK
def GenFoveaRetinal(self, b_fovea_layers):
'''
Generate foveated retinal image by blending fovea layers
**Note: current implementation only support two fovea layers**
Parameters
--------
b_fovea_layers - N_f x (B x 3 x H[f] x W[f]), list of batch of (masked) foveated layers
Returns
--------
B x 3 x H_r x W_r, batch of foveated retinal images
'''
b_fovea_retinal = torch.nn.functional.interpolate(b_fovea_layers[1],
scale_factor=self.conf.eye_fovea_downsamples[1],
mode='bilinear', align_corners=False)
region = self.conf.GetRegionOfFoveaLayer(0)
blend = self.conf.eye_fovea_blend[0]
b_roi = b_fovea_retinal[:, :, region, region]
b_roi.mul_(1 - blend).add_(b_fovea_layers[0] * blend)
return b_fovea_retinal
import torch
from ssim import *
from perc_loss import *
l1loss = torch.nn.L1Loss()
perc_loss = VGGPerceptualLoss()
perc_loss = perc_loss.to("cuda:1")
##### LOSS #####
def calImageGradients(images):
# x is a 4-D tensor
dx = images[:, :, 1:, :] - images[:, :, :-1, :]
dy = images[:, :, :, 1:] - images[:, :, :, :-1]
return dx, dy
def loss_new(generated, gt):
mse_loss = torch.nn.MSELoss()
rmse_intensity = mse_loss(generated, gt)
psnr_intensity = torch.log10(rmse_intensity)
# print("psnr:",psnr_intensity)
# ssim_intensity = ssim(generated, gt)
labels_dx, labels_dy = calImageGradients(gt)
# print("generated:",generated.shape)
preds_dx, preds_dy = calImageGradients(generated)
rmse_grad_x, rmse_grad_y = mse_loss(labels_dx, preds_dx), mse_loss(labels_dy, preds_dy)
psnr_grad_x, psnr_grad_y = torch.log10(rmse_grad_x), torch.log10(rmse_grad_y)
# print("psnr x&y:",psnr_grad_x," ",psnr_grad_y)
p_loss = perc_loss(generated,gt)
# print("-psnr:",-psnr_intensity,",0.5*(psnr_grad_x + psnr_grad_y):",0.5*(psnr_grad_x + psnr_grad_y),",perc_loss:",p_loss)
total_loss = psnr_intensity + 0.5*(psnr_grad_x + psnr_grad_y) + p_loss
# total_loss = rmse_intensity + 0.5*(rmse_grad_x + rmse_grad_y) # + p_loss
return total_loss
def loss_without_perc(generated, gt):
mse_loss = torch.nn.MSELoss()
rmse_intensity = mse_loss(generated, gt)
psnr_intensity = torch.log10(rmse_intensity)
# print("psnr:",psnr_intensity)
# ssim_intensity = ssim(generated, gt)
labels_dx, labels_dy = calImageGradients(gt)
# print("generated:",generated.shape)
preds_dx, preds_dy = calImageGradients(generated)
rmse_grad_x, rmse_grad_y = mse_loss(labels_dx, preds_dx), mse_loss(labels_dy, preds_dy)
psnr_grad_x, psnr_grad_y = torch.log10(rmse_grad_x), torch.log10(rmse_grad_y)
# print("psnr x&y:",psnr_grad_x," ",psnr_grad_y)
# print("-psnr:",-psnr_intensity,",0.5*(psnr_grad_x + psnr_grad_y):",0.5*(psnr_grad_x + psnr_grad_y),",perc_loss:",p_loss)
total_loss = psnr_intensity + 0.5*(psnr_grad_x + psnr_grad_y)
# total_loss = rmse_intensity + 0.5*(rmse_grad_x + rmse_grad_y) # + p_loss
return total_loss
##### LOSS #####
class ReconstructionLoss(torch.nn.Module):
def __init__(self):
super(ReconstructionLoss, self).__init__()
def forward(self, generated, gt):
rmse_intensity = torch.nn.functional.mse_loss(generated, gt)
psnr_intensity = torch.log10(rmse_intensity)
labels_dx, labels_dy = calImageGradients(gt)
preds_dx, preds_dy = calImageGradients(generated)
rmse_grad_x, rmse_grad_y = torch.nn.functional.mse_loss(labels_dx, preds_dx), torch.nn.functional.mse_loss(labels_dy, preds_dy)
psnr_grad_x, psnr_grad_y = torch.log10(rmse_grad_x), torch.log10(rmse_grad_y)
total_loss = psnr_intensity + 0.5*(psnr_grad_x + psnr_grad_y)
return total_loss
class PerceptionReconstructionLoss(torch.nn.Module):
def __init__(self):
super(PerceptionReconstructionLoss, self).__init__()
def forward(self, generated, gt):
rmse_intensity = torch.nn.functional.mse_loss(generated, gt)
psnr_intensity = torch.log10(rmse_intensity)
labels_dx, labels_dy = calImageGradients(gt)
preds_dx, preds_dy = calImageGradients(generated)
rmse_grad_x, rmse_grad_y = torch.nn.functional.mse_loss(labels_dx, preds_dx), torch.nn.functional.mse_loss(labels_dy, preds_dy)
psnr_grad_x, psnr_grad_y = torch.log10(rmse_grad_x), torch.log10(rmse_grad_y)
p_loss = perc_loss(generated,gt)
total_loss = psnr_intensity + 0.5*(psnr_grad_x + psnr_grad_y) + p_loss
return total_loss
......@@ -12,212 +12,33 @@ from torch.autograd import Variable
import cv2
from gen_image import *
from loss import *
import json
from ssim import *
from perc_loss import *
from conf import Conf
from model.baseline import *
from baseline import *
from data import *
import torch.autograd.profiler as profiler
# param
BATCH_SIZE = 2
NUM_EPOCH = 300
BATCH_SIZE = 1
NUM_EPOCH = 1001
INTERLEAVE_RATE = 2
IM_H = 320
IM_W = 320
Retinal_IM_H = 320
Retinal_IM_W = 320
N = 9 # number of input light field stack
N = 25 # number of input light field stack
M = 2 # number of display layers
DATA_FILE = "/home/yejiannan/Project/LightField/data/gaze_fovea"
DATA_JSON = "/home/yejiannan/Project/LightField/data/data_gaze_fovea_seq.json"
DATA_VAL_JSON = "/home/yejiannan/Project/LightField/data/data_gaze_fovea_val.json"
OUTPUT_DIR = "/home/yejiannan/Project/LightField/outputE/gaze_fovea_seq"
DATA_FILE = "/home/yejiannan/Project/LightField/data/FlowRPG1211"
DATA_JSON = "/home/yejiannan/Project/LightField/data/data_gaze_fovea_seq_flow_RPG.json"
# DATA_VAL_JSON = "/home/yejiannan/Project/LightField/data/data_gaze_fovea_val.json"
OUTPUT_DIR = "/home/yejiannan/Project/LightField/outputE/gaze_fovea_seq_flow_RPG_seq5_same_loss"
OUT_CHANNELS_RB = 128
KERNEL_SIZE_RB = 3
KERNEL_SIZE = 3
LAST_LAYER_CHANNELS = 6 * INTERLEAVE_RATE**2
FIRSST_LAYER_CHANNELS = 27 * INTERLEAVE_RATE**2
class lightFieldDataLoader(torch.utils.data.dataset.Dataset):
def __init__(self, file_dir_path, file_json, transforms=None):
self.file_dir_path = file_dir_path
self.transforms = transforms
# self.datum_list = glob.glob(os.path.join(file_dir_path,"*"))
with open(file_json, encoding='utf-8') as file:
self.dataset_desc = json.loads(file.read())
def __len__(self):
return len(self.dataset_desc["focaldepth"])
def __getitem__(self, idx):
lightfield_images, gt, gt2, fd, gazeX, gazeY, sample_idx = self.get_datum(idx)
if self.transforms:
lightfield_images = self.transforms(lightfield_images)
# print(lightfield_images.shape,gt.shape,fd,gazeX,gazeY,sample_idx)
return (lightfield_images, gt, gt2, fd, gazeX, gazeY, sample_idx)
def get_datum(self, idx):
lf_image_paths = os.path.join(DATA_FILE, self.dataset_desc["train"][idx])
# print(lf_image_paths)
fd_gt_path = os.path.join(DATA_FILE, self.dataset_desc["gt"][idx])
fd_gt_path2 = os.path.join(DATA_FILE, self.dataset_desc["gt2"][idx])
# print(fd_gt_path)
lf_images = []
lf_image_big = cv2.imread(lf_image_paths, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.
lf_image_big = cv2.cvtColor(lf_image_big,cv2.COLOR_BGR2RGB)
## IF GrayScale
# lf_image_big = cv2.imread(lf_image_paths, cv2.IMREAD_GRAYSCALE).astype(np.float32) / 255.
# lf_image_big = np.expand_dims(lf_image_big, axis=-1)
# print(lf_image_big.shape)
for i in range(9):
lf_image = lf_image_big[i//3*IM_H:i//3*IM_H+IM_H,i%3*IM_W:i%3*IM_W+IM_W,0:3]
## IF GrayScale
# lf_image = lf_image_big[i//3*IM_H:i//3*IM_H+IM_H,i%3*IM_W:i%3*IM_W+IM_W,0:1]
# print(lf_image.shape)
lf_images.append(lf_image)
gt = cv2.imread(fd_gt_path, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.
gt = cv2.cvtColor(gt,cv2.COLOR_BGR2RGB)
gt2 = cv2.imread(fd_gt_path2, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.
gt2 = cv2.cvtColor(gt2,cv2.COLOR_BGR2RGB)
## IF GrayScale
# gt = cv2.imread(fd_gt_path, cv2.IMREAD_GRAYSCALE).astype(np.float32) / 255.
# gt = np.expand_dims(gt, axis=-1)
fd = self.dataset_desc["focaldepth"][idx]
gazeX = self.dataset_desc["gazeX"][idx]
gazeY = self.dataset_desc["gazeY"][idx]
sample_idx = self.dataset_desc["idx"][idx]
return np.asarray(lf_images),gt,gt2,fd,gazeX,gazeY,sample_idx
class lightFieldValDataLoader(torch.utils.data.dataset.Dataset):
def __init__(self, file_dir_path, file_json, transforms=None):
self.file_dir_path = file_dir_path
self.transforms = transforms
# self.datum_list = glob.glob(os.path.join(file_dir_path,"*"))
with open(file_json, encoding='utf-8') as file:
self.dataset_desc = json.loads(file.read())
def __len__(self):
return len(self.dataset_desc["focaldepth"])
def __getitem__(self, idx):
lightfield_images, fd, gazeX, gazeY, sample_idx = self.get_datum(idx)
if self.transforms:
lightfield_images = self.transforms(lightfield_images)
# print(lightfield_images.shape,gt.shape,fd,gazeX,gazeY,sample_idx)
return (lightfield_images, fd, gazeX, gazeY, sample_idx)
def get_datum(self, idx):
lf_image_paths = os.path.join(DATA_FILE, self.dataset_desc["train"][idx])
# print(fd_gt_path)
lf_images = []
lf_image_big = cv2.imread(lf_image_paths, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.
lf_image_big = cv2.cvtColor(lf_image_big,cv2.COLOR_BGR2RGB)
## IF GrayScale
# lf_image_big = cv2.imread(lf_image_paths, cv2.IMREAD_GRAYSCALE).astype(np.float32) / 255.
# lf_image_big = np.expand_dims(lf_image_big, axis=-1)
# print(lf_image_big.shape)
for i in range(9):
lf_image = lf_image_big[i//3*IM_H:i//3*IM_H+IM_H,i%3*IM_W:i%3*IM_W+IM_W,0:3]
## IF GrayScale
# lf_image = lf_image_big[i//3*IM_H:i//3*IM_H+IM_H,i%3*IM_W:i%3*IM_W+IM_W,0:1]
# print(lf_image.shape)
lf_images.append(lf_image)
## IF GrayScale
# gt = cv2.imread(fd_gt_path, cv2.IMREAD_GRAYSCALE).astype(np.float32) / 255.
# gt = np.expand_dims(gt, axis=-1)
fd = self.dataset_desc["focaldepth"][idx]
gazeX = self.dataset_desc["gazeX"][idx]
gazeY = self.dataset_desc["gazeY"][idx]
sample_idx = self.dataset_desc["idx"][idx]
return np.asarray(lf_images),fd,gazeX,gazeY,sample_idx
class lightFieldSeqDataLoader(torch.utils.data.dataset.Dataset):
def __init__(self, file_dir_path, file_json, transforms=None):
self.file_dir_path = file_dir_path
self.transforms = transforms
with open(file_json, encoding='utf-8') as file:
self.dataset_desc = json.loads(file.read())
def __len__(self):
return len(self.dataset_desc["seq"])
def __getitem__(self, idx):
lightfield_images, gt, gt2, fd, gazeX, gazeY, sample_idx = self.get_datum(idx)
fd = fd.astype(np.float32)
gazeX = gazeX.astype(np.float32)
gazeY = gazeY.astype(np.float32)
sample_idx = sample_idx.astype(np.int64)
# print(fd)
# print(gazeX)
# print(gazeY)
# print(sample_idx)
# print(lightfield_images.dtype,gt.dtype, gt2.dtype, fd.dtype, gazeX.dtype, gazeY.dtype, sample_idx.dtype, delta.dtype)
# print(lightfield_images.shape,gt.shape, gt2.shape, fd.shape, gazeX.shape, gazeY.shape, sample_idx.shape, delta.shape)
if self.transforms:
lightfield_images = self.transforms(lightfield_images)
return (lightfield_images, gt, gt2, fd, gazeX, gazeY, sample_idx)
def get_datum(self, idx):
indices = self.dataset_desc["seq"][idx]
# print("indices:",indices)
lf_images = []
fd = []
gazeX = []
gazeY = []
sample_idx = []
gt = []
gt2 = []
for i in range(len(indices)):
lf_image_paths = os.path.join(DATA_FILE, self.dataset_desc["train"][indices[i]])
fd_gt_path = os.path.join(DATA_FILE, self.dataset_desc["gt"][indices[i]])
fd_gt_path2 = os.path.join(DATA_FILE, self.dataset_desc["gt2"][indices[i]])
lf_image_one_sample = []
lf_image_big = cv2.imread(lf_image_paths, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.
lf_image_big = cv2.cvtColor(lf_image_big,cv2.COLOR_BGR2RGB)
for j in range(9):
lf_image = lf_image_big[j//3*IM_H:j//3*IM_H+IM_H,j%3*IM_W:j%3*IM_W+IM_W,0:3]
## IF GrayScale
# lf_image = lf_image_big[i//3*IM_H:i//3*IM_H+IM_H,i%3*IM_W:i%3*IM_W+IM_W,0:1]
# print(lf_image.shape)
lf_image_one_sample.append(lf_image)
gt_i = cv2.imread(fd_gt_path, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.
gt.append(cv2.cvtColor(gt_i,cv2.COLOR_BGR2RGB))
gt2_i = cv2.imread(fd_gt_path2, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.
gt2.append(cv2.cvtColor(gt2_i,cv2.COLOR_BGR2RGB))
# print("indices[i]:",indices[i])
fd.append([self.dataset_desc["focaldepth"][indices[i]]])
gazeX.append([self.dataset_desc["gazeX"][indices[i]]])
gazeY.append([self.dataset_desc["gazeY"][indices[i]]])
sample_idx.append([self.dataset_desc["idx"][indices[i]]])
lf_images.append(lf_image_one_sample)
#lf_images: 5,9,320,320
return np.asarray(lf_images),np.asarray(gt),np.asarray(gt2),np.asarray(fd),np.asarray(gazeX),np.asarray(gazeY),np.asarray(sample_idx)
#### Image Gen
conf = Conf()
u = GenSamplesInPupil(conf.pupil_size, 5)
gen = RetinalGen(conf, u)
FIRSST_LAYER_CHANNELS = 75 * INTERLEAVE_RATE**2
def GenRetinalFromLayersBatch(layers, gen, sample_idx, phi_dict, mask_dict):
# layers: batchsize, 2*color, height, width
......@@ -248,6 +69,7 @@ def GenRetinalGazeFromLayersBatch(layers, gen, sample_idx, phi_dict, mask_dict):
# retinal bs x color x height x width
retinal_fovea = torch.empty(layers.shape[0], 6, 160, 160)
mask_fovea = torch.empty(layers.shape[0], 2, 160, 160)
for i in range(0, layers.size()[0]):
phi = phi_dict[int(sample_idx[i].data)]
# print("phi_i:",phi.shape)
......@@ -262,11 +84,64 @@ def GenRetinalGazeFromLayersBatch(layers, gen, sample_idx, phi_dict, mask_dict):
retinal_fovea[i] = torch.cat([fovea_layers[0],fovea_layers[1]],dim=0)
mask_fovea[i] = torch.stack([fovea_layer_masks[0],fovea_layer_masks[1]],dim=0)
retinal_fovea = var_or_cuda(retinal_fovea)
mask_fovea = var_or_cuda(mask_fovea) # batch x 2 x height x width
# mask = torch.stack(mask,dim = 0).unsqueeze(1)
return retinal_fovea, mask_fovea
import time
def GenRetinalGazeFromLayersBatchSpeed(layers, gen, phi, phi_invalid, retinal_invalid):
# layers: batchsize, 2*color, height, width
# Phi:torch.Size([batchsize, Layer, h, w, 41, 2])
# df : batchsize,..
# start1 = time.time()
# retinal bs x color x height x width
retinal_fovea = torch.empty((layers.shape[0], 6, 160, 160),device="cuda:2")
mask_fovea = torch.empty((layers.shape[0], 2, 160, 160),device="cuda:2")
# start = time.time()
retinal = gen.GenRetinalFromLayersBatch(layers,phi_batch)
# print("retinal:",retinal.shape) #retinal: torch.Size([2, 3, 320, 320])
# print("t2:",time.time() - start)
# start = time.time()
fovea_layers, fovea_layer_masks = gen.GenFoveaLayersBatch(retinal,mask_batch)
mask_fovea = torch.stack([fovea_layer_masks[0],fovea_layer_masks[1]],dim=1)
retinal_fovea = torch.cat([fovea_layers[0],fovea_layers[1]],dim=1)
# print("t3:",time.time() - start)
retinal_fovea = var_or_cuda(retinal_fovea)
mask_fovea = var_or_cuda(mask_fovea) # batch x 2 x height x width
# mask = torch.stack(mask,dim = 0).unsqueeze(1)
return retinal_fovea, mask_fovea
def MergeBatchSpeed(layers, gen, phi, phi_invalid, retinal_invalid):
# layers: batchsize, 2*color, height, width
# Phi:torch.Size([batchsize, Layer, h, w, 41, 2])
# df : batchsize,..
# start1 = time.time()
# retinal bs x color x height x width
# retinal_fovea = torch.empty((layers.shape[0], 6, 160, 160),device="cuda:2")
# mask_fovea = torch.empty((layers.shape[0], 2, 160, 160),device="cuda:2")
# start = time.time()
retinal = gen.GenRetinalFromLayersBatch(layers,phi) #retinal: torch.Size([BatchSize , 3, 320, 320])
retinal.mul_(~retinal_invalid.to("cuda:2"))
# print("retinal:",retinal.shape)
# print("t2:",time.time() - start)
# start = time.time()
# fovea_layers, fovea_layer_masks = gen.GenFoveaLayersBatch(retinal,mask_batch)
# mask_fovea = torch.stack([fovea_layer_masks[0],fovea_layer_masks[1]],dim=1)
# retinal_fovea = torch.cat([fovea_layers[0],fovea_layers[1]],dim=1)
# print("t3:",time.time() - start)
# retinal_fovea = var_or_cuda(retinal_fovea)
# mask_fovea = var_or_cuda(mask_fovea) # batch x 2 x height x width
# mask = torch.stack(mask,dim = 0).unsqueeze(1)
return retinal
def GenRetinalFromLayersBatch_Online(layers, gen, phi, mask):
# layers: batchsize, 2*color, height, width
# Phi:torch.Size([batchsize, 480, 640, 2, 41, 2])
......@@ -285,47 +160,7 @@ def GenRetinalFromLayersBatch_Online(layers, gen, phi, mask):
return retinal.unsqueeze(0), mask_out
#### Image Gen End
weightVarScale = 0.25
bias_stddev = 0.01
def weight_init_normal(m):
classname = m.__class__.__name__
if classname.find("Conv") != -1:
torch.nn.init.xavier_normal_(m.weight.data)
torch.nn.init.normal_(m.bias.data,mean = 0.0, std=bias_stddev)
elif classname.find("BatchNorm2d") != -1:
torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
torch.nn.init.constant_(m.bias.data, 0.0)
def calImageGradients(images):
# x is a 4-D tensor
dx = images[:, :, 1:, :] - images[:, :, :-1, :]
dy = images[:, :, :, 1:] - images[:, :, :, :-1]
return dx, dy
perc_loss = VGGPerceptualLoss()
perc_loss = perc_loss.to("cuda:1")
def loss_new(generated, gt):
mse_loss = torch.nn.MSELoss()
rmse_intensity = mse_loss(generated, gt)
psnr_intensity = torch.log10(rmse_intensity)
# print("psnr:",psnr_intensity)
# ssim_intensity = ssim(generated, gt)
labels_dx, labels_dy = calImageGradients(gt)
# print("generated:",generated.shape)
preds_dx, preds_dy = calImageGradients(generated)
rmse_grad_x, rmse_grad_y = mse_loss(labels_dx, preds_dx), mse_loss(labels_dy, preds_dy)
psnr_grad_x, psnr_grad_y = torch.log10(rmse_grad_x), torch.log10(rmse_grad_y)
# print("psnr x&y:",psnr_grad_x," ",psnr_grad_y)
p_loss = perc_loss(generated,gt)
# print("-psnr:",-psnr_intensity,",0.5*(psnr_grad_x + psnr_grad_y):",0.5*(psnr_grad_x + psnr_grad_y),",perc_loss:",p_loss)
total_loss = 10 + psnr_intensity + 0.5*(psnr_grad_x + psnr_grad_y) + p_loss
# total_loss = rmse_intensity + 0.5*(rmse_grad_x + rmse_grad_y) # + p_loss
return total_loss
from weight_init import weight_init_normal
def save_checkpoints(file_path, epoch_idx, model, model_solver):
print('[INFO] Saving checkpoint to %s ...' % ( file_path))
......@@ -336,93 +171,112 @@ def save_checkpoints(file_path, epoch_idx, model, model_solver):
}
torch.save(checkpoint, file_path)
mode = "train"
import pickle
def save_obj(obj, name ):
# with open('./outputF/dict/'+ name + '.pkl', 'wb') as f:
# pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL)
torch.save(obj,'./outputF/dict/'+ name + '.pkl')
def load_obj(name):
# with open('./outputF/dict/' + name + '.pkl', 'rb') as f:
# return pickle.load(f)
return torch.load('./outputF/dict/'+ name + '.pkl')
def hook_fn_back(m, i, o):
for grad in i:
try:
print("Input Grad:",m,grad.shape,grad.sum())
except AttributeError:
print ("None found for Gradient")
for grad in o:
try:
print("Output Grad:",m,grad.shape,grad.sum())
except AttributeError:
print ("None found for Gradient")
print("\n")
def hook_fn_for(m, i, o):
for grad in i:
try:
print("Input Feats:",m,grad.shape,grad.sum())
except AttributeError:
print ("None found for Gradient")
for grad in o:
try:
print("Output Feats:",m,grad.shape,grad.sum())
except AttributeError:
print ("None found for Gradient")
print("\n")
def generatePhiMaskDict(data_json, generator):
phi_dict = {}
mask_dict = {}
idx_info_dict = {}
with open(data_json, encoding='utf-8') as file:
dataset_desc = json.loads(file.read())
for i in range(len(dataset_desc["focaldepth"])):
# if i == 2:
# break
idx = dataset_desc["idx"][i]
focaldepth = dataset_desc["focaldepth"][i]
gazeX = dataset_desc["gazeX"][i]
gazeY = dataset_desc["gazeY"][i]
print("focaldepth:",focaldepth," idx:",idx," gazeX:",gazeX," gazeY:",gazeY)
phi,mask = generator.CalculateRetinal2LayerMappings(focaldepth,torch.tensor([gazeX, gazeY]))
phi_dict[idx]=phi
mask_dict[idx]=mask
idx_info_dict[idx]=[idx,focaldepth,gazeX,gazeY]
return phi_dict,mask_dict,idx_info_dict
# import pickle
# def save_obj(obj, name ):
# # with open('./outputF/dict/'+ name + '.pkl', 'wb') as f:
# # pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL)
# torch.save(obj,'./outputF/dict/'+ name + '.pkl')
# def load_obj(name):
# # with open('./outputF/dict/' + name + '.pkl', 'rb') as f:
# # return pickle.load(f)
# return torch.load('./outputF/dict/'+ name + '.pkl')
# def generatePhiMaskDict(data_json, generator):
# phi_dict = {}
# mask_dict = {}
# idx_info_dict = {}
# with open(data_json, encoding='utf-8') as file:
# dataset_desc = json.loads(file.read())
# for i in range(len(dataset_desc["focaldepth"])):
# # if i == 2:
# # break
# idx = dataset_desc["idx"][i]
# focaldepth = dataset_desc["focaldepth"][i]
# gazeX = dataset_desc["gazeX"][i]
# gazeY = dataset_desc["gazeY"][i]
# print("focaldepth:",focaldepth," idx:",idx," gazeX:",gazeX," gazeY:",gazeY)
# phi,mask = generator.CalculateRetinal2LayerMappings(focaldepth,torch.tensor([gazeX, gazeY]))
# phi_dict[idx]=phi
# mask_dict[idx]=mask
# idx_info_dict[idx]=[idx,focaldepth,gazeX,gazeY]
# return phi_dict,mask_dict,idx_info_dict
# def generatePhiMaskDictNew(data_json, generator):
# phi_dict = {}
# mask_dict = {}
# idx_info_dict = {}
# with open(data_json, encoding='utf-8') as file:
# dataset_desc = json.loads(file.read())
# for i in range(len(dataset_desc["seq"])):
# for j in dataset_desc["seq"][i]:
# idx = dataset_desc["idx"][j]
# focaldepth = dataset_desc["focaldepth"][j]
# gazeX = dataset_desc["gazeX"][j]
# gazeY = dataset_desc["gazeY"][j]
# print("focaldepth:",focaldepth," idx:",idx," gazeX:",gazeX," gazeY:",gazeY)
# phi,mask = generator.CalculateRetinal2LayerMappings(focaldepth,torch.tensor([gazeX, gazeY]))
# phi_dict[idx]=phi
# mask_dict[idx]=mask
# idx_info_dict[idx]=[idx,focaldepth,gazeX,gazeY]
# return phi_dict,mask_dict,idx_info_dict
mode = "Silence" #"Perf"
model_type = "RNN" #"RNN"
w_frame = 0.9
w_inter_frame = 0.1
batch_model = "NoSingle"
loss1 = ReconstructionLoss()
loss2 = ReconstructionLoss()
if __name__ == "__main__":
############################## generate phi and mask in pre-training
# print("generating phi and mask...")
# phi_dict,mask_dict,idx_info_dict = generatePhiMaskDict(DATA_JSON,gen)
# save_obj(phi_dict,"phi_1204")
# save_obj(mask_dict,"mask_1204")
# save_obj(idx_info_dict,"idx_info_1204")
# phi_dict,mask_dict,idx_info_dict = generatePhiMaskDictNew(DATA_JSON,gen)
# # save_obj(phi_dict,"phi_1204")
# # save_obj(mask_dict,"mask_1204")
# # save_obj(idx_info_dict,"idx_info_1204")
# print("generating phi and mask end.")
# exit(0)
############################# load phi and mask in pre-training
print("loading phi and mask ...")
phi_dict = load_obj("phi_1204")
mask_dict = load_obj("mask_1204")
idx_info_dict = load_obj("idx_info_1204")
print(len(phi_dict))
print(len(mask_dict))
print("loading phi and mask end")
# print("loading phi and mask ...")
# phi_dict = load_obj("phi_1204")
# mask_dict = load_obj("mask_1204")
# idx_info_dict = load_obj("idx_info_1204")
# print(len(phi_dict))
# print(len(mask_dict))
# print("loading phi and mask end")
#### Image Gen and conf
conf = Conf()
gen = RetinalGen(conf)
#train
train_data_loader = torch.utils.data.DataLoader(dataset=lightFieldSeqDataLoader(DATA_FILE,DATA_JSON),
train_data_loader = torch.utils.data.DataLoader(dataset=lightFieldFlowSeqDataLoader(DATA_FILE,DATA_JSON, gen, conf),
batch_size=BATCH_SIZE,
num_workers=0,
num_workers=8,
pin_memory=True,
shuffle=True,
drop_last=False)
#Data loader test
print(len(train_data_loader))
# # lightfield_images, gt, flow, fd, gazeX, gazeY, posX, posY, sample_idx, phi, mask
# # lightfield_images, gt, phi, phi_invalid, retinal_invalid, flow, fd, gazeX, gazeY, posX, posY, posZ, sample_idx
# for batch_idx, (image_set, gt,phi, phi_invalid, retinal_invalid, flow, df, gazeX, gazeY, posX, posY, posZ, sample_idx) in enumerate(train_data_loader):
# print(image_set.shape,type(image_set))
# print(gt.shape,type(gt))
# print(phi.shape,type(phi))
# print(phi_invalid.shape,type(phi_invalid))
# print(retinal_invalid.shape,type(retinal_invalid))
# print(flow.shape,type(flow))
# print(df.shape,type(df))
# print(gazeX.shape,type(gazeX))
# print(posX.shape,type(posX))
# print(sample_idx.shape,type(sample_idx))
# print("test train dataloader.")
# exit(0)
#Data loader test end
################################################ val #########################################################
......@@ -464,21 +318,24 @@ if __name__ == "__main__":
# print("output:",output.shape," df:",df[0].data, ",gazeX:",gazeX[0].data,",gazeY:", gazeY[0].data)
# for i in range(output1.size()[0]):
# save_image(output[i][0:3].data,os.path.join(OUTPUT_DIR,"test_interp_gaze_fac1_o_%.3f_%.3f_%.3f.png"%(df[i].data,gazeX[i].data,gazeY[i].data)))
# save_image(output[i][3:6].data,os.path.join(OUTPUT_DIR,"test_interp_gaze_fac2_o_%.3f_%.3f_%.3f.png"%(df[i].data,gazeX[i].data,gazeY[i].data)))
# save_image(output1[i][0:3].data,os.path.join(OUTPUT_DIR,"test_interp_gaze_out1_o_%.3f_%.3f_%.3f.png"%(df[i].data,gazeX[i].data,gazeY[i].data)))
# save_image(output1[i][3:6].data,os.path.join(OUTPUT_DIR,"test_interp_gaze_out2_o_%.3f_%.3f_%.3f.png"%(df[i].data,gazeX[i].data,gazeY[i].data)))
# save_image(output[i][0:3].data,os.path.join(OUTPUT_DIR,"test_interp_gaze_fac1_o_%.5f_%.5f_%.5f.png"%(df[i].data,gazeX[i].data,gazeY[i].data)))
# save_image(output[i][3:6].data,os.path.join(OUTPUT_DIR,"test_interp_gaze_fac2_o_%.5f_%.5f_%.5f.png"%(df[i].data,gazeX[i].data,gazeY[i].data)))
# save_image(output1[i][0:3].data,os.path.join(OUTPUT_DIR,"test_interp_gaze_out1_o_%.5f_%.5f_%.5f.png"%(df[i].data,gazeX[i].data,gazeY[i].data)))
# save_image(output1[i][3:6].data,os.path.join(OUTPUT_DIR,"test_interp_gaze_out2_o_%.5f_%.5f_%.5f.png"%(df[i].data,gazeX[i].data,gazeY[i].data)))
# # save_image(output[0][0:3].data,os.path.join(OUTPUT_DIR,"gaze_fovea_interp_l1_%.3f.png"%(df[0].data)))
# # save_image(output[0][3:6].data,os.path.join(OUTPUT_DIR,"gaze_fovea_interp_l2_%.3f.png"%(df[0].data)))
# # save_image(output[0][0:3].data,os.path.join(OUTPUT_DIR,"gaze_fovea_interp_l1_%.5f.png"%(df[0].data)))
# # save_image(output[0][3:6].data,os.path.join(OUTPUT_DIR,"gaze_fovea_interp_l2_%.5f.png"%(df[0].data)))
# # output = GenRetinalFromLayersBatch(output,conf,df,v,u)
# # save_image(output[0][0:3].data,os.path.join(OUTPUT_DIR,"1113_interp_o%.3f.png"%(df[0].data)))
# # save_image(output[0][0:3].data,os.path.join(OUTPUT_DIR,"1113_interp_o%.5f.png"%(df[0].data)))
# exit()
################################################ train #########################################################
if model_type == "RNN":
lf_model = model(FIRSST_LAYER_CHANNELS,LAST_LAYER_CHANNELS,OUT_CHANNELS_RB,KERNEL_SIZE,KERNEL_SIZE_RB,INTERLEAVE_RATE)
else:
lf_model = model(FIRSST_LAYER_CHANNELS,LAST_LAYER_CHANNELS,OUT_CHANNELS_RB,KERNEL_SIZE,KERNEL_SIZE_RB,INTERLEAVE_RATE,RNN=False)
lf_model.apply(weight_init_normal)
lf_model.train()
epoch_begin = 0
################################ load model file
......@@ -493,144 +350,243 @@ if __name__ == "__main__":
if torch.cuda.is_available():
# lf_model = torch.nn.DataParallel(lf_model).cuda()
lf_model = lf_model.to('cuda:1')
lf_model.train()
optimizer = torch.optim.Adam(lf_model.parameters(),lr=1e-2,betas=(0.9,0.999))
l1loss = torch.nn.L1Loss()
lf_model = lf_model.to('cuda:2')
optimizer = torch.optim.Adam(lf_model.parameters(),lr=5e-3,betas=(0.9,0.999))
# lf_model.output_layer.register_backward_hook(hook_fn_back)
if mode=="Perf":
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
print("begin training....")
for epoch in range(epoch_begin, NUM_EPOCH):
for batch_idx, (image_set, gt, gt2, df, gazeX, gazeY, sample_idx) in enumerate(train_data_loader):
for batch_idx, (image_set, gt,phi, phi_invalid, retinal_invalid, flow, flow_invalid_mask, df, gazeX, gazeY, posX, posY, posZ, sample_idx) in enumerate(train_data_loader):
# print(sample_idx.shape,df.shape,gazeX.shape,gazeY.shape) # torch.Size([2, 5])
# print(image_set.shape,gt.shape,gt2.shape) #torch.Size([2, 5, 9, 320, 320, 3]) torch.Size([2, 5, 160, 160, 3]) torch.Size([2, 5, 160, 160, 3])
# print(delta.shape) # delta: torch.Size([2, 4, 160, 160, 3])
if mode=="Perf":
end.record()
torch.cuda.synchronize()
print("load:",start.elapsed_time(end))
start.record()
#reshape for input
image_set = image_set.permute(0,1,2,5,3,4) # N S LF C H W
image_set = image_set.permute(0,1,2,5,3,4) # N Seq 5 LF 25 C 3 H W
image_set = image_set.reshape(image_set.shape[0],image_set.shape[1],-1,image_set.shape[4],image_set.shape[5]) # N, LFxC, H, W
# N, Seq 5, LF 25 C 3, H, W
image_set = var_or_cuda(image_set)
gt = gt.permute(0,1,4,2,3) # N S C H W
# print(image_set.shape) #torch.Size([2, 5, 75, 320, 320])
gt = gt.permute(0,1,4,2,3) # BS Seq 5 C 3 H W
gt = var_or_cuda(gt)
gt2 = gt2.permute(0,1,4,2,3)
gt2 = var_or_cuda(gt2)
flow = var_or_cuda(flow) #BS,Seq-1,H,W,2
# gt2 = gt2.permute(0,1,4,2,3)
# gt2 = var_or_cuda(gt2)
gen1 = torch.empty(gt.shape)
gen1 = torch.empty(gt.shape) # BS Seq C H W
gen1 = var_or_cuda(gen1)
# print(gen1.shape) #torch.Size([2, 5, 3, 320, 320])
gen2 = torch.empty(gt2.shape)
gen2 = var_or_cuda(gen2)
# gen2 = torch.empty(gt2.shape)
# gen2 = var_or_cuda(gen2)
warped = torch.empty(gt2.shape[0],gt2.shape[1]-1,gt2.shape[2],gt2.shape[3],gt2.shape[4])
#BS, Seq - 1, C, H, W
warped = torch.empty(gt.shape[0],gt.shape[1]-1,gt.shape[2],gt.shape[3],gt.shape[4])
warped = var_or_cuda(warped)
gen_temp = torch.empty(warped.shape)
gen_temp = var_or_cuda(gen_temp)
# print("warped:",warped.shape) #warped: torch.Size([2, 4, 3, 320, 320])
if mode=="Perf":
end.record()
torch.cuda.synchronize()
print("data prepare:",start.elapsed_time(end))
start.record()
if model_type == "RNN":
if batch_model != "Single":
for k in range(image_set.shape[1]):
if k == 0:
lf_model.reset_hidden(image_set[:,k])
output = lf_model(image_set[:,k],df[:,k],gazeX[:,k],gazeY[:,k],posX[:,k],posY[:,k],posZ[:,k]) # batchsize, layer_num x 2 = 6, layer_res: 320, layer_res: 320
output1 = MergeBatchSpeed(output, gen, phi[:,k], phi_invalid[:,k], retinal_invalid[:,k])
gen1[:,k] = output1[:,0:3]
gt[:,k] = gt[:,k].mul_(~retinal_invalid[:,k].to("cuda:2"))
delta = torch.empty(gt2.shape[0],gt2.shape[1]-1,gt2.shape[2],gt2.shape[3],gt2.shape[4])
delta = var_or_cuda(delta)
if ((epoch%10 == 0 and epoch != 0) or epoch == 2):
for i in range(output.shape[0]):
save_image(output[i][0:3].data,os.path.join(OUTPUT_DIR,"gaze_fac1_o_%.5f_%.5f_%.5f_%.5f_%.5f.png"%(df[i][k].data,gazeX[i][k].data,gazeY[i][k].data,posX[i][k].data,posY[i][k].data)))
save_image(output[i][3:6].data,os.path.join(OUTPUT_DIR,"gaze_fac2_o_%.5f_%.5f_%.5f_%.5f_%.5f.png"%(df[i][k].data,gazeX[i][k].data,gazeY[i][k].data,posX[i][k].data,posY[i][k].data)))
for i in range(1,gt.shape[1]):
# print(flow_invalid_mask.shape) #torch.Size([2, 4, 320, 320])
# print(FlowMap(gen1[:,i-1],flow[:,i-1]).shape) #torch.Size([2, 3, 320, 320])
warped[:,i-1] = FlowMap(gen1[:,i-1],flow[:,i-1]).mul(~flow_invalid_mask[:,i-1].unsqueeze(1).to("cuda:2"))
gen_temp[:,i-1] = gen1[:,i].mul(~flow_invalid_mask[:,i-1].unsqueeze(1).to("cuda:2"))
else:
for k in range(image_set.shape[1]):
if k == 0:
lf_model.reset_hidden(image_set[:,k])
output = lf_model(image_set[:,k],df[:,k],gazeX[:,k],gazeY[:,k],posX[:,k],posY[:,k],posZ[:,k]) # batchsize, layer_num x 2 = 6, layer_res: 320, layer_res: 320
output1 = MergeBatchSpeed(output, gen, phi[:,k], phi_invalid[:,k], retinal_invalid[:,k])
gen1[:,k] = output1[:,0:3]
gt[:,k] = gt[:,k].mul_(~retinal_invalid[:,k].to("cuda:2"))
if k != image_set.shape[1]-1:
warped[:,k] = FlowMap(output1.detach(),flow[:,k]).mul(~flow_invalid_mask[:,k].unsqueeze(1).to("cuda:2"))
loss1 = loss_without_perc(output1,gt[:,k])
loss = (w_frame * loss1)
if k==0:
loss.backward(retain_graph=False)
optimizer.step()
lf_model.zero_grad()
optimizer.zero_grad()
print("Epoch:",epoch,",Iter:",batch_idx,",Seq:",k,",loss:",loss.item())
else:
output1mask = output1.mul(~flow_invalid_mask[:,k-1].unsqueeze(1).to("cuda:2"))
loss2 = l1loss(output1mask,warped[:,k-1])
loss += (w_inter_frame * loss2)
loss.backward(retain_graph=False)
optimizer.step()
lf_model.zero_grad()
optimizer.zero_grad()
# print("Epoch:",epoch,",Iter:",batch_idx,",Seq:",k,",loss:",loss.item())
print("Epoch:",epoch,",Iter:",batch_idx,",Seq:",k,",frame loss:",loss1.item(),",inter loss:",w_inter_frame * loss2.item())
# start = torch.cuda.Event(enable_timing=True)
# end = torch.cuda.Event(enable_timing=True)
# start.record()
output = lf_model(image_set[:,k],df[:,k],gazeX[:,k],gazeY[:,k])
# end.record()
# torch.cuda.synchronize()
# print("Model Forward:",start.elapsed_time(end))
# print("output:",output.shape) # [2, 6, 320, 320]
# exit()
########################### Use Pregen Phi and Mask ###################
# start.record()
output1,mask = GenRetinalGazeFromLayersBatch(output, gen, sample_idx[:,k], phi_dict, mask_dict)
# end.record()
# torch.cuda.synchronize()
# print("Merge:",start.elapsed_time(end))
# print("output1 shape:",output1.shape, "mask shape:",mask.shape)
# output1 shape: torch.Size([2, 6, 160, 160]) mask shape: torch.Size([2, 2, 160, 160])
for i in range(0, 2):
output1[:,i*3:i*3+3].mul_(mask[:,i:i+1])
if i == 0:
gt[:,k].mul_(mask[:,i:i+1])
if i == 1:
gt2[:,k].mul_(mask[:,i:i+1])
if ((epoch%10 == 0 and epoch != 0) or epoch == 2):
for i in range(output.shape[0]):
save_image(output[i][0:3].data,os.path.join(OUTPUT_DIR,"gaze_fac1_o_%.5f_%.5f_%.5f_%.5f_%.5f.png"%(df[i][k].data,gazeX[i][k].data,gazeY[i][k].data,posX[i][k].data,posY[i][k].data)))
save_image(output[i][3:6].data,os.path.join(OUTPUT_DIR,"gaze_fac2_o_%.5f_%.5f_%.5f_%.5f_%.5f.png"%(df[i][k].data,gazeX[i][k].data,gazeY[i][k].data,posX[i][k].data,posY[i][k].data)))
# BSxSeq C H W
gen1 = gen1.reshape(-1,gen1.shape[2],gen1.shape[3],gen1.shape[4])
gt = gt.reshape(-1,gt.shape[2],gt.shape[3],gt.shape[4])
if ((epoch%10== 0 and epoch != 0) or epoch == 2): # torch.Size([2, 5, 160, 160, 3])
for i in range(gt.size()[0]):
save_image(gen1[i].data,os.path.join(OUTPUT_DIR,"gaze_out1_o_%.5f_%.5f_%.5f_%.5f_%.5f.png"%(df[i//5][i%5].data,gazeX[i//5][i%5].data,gazeY[i//5][i%5].data,posX[i//5][i%5].data,posY[i//5][i%5].data)))
save_image(gt[i].data,os.path.join(OUTPUT_DIR,"gaze_test1_gt0_%.5f_%.5f_%.5f_%.5f_%.5f.png"%(df[i//5][i%5].data,gazeX[i//5][i%5].data,gazeY[i//5][i%5].data,posX[i//5][i%5].data,posY[i//5][i%5].data)))
if ((epoch%100 == 0) and epoch != 0 and batch_idx==len(train_data_loader)-1):
save_checkpoints(os.path.join(OUTPUT_DIR, 'gaze-ckpt-epoch-%04d.pth' % (epoch)),epoch,lf_model,optimizer)
else:
if batch_model != "Single":
for k in range(image_set.shape[1]):
output = lf_model(image_set[:,k],df[:,k],gazeX[:,k],gazeY[:,k],posX[:,k],posY[:,k],posZ[:,k]) # batchsize, layer_num x 2 = 6, layer_res: 320, layer_res: 320
output1 = MergeBatchSpeed(output, gen, phi[:,k], phi_invalid[:,k], retinal_invalid[:,k])
gen1[:,k] = output1[:,0:3]
gen2[:,k] = output1[:,3:6]
if ((epoch%5== 0) or epoch == 2):
gt[:,k] = gt[:,k].mul_(~retinal_invalid[:,k].to("cuda:2"))
if ((epoch%10 == 0 and epoch != 0) or epoch == 2):
for i in range(output.shape[0]):
save_image(output[i][0:3].data,os.path.join(OUTPUT_DIR,"gaze_fac1_o_%.3f_%.3f_%.3f.png"%(df[i][k].data,gazeX[i][k].data,gazeY[i][k].data)))
save_image(output[i][3:6].data,os.path.join(OUTPUT_DIR,"gaze_fac2_o_%.3f_%.3f_%.3f.png"%(df[i][k].data,gazeX[i][k].data,gazeY[i][k].data)))
save_image(output[i][0:3].data,os.path.join(OUTPUT_DIR,"gaze_fac1_o_%.5f_%.5f_%.5f_%.5f_%.5f.png"%(df[i][k].data,gazeX[i][k].data,gazeY[i][k].data,posX[i][k].data,posY[i][k].data)))
save_image(output[i][3:6].data,os.path.join(OUTPUT_DIR,"gaze_fac2_o_%.5f_%.5f_%.5f_%.5f_%.5f.png"%(df[i][k].data,gazeX[i][k].data,gazeY[i][k].data,posX[i][k].data,posY[i][k].data)))
for i in range(1,gt.shape[1]):
# print(flow_invalid_mask.shape) #torch.Size([2, 4, 320, 320])
# print(FlowMap(gen1[:,i-1],flow[:,i-1]).shape) #torch.Size([2, 3, 320, 320])
warped[:,i-1] = FlowMap(gen1[:,i-1],flow[:,i-1]).mul(~flow_invalid_mask[:,i-1].unsqueeze(1).to("cuda:2"))
gen_temp[:,i-1] = gen1[:,i].mul(~flow_invalid_mask[:,i-1].unsqueeze(1).to("cuda:2"))
else:
for k in range(image_set.shape[1]):
output = lf_model(image_set[:,k],df[:,k],gazeX[:,k],gazeY[:,k],posX[:,k],posY[:,k],posZ[:,k]) # batchsize, layer_num x 2 = 6, layer_res: 320, layer_res: 320
output1 = MergeBatchSpeed(output, gen, phi[:,k], phi_invalid[:,k], retinal_invalid[:,k])
gen1[:,k] = output1[:,0:3]
gt[:,k] = gt[:,k].mul_(~retinal_invalid[:,k].to("cuda:2"))
########################### Update ###################
for i in range(1,image_set.shape[1]):
delta[:,i-1] = gt2[:,i] - gt2[:,i]
warped[:,i-1] = gen2[:,i]-gen2[:,i-1]
# print(output1.shape) #torch.Size([BS, 3, 320, 320])
loss1 = loss_without_perc(output1,gt[:,k])
loss = (w_frame * loss1)
# print("loss:",loss1.item())
loss.backward(retain_graph=False)
optimizer.step()
lf_model.zero_grad()
optimizer.zero_grad()
print("Epoch:",epoch,",Iter:",batch_idx,",Seq:",k,",loss:",loss.item())
if ((epoch%10 == 0 and epoch != 0) or epoch == 2):
for i in range(output.shape[0]):
save_image(output[i][0:3].data,os.path.join(OUTPUT_DIR,"gaze_fac1_o_%.5f_%.5f_%.5f_%.5f_%.5f.png"%(df[i][k].data,gazeX[i][k].data,gazeY[i][k].data,posX[i][k].data,posY[i][k].data)))
save_image(output[i][3:6].data,os.path.join(OUTPUT_DIR,"gaze_fac2_o_%.5f_%.5f_%.5f_%.5f_%.5f.png"%(df[i][k].data,gazeX[i][k].data,gazeY[i][k].data,posX[i][k].data,posY[i][k].data)))
for i in range(1,gt.shape[1]):
warped[:,i-1] = FlowMap(gen1[:,i-1],flow[:,i-1]).mul(~flow_invalid_mask[:,i-1].unsqueeze(1).to("cuda:2"))
gen_temp[:,i-1] = gen1[:,i].mul(~flow_invalid_mask[:,i-1].unsqueeze(1).to("cuda:2"))
warped = warped.reshape(-1,warped.shape[2],warped.shape[3],warped.shape[4])
gen_temp = gen_temp.reshape(-1,gen_temp.shape[2],gen_temp.shape[3],gen_temp.shape[4])
loss3 = l1loss(warped,gen_temp)
print("Epoch:",epoch,",Iter:",batch_idx,",inter-frame loss:",w_inter_frame *loss3.item())
# BSxSeq C H W
gen1 = gen1.reshape(-1,gen1.shape[2],gen1.shape[3],gen1.shape[4])
gt = gt.reshape(-1,gt.shape[2],gt.shape[3],gt.shape[4])
if ((epoch%10== 0 and epoch != 0) or epoch == 2): # torch.Size([2, 5, 160, 160, 3])
for i in range(gt.size()[0]):
save_image(gen1[i].data,os.path.join(OUTPUT_DIR,"gaze_out1_o_%.5f_%.5f_%.5f_%.5f_%.5f.png"%(df[i//5][i%5].data,gazeX[i//5][i%5].data,gazeY[i//5][i%5].data,posX[i//5][i%5].data,posY[i//5][i%5].data)))
save_image(gt[i].data,os.path.join(OUTPUT_DIR,"gaze_test1_gt0_%.5f_%.5f_%.5f_%.5f_%.5f.png"%(df[i//5][i%5].data,gazeX[i//5][i%5].data,gazeY[i//5][i%5].data,posX[i//5][i%5].data,posY[i//5][i%5].data)))
if ((epoch%100 == 0) and epoch != 0 and batch_idx==len(train_data_loader)-1):
save_checkpoints(os.path.join(OUTPUT_DIR, 'gaze-ckpt-epoch-%04d.pth' % (epoch)),epoch,lf_model,optimizer)
if batch_model != "Single":
if mode=="Perf":
end.record()
torch.cuda.synchronize()
print("forward:",start.elapsed_time(end))
start.record()
optimizer.zero_grad()
# # N S C H W
# BSxSeq C H W
gen1 = gen1.reshape(-1,gen1.shape[2],gen1.shape[3],gen1.shape[4])
gen2 = gen2.reshape(-1,gen2.shape[2],gen2.shape[3],gen2.shape[4])
# gen2 = gen2.reshape(-1,gen2.shape[2],gen2.shape[3],gen2.shape[4])
# BSxSeq C H W
gt = gt.reshape(-1,gt.shape[2],gt.shape[3],gt.shape[4])
gt2 = gt2.reshape(-1,gt2.shape[2],gt2.shape[3],gt2.shape[4])
warped = warped.reshape(-1,warped.shape[2],warped.shape[3],warped.shape[4])
delta = delta.reshape(-1,delta.shape[2],delta.shape[3],delta.shape[4])
# gt2 = gt2.reshape(-1,gt2.shape[2],gt2.shape[3],gt2.shape[4])
# BSx(Seq-1) C H W
warped = warped.reshape(-1,warped.shape[2],warped.shape[3],warped.shape[4])
gen_temp = gen_temp.reshape(-1,gen_temp.shape[2],gen_temp.shape[3],gen_temp.shape[4])
# start = torch.cuda.Event(enable_timing=True)
# end = torch.cuda.Event(enable_timing=True)
# start.record()
loss1 = loss_new(gen1,gt)
loss2 = loss_new(gen2,gt2)
loss3 = l1loss(warped,delta)
loss = loss1+loss2+loss3
# end.record()
# torch.cuda.synchronize()
# print("loss comp:",start.elapsed_time(end))
loss1_value = loss1(gen1,gt)
loss2_value = loss2(warped,gen_temp)
if model_type == "RNN":
loss = (w_frame * loss1_value)+ (w_inter_frame * loss2_value)
else:
loss = (w_frame * loss1_value)
if mode=="Perf":
end.record()
torch.cuda.synchronize()
print("compute loss:",start.elapsed_time(end))
# start.record()
start.record()
loss.backward()
# end.record()
# torch.cuda.synchronize()
# print("backward:",start.elapsed_time(end))
if mode=="Perf":
end.record()
torch.cuda.synchronize()
print("backward:",start.elapsed_time(end))
# start.record()
start.record()
optimizer.step()
# end.record()
# torch.cuda.synchronize()
# print("optimizer step:",start.elapsed_time(end))
if mode=="Perf":
end.record()
torch.cuda.synchronize()
print("update:",start.elapsed_time(end))
## Update Prev
print("Epoch:",epoch,",Iter:",batch_idx,",loss:",loss)
print("Epoch:",epoch,",Iter:",batch_idx,",loss:",loss.item(),",frame loss:",loss1_value.item(),",inter-frame loss:",loss2_value.item())
# exit(0)
########################### Save #####################
if ((epoch%5== 0) or epoch == 2): # torch.Size([2, 5, 160, 160, 3])
if ((epoch%10== 0 and epoch != 0) or epoch == 2): # torch.Size([2, 5, 160, 160, 3])
for i in range(gt.size()[0]):
# df 2,5
save_image(gen1[i].data,os.path.join(OUTPUT_DIR,"gaze_out1_o_%.3f_%.3f_%.3f.png"%(df[i//5][i%5].data,gazeX[i//5][i%5].data,gazeY[i//5][i%5].data)))
save_image(gen2[i].data,os.path.join(OUTPUT_DIR,"gaze_out2_o_%.3f_%.3f_%.3f.png"%(df[i//5][i%5].data,gazeX[i//5][i%5].data,gazeY[i//5][i%5].data)))
save_image(gt[i].data,os.path.join(OUTPUT_DIR,"gaze_test1_gt0_%.3f_%.3f_%.3f.png"%(df[i//5][i%5].data,gazeX[i//5][i%5].data,gazeY[i//5][i%5].data)))
save_image(gt2[i].data,os.path.join(OUTPUT_DIR,"gaze_test1_gt1_%.3f_%.3f_%.3f.png"%(df[i//5][i%5].data,gazeX[i//5][i%5].data,gazeY[i//5][i%5].data)))
save_image(gen1[i].data,os.path.join(OUTPUT_DIR,"gaze_out1_o_%.5f_%.5f_%.5f_%.5f_%.5f.png"%(df[i//5][i%5].data,gazeX[i//5][i%5].data,gazeY[i//5][i%5].data,posX[i//5][i%5].data,posY[i//5][i%5].data)))
# save_image(gen2[i].data,os.path.join(OUTPUT_DIR,"gaze_out2_o_%.5f_%.5f_%.5f.png"%(df[i//5][i%5].data,gazeX[i//5][i%5].data,gazeY[i//5][i%5].data)))
save_image(gt[i].data,os.path.join(OUTPUT_DIR,"gaze_test1_gt0_%.5f_%.5f_%.5f_%.5f_%.5f.png"%(df[i//5][i%5].data,gazeX[i//5][i%5].data,gazeY[i//5][i%5].data,posX[i//5][i%5].data,posY[i//5][i%5].data)))
# save_image(gt2[i].data,os.path.join(OUTPUT_DIR,"gaze_test1_gt1_%.5f_%.5f_%.5f.png"%(df[i//5][i%5].data,gazeX[i//5][i%5].data,gazeY[i//5][i%5].data)))
if ((epoch%100 == 0) and epoch != 0 and batch_idx==len(train_data_loader)-1):
save_checkpoints(os.path.join(OUTPUT_DIR, 'gaze-ckpt-epoch-%04d.pth' % (epoch + 1)),epoch,lf_model,optimizer)
########################## test Phi and Mask ##########################
# phi,mask = gen.CalculateRetinal2LayerMappings(df[0],torch.tensor([gazeX[0], gazeY[0]]))
# # print("gaze Online:",gazeX[0]," ,",gazeY[0])
# # print("df Online:",df[0])
# # print("idx:",int(sample_idx[0].data))
# phi_t = phi_dict[int(sample_idx[0].data)]
# mask_t = mask_dict[int(sample_idx[0].data)]
# # print("idx info:",idx_info_dict[int(sample_idx[0].data)])
# # print("phi online:", phi.shape, " phi_t:", phi_t.shape)
# # print("mask online:", mask.shape, " mask_t:", mask_t.shape)
# print("phi delta:", (phi-phi_t).sum()," mask delta:",(mask -mask_t).sum())
# exit(0)
###########################Gen Batch 1 by 1###################
# phi,mask = gen.CalculateRetinal2LayerMappings(df[0],torch.tensor([gazeX[0], gazeY[0]]))
# # print(phi.shape) # 2,240,320,41,2
# output1, mask = GenRetinalFromLayersBatch_Online(output, gen, phi, mask)
###########################Gen Batch 1 by 1###################
\ No newline at end of file
save_checkpoints(os.path.join(OUTPUT_DIR, 'gaze-ckpt-epoch-%04d.pth' % (epoch)),epoch,lf_model,optimizer)
\ No newline at end of file
......@@ -31,7 +31,7 @@ M = 1 # number of display layers
DATA_FILE = "/home/yejiannan/Project/LightField/data/lf_syn"
DATA_JSON = "/home/yejiannan/Project/LightField/data/data_lf_syn_full.json"
# DATA_VAL_JSON = "/home/yejiannan/Project/LightField/data/data_gaze_fovea_val.json"
OUTPUT_DIR = "/home/yejiannan/Project/LightField/outputE/lf_syn_full_perc"
OUTPUT_DIR = "/home/yejiannan/Project/LightField/outputE/lf_syn_full"
OUT_CHANNELS_RB = 128
KERNEL_SIZE_RB = 3
KERNEL_SIZE = 3
......@@ -50,7 +50,7 @@ def save_checkpoints(file_path, epoch_idx, model, model_solver):
torch.save(checkpoint, file_path)
mode = "Silence" #"Perf"
w_frame = 0.9
w_frame = 1.0
loss1 = PerceptionReconstructionLoss()
if __name__ == "__main__":
#train
......@@ -70,7 +70,7 @@ if __name__ == "__main__":
if torch.cuda.is_available():
# lf_model = torch.nn.DataParallel(lf_model).cuda()
lf_model = lf_model.to('cuda:3')
lf_model = lf_model.to('cuda:1')
optimizer = torch.optim.Adam(lf_model.parameters(),lr=5e-3,betas=(0.9,0.999))
......
import torch, os, sys, cv2
import torch.nn as nn
from torch.nn import init
import functools
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.nn import functional as func
from PIL import Image
import torchvision.transforms as transforms
import numpy as np
import torch
class RecurrentBlock(nn.Module):
def __init__(self, input_nc, output_nc, downsampling=False, bottleneck=False, upsampling=False):
super(RecurrentBlock, self).__init__()
self.input_nc = input_nc
self.output_nc = output_nc
self.downsampling = downsampling
self.upsampling = upsampling
self.bottleneck = bottleneck
self.hidden = None
if self.downsampling:
self.l1 = nn.Sequential(
nn.Conv2d(input_nc, output_nc, 3, padding=1),
nn.LeakyReLU(negative_slope=0.1)
)
self.l2 = nn.Sequential(
nn.Conv2d(2 * output_nc, output_nc, 3, padding=1),
nn.LeakyReLU(negative_slope=0.1),
nn.Conv2d(output_nc, output_nc, 3, padding=1),
nn.LeakyReLU(negative_slope=0.1),
)
elif self.upsampling:
self.l1 = nn.Sequential(
nn.Upsample(scale_factor=2, mode='nearest'),
nn.Conv2d(2 * input_nc, output_nc, 3, padding=1),
nn.LeakyReLU(negative_slope=0.1),
nn.Conv2d(output_nc, output_nc, 3, padding=1),
nn.LeakyReLU(negative_slope=0.1),
)
elif self.bottleneck:
self.l1 = nn.Sequential(
nn.Conv2d(input_nc, output_nc, 3, padding=1),
nn.LeakyReLU(negative_slope=0.1)
)
self.l2 = nn.Sequential(
nn.Conv2d(2 * output_nc, output_nc, 3, padding=1),
nn.LeakyReLU(negative_slope=0.1),
nn.Conv2d(output_nc, output_nc, 3, padding=1),
nn.LeakyReLU(negative_slope=0.1),
)
def forward(self, inp):
if self.downsampling:
op1 = self.l1(inp)
op2 = self.l2(torch.cat((op1, self.hidden), dim=1))
self.hidden = op2
return op2
elif self.upsampling:
op1 = self.l1(inp)
return op1
elif self.bottleneck:
op1 = self.l1(inp)
op2 = self.l2(torch.cat((op1, self.hidden), dim=1))
self.hidden = op2
return op2
def reset_hidden(self, inp, dfac):
size = list(inp.size())
size[1] = self.output_nc
size[2] /= dfac
size[3] /= dfac
self.hidden_size = size
self.hidden = torch.zeros(*(size)).to('cuda:0')
class RecurrentAE(nn.Module):
def __init__(self, input_nc):
super(RecurrentAE, self).__init__()
self.d1 = RecurrentBlock(input_nc=input_nc, output_nc=32, downsampling=True)
self.d2 = RecurrentBlock(input_nc=32, output_nc=43, downsampling=True)
self.d3 = RecurrentBlock(input_nc=43, output_nc=57, downsampling=True)
self.d4 = RecurrentBlock(input_nc=57, output_nc=76, downsampling=True)
self.d5 = RecurrentBlock(input_nc=76, output_nc=101, downsampling=True)
self.bottleneck = RecurrentBlock(input_nc=101, output_nc=101, bottleneck=True)
self.u5 = RecurrentBlock(input_nc=101, output_nc=76, upsampling=True)
self.u4 = RecurrentBlock(input_nc=76, output_nc=57, upsampling=True)
self.u3 = RecurrentBlock(input_nc=57, output_nc=43, upsampling=True)
self.u2 = RecurrentBlock(input_nc=43, output_nc=32, upsampling=True)
self.u1 = RecurrentBlock(input_nc=32, output_nc=3, upsampling=True)
def set_input(self, inp):
self.inp = inp['A']
def forward(self):
d1 = func.max_pool2d(input=self.d1(self.inp), kernel_size=2)
d2 = func.max_pool2d(input=self.d2(d1), kernel_size=2)
d3 = func.max_pool2d(input=self.d3(d2), kernel_size=2)
d4 = func.max_pool2d(input=self.d4(d3), kernel_size=2)
d5 = func.max_pool2d(input=self.d5(d4), kernel_size=2)
b = self.bottleneck(d5)
u5 = self.u5(torch.cat((b, d5), dim=1))
u4 = self.u4(torch.cat((u5, d4), dim=1))
u3 = self.u3(torch.cat((u4, d3), dim=1))
u2 = self.u2(torch.cat((u3, d2), dim=1))
u1 = self.u1(torch.cat((u2, d1), dim=1))
return u1
def reset_hidden(self):
self.d1.reset_hidden(self.inp, dfac=1)
self.d2.reset_hidden(self.inp, dfac=2)
self.d3.reset_hidden(self.inp, dfac=4)
self.d4.reset_hidden(self.inp, dfac=8)
self.d5.reset_hidden(self.inp, dfac=16)
self.bottleneck.reset_hidden(self.inp, dfac=32)
self.u4.reset_hidden(self.inp, dfac=16)
self.u3.reset_hidden(self.inp, dfac=8)
self.u5.reset_hidden(self.inp, dfac=4)
self.u2.reset_hidden(self.inp, dfac=2)
self.u1.reset_hidden(self.inp, dfac=1)
import numpy as np
import torch
import matplotlib.pyplot as plt
import glm
gvec_type = [ glm.dvec1, glm.dvec2, glm.dvec3, glm.dvec4 ]
gmat_type = [ [ glm.dmat2, glm.dmat2x3, glm.dmat2x4 ],
[ glm.dmat3x2, glm.dmat3, glm.dmat3x4 ],
[ glm.dmat4x2, glm.dmat4x3, glm.dmat4 ] ]
def Fov2Length(angle):
return np.tan(angle * np.pi / 360) * 2
def SmoothStep(x0, x1, x):
y = torch.clamp((x - x0) / (x1 - x0), 0, 1)
return y * y * (3 - 2 * y)
def MatImg2Tensor(img, permute = True, batch_dim = True):
batch_input = len(img.shape) == 4
if permute:
t = torch.from_numpy(np.transpose(img,
[0, 3, 1, 2] if batch_input else [2, 0, 1]))
else:
t = torch.from_numpy(img)
if not batch_input and batch_dim:
t = t.unsqueeze(0)
return t
def MatImg2Numpy(img, permute = True, batch_dim = True):
batch_input = len(img.shape) == 4
if permute:
t = np.transpose(img, [0, 3, 1, 2] if batch_input else [2, 0, 1])
else:
t = img
if not batch_input and batch_dim:
t = t.unsqueeze(0)
return t
def Tensor2MatImg(t):
img = t.squeeze().cpu().numpy()
batch_input = len(img.shape) == 4
if t.size()[batch_input] <= 4:
return np.transpose(img, [0, 2, 3, 1] if batch_input else [1, 2, 0])
return img
def ReadImageTensor(path, permute = True, rgb_only = True, batch_dim = True):
channels = 3 if rgb_only else 4
if isinstance(path,list):
first_image = plt.imread(path[0])[:, :, 0:channels]
b_image = np.empty((len(path), first_image.shape[0], first_image.shape[1], channels), dtype=np.float32)
b_image[0] = first_image
for i in range(1, len(path)):
b_image[i] = plt.imread(path[i])[:, :, 0:channels]
return MatImg2Tensor(b_image, permute)
return MatImg2Tensor(plt.imread(path)[:, :, 0:channels], permute, batch_dim)
def ReadImageNumpyArray(path, permute = True, rgb_only = True, batch_dim = True):
channels = 3 if rgb_only else 4
if isinstance(path,list):
first_image = plt.imread(path[0])[:, :, 0:channels]
b_image = np.empty((len(path), first_image.shape[0], first_image.shape[1], channels), dtype=np.float32)
b_image[0] = first_image
for i in range(1, len(path)):
b_image[i] = plt.imread(path[i])[:, :, 0:channels]
return MatImg2Numpy(b_image, permute)
return MatImg2Numpy(plt.imread(path)[:, :, 0:channels], permute, batch_dim)
def WriteImageTensor(t, path):
image = Tensor2MatImg(t)
if isinstance(path,list):
if len(image.shape) != 4 or image.shape[0] != len(path):
raise ValueError
for i in range(len(path)):
plt.imsave(path[i], image[i])
else:
if len(image.shape) == 4 and image.shape[0] != 1:
raise ValueError
plt.imsave(path, image)
def PlotImageTensor(t):
plt.imshow(Tensor2MatImg(t))
def Tensor2Glm(t):
t = t.squeeze()
size = t.size()
if len(size) == 1:
if size[0] <= 0 or size[0] > 4:
raise ValueError
return gvec_type[size[0] - 1](t.cpu().numpy())
if len(size) == 2:
if size[0] <= 1 or size[0] > 4 or size[1] <= 1 or size[1] > 4:
raise ValueError
return gmat_type[size[1] - 2][size[0] - 2](t.cpu().numpy())
raise ValueError
def Glm2Tensor(val):
return torch.from_numpy(np.array(val))
def MeshGrid(size, normalize=False):
y,x = torch.meshgrid(torch.tensor(range(size[0])),
torch.tensor(range(size[1])))
if normalize:
return torch.stack([x / (size[1] - 1.), y / (size[0] - 1.)], 2)
return torch.stack([x, y], 2)
\ No newline at end of file
import torch
weightVarScale = 0.25
bias_stddev = 0.01
def weight_init_normal(m):
classname = m.__class__.__name__
if classname.find("Conv") != -1:
torch.nn.init.xavier_normal_(m.weight.data)
torch.nn.init.normal_(m.bias.data,mean = 0.0, std=bias_stddev)
elif classname.find("BatchNorm2d") != -1:
torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
torch.nn.init.constant_(m.bias.data, 0.0)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment