Commit 5069f8ae authored by BobYeah's avatar BobYeah
Browse files

Gaze

parent 055dc0bb
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
import torch import torch
import glm
def RandomGenSamplesInPupil(conf, n_samples): def Fov2Length(angle):
'''
'''
return np.tan(angle * np.pi / 360) * 2
def RandomGenSamplesInPupil(pupil_size, n_samples):
''' '''
Random sample n_samples positions in pupil region Random sample n_samples positions in pupil region
...@@ -18,14 +26,14 @@ def RandomGenSamplesInPupil(conf, n_samples): ...@@ -18,14 +26,14 @@ def RandomGenSamplesInPupil(conf, n_samples):
samples = torch.empty(n_samples, 2) samples = torch.empty(n_samples, 2)
i = 0 i = 0
while i < n_samples: while i < n_samples:
s = (torch.rand(2) - 0.5) * conf.pupil_size s = (torch.rand(2) - 0.5) * pupil_size
if np.linalg.norm(s) > conf.pupil_size / 2.: if np.linalg.norm(s) > pupil_size / 2.:
continue continue
samples[i, :] = s samples[i, :] = [ s[0], s[1], 0 ]
i += 1 i += 1
return samples return samples
def GenSamplesInPupil(conf, circles): def GenSamplesInPupil(pupil_size, circles):
''' '''
Sample positions on circles in pupil region Sample positions on circles in pupil region
...@@ -38,68 +46,116 @@ def GenSamplesInPupil(conf, circles): ...@@ -38,68 +46,116 @@ def GenSamplesInPupil(conf, circles):
-------- --------
a n_samples x 2 tensor with 2D sample position in each row a n_samples x 2 tensor with 2D sample position in each row
''' '''
samples = torch.tensor([[ 0., 0. ]]) samples = torch.zeros(1, 3)
for i in range(1, circles): for i in range(1, circles):
r = conf.pupil_size / 2. / (circles - 1) * i r = pupil_size / 2. / (circles - 1) * i
n = 4 * i n = 4 * i
for j in range(0, n): for j in range(0, n):
angle = 2 * np.pi / n * j angle = 2 * np.pi / n * j
samples = torch.cat((samples, torch.tensor([[ r * np.cos(angle), r * np.sin(angle)]])),dim=0) samples = torch.cat([ samples, torch.tensor([[ r * np.cos(angle), r * np.sin(angle), 0 ]]) ], 0)
return samples return samples
def GenRetinal2LayerMappings(conf, df, v, u): class RetinalGen(object):
''' '''
Generate the mapping matrix from retinal to layers. Class for retinal generation process
Properties
--------
conf - multi-layers' parameters configuration
u - M x 3 tensor, M sample positions in pupil
p_r - H_r x W_r x 3 tensor, retinal pixel grid, [H_r, W_r] is the retinal resolution
Phi - N x H_r x W_r x M x 2 tensor, retinal to layers mapping, N is number of layers
mask - N x H_r x W_r x M x 2 tensor, indicates invalid (out-of-range) mapping
Methods
--------
'''
def __init__(self, conf, u):
'''
Initialize retinal generator instance
Parameters Parameters
-------- --------
conf - multi-layers' parameters configuration conf - multi-layers' parameters configuration
df - focal distance u - a M x 3 tensor stores M sample positions in pupil
v - a 1 x 2 tensor stores half viewport '''
u - a M x 2 tensor stores M sample positions on pupil self.conf = conf
# self.u = u.to(cuda_dev)
self.u = u # M x 3 M sample positions
self.D_r = conf.retinal_res # retinal res 480 x 640
self.N = conf.GetNLayers() # 2
self.M = u.size()[0] # samples
p_rx, p_ry = torch.meshgrid(torch.tensor(range(0, self.D_r[0])),
torch.tensor(range(0, self.D_r[1])))
self.p_r = torch.cat([
((torch.stack([p_rx, p_ry], 2) + 0.5) / self.D_r - 0.5) * conf.GetEyeViewportSize(), # 眼球视野
torch.ones(self.D_r[0], self.D_r[1], 1)
], 2)
Returns # self.Phi = torch.empty(N, D_r[0], D_r[1], M, 2, device=cuda_dev, dtype=torch.long)
# self.mask = torch.empty(self.N, self.D_r[0], self.D_r[1], self.M, 2, dtype=torch.float) # 2 x 480 x 640 x 41 x 2
def CalculateRetinal2LayerMappings(self, df, gaze):
'''
Calculate the mapping matrix from retinal to layers.
Parameters
-------- --------
The mapping matrix df - focus distance
gaze - 2 x 1 tensor, eye rotation angle (degs) in horizontal and vertical direction
''' '''
H_r = conf.retinal_res[0] Phi = torch.empty(self.N, self.D_r[0], self.D_r[1], self.M, 2, dtype=torch.long) # 2 x 480 x 640 x 41 x 2
W_r = conf.retinal_res[1] mask = torch.empty(self.N, self.D_r[0], self.D_r[1], self.M, 2, dtype=torch.float)
D_r = conf.retinal_res.double() D_r = self.conf.retinal_res # D_r: Resolution of retinal 480 640
N = conf.n_layers V = self.conf.GetEyeViewportSize() # V: Viewport size of eye
M = u.size()[0] #41 c = (self.conf.layer_res / 2) # c: Center of layers (pixel)
Phi = torch.empty(H_r, W_r, N, M, 2, dtype=torch.long) p_f = self.p_r * df # p_f: H x W x 3, focus positions of retinal pixels on focus plane
p_rx, p_ry = torch.meshgrid(torch.tensor(range(0, H_r)), rot_forward = glm.dvec3(glm.tan(glm.radians(glm.dvec2(gaze[1], -gaze[0]))), 1)
torch.tensor(range(0, W_r))) rot_mat = torch.from_numpy(np.array(
p_r = torch.stack([p_rx, p_ry], 2).unsqueeze(2).expand(-1, -1, M, -1) glm.dmat3(glm.lookAtLH(glm.dvec3(), rot_forward, glm.dvec3(0, 1, 0)))))
# print(p_r.shape) #torch.Size([480, 640, 41, 2]) rot_mat = rot_mat.float()
for i in range(0, N): u_rot = torch.mm(self.u, rot_mat)
dpi = conf.h_layer[i] / conf.layer_res[0] # 1 / 480 v_rot = torch.matmul(p_f, rot_mat).unsqueeze(2).expand(
ci = conf.layer_res / 2 # [240,320] -1, -1, self.u.size()[0], -1) - u_rot # v_rot: H x W x M x 3, rotated rays' direction vector
di = conf.d_layer[i] # 深度 v_rot.div_(v_rot[:, :, :, 2].unsqueeze(3)) # make z = 1 for each direction vector in v_rot
pi_r = di * v * (1. / D_r * (p_r + 0.5) - 0.5) / dpi # [480, 640, 41, 2]
wi = (1 - di / df) / dpi # (1 - 深度/聚焦) / dpi df = 2.625 di = 1.75 for i in range(0, self.conf.GetNLayers()):
pi = torch.floor(pi_r + ci + wi * u) dp_i = self.conf.GetLayerSize(i)[0] / self.conf.layer_res[0] # dp_i: Pixel size of layer i
torch.clamp_(pi[:, :, :, 0], 0, conf.layer_res[0] - 1) d_i = self.conf.d_layer[i] # d_i: Distance of layer i
torch.clamp_(pi[:, :, :, 1], 0, conf.layer_res[1] - 1) k = (d_i - u_rot[:, 2]).unsqueeze(1)
Phi[:, :, i, :, :] = pi pi_r = (u_rot[:, 0:2] + v_rot[:, :, :, 0:2] * k) / dp_i # pi_r: H x W x M x 2, rays' pixel coord on layer i
return Phi Phi[i, :, :, :, :] = torch.floor(pi_r + c)
mask[:, :, :, :, 0] = ((Phi[:, :, :, :, 0] >= 0) & (Phi[:, :, :, :, 0] < self.conf.layer_res[0])).float()
def GenRetinalFromLayers(layers, Phi): mask[:, :, :, :, 1] = ((Phi[:, :, :, :, 1] >= 0) & (Phi[:, :, :, :, 1] < self.conf.layer_res[1])).float()
# layers: 2, color, height, width Phi[:, :, :, :, 0].clamp_(0, self.conf.layer_res[0] - 1)
# Phi:torch.Size([480, 640, 2, 41, 2]) Phi[:, :, :, :, 1].clamp_(0, self.conf.layer_res[1] - 1)
M = Phi.size()[3] # 41 retinal_mask = mask.prod(0).prod(2).prod(2)
N = Phi.size()[2] # 2 return [ Phi, retinal_mask ]
# print(layers.shape)# torch.Size([2, 3, 480, 640])
# print(Phi.shape)# torch.Size([480, 640, 2, 41, 2]) def GenRetinalFromLayers(self, layers, Phi):
# retinal image: 3channels x retinal_size '''
retinal = torch.zeros(3, Phi.size()[0], Phi.size()[1]) Generate retinal image from layers, using precalculated mapping matrix
for j in range(0, M):
retinal_view = torch.zeros(3, Phi.size()[0], Phi.size()[1])
for i in range(0, N):
retinal_view.add_(layers[i,:, Phi[:, :, i, j, 0], Phi[:, :, i, j, 1]])
retinal.add_(retinal_view)
retinal.div_(M)
return retinal
Parameters
--------
layers - 3N x H_l x W_l tensor, stacked layer images, with 3 channels in each layer
Returns
--------
3 x H_r x W_r tensor, 3 channels retinal image
H_r x W_r tensor, retinal image mask, indicates pixels valid or not
'''
# FOR GRAYSCALE 1 FOR RGB 3
mapped_layers = torch.empty(self.N, 3, self.D_r[0], self.D_r[1], self.M) # 2 x 3 x 480 x 640 x 41
# print("mapped_layers:",mapped_layers.shape)
for i in range(0, Phi.size()[0]):
# print("gather layers:",layers[(i * 3) : (i * 3 + 3),Phi[i, :, :, :, 0],Phi[i, :, :, :, 1]].shape)
mapped_layers[i, :, :, :, :] = layers[(i * 3) : (i * 3 + 3),
Phi[i, :, :, :, 0],
Phi[i, :, :, :, 1]]
# print("mapped_layers:",mapped_layers.shape)
retinal = mapped_layers.prod(0).sum(3).div(Phi.size()[3])
# print("retinal:",retinal.shape)
return retinal
\ No newline at end of file
...@@ -16,55 +16,64 @@ import json ...@@ -16,55 +16,64 @@ import json
from ssim import * from ssim import *
from perc_loss import * from perc_loss import *
# param # param
BATCH_SIZE = 5 BATCH_SIZE = 16
NUM_EPOCH = 5000 NUM_EPOCH = 1000
INTERLEAVE_RATE = 2 INTERLEAVE_RATE = 2
IM_H = 480 IM_H = 320
IM_W = 640 IM_W = 320
Retinal_IM_H = 320
Retinal_IM_W = 320
N = 9 # number of input light field stack N = 9 # number of input light field stack
M = 2 # number of display layers M = 2 # number of display layers
DATA_FILE = "/home/yejiannan/Project/LightField/data/try" DATA_FILE = "/home/yejiannan/Project/LightField/data/gaze_small_nar_new"
DATA_JSON = "/home/yejiannan/Project/LightField/data/data.json" DATA_JSON = "/home/yejiannan/Project/LightField/data/data_gaze_low_new.json"
DATA_VAL_JSON = "/home/yejiannan/Project/LightField/data/data_val.json" DATA_VAL_JSON = "/home/yejiannan/Project/LightField/data/data_val.json"
OUTPUT_DIR = "/home/yejiannan/Project/LightField/output" OUTPUT_DIR = "/home/yejiannan/Project/LightField/output/gaze_low_new_1125_minibatch"
class lightFieldDataLoader(torch.utils.data.dataset.Dataset): class lightFieldDataLoader(torch.utils.data.dataset.Dataset):
def __init__(self, file_dir_path, file_json, transforms=None): def __init__(self, file_dir_path, file_json, transforms=None):
self.file_dir_path = file_dir_path self.file_dir_path = file_dir_path
self.transforms = transforms self.transforms = transforms
# self.datum_list = glob.glob(os.path.join(file_dir_path,"*"))
with open(file_json, encoding='utf-8') as file: with open(file_json, encoding='utf-8') as file:
self.dastset_desc = json.loads(file.read()) self.dataset_desc = json.loads(file.read())
def __len__(self): def __len__(self):
return len(self.dastset_desc["focaldepth"]) return len(self.dataset_desc["focaldepth"])
def __getitem__(self, idx): def __getitem__(self, idx):
lightfield_images, gt, fd = self.get_datum(idx) lightfield_images, gt, fd, gazeX, gazeY, sample_idx = self.get_datum(idx)
if self.transforms: if self.transforms:
lightfield_images = self.transforms(lightfield_images) lightfield_images = self.transforms(lightfield_images)
return (lightfield_images, gt, fd) return (lightfield_images, gt, fd, gazeX, gazeY, sample_idx)
def get_datum(self, idx): def get_datum(self, idx):
lf_image_paths = os.path.join(DATA_FILE, self.dastset_desc["train"][idx]) lf_image_paths = os.path.join(DATA_FILE, self.dataset_desc["train"][idx])
# print(lf_image_paths) fd_gt_path = os.path.join(DATA_FILE, self.dataset_desc["gt"][idx])
fd_gt_path = os.path.join(DATA_FILE, self.dastset_desc["gt"][idx])
# print(fd_gt_path)
lf_images = [] lf_images = []
lf_image_big = cv2.imread(lf_image_paths, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255. lf_image_big = cv2.imread(lf_image_paths, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.
lf_image_big = cv2.cvtColor(lf_image_big,cv2.COLOR_BGR2RGB) lf_image_big = cv2.cvtColor(lf_image_big,cv2.COLOR_BGR2RGB)
for i in range(9): for i in range(9):
lf_image = lf_image_big[i//3*IM_H:i//3*IM_H+IM_H,i%3*IM_W:i%3*IM_W+IM_W,0:3] lf_image = lf_image_big[i//3*IM_H:i//3*IM_H+IM_H,i%3*IM_W:i%3*IM_W+IM_W,0:3]
## IF GrayScale
# lf_image = lf_image_big[i//3*IM_H:i//3*IM_H+IM_H,i%3*IM_W:i%3*IM_W+IM_W,0:1]
# print(lf_image.shape) # print(lf_image.shape)
lf_images.append(lf_image) lf_images.append(lf_image)
gt = cv2.imread(fd_gt_path, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255. gt = cv2.imread(fd_gt_path, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.
gt = cv2.cvtColor(gt,cv2.COLOR_BGR2RGB) gt = cv2.cvtColor(gt,cv2.COLOR_BGR2RGB)
fd = self.dastset_desc["focaldepth"][idx] ## IF GrayScale
return (np.asarray(lf_images),gt,fd) # gt = cv2.imread(fd_gt_path, cv2.IMREAD_GRAYSCALE).astype(np.float32) / 255.
# gt = np.expand_dims(gt, axis=-1)
fd = self.dataset_desc["focaldepth"][idx]
gazeX = self.dataset_desc["gazeX"][idx]
gazeY = self.dataset_desc["gazeY"][idx]
sample_idx = self.dataset_desc["idx"][idx]
return np.asarray(lf_images),gt,fd,gazeX,gazeY,sample_idx
OUT_CHANNELS_RB = 128 OUT_CHANNELS_RB = 128
KERNEL_SIZE_RB = 3 KERNEL_SIZE_RB = 3
...@@ -128,7 +137,6 @@ class interleave(torch.nn.Module): ...@@ -128,7 +137,6 @@ class interleave(torch.nn.Module):
output = output.permute(0, 3, 1, 2) output = output.permute(0, 3, 1, 2)
return output return output
LAST_LAYER_CHANNELS = 6 * INTERLEAVE_RATE**2 LAST_LAYER_CHANNELS = 6 * INTERLEAVE_RATE**2
FIRSST_LAYER_CHANNELS = 27 * INTERLEAVE_RATE**2 FIRSST_LAYER_CHANNELS = 27 * INTERLEAVE_RATE**2
...@@ -144,37 +152,39 @@ class model(torch.nn.Module): ...@@ -144,37 +152,39 @@ class model(torch.nn.Module):
) )
self.residual_block1 = residual_block(0) self.residual_block1 = residual_block(0)
self.residual_block2 = residual_block(1) self.residual_block2 = residual_block(3)
self.residual_block3 = residual_block(1) self.residual_block3 = residual_block(3)
self.residual_block4 = residual_block(3)
self.residual_block5 = residual_block(3)
self.output_layer = torch.nn.Sequential( self.output_layer = torch.nn.Sequential(
torch.nn.Conv2d(OUT_CHANNELS_RB+1,LAST_LAYER_CHANNELS,KERNEL_SIZE,stride=1,padding=1), torch.nn.Conv2d(OUT_CHANNELS_RB+3,LAST_LAYER_CHANNELS,KERNEL_SIZE,stride=1,padding=1),
torch.nn.BatchNorm2d(LAST_LAYER_CHANNELS), torch.nn.BatchNorm2d(LAST_LAYER_CHANNELS),
torch.nn.Sigmoid() torch.nn.Sigmoid()
) )
self.deinterleave = deinterleave(INTERLEAVE_RATE) self.deinterleave = deinterleave(INTERLEAVE_RATE)
def forward(self, lightfield_images, focal_length): def forward(self, lightfield_images, focal_length, gazeX, gazeY):
# lightfield_images: torch.Size([batch_size, channels * D, H, W])
# channels : RGB*D: 3*9, H:256, W:256
input_to_net = self.interleave(lightfield_images) input_to_net = self.interleave(lightfield_images)
# print("after interleave:",input_to_net.shape)
input_to_rb = self.first_layer(input_to_net) input_to_rb = self.first_layer(input_to_net)
output = self.residual_block1(input_to_rb) output = self.residual_block1(input_to_rb)
# print("output1:",output.shape) depth_layer = torch.ones((input_to_rb.shape[0],1,input_to_rb.shape[2],input_to_rb.shape[3]))
gazeX_layer = torch.ones((input_to_rb.shape[0],1,input_to_rb.shape[2],input_to_rb.shape[3]))
depth_layer = torch.ones((output.shape[0],1,output.shape[2],output.shape[3])) gazeY_layer = torch.ones((input_to_rb.shape[0],1,input_to_rb.shape[2],input_to_rb.shape[3]))
# print(df.shape[0])
for i in range(focal_length.shape[0]): for i in range(focal_length.shape[0]):
depth_layer[i] = 1. / focal_length[i] depth_layer[i] *= 1. / focal_length[i]
# print(depth_layer.shape) gazeX_layer[i] *= (gazeX[i] - (-3.333)) / (3.333*2)
gazeY_layer[i] *= (gazeY[i] - (-3.333)) / (3.333*2)
depth_layer = var_or_cuda(depth_layer) depth_layer = var_or_cuda(depth_layer)
output = torch.cat((output,depth_layer),dim=1) gazeX_layer = var_or_cuda(gazeX_layer)
gazeY_layer = var_or_cuda(gazeY_layer)
output = torch.cat((output,depth_layer,gazeX_layer,gazeY_layer),dim=1)
output = self.residual_block2(output) output = self.residual_block2(output)
output = self.residual_block3(output) output = self.residual_block3(output)
# output = output + input_to_net output = self.residual_block4(output)
output = self.residual_block5(output)
output = self.output_layer(output) output = self.output_layer(output)
output = self.deinterleave(output) output = self.deinterleave(output)
return output return output
...@@ -182,72 +192,65 @@ class model(torch.nn.Module): ...@@ -182,72 +192,65 @@ class model(torch.nn.Module):
class Conf(object): class Conf(object):
def __init__(self): def __init__(self):
self.pupil_size = 0.02 # 2cm self.pupil_size = 0.02 # 2cm
self.retinal_res = torch.tensor([ 480, 640 ]) self.retinal_res = torch.tensor([ Retinal_IM_H, Retinal_IM_W ])
self.layer_res = torch.tensor([ 480, 640 ]) self.layer_res = torch.tensor([ IM_H, IM_W ])
self.n_layers = 2 self.layer_hfov = 90 # layers' horizontal FOV
self.d_layer = [ 1., 3. ] # layers' distance self.eye_hfov = 85 # eye's horizontal FOV
self.h_layer = [ 1. * 480. / 640., 3. * 480. / 640. ] # layers' height self.d_layer = [ 1, 3 ] # layers' distance
def GetNLayers(self):
return len(self.d_layer)
def GetLayerSize(self, i):
w = Fov2Length(self.layer_hfov)
h = w * self.layer_res[0] / self.layer_res[1]
return torch.tensor([ h, w ]) * self.d_layer[i]
def GetEyeViewportSize(self):
w = Fov2Length(self.eye_hfov)
h = w * self.retinal_res[0] / self.retinal_res[1]
return torch.tensor([ h, w ])
#### Image Gen #### Image Gen
conf = Conf() conf = Conf()
v = torch.tensor([conf.h_layer[0] / conf.d_layer[0], u = GenSamplesInPupil(conf.pupil_size, 5)
conf.h_layer[0] / conf.d_layer[0] * conf.layer_res[1] / conf.layer_res[0]])
u = GenSamplesInPupil(conf, 5) gen = RetinalGen(conf, u)
def GenRetinalFromLayersBatch(layers, conf, df, v, u): def GenRetinalFromLayersBatch(layers, gen, sample_idx, phi_dict, mask_dict):
# layers: batchsize, 2, color, height, width # layers: batchsize, 2*color, height, width
# Phi:torch.Size([batchsize, 480, 640, 2, 41, 2]) # Phi:torch.Size([batchsize, 480, 640, 2, 41, 2])
# df : batchsize,.. # df : batchsize,..
H_r = conf.retinal_res[0]
W_r = conf.retinal_res[1] # retinal bs x color x height x width
D_r = conf.retinal_res.double() retinal = torch.zeros(layers.shape[0], 3, Retinal_IM_H, Retinal_IM_W)
N = conf.n_layers mask = [] # mask shape 480 x 640
M = u.size()[0] #41 for i in range(0, layers.size()[0]):
BS = df.shape[0] phi = phi_dict[int(sample_idx[i].data)]
Phi = torch.empty(BS, H_r, W_r, N, M, 2, dtype=torch.long) phi = var_or_cuda(phi)
# print("Phi:",Phi.shape) phi.requires_grad = False
retinal[i] = gen.GenRetinalFromLayers(layers[i],phi)
p_rx, p_ry = torch.meshgrid(torch.tensor(range(0, H_r)), mask.append(mask_dict[int(sample_idx[i].data)])
torch.tensor(range(0, W_r)))
p_r = torch.stack([p_rx, p_ry], 2).unsqueeze(2).expand(-1, -1, M, -1)
# print("p_r:",p_r.shape) #torch.Size([480, 640, 41, 2])
for bs in range(BS):
for i in range(0, N):
dpi = conf.h_layer[i] / float(conf.layer_res[0]) # 1 / 480
# print("dpi:",dpi)
ci = conf.layer_res / 2 # [240,320]
di = conf.d_layer[i] # 深度
pi_r = di * v * (1. / D_r * (p_r + 0.5) - 0.5) / dpi # [480, 640, 41, 2]
wi = (1 - di / df[bs]) / dpi # (1 - 深度/聚焦) / dpi df = 2.625 di = 1.75
pi = torch.floor(pi_r + ci + wi * u)
torch.clamp_(pi[:, :, :, 0], 0, conf.layer_res[0] - 1)
torch.clamp_(pi[:, :, :, 1], 0, conf.layer_res[1] - 1)
Phi[bs, :, :, i, :, :] = pi
# print("Phi slice:",Phi[0, :, :, 0, 0, 0].shape)
retinal = torch.ones(BS, 3, H_r, W_r)
retinal = var_or_cuda(retinal) retinal = var_or_cuda(retinal)
for bs in range(BS): mask = torch.stack(mask,dim = 0).unsqueeze(1) # batch x 1 x height x width
for j in range(0, M): return retinal, mask
retinal_view = torch.ones(3, H_r, W_r)
retinal_view = var_or_cuda(retinal_view)
for i in range(0, N):
retinal_view.mul_(layers[bs, (i * 3) : (i * 3 + 3), Phi[bs, :, :, i, j, 0], Phi[bs, :, :, i, j, 1]])
retinal[bs,:,:,:].add_(retinal_view)
retinal[bs,:,:,:].div_(M)
return retinal
#### Image Gen End
def merge_two(near,far): def GenRetinalFromLayersBatch_Online(layers, gen, phi, mask):
df = conf.d_layer[0] + (conf.d_layer[1] - conf.d_layer[0]) / 2. # layers: batchsize, 2*color, height, width
# Phi = GenRetinal2LayerMappings(conf, df, v, u) # Phi:torch.Size([batchsize, 480, 640, 2, 41, 2])
# retinal = GenRetinalFromLayers(layers, Phi) # df : batchsize,..
return near[:,0:3,:,:] + far[:,3:6,:,:] / 2.0
def loss_two_images(generated, gt): # retinal bs x color x height x width
l1_loss = torch.nn.L1Loss() # retinal = torch.zeros(layers.shape[0], 3, Retinal_IM_H, Retinal_IM_W)
return l1_loss(generated, gt) # retinal = var_or_cuda(retinal)
phi = var_or_cuda(phi)
phi.requires_grad = False
retinal = gen.GenRetinalFromLayers(layers[0],phi)
retinal = var_or_cuda(retinal)
mask_out = mask.unsqueeze(0).unsqueeze(0)
return retinal.unsqueeze(0), mask_out
#### Image Gen End
weightVarScale = 0.25 weightVarScale = 0.25
bias_stddev = 0.01 bias_stddev = 0.01
...@@ -269,7 +272,7 @@ def var_or_cuda(x): ...@@ -269,7 +272,7 @@ def var_or_cuda(x):
def calImageGradients(images): def calImageGradients(images):
# x is a 4-D tensor # x is a 4-D tensor
dx = images[:, :, 1:, :] - images[:, :, :-1, :] dx = images[:, :, 1:, :] - images[:, :, :-1, :]
dy = images[:, 1:, :, :] - images[:, :-1, :, :] dy = images[:, :, :, 1:] - images[:, :, :, :-1]
return dx, dy return dx, dy
...@@ -279,16 +282,13 @@ perc_loss = perc_loss.to("cuda") ...@@ -279,16 +282,13 @@ perc_loss = perc_loss.to("cuda")
def loss_new(generated, gt): def loss_new(generated, gt):
mse_loss = torch.nn.MSELoss() mse_loss = torch.nn.MSELoss()
rmse_intensity = mse_loss(generated, gt) rmse_intensity = mse_loss(generated, gt)
RENORM_SCALE = torch.tensor(0.9)
RENORM_SCALE = var_or_cuda(RENORM_SCALE)
psnr_intensity = torch.log10(rmse_intensity) psnr_intensity = torch.log10(rmse_intensity)
ssim_intensity = ssim(generated, gt)
labels_dx, labels_dy = calImageGradients(gt) labels_dx, labels_dy = calImageGradients(gt)
preds_dx, preds_dy = calImageGradients(generated) preds_dx, preds_dy = calImageGradients(generated)
rmse_grad_x, rmse_grad_y = mse_loss(labels_dx, preds_dx), mse_loss(labels_dy, preds_dy) rmse_grad_x, rmse_grad_y = mse_loss(labels_dx, preds_dx), mse_loss(labels_dy, preds_dy)
psnr_grad_x, psnr_grad_y = torch.log10(rmse_grad_x), torch.log10(rmse_grad_y) psnr_grad_x, psnr_grad_y = torch.log10(rmse_grad_x), torch.log10(rmse_grad_y)
p_loss = perc_loss(generated,gt) p_loss = perc_loss(generated,gt)
# print("-psnr:",-psnr_intensity,",0.5*(psnr_grad_x + psnr_grad_y):",0.5*(psnr_grad_x + psnr_grad_y),",perc_loss:",p_loss)
total_loss = 10 + psnr_intensity + 0.5*(psnr_grad_x + psnr_grad_y) + p_loss total_loss = 10 + psnr_intensity + 0.5*(psnr_grad_x + psnr_grad_y) + p_loss
return total_loss return total_loss
...@@ -301,15 +301,56 @@ def save_checkpoints(file_path, epoch_idx, model, model_solver): ...@@ -301,15 +301,56 @@ def save_checkpoints(file_path, epoch_idx, model, model_solver):
} }
torch.save(checkpoint, file_path) torch.save(checkpoint, file_path)
mode = "val" def hook_fn_back(m, i, o):
for grad in i:
try:
print("Input Grad:",m,grad.shape,grad.sum())
except AttributeError:
print ("None found for Gradient")
for grad in o:
try:
print("Output Grad:",m,grad.shape,grad.sum())
except AttributeError:
print ("None found for Gradient")
print("\n")
def hook_fn_for(m, i, o):
for grad in i:
try:
print("Input Feats:",m,grad.shape,grad.sum())
except AttributeError:
print ("None found for Gradient")
for grad in o:
try:
print("Output Feats:",m,grad.shape,grad.sum())
except AttributeError:
print ("None found for Gradient")
print("\n")
if __name__ == "__main__": if __name__ == "__main__":
#test
# train_dataset = lightFieldDataLoader(DATA_FILE,DATA_JSON)
# print(train_dataset[0][0].shape)
# cv2.imwrite("test_crop0.png",train_dataset[0][1]*255.)
# save_image(output[0][0:3].data,os.path.join(OUTPUT_DIR,"o%d_%d.png"%(epoch,batch_idx)))
#test end
############################## generate phi and mask in pre-training
phi_dict = {}
mask_dict = {}
idx_info_dict = {}
print("generating phi and mask...")
with open(DATA_JSON, encoding='utf-8') as file:
dataset_desc = json.loads(file.read())
for i in range(len(dataset_desc["focaldepth"])):
# if i == 2:
# break
idx = dataset_desc["idx"][i]
focaldepth = dataset_desc["focaldepth"][i]
gazeX = dataset_desc["gazeX"][i]
gazeY = dataset_desc["gazeY"][i]
# print("focaldepth:",focaldepth," idx:",idx," gazeX:",gazeX," gazeY:",gazeY)
phi,mask = gen.CalculateRetinal2LayerMappings(focaldepth,torch.tensor([gazeX, gazeY]))
phi_dict[idx]=phi
mask_dict[idx]=mask
idx_info_dict[idx]=[idx,focaldepth,gazeX,gazeY]
print("generating phi and mask end.")
# exit(0)
#train #train
train_data_loader = torch.utils.data.DataLoader(dataset=lightFieldDataLoader(DATA_FILE,DATA_JSON), train_data_loader = torch.utils.data.DataLoader(dataset=lightFieldDataLoader(DATA_FILE,DATA_JSON),
batch_size=BATCH_SIZE, batch_size=BATCH_SIZE,
...@@ -319,82 +360,51 @@ if __name__ == "__main__": ...@@ -319,82 +360,51 @@ if __name__ == "__main__":
drop_last=False) drop_last=False)
print(len(train_data_loader)) print(len(train_data_loader))
val_data_loader = torch.utils.data.DataLoader(dataset=lightFieldDataLoader(DATA_FILE,DATA_VAL_JSON), # exit(0)
batch_size=1,
num_workers=0,
pin_memory=True,
shuffle=False,
drop_last=False)
print(len(val_data_loader))
################################################ train #########################################################
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
lf_model = model() lf_model = model()
lf_model.apply(weight_init_normal)
epoch_begin = 0
if torch.cuda.is_available(): if torch.cuda.is_available():
lf_model = torch.nn.DataParallel(lf_model).cuda() lf_model = torch.nn.DataParallel(lf_model).cuda()
lf_model.train()
optimizer = torch.optim.Adam(lf_model.parameters(),lr=1e-2,betas=(0.9,0.999))
#val print("begin training....")
checkpoint = torch.load(os.path.join(OUTPUT_DIR,"ckpt-epoch-3001.pth")) for epoch in range(epoch_begin, NUM_EPOCH):
lf_model.load_state_dict(checkpoint["model_state_dict"]) for batch_idx, (image_set, gt, df, gazeX, gazeY, sample_idx) in enumerate(train_data_loader):
lf_model.eval()
print("Eval::")
for sample_idx, (image_set, gt, df) in enumerate(val_data_loader):
print("sample_idx::")
with torch.no_grad():
#reshape for input #reshape for input
image_set = image_set.permute(0,1,4,2,3) # N LF C H W image_set = image_set.permute(0,1,4,2,3) # N LF C H W
image_set = image_set.reshape(image_set.shape[0],-1,image_set.shape[3],image_set.shape[4]) # N, LFxC, H, W image_set = image_set.reshape(image_set.shape[0],-1,image_set.shape[3],image_set.shape[4]) # N, LFxC, H, W
image_set = var_or_cuda(image_set) image_set = var_or_cuda(image_set)
# image_set.to(device)
gt = gt.permute(0,3,1,2) gt = gt.permute(0,3,1,2)
gt = var_or_cuda(gt) gt = var_or_cuda(gt)
# print("Epoch:",epoch,",Iter:",batch_idx,",Input shape:",image_set.shape, ",Input gt:",gt.shape)
output = lf_model(image_set,df)
print("output:",output.shape," df:",df)
save_image(output[0][0:3].data,os.path.join(OUTPUT_DIR,"1113_interp_l1_%.3f.png"%(df[0].data)))
save_image(output[0][3:6].data,os.path.join(OUTPUT_DIR,"1113_interp_l2_%.3f.png"%(df[0].data)))
output = GenRetinalFromLayersBatch(output,conf,df,v,u)
save_image(output[0][0:3].data,os.path.join(OUTPUT_DIR,"1113_interp_o%.3f.png"%(df[0].data)))
exit()
# train
# print(lf_model)
# exit()
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# lf_model = model()
# lf_model.apply(weight_init_normal)
# if torch.cuda.is_available():
# lf_model = torch.nn.DataParallel(lf_model).cuda()
# lf_model.train()
# optimizer = torch.optim.Adam(lf_model.parameters(),lr=5e-2,betas=(0.9,0.999))
# for epoch in range(NUM_EPOCH):
# for batch_idx, (image_set, gt, df) in enumerate(train_data_loader):
# #reshape for input
# image_set = image_set.permute(0,1,4,2,3) # N LF C H W
# image_set = image_set.reshape(image_set.shape[0],-1,image_set.shape[3],image_set.shape[4]) # N, LFxC, H, W
# image_set = var_or_cuda(image_set)
# # image_set.to(device)
# gt = gt.permute(0,3,1,2)
# gt = var_or_cuda(gt)
# # print("Epoch:",epoch,",Iter:",batch_idx,",Input shape:",image_set.shape, ",Input gt:",gt.shape)
# optimizer.zero_grad()
# output = lf_model(image_set,df)
# # print("output:",output.shape," df:",df.shape)
# output = GenRetinalFromLayersBatch(output,conf,df,v,u)
# loss = loss_new(output,gt)
# print("Epoch:",epoch,",Iter:",batch_idx,",loss:",loss)
# loss.backward()
# optimizer.step()
# if (epoch%100 == 0):
# for i in range(BATCH_SIZE):
# save_image(output[i][0:3].data,os.path.join(OUTPUT_DIR,"cuda_lr_5e-2_mul_dip_newloss_debug_conf_o%d_%d.png"%(epoch,i)))
# if (epoch%1000 == 0):
# save_checkpoints(os.path.join(OUTPUT_DIR, 'ckpt-epoch-%04d.pth' % (epoch + 1)),
# epoch,lf_model,optimizer)
optimizer.zero_grad()
output = lf_model(image_set,df,gazeX,gazeY)
########################### Use Pregen Phi and Mask ###################
output1,mask = GenRetinalFromLayersBatch(output, gen, sample_idx, phi_dict, mask_dict)
mask = var_or_cuda(mask)
mask.requires_grad = False
output_f = output1 * mask
gt = gt * mask
loss = loss_new(output_f,gt)
print("Epoch:",epoch,",Iter:",batch_idx,",loss:",loss)
########################### Update ###################
loss.backward()
optimizer.step()
########################### Save #####################
if ((epoch%50== 0) or epoch == 5):
for i in range(output_f.size()[0]):
save_image(output[i][0:3].data,os.path.join(OUTPUT_DIR,"gaze_fac1_o_%.3f_%.3f_%.3f.png"%(df[i].data,gazeX[i].data,gazeY[i].data)))
save_image(output[i][3:6].data,os.path.join(OUTPUT_DIR,"gaze_fac2_o_%.3f_%.3f_%.3f.png"%(df[i].data,gazeX[i].data,gazeY[i].data)))
save_image(output_f[i][0:3].data,os.path.join(OUTPUT_DIR,"gaze_test1_o_%.3f_%.3f_%.3f.png"%(df[i].data,gazeX[i].data,gazeY[i].data)))
if ((epoch%200 == 0) and epoch != 0 and batch_idx==len(train_data_loader)-1):
save_checkpoints(os.path.join(OUTPUT_DIR, 'gaze-ckpt-epoch-%04d.pth' % (epoch + 1)),
epoch,lf_model,optimizer)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment