Commit a22cfc90 authored by BobYeah's avatar BobYeah
Browse files

Update CNN scripts

parent c5434e97
import matplotlib.pyplot as plt
import torch
import util
import numpy as np
def FlowMap(b_last_frame, b_map):
'''
Map images using the flow data.
Parameters
--------
b_last_frame - B x 3 x H x W tensor, batch of images
b_map - B x H x W x 2, batch of map data records pixel coords in last frames
Returns
--------
B x 3 x H x W tensor, batch of images mapped by flow data
'''
return torch.nn.functional.grid_sample(b_last_frame, b_map, align_corners=False)
class Flow(object):
'''
Class representating optical flow
Properties
--------
b_data - B x H x W x 2, batch of flow data
b_map - B x H x W x 2, batch of map data records pixel coords in last frames
b_invalid_mask - B x H x W, batch of masks, indicate invalid elements in corresponding flow data
'''
def Load(paths):
'''
Create a Flow instance using a batch of encoded data images loaded from paths
Parameters
--------
paths - list of encoded data image paths
Returns
--------
Flow instance
'''
b_encoded_image = util.ReadImageTensor(paths, rgb_only=False, permute=False, batch_dim=True)
return Flow(b_encoded_image)
def __init__(self, b_encoded_image):
'''
Initialize a Flow instance from a batch of encoded data images
Parameters
--------
b_encoded_image - batch of encoded data images
'''
b_encoded_image = b_encoded_image.mul(255)
# print("b_encoded_image:",b_encoded_image.shape)
self.b_invalid_mask = (b_encoded_image[:, :, :, 0] == 255)
self.b_data = (b_encoded_image[:, :, :, 0:2] / 254 + b_encoded_image[:, :, :, 2:4] - 127) / 127
self.b_data[:, :, :, 1] = -self.b_data[:, :, :, 1]
D = self.b_data.size()
grid = util.MeshGrid((D[1], D[2]), True)
self.b_map = (grid - self.b_data - 0.5) * 2
self.b_map[self.b_invalid_mask] = torch.tensor([ -2.0, -2.0 ])
def getMap(self):
return self.b_map
def Visualize(self, scale_factor = 1):
'''
Visualize the flow data by "color wheel".
Parameters
--------
scale_factor - scale factor of flow data to visualize, default is 1
Returns
--------
B x 3 x H x W tensor, visualization of flow data
'''
try:
Flow.b_color_wheel
except AttributeError:
Flow.b_color_wheel = util.ReadImageTensor('color_wheel.png')
return torch.nn.functional.grid_sample(Flow.b_color_wheel.expand(self.b_data.size()[0], -1, -1, -1),
(self.b_data * scale_factor), align_corners=False)
\ No newline at end of file
...@@ -38,7 +38,7 @@ class residual_block(torch.nn.Module): ...@@ -38,7 +38,7 @@ class residual_block(torch.nn.Module):
def forward(self,input): def forward(self,input):
if self.RNN: if self.RNN:
# print("input:",input.shape,"hidden:",self.hidden.shape) # print("input:",input.shape,"hidden:",self.hidden.shape)
inp = torch.cat((input,self.hidden),dim=1) inp = torch.cat((input,self.hidden.detach()),dim=1)
# print(inp.shape) # print(inp.shape)
output = self.layer1(inp) output = self.layer1(inp)
output = self.layer2(output) output = self.layer2(output)
...@@ -97,7 +97,7 @@ class interleave(torch.nn.Module): ...@@ -97,7 +97,7 @@ class interleave(torch.nn.Module):
return output return output
class model(torch.nn.Module): class model(torch.nn.Module):
def __init__(self,FIRSST_LAYER_CHANNELS,LAST_LAYER_CHANNELS,OUT_CHANNELS_RB,KERNEL_SIZE,KERNEL_SIZE_RB,INTERLEAVE_RATE): def __init__(self,FIRSST_LAYER_CHANNELS,LAST_LAYER_CHANNELS,OUT_CHANNELS_RB,KERNEL_SIZE,KERNEL_SIZE_RB,INTERLEAVE_RATE,RNN=False):
super(model, self).__init__() super(model, self).__init__()
self.interleave = interleave(INTERLEAVE_RATE) self.interleave = interleave(INTERLEAVE_RATE)
...@@ -108,13 +108,19 @@ class model(torch.nn.Module): ...@@ -108,13 +108,19 @@ class model(torch.nn.Module):
) )
self.residual_block1 = residual_block(OUT_CHANNELS_RB,0,KERNEL_SIZE_RB,False) self.residual_block1 = residual_block(OUT_CHANNELS_RB,0,KERNEL_SIZE_RB,False)
self.residual_block2 = residual_block(OUT_CHANNELS_RB,3,KERNEL_SIZE_RB,False) self.residual_block2 = residual_block(OUT_CHANNELS_RB,2,KERNEL_SIZE_RB,False)
self.residual_block3 = residual_block(OUT_CHANNELS_RB,3,KERNEL_SIZE_RB,True) self.residual_block3 = residual_block(OUT_CHANNELS_RB,2,KERNEL_SIZE_RB,False)
self.residual_block4 = residual_block(OUT_CHANNELS_RB,3,KERNEL_SIZE_RB,True) # if RNN:
self.residual_block5 = residual_block(OUT_CHANNELS_RB,3,KERNEL_SIZE_RB,True) # self.residual_block3 = residual_block(OUT_CHANNELS_RB,6,KERNEL_SIZE_RB,True)
# self.residual_block4 = residual_block(OUT_CHANNELS_RB,6,KERNEL_SIZE_RB,True)
# self.residual_block5 = residual_block(OUT_CHANNELS_RB,6,KERNEL_SIZE_RB,True)
# else:
# self.residual_block3 = residual_block(OUT_CHANNELS_RB,6,KERNEL_SIZE_RB,False)
# self.residual_block4 = residual_block(OUT_CHANNELS_RB,6,KERNEL_SIZE_RB,False)
# self.residual_block5 = residual_block(OUT_CHANNELS_RB,6,KERNEL_SIZE_RB,False)
self.output_layer = torch.nn.Sequential( self.output_layer = torch.nn.Sequential(
torch.nn.Conv2d(OUT_CHANNELS_RB+3,LAST_LAYER_CHANNELS,KERNEL_SIZE,stride=1,padding=1), torch.nn.Conv2d(OUT_CHANNELS_RB+2,LAST_LAYER_CHANNELS,KERNEL_SIZE,stride=1,padding=1),
torch.nn.BatchNorm2d(LAST_LAYER_CHANNELS), torch.nn.BatchNorm2d(LAST_LAYER_CHANNELS),
torch.nn.Sigmoid() torch.nn.Sigmoid()
) )
...@@ -125,7 +131,7 @@ class model(torch.nn.Module): ...@@ -125,7 +131,7 @@ class model(torch.nn.Module):
self.residual_block4.reset_hidden(inp) self.residual_block4.reset_hidden(inp)
self.residual_block5.reset_hidden(inp) self.residual_block5.reset_hidden(inp)
def forward(self, lightfield_images, focal_length, gazeX, gazeY): def forward(self, lightfield_images, pos_row, pos_col):
# lightfield_images: torch.Size([batch_size, channels * D, H, W]) # lightfield_images: torch.Size([batch_size, channels * D, H, W])
# channels : RGB*D: 3*9, H:256, W:256 # channels : RGB*D: 3*9, H:256, W:256
# print("lightfield_images:",lightfield_images.shape) # print("lightfield_images:",lightfield_images.shape)
...@@ -136,32 +142,18 @@ class model(torch.nn.Module): ...@@ -136,32 +142,18 @@ class model(torch.nn.Module):
# print("input_to_rb1:",input_to_rb.shape) # print("input_to_rb1:",input_to_rb.shape)
output = self.residual_block1(input_to_rb) output = self.residual_block1(input_to_rb)
depth_layer = torch.ones((input_to_rb.shape[0],1,input_to_rb.shape[2],input_to_rb.shape[3])) pos_row_layer = torch.ones((input_to_rb.shape[0],1,input_to_rb.shape[2],input_to_rb.shape[3]))
gazeX_layer = torch.ones((input_to_rb.shape[0],1,input_to_rb.shape[2],input_to_rb.shape[3])) pos_col_layer = torch.ones((input_to_rb.shape[0],1,input_to_rb.shape[2],input_to_rb.shape[3]))
gazeY_layer = torch.ones((input_to_rb.shape[0],1,input_to_rb.shape[2],input_to_rb.shape[3])) for i in range(pos_row.shape[0]):
# print("depth_layer:",depth_layer.shape) pos_row_layer[i] *= pos_row[i]
# print("focal_depth:",focal_length," gazeX:",gazeX," gazeY:",gazeY, " gazeX norm:",(gazeX[0] - (-3.333)) / (3.333*2)) pos_col_layer[i] *= pos_col[i]
for i in range(focal_length.shape[0]):
depth_layer[i] *= 1. / focal_length[i]
gazeX_layer[i] *= (gazeX[i] - (-3.333)) / (3.333*2)
gazeY_layer[i] *= (gazeY[i] - (-3.333)) / (3.333*2)
# print(depth_layer.shape) # print(depth_layer.shape)
depth_layer = var_or_cuda(depth_layer) pos_row_layer = var_or_cuda(pos_row_layer)
gazeX_layer = var_or_cuda(gazeX_layer) pos_col_layer = var_or_cuda(pos_col_layer)
gazeY_layer = var_or_cuda(gazeY_layer)
output = torch.cat((output,depth_layer,gazeX_layer,gazeY_layer),dim=1)
# output = torch.cat((output,depth_layer),dim=1)
# print("output to rb2:",output.shape)
output = torch.cat((output,pos_row_layer,pos_col_layer),dim=1)
output = self.residual_block2(output) output = self.residual_block2(output)
# print("output to rb3:",output.shape)
output = self.residual_block3(output) output = self.residual_block3(output)
# print("output to rb4:",output.shape)
output = self.residual_block4(output)
# print("output to rb5:",output.shape)
output = self.residual_block5(output)
# output = output + input_to_net
output = self.output_layer(output) output = self.output_layer(output)
output = self.deinterleave(output) output = self.deinterleave(output)
return output return output
\ No newline at end of file
import torch import torch
from gen_image import * import util
import numpy as np
class Conf(object): class Conf(object):
def __init__(self): def __init__(self):
self.pupil_size = 0.02 self.pupil_size = 0.02
self.retinal_res = torch.tensor([ 320, 320 ]) self.retinal_res = torch.tensor([ 320, 320 ])
self.layer_res = torch.tensor([ 320, 320 ]) self.layer_res = torch.tensor([ 320, 320 ])
self.layer_hfov = 90 # layers' horizontal FOV self.layer_hfov = 90 # layers' horizontal FOV
self.eye_hfov = 85 # eye's horizontal FOV (ignored in foveated rendering) self.eye_hfov = 80 # eye's horizontal FOV (ignored in foveated rendering)
self.eye_enable_fovea = True # enable foveated rendering self.eye_enable_fovea = False # enable foveated rendering
self.eye_fovea_angles = [ 40, 80 ] # eye's foveation layers' angles self.eye_fovea_angles = [ 40, 80 ] # eye's foveation layers' angles
self.eye_fovea_downsamples = [ 1, 2 ] # eye's foveation layers' downsamples self.eye_fovea_downsamples = [ 1, 2 ] # eye's foveation layers' downsamples
self.d_layer = [ 1, 3 ] # layers' distance self.d_layer = [ 1, 3 ] # layers' distance
self.eye_fovea_blend = [ self._GenFoveaLayerBlend(0) ]
# blend maps of fovea layers
self.light_field_dim = 5
def GetNLayers(self): def GetNLayers(self):
return len(self.d_layer) return len(self.d_layer)
def GetLayerSize(self, i): def GetLayerSize(self, i):
w = Fov2Length(self.layer_hfov) w = util.Fov2Length(self.layer_hfov)
h = w * self.layer_res[0] / self.layer_res[1] h = w * self.layer_res[0] / self.layer_res[1]
return torch.tensor([ h, w ]) * self.d_layer[i] return torch.tensor([ h, w ]) * self.d_layer[i]
def GetPixelSizeOfLayer(self, i):
'''
Get pixel size of layer i
'''
return util.Fov2Length(self.layer_hfov) * self.d_layer[i] / self.layer_res[0]
def GetEyeViewportSize(self): def GetEyeViewportSize(self):
fov = self.eye_fovea_angles[-1] if self.eye_enable_fovea else self.eye_hfov fov = self.eye_fovea_angles[-1] if self.eye_enable_fovea else self.eye_hfov
w = Fov2Length(fov) w = util.Fov2Length(fov)
h = w * self.retinal_res[0] / self.retinal_res[1] h = w * self.retinal_res[0] / self.retinal_res[1]
return torch.tensor([ h, w ]) return torch.tensor([ h, w ])
\ No newline at end of file
def GetRegionOfFoveaLayer(self, i):
'''
Get region of fovea layer i in retinal image
Returns
--------
slice object stores the start and end of region
'''
roi_size = int(np.ceil(self.retinal_res[0] * self.eye_fovea_angles[i] / self.eye_fovea_angles[-1]))
roi_offset = int((self.retinal_res[0] - roi_size) / 2)
return slice(roi_offset, roi_offset + roi_size)
def _GenFoveaLayerBlend(self, i):
'''
Generate blend map for fovea layer i
Parameters
--------
i - index of fovea layer
Returns
--------
H[i] x W[i], blend map
'''
region = self.GetRegionOfFoveaLayer(i)
width = region.stop - region.start
R = width / 2
p = util.MeshGrid([ width, width ])
r = torch.linalg.norm(p - R, 2, dim=2, keepdim=False)
return util.SmoothStep(R, R * 0.6, r)
import torch
import os
import glob
import numpy as np
import torchvision.transforms as transforms
from torchvision import datasets
from torch.utils.data import DataLoader
import cv2
import json
from Flow import *
from gen_image import *
import util
import time
class lightFieldSynDataLoader(torch.utils.data.dataset.Dataset):
def __init__(self, file_dir_path,file_json):
self.file_dir_path = file_dir_path
with open(file_json, encoding='utf-8') as file:
self.dataset_desc = json.loads(file.read())
self.input_img = []
for i in self.dataset_desc["train"]:
lf_element = os.path.join(self.file_dir_path,i)
lf_element = cv2.imread(lf_element, -cv2.IMREAD_ANYDEPTH)[:, :, 0:3]
lf_element = cv2.cvtColor(lf_element, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.
lf_element = lf_element[:,:-1,:]
self.input_img.append(lf_element)
self.input_img = np.asarray(self.input_img)
def __len__(self):
return len(self.dataset_desc["gt"])
def __getitem__(self, idx):
gt, pos_row, pos_col = self.get_datum(idx)
return (self.input_img, gt, pos_row, pos_col)
def get_datum(self, idx):
fd_gt_path = os.path.join(self.file_dir_path, self.dataset_desc["gt"][idx])
gt = cv2.imread(fd_gt_path, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.
gt = cv2.cvtColor(gt,cv2.COLOR_BGR2RGB)
gt = gt[:,:-1,:]
pos_col = self.dataset_desc["x"][idx]
pos_row = self.dataset_desc["y"][idx]
return gt, pos_row, pos_col
class lightFieldDataLoader(torch.utils.data.dataset.Dataset):
def __init__(self, file_dir_path, file_json, transforms=None):
self.file_dir_path = file_dir_path
self.transforms = transforms
# self.datum_list = glob.glob(os.path.join(file_dir_path,"*"))
with open(file_json, encoding='utf-8') as file:
self.dataset_desc = json.loads(file.read())
def __len__(self):
return len(self.dataset_desc["focaldepth"])
def __getitem__(self, idx):
lightfield_images, gt, gt2, fd, gazeX, gazeY, sample_idx = self.get_datum(idx)
if self.transforms:
lightfield_images = self.transforms(lightfield_images)
# print(lightfield_images.shape,gt.shape,fd,gazeX,gazeY,sample_idx)
return (lightfield_images, gt, gt2, fd, gazeX, gazeY, sample_idx)
def get_datum(self, idx):
lf_image_paths = os.path.join(self.file_dir_path, self.dataset_desc["train"][idx])
# print(lf_image_paths)
fd_gt_path = os.path.join(self.file_dir_path, self.dataset_desc["gt"][idx])
fd_gt_path2 = os.path.join(self.file_dir_path, self.dataset_desc["gt2"][idx])
# print(fd_gt_path)
lf_images = []
lf_image_big = cv2.imread(lf_image_paths, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.
lf_image_big = cv2.cvtColor(lf_image_big,cv2.COLOR_BGR2RGB)
## IF GrayScale
# lf_image_big = cv2.imread(lf_image_paths, cv2.IMREAD_GRAYSCALE).astype(np.float32) / 255.
# lf_image_big = np.expand_dims(lf_image_big, axis=-1)
# print(lf_image_big.shape)
for i in range(9):
lf_image = lf_image_big[i//3*IM_H:i//3*IM_H+IM_H,i%3*IM_W:i%3*IM_W+IM_W,0:3]
## IF GrayScale
# lf_image = lf_image_big[i//3*IM_H:i//3*IM_H+IM_H,i%3*IM_W:i%3*IM_W+IM_W,0:1]
# print(lf_image.shape)
lf_images.append(lf_image)
gt = cv2.imread(fd_gt_path, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.
gt = cv2.cvtColor(gt,cv2.COLOR_BGR2RGB)
gt2 = cv2.imread(fd_gt_path2, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.
gt2 = cv2.cvtColor(gt2,cv2.COLOR_BGR2RGB)
## IF GrayScale
# gt = cv2.imread(fd_gt_path, cv2.IMREAD_GRAYSCALE).astype(np.float32) / 255.
# gt = np.expand_dims(gt, axis=-1)
fd = self.dataset_desc["focaldepth"][idx]
gazeX = self.dataset_desc["gazeX"][idx]
gazeY = self.dataset_desc["gazeY"][idx]
sample_idx = self.dataset_desc["idx"][idx]
return np.asarray(lf_images),gt,gt2,fd,gazeX,gazeY,sample_idx
class lightFieldValDataLoader(torch.utils.data.dataset.Dataset):
def __init__(self, file_dir_path, file_json, transforms=None):
self.file_dir_path = file_dir_path
self.transforms = transforms
# self.datum_list = glob.glob(os.path.join(file_dir_path,"*"))
with open(file_json, encoding='utf-8') as file:
self.dataset_desc = json.loads(file.read())
def __len__(self):
return len(self.dataset_desc["focaldepth"])
def __getitem__(self, idx):
lightfield_images, fd, gazeX, gazeY, sample_idx = self.get_datum(idx)
if self.transforms:
lightfield_images = self.transforms(lightfield_images)
# print(lightfield_images.shape,gt.shape,fd,gazeX,gazeY,sample_idx)
return (lightfield_images, fd, gazeX, gazeY, sample_idx)
def get_datum(self, idx):
lf_image_paths = os.path.join(self.file_dir_path, self.dataset_desc["train"][idx])
# print(fd_gt_path)
lf_images = []
lf_image_big = cv2.imread(lf_image_paths, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.
lf_image_big = cv2.cvtColor(lf_image_big,cv2.COLOR_BGR2RGB)
## IF GrayScale
# lf_image_big = cv2.imread(lf_image_paths, cv2.IMREAD_GRAYSCALE).astype(np.float32) / 255.
# lf_image_big = np.expand_dims(lf_image_big, axis=-1)
# print(lf_image_big.shape)
for i in range(9):
lf_image = lf_image_big[i//3*IM_H:i//3*IM_H+IM_H,i%3*IM_W:i%3*IM_W+IM_W,0:3]
## IF GrayScale
# lf_image = lf_image_big[i//3*IM_H:i//3*IM_H+IM_H,i%3*IM_W:i%3*IM_W+IM_W,0:1]
# print(lf_image.shape)
lf_images.append(lf_image)
## IF GrayScale
# gt = cv2.imread(fd_gt_path, cv2.IMREAD_GRAYSCALE).astype(np.float32) / 255.
# gt = np.expand_dims(gt, axis=-1)
fd = self.dataset_desc["focaldepth"][idx]
gazeX = self.dataset_desc["gazeX"][idx]
gazeY = self.dataset_desc["gazeY"][idx]
sample_idx = self.dataset_desc["idx"][idx]
return np.asarray(lf_images),fd,gazeX,gazeY,sample_idx
class lightFieldSeqDataLoader(torch.utils.data.dataset.Dataset):
def __init__(self, file_dir_path, file_json, transforms=None):
self.file_dir_path = file_dir_path
self.transforms = transforms
with open(file_json, encoding='utf-8') as file:
self.dataset_desc = json.loads(file.read())
def __len__(self):
return len(self.dataset_desc["seq"])
def __getitem__(self, idx):
lightfield_images, gt, gt2, fd, gazeX, gazeY, sample_idx = self.get_datum(idx)
fd = fd.astype(np.float32)
gazeX = gazeX.astype(np.float32)
gazeY = gazeY.astype(np.float32)
sample_idx = sample_idx.astype(np.int64)
# print(fd)
# print(gazeX)
# print(gazeY)
# print(sample_idx)
# print(lightfield_images.dtype,gt.dtype, gt2.dtype, fd.dtype, gazeX.dtype, gazeY.dtype, sample_idx.dtype, delta.dtype)
# print(lightfield_images.shape,gt.shape, gt2.shape, fd.shape, gazeX.shape, gazeY.shape, sample_idx.shape, delta.shape)
if self.transforms:
lightfield_images = self.transforms(lightfield_images)
return (lightfield_images, gt, gt2, fd, gazeX, gazeY, sample_idx)
def get_datum(self, idx):
indices = self.dataset_desc["seq"][idx]
# print("indices:",indices)
lf_images = []
fd = []
gazeX = []
gazeY = []
sample_idx = []
gt = []
gt2 = []
for i in range(len(indices)):
lf_image_paths = os.path.join(self.file_dir_path, self.dataset_desc["train"][indices[i]])
fd_gt_path = os.path.join(self.file_dir_path, self.dataset_desc["gt"][indices[i]])
fd_gt_path2 = os.path.join(self.file_dir_path, self.dataset_desc["gt2"][indices[i]])
lf_image_one_sample = []
lf_image_big = cv2.imread(lf_image_paths, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.
lf_image_big = cv2.cvtColor(lf_image_big,cv2.COLOR_BGR2RGB)
for j in range(9):
lf_image = lf_image_big[j//3*IM_H:j//3*IM_H+IM_H,j%3*IM_W:j%3*IM_W+IM_W,0:3]
lf_image_one_sample.append(lf_image)
gt_i = cv2.imread(fd_gt_path, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.
gt.append(cv2.cvtColor(gt_i,cv2.COLOR_BGR2RGB))
gt2_i = cv2.imread(fd_gt_path2, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.
gt2.append(cv2.cvtColor(gt2_i,cv2.COLOR_BGR2RGB))
# print("indices[i]:",indices[i])
fd.append([self.dataset_desc["focaldepth"][indices[i]]])
gazeX.append([self.dataset_desc["gazeX"][indices[i]]])
gazeY.append([self.dataset_desc["gazeY"][indices[i]]])
sample_idx.append([self.dataset_desc["idx"][indices[i]]])
lf_images.append(lf_image_one_sample)
#lf_images: 5,9,320,320
return np.asarray(lf_images),np.asarray(gt),np.asarray(gt2),np.asarray(fd),np.asarray(gazeX),np.asarray(gazeY),np.asarray(sample_idx)
class lightFieldFlowSeqDataLoader(torch.utils.data.dataset.Dataset):
def __init__(self, file_dir_path, file_json, gen, conf, transforms=None):
self.file_dir_path = file_dir_path
self.file_json = file_json
self.gen = gen
self.conf = conf
self.transforms = transforms
with open(file_json, encoding='utf-8') as file:
self.dataset_desc = json.loads(file.read())
def __len__(self):
return len(self.dataset_desc["seq"])
def __getitem__(self, idx):
# start = time.time()
lightfield_images, gt, phi, phi_invalid, retinal_invalid, flow, flow_invalid_mask, fd, gazeX, gazeY, posX, posY, posZ, sample_idx = self.get_datum(idx)
fd = fd.astype(np.float32)
gazeX = gazeX.astype(np.float32)
gazeY = gazeY.astype(np.float32)
posX = posX.astype(np.float32)
posY = posY.astype(np.float32)
posZ = posZ.astype(np.float32)
sample_idx = sample_idx.astype(np.int64)
if self.transforms:
lightfield_images = self.transforms(lightfield_images)
# print("read once:",time.time() - start) # 8 ms
return (lightfield_images, gt, phi, phi_invalid, retinal_invalid, flow, flow_invalid_mask, fd, gazeX, gazeY, posX, posY, posZ, sample_idx)
def get_datum(self, idx):
IM_H = 320
IM_W = 320
indices = self.dataset_desc["seq"][idx]
# print("indices:",indices)
lf_images = []
fd = []
gazeX = []
gazeY = []
posX = []
posY = []
posZ = []
sample_idx = []
gt = []
# gt2 = []
phi = []
phi_invalid = []
retinal_invalid = []
for i in range(len(indices)): # 5
lf_image_paths = os.path.join(self.file_dir_path, self.dataset_desc["train"][indices[i]])
fd_gt_path = os.path.join(self.file_dir_path, self.dataset_desc["gt"][indices[i]])
# fd_gt_path2 = os.path.join(self.file_dir_path, self.dataset_desc["gt2"][indices[i]])
lf_image_one_sample = []
lf_image_big = cv2.imread(lf_image_paths, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.
lf_image_big = cv2.cvtColor(lf_image_big,cv2.COLOR_BGR2RGB)
lf_dim = int(self.conf.light_field_dim)
for j in range(lf_dim**2):
lf_image = lf_image_big[j//lf_dim*IM_H:j//lf_dim*IM_H+IM_H,j%lf_dim*IM_W:j%lf_dim*IM_W+IM_W,0:3]
lf_image_one_sample.append(lf_image)
gt_i = cv2.imread(fd_gt_path, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.
gt.append(cv2.cvtColor(gt_i,cv2.COLOR_BGR2RGB))
# gt2_i = cv2.imread(fd_gt_path2, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.
# gt2.append(cv2.cvtColor(gt2_i,cv2.COLOR_BGR2RGB))
# print("indices[i]:",indices[i])
fd.append([self.dataset_desc["focaldepth"][indices[i]]])
gazeX.append([self.dataset_desc["gazeX"][indices[i]]])
gazeY.append([self.dataset_desc["gazeY"][indices[i]]])
posX.append([self.dataset_desc["x"][indices[i]]])
posY.append([self.dataset_desc["y"][indices[i]]])
posZ.append([0.0])
sample_idx.append([self.dataset_desc["idx"][indices[i]]])
lf_images.append(lf_image_one_sample)
idx_i = sample_idx[i][0]
focaldepth_i = fd[i][0]
gazeX_i = gazeX[i][0]
gazeY_i = gazeY[i][0]
posX_i = posX[i][0]
posY_i = posY[i][0]
posZ_i = posZ[i][0]
# print("x:%.3f,y:%.3f,z:%.3f;gaze:%.4f,%.4f,focaldepth:%.3f."%(posX_i,posY_i,posZ_i,gazeX_i,gazeY_i,focaldepth_i))
phi_i,phi_invalid_i,retinal_invalid_i = self.gen.CalculateRetinal2LayerMappings(torch.tensor([posX_i, posY_i,posZ_i]),torch.tensor([gazeX_i, gazeY_i]),focaldepth_i)
phi.append(phi_i)
phi_invalid.append(phi_invalid_i)
retinal_invalid.append(retinal_invalid_i)
#lf_images: 5,9,320,320
flow = Flow.Load([os.path.join(self.file_dir_path, self.dataset_desc["flow"][indices[i-1]]) for i in range(1,len(indices))])
flow_map = flow.getMap()
flow_invalid_mask = flow.b_invalid_mask
# print("flow:",flow_map.shape)
return np.asarray(lf_images),np.asarray(gt), torch.stack(phi,dim=0), torch.stack(phi_invalid,dim=0),torch.stack(retinal_invalid,dim=0), flow_map, flow_invalid_mask, np.asarray(fd),np.asarray(gazeX),np.asarray(gazeY),np.asarray(posX),np.asarray(posY),np.asarray(posZ),np.asarray(sample_idx)
...@@ -2,13 +2,8 @@ import matplotlib.pyplot as plt ...@@ -2,13 +2,8 @@ import matplotlib.pyplot as plt
import numpy as np import numpy as np
import torch import torch
import glm import glm
import time
def Fov2Length(angle): import util
'''
'''
return np.tan(angle * np.pi / 360) * 2
def RandomGenSamplesInPupil(pupil_size, n_samples): def RandomGenSamplesInPupil(pupil_size, n_samples):
''' '''
...@@ -21,9 +16,9 @@ def RandomGenSamplesInPupil(pupil_size, n_samples): ...@@ -21,9 +16,9 @@ def RandomGenSamplesInPupil(pupil_size, n_samples):
Returns Returns
-------- --------
a n_samples x 2 tensor with 2D sample position in each row a n_samples x 3 tensor with 3D sample position in each row
''' '''
samples = torch.empty(n_samples, 2) samples = torch.empty(n_samples, 3)
i = 0 i = 0
while i < n_samples: while i < n_samples:
s = (torch.rand(2) - 0.5) * pupil_size s = (torch.rand(2) - 0.5) * pupil_size
...@@ -44,7 +39,7 @@ def GenSamplesInPupil(pupil_size, circles): ...@@ -44,7 +39,7 @@ def GenSamplesInPupil(pupil_size, circles):
Returns Returns
-------- --------
a n_samples x 2 tensor with 2D sample position in each row a n_samples x 3 tensor with 3D sample position in each row
''' '''
samples = torch.zeros(1, 3) samples = torch.zeros(1, 3)
for i in range(1, circles): for i in range(1, circles):
...@@ -70,7 +65,7 @@ class RetinalGen(object): ...@@ -70,7 +65,7 @@ class RetinalGen(object):
Methods Methods
-------- --------
''' '''
def __init__(self, conf, u): def __init__(self, conf):
''' '''
Initialize retinal generator instance Initialize retinal generator instance
...@@ -80,58 +75,83 @@ class RetinalGen(object): ...@@ -80,58 +75,83 @@ class RetinalGen(object):
u - a M x 3 tensor stores M sample positions in pupil u - a M x 3 tensor stores M sample positions in pupil
''' '''
self.conf = conf self.conf = conf
self.u = GenSamplesInPupil(conf.pupil_size, 5)
# self.u = u.to(cuda_dev) # self.u = u.to(cuda_dev)
self.u = u # M x 3 M sample positions # self.u = u # M x 3 M sample positions
self.D_r = conf.retinal_res # retinal res 480 x 640 self.D_r = conf.retinal_res # retinal res 480 x 640
self.N = conf.GetNLayers() # 2 self.N = conf.GetNLayers() # 2
self.M = u.size()[0] # samples self.M = self.u.size()[0] # samples
p_rx, p_ry = torch.meshgrid(torch.tensor(range(0, self.D_r[0])), # p_rx, p_ry = torch.meshgrid(torch.tensor(range(0, self.D_r[0])),
torch.tensor(range(0, self.D_r[1]))) # torch.tensor(range(0, self.D_r[1])))
# self.p_r = torch.cat([
# ((torch.stack([p_rx, p_ry], 2) + 0.5) / self.D_r - 0.5) * conf.GetEyeViewportSize(), # 眼球视野
# torch.ones(self.D_r[0], self.D_r[1], 1)
# ], 2)
self.p_r = torch.cat([ self.p_r = torch.cat([
((torch.stack([p_rx, p_ry], 2) + 0.5) / self.D_r - 0.5) * conf.GetEyeViewportSize(), # 眼球视野 ((util.MeshGrid(self.D_r) + 0.5) / self.D_r - 0.5) * conf.GetEyeViewportSize(),
torch.ones(self.D_r[0], self.D_r[1], 1) torch.ones(self.D_r[0], self.D_r[1], 1)
], 2) ], 2)
# self.Phi = torch.empty(N, D_r[0], D_r[1], M, 2, device=cuda_dev, dtype=torch.long) # self.Phi = torch.empty(N, D_r[0], D_r[1], M, 2, device=cuda_dev, dtype=torch.long)
# self.mask = torch.empty(self.N, self.D_r[0], self.D_r[1], self.M, 2, dtype=torch.float) # 2 x 480 x 640 x 41 x 2 # self.mask = torch.empty(self.N, self.D_r[0], self.D_r[1], self.M, 2, dtype=torch.float) # 2 x 480 x 640 x 41 x 2
def CalculateRetinal2LayerMappings(self, df, gaze): def CalculateRetinal2LayerMappings(self, position, gaze_dir, df):
''' '''
Calculate the mapping matrix from retinal to layers. Calculate the mapping matrix from retinal to layers.
Parameters Parameters
-------- --------
df - focus distance position - 1 x 3 tensor, eye's position
gaze - 2 x 1 tensor, eye rotation angle (degs) in horizontal and vertical direction gaze_dir - 1 x 2 tensor, gaze forward vector (with z normalized)
df - focus distance
Returns
--------
phi - N x H_r x W_r x M x 2, retinal to layers mapping, N is number of layers
phi_invalid - N x H_r x W_r x M x 1, indicates invalid (out-of-range) mapping
retinal_invalid - 1 x H_r x W_r, indicates invalid pixels in retinal image
''' '''
Phi = torch.empty(self.N, self.D_r[0], self.D_r[1], self.M, 2, dtype=torch.long) # 2 x 480 x 640 x 41 x 2 D = self.conf.layer_res
mask = torch.empty(self.N, self.D_r[0], self.D_r[1], self.M, 2, dtype=torch.float) c = torch.tensor([ D[1] / 2, D[0] / 2 ]) # c: Center of layers (pixel)
D_r = self.conf.retinal_res # D_r: Resolution of retinal 480 640 D_r = self.conf.retinal_res # D_r: Resolution of retinal 480 640
V = self.conf.GetEyeViewportSize() # V: Viewport size of eye V = self.conf.GetEyeViewportSize() # V: Viewport size of eye
c = (self.conf.layer_res / 2) # c: Center of layers (pixel)
p_f = self.p_r * df # p_f: H x W x 3, focus positions of retinal pixels on focus plane p_f = self.p_r * df # p_f: H x W x 3, focus positions of retinal pixels on focus plane
rot_forward = glm.dvec3(glm.tan(glm.radians(glm.dvec2(gaze[1], -gaze[0]))), 1)
rot_mat = torch.from_numpy(np.array( # Calculate transformation from eye to display
glm.dmat3(glm.lookAtLH(glm.dvec3(), rot_forward, glm.dvec3(0, 1, 0))))) gvec_lookat = glm.dvec3(gaze_dir[0], -gaze_dir[1], 1)
rot_mat = rot_mat.float() gmat_eye = glm.inverse(glm.lookAtLH(glm.dvec3(), gvec_lookat, glm.dvec3(0, 1, 0)))
u_rot = torch.mm(self.u, rot_mat) eye_rot = util.Glm2Tensor(glm.dmat3(gmat_eye))
v_rot = torch.matmul(p_f, rot_mat).unsqueeze(2).expand( eye_center = torch.tensor([ position[0], -position[1], position[2] ])
-1, -1, self.u.size()[0], -1) - u_rot # v_rot: H x W x M x 3, rotated rays' direction vector
v_rot.div_(v_rot[:, :, :, 2].unsqueeze(3)) # make z = 1 for each direction vector in v_rot u_rot = torch.mm(self.u, eye_rot)
v_rot = torch.matmul(p_f, eye_rot).unsqueeze(2).expand(
for i in range(0, self.conf.GetNLayers()): -1, -1, self.M, -1) - u_rot # v_rot: H x W x M x 3, rotated rays' direction vector
dp_i = self.conf.GetLayerSize(i)[0] / self.conf.layer_res[0] # dp_i: Pixel size of layer i u_rot.add_(eye_center) # translate by eye's center
d_i = self.conf.d_layer[i] # d_i: Distance of layer i v_rot = v_rot.div(v_rot[:, :, :, 2].unsqueeze(3)) # make z = 1 for each direction vector in v_rot
phi = torch.empty(self.N, self.D_r[0], self.D_r[1], self.M, 2, dtype=torch.long)
for i in range(0, self.N):
dp_i = self.conf.GetPixelSizeOfLayer(i) # dp_i: Pixel size of layer i
d_i = self.conf.d_layer[i] # d_i: Distance of layer i
k = (d_i - u_rot[:, 2]).unsqueeze(1) k = (d_i - u_rot[:, 2]).unsqueeze(1)
pi_r = (u_rot[:, 0:2] + v_rot[:, :, :, 0:2] * k) / dp_i # pi_r: H x W x M x 2, rays' pixel coord on layer i pi_r = (u_rot[:, 0:2] + v_rot[:, :, :, 0:2] * k) / dp_i # pi_r: H x W x M x 2, rays' pixel coord on layer i
Phi[i, :, :, :, :] = torch.floor(pi_r + c) phi[i, :, :, :, :] = torch.floor(pi_r + c)
mask[:, :, :, :, 0] = ((Phi[:, :, :, :, 0] >= 0) & (Phi[:, :, :, :, 0] < self.conf.layer_res[0])).float()
mask[:, :, :, :, 1] = ((Phi[:, :, :, :, 1] >= 0) & (Phi[:, :, :, :, 1] < self.conf.layer_res[1])).float() # Calculate invalid mask (out-of-range elements in phi) and reduced to retinal
Phi[:, :, :, :, 0].clamp_(0, self.conf.layer_res[0] - 1) phi_invalid = (phi[:, :, :, :, 0] < 0) | (phi[:, :, :, :, 0] >= D[1]) | \
Phi[:, :, :, :, 1].clamp_(0, self.conf.layer_res[1] - 1) (phi[:, :, :, :, 1] < 0) | (phi[:, :, :, :, 1] >= D[0])
retinal_mask = mask.prod(0).prod(2).prod(2) phi_invalid = phi_invalid.unsqueeze(4)
return [ Phi, retinal_mask ] # print("phi_invalid:",phi_invalid.shape)
retinal_invalid = phi_invalid.amax((0, 3)).squeeze().unsqueeze(0)
# print("retinal_invalid:",retinal_invalid.shape)
# Fix invalid elements in phi
phi[phi_invalid.expand(-1, -1, -1, -1, 2)] = 0
return [ phi, phi_invalid, retinal_invalid ]
def GenRetinalFromLayers(self, layers, Phi): def GenRetinalFromLayers(self, layers, Phi):
''' '''
...@@ -139,28 +159,159 @@ class RetinalGen(object): ...@@ -139,28 +159,159 @@ class RetinalGen(object):
Parameters Parameters
-------- --------
layers - 3N x H_l x W_l tensor, stacked layer images, with 3 channels in each layer layers - 3N x H x W, stacked layer images, with 3 channels in each layer
phi - N x H_r x W_r x M x 2, retinal to layers mapping, N is number of layers
Returns Returns
-------- --------
3 x H_r x W_r tensor, 3 channels retinal image 3 x H_r x W_r, 3 channels retinal image
H_r x W_r tensor, retinal image mask, indicates pixels valid or not
''' '''
# FOR GRAYSCALE 1 FOR RGB 3 # FOR GRAYSCALE 1 FOR RGB 3
mapped_layers = torch.empty(self.N, 3, self.D_r[0], self.D_r[1], self.M) # 2 x 3 x 480 x 640 x 41 mapped_layers = torch.empty(self.N, 3, self.D_r[0], self.D_r[1], self.M) # 2 x 3 x 480 x 640 x 41
# print("mapped_layers:",mapped_layers.shape) # print("mapped_layers:",mapped_layers.shape)
for i in range(0, Phi.size()[0]): for i in range(0, Phi.size()[0]):
# torch.Size([3, 2, 320, 320, 2])
# print("gather layers:",layers[(i * 3) : (i * 3 + 3),Phi[i, :, :, :, 0],Phi[i, :, :, :, 1]].shape) # print("gather layers:",layers[(i * 3) : (i * 3 + 3),Phi[i, :, :, :, 0],Phi[i, :, :, :, 1]].shape)
mapped_layers[i, :, :, :, :] = layers[(i * 3) : (i * 3 + 3), mapped_layers[i, :, :, :, :] = layers[(i * 3) : (i * 3 + 3),
Phi[i, :, :, :, 0], Phi[i, :, :, :, 1],
Phi[i, :, :, :, 1]] Phi[i, :, :, :, 0]]
# print("mapped_layers:",mapped_layers.shape) # print("mapped_layers:",mapped_layers.shape)
retinal = mapped_layers.prod(0).sum(3).div(Phi.size()[3]) retinal = mapped_layers.prod(0).sum(3).div(Phi.size()[3])
# print("retinal:",retinal.shape) # print("retinal:",retinal.shape)
return retinal return retinal
def GenRetinalFromLayersBatch(self, layers, Phi):
'''
Generate retinal image from layers, using precalculated mapping matrix
Parameters
--------
layers - 3N x H_l x W_l tensor, stacked layer images, with 3 channels in each layer
Returns
--------
3 x H_r x W_r tensor, 3 channels retinal image
H_r x W_r tensor, retinal image mask, indicates pixels valid or not
'''
mapped_layers = torch.empty(layers.size()[0], self.N, 3, self.D_r[0], self.D_r[1], self.M) #BS x Layers x C x H x W x Sample
# truth = torch.empty(layers.size()[0], self.N, 3, self.D_r[0], self.D_r[1], self.M)
# layers_truth = layers.clone()
# Phi_truth = Phi.clone()
layers = torch.stack((layers[:,0:3,:,:],layers[:,3:6,:,:]),dim=1) ## torch.Size([BS, Layer, RGB 3, 320, 320])
# Phi = Phi[:,:,None,:,:,:,:].expand(-1,-1,3,-1,-1,-1,-1)
# print("mapped_layers:",mapped_layers.shape) #torch.Size([2, 2, 3, 320, 320, 41])
# print("input layers:",layers.shape) ## torch.Size([2, 2, 3, 320, 320])
# print("input Phi:",Phi.shape) #torch.Size([2, 2, 320, 320, 41, 2])
# #没优化
# for i in range(0, Phi_truth.size()[0]):
# for j in range(0, Phi_truth.size()[1]):
# truth[i, j, :, :, :, :] = layers_truth[i, (j * 3) : (j * 3 + 3),
# Phi_truth[i, j, :, :, :, 0],
# Phi_truth[i, j, :, :, :, 1]]
#优化2
# start = time.time()
mapped_layers_op1 = mapped_layers.reshape(-1,
mapped_layers.shape[2],mapped_layers.shape[3],mapped_layers.shape[4],mapped_layers.shape[5])
# BatchSizexLayer Channel 3 320 320 41
layers_op1 = layers.reshape(-1,layers.shape[2],layers.shape[3],layers.shape[4]) # 2x2 3 320 320
Phi_op1 = Phi.reshape(-1,Phi.shape[2],Phi.shape[3],Phi.shape[4],Phi.shape[5]) # 2x2 320 320 41 2
x = Phi_op1[:,:,:,:,0] # 2x2 320 320 41
y = Phi_op1[:,:,:,:,1] # 2x2 320 320 41
# print("reshape:",time.time() - start)
# start = time.time()
mapped_layers_op1 = layers_op1[torch.arange(layers_op1.shape[0])[:, None, None, None], :, y, x] # x,y 切换
#2x2 320 320 41 3
# print("mapping one step:",time.time() - start)
# print("mapped_layers:",mapped_layers_op1.shape) # torch.Size([4, 3, 320, 320, 41])
# start = time.time()
mapped_layers_op1 = mapped_layers_op1.permute(0,4,1,2,3)
mapped_layers = mapped_layers_op1.reshape(mapped_layers.shape[0],mapped_layers.shape[1],
mapped_layers.shape[2],mapped_layers.shape[3],mapped_layers.shape[4],mapped_layers.shape[5])
# print("reshape end:",time.time() - start)
# print("test:")
# print((truth.cpu() == mapped_layers.cpu()).all())
#优化1
# start = time.time()
# mapped_layers_op1 = mapped_layers.reshape(-1,
# mapped_layers.shape[2],mapped_layers.shape[3],mapped_layers.shape[4],mapped_layers.shape[5])
# layers_op1 = layers.reshape(-1,layers.shape[2],layers.shape[3],layers.shape[4])
# Phi_op1 = Phi.reshape(-1,Phi.shape[2],Phi.shape[3],Phi.shape[4],Phi.shape[5])
# print("reshape:",time.time() - start)
# for i in range(0, Phi_op1.size()[0]):
# start = time.time()
# mapped_layers_op1[i, :, :, :, :] = layers_op1[i,:,
# Phi_op1[i, :, :, :, 0],
# Phi_op1[i, :, :, :, 1]]
# print("mapping one step:",time.time() - start)
# print("mapped_layers:",mapped_layers_op1.shape) # torch.Size([4, 3, 320, 320, 41])
# start = time.time()
# mapped_layers = mapped_layers_op1.reshape(mapped_layers.shape[0],mapped_layers.shape[1],
# mapped_layers.shape[2],mapped_layers.shape[3],mapped_layers.shape[4],mapped_layers.shape[5])
# print("reshape end:",time.time() - start)
# print("mapped_layers:",mapped_layers.shape) # torch.Size([2, 2, 3, 320, 320, 41])
retinal = mapped_layers.prod(1).sum(4).div(Phi.size()[4])
# print("retinal:",retinal.shape) # torch.Size([BatchSize, 3, 320, 320])
return retinal
## TO BE CHECK
def GenFoveaLayers(self, b_retinal, is_mask):
'''
Generate foveated layers for retinal images or masks
Parameters
--------
b_retinal - B x C x H_r x W_r, Batch of retinal images/masks
is_mask - Whether b_retinal is masks or images
def GenFoveaLayers(self, retinal, retinal_mask): Returns
--------
b_fovea_layers - N_f x (B x C x H[f] x W[f]) list of batch of foveated layers
'''
b_fovea_layers = []
for i in range(0, len(self.conf.eye_fovea_angles)):
k = self.conf.eye_fovea_downsamples[i]
region = self.conf.GetRegionOfFoveaLayer(i)
b_roi = b_retinal[:, :, region, region]
if k == 1:
b_fovea_layers.append(b_roi)
elif is_mask:
b_fovea_layers.append(torch.nn.functional.max_pool2d(b_roi.to(torch.float), k).to(torch.bool))
else:
b_fovea_layers.append(torch.nn.functional.avg_pool2d(b_roi, k))
return b_fovea_layers
# fovea_layers = []
# fovea_layer_masks = []
# fov = self.conf.eye_fovea_angles[-1]
# retinal_res = int(self.conf.retinal_res[0])
# for i in range(0, len(self.conf.eye_fovea_angles)):
# angle = self.conf.eye_fovea_angles[i]
# k = self.conf.eye_fovea_downsamples[i]
# roi_size = int(np.ceil(retinal_res * angle / fov))
# roi_offset = int((retinal_res - roi_size) / 2)
# roi_img = retinal[:, roi_offset:(roi_offset + roi_size), roi_offset:(roi_offset + roi_size)]
# roi_mask = retinal_mask[roi_offset:(roi_offset + roi_size), roi_offset:(roi_offset + roi_size)]
# if k == 1:
# fovea_layers.append(roi_img)
# fovea_layer_masks.append(roi_mask)
# else:
# fovea_layers.append(torch.nn.functional.avg_pool2d(roi_img.unsqueeze(0), k).squeeze(0))
# fovea_layer_masks.append(1 - torch.nn.functional.max_pool2d((1 - roi_mask).unsqueeze(0), k).squeeze(0))
# return [ fovea_layers, fovea_layer_masks ]
## TO BE CHECK
def GenFoveaLayersBatch(self, retinal, retinal_mask):
''' '''
Generate foveated layers and corresponding masks Generate foveated layers and corresponding masks
...@@ -177,18 +328,48 @@ class RetinalGen(object): ...@@ -177,18 +328,48 @@ class RetinalGen(object):
fovea_layers = [] fovea_layers = []
fovea_layer_masks = [] fovea_layer_masks = []
fov = self.conf.eye_fovea_angles[-1] fov = self.conf.eye_fovea_angles[-1]
# print("fov:",fov)
retinal_res = int(self.conf.retinal_res[0]) retinal_res = int(self.conf.retinal_res[0])
# print("retinal_res:",retinal_res)
# print("len(self.conf.eye_fovea_angles):",len(self.conf.eye_fovea_angles))
for i in range(0, len(self.conf.eye_fovea_angles)): for i in range(0, len(self.conf.eye_fovea_angles)):
angle = self.conf.eye_fovea_angles[i] angle = self.conf.eye_fovea_angles[i]
k = self.conf.eye_fovea_downsamples[i] k = self.conf.eye_fovea_downsamples[i]
roi_size = int(np.ceil(retinal_res * angle / fov)) roi_size = int(np.ceil(retinal_res * angle / fov))
roi_offset = int((retinal_res - roi_size) / 2) roi_offset = int((retinal_res - roi_size) / 2)
roi_img = retinal[:, roi_offset:(roi_offset + roi_size), roi_offset:(roi_offset + roi_size)] # [2, 3, 320, 320]
roi_mask = retinal_mask[roi_offset:(roi_offset + roi_size), roi_offset:(roi_offset + roi_size)] roi_img = retinal[:, :, roi_offset:(roi_offset + roi_size), roi_offset:(roi_offset + roi_size)]
# print("roi_img:",roi_img.shape)
# [2, 320, 320]
roi_mask = retinal_mask[:, roi_offset:(roi_offset + roi_size), roi_offset:(roi_offset + roi_size)]
# print("roi_mask:",roi_mask.shape)
if k == 1: if k == 1:
fovea_layers.append(roi_img) fovea_layers.append(roi_img)
fovea_layer_masks.append(roi_mask) fovea_layer_masks.append(roi_mask)
else: else:
fovea_layers.append(torch.nn.functional.avg_pool2d(roi_img.unsqueeze(0), k).squeeze(0)) fovea_layers.append(torch.nn.functional.avg_pool2d(roi_img, k))
fovea_layer_masks.append(1 - torch.nn.functional.max_pool2d((1 - roi_mask).unsqueeze(0), k).squeeze(0)) fovea_layer_masks.append(1 - torch.nn.functional.max_pool2d((1 - roi_mask), k))
return [ fovea_layers, fovea_layer_masks ] return [ fovea_layers, fovea_layer_masks ]
\ No newline at end of file
## TO BE CHECK
def GenFoveaRetinal(self, b_fovea_layers):
'''
Generate foveated retinal image by blending fovea layers
**Note: current implementation only support two fovea layers**
Parameters
--------
b_fovea_layers - N_f x (B x 3 x H[f] x W[f]), list of batch of (masked) foveated layers
Returns
--------
B x 3 x H_r x W_r, batch of foveated retinal images
'''
b_fovea_retinal = torch.nn.functional.interpolate(b_fovea_layers[1],
scale_factor=self.conf.eye_fovea_downsamples[1],
mode='bilinear', align_corners=False)
region = self.conf.GetRegionOfFoveaLayer(0)
blend = self.conf.eye_fovea_blend[0]
b_roi = b_fovea_retinal[:, :, region, region]
b_roi.mul_(1 - blend).add_(b_fovea_layers[0] * blend)
return b_fovea_retinal
import torch
from ssim import *
from perc_loss import *
l1loss = torch.nn.L1Loss()
perc_loss = VGGPerceptualLoss()
perc_loss = perc_loss.to("cuda:1")
##### LOSS #####
def calImageGradients(images):
# x is a 4-D tensor
dx = images[:, :, 1:, :] - images[:, :, :-1, :]
dy = images[:, :, :, 1:] - images[:, :, :, :-1]
return dx, dy
def loss_new(generated, gt):
mse_loss = torch.nn.MSELoss()
rmse_intensity = mse_loss(generated, gt)
psnr_intensity = torch.log10(rmse_intensity)
# print("psnr:",psnr_intensity)
# ssim_intensity = ssim(generated, gt)
labels_dx, labels_dy = calImageGradients(gt)
# print("generated:",generated.shape)
preds_dx, preds_dy = calImageGradients(generated)
rmse_grad_x, rmse_grad_y = mse_loss(labels_dx, preds_dx), mse_loss(labels_dy, preds_dy)
psnr_grad_x, psnr_grad_y = torch.log10(rmse_grad_x), torch.log10(rmse_grad_y)
# print("psnr x&y:",psnr_grad_x," ",psnr_grad_y)
p_loss = perc_loss(generated,gt)
# print("-psnr:",-psnr_intensity,",0.5*(psnr_grad_x + psnr_grad_y):",0.5*(psnr_grad_x + psnr_grad_y),",perc_loss:",p_loss)
total_loss = psnr_intensity + 0.5*(psnr_grad_x + psnr_grad_y) + p_loss
# total_loss = rmse_intensity + 0.5*(rmse_grad_x + rmse_grad_y) # + p_loss
return total_loss
def loss_without_perc(generated, gt):
mse_loss = torch.nn.MSELoss()
rmse_intensity = mse_loss(generated, gt)
psnr_intensity = torch.log10(rmse_intensity)
# print("psnr:",psnr_intensity)
# ssim_intensity = ssim(generated, gt)
labels_dx, labels_dy = calImageGradients(gt)
# print("generated:",generated.shape)
preds_dx, preds_dy = calImageGradients(generated)
rmse_grad_x, rmse_grad_y = mse_loss(labels_dx, preds_dx), mse_loss(labels_dy, preds_dy)
psnr_grad_x, psnr_grad_y = torch.log10(rmse_grad_x), torch.log10(rmse_grad_y)
# print("psnr x&y:",psnr_grad_x," ",psnr_grad_y)
# print("-psnr:",-psnr_intensity,",0.5*(psnr_grad_x + psnr_grad_y):",0.5*(psnr_grad_x + psnr_grad_y),",perc_loss:",p_loss)
total_loss = psnr_intensity + 0.5*(psnr_grad_x + psnr_grad_y)
# total_loss = rmse_intensity + 0.5*(rmse_grad_x + rmse_grad_y) # + p_loss
return total_loss
##### LOSS #####
class ReconstructionLoss(torch.nn.Module):
def __init__(self):
super(ReconstructionLoss, self).__init__()
def forward(self, generated, gt):
rmse_intensity = torch.nn.functional.mse_loss(generated, gt)
psnr_intensity = torch.log10(rmse_intensity)
labels_dx, labels_dy = calImageGradients(gt)
preds_dx, preds_dy = calImageGradients(generated)
rmse_grad_x, rmse_grad_y = torch.nn.functional.mse_loss(labels_dx, preds_dx), torch.nn.functional.mse_loss(labels_dy, preds_dy)
psnr_grad_x, psnr_grad_y = torch.log10(rmse_grad_x), torch.log10(rmse_grad_y)
total_loss = psnr_intensity + 0.5*(psnr_grad_x + psnr_grad_y)
return total_loss
class PerceptionReconstructionLoss(torch.nn.Module):
def __init__(self):
super(PerceptionReconstructionLoss, self).__init__()
def forward(self, generated, gt):
rmse_intensity = torch.nn.functional.mse_loss(generated, gt)
psnr_intensity = torch.log10(rmse_intensity)
labels_dx, labels_dy = calImageGradients(gt)
preds_dx, preds_dy = calImageGradients(generated)
rmse_grad_x, rmse_grad_y = torch.nn.functional.mse_loss(labels_dx, preds_dx), torch.nn.functional.mse_loss(labels_dy, preds_dy)
psnr_grad_x, psnr_grad_y = torch.log10(rmse_grad_x), torch.log10(rmse_grad_y)
p_loss = perc_loss(generated,gt)
total_loss = psnr_intensity + 0.5*(psnr_grad_x + psnr_grad_y) + p_loss
return total_loss
...@@ -12,212 +12,33 @@ from torch.autograd import Variable ...@@ -12,212 +12,33 @@ from torch.autograd import Variable
import cv2 import cv2
from gen_image import * from gen_image import *
from loss import *
import json import json
from ssim import *
from perc_loss import *
from conf import Conf from conf import Conf
from model.baseline import * from baseline import *
from data import *
import torch.autograd.profiler as profiler import torch.autograd.profiler as profiler
# param # param
BATCH_SIZE = 2 BATCH_SIZE = 1
NUM_EPOCH = 300 NUM_EPOCH = 1001
INTERLEAVE_RATE = 2 INTERLEAVE_RATE = 2
IM_H = 320 IM_H = 320
IM_W = 320 IM_W = 320
Retinal_IM_H = 320 Retinal_IM_H = 320
Retinal_IM_W = 320 Retinal_IM_W = 320
N = 25 # number of input light field stack
N = 9 # number of input light field stack
M = 2 # number of display layers M = 2 # number of display layers
DATA_FILE = "/home/yejiannan/Project/LightField/data/FlowRPG1211"
DATA_FILE = "/home/yejiannan/Project/LightField/data/gaze_fovea" DATA_JSON = "/home/yejiannan/Project/LightField/data/data_gaze_fovea_seq_flow_RPG.json"
DATA_JSON = "/home/yejiannan/Project/LightField/data/data_gaze_fovea_seq.json" # DATA_VAL_JSON = "/home/yejiannan/Project/LightField/data/data_gaze_fovea_val.json"
DATA_VAL_JSON = "/home/yejiannan/Project/LightField/data/data_gaze_fovea_val.json" OUTPUT_DIR = "/home/yejiannan/Project/LightField/outputE/gaze_fovea_seq_flow_RPG_seq5_same_loss"
OUTPUT_DIR = "/home/yejiannan/Project/LightField/outputE/gaze_fovea_seq"
OUT_CHANNELS_RB = 128 OUT_CHANNELS_RB = 128
KERNEL_SIZE_RB = 3 KERNEL_SIZE_RB = 3
KERNEL_SIZE = 3 KERNEL_SIZE = 3
LAST_LAYER_CHANNELS = 6 * INTERLEAVE_RATE**2 LAST_LAYER_CHANNELS = 6 * INTERLEAVE_RATE**2
FIRSST_LAYER_CHANNELS = 27 * INTERLEAVE_RATE**2 FIRSST_LAYER_CHANNELS = 75 * INTERLEAVE_RATE**2
class lightFieldDataLoader(torch.utils.data.dataset.Dataset):
def __init__(self, file_dir_path, file_json, transforms=None):
self.file_dir_path = file_dir_path
self.transforms = transforms
# self.datum_list = glob.glob(os.path.join(file_dir_path,"*"))
with open(file_json, encoding='utf-8') as file:
self.dataset_desc = json.loads(file.read())
def __len__(self):
return len(self.dataset_desc["focaldepth"])
def __getitem__(self, idx):
lightfield_images, gt, gt2, fd, gazeX, gazeY, sample_idx = self.get_datum(idx)
if self.transforms:
lightfield_images = self.transforms(lightfield_images)
# print(lightfield_images.shape,gt.shape,fd,gazeX,gazeY,sample_idx)
return (lightfield_images, gt, gt2, fd, gazeX, gazeY, sample_idx)
def get_datum(self, idx):
lf_image_paths = os.path.join(DATA_FILE, self.dataset_desc["train"][idx])
# print(lf_image_paths)
fd_gt_path = os.path.join(DATA_FILE, self.dataset_desc["gt"][idx])
fd_gt_path2 = os.path.join(DATA_FILE, self.dataset_desc["gt2"][idx])
# print(fd_gt_path)
lf_images = []
lf_image_big = cv2.imread(lf_image_paths, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.
lf_image_big = cv2.cvtColor(lf_image_big,cv2.COLOR_BGR2RGB)
## IF GrayScale
# lf_image_big = cv2.imread(lf_image_paths, cv2.IMREAD_GRAYSCALE).astype(np.float32) / 255.
# lf_image_big = np.expand_dims(lf_image_big, axis=-1)
# print(lf_image_big.shape)
for i in range(9):
lf_image = lf_image_big[i//3*IM_H:i//3*IM_H+IM_H,i%3*IM_W:i%3*IM_W+IM_W,0:3]
## IF GrayScale
# lf_image = lf_image_big[i//3*IM_H:i//3*IM_H+IM_H,i%3*IM_W:i%3*IM_W+IM_W,0:1]
# print(lf_image.shape)
lf_images.append(lf_image)
gt = cv2.imread(fd_gt_path, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.
gt = cv2.cvtColor(gt,cv2.COLOR_BGR2RGB)
gt2 = cv2.imread(fd_gt_path2, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.
gt2 = cv2.cvtColor(gt2,cv2.COLOR_BGR2RGB)
## IF GrayScale
# gt = cv2.imread(fd_gt_path, cv2.IMREAD_GRAYSCALE).astype(np.float32) / 255.
# gt = np.expand_dims(gt, axis=-1)
fd = self.dataset_desc["focaldepth"][idx]
gazeX = self.dataset_desc["gazeX"][idx]
gazeY = self.dataset_desc["gazeY"][idx]
sample_idx = self.dataset_desc["idx"][idx]
return np.asarray(lf_images),gt,gt2,fd,gazeX,gazeY,sample_idx
class lightFieldValDataLoader(torch.utils.data.dataset.Dataset):
def __init__(self, file_dir_path, file_json, transforms=None):
self.file_dir_path = file_dir_path
self.transforms = transforms
# self.datum_list = glob.glob(os.path.join(file_dir_path,"*"))
with open(file_json, encoding='utf-8') as file:
self.dataset_desc = json.loads(file.read())
def __len__(self):
return len(self.dataset_desc["focaldepth"])
def __getitem__(self, idx):
lightfield_images, fd, gazeX, gazeY, sample_idx = self.get_datum(idx)
if self.transforms:
lightfield_images = self.transforms(lightfield_images)
# print(lightfield_images.shape,gt.shape,fd,gazeX,gazeY,sample_idx)
return (lightfield_images, fd, gazeX, gazeY, sample_idx)
def get_datum(self, idx):
lf_image_paths = os.path.join(DATA_FILE, self.dataset_desc["train"][idx])
# print(fd_gt_path)
lf_images = []
lf_image_big = cv2.imread(lf_image_paths, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.
lf_image_big = cv2.cvtColor(lf_image_big,cv2.COLOR_BGR2RGB)
## IF GrayScale
# lf_image_big = cv2.imread(lf_image_paths, cv2.IMREAD_GRAYSCALE).astype(np.float32) / 255.
# lf_image_big = np.expand_dims(lf_image_big, axis=-1)
# print(lf_image_big.shape)
for i in range(9):
lf_image = lf_image_big[i//3*IM_H:i//3*IM_H+IM_H,i%3*IM_W:i%3*IM_W+IM_W,0:3]
## IF GrayScale
# lf_image = lf_image_big[i//3*IM_H:i//3*IM_H+IM_H,i%3*IM_W:i%3*IM_W+IM_W,0:1]
# print(lf_image.shape)
lf_images.append(lf_image)
## IF GrayScale
# gt = cv2.imread(fd_gt_path, cv2.IMREAD_GRAYSCALE).astype(np.float32) / 255.
# gt = np.expand_dims(gt, axis=-1)
fd = self.dataset_desc["focaldepth"][idx]
gazeX = self.dataset_desc["gazeX"][idx]
gazeY = self.dataset_desc["gazeY"][idx]
sample_idx = self.dataset_desc["idx"][idx]
return np.asarray(lf_images),fd,gazeX,gazeY,sample_idx
class lightFieldSeqDataLoader(torch.utils.data.dataset.Dataset):
def __init__(self, file_dir_path, file_json, transforms=None):
self.file_dir_path = file_dir_path
self.transforms = transforms
with open(file_json, encoding='utf-8') as file:
self.dataset_desc = json.loads(file.read())
def __len__(self):
return len(self.dataset_desc["seq"])
def __getitem__(self, idx):
lightfield_images, gt, gt2, fd, gazeX, gazeY, sample_idx = self.get_datum(idx)
fd = fd.astype(np.float32)
gazeX = gazeX.astype(np.float32)
gazeY = gazeY.astype(np.float32)
sample_idx = sample_idx.astype(np.int64)
# print(fd)
# print(gazeX)
# print(gazeY)
# print(sample_idx)
# print(lightfield_images.dtype,gt.dtype, gt2.dtype, fd.dtype, gazeX.dtype, gazeY.dtype, sample_idx.dtype, delta.dtype)
# print(lightfield_images.shape,gt.shape, gt2.shape, fd.shape, gazeX.shape, gazeY.shape, sample_idx.shape, delta.shape)
if self.transforms:
lightfield_images = self.transforms(lightfield_images)
return (lightfield_images, gt, gt2, fd, gazeX, gazeY, sample_idx)
def get_datum(self, idx):
indices = self.dataset_desc["seq"][idx]
# print("indices:",indices)
lf_images = []
fd = []
gazeX = []
gazeY = []
sample_idx = []
gt = []
gt2 = []
for i in range(len(indices)):
lf_image_paths = os.path.join(DATA_FILE, self.dataset_desc["train"][indices[i]])
fd_gt_path = os.path.join(DATA_FILE, self.dataset_desc["gt"][indices[i]])
fd_gt_path2 = os.path.join(DATA_FILE, self.dataset_desc["gt2"][indices[i]])
lf_image_one_sample = []
lf_image_big = cv2.imread(lf_image_paths, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.
lf_image_big = cv2.cvtColor(lf_image_big,cv2.COLOR_BGR2RGB)
for j in range(9):
lf_image = lf_image_big[j//3*IM_H:j//3*IM_H+IM_H,j%3*IM_W:j%3*IM_W+IM_W,0:3]
## IF GrayScale
# lf_image = lf_image_big[i//3*IM_H:i//3*IM_H+IM_H,i%3*IM_W:i%3*IM_W+IM_W,0:1]
# print(lf_image.shape)
lf_image_one_sample.append(lf_image)
gt_i = cv2.imread(fd_gt_path, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.
gt.append(cv2.cvtColor(gt_i,cv2.COLOR_BGR2RGB))
gt2_i = cv2.imread(fd_gt_path2, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.
gt2.append(cv2.cvtColor(gt2_i,cv2.COLOR_BGR2RGB))
# print("indices[i]:",indices[i])
fd.append([self.dataset_desc["focaldepth"][indices[i]]])
gazeX.append([self.dataset_desc["gazeX"][indices[i]]])
gazeY.append([self.dataset_desc["gazeY"][indices[i]]])
sample_idx.append([self.dataset_desc["idx"][indices[i]]])
lf_images.append(lf_image_one_sample)
#lf_images: 5,9,320,320
return np.asarray(lf_images),np.asarray(gt),np.asarray(gt2),np.asarray(fd),np.asarray(gazeX),np.asarray(gazeY),np.asarray(sample_idx)
#### Image Gen
conf = Conf()
u = GenSamplesInPupil(conf.pupil_size, 5)
gen = RetinalGen(conf, u)
def GenRetinalFromLayersBatch(layers, gen, sample_idx, phi_dict, mask_dict): def GenRetinalFromLayersBatch(layers, gen, sample_idx, phi_dict, mask_dict):
# layers: batchsize, 2*color, height, width # layers: batchsize, 2*color, height, width
...@@ -248,6 +69,7 @@ def GenRetinalGazeFromLayersBatch(layers, gen, sample_idx, phi_dict, mask_dict): ...@@ -248,6 +69,7 @@ def GenRetinalGazeFromLayersBatch(layers, gen, sample_idx, phi_dict, mask_dict):
# retinal bs x color x height x width # retinal bs x color x height x width
retinal_fovea = torch.empty(layers.shape[0], 6, 160, 160) retinal_fovea = torch.empty(layers.shape[0], 6, 160, 160)
mask_fovea = torch.empty(layers.shape[0], 2, 160, 160) mask_fovea = torch.empty(layers.shape[0], 2, 160, 160)
for i in range(0, layers.size()[0]): for i in range(0, layers.size()[0]):
phi = phi_dict[int(sample_idx[i].data)] phi = phi_dict[int(sample_idx[i].data)]
# print("phi_i:",phi.shape) # print("phi_i:",phi.shape)
...@@ -261,12 +83,65 @@ def GenRetinalGazeFromLayersBatch(layers, gen, sample_idx, phi_dict, mask_dict): ...@@ -261,12 +83,65 @@ def GenRetinalGazeFromLayersBatch(layers, gen, sample_idx, phi_dict, mask_dict):
fovea_layers, fovea_layer_masks = gen.GenFoveaLayers(retinal_i,mask_i) fovea_layers, fovea_layer_masks = gen.GenFoveaLayers(retinal_i,mask_i)
retinal_fovea[i] = torch.cat([fovea_layers[0],fovea_layers[1]],dim=0) retinal_fovea[i] = torch.cat([fovea_layers[0],fovea_layers[1]],dim=0)
mask_fovea[i] = torch.stack([fovea_layer_masks[0],fovea_layer_masks[1]],dim=0) mask_fovea[i] = torch.stack([fovea_layer_masks[0],fovea_layer_masks[1]],dim=0)
retinal_fovea = var_or_cuda(retinal_fovea)
mask_fovea = var_or_cuda(mask_fovea) # batch x 2 x height x width
# mask = torch.stack(mask,dim = 0).unsqueeze(1)
return retinal_fovea, mask_fovea
import time
def GenRetinalGazeFromLayersBatchSpeed(layers, gen, phi, phi_invalid, retinal_invalid):
# layers: batchsize, 2*color, height, width
# Phi:torch.Size([batchsize, Layer, h, w, 41, 2])
# df : batchsize,..
# start1 = time.time()
# retinal bs x color x height x width
retinal_fovea = torch.empty((layers.shape[0], 6, 160, 160),device="cuda:2")
mask_fovea = torch.empty((layers.shape[0], 2, 160, 160),device="cuda:2")
# start = time.time()
retinal = gen.GenRetinalFromLayersBatch(layers,phi_batch)
# print("retinal:",retinal.shape) #retinal: torch.Size([2, 3, 320, 320])
# print("t2:",time.time() - start)
# start = time.time()
fovea_layers, fovea_layer_masks = gen.GenFoveaLayersBatch(retinal,mask_batch)
mask_fovea = torch.stack([fovea_layer_masks[0],fovea_layer_masks[1]],dim=1)
retinal_fovea = torch.cat([fovea_layers[0],fovea_layers[1]],dim=1)
# print("t3:",time.time() - start)
retinal_fovea = var_or_cuda(retinal_fovea) retinal_fovea = var_or_cuda(retinal_fovea)
mask_fovea = var_or_cuda(mask_fovea) # batch x 2 x height x width mask_fovea = var_or_cuda(mask_fovea) # batch x 2 x height x width
# mask = torch.stack(mask,dim = 0).unsqueeze(1) # mask = torch.stack(mask,dim = 0).unsqueeze(1)
return retinal_fovea, mask_fovea return retinal_fovea, mask_fovea
def MergeBatchSpeed(layers, gen, phi, phi_invalid, retinal_invalid):
# layers: batchsize, 2*color, height, width
# Phi:torch.Size([batchsize, Layer, h, w, 41, 2])
# df : batchsize,..
# start1 = time.time()
# retinal bs x color x height x width
# retinal_fovea = torch.empty((layers.shape[0], 6, 160, 160),device="cuda:2")
# mask_fovea = torch.empty((layers.shape[0], 2, 160, 160),device="cuda:2")
# start = time.time()
retinal = gen.GenRetinalFromLayersBatch(layers,phi) #retinal: torch.Size([BatchSize , 3, 320, 320])
retinal.mul_(~retinal_invalid.to("cuda:2"))
# print("retinal:",retinal.shape)
# print("t2:",time.time() - start)
# start = time.time()
# fovea_layers, fovea_layer_masks = gen.GenFoveaLayersBatch(retinal,mask_batch)
# mask_fovea = torch.stack([fovea_layer_masks[0],fovea_layer_masks[1]],dim=1)
# retinal_fovea = torch.cat([fovea_layers[0],fovea_layers[1]],dim=1)
# print("t3:",time.time() - start)
# retinal_fovea = var_or_cuda(retinal_fovea)
# mask_fovea = var_or_cuda(mask_fovea) # batch x 2 x height x width
# mask = torch.stack(mask,dim = 0).unsqueeze(1)
return retinal
def GenRetinalFromLayersBatch_Online(layers, gen, phi, mask): def GenRetinalFromLayersBatch_Online(layers, gen, phi, mask):
# layers: batchsize, 2*color, height, width # layers: batchsize, 2*color, height, width
# Phi:torch.Size([batchsize, 480, 640, 2, 41, 2]) # Phi:torch.Size([batchsize, 480, 640, 2, 41, 2])
...@@ -285,47 +160,7 @@ def GenRetinalFromLayersBatch_Online(layers, gen, phi, mask): ...@@ -285,47 +160,7 @@ def GenRetinalFromLayersBatch_Online(layers, gen, phi, mask):
return retinal.unsqueeze(0), mask_out return retinal.unsqueeze(0), mask_out
#### Image Gen End #### Image Gen End
weightVarScale = 0.25 from weight_init import weight_init_normal
bias_stddev = 0.01
def weight_init_normal(m):
classname = m.__class__.__name__
if classname.find("Conv") != -1:
torch.nn.init.xavier_normal_(m.weight.data)
torch.nn.init.normal_(m.bias.data,mean = 0.0, std=bias_stddev)
elif classname.find("BatchNorm2d") != -1:
torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
torch.nn.init.constant_(m.bias.data, 0.0)
def calImageGradients(images):
# x is a 4-D tensor
dx = images[:, :, 1:, :] - images[:, :, :-1, :]
dy = images[:, :, :, 1:] - images[:, :, :, :-1]
return dx, dy
perc_loss = VGGPerceptualLoss()
perc_loss = perc_loss.to("cuda:1")
def loss_new(generated, gt):
mse_loss = torch.nn.MSELoss()
rmse_intensity = mse_loss(generated, gt)
psnr_intensity = torch.log10(rmse_intensity)
# print("psnr:",psnr_intensity)
# ssim_intensity = ssim(generated, gt)
labels_dx, labels_dy = calImageGradients(gt)
# print("generated:",generated.shape)
preds_dx, preds_dy = calImageGradients(generated)
rmse_grad_x, rmse_grad_y = mse_loss(labels_dx, preds_dx), mse_loss(labels_dy, preds_dy)
psnr_grad_x, psnr_grad_y = torch.log10(rmse_grad_x), torch.log10(rmse_grad_y)
# print("psnr x&y:",psnr_grad_x," ",psnr_grad_y)
p_loss = perc_loss(generated,gt)
# print("-psnr:",-psnr_intensity,",0.5*(psnr_grad_x + psnr_grad_y):",0.5*(psnr_grad_x + psnr_grad_y),",perc_loss:",p_loss)
total_loss = 10 + psnr_intensity + 0.5*(psnr_grad_x + psnr_grad_y) + p_loss
# total_loss = rmse_intensity + 0.5*(rmse_grad_x + rmse_grad_y) # + p_loss
return total_loss
def save_checkpoints(file_path, epoch_idx, model, model_solver): def save_checkpoints(file_path, epoch_idx, model, model_solver):
print('[INFO] Saving checkpoint to %s ...' % ( file_path)) print('[INFO] Saving checkpoint to %s ...' % ( file_path))
...@@ -336,93 +171,112 @@ def save_checkpoints(file_path, epoch_idx, model, model_solver): ...@@ -336,93 +171,112 @@ def save_checkpoints(file_path, epoch_idx, model, model_solver):
} }
torch.save(checkpoint, file_path) torch.save(checkpoint, file_path)
mode = "train" # import pickle
# def save_obj(obj, name ):
import pickle # # with open('./outputF/dict/'+ name + '.pkl', 'wb') as f:
def save_obj(obj, name ): # # pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL)
# with open('./outputF/dict/'+ name + '.pkl', 'wb') as f: # torch.save(obj,'./outputF/dict/'+ name + '.pkl')
# pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL) # def load_obj(name):
torch.save(obj,'./outputF/dict/'+ name + '.pkl') # # with open('./outputF/dict/' + name + '.pkl', 'rb') as f:
def load_obj(name): # # return pickle.load(f)
# with open('./outputF/dict/' + name + '.pkl', 'rb') as f: # return torch.load('./outputF/dict/'+ name + '.pkl')
# return pickle.load(f)
return torch.load('./outputF/dict/'+ name + '.pkl') # def generatePhiMaskDict(data_json, generator):
# phi_dict = {}
def hook_fn_back(m, i, o): # mask_dict = {}
for grad in i: # idx_info_dict = {}
try: # with open(data_json, encoding='utf-8') as file:
print("Input Grad:",m,grad.shape,grad.sum()) # dataset_desc = json.loads(file.read())
except AttributeError: # for i in range(len(dataset_desc["focaldepth"])):
print ("None found for Gradient") # # if i == 2:
for grad in o: # # break
try: # idx = dataset_desc["idx"][i]
print("Output Grad:",m,grad.shape,grad.sum()) # focaldepth = dataset_desc["focaldepth"][i]
except AttributeError: # gazeX = dataset_desc["gazeX"][i]
print ("None found for Gradient") # gazeY = dataset_desc["gazeY"][i]
print("\n") # print("focaldepth:",focaldepth," idx:",idx," gazeX:",gazeX," gazeY:",gazeY)
# phi,mask = generator.CalculateRetinal2LayerMappings(focaldepth,torch.tensor([gazeX, gazeY]))
def hook_fn_for(m, i, o): # phi_dict[idx]=phi
for grad in i: # mask_dict[idx]=mask
try: # idx_info_dict[idx]=[idx,focaldepth,gazeX,gazeY]
print("Input Feats:",m,grad.shape,grad.sum()) # return phi_dict,mask_dict,idx_info_dict
except AttributeError:
print ("None found for Gradient") # def generatePhiMaskDictNew(data_json, generator):
for grad in o: # phi_dict = {}
try: # mask_dict = {}
print("Output Feats:",m,grad.shape,grad.sum()) # idx_info_dict = {}
except AttributeError: # with open(data_json, encoding='utf-8') as file:
print ("None found for Gradient") # dataset_desc = json.loads(file.read())
print("\n") # for i in range(len(dataset_desc["seq"])):
# for j in dataset_desc["seq"][i]:
def generatePhiMaskDict(data_json, generator): # idx = dataset_desc["idx"][j]
phi_dict = {} # focaldepth = dataset_desc["focaldepth"][j]
mask_dict = {} # gazeX = dataset_desc["gazeX"][j]
idx_info_dict = {} # gazeY = dataset_desc["gazeY"][j]
with open(data_json, encoding='utf-8') as file: # print("focaldepth:",focaldepth," idx:",idx," gazeX:",gazeX," gazeY:",gazeY)
dataset_desc = json.loads(file.read()) # phi,mask = generator.CalculateRetinal2LayerMappings(focaldepth,torch.tensor([gazeX, gazeY]))
for i in range(len(dataset_desc["focaldepth"])): # phi_dict[idx]=phi
# if i == 2: # mask_dict[idx]=mask
# break # idx_info_dict[idx]=[idx,focaldepth,gazeX,gazeY]
idx = dataset_desc["idx"][i] # return phi_dict,mask_dict,idx_info_dict
focaldepth = dataset_desc["focaldepth"][i]
gazeX = dataset_desc["gazeX"][i] mode = "Silence" #"Perf"
gazeY = dataset_desc["gazeY"][i] model_type = "RNN" #"RNN"
print("focaldepth:",focaldepth," idx:",idx," gazeX:",gazeX," gazeY:",gazeY) w_frame = 0.9
phi,mask = generator.CalculateRetinal2LayerMappings(focaldepth,torch.tensor([gazeX, gazeY])) w_inter_frame = 0.1
phi_dict[idx]=phi batch_model = "NoSingle"
mask_dict[idx]=mask loss1 = ReconstructionLoss()
idx_info_dict[idx]=[idx,focaldepth,gazeX,gazeY] loss2 = ReconstructionLoss()
return phi_dict,mask_dict,idx_info_dict
if __name__ == "__main__": if __name__ == "__main__":
############################## generate phi and mask in pre-training ############################## generate phi and mask in pre-training
# print("generating phi and mask...") # print("generating phi and mask...")
# phi_dict,mask_dict,idx_info_dict = generatePhiMaskDict(DATA_JSON,gen) # phi_dict,mask_dict,idx_info_dict = generatePhiMaskDictNew(DATA_JSON,gen)
# save_obj(phi_dict,"phi_1204") # # save_obj(phi_dict,"phi_1204")
# save_obj(mask_dict,"mask_1204") # # save_obj(mask_dict,"mask_1204")
# save_obj(idx_info_dict,"idx_info_1204") # # save_obj(idx_info_dict,"idx_info_1204")
# print("generating phi and mask end.") # print("generating phi and mask end.")
# exit(0) # exit(0)
############################# load phi and mask in pre-training ############################# load phi and mask in pre-training
print("loading phi and mask ...") # print("loading phi and mask ...")
phi_dict = load_obj("phi_1204") # phi_dict = load_obj("phi_1204")
mask_dict = load_obj("mask_1204") # mask_dict = load_obj("mask_1204")
idx_info_dict = load_obj("idx_info_1204") # idx_info_dict = load_obj("idx_info_1204")
print(len(phi_dict)) # print(len(phi_dict))
print(len(mask_dict)) # print(len(mask_dict))
print("loading phi and mask end") # print("loading phi and mask end")
#### Image Gen and conf
conf = Conf()
gen = RetinalGen(conf)
#train #train
train_data_loader = torch.utils.data.DataLoader(dataset=lightFieldSeqDataLoader(DATA_FILE,DATA_JSON), train_data_loader = torch.utils.data.DataLoader(dataset=lightFieldFlowSeqDataLoader(DATA_FILE,DATA_JSON, gen, conf),
batch_size=BATCH_SIZE, batch_size=BATCH_SIZE,
num_workers=0, num_workers=8,
pin_memory=True, pin_memory=True,
shuffle=True, shuffle=True,
drop_last=False) drop_last=False)
#Data loader test
print(len(train_data_loader)) print(len(train_data_loader))
# exit(0) # # lightfield_images, gt, flow, fd, gazeX, gazeY, posX, posY, sample_idx, phi, mask
# # lightfield_images, gt, phi, phi_invalid, retinal_invalid, flow, fd, gazeX, gazeY, posX, posY, posZ, sample_idx
# for batch_idx, (image_set, gt,phi, phi_invalid, retinal_invalid, flow, df, gazeX, gazeY, posX, posY, posZ, sample_idx) in enumerate(train_data_loader):
# print(image_set.shape,type(image_set))
# print(gt.shape,type(gt))
# print(phi.shape,type(phi))
# print(phi_invalid.shape,type(phi_invalid))
# print(retinal_invalid.shape,type(retinal_invalid))
# print(flow.shape,type(flow))
# print(df.shape,type(df))
# print(gazeX.shape,type(gazeX))
# print(posX.shape,type(posX))
# print(sample_idx.shape,type(sample_idx))
# print("test train dataloader.")
# exit(0)
#Data loader test end
################################################ val ######################################################### ################################################ val #########################################################
...@@ -464,21 +318,24 @@ if __name__ == "__main__": ...@@ -464,21 +318,24 @@ if __name__ == "__main__":
# print("output:",output.shape," df:",df[0].data, ",gazeX:",gazeX[0].data,",gazeY:", gazeY[0].data) # print("output:",output.shape," df:",df[0].data, ",gazeX:",gazeX[0].data,",gazeY:", gazeY[0].data)
# for i in range(output1.size()[0]): # for i in range(output1.size()[0]):
# save_image(output[i][0:3].data,os.path.join(OUTPUT_DIR,"test_interp_gaze_fac1_o_%.3f_%.3f_%.3f.png"%(df[i].data,gazeX[i].data,gazeY[i].data))) # save_image(output[i][0:3].data,os.path.join(OUTPUT_DIR,"test_interp_gaze_fac1_o_%.5f_%.5f_%.5f.png"%(df[i].data,gazeX[i].data,gazeY[i].data)))
# save_image(output[i][3:6].data,os.path.join(OUTPUT_DIR,"test_interp_gaze_fac2_o_%.3f_%.3f_%.3f.png"%(df[i].data,gazeX[i].data,gazeY[i].data))) # save_image(output[i][3:6].data,os.path.join(OUTPUT_DIR,"test_interp_gaze_fac2_o_%.5f_%.5f_%.5f.png"%(df[i].data,gazeX[i].data,gazeY[i].data)))
# save_image(output1[i][0:3].data,os.path.join(OUTPUT_DIR,"test_interp_gaze_out1_o_%.3f_%.3f_%.3f.png"%(df[i].data,gazeX[i].data,gazeY[i].data))) # save_image(output1[i][0:3].data,os.path.join(OUTPUT_DIR,"test_interp_gaze_out1_o_%.5f_%.5f_%.5f.png"%(df[i].data,gazeX[i].data,gazeY[i].data)))
# save_image(output1[i][3:6].data,os.path.join(OUTPUT_DIR,"test_interp_gaze_out2_o_%.3f_%.3f_%.3f.png"%(df[i].data,gazeX[i].data,gazeY[i].data))) # save_image(output1[i][3:6].data,os.path.join(OUTPUT_DIR,"test_interp_gaze_out2_o_%.5f_%.5f_%.5f.png"%(df[i].data,gazeX[i].data,gazeY[i].data)))
# # save_image(output[0][0:3].data,os.path.join(OUTPUT_DIR,"gaze_fovea_interp_l1_%.3f.png"%(df[0].data))) # # save_image(output[0][0:3].data,os.path.join(OUTPUT_DIR,"gaze_fovea_interp_l1_%.5f.png"%(df[0].data)))
# # save_image(output[0][3:6].data,os.path.join(OUTPUT_DIR,"gaze_fovea_interp_l2_%.3f.png"%(df[0].data))) # # save_image(output[0][3:6].data,os.path.join(OUTPUT_DIR,"gaze_fovea_interp_l2_%.5f.png"%(df[0].data)))
# # output = GenRetinalFromLayersBatch(output,conf,df,v,u) # # output = GenRetinalFromLayersBatch(output,conf,df,v,u)
# # save_image(output[0][0:3].data,os.path.join(OUTPUT_DIR,"1113_interp_o%.3f.png"%(df[0].data))) # # save_image(output[0][0:3].data,os.path.join(OUTPUT_DIR,"1113_interp_o%.5f.png"%(df[0].data)))
# exit() # exit()
################################################ train ######################################################### ################################################ train #########################################################
lf_model = model(FIRSST_LAYER_CHANNELS,LAST_LAYER_CHANNELS,OUT_CHANNELS_RB,KERNEL_SIZE,KERNEL_SIZE_RB,INTERLEAVE_RATE) if model_type == "RNN":
lf_model = model(FIRSST_LAYER_CHANNELS,LAST_LAYER_CHANNELS,OUT_CHANNELS_RB,KERNEL_SIZE,KERNEL_SIZE_RB,INTERLEAVE_RATE)
else:
lf_model = model(FIRSST_LAYER_CHANNELS,LAST_LAYER_CHANNELS,OUT_CHANNELS_RB,KERNEL_SIZE,KERNEL_SIZE_RB,INTERLEAVE_RATE,RNN=False)
lf_model.apply(weight_init_normal) lf_model.apply(weight_init_normal)
lf_model.train()
epoch_begin = 0 epoch_begin = 0
################################ load model file ################################ load model file
...@@ -493,144 +350,243 @@ if __name__ == "__main__": ...@@ -493,144 +350,243 @@ if __name__ == "__main__":
if torch.cuda.is_available(): if torch.cuda.is_available():
# lf_model = torch.nn.DataParallel(lf_model).cuda() # lf_model = torch.nn.DataParallel(lf_model).cuda()
lf_model = lf_model.to('cuda:1') lf_model = lf_model.to('cuda:2')
lf_model.train()
optimizer = torch.optim.Adam(lf_model.parameters(),lr=1e-2,betas=(0.9,0.999)) optimizer = torch.optim.Adam(lf_model.parameters(),lr=5e-3,betas=(0.9,0.999))
l1loss = torch.nn.L1Loss()
# lf_model.output_layer.register_backward_hook(hook_fn_back) # lf_model.output_layer.register_backward_hook(hook_fn_back)
if mode=="Perf":
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
print("begin training....") print("begin training....")
for epoch in range(epoch_begin, NUM_EPOCH): for epoch in range(epoch_begin, NUM_EPOCH):
for batch_idx, (image_set, gt, gt2, df, gazeX, gazeY, sample_idx) in enumerate(train_data_loader): for batch_idx, (image_set, gt,phi, phi_invalid, retinal_invalid, flow, flow_invalid_mask, df, gazeX, gazeY, posX, posY, posZ, sample_idx) in enumerate(train_data_loader):
# print(sample_idx.shape,df.shape,gazeX.shape,gazeY.shape) # torch.Size([2, 5]) # print(sample_idx.shape,df.shape,gazeX.shape,gazeY.shape) # torch.Size([2, 5])
# print(image_set.shape,gt.shape,gt2.shape) #torch.Size([2, 5, 9, 320, 320, 3]) torch.Size([2, 5, 160, 160, 3]) torch.Size([2, 5, 160, 160, 3]) # print(image_set.shape,gt.shape,gt2.shape) #torch.Size([2, 5, 9, 320, 320, 3]) torch.Size([2, 5, 160, 160, 3]) torch.Size([2, 5, 160, 160, 3])
# print(delta.shape) # delta: torch.Size([2, 4, 160, 160, 3]) # print(delta.shape) # delta: torch.Size([2, 4, 160, 160, 3])
if mode=="Perf":
end.record()
torch.cuda.synchronize()
print("load:",start.elapsed_time(end))
start.record()
#reshape for input #reshape for input
image_set = image_set.permute(0,1,2,5,3,4) # N S LF C H W image_set = image_set.permute(0,1,2,5,3,4) # N Seq 5 LF 25 C 3 H W
image_set = image_set.reshape(image_set.shape[0],image_set.shape[1],-1,image_set.shape[4],image_set.shape[5]) # N, LFxC, H, W image_set = image_set.reshape(image_set.shape[0],image_set.shape[1],-1,image_set.shape[4],image_set.shape[5]) # N, LFxC, H, W
# N, Seq 5, LF 25 C 3, H, W
image_set = var_or_cuda(image_set) image_set = var_or_cuda(image_set)
gt = gt.permute(0,1,4,2,3) # N S C H W # print(image_set.shape) #torch.Size([2, 5, 75, 320, 320])
gt = gt.permute(0,1,4,2,3) # BS Seq 5 C 3 H W
gt = var_or_cuda(gt) gt = var_or_cuda(gt)
gt2 = gt2.permute(0,1,4,2,3) flow = var_or_cuda(flow) #BS,Seq-1,H,W,2
gt2 = var_or_cuda(gt2)
# gt2 = gt2.permute(0,1,4,2,3)
# gt2 = var_or_cuda(gt2)
gen1 = torch.empty(gt.shape) gen1 = torch.empty(gt.shape) # BS Seq C H W
gen1 = var_or_cuda(gen1) gen1 = var_or_cuda(gen1)
# print(gen1.shape) #torch.Size([2, 5, 3, 320, 320])
gen2 = torch.empty(gt2.shape) # gen2 = torch.empty(gt2.shape)
gen2 = var_or_cuda(gen2) # gen2 = var_or_cuda(gen2)
warped = torch.empty(gt2.shape[0],gt2.shape[1]-1,gt2.shape[2],gt2.shape[3],gt2.shape[4]) #BS, Seq - 1, C, H, W
warped = torch.empty(gt.shape[0],gt.shape[1]-1,gt.shape[2],gt.shape[3],gt.shape[4])
warped = var_or_cuda(warped) warped = var_or_cuda(warped)
gen_temp = torch.empty(warped.shape)
gen_temp = var_or_cuda(gen_temp)
# print("warped:",warped.shape) #warped: torch.Size([2, 4, 3, 320, 320])
if mode=="Perf":
end.record()
torch.cuda.synchronize()
print("data prepare:",start.elapsed_time(end))
start.record()
if model_type == "RNN":
if batch_model != "Single":
for k in range(image_set.shape[1]):
if k == 0:
lf_model.reset_hidden(image_set[:,k])
output = lf_model(image_set[:,k],df[:,k],gazeX[:,k],gazeY[:,k],posX[:,k],posY[:,k],posZ[:,k]) # batchsize, layer_num x 2 = 6, layer_res: 320, layer_res: 320
output1 = MergeBatchSpeed(output, gen, phi[:,k], phi_invalid[:,k], retinal_invalid[:,k])
gen1[:,k] = output1[:,0:3]
gt[:,k] = gt[:,k].mul_(~retinal_invalid[:,k].to("cuda:2"))
if ((epoch%10 == 0 and epoch != 0) or epoch == 2):
for i in range(output.shape[0]):
save_image(output[i][0:3].data,os.path.join(OUTPUT_DIR,"gaze_fac1_o_%.5f_%.5f_%.5f_%.5f_%.5f.png"%(df[i][k].data,gazeX[i][k].data,gazeY[i][k].data,posX[i][k].data,posY[i][k].data)))
save_image(output[i][3:6].data,os.path.join(OUTPUT_DIR,"gaze_fac2_o_%.5f_%.5f_%.5f_%.5f_%.5f.png"%(df[i][k].data,gazeX[i][k].data,gazeY[i][k].data,posX[i][k].data,posY[i][k].data)))
for i in range(1,gt.shape[1]):
# print(flow_invalid_mask.shape) #torch.Size([2, 4, 320, 320])
# print(FlowMap(gen1[:,i-1],flow[:,i-1]).shape) #torch.Size([2, 3, 320, 320])
warped[:,i-1] = FlowMap(gen1[:,i-1],flow[:,i-1]).mul(~flow_invalid_mask[:,i-1].unsqueeze(1).to("cuda:2"))
gen_temp[:,i-1] = gen1[:,i].mul(~flow_invalid_mask[:,i-1].unsqueeze(1).to("cuda:2"))
else:
for k in range(image_set.shape[1]):
if k == 0:
lf_model.reset_hidden(image_set[:,k])
output = lf_model(image_set[:,k],df[:,k],gazeX[:,k],gazeY[:,k],posX[:,k],posY[:,k],posZ[:,k]) # batchsize, layer_num x 2 = 6, layer_res: 320, layer_res: 320
output1 = MergeBatchSpeed(output, gen, phi[:,k], phi_invalid[:,k], retinal_invalid[:,k])
gen1[:,k] = output1[:,0:3]
gt[:,k] = gt[:,k].mul_(~retinal_invalid[:,k].to("cuda:2"))
if k != image_set.shape[1]-1:
warped[:,k] = FlowMap(output1.detach(),flow[:,k]).mul(~flow_invalid_mask[:,k].unsqueeze(1).to("cuda:2"))
loss1 = loss_without_perc(output1,gt[:,k])
loss = (w_frame * loss1)
if k==0:
loss.backward(retain_graph=False)
optimizer.step()
lf_model.zero_grad()
optimizer.zero_grad()
print("Epoch:",epoch,",Iter:",batch_idx,",Seq:",k,",loss:",loss.item())
else:
output1mask = output1.mul(~flow_invalid_mask[:,k-1].unsqueeze(1).to("cuda:2"))
loss2 = l1loss(output1mask,warped[:,k-1])
loss += (w_inter_frame * loss2)
loss.backward(retain_graph=False)
optimizer.step()
lf_model.zero_grad()
optimizer.zero_grad()
# print("Epoch:",epoch,",Iter:",batch_idx,",Seq:",k,",loss:",loss.item())
print("Epoch:",epoch,",Iter:",batch_idx,",Seq:",k,",frame loss:",loss1.item(),",inter loss:",w_inter_frame * loss2.item())
if ((epoch%10 == 0 and epoch != 0) or epoch == 2):
for i in range(output.shape[0]):
save_image(output[i][0:3].data,os.path.join(OUTPUT_DIR,"gaze_fac1_o_%.5f_%.5f_%.5f_%.5f_%.5f.png"%(df[i][k].data,gazeX[i][k].data,gazeY[i][k].data,posX[i][k].data,posY[i][k].data)))
save_image(output[i][3:6].data,os.path.join(OUTPUT_DIR,"gaze_fac2_o_%.5f_%.5f_%.5f_%.5f_%.5f.png"%(df[i][k].data,gazeX[i][k].data,gazeY[i][k].data,posX[i][k].data,posY[i][k].data)))
# BSxSeq C H W
gen1 = gen1.reshape(-1,gen1.shape[2],gen1.shape[3],gen1.shape[4])
gt = gt.reshape(-1,gt.shape[2],gt.shape[3],gt.shape[4])
if ((epoch%10== 0 and epoch != 0) or epoch == 2): # torch.Size([2, 5, 160, 160, 3])
for i in range(gt.size()[0]):
save_image(gen1[i].data,os.path.join(OUTPUT_DIR,"gaze_out1_o_%.5f_%.5f_%.5f_%.5f_%.5f.png"%(df[i//5][i%5].data,gazeX[i//5][i%5].data,gazeY[i//5][i%5].data,posX[i//5][i%5].data,posY[i//5][i%5].data)))
save_image(gt[i].data,os.path.join(OUTPUT_DIR,"gaze_test1_gt0_%.5f_%.5f_%.5f_%.5f_%.5f.png"%(df[i//5][i%5].data,gazeX[i//5][i%5].data,gazeY[i//5][i%5].data,posX[i//5][i%5].data,posY[i//5][i%5].data)))
if ((epoch%100 == 0) and epoch != 0 and batch_idx==len(train_data_loader)-1):
save_checkpoints(os.path.join(OUTPUT_DIR, 'gaze-ckpt-epoch-%04d.pth' % (epoch)),epoch,lf_model,optimizer)
else:
if batch_model != "Single":
for k in range(image_set.shape[1]):
output = lf_model(image_set[:,k],df[:,k],gazeX[:,k],gazeY[:,k],posX[:,k],posY[:,k],posZ[:,k]) # batchsize, layer_num x 2 = 6, layer_res: 320, layer_res: 320
output1 = MergeBatchSpeed(output, gen, phi[:,k], phi_invalid[:,k], retinal_invalid[:,k])
gen1[:,k] = output1[:,0:3]
gt[:,k] = gt[:,k].mul_(~retinal_invalid[:,k].to("cuda:2"))
if ((epoch%10 == 0 and epoch != 0) or epoch == 2):
for i in range(output.shape[0]):
save_image(output[i][0:3].data,os.path.join(OUTPUT_DIR,"gaze_fac1_o_%.5f_%.5f_%.5f_%.5f_%.5f.png"%(df[i][k].data,gazeX[i][k].data,gazeY[i][k].data,posX[i][k].data,posY[i][k].data)))
save_image(output[i][3:6].data,os.path.join(OUTPUT_DIR,"gaze_fac2_o_%.5f_%.5f_%.5f_%.5f_%.5f.png"%(df[i][k].data,gazeX[i][k].data,gazeY[i][k].data,posX[i][k].data,posY[i][k].data)))
for i in range(1,gt.shape[1]):
# print(flow_invalid_mask.shape) #torch.Size([2, 4, 320, 320])
# print(FlowMap(gen1[:,i-1],flow[:,i-1]).shape) #torch.Size([2, 3, 320, 320])
warped[:,i-1] = FlowMap(gen1[:,i-1],flow[:,i-1]).mul(~flow_invalid_mask[:,i-1].unsqueeze(1).to("cuda:2"))
gen_temp[:,i-1] = gen1[:,i].mul(~flow_invalid_mask[:,i-1].unsqueeze(1).to("cuda:2"))
else:
for k in range(image_set.shape[1]):
output = lf_model(image_set[:,k],df[:,k],gazeX[:,k],gazeY[:,k],posX[:,k],posY[:,k],posZ[:,k]) # batchsize, layer_num x 2 = 6, layer_res: 320, layer_res: 320
output1 = MergeBatchSpeed(output, gen, phi[:,k], phi_invalid[:,k], retinal_invalid[:,k])
gen1[:,k] = output1[:,0:3]
gt[:,k] = gt[:,k].mul_(~retinal_invalid[:,k].to("cuda:2"))
# print(output1.shape) #torch.Size([BS, 3, 320, 320])
loss1 = loss_without_perc(output1,gt[:,k])
loss = (w_frame * loss1)
# print("loss:",loss1.item())
loss.backward(retain_graph=False)
optimizer.step()
lf_model.zero_grad()
optimizer.zero_grad()
print("Epoch:",epoch,",Iter:",batch_idx,",Seq:",k,",loss:",loss.item())
if ((epoch%10 == 0 and epoch != 0) or epoch == 2):
for i in range(output.shape[0]):
save_image(output[i][0:3].data,os.path.join(OUTPUT_DIR,"gaze_fac1_o_%.5f_%.5f_%.5f_%.5f_%.5f.png"%(df[i][k].data,gazeX[i][k].data,gazeY[i][k].data,posX[i][k].data,posY[i][k].data)))
save_image(output[i][3:6].data,os.path.join(OUTPUT_DIR,"gaze_fac2_o_%.5f_%.5f_%.5f_%.5f_%.5f.png"%(df[i][k].data,gazeX[i][k].data,gazeY[i][k].data,posX[i][k].data,posY[i][k].data)))
for i in range(1,gt.shape[1]):
warped[:,i-1] = FlowMap(gen1[:,i-1],flow[:,i-1]).mul(~flow_invalid_mask[:,i-1].unsqueeze(1).to("cuda:2"))
gen_temp[:,i-1] = gen1[:,i].mul(~flow_invalid_mask[:,i-1].unsqueeze(1).to("cuda:2"))
warped = warped.reshape(-1,warped.shape[2],warped.shape[3],warped.shape[4])
gen_temp = gen_temp.reshape(-1,gen_temp.shape[2],gen_temp.shape[3],gen_temp.shape[4])
loss3 = l1loss(warped,gen_temp)
print("Epoch:",epoch,",Iter:",batch_idx,",inter-frame loss:",w_inter_frame *loss3.item())
# BSxSeq C H W
gen1 = gen1.reshape(-1,gen1.shape[2],gen1.shape[3],gen1.shape[4])
gt = gt.reshape(-1,gt.shape[2],gt.shape[3],gt.shape[4])
if ((epoch%10== 0 and epoch != 0) or epoch == 2): # torch.Size([2, 5, 160, 160, 3])
for i in range(gt.size()[0]):
save_image(gen1[i].data,os.path.join(OUTPUT_DIR,"gaze_out1_o_%.5f_%.5f_%.5f_%.5f_%.5f.png"%(df[i//5][i%5].data,gazeX[i//5][i%5].data,gazeY[i//5][i%5].data,posX[i//5][i%5].data,posY[i//5][i%5].data)))
save_image(gt[i].data,os.path.join(OUTPUT_DIR,"gaze_test1_gt0_%.5f_%.5f_%.5f_%.5f_%.5f.png"%(df[i//5][i%5].data,gazeX[i//5][i%5].data,gazeY[i//5][i%5].data,posX[i//5][i%5].data,posY[i//5][i%5].data)))
if ((epoch%100 == 0) and epoch != 0 and batch_idx==len(train_data_loader)-1):
save_checkpoints(os.path.join(OUTPUT_DIR, 'gaze-ckpt-epoch-%04d.pth' % (epoch)),epoch,lf_model,optimizer)
if batch_model != "Single":
if mode=="Perf":
end.record()
torch.cuda.synchronize()
print("forward:",start.elapsed_time(end))
start.record()
optimizer.zero_grad()
delta = torch.empty(gt2.shape[0],gt2.shape[1]-1,gt2.shape[2],gt2.shape[3],gt2.shape[4])
delta = var_or_cuda(delta)
for k in range(image_set.shape[1]):
if k == 0:
lf_model.reset_hidden(image_set[:,k])
# start = torch.cuda.Event(enable_timing=True) # BSxSeq C H W
# end = torch.cuda.Event(enable_timing=True) gen1 = gen1.reshape(-1,gen1.shape[2],gen1.shape[3],gen1.shape[4])
# start.record() # gen2 = gen2.reshape(-1,gen2.shape[2],gen2.shape[3],gen2.shape[4])
output = lf_model(image_set[:,k],df[:,k],gazeX[:,k],gazeY[:,k]) # BSxSeq C H W
# end.record() gt = gt.reshape(-1,gt.shape[2],gt.shape[3],gt.shape[4])
# torch.cuda.synchronize() # gt2 = gt2.reshape(-1,gt2.shape[2],gt2.shape[3],gt2.shape[4])
# print("Model Forward:",start.elapsed_time(end))
# print("output:",output.shape) # [2, 6, 320, 320] # BSx(Seq-1) C H W
# exit() warped = warped.reshape(-1,warped.shape[2],warped.shape[3],warped.shape[4])
########################### Use Pregen Phi and Mask ################### gen_temp = gen_temp.reshape(-1,gen_temp.shape[2],gen_temp.shape[3],gen_temp.shape[4])
# start.record()
output1,mask = GenRetinalGazeFromLayersBatch(output, gen, sample_idx[:,k], phi_dict, mask_dict)
# end.record()
# torch.cuda.synchronize()
# print("Merge:",start.elapsed_time(end))
# print("output1 shape:",output1.shape, "mask shape:",mask.shape)
# output1 shape: torch.Size([2, 6, 160, 160]) mask shape: torch.Size([2, 2, 160, 160])
for i in range(0, 2):
output1[:,i*3:i*3+3].mul_(mask[:,i:i+1])
if i == 0:
gt[:,k].mul_(mask[:,i:i+1])
if i == 1:
gt2[:,k].mul_(mask[:,i:i+1])
gen1[:,k] = output1[:,0:3] loss1_value = loss1(gen1,gt)
gen2[:,k] = output1[:,3:6] loss2_value = loss2(warped,gen_temp)
if ((epoch%5== 0) or epoch == 2): if model_type == "RNN":
for i in range(output.shape[0]): loss = (w_frame * loss1_value)+ (w_inter_frame * loss2_value)
save_image(output[i][0:3].data,os.path.join(OUTPUT_DIR,"gaze_fac1_o_%.3f_%.3f_%.3f.png"%(df[i][k].data,gazeX[i][k].data,gazeY[i][k].data))) else:
save_image(output[i][3:6].data,os.path.join(OUTPUT_DIR,"gaze_fac2_o_%.3f_%.3f_%.3f.png"%(df[i][k].data,gazeX[i][k].data,gazeY[i][k].data))) loss = (w_frame * loss1_value)
########################### Update ################### if mode=="Perf":
for i in range(1,image_set.shape[1]): end.record()
delta[:,i-1] = gt2[:,i] - gt2[:,i] torch.cuda.synchronize()
warped[:,i-1] = gen2[:,i]-gen2[:,i-1] print("compute loss:",start.elapsed_time(end))
optimizer.zero_grad() start.record()
loss.backward()
# # N S C H W if mode=="Perf":
gen1 = gen1.reshape(-1,gen1.shape[2],gen1.shape[3],gen1.shape[4]) end.record()
gen2 = gen2.reshape(-1,gen2.shape[2],gen2.shape[3],gen2.shape[4]) torch.cuda.synchronize()
gt = gt.reshape(-1,gt.shape[2],gt.shape[3],gt.shape[4]) print("backward:",start.elapsed_time(end))
gt2 = gt2.reshape(-1,gt2.shape[2],gt2.shape[3],gt2.shape[4])
warped = warped.reshape(-1,warped.shape[2],warped.shape[3],warped.shape[4]) start.record()
delta = delta.reshape(-1,delta.shape[2],delta.shape[3],delta.shape[4]) optimizer.step()
if mode=="Perf":
end.record()
# start = torch.cuda.Event(enable_timing=True) torch.cuda.synchronize()
# end = torch.cuda.Event(enable_timing=True) print("update:",start.elapsed_time(end))
# start.record()
loss1 = loss_new(gen1,gt) print("Epoch:",epoch,",Iter:",batch_idx,",loss:",loss.item(),",frame loss:",loss1_value.item(),",inter-frame loss:",loss2_value.item())
loss2 = loss_new(gen2,gt2)
loss3 = l1loss(warped,delta) # exit(0)
loss = loss1+loss2+loss3 ########################### Save #####################
# end.record() if ((epoch%10== 0 and epoch != 0) or epoch == 2): # torch.Size([2, 5, 160, 160, 3])
# torch.cuda.synchronize() for i in range(gt.size()[0]):
# print("loss comp:",start.elapsed_time(end)) # df 2,5
save_image(gen1[i].data,os.path.join(OUTPUT_DIR,"gaze_out1_o_%.5f_%.5f_%.5f_%.5f_%.5f.png"%(df[i//5][i%5].data,gazeX[i//5][i%5].data,gazeY[i//5][i%5].data,posX[i//5][i%5].data,posY[i//5][i%5].data)))
# save_image(gen2[i].data,os.path.join(OUTPUT_DIR,"gaze_out2_o_%.5f_%.5f_%.5f.png"%(df[i//5][i%5].data,gazeX[i//5][i%5].data,gazeY[i//5][i%5].data)))
# start.record() save_image(gt[i].data,os.path.join(OUTPUT_DIR,"gaze_test1_gt0_%.5f_%.5f_%.5f_%.5f_%.5f.png"%(df[i//5][i%5].data,gazeX[i//5][i%5].data,gazeY[i//5][i%5].data,posX[i//5][i%5].data,posY[i//5][i%5].data)))
loss.backward() # save_image(gt2[i].data,os.path.join(OUTPUT_DIR,"gaze_test1_gt1_%.5f_%.5f_%.5f.png"%(df[i//5][i%5].data,gazeX[i//5][i%5].data,gazeY[i//5][i%5].data)))
# end.record() if ((epoch%100 == 0) and epoch != 0 and batch_idx==len(train_data_loader)-1):
# torch.cuda.synchronize() save_checkpoints(os.path.join(OUTPUT_DIR, 'gaze-ckpt-epoch-%04d.pth' % (epoch)),epoch,lf_model,optimizer)
# print("backward:",start.elapsed_time(end)) \ No newline at end of file
# start.record()
optimizer.step()
# end.record()
# torch.cuda.synchronize()
# print("optimizer step:",start.elapsed_time(end))
## Update Prev
print("Epoch:",epoch,",Iter:",batch_idx,",loss:",loss)
########################### Save #####################
if ((epoch%5== 0) or epoch == 2): # torch.Size([2, 5, 160, 160, 3])
for i in range(gt.size()[0]):
# df 2,5
save_image(gen1[i].data,os.path.join(OUTPUT_DIR,"gaze_out1_o_%.3f_%.3f_%.3f.png"%(df[i//5][i%5].data,gazeX[i//5][i%5].data,gazeY[i//5][i%5].data)))
save_image(gen2[i].data,os.path.join(OUTPUT_DIR,"gaze_out2_o_%.3f_%.3f_%.3f.png"%(df[i//5][i%5].data,gazeX[i//5][i%5].data,gazeY[i//5][i%5].data)))
save_image(gt[i].data,os.path.join(OUTPUT_DIR,"gaze_test1_gt0_%.3f_%.3f_%.3f.png"%(df[i//5][i%5].data,gazeX[i//5][i%5].data,gazeY[i//5][i%5].data)))
save_image(gt2[i].data,os.path.join(OUTPUT_DIR,"gaze_test1_gt1_%.3f_%.3f_%.3f.png"%(df[i//5][i%5].data,gazeX[i//5][i%5].data,gazeY[i//5][i%5].data)))
if ((epoch%100 == 0) and epoch != 0 and batch_idx==len(train_data_loader)-1):
save_checkpoints(os.path.join(OUTPUT_DIR, 'gaze-ckpt-epoch-%04d.pth' % (epoch + 1)),epoch,lf_model,optimizer)
########################## test Phi and Mask ##########################
# phi,mask = gen.CalculateRetinal2LayerMappings(df[0],torch.tensor([gazeX[0], gazeY[0]]))
# # print("gaze Online:",gazeX[0]," ,",gazeY[0])
# # print("df Online:",df[0])
# # print("idx:",int(sample_idx[0].data))
# phi_t = phi_dict[int(sample_idx[0].data)]
# mask_t = mask_dict[int(sample_idx[0].data)]
# # print("idx info:",idx_info_dict[int(sample_idx[0].data)])
# # print("phi online:", phi.shape, " phi_t:", phi_t.shape)
# # print("mask online:", mask.shape, " mask_t:", mask_t.shape)
# print("phi delta:", (phi-phi_t).sum()," mask delta:",(mask -mask_t).sum())
# exit(0)
###########################Gen Batch 1 by 1###################
# phi,mask = gen.CalculateRetinal2LayerMappings(df[0],torch.tensor([gazeX[0], gazeY[0]]))
# # print(phi.shape) # 2,240,320,41,2
# output1, mask = GenRetinalFromLayersBatch_Online(output, gen, phi, mask)
###########################Gen Batch 1 by 1###################
\ No newline at end of file
...@@ -31,7 +31,7 @@ M = 1 # number of display layers ...@@ -31,7 +31,7 @@ M = 1 # number of display layers
DATA_FILE = "/home/yejiannan/Project/LightField/data/lf_syn" DATA_FILE = "/home/yejiannan/Project/LightField/data/lf_syn"
DATA_JSON = "/home/yejiannan/Project/LightField/data/data_lf_syn_full.json" DATA_JSON = "/home/yejiannan/Project/LightField/data/data_lf_syn_full.json"
# DATA_VAL_JSON = "/home/yejiannan/Project/LightField/data/data_gaze_fovea_val.json" # DATA_VAL_JSON = "/home/yejiannan/Project/LightField/data/data_gaze_fovea_val.json"
OUTPUT_DIR = "/home/yejiannan/Project/LightField/outputE/lf_syn_full_perc" OUTPUT_DIR = "/home/yejiannan/Project/LightField/outputE/lf_syn_full"
OUT_CHANNELS_RB = 128 OUT_CHANNELS_RB = 128
KERNEL_SIZE_RB = 3 KERNEL_SIZE_RB = 3
KERNEL_SIZE = 3 KERNEL_SIZE = 3
...@@ -50,7 +50,7 @@ def save_checkpoints(file_path, epoch_idx, model, model_solver): ...@@ -50,7 +50,7 @@ def save_checkpoints(file_path, epoch_idx, model, model_solver):
torch.save(checkpoint, file_path) torch.save(checkpoint, file_path)
mode = "Silence" #"Perf" mode = "Silence" #"Perf"
w_frame = 0.9 w_frame = 1.0
loss1 = PerceptionReconstructionLoss() loss1 = PerceptionReconstructionLoss()
if __name__ == "__main__": if __name__ == "__main__":
#train #train
...@@ -70,7 +70,7 @@ if __name__ == "__main__": ...@@ -70,7 +70,7 @@ if __name__ == "__main__":
if torch.cuda.is_available(): if torch.cuda.is_available():
# lf_model = torch.nn.DataParallel(lf_model).cuda() # lf_model = torch.nn.DataParallel(lf_model).cuda()
lf_model = lf_model.to('cuda:3') lf_model = lf_model.to('cuda:1')
optimizer = torch.optim.Adam(lf_model.parameters(),lr=5e-3,betas=(0.9,0.999)) optimizer = torch.optim.Adam(lf_model.parameters(),lr=5e-3,betas=(0.9,0.999))
......
import torch, os, sys, cv2
import torch.nn as nn
from torch.nn import init
import functools
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.nn import functional as func
from PIL import Image
import torchvision.transforms as transforms
import numpy as np
import torch
class RecurrentBlock(nn.Module):
def __init__(self, input_nc, output_nc, downsampling=False, bottleneck=False, upsampling=False):
super(RecurrentBlock, self).__init__()
self.input_nc = input_nc
self.output_nc = output_nc
self.downsampling = downsampling
self.upsampling = upsampling
self.bottleneck = bottleneck
self.hidden = None
if self.downsampling:
self.l1 = nn.Sequential(
nn.Conv2d(input_nc, output_nc, 3, padding=1),
nn.LeakyReLU(negative_slope=0.1)
)
self.l2 = nn.Sequential(
nn.Conv2d(2 * output_nc, output_nc, 3, padding=1),
nn.LeakyReLU(negative_slope=0.1),
nn.Conv2d(output_nc, output_nc, 3, padding=1),
nn.LeakyReLU(negative_slope=0.1),
)
elif self.upsampling:
self.l1 = nn.Sequential(
nn.Upsample(scale_factor=2, mode='nearest'),
nn.Conv2d(2 * input_nc, output_nc, 3, padding=1),
nn.LeakyReLU(negative_slope=0.1),
nn.Conv2d(output_nc, output_nc, 3, padding=1),
nn.LeakyReLU(negative_slope=0.1),
)
elif self.bottleneck:
self.l1 = nn.Sequential(
nn.Conv2d(input_nc, output_nc, 3, padding=1),
nn.LeakyReLU(negative_slope=0.1)
)
self.l2 = nn.Sequential(
nn.Conv2d(2 * output_nc, output_nc, 3, padding=1),
nn.LeakyReLU(negative_slope=0.1),
nn.Conv2d(output_nc, output_nc, 3, padding=1),
nn.LeakyReLU(negative_slope=0.1),
)
def forward(self, inp):
if self.downsampling:
op1 = self.l1(inp)
op2 = self.l2(torch.cat((op1, self.hidden), dim=1))
self.hidden = op2
return op2
elif self.upsampling:
op1 = self.l1(inp)
return op1
elif self.bottleneck:
op1 = self.l1(inp)
op2 = self.l2(torch.cat((op1, self.hidden), dim=1))
self.hidden = op2
return op2
def reset_hidden(self, inp, dfac):
size = list(inp.size())
size[1] = self.output_nc
size[2] /= dfac
size[3] /= dfac
self.hidden_size = size
self.hidden = torch.zeros(*(size)).to('cuda:0')
class RecurrentAE(nn.Module):
def __init__(self, input_nc):
super(RecurrentAE, self).__init__()
self.d1 = RecurrentBlock(input_nc=input_nc, output_nc=32, downsampling=True)
self.d2 = RecurrentBlock(input_nc=32, output_nc=43, downsampling=True)
self.d3 = RecurrentBlock(input_nc=43, output_nc=57, downsampling=True)
self.d4 = RecurrentBlock(input_nc=57, output_nc=76, downsampling=True)
self.d5 = RecurrentBlock(input_nc=76, output_nc=101, downsampling=True)
self.bottleneck = RecurrentBlock(input_nc=101, output_nc=101, bottleneck=True)
self.u5 = RecurrentBlock(input_nc=101, output_nc=76, upsampling=True)
self.u4 = RecurrentBlock(input_nc=76, output_nc=57, upsampling=True)
self.u3 = RecurrentBlock(input_nc=57, output_nc=43, upsampling=True)
self.u2 = RecurrentBlock(input_nc=43, output_nc=32, upsampling=True)
self.u1 = RecurrentBlock(input_nc=32, output_nc=3, upsampling=True)
def set_input(self, inp):
self.inp = inp['A']
def forward(self):
d1 = func.max_pool2d(input=self.d1(self.inp), kernel_size=2)
d2 = func.max_pool2d(input=self.d2(d1), kernel_size=2)
d3 = func.max_pool2d(input=self.d3(d2), kernel_size=2)
d4 = func.max_pool2d(input=self.d4(d3), kernel_size=2)
d5 = func.max_pool2d(input=self.d5(d4), kernel_size=2)
b = self.bottleneck(d5)
u5 = self.u5(torch.cat((b, d5), dim=1))
u4 = self.u4(torch.cat((u5, d4), dim=1))
u3 = self.u3(torch.cat((u4, d3), dim=1))
u2 = self.u2(torch.cat((u3, d2), dim=1))
u1 = self.u1(torch.cat((u2, d1), dim=1))
return u1
def reset_hidden(self):
self.d1.reset_hidden(self.inp, dfac=1)
self.d2.reset_hidden(self.inp, dfac=2)
self.d3.reset_hidden(self.inp, dfac=4)
self.d4.reset_hidden(self.inp, dfac=8)
self.d5.reset_hidden(self.inp, dfac=16)
self.bottleneck.reset_hidden(self.inp, dfac=32)
self.u4.reset_hidden(self.inp, dfac=16)
self.u3.reset_hidden(self.inp, dfac=8)
self.u5.reset_hidden(self.inp, dfac=4)
self.u2.reset_hidden(self.inp, dfac=2)
self.u1.reset_hidden(self.inp, dfac=1)
import numpy as np
import torch
import matplotlib.pyplot as plt
import glm
gvec_type = [ glm.dvec1, glm.dvec2, glm.dvec3, glm.dvec4 ]
gmat_type = [ [ glm.dmat2, glm.dmat2x3, glm.dmat2x4 ],
[ glm.dmat3x2, glm.dmat3, glm.dmat3x4 ],
[ glm.dmat4x2, glm.dmat4x3, glm.dmat4 ] ]
def Fov2Length(angle):
return np.tan(angle * np.pi / 360) * 2
def SmoothStep(x0, x1, x):
y = torch.clamp((x - x0) / (x1 - x0), 0, 1)
return y * y * (3 - 2 * y)
def MatImg2Tensor(img, permute = True, batch_dim = True):
batch_input = len(img.shape) == 4
if permute:
t = torch.from_numpy(np.transpose(img,
[0, 3, 1, 2] if batch_input else [2, 0, 1]))
else:
t = torch.from_numpy(img)
if not batch_input and batch_dim:
t = t.unsqueeze(0)
return t
def MatImg2Numpy(img, permute = True, batch_dim = True):
batch_input = len(img.shape) == 4
if permute:
t = np.transpose(img, [0, 3, 1, 2] if batch_input else [2, 0, 1])
else:
t = img
if not batch_input and batch_dim:
t = t.unsqueeze(0)
return t
def Tensor2MatImg(t):
img = t.squeeze().cpu().numpy()
batch_input = len(img.shape) == 4
if t.size()[batch_input] <= 4:
return np.transpose(img, [0, 2, 3, 1] if batch_input else [1, 2, 0])
return img
def ReadImageTensor(path, permute = True, rgb_only = True, batch_dim = True):
channels = 3 if rgb_only else 4
if isinstance(path,list):
first_image = plt.imread(path[0])[:, :, 0:channels]
b_image = np.empty((len(path), first_image.shape[0], first_image.shape[1], channels), dtype=np.float32)
b_image[0] = first_image
for i in range(1, len(path)):
b_image[i] = plt.imread(path[i])[:, :, 0:channels]
return MatImg2Tensor(b_image, permute)
return MatImg2Tensor(plt.imread(path)[:, :, 0:channels], permute, batch_dim)
def ReadImageNumpyArray(path, permute = True, rgb_only = True, batch_dim = True):
channels = 3 if rgb_only else 4
if isinstance(path,list):
first_image = plt.imread(path[0])[:, :, 0:channels]
b_image = np.empty((len(path), first_image.shape[0], first_image.shape[1], channels), dtype=np.float32)
b_image[0] = first_image
for i in range(1, len(path)):
b_image[i] = plt.imread(path[i])[:, :, 0:channels]
return MatImg2Numpy(b_image, permute)
return MatImg2Numpy(plt.imread(path)[:, :, 0:channels], permute, batch_dim)
def WriteImageTensor(t, path):
image = Tensor2MatImg(t)
if isinstance(path,list):
if len(image.shape) != 4 or image.shape[0] != len(path):
raise ValueError
for i in range(len(path)):
plt.imsave(path[i], image[i])
else:
if len(image.shape) == 4 and image.shape[0] != 1:
raise ValueError
plt.imsave(path, image)
def PlotImageTensor(t):
plt.imshow(Tensor2MatImg(t))
def Tensor2Glm(t):
t = t.squeeze()
size = t.size()
if len(size) == 1:
if size[0] <= 0 or size[0] > 4:
raise ValueError
return gvec_type[size[0] - 1](t.cpu().numpy())
if len(size) == 2:
if size[0] <= 1 or size[0] > 4 or size[1] <= 1 or size[1] > 4:
raise ValueError
return gmat_type[size[1] - 2][size[0] - 2](t.cpu().numpy())
raise ValueError
def Glm2Tensor(val):
return torch.from_numpy(np.array(val))
def MeshGrid(size, normalize=False):
y,x = torch.meshgrid(torch.tensor(range(size[0])),
torch.tensor(range(size[1])))
if normalize:
return torch.stack([x / (size[1] - 1.), y / (size[0] - 1.)], 2)
return torch.stack([x, y], 2)
\ No newline at end of file
import torch
weightVarScale = 0.25
bias_stddev = 0.01
def weight_init_normal(m):
classname = m.__class__.__name__
if classname.find("Conv") != -1:
torch.nn.init.xavier_normal_(m.weight.data)
torch.nn.init.normal_(m.bias.data,mean = 0.0, std=bias_stddev)
elif classname.find("BatchNorm2d") != -1:
torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
torch.nn.init.constant_(m.bias.data, 0.0)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment