Commit 5069f8ae authored by BobYeah's avatar BobYeah
Browse files

Gaze

parent 055dc0bb
import matplotlib.pyplot as plt
import numpy as np
import torch
import glm
def RandomGenSamplesInPupil(conf, n_samples):
def Fov2Length(angle):
'''
'''
return np.tan(angle * np.pi / 360) * 2
def RandomGenSamplesInPupil(pupil_size, n_samples):
'''
Random sample n_samples positions in pupil region
......@@ -18,14 +26,14 @@ def RandomGenSamplesInPupil(conf, n_samples):
samples = torch.empty(n_samples, 2)
i = 0
while i < n_samples:
s = (torch.rand(2) - 0.5) * conf.pupil_size
if np.linalg.norm(s) > conf.pupil_size / 2.:
s = (torch.rand(2) - 0.5) * pupil_size
if np.linalg.norm(s) > pupil_size / 2.:
continue
samples[i, :] = s
samples[i, :] = [ s[0], s[1], 0 ]
i += 1
return samples
def GenSamplesInPupil(conf, circles):
def GenSamplesInPupil(pupil_size, circles):
'''
Sample positions on circles in pupil region
......@@ -38,68 +46,116 @@ def GenSamplesInPupil(conf, circles):
--------
a n_samples x 2 tensor with 2D sample position in each row
'''
samples = torch.tensor([[ 0., 0. ]])
samples = torch.zeros(1, 3)
for i in range(1, circles):
r = conf.pupil_size / 2. / (circles - 1) * i
r = pupil_size / 2. / (circles - 1) * i
n = 4 * i
for j in range(0, n):
angle = 2 * np.pi / n * j
samples = torch.cat((samples, torch.tensor([[ r * np.cos(angle), r * np.sin(angle)]])),dim=0)
samples = torch.cat([ samples, torch.tensor([[ r * np.cos(angle), r * np.sin(angle), 0 ]]) ], 0)
return samples
def GenRetinal2LayerMappings(conf, df, v, u):
class RetinalGen(object):
'''
Generate the mapping matrix from retinal to layers.
Class for retinal generation process
Parameters
Properties
--------
conf - multi-layers' parameters configuration
df - focal distance
v - a 1 x 2 tensor stores half viewport
u - a M x 2 tensor stores M sample positions on pupil
u - M x 3 tensor, M sample positions in pupil
p_r - H_r x W_r x 3 tensor, retinal pixel grid, [H_r, W_r] is the retinal resolution
Phi - N x H_r x W_r x M x 2 tensor, retinal to layers mapping, N is number of layers
mask - N x H_r x W_r x M x 2 tensor, indicates invalid (out-of-range) mapping
Returns
Methods
--------
The mapping matrix
'''
H_r = conf.retinal_res[0]
W_r = conf.retinal_res[1]
D_r = conf.retinal_res.double()
N = conf.n_layers
M = u.size()[0] #41
Phi = torch.empty(H_r, W_r, N, M, 2, dtype=torch.long)
p_rx, p_ry = torch.meshgrid(torch.tensor(range(0, H_r)),
torch.tensor(range(0, W_r)))
p_r = torch.stack([p_rx, p_ry], 2).unsqueeze(2).expand(-1, -1, M, -1)
# print(p_r.shape) #torch.Size([480, 640, 41, 2])
for i in range(0, N):
dpi = conf.h_layer[i] / conf.layer_res[0] # 1 / 480
ci = conf.layer_res / 2 # [240,320]
di = conf.d_layer[i] # 深度
pi_r = di * v * (1. / D_r * (p_r + 0.5) - 0.5) / dpi # [480, 640, 41, 2]
wi = (1 - di / df) / dpi # (1 - 深度/聚焦) / dpi df = 2.625 di = 1.75
pi = torch.floor(pi_r + ci + wi * u)
torch.clamp_(pi[:, :, :, 0], 0, conf.layer_res[0] - 1)
torch.clamp_(pi[:, :, :, 1], 0, conf.layer_res[1] - 1)
Phi[:, :, i, :, :] = pi
return Phi
def __init__(self, conf, u):
'''
Initialize retinal generator instance
def GenRetinalFromLayers(layers, Phi):
# layers: 2, color, height, width
# Phi:torch.Size([480, 640, 2, 41, 2])
M = Phi.size()[3] # 41
N = Phi.size()[2] # 2
# print(layers.shape)# torch.Size([2, 3, 480, 640])
# print(Phi.shape)# torch.Size([480, 640, 2, 41, 2])
# retinal image: 3channels x retinal_size
retinal = torch.zeros(3, Phi.size()[0], Phi.size()[1])
for j in range(0, M):
retinal_view = torch.zeros(3, Phi.size()[0], Phi.size()[1])
for i in range(0, N):
retinal_view.add_(layers[i,:, Phi[:, :, i, j, 0], Phi[:, :, i, j, 1]])
retinal.add_(retinal_view)
retinal.div_(M)
return retinal
Parameters
--------
conf - multi-layers' parameters configuration
u - a M x 3 tensor stores M sample positions in pupil
'''
self.conf = conf
# self.u = u.to(cuda_dev)
self.u = u # M x 3 M sample positions
self.D_r = conf.retinal_res # retinal res 480 x 640
self.N = conf.GetNLayers() # 2
self.M = u.size()[0] # samples
p_rx, p_ry = torch.meshgrid(torch.tensor(range(0, self.D_r[0])),
torch.tensor(range(0, self.D_r[1])))
self.p_r = torch.cat([
((torch.stack([p_rx, p_ry], 2) + 0.5) / self.D_r - 0.5) * conf.GetEyeViewportSize(), # 眼球视野
torch.ones(self.D_r[0], self.D_r[1], 1)
], 2)
# self.Phi = torch.empty(N, D_r[0], D_r[1], M, 2, device=cuda_dev, dtype=torch.long)
# self.mask = torch.empty(self.N, self.D_r[0], self.D_r[1], self.M, 2, dtype=torch.float) # 2 x 480 x 640 x 41 x 2
def CalculateRetinal2LayerMappings(self, df, gaze):
'''
Calculate the mapping matrix from retinal to layers.
\ No newline at end of file
Parameters
--------
df - focus distance
gaze - 2 x 1 tensor, eye rotation angle (degs) in horizontal and vertical direction
'''
Phi = torch.empty(self.N, self.D_r[0], self.D_r[1], self.M, 2, dtype=torch.long) # 2 x 480 x 640 x 41 x 2
mask = torch.empty(self.N, self.D_r[0], self.D_r[1], self.M, 2, dtype=torch.float)
D_r = self.conf.retinal_res # D_r: Resolution of retinal 480 640
V = self.conf.GetEyeViewportSize() # V: Viewport size of eye
c = (self.conf.layer_res / 2) # c: Center of layers (pixel)
p_f = self.p_r * df # p_f: H x W x 3, focus positions of retinal pixels on focus plane
rot_forward = glm.dvec3(glm.tan(glm.radians(glm.dvec2(gaze[1], -gaze[0]))), 1)
rot_mat = torch.from_numpy(np.array(
glm.dmat3(glm.lookAtLH(glm.dvec3(), rot_forward, glm.dvec3(0, 1, 0)))))
rot_mat = rot_mat.float()
u_rot = torch.mm(self.u, rot_mat)
v_rot = torch.matmul(p_f, rot_mat).unsqueeze(2).expand(
-1, -1, self.u.size()[0], -1) - u_rot # v_rot: H x W x M x 3, rotated rays' direction vector
v_rot.div_(v_rot[:, :, :, 2].unsqueeze(3)) # make z = 1 for each direction vector in v_rot
for i in range(0, self.conf.GetNLayers()):
dp_i = self.conf.GetLayerSize(i)[0] / self.conf.layer_res[0] # dp_i: Pixel size of layer i
d_i = self.conf.d_layer[i] # d_i: Distance of layer i
k = (d_i - u_rot[:, 2]).unsqueeze(1)
pi_r = (u_rot[:, 0:2] + v_rot[:, :, :, 0:2] * k) / dp_i # pi_r: H x W x M x 2, rays' pixel coord on layer i
Phi[i, :, :, :, :] = torch.floor(pi_r + c)
mask[:, :, :, :, 0] = ((Phi[:, :, :, :, 0] >= 0) & (Phi[:, :, :, :, 0] < self.conf.layer_res[0])).float()
mask[:, :, :, :, 1] = ((Phi[:, :, :, :, 1] >= 0) & (Phi[:, :, :, :, 1] < self.conf.layer_res[1])).float()
Phi[:, :, :, :, 0].clamp_(0, self.conf.layer_res[0] - 1)
Phi[:, :, :, :, 1].clamp_(0, self.conf.layer_res[1] - 1)
retinal_mask = mask.prod(0).prod(2).prod(2)
return [ Phi, retinal_mask ]
def GenRetinalFromLayers(self, layers, Phi):
'''
Generate retinal image from layers, using precalculated mapping matrix
Parameters
--------
layers - 3N x H_l x W_l tensor, stacked layer images, with 3 channels in each layer
Returns
--------
3 x H_r x W_r tensor, 3 channels retinal image
H_r x W_r tensor, retinal image mask, indicates pixels valid or not
'''
# FOR GRAYSCALE 1 FOR RGB 3
mapped_layers = torch.empty(self.N, 3, self.D_r[0], self.D_r[1], self.M) # 2 x 3 x 480 x 640 x 41
# print("mapped_layers:",mapped_layers.shape)
for i in range(0, Phi.size()[0]):
# print("gather layers:",layers[(i * 3) : (i * 3 + 3),Phi[i, :, :, :, 0],Phi[i, :, :, :, 1]].shape)
mapped_layers[i, :, :, :, :] = layers[(i * 3) : (i * 3 + 3),
Phi[i, :, :, :, 0],
Phi[i, :, :, :, 1]]
# print("mapped_layers:",mapped_layers.shape)
retinal = mapped_layers.prod(0).sum(3).div(Phi.size()[3])
# print("retinal:",retinal.shape)
return retinal
\ No newline at end of file
......@@ -16,55 +16,64 @@ import json
from ssim import *
from perc_loss import *
# param
BATCH_SIZE = 5
NUM_EPOCH = 5000
BATCH_SIZE = 16
NUM_EPOCH = 1000
INTERLEAVE_RATE = 2
IM_H = 480
IM_W = 640
IM_H = 320
IM_W = 320
Retinal_IM_H = 320
Retinal_IM_W = 320
N = 9 # number of input light field stack
M = 2 # number of display layers
DATA_FILE = "/home/yejiannan/Project/LightField/data/try"
DATA_JSON = "/home/yejiannan/Project/LightField/data/data.json"
DATA_FILE = "/home/yejiannan/Project/LightField/data/gaze_small_nar_new"
DATA_JSON = "/home/yejiannan/Project/LightField/data/data_gaze_low_new.json"
DATA_VAL_JSON = "/home/yejiannan/Project/LightField/data/data_val.json"
OUTPUT_DIR = "/home/yejiannan/Project/LightField/output"
OUTPUT_DIR = "/home/yejiannan/Project/LightField/output/gaze_low_new_1125_minibatch"
class lightFieldDataLoader(torch.utils.data.dataset.Dataset):
def __init__(self, file_dir_path, file_json, transforms=None):
self.file_dir_path = file_dir_path
self.transforms = transforms
# self.datum_list = glob.glob(os.path.join(file_dir_path,"*"))
with open(file_json, encoding='utf-8') as file:
self.dastset_desc = json.loads(file.read())
self.dataset_desc = json.loads(file.read())
def __len__(self):
return len(self.dastset_desc["focaldepth"])
return len(self.dataset_desc["focaldepth"])
def __getitem__(self, idx):
lightfield_images, gt, fd = self.get_datum(idx)
lightfield_images, gt, fd, gazeX, gazeY, sample_idx = self.get_datum(idx)
if self.transforms:
lightfield_images = self.transforms(lightfield_images)
return (lightfield_images, gt, fd)
return (lightfield_images, gt, fd, gazeX, gazeY, sample_idx)
def get_datum(self, idx):
lf_image_paths = os.path.join(DATA_FILE, self.dastset_desc["train"][idx])
# print(lf_image_paths)
fd_gt_path = os.path.join(DATA_FILE, self.dastset_desc["gt"][idx])
# print(fd_gt_path)
lf_image_paths = os.path.join(DATA_FILE, self.dataset_desc["train"][idx])
fd_gt_path = os.path.join(DATA_FILE, self.dataset_desc["gt"][idx])
lf_images = []
lf_image_big = cv2.imread(lf_image_paths, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.
lf_image_big = cv2.cvtColor(lf_image_big,cv2.COLOR_BGR2RGB)
for i in range(9):
lf_image = lf_image_big[i//3*IM_H:i//3*IM_H+IM_H,i%3*IM_W:i%3*IM_W+IM_W,0:3]
## IF GrayScale
# lf_image = lf_image_big[i//3*IM_H:i//3*IM_H+IM_H,i%3*IM_W:i%3*IM_W+IM_W,0:1]
# print(lf_image.shape)
lf_images.append(lf_image)
gt = cv2.imread(fd_gt_path, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.
gt = cv2.cvtColor(gt,cv2.COLOR_BGR2RGB)
fd = self.dastset_desc["focaldepth"][idx]
return (np.asarray(lf_images),gt,fd)
## IF GrayScale
# gt = cv2.imread(fd_gt_path, cv2.IMREAD_GRAYSCALE).astype(np.float32) / 255.
# gt = np.expand_dims(gt, axis=-1)
fd = self.dataset_desc["focaldepth"][idx]
gazeX = self.dataset_desc["gazeX"][idx]
gazeY = self.dataset_desc["gazeY"][idx]
sample_idx = self.dataset_desc["idx"][idx]
return np.asarray(lf_images),gt,fd,gazeX,gazeY,sample_idx
OUT_CHANNELS_RB = 128
KERNEL_SIZE_RB = 3
......@@ -128,7 +137,6 @@ class interleave(torch.nn.Module):
output = output.permute(0, 3, 1, 2)
return output
LAST_LAYER_CHANNELS = 6 * INTERLEAVE_RATE**2
FIRSST_LAYER_CHANNELS = 27 * INTERLEAVE_RATE**2
......@@ -144,37 +152,39 @@ class model(torch.nn.Module):
)
self.residual_block1 = residual_block(0)
self.residual_block2 = residual_block(1)
self.residual_block3 = residual_block(1)
self.residual_block2 = residual_block(3)
self.residual_block3 = residual_block(3)
self.residual_block4 = residual_block(3)
self.residual_block5 = residual_block(3)
self.output_layer = torch.nn.Sequential(
torch.nn.Conv2d(OUT_CHANNELS_RB+1,LAST_LAYER_CHANNELS,KERNEL_SIZE,stride=1,padding=1),
torch.nn.Conv2d(OUT_CHANNELS_RB+3,LAST_LAYER_CHANNELS,KERNEL_SIZE,stride=1,padding=1),
torch.nn.BatchNorm2d(LAST_LAYER_CHANNELS),
torch.nn.Sigmoid()
)
self.deinterleave = deinterleave(INTERLEAVE_RATE)
def forward(self, lightfield_images, focal_length):
# lightfield_images: torch.Size([batch_size, channels * D, H, W])
# channels : RGB*D: 3*9, H:256, W:256
def forward(self, lightfield_images, focal_length, gazeX, gazeY):
input_to_net = self.interleave(lightfield_images)
# print("after interleave:",input_to_net.shape)
input_to_rb = self.first_layer(input_to_net)
output = self.residual_block1(input_to_rb)
# print("output1:",output.shape)
depth_layer = torch.ones((output.shape[0],1,output.shape[2],output.shape[3]))
# print(df.shape[0])
depth_layer = torch.ones((input_to_rb.shape[0],1,input_to_rb.shape[2],input_to_rb.shape[3]))
gazeX_layer = torch.ones((input_to_rb.shape[0],1,input_to_rb.shape[2],input_to_rb.shape[3]))
gazeY_layer = torch.ones((input_to_rb.shape[0],1,input_to_rb.shape[2],input_to_rb.shape[3]))
for i in range(focal_length.shape[0]):
depth_layer[i] = 1. / focal_length[i]
# print(depth_layer.shape)
depth_layer[i] *= 1. / focal_length[i]
gazeX_layer[i] *= (gazeX[i] - (-3.333)) / (3.333*2)
gazeY_layer[i] *= (gazeY[i] - (-3.333)) / (3.333*2)
depth_layer = var_or_cuda(depth_layer)
output = torch.cat((output,depth_layer),dim=1)
gazeX_layer = var_or_cuda(gazeX_layer)
gazeY_layer = var_or_cuda(gazeY_layer)
output = torch.cat((output,depth_layer,gazeX_layer,gazeY_layer),dim=1)
output = self.residual_block2(output)
output = self.residual_block3(output)
# output = output + input_to_net
output = self.residual_block4(output)
output = self.residual_block5(output)
output = self.output_layer(output)
output = self.deinterleave(output)
return output
......@@ -182,72 +192,65 @@ class model(torch.nn.Module):
class Conf(object):
def __init__(self):
self.pupil_size = 0.02 # 2cm
self.retinal_res = torch.tensor([ 480, 640 ])
self.layer_res = torch.tensor([ 480, 640 ])
self.n_layers = 2
self.d_layer = [ 1., 3. ] # layers' distance
self.h_layer = [ 1. * 480. / 640., 3. * 480. / 640. ] # layers' height
self.retinal_res = torch.tensor([ Retinal_IM_H, Retinal_IM_W ])
self.layer_res = torch.tensor([ IM_H, IM_W ])
self.layer_hfov = 90 # layers' horizontal FOV
self.eye_hfov = 85 # eye's horizontal FOV
self.d_layer = [ 1, 3 ] # layers' distance
def GetNLayers(self):
return len(self.d_layer)
def GetLayerSize(self, i):
w = Fov2Length(self.layer_hfov)
h = w * self.layer_res[0] / self.layer_res[1]
return torch.tensor([ h, w ]) * self.d_layer[i]
def GetEyeViewportSize(self):
w = Fov2Length(self.eye_hfov)
h = w * self.retinal_res[0] / self.retinal_res[1]
return torch.tensor([ h, w ])
#### Image Gen
conf = Conf()
v = torch.tensor([conf.h_layer[0] / conf.d_layer[0],
conf.h_layer[0] / conf.d_layer[0] * conf.layer_res[1] / conf.layer_res[0]])
u = GenSamplesInPupil(conf.pupil_size, 5)
u = GenSamplesInPupil(conf, 5)
gen = RetinalGen(conf, u)
def GenRetinalFromLayersBatch(layers, conf, df, v, u):
# layers: batchsize, 2, color, height, width
def GenRetinalFromLayersBatch(layers, gen, sample_idx, phi_dict, mask_dict):
# layers: batchsize, 2*color, height, width
# Phi:torch.Size([batchsize, 480, 640, 2, 41, 2])
# df : batchsize,..
H_r = conf.retinal_res[0]
W_r = conf.retinal_res[1]
D_r = conf.retinal_res.double()
N = conf.n_layers
M = u.size()[0] #41
BS = df.shape[0]
Phi = torch.empty(BS, H_r, W_r, N, M, 2, dtype=torch.long)
# print("Phi:",Phi.shape)
p_rx, p_ry = torch.meshgrid(torch.tensor(range(0, H_r)),
torch.tensor(range(0, W_r)))
p_r = torch.stack([p_rx, p_ry], 2).unsqueeze(2).expand(-1, -1, M, -1)
# print("p_r:",p_r.shape) #torch.Size([480, 640, 41, 2])
for bs in range(BS):
for i in range(0, N):
dpi = conf.h_layer[i] / float(conf.layer_res[0]) # 1 / 480
# print("dpi:",dpi)
ci = conf.layer_res / 2 # [240,320]
di = conf.d_layer[i] # 深度
pi_r = di * v * (1. / D_r * (p_r + 0.5) - 0.5) / dpi # [480, 640, 41, 2]
wi = (1 - di / df[bs]) / dpi # (1 - 深度/聚焦) / dpi df = 2.625 di = 1.75
pi = torch.floor(pi_r + ci + wi * u)
torch.clamp_(pi[:, :, :, 0], 0, conf.layer_res[0] - 1)
torch.clamp_(pi[:, :, :, 1], 0, conf.layer_res[1] - 1)
Phi[bs, :, :, i, :, :] = pi
# print("Phi slice:",Phi[0, :, :, 0, 0, 0].shape)
retinal = torch.ones(BS, 3, H_r, W_r)
# retinal bs x color x height x width
retinal = torch.zeros(layers.shape[0], 3, Retinal_IM_H, Retinal_IM_W)
mask = [] # mask shape 480 x 640
for i in range(0, layers.size()[0]):
phi = phi_dict[int(sample_idx[i].data)]
phi = var_or_cuda(phi)
phi.requires_grad = False
retinal[i] = gen.GenRetinalFromLayers(layers[i],phi)
mask.append(mask_dict[int(sample_idx[i].data)])
retinal = var_or_cuda(retinal)
for bs in range(BS):
for j in range(0, M):
retinal_view = torch.ones(3, H_r, W_r)
retinal_view = var_or_cuda(retinal_view)
for i in range(0, N):
retinal_view.mul_(layers[bs, (i * 3) : (i * 3 + 3), Phi[bs, :, :, i, j, 0], Phi[bs, :, :, i, j, 1]])
retinal[bs,:,:,:].add_(retinal_view)
retinal[bs,:,:,:].div_(M)
return retinal
#### Image Gen End
mask = torch.stack(mask,dim = 0).unsqueeze(1) # batch x 1 x height x width
return retinal, mask
def merge_two(near,far):
df = conf.d_layer[0] + (conf.d_layer[1] - conf.d_layer[0]) / 2.
# Phi = GenRetinal2LayerMappings(conf, df, v, u)
# retinal = GenRetinalFromLayers(layers, Phi)
return near[:,0:3,:,:] + far[:,3:6,:,:] / 2.0
def loss_two_images(generated, gt):
l1_loss = torch.nn.L1Loss()
return l1_loss(generated, gt)
def GenRetinalFromLayersBatch_Online(layers, gen, phi, mask):
# layers: batchsize, 2*color, height, width
# Phi:torch.Size([batchsize, 480, 640, 2, 41, 2])
# df : batchsize,..
# retinal bs x color x height x width
# retinal = torch.zeros(layers.shape[0], 3, Retinal_IM_H, Retinal_IM_W)
# retinal = var_or_cuda(retinal)
phi = var_or_cuda(phi)
phi.requires_grad = False
retinal = gen.GenRetinalFromLayers(layers[0],phi)
retinal = var_or_cuda(retinal)
mask_out = mask.unsqueeze(0).unsqueeze(0)
return retinal.unsqueeze(0), mask_out
#### Image Gen End
weightVarScale = 0.25
bias_stddev = 0.01
......@@ -269,7 +272,7 @@ def var_or_cuda(x):
def calImageGradients(images):
# x is a 4-D tensor
dx = images[:, :, 1:, :] - images[:, :, :-1, :]
dy = images[:, 1:, :, :] - images[:, :-1, :, :]
dy = images[:, :, :, 1:] - images[:, :, :, :-1]
return dx, dy
......@@ -279,16 +282,13 @@ perc_loss = perc_loss.to("cuda")
def loss_new(generated, gt):
mse_loss = torch.nn.MSELoss()
rmse_intensity = mse_loss(generated, gt)
RENORM_SCALE = torch.tensor(0.9)
RENORM_SCALE = var_or_cuda(RENORM_SCALE)
psnr_intensity = torch.log10(rmse_intensity)
ssim_intensity = ssim(generated, gt)
labels_dx, labels_dy = calImageGradients(gt)
preds_dx, preds_dy = calImageGradients(generated)
rmse_grad_x, rmse_grad_y = mse_loss(labels_dx, preds_dx), mse_loss(labels_dy, preds_dy)
psnr_grad_x, psnr_grad_y = torch.log10(rmse_grad_x), torch.log10(rmse_grad_y)
p_loss = perc_loss(generated,gt)
# print("-psnr:",-psnr_intensity,",0.5*(psnr_grad_x + psnr_grad_y):",0.5*(psnr_grad_x + psnr_grad_y),",perc_loss:",p_loss)
total_loss = 10 + psnr_intensity + 0.5*(psnr_grad_x + psnr_grad_y) + p_loss
return total_loss
......@@ -301,15 +301,56 @@ def save_checkpoints(file_path, epoch_idx, model, model_solver):
}
torch.save(checkpoint, file_path)
mode = "val"
def hook_fn_back(m, i, o):
for grad in i:
try:
print("Input Grad:",m,grad.shape,grad.sum())
except AttributeError:
print ("None found for Gradient")
for grad in o:
try:
print("Output Grad:",m,grad.shape,grad.sum())
except AttributeError:
print ("None found for Gradient")
print("\n")
def hook_fn_for(m, i, o):
for grad in i:
try:
print("Input Feats:",m,grad.shape,grad.sum())
except AttributeError:
print ("None found for Gradient")
for grad in o:
try:
print("Output Feats:",m,grad.shape,grad.sum())
except AttributeError:
print ("None found for Gradient")
print("\n")
if __name__ == "__main__":
#test
# train_dataset = lightFieldDataLoader(DATA_FILE,DATA_JSON)
# print(train_dataset[0][0].shape)
# cv2.imwrite("test_crop0.png",train_dataset[0][1]*255.)
# save_image(output[0][0:3].data,os.path.join(OUTPUT_DIR,"o%d_%d.png"%(epoch,batch_idx)))
#test end
############################## generate phi and mask in pre-training
phi_dict = {}
mask_dict = {}
idx_info_dict = {}
print("generating phi and mask...")
with open(DATA_JSON, encoding='utf-8') as file:
dataset_desc = json.loads(file.read())
for i in range(len(dataset_desc["focaldepth"])):
# if i == 2:
# break
idx = dataset_desc["idx"][i]
focaldepth = dataset_desc["focaldepth"][i]
gazeX = dataset_desc["gazeX"][i]
gazeY = dataset_desc["gazeY"][i]
# print("focaldepth:",focaldepth," idx:",idx," gazeX:",gazeX," gazeY:",gazeY)
phi,mask = gen.CalculateRetinal2LayerMappings(focaldepth,torch.tensor([gazeX, gazeY]))
phi_dict[idx]=phi
mask_dict[idx]=mask
idx_info_dict[idx]=[idx,focaldepth,gazeX,gazeY]
print("generating phi and mask end.")
# exit(0)
#train
train_data_loader = torch.utils.data.DataLoader(dataset=lightFieldDataLoader(DATA_FILE,DATA_JSON),
batch_size=BATCH_SIZE,
......@@ -319,82 +360,51 @@ if __name__ == "__main__":
drop_last=False)
print(len(train_data_loader))
val_data_loader = torch.utils.data.DataLoader(dataset=lightFieldDataLoader(DATA_FILE,DATA_VAL_JSON),
batch_size=1,
num_workers=0,
pin_memory=True,
shuffle=False,
drop_last=False)
print(len(val_data_loader))
# exit(0)
################################################ train #########################################################
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
lf_model = model()
lf_model.apply(weight_init_normal)
epoch_begin = 0
if torch.cuda.is_available():
lf_model = torch.nn.DataParallel(lf_model).cuda()
lf_model.train()
optimizer = torch.optim.Adam(lf_model.parameters(),lr=1e-2,betas=(0.9,0.999))
#val
checkpoint = torch.load(os.path.join(OUTPUT_DIR,"ckpt-epoch-3001.pth"))
lf_model.load_state_dict(checkpoint["model_state_dict"])
lf_model.eval()
print("Eval::")
for sample_idx, (image_set, gt, df) in enumerate(val_data_loader):
print("sample_idx::")
with torch.no_grad():
print("begin training....")
for epoch in range(epoch_begin, NUM_EPOCH):
for batch_idx, (image_set, gt, df, gazeX, gazeY, sample_idx) in enumerate(train_data_loader):
#reshape for input
image_set = image_set.permute(0,1,4,2,3) # N LF C H W
image_set = image_set.reshape(image_set.shape[0],-1,image_set.shape[3],image_set.shape[4]) # N, LFxC, H, W
image_set = var_or_cuda(image_set)
# image_set.to(device)
gt = gt.permute(0,3,1,2)
gt = var_or_cuda(gt)
# print("Epoch:",epoch,",Iter:",batch_idx,",Input shape:",image_set.shape, ",Input gt:",gt.shape)
output = lf_model(image_set,df)
print("output:",output.shape," df:",df)
save_image(output[0][0:3].data,os.path.join(OUTPUT_DIR,"1113_interp_l1_%.3f.png"%(df[0].data)))
save_image(output[0][3:6].data,os.path.join(OUTPUT_DIR,"1113_interp_l2_%.3f.png"%(df[0].data)))
output = GenRetinalFromLayersBatch(output,conf,df,v,u)
save_image(output[0][0:3].data,os.path.join(OUTPUT_DIR,"1113_interp_o%.3f.png"%(df[0].data)))
exit()
# train
# print(lf_model)
# exit()
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# lf_model = model()
# lf_model.apply(weight_init_normal)
# if torch.cuda.is_available():
# lf_model = torch.nn.DataParallel(lf_model).cuda()
# lf_model.train()
# optimizer = torch.optim.Adam(lf_model.parameters(),lr=5e-2,betas=(0.9,0.999))
# for epoch in range(NUM_EPOCH):
# for batch_idx, (image_set, gt, df) in enumerate(train_data_loader):
# #reshape for input
# image_set = image_set.permute(0,1,4,2,3) # N LF C H W
# image_set = image_set.reshape(image_set.shape[0],-1,image_set.shape[3],image_set.shape[4]) # N, LFxC, H, W
# image_set = var_or_cuda(image_set)
# # image_set.to(device)
# gt = gt.permute(0,3,1,2)
# gt = var_or_cuda(gt)
# # print("Epoch:",epoch,",Iter:",batch_idx,",Input shape:",image_set.shape, ",Input gt:",gt.shape)
# optimizer.zero_grad()
# output = lf_model(image_set,df)
# # print("output:",output.shape," df:",df.shape)
# output = GenRetinalFromLayersBatch(output,conf,df,v,u)
# loss = loss_new(output,gt)
# print("Epoch:",epoch,",Iter:",batch_idx,",loss:",loss)
# loss.backward()
# optimizer.step()
# if (epoch%100 == 0):
# for i in range(BATCH_SIZE):
# save_image(output[i][0:3].data,os.path.join(OUTPUT_DIR,"cuda_lr_5e-2_mul_dip_newloss_debug_conf_o%d_%d.png"%(epoch,i)))
# if (epoch%1000 == 0):
# save_checkpoints(os.path.join(OUTPUT_DIR, 'ckpt-epoch-%04d.pth' % (epoch + 1)),
# epoch,lf_model,optimizer)
\ No newline at end of file
optimizer.zero_grad()
output = lf_model(image_set,df,gazeX,gazeY)
########################### Use Pregen Phi and Mask ###################
output1,mask = GenRetinalFromLayersBatch(output, gen, sample_idx, phi_dict, mask_dict)
mask = var_or_cuda(mask)
mask.requires_grad = False
output_f = output1 * mask
gt = gt * mask
loss = loss_new(output_f,gt)
print("Epoch:",epoch,",Iter:",batch_idx,",loss:",loss)
########################### Update ###################
loss.backward()
optimizer.step()
########################### Save #####################
if ((epoch%50== 0) or epoch == 5):
for i in range(output_f.size()[0]):
save_image(output[i][0:3].data,os.path.join(OUTPUT_DIR,"gaze_fac1_o_%.3f_%.3f_%.3f.png"%(df[i].data,gazeX[i].data,gazeY[i].data)))
save_image(output[i][3:6].data,os.path.join(OUTPUT_DIR,"gaze_fac2_o_%.3f_%.3f_%.3f.png"%(df[i].data,gazeX[i].data,gazeY[i].data)))
save_image(output_f[i][0:3].data,os.path.join(OUTPUT_DIR,"gaze_test1_o_%.3f_%.3f_%.3f.png"%(df[i].data,gazeX[i].data,gazeY[i].data)))
if ((epoch%200 == 0) and epoch != 0 and batch_idx==len(train_data_loader)-1):
save_checkpoints(os.path.join(OUTPUT_DIR, 'gaze-ckpt-epoch-%04d.pth' % (epoch + 1)),
epoch,lf_model,optimizer)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment