From a22cfc90006c0bb8fa7c165273ec461d804169da Mon Sep 17 00:00:00 2001 From: BobYeah <635596704@qq.com> Date: Sat, 19 Dec 2020 15:32:04 +0800 Subject: [PATCH] Update CNN scripts --- Flow.py | 83 ++++ model/baseline.py => baseline.py | 54 +- conf.py | 55 +- data.py | 305 ++++++++++++ gen_image.py | 287 +++++++++-- loss.py | 80 +++ main.py | 826 +++++++++++++++---------------- main_lf_syn.py | 6 +- model/__init__.py | 0 model/recurrent.py | 146 ------ util.py | 104 ++++ weight_init.py | 12 + 12 files changed, 1283 insertions(+), 675 deletions(-) create mode 100644 Flow.py rename model/baseline.py => baseline.py (75%) create mode 100644 loss.py delete mode 100644 model/__init__.py delete mode 100644 model/recurrent.py create mode 100644 util.py create mode 100644 weight_init.py diff --git a/Flow.py b/Flow.py new file mode 100644 index 0000000..83a6b6a --- /dev/null +++ b/Flow.py @@ -0,0 +1,83 @@ +import matplotlib.pyplot as plt +import torch +import util +import numpy as np +def FlowMap(b_last_frame, b_map): + ''' + Map images using the flow data. + + Parameters + -------- + b_last_frame - B x 3 x H x W tensor, batch of images + b_map - B x H x W x 2, batch of map data records pixel coords in last frames + + Returns + -------- + B x 3 x H x W tensor, batch of images mapped by flow data + ''' + return torch.nn.functional.grid_sample(b_last_frame, b_map, align_corners=False) + +class Flow(object): + ''' + Class representating optical flow + + Properties + -------- + b_data - B x H x W x 2, batch of flow data + b_map - B x H x W x 2, batch of map data records pixel coords in last frames + b_invalid_mask - B x H x W, batch of masks, indicate invalid elements in corresponding flow data + ''' + def Load(paths): + ''' + Create a Flow instance using a batch of encoded data images loaded from paths + + Parameters + -------- + paths - list of encoded data image paths + + Returns + -------- + Flow instance + ''' + b_encoded_image = util.ReadImageTensor(paths, rgb_only=False, permute=False, batch_dim=True) + return Flow(b_encoded_image) + + def __init__(self, b_encoded_image): + ''' + Initialize a Flow instance from a batch of encoded data images + + Parameters + -------- + b_encoded_image - batch of encoded data images + ''' + b_encoded_image = b_encoded_image.mul(255) + # print("b_encoded_image:",b_encoded_image.shape) + self.b_invalid_mask = (b_encoded_image[:, :, :, 0] == 255) + self.b_data = (b_encoded_image[:, :, :, 0:2] / 254 + b_encoded_image[:, :, :, 2:4] - 127) / 127 + self.b_data[:, :, :, 1] = -self.b_data[:, :, :, 1] + D = self.b_data.size() + grid = util.MeshGrid((D[1], D[2]), True) + self.b_map = (grid - self.b_data - 0.5) * 2 + self.b_map[self.b_invalid_mask] = torch.tensor([ -2.0, -2.0 ]) + + def getMap(self): + return self.b_map + + def Visualize(self, scale_factor = 1): + ''' + Visualize the flow data by "color wheel". + + Parameters + -------- + scale_factor - scale factor of flow data to visualize, default is 1 + + Returns + -------- + B x 3 x H x W tensor, visualization of flow data + ''' + try: + Flow.b_color_wheel + except AttributeError: + Flow.b_color_wheel = util.ReadImageTensor('color_wheel.png') + return torch.nn.functional.grid_sample(Flow.b_color_wheel.expand(self.b_data.size()[0], -1, -1, -1), + (self.b_data * scale_factor), align_corners=False) \ No newline at end of file diff --git a/model/baseline.py b/baseline.py similarity index 75% rename from model/baseline.py rename to baseline.py index d829464..0ba91ff 100644 --- a/model/baseline.py +++ b/baseline.py @@ -38,7 +38,7 @@ class residual_block(torch.nn.Module): def forward(self,input): if self.RNN: # print("input:",input.shape,"hidden:",self.hidden.shape) - inp = torch.cat((input,self.hidden),dim=1) + inp = torch.cat((input,self.hidden.detach()),dim=1) # print(inp.shape) output = self.layer1(inp) output = self.layer2(output) @@ -97,7 +97,7 @@ class interleave(torch.nn.Module): return output class model(torch.nn.Module): - def __init__(self,FIRSST_LAYER_CHANNELS,LAST_LAYER_CHANNELS,OUT_CHANNELS_RB,KERNEL_SIZE,KERNEL_SIZE_RB,INTERLEAVE_RATE): + def __init__(self,FIRSST_LAYER_CHANNELS,LAST_LAYER_CHANNELS,OUT_CHANNELS_RB,KERNEL_SIZE,KERNEL_SIZE_RB,INTERLEAVE_RATE,RNN=False): super(model, self).__init__() self.interleave = interleave(INTERLEAVE_RATE) @@ -108,13 +108,19 @@ class model(torch.nn.Module): ) self.residual_block1 = residual_block(OUT_CHANNELS_RB,0,KERNEL_SIZE_RB,False) - self.residual_block2 = residual_block(OUT_CHANNELS_RB,3,KERNEL_SIZE_RB,False) - self.residual_block3 = residual_block(OUT_CHANNELS_RB,3,KERNEL_SIZE_RB,True) - self.residual_block4 = residual_block(OUT_CHANNELS_RB,3,KERNEL_SIZE_RB,True) - self.residual_block5 = residual_block(OUT_CHANNELS_RB,3,KERNEL_SIZE_RB,True) + self.residual_block2 = residual_block(OUT_CHANNELS_RB,2,KERNEL_SIZE_RB,False) + self.residual_block3 = residual_block(OUT_CHANNELS_RB,2,KERNEL_SIZE_RB,False) + # if RNN: + # self.residual_block3 = residual_block(OUT_CHANNELS_RB,6,KERNEL_SIZE_RB,True) + # self.residual_block4 = residual_block(OUT_CHANNELS_RB,6,KERNEL_SIZE_RB,True) + # self.residual_block5 = residual_block(OUT_CHANNELS_RB,6,KERNEL_SIZE_RB,True) + # else: + # self.residual_block3 = residual_block(OUT_CHANNELS_RB,6,KERNEL_SIZE_RB,False) + # self.residual_block4 = residual_block(OUT_CHANNELS_RB,6,KERNEL_SIZE_RB,False) + # self.residual_block5 = residual_block(OUT_CHANNELS_RB,6,KERNEL_SIZE_RB,False) self.output_layer = torch.nn.Sequential( - torch.nn.Conv2d(OUT_CHANNELS_RB+3,LAST_LAYER_CHANNELS,KERNEL_SIZE,stride=1,padding=1), + torch.nn.Conv2d(OUT_CHANNELS_RB+2,LAST_LAYER_CHANNELS,KERNEL_SIZE,stride=1,padding=1), torch.nn.BatchNorm2d(LAST_LAYER_CHANNELS), torch.nn.Sigmoid() ) @@ -125,7 +131,7 @@ class model(torch.nn.Module): self.residual_block4.reset_hidden(inp) self.residual_block5.reset_hidden(inp) - def forward(self, lightfield_images, focal_length, gazeX, gazeY): + def forward(self, lightfield_images, pos_row, pos_col): # lightfield_images: torch.Size([batch_size, channels * D, H, W]) # channels : RGB*D: 3*9, H:256, W:256 # print("lightfield_images:",lightfield_images.shape) @@ -136,32 +142,18 @@ class model(torch.nn.Module): # print("input_to_rb1:",input_to_rb.shape) output = self.residual_block1(input_to_rb) - depth_layer = torch.ones((input_to_rb.shape[0],1,input_to_rb.shape[2],input_to_rb.shape[3])) - gazeX_layer = torch.ones((input_to_rb.shape[0],1,input_to_rb.shape[2],input_to_rb.shape[3])) - gazeY_layer = torch.ones((input_to_rb.shape[0],1,input_to_rb.shape[2],input_to_rb.shape[3])) - # print("depth_layer:",depth_layer.shape) - # print("focal_depth:",focal_length," gazeX:",gazeX," gazeY:",gazeY, " gazeX norm:",(gazeX[0] - (-3.333)) / (3.333*2)) - for i in range(focal_length.shape[0]): - depth_layer[i] *= 1. / focal_length[i] - gazeX_layer[i] *= (gazeX[i] - (-3.333)) / (3.333*2) - gazeY_layer[i] *= (gazeY[i] - (-3.333)) / (3.333*2) + pos_row_layer = torch.ones((input_to_rb.shape[0],1,input_to_rb.shape[2],input_to_rb.shape[3])) + pos_col_layer = torch.ones((input_to_rb.shape[0],1,input_to_rb.shape[2],input_to_rb.shape[3])) + for i in range(pos_row.shape[0]): + pos_row_layer[i] *= pos_row[i] + pos_col_layer[i] *= pos_col[i] # print(depth_layer.shape) - depth_layer = var_or_cuda(depth_layer) - gazeX_layer = var_or_cuda(gazeX_layer) - gazeY_layer = var_or_cuda(gazeY_layer) - - output = torch.cat((output,depth_layer,gazeX_layer,gazeY_layer),dim=1) - # output = torch.cat((output,depth_layer),dim=1) - # print("output to rb2:",output.shape) + pos_row_layer = var_or_cuda(pos_row_layer) + pos_col_layer = var_or_cuda(pos_col_layer) + output = torch.cat((output,pos_row_layer,pos_col_layer),dim=1) output = self.residual_block2(output) - # print("output to rb3:",output.shape) output = self.residual_block3(output) - # print("output to rb4:",output.shape) - output = self.residual_block4(output) - # print("output to rb5:",output.shape) - output = self.residual_block5(output) - # output = output + input_to_net output = self.output_layer(output) output = self.deinterleave(output) - return output + return output \ No newline at end of file diff --git a/conf.py b/conf.py index a6aea74..2bb9180 100644 --- a/conf.py +++ b/conf.py @@ -1,27 +1,68 @@ import torch -from gen_image import * +import util +import numpy as np class Conf(object): def __init__(self): self.pupil_size = 0.02 self.retinal_res = torch.tensor([ 320, 320 ]) self.layer_res = torch.tensor([ 320, 320 ]) self.layer_hfov = 90 # layers' horizontal FOV - self.eye_hfov = 85 # eye's horizontal FOV (ignored in foveated rendering) - self.eye_enable_fovea = True # enable foveated rendering + self.eye_hfov = 80 # eye's horizontal FOV (ignored in foveated rendering) + self.eye_enable_fovea = False # enable foveated rendering self.eye_fovea_angles = [ 40, 80 ] # eye's foveation layers' angles self.eye_fovea_downsamples = [ 1, 2 ] # eye's foveation layers' downsamples self.d_layer = [ 1, 3 ] # layers' distance - + self.eye_fovea_blend = [ self._GenFoveaLayerBlend(0) ] + # blend maps of fovea layers + self.light_field_dim = 5 def GetNLayers(self): return len(self.d_layer) def GetLayerSize(self, i): - w = Fov2Length(self.layer_hfov) + w = util.Fov2Length(self.layer_hfov) h = w * self.layer_res[0] / self.layer_res[1] return torch.tensor([ h, w ]) * self.d_layer[i] + def GetPixelSizeOfLayer(self, i): + ''' + Get pixel size of layer i + ''' + return util.Fov2Length(self.layer_hfov) * self.d_layer[i] / self.layer_res[0] + def GetEyeViewportSize(self): fov = self.eye_fovea_angles[-1] if self.eye_enable_fovea else self.eye_hfov - w = Fov2Length(fov) + w = util.Fov2Length(fov) h = w * self.retinal_res[0] / self.retinal_res[1] - return torch.tensor([ h, w ]) \ No newline at end of file + return torch.tensor([ h, w ]) + + def GetRegionOfFoveaLayer(self, i): + ''' + Get region of fovea layer i in retinal image + + Returns + -------- + slice object stores the start and end of region + ''' + roi_size = int(np.ceil(self.retinal_res[0] * self.eye_fovea_angles[i] / self.eye_fovea_angles[-1])) + roi_offset = int((self.retinal_res[0] - roi_size) / 2) + return slice(roi_offset, roi_offset + roi_size) + + def _GenFoveaLayerBlend(self, i): + ''' + Generate blend map for fovea layer i + + Parameters + -------- + i - index of fovea layer + + Returns + -------- + H[i] x W[i], blend map + + ''' + region = self.GetRegionOfFoveaLayer(i) + width = region.stop - region.start + R = width / 2 + p = util.MeshGrid([ width, width ]) + r = torch.linalg.norm(p - R, 2, dim=2, keepdim=False) + return util.SmoothStep(R, R * 0.6, r) diff --git a/data.py b/data.py index e69de29..14b5a2c 100644 --- a/data.py +++ b/data.py @@ -0,0 +1,305 @@ +import torch +import os +import glob +import numpy as np +import torchvision.transforms as transforms + +from torchvision import datasets +from torch.utils.data import DataLoader + +import cv2 +import json +from Flow import * +from gen_image import * +import util + +import time + + +class lightFieldSynDataLoader(torch.utils.data.dataset.Dataset): + def __init__(self, file_dir_path,file_json): + self.file_dir_path = file_dir_path + with open(file_json, encoding='utf-8') as file: + self.dataset_desc = json.loads(file.read()) + self.input_img = [] + for i in self.dataset_desc["train"]: + lf_element = os.path.join(self.file_dir_path,i) + lf_element = cv2.imread(lf_element, -cv2.IMREAD_ANYDEPTH)[:, :, 0:3] + lf_element = cv2.cvtColor(lf_element, cv2.COLOR_BGR2RGB).astype(np.float32) / 255. + lf_element = lf_element[:,:-1,:] + self.input_img.append(lf_element) + self.input_img = np.asarray(self.input_img) + def __len__(self): + return len(self.dataset_desc["gt"]) + + def __getitem__(self, idx): + gt, pos_row, pos_col = self.get_datum(idx) + return (self.input_img, gt, pos_row, pos_col) + + def get_datum(self, idx): + fd_gt_path = os.path.join(self.file_dir_path, self.dataset_desc["gt"][idx]) + gt = cv2.imread(fd_gt_path, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255. + gt = cv2.cvtColor(gt,cv2.COLOR_BGR2RGB) + gt = gt[:,:-1,:] + pos_col = self.dataset_desc["x"][idx] + pos_row = self.dataset_desc["y"][idx] + return gt, pos_row, pos_col + +class lightFieldDataLoader(torch.utils.data.dataset.Dataset): + def __init__(self, file_dir_path, file_json, transforms=None): + self.file_dir_path = file_dir_path + self.transforms = transforms + # self.datum_list = glob.glob(os.path.join(file_dir_path,"*")) + with open(file_json, encoding='utf-8') as file: + self.dataset_desc = json.loads(file.read()) + + def __len__(self): + return len(self.dataset_desc["focaldepth"]) + + def __getitem__(self, idx): + lightfield_images, gt, gt2, fd, gazeX, gazeY, sample_idx = self.get_datum(idx) + if self.transforms: + lightfield_images = self.transforms(lightfield_images) + # print(lightfield_images.shape,gt.shape,fd,gazeX,gazeY,sample_idx) + return (lightfield_images, gt, gt2, fd, gazeX, gazeY, sample_idx) + + def get_datum(self, idx): + lf_image_paths = os.path.join(self.file_dir_path, self.dataset_desc["train"][idx]) + # print(lf_image_paths) + fd_gt_path = os.path.join(self.file_dir_path, self.dataset_desc["gt"][idx]) + fd_gt_path2 = os.path.join(self.file_dir_path, self.dataset_desc["gt2"][idx]) + # print(fd_gt_path) + lf_images = [] + lf_image_big = cv2.imread(lf_image_paths, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255. + lf_image_big = cv2.cvtColor(lf_image_big,cv2.COLOR_BGR2RGB) + + ## IF GrayScale + # lf_image_big = cv2.imread(lf_image_paths, cv2.IMREAD_GRAYSCALE).astype(np.float32) / 255. + # lf_image_big = np.expand_dims(lf_image_big, axis=-1) + # print(lf_image_big.shape) + + for i in range(9): + lf_image = lf_image_big[i//3*IM_H:i//3*IM_H+IM_H,i%3*IM_W:i%3*IM_W+IM_W,0:3] + ## IF GrayScale + # lf_image = lf_image_big[i//3*IM_H:i//3*IM_H+IM_H,i%3*IM_W:i%3*IM_W+IM_W,0:1] + # print(lf_image.shape) + lf_images.append(lf_image) + gt = cv2.imread(fd_gt_path, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255. + gt = cv2.cvtColor(gt,cv2.COLOR_BGR2RGB) + gt2 = cv2.imread(fd_gt_path2, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255. + gt2 = cv2.cvtColor(gt2,cv2.COLOR_BGR2RGB) + ## IF GrayScale + # gt = cv2.imread(fd_gt_path, cv2.IMREAD_GRAYSCALE).astype(np.float32) / 255. + # gt = np.expand_dims(gt, axis=-1) + + fd = self.dataset_desc["focaldepth"][idx] + gazeX = self.dataset_desc["gazeX"][idx] + gazeY = self.dataset_desc["gazeY"][idx] + sample_idx = self.dataset_desc["idx"][idx] + return np.asarray(lf_images),gt,gt2,fd,gazeX,gazeY,sample_idx + +class lightFieldValDataLoader(torch.utils.data.dataset.Dataset): + def __init__(self, file_dir_path, file_json, transforms=None): + self.file_dir_path = file_dir_path + self.transforms = transforms + # self.datum_list = glob.glob(os.path.join(file_dir_path,"*")) + with open(file_json, encoding='utf-8') as file: + self.dataset_desc = json.loads(file.read()) + + def __len__(self): + return len(self.dataset_desc["focaldepth"]) + + def __getitem__(self, idx): + lightfield_images, fd, gazeX, gazeY, sample_idx = self.get_datum(idx) + if self.transforms: + lightfield_images = self.transforms(lightfield_images) + # print(lightfield_images.shape,gt.shape,fd,gazeX,gazeY,sample_idx) + return (lightfield_images, fd, gazeX, gazeY, sample_idx) + + def get_datum(self, idx): + lf_image_paths = os.path.join(self.file_dir_path, self.dataset_desc["train"][idx]) + # print(fd_gt_path) + lf_images = [] + lf_image_big = cv2.imread(lf_image_paths, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255. + lf_image_big = cv2.cvtColor(lf_image_big,cv2.COLOR_BGR2RGB) + + ## IF GrayScale + # lf_image_big = cv2.imread(lf_image_paths, cv2.IMREAD_GRAYSCALE).astype(np.float32) / 255. + # lf_image_big = np.expand_dims(lf_image_big, axis=-1) + # print(lf_image_big.shape) + + for i in range(9): + lf_image = lf_image_big[i//3*IM_H:i//3*IM_H+IM_H,i%3*IM_W:i%3*IM_W+IM_W,0:3] + ## IF GrayScale + # lf_image = lf_image_big[i//3*IM_H:i//3*IM_H+IM_H,i%3*IM_W:i%3*IM_W+IM_W,0:1] + # print(lf_image.shape) + lf_images.append(lf_image) + ## IF GrayScale + # gt = cv2.imread(fd_gt_path, cv2.IMREAD_GRAYSCALE).astype(np.float32) / 255. + # gt = np.expand_dims(gt, axis=-1) + + fd = self.dataset_desc["focaldepth"][idx] + gazeX = self.dataset_desc["gazeX"][idx] + gazeY = self.dataset_desc["gazeY"][idx] + sample_idx = self.dataset_desc["idx"][idx] + return np.asarray(lf_images),fd,gazeX,gazeY,sample_idx + +class lightFieldSeqDataLoader(torch.utils.data.dataset.Dataset): + def __init__(self, file_dir_path, file_json, transforms=None): + self.file_dir_path = file_dir_path + self.transforms = transforms + with open(file_json, encoding='utf-8') as file: + self.dataset_desc = json.loads(file.read()) + + def __len__(self): + return len(self.dataset_desc["seq"]) + + def __getitem__(self, idx): + lightfield_images, gt, gt2, fd, gazeX, gazeY, sample_idx = self.get_datum(idx) + fd = fd.astype(np.float32) + gazeX = gazeX.astype(np.float32) + gazeY = gazeY.astype(np.float32) + sample_idx = sample_idx.astype(np.int64) + # print(fd) + # print(gazeX) + # print(gazeY) + # print(sample_idx) + + # print(lightfield_images.dtype,gt.dtype, gt2.dtype, fd.dtype, gazeX.dtype, gazeY.dtype, sample_idx.dtype, delta.dtype) + # print(lightfield_images.shape,gt.shape, gt2.shape, fd.shape, gazeX.shape, gazeY.shape, sample_idx.shape, delta.shape) + if self.transforms: + lightfield_images = self.transforms(lightfield_images) + return (lightfield_images, gt, gt2, fd, gazeX, gazeY, sample_idx) + + def get_datum(self, idx): + indices = self.dataset_desc["seq"][idx] + # print("indices:",indices) + lf_images = [] + fd = [] + gazeX = [] + gazeY = [] + sample_idx = [] + gt = [] + gt2 = [] + for i in range(len(indices)): + lf_image_paths = os.path.join(self.file_dir_path, self.dataset_desc["train"][indices[i]]) + fd_gt_path = os.path.join(self.file_dir_path, self.dataset_desc["gt"][indices[i]]) + fd_gt_path2 = os.path.join(self.file_dir_path, self.dataset_desc["gt2"][indices[i]]) + lf_image_one_sample = [] + lf_image_big = cv2.imread(lf_image_paths, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255. + lf_image_big = cv2.cvtColor(lf_image_big,cv2.COLOR_BGR2RGB) + + for j in range(9): + lf_image = lf_image_big[j//3*IM_H:j//3*IM_H+IM_H,j%3*IM_W:j%3*IM_W+IM_W,0:3] + lf_image_one_sample.append(lf_image) + + gt_i = cv2.imread(fd_gt_path, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255. + gt.append(cv2.cvtColor(gt_i,cv2.COLOR_BGR2RGB)) + gt2_i = cv2.imread(fd_gt_path2, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255. + gt2.append(cv2.cvtColor(gt2_i,cv2.COLOR_BGR2RGB)) + + # print("indices[i]:",indices[i]) + fd.append([self.dataset_desc["focaldepth"][indices[i]]]) + gazeX.append([self.dataset_desc["gazeX"][indices[i]]]) + gazeY.append([self.dataset_desc["gazeY"][indices[i]]]) + sample_idx.append([self.dataset_desc["idx"][indices[i]]]) + lf_images.append(lf_image_one_sample) + #lf_images: 5,9,320,320 + + return np.asarray(lf_images),np.asarray(gt),np.asarray(gt2),np.asarray(fd),np.asarray(gazeX),np.asarray(gazeY),np.asarray(sample_idx) + +class lightFieldFlowSeqDataLoader(torch.utils.data.dataset.Dataset): + def __init__(self, file_dir_path, file_json, gen, conf, transforms=None): + self.file_dir_path = file_dir_path + self.file_json = file_json + self.gen = gen + self.conf = conf + self.transforms = transforms + with open(file_json, encoding='utf-8') as file: + self.dataset_desc = json.loads(file.read()) + + def __len__(self): + return len(self.dataset_desc["seq"]) + + def __getitem__(self, idx): + # start = time.time() + lightfield_images, gt, phi, phi_invalid, retinal_invalid, flow, flow_invalid_mask, fd, gazeX, gazeY, posX, posY, posZ, sample_idx = self.get_datum(idx) + fd = fd.astype(np.float32) + gazeX = gazeX.astype(np.float32) + gazeY = gazeY.astype(np.float32) + posX = posX.astype(np.float32) + posY = posY.astype(np.float32) + posZ = posZ.astype(np.float32) + sample_idx = sample_idx.astype(np.int64) + + if self.transforms: + lightfield_images = self.transforms(lightfield_images) + # print("read once:",time.time() - start) # 8 ms + return (lightfield_images, gt, phi, phi_invalid, retinal_invalid, flow, flow_invalid_mask, fd, gazeX, gazeY, posX, posY, posZ, sample_idx) + + def get_datum(self, idx): + IM_H = 320 + IM_W = 320 + indices = self.dataset_desc["seq"][idx] + # print("indices:",indices) + lf_images = [] + fd = [] + gazeX = [] + gazeY = [] + posX = [] + posY = [] + posZ = [] + sample_idx = [] + gt = [] + # gt2 = [] + phi = [] + phi_invalid = [] + retinal_invalid = [] + for i in range(len(indices)): # 5 + lf_image_paths = os.path.join(self.file_dir_path, self.dataset_desc["train"][indices[i]]) + fd_gt_path = os.path.join(self.file_dir_path, self.dataset_desc["gt"][indices[i]]) + # fd_gt_path2 = os.path.join(self.file_dir_path, self.dataset_desc["gt2"][indices[i]]) + lf_image_one_sample = [] + lf_image_big = cv2.imread(lf_image_paths, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255. + lf_image_big = cv2.cvtColor(lf_image_big,cv2.COLOR_BGR2RGB) + + lf_dim = int(self.conf.light_field_dim) + for j in range(lf_dim**2): + lf_image = lf_image_big[j//lf_dim*IM_H:j//lf_dim*IM_H+IM_H,j%lf_dim*IM_W:j%lf_dim*IM_W+IM_W,0:3] + lf_image_one_sample.append(lf_image) + + gt_i = cv2.imread(fd_gt_path, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255. + gt.append(cv2.cvtColor(gt_i,cv2.COLOR_BGR2RGB)) + # gt2_i = cv2.imread(fd_gt_path2, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255. + # gt2.append(cv2.cvtColor(gt2_i,cv2.COLOR_BGR2RGB)) + + # print("indices[i]:",indices[i]) + fd.append([self.dataset_desc["focaldepth"][indices[i]]]) + gazeX.append([self.dataset_desc["gazeX"][indices[i]]]) + gazeY.append([self.dataset_desc["gazeY"][indices[i]]]) + posX.append([self.dataset_desc["x"][indices[i]]]) + posY.append([self.dataset_desc["y"][indices[i]]]) + posZ.append([0.0]) + sample_idx.append([self.dataset_desc["idx"][indices[i]]]) + lf_images.append(lf_image_one_sample) + + idx_i = sample_idx[i][0] + focaldepth_i = fd[i][0] + gazeX_i = gazeX[i][0] + gazeY_i = gazeY[i][0] + posX_i = posX[i][0] + posY_i = posY[i][0] + posZ_i = posZ[i][0] + # print("x:%.3f,y:%.3f,z:%.3f;gaze:%.4f,%.4f,focaldepth:%.3f."%(posX_i,posY_i,posZ_i,gazeX_i,gazeY_i,focaldepth_i)) + phi_i,phi_invalid_i,retinal_invalid_i = self.gen.CalculateRetinal2LayerMappings(torch.tensor([posX_i, posY_i,posZ_i]),torch.tensor([gazeX_i, gazeY_i]),focaldepth_i) + + phi.append(phi_i) + phi_invalid.append(phi_invalid_i) + retinal_invalid.append(retinal_invalid_i) + #lf_images: 5,9,320,320 + flow = Flow.Load([os.path.join(self.file_dir_path, self.dataset_desc["flow"][indices[i-1]]) for i in range(1,len(indices))]) + flow_map = flow.getMap() + flow_invalid_mask = flow.b_invalid_mask + # print("flow:",flow_map.shape) + + return np.asarray(lf_images),np.asarray(gt), torch.stack(phi,dim=0), torch.stack(phi_invalid,dim=0),torch.stack(retinal_invalid,dim=0), flow_map, flow_invalid_mask, np.asarray(fd),np.asarray(gazeX),np.asarray(gazeY),np.asarray(posX),np.asarray(posY),np.asarray(posZ),np.asarray(sample_idx) diff --git a/gen_image.py b/gen_image.py index 7ae1de8..83e7f4d 100644 --- a/gen_image.py +++ b/gen_image.py @@ -2,13 +2,8 @@ import matplotlib.pyplot as plt import numpy as np import torch import glm - -def Fov2Length(angle): - ''' - - ''' - return np.tan(angle * np.pi / 360) * 2 - +import time +import util def RandomGenSamplesInPupil(pupil_size, n_samples): ''' @@ -21,9 +16,9 @@ def RandomGenSamplesInPupil(pupil_size, n_samples): Returns -------- - a n_samples x 2 tensor with 2D sample position in each row + a n_samples x 3 tensor with 3D sample position in each row ''' - samples = torch.empty(n_samples, 2) + samples = torch.empty(n_samples, 3) i = 0 while i < n_samples: s = (torch.rand(2) - 0.5) * pupil_size @@ -44,7 +39,7 @@ def GenSamplesInPupil(pupil_size, circles): Returns -------- - a n_samples x 2 tensor with 2D sample position in each row + a n_samples x 3 tensor with 3D sample position in each row ''' samples = torch.zeros(1, 3) for i in range(1, circles): @@ -70,7 +65,7 @@ class RetinalGen(object): Methods -------- ''' - def __init__(self, conf, u): + def __init__(self, conf): ''' Initialize retinal generator instance @@ -80,58 +75,83 @@ class RetinalGen(object): u - a M x 3 tensor stores M sample positions in pupil ''' self.conf = conf + self.u = GenSamplesInPupil(conf.pupil_size, 5) # self.u = u.to(cuda_dev) - self.u = u # M x 3 M sample positions + # self.u = u # M x 3 M sample positions self.D_r = conf.retinal_res # retinal res 480 x 640 self.N = conf.GetNLayers() # 2 - self.M = u.size()[0] # samples - p_rx, p_ry = torch.meshgrid(torch.tensor(range(0, self.D_r[0])), - torch.tensor(range(0, self.D_r[1]))) + self.M = self.u.size()[0] # samples + # p_rx, p_ry = torch.meshgrid(torch.tensor(range(0, self.D_r[0])), + # torch.tensor(range(0, self.D_r[1]))) + # self.p_r = torch.cat([ + # ((torch.stack([p_rx, p_ry], 2) + 0.5) / self.D_r - 0.5) * conf.GetEyeViewportSize(), # 鐪肩悆瑙嗛噹 + # torch.ones(self.D_r[0], self.D_r[1], 1) + # ], 2) + self.p_r = torch.cat([ - ((torch.stack([p_rx, p_ry], 2) + 0.5) / self.D_r - 0.5) * conf.GetEyeViewportSize(), # 鐪肩悆瑙嗛噹 + ((util.MeshGrid(self.D_r) + 0.5) / self.D_r - 0.5) * conf.GetEyeViewportSize(), torch.ones(self.D_r[0], self.D_r[1], 1) ], 2) # self.Phi = torch.empty(N, D_r[0], D_r[1], M, 2, device=cuda_dev, dtype=torch.long) # self.mask = torch.empty(self.N, self.D_r[0], self.D_r[1], self.M, 2, dtype=torch.float) # 2 x 480 x 640 x 41 x 2 - def CalculateRetinal2LayerMappings(self, df, gaze): + def CalculateRetinal2LayerMappings(self, position, gaze_dir, df): ''' Calculate the mapping matrix from retinal to layers. Parameters -------- - df - focus distance - gaze - 2 x 1 tensor, eye rotation angle (degs) in horizontal and vertical direction + position - 1 x 3 tensor, eye's position + gaze_dir - 1 x 2 tensor, gaze forward vector (with z normalized) + df - focus distance + Returns + -------- + phi - N x H_r x W_r x M x 2, retinal to layers mapping, N is number of layers + phi_invalid - N x H_r x W_r x M x 1, indicates invalid (out-of-range) mapping + retinal_invalid - 1 x H_r x W_r, indicates invalid pixels in retinal image ''' - Phi = torch.empty(self.N, self.D_r[0], self.D_r[1], self.M, 2, dtype=torch.long) # 2 x 480 x 640 x 41 x 2 - mask = torch.empty(self.N, self.D_r[0], self.D_r[1], self.M, 2, dtype=torch.float) + D = self.conf.layer_res + c = torch.tensor([ D[1] / 2, D[0] / 2 ]) # c: Center of layers (pixel) + D_r = self.conf.retinal_res # D_r: Resolution of retinal 480 640 V = self.conf.GetEyeViewportSize() # V: Viewport size of eye - c = (self.conf.layer_res / 2) # c: Center of layers (pixel) p_f = self.p_r * df # p_f: H x W x 3, focus positions of retinal pixels on focus plane - rot_forward = glm.dvec3(glm.tan(glm.radians(glm.dvec2(gaze[1], -gaze[0]))), 1) - rot_mat = torch.from_numpy(np.array( - glm.dmat3(glm.lookAtLH(glm.dvec3(), rot_forward, glm.dvec3(0, 1, 0))))) - rot_mat = rot_mat.float() - u_rot = torch.mm(self.u, rot_mat) - v_rot = torch.matmul(p_f, rot_mat).unsqueeze(2).expand( - -1, -1, self.u.size()[0], -1) - u_rot # v_rot: H x W x M x 3, rotated rays' direction vector - v_rot.div_(v_rot[:, :, :, 2].unsqueeze(3)) # make z = 1 for each direction vector in v_rot - - for i in range(0, self.conf.GetNLayers()): - dp_i = self.conf.GetLayerSize(i)[0] / self.conf.layer_res[0] # dp_i: Pixel size of layer i - d_i = self.conf.d_layer[i] # d_i: Distance of layer i + + # Calculate transformation from eye to display + gvec_lookat = glm.dvec3(gaze_dir[0], -gaze_dir[1], 1) + gmat_eye = glm.inverse(glm.lookAtLH(glm.dvec3(), gvec_lookat, glm.dvec3(0, 1, 0))) + eye_rot = util.Glm2Tensor(glm.dmat3(gmat_eye)) + eye_center = torch.tensor([ position[0], -position[1], position[2] ]) + + u_rot = torch.mm(self.u, eye_rot) + v_rot = torch.matmul(p_f, eye_rot).unsqueeze(2).expand( + -1, -1, self.M, -1) - u_rot # v_rot: H x W x M x 3, rotated rays' direction vector + u_rot.add_(eye_center) # translate by eye's center + v_rot = v_rot.div(v_rot[:, :, :, 2].unsqueeze(3)) # make z = 1 for each direction vector in v_rot + + phi = torch.empty(self.N, self.D_r[0], self.D_r[1], self.M, 2, dtype=torch.long) + + for i in range(0, self.N): + dp_i = self.conf.GetPixelSizeOfLayer(i) # dp_i: Pixel size of layer i + d_i = self.conf.d_layer[i] # d_i: Distance of layer i k = (d_i - u_rot[:, 2]).unsqueeze(1) pi_r = (u_rot[:, 0:2] + v_rot[:, :, :, 0:2] * k) / dp_i # pi_r: H x W x M x 2, rays' pixel coord on layer i - Phi[i, :, :, :, :] = torch.floor(pi_r + c) - mask[:, :, :, :, 0] = ((Phi[:, :, :, :, 0] >= 0) & (Phi[:, :, :, :, 0] < self.conf.layer_res[0])).float() - mask[:, :, :, :, 1] = ((Phi[:, :, :, :, 1] >= 0) & (Phi[:, :, :, :, 1] < self.conf.layer_res[1])).float() - Phi[:, :, :, :, 0].clamp_(0, self.conf.layer_res[0] - 1) - Phi[:, :, :, :, 1].clamp_(0, self.conf.layer_res[1] - 1) - retinal_mask = mask.prod(0).prod(2).prod(2) - return [ Phi, retinal_mask ] + phi[i, :, :, :, :] = torch.floor(pi_r + c) + + # Calculate invalid mask (out-of-range elements in phi) and reduced to retinal + phi_invalid = (phi[:, :, :, :, 0] < 0) | (phi[:, :, :, :, 0] >= D[1]) | \ + (phi[:, :, :, :, 1] < 0) | (phi[:, :, :, :, 1] >= D[0]) + phi_invalid = phi_invalid.unsqueeze(4) + # print("phi_invalid:",phi_invalid.shape) + retinal_invalid = phi_invalid.amax((0, 3)).squeeze().unsqueeze(0) + # print("retinal_invalid:",retinal_invalid.shape) + # Fix invalid elements in phi + phi[phi_invalid.expand(-1, -1, -1, -1, 2)] = 0 + + return [ phi, phi_invalid, retinal_invalid ] + def GenRetinalFromLayers(self, layers, Phi): ''' @@ -139,28 +159,159 @@ class RetinalGen(object): Parameters -------- - layers - 3N x H_l x W_l tensor, stacked layer images, with 3 channels in each layer + layers - 3N x H x W, stacked layer images, with 3 channels in each layer + phi - N x H_r x W_r x M x 2, retinal to layers mapping, N is number of layers Returns -------- - 3 x H_r x W_r tensor, 3 channels retinal image - H_r x W_r tensor, retinal image mask, indicates pixels valid or not - + 3 x H_r x W_r, 3 channels retinal image ''' # FOR GRAYSCALE 1 FOR RGB 3 mapped_layers = torch.empty(self.N, 3, self.D_r[0], self.D_r[1], self.M) # 2 x 3 x 480 x 640 x 41 # print("mapped_layers:",mapped_layers.shape) for i in range(0, Phi.size()[0]): + # torch.Size([3, 2, 320, 320, 2]) # print("gather layers:",layers[(i * 3) : (i * 3 + 3),Phi[i, :, :, :, 0],Phi[i, :, :, :, 1]].shape) mapped_layers[i, :, :, :, :] = layers[(i * 3) : (i * 3 + 3), - Phi[i, :, :, :, 0], - Phi[i, :, :, :, 1]] + Phi[i, :, :, :, 1], + Phi[i, :, :, :, 0]] # print("mapped_layers:",mapped_layers.shape) retinal = mapped_layers.prod(0).sum(3).div(Phi.size()[3]) # print("retinal:",retinal.shape) return retinal + + def GenRetinalFromLayersBatch(self, layers, Phi): + ''' + Generate retinal image from layers, using precalculated mapping matrix + + Parameters + -------- + layers - 3N x H_l x W_l tensor, stacked layer images, with 3 channels in each layer + + Returns + -------- + 3 x H_r x W_r tensor, 3 channels retinal image + H_r x W_r tensor, retinal image mask, indicates pixels valid or not + + ''' + mapped_layers = torch.empty(layers.size()[0], self.N, 3, self.D_r[0], self.D_r[1], self.M) #BS x Layers x C x H x W x Sample + + # truth = torch.empty(layers.size()[0], self.N, 3, self.D_r[0], self.D_r[1], self.M) + # layers_truth = layers.clone() + # Phi_truth = Phi.clone() + layers = torch.stack((layers[:,0:3,:,:],layers[:,3:6,:,:]),dim=1) ## torch.Size([BS, Layer, RGB 3, 320, 320]) + + # Phi = Phi[:,:,None,:,:,:,:].expand(-1,-1,3,-1,-1,-1,-1) + # print("mapped_layers:",mapped_layers.shape) #torch.Size([2, 2, 3, 320, 320, 41]) + # print("input layers:",layers.shape) ## torch.Size([2, 2, 3, 320, 320]) + # print("input Phi:",Phi.shape) #torch.Size([2, 2, 320, 320, 41, 2]) + + # #娌′紭鍖� + + # for i in range(0, Phi_truth.size()[0]): + # for j in range(0, Phi_truth.size()[1]): + # truth[i, j, :, :, :, :] = layers_truth[i, (j * 3) : (j * 3 + 3), + # Phi_truth[i, j, :, :, :, 0], + # Phi_truth[i, j, :, :, :, 1]] + + #浼樺寲2 + # start = time.time() + mapped_layers_op1 = mapped_layers.reshape(-1, + mapped_layers.shape[2],mapped_layers.shape[3],mapped_layers.shape[4],mapped_layers.shape[5]) + # BatchSizexLayer Channel 3 320 320 41 + layers_op1 = layers.reshape(-1,layers.shape[2],layers.shape[3],layers.shape[4]) # 2x2 3 320 320 + Phi_op1 = Phi.reshape(-1,Phi.shape[2],Phi.shape[3],Phi.shape[4],Phi.shape[5]) # 2x2 320 320 41 2 + x = Phi_op1[:,:,:,:,0] # 2x2 320 320 41 + y = Phi_op1[:,:,:,:,1] # 2x2 320 320 41 + # print("reshape:",time.time() - start) + + # start = time.time() + mapped_layers_op1 = layers_op1[torch.arange(layers_op1.shape[0])[:, None, None, None], :, y, x] # x,y 鍒囨崲 + #2x2 320 320 41 3 + # print("mapping one step:",time.time() - start) + + # print("mapped_layers:",mapped_layers_op1.shape) # torch.Size([4, 3, 320, 320, 41]) + # start = time.time() + mapped_layers_op1 = mapped_layers_op1.permute(0,4,1,2,3) + mapped_layers = mapped_layers_op1.reshape(mapped_layers.shape[0],mapped_layers.shape[1], + mapped_layers.shape[2],mapped_layers.shape[3],mapped_layers.shape[4],mapped_layers.shape[5]) + # print("reshape end:",time.time() - start) + + # print("test:") + # print((truth.cpu() == mapped_layers.cpu()).all()) + #浼樺寲1 + # start = time.time() + # mapped_layers_op1 = mapped_layers.reshape(-1, + # mapped_layers.shape[2],mapped_layers.shape[3],mapped_layers.shape[4],mapped_layers.shape[5]) + # layers_op1 = layers.reshape(-1,layers.shape[2],layers.shape[3],layers.shape[4]) + # Phi_op1 = Phi.reshape(-1,Phi.shape[2],Phi.shape[3],Phi.shape[4],Phi.shape[5]) + # print("reshape:",time.time() - start) + + + # for i in range(0, Phi_op1.size()[0]): + # start = time.time() + # mapped_layers_op1[i, :, :, :, :] = layers_op1[i,:, + # Phi_op1[i, :, :, :, 0], + # Phi_op1[i, :, :, :, 1]] + # print("mapping one step:",time.time() - start) + # print("mapped_layers:",mapped_layers_op1.shape) # torch.Size([4, 3, 320, 320, 41]) + # start = time.time() + # mapped_layers = mapped_layers_op1.reshape(mapped_layers.shape[0],mapped_layers.shape[1], + # mapped_layers.shape[2],mapped_layers.shape[3],mapped_layers.shape[4],mapped_layers.shape[5]) + # print("reshape end:",time.time() - start) + + # print("mapped_layers:",mapped_layers.shape) # torch.Size([2, 2, 3, 320, 320, 41]) + retinal = mapped_layers.prod(1).sum(4).div(Phi.size()[4]) + # print("retinal:",retinal.shape) # torch.Size([BatchSize, 3, 320, 320]) + return retinal + + ## TO BE CHECK + def GenFoveaLayers(self, b_retinal, is_mask): + ''' + Generate foveated layers for retinal images or masks + + Parameters + -------- + b_retinal - B x C x H_r x W_r, Batch of retinal images/masks + is_mask - Whether b_retinal is masks or images - def GenFoveaLayers(self, retinal, retinal_mask): + Returns + -------- + b_fovea_layers - N_f x (B x C x H[f] x W[f]) list of batch of foveated layers + ''' + b_fovea_layers = [] + for i in range(0, len(self.conf.eye_fovea_angles)): + k = self.conf.eye_fovea_downsamples[i] + region = self.conf.GetRegionOfFoveaLayer(i) + b_roi = b_retinal[:, :, region, region] + if k == 1: + b_fovea_layers.append(b_roi) + elif is_mask: + b_fovea_layers.append(torch.nn.functional.max_pool2d(b_roi.to(torch.float), k).to(torch.bool)) + else: + b_fovea_layers.append(torch.nn.functional.avg_pool2d(b_roi, k)) + return b_fovea_layers + # fovea_layers = [] + # fovea_layer_masks = [] + # fov = self.conf.eye_fovea_angles[-1] + # retinal_res = int(self.conf.retinal_res[0]) + # for i in range(0, len(self.conf.eye_fovea_angles)): + # angle = self.conf.eye_fovea_angles[i] + # k = self.conf.eye_fovea_downsamples[i] + # roi_size = int(np.ceil(retinal_res * angle / fov)) + # roi_offset = int((retinal_res - roi_size) / 2) + # roi_img = retinal[:, roi_offset:(roi_offset + roi_size), roi_offset:(roi_offset + roi_size)] + # roi_mask = retinal_mask[roi_offset:(roi_offset + roi_size), roi_offset:(roi_offset + roi_size)] + # if k == 1: + # fovea_layers.append(roi_img) + # fovea_layer_masks.append(roi_mask) + # else: + # fovea_layers.append(torch.nn.functional.avg_pool2d(roi_img.unsqueeze(0), k).squeeze(0)) + # fovea_layer_masks.append(1 - torch.nn.functional.max_pool2d((1 - roi_mask).unsqueeze(0), k).squeeze(0)) + # return [ fovea_layers, fovea_layer_masks ] + + ## TO BE CHECK + def GenFoveaLayersBatch(self, retinal, retinal_mask): ''' Generate foveated layers and corresponding masks @@ -177,18 +328,48 @@ class RetinalGen(object): fovea_layers = [] fovea_layer_masks = [] fov = self.conf.eye_fovea_angles[-1] + # print("fov:",fov) retinal_res = int(self.conf.retinal_res[0]) + # print("retinal_res:",retinal_res) + # print("len(self.conf.eye_fovea_angles):",len(self.conf.eye_fovea_angles)) for i in range(0, len(self.conf.eye_fovea_angles)): angle = self.conf.eye_fovea_angles[i] k = self.conf.eye_fovea_downsamples[i] roi_size = int(np.ceil(retinal_res * angle / fov)) roi_offset = int((retinal_res - roi_size) / 2) - roi_img = retinal[:, roi_offset:(roi_offset + roi_size), roi_offset:(roi_offset + roi_size)] - roi_mask = retinal_mask[roi_offset:(roi_offset + roi_size), roi_offset:(roi_offset + roi_size)] + # [2, 3, 320, 320] + roi_img = retinal[:, :, roi_offset:(roi_offset + roi_size), roi_offset:(roi_offset + roi_size)] + # print("roi_img:",roi_img.shape) + # [2, 320, 320] + roi_mask = retinal_mask[:, roi_offset:(roi_offset + roi_size), roi_offset:(roi_offset + roi_size)] + # print("roi_mask:",roi_mask.shape) if k == 1: fovea_layers.append(roi_img) fovea_layer_masks.append(roi_mask) else: - fovea_layers.append(torch.nn.functional.avg_pool2d(roi_img.unsqueeze(0), k).squeeze(0)) - fovea_layer_masks.append(1 - torch.nn.functional.max_pool2d((1 - roi_mask).unsqueeze(0), k).squeeze(0)) - return [ fovea_layers, fovea_layer_masks ] \ No newline at end of file + fovea_layers.append(torch.nn.functional.avg_pool2d(roi_img, k)) + fovea_layer_masks.append(1 - torch.nn.functional.max_pool2d((1 - roi_mask), k)) + return [ fovea_layers, fovea_layer_masks ] + + ## TO BE CHECK + def GenFoveaRetinal(self, b_fovea_layers): + ''' + Generate foveated retinal image by blending fovea layers + **Note: current implementation only support two fovea layers** + + Parameters + -------- + b_fovea_layers - N_f x (B x 3 x H[f] x W[f]), list of batch of (masked) foveated layers + + Returns + -------- + B x 3 x H_r x W_r, batch of foveated retinal images + ''' + b_fovea_retinal = torch.nn.functional.interpolate(b_fovea_layers[1], + scale_factor=self.conf.eye_fovea_downsamples[1], + mode='bilinear', align_corners=False) + region = self.conf.GetRegionOfFoveaLayer(0) + blend = self.conf.eye_fovea_blend[0] + b_roi = b_fovea_retinal[:, :, region, region] + b_roi.mul_(1 - blend).add_(b_fovea_layers[0] * blend) + return b_fovea_retinal diff --git a/loss.py b/loss.py new file mode 100644 index 0000000..7307a67 --- /dev/null +++ b/loss.py @@ -0,0 +1,80 @@ +import torch +from ssim import * +from perc_loss import * + +l1loss = torch.nn.L1Loss() +perc_loss = VGGPerceptualLoss() +perc_loss = perc_loss.to("cuda:1") + +##### LOSS ##### +def calImageGradients(images): + # x is a 4-D tensor + dx = images[:, :, 1:, :] - images[:, :, :-1, :] + dy = images[:, :, :, 1:] - images[:, :, :, :-1] + return dx, dy + +def loss_new(generated, gt): + mse_loss = torch.nn.MSELoss() + rmse_intensity = mse_loss(generated, gt) + psnr_intensity = torch.log10(rmse_intensity) + # print("psnr:",psnr_intensity) + # ssim_intensity = ssim(generated, gt) + labels_dx, labels_dy = calImageGradients(gt) + # print("generated:",generated.shape) + preds_dx, preds_dy = calImageGradients(generated) + rmse_grad_x, rmse_grad_y = mse_loss(labels_dx, preds_dx), mse_loss(labels_dy, preds_dy) + psnr_grad_x, psnr_grad_y = torch.log10(rmse_grad_x), torch.log10(rmse_grad_y) + # print("psnr x&y:",psnr_grad_x," ",psnr_grad_y) + p_loss = perc_loss(generated,gt) + # print("-psnr:",-psnr_intensity,",0.5*(psnr_grad_x + psnr_grad_y):",0.5*(psnr_grad_x + psnr_grad_y),",perc_loss:",p_loss) + total_loss = psnr_intensity + 0.5*(psnr_grad_x + psnr_grad_y) + p_loss + # total_loss = rmse_intensity + 0.5*(rmse_grad_x + rmse_grad_y) # + p_loss + return total_loss + +def loss_without_perc(generated, gt): + mse_loss = torch.nn.MSELoss() + rmse_intensity = mse_loss(generated, gt) + psnr_intensity = torch.log10(rmse_intensity) + # print("psnr:",psnr_intensity) + # ssim_intensity = ssim(generated, gt) + labels_dx, labels_dy = calImageGradients(gt) + # print("generated:",generated.shape) + preds_dx, preds_dy = calImageGradients(generated) + rmse_grad_x, rmse_grad_y = mse_loss(labels_dx, preds_dx), mse_loss(labels_dy, preds_dy) + psnr_grad_x, psnr_grad_y = torch.log10(rmse_grad_x), torch.log10(rmse_grad_y) + # print("psnr x&y:",psnr_grad_x," ",psnr_grad_y) + # print("-psnr:",-psnr_intensity,",0.5*(psnr_grad_x + psnr_grad_y):",0.5*(psnr_grad_x + psnr_grad_y),",perc_loss:",p_loss) + total_loss = psnr_intensity + 0.5*(psnr_grad_x + psnr_grad_y) + # total_loss = rmse_intensity + 0.5*(rmse_grad_x + rmse_grad_y) # + p_loss + return total_loss +##### LOSS ##### + + +class ReconstructionLoss(torch.nn.Module): + def __init__(self): + super(ReconstructionLoss, self).__init__() + + def forward(self, generated, gt): + rmse_intensity = torch.nn.functional.mse_loss(generated, gt) + psnr_intensity = torch.log10(rmse_intensity) + labels_dx, labels_dy = calImageGradients(gt) + preds_dx, preds_dy = calImageGradients(generated) + rmse_grad_x, rmse_grad_y = torch.nn.functional.mse_loss(labels_dx, preds_dx), torch.nn.functional.mse_loss(labels_dy, preds_dy) + psnr_grad_x, psnr_grad_y = torch.log10(rmse_grad_x), torch.log10(rmse_grad_y) + total_loss = psnr_intensity + 0.5*(psnr_grad_x + psnr_grad_y) + return total_loss + +class PerceptionReconstructionLoss(torch.nn.Module): + def __init__(self): + super(PerceptionReconstructionLoss, self).__init__() + + def forward(self, generated, gt): + rmse_intensity = torch.nn.functional.mse_loss(generated, gt) + psnr_intensity = torch.log10(rmse_intensity) + labels_dx, labels_dy = calImageGradients(gt) + preds_dx, preds_dy = calImageGradients(generated) + rmse_grad_x, rmse_grad_y = torch.nn.functional.mse_loss(labels_dx, preds_dx), torch.nn.functional.mse_loss(labels_dy, preds_dy) + psnr_grad_x, psnr_grad_y = torch.log10(rmse_grad_x), torch.log10(rmse_grad_y) + p_loss = perc_loss(generated,gt) + total_loss = psnr_intensity + 0.5*(psnr_grad_x + psnr_grad_y) + p_loss + return total_loss diff --git a/main.py b/main.py index 278ad74..104c174 100644 --- a/main.py +++ b/main.py @@ -12,212 +12,33 @@ from torch.autograd import Variable import cv2 from gen_image import * +from loss import * import json -from ssim import * -from perc_loss import * from conf import Conf -from model.baseline import * +from baseline import * +from data import * import torch.autograd.profiler as profiler # param -BATCH_SIZE = 2 -NUM_EPOCH = 300 - +BATCH_SIZE = 1 +NUM_EPOCH = 1001 INTERLEAVE_RATE = 2 - IM_H = 320 IM_W = 320 - Retinal_IM_H = 320 Retinal_IM_W = 320 - -N = 9 # number of input light field stack +N = 25 # number of input light field stack M = 2 # number of display layers - -DATA_FILE = "/home/yejiannan/Project/LightField/data/gaze_fovea" -DATA_JSON = "/home/yejiannan/Project/LightField/data/data_gaze_fovea_seq.json" -DATA_VAL_JSON = "/home/yejiannan/Project/LightField/data/data_gaze_fovea_val.json" -OUTPUT_DIR = "/home/yejiannan/Project/LightField/outputE/gaze_fovea_seq" - - +DATA_FILE = "/home/yejiannan/Project/LightField/data/FlowRPG1211" +DATA_JSON = "/home/yejiannan/Project/LightField/data/data_gaze_fovea_seq_flow_RPG.json" +# DATA_VAL_JSON = "/home/yejiannan/Project/LightField/data/data_gaze_fovea_val.json" +OUTPUT_DIR = "/home/yejiannan/Project/LightField/outputE/gaze_fovea_seq_flow_RPG_seq5_same_loss" OUT_CHANNELS_RB = 128 KERNEL_SIZE_RB = 3 KERNEL_SIZE = 3 - LAST_LAYER_CHANNELS = 6 * INTERLEAVE_RATE**2 -FIRSST_LAYER_CHANNELS = 27 * INTERLEAVE_RATE**2 - -class lightFieldDataLoader(torch.utils.data.dataset.Dataset): - def __init__(self, file_dir_path, file_json, transforms=None): - self.file_dir_path = file_dir_path - self.transforms = transforms - # self.datum_list = glob.glob(os.path.join(file_dir_path,"*")) - with open(file_json, encoding='utf-8') as file: - self.dataset_desc = json.loads(file.read()) - - def __len__(self): - return len(self.dataset_desc["focaldepth"]) - - def __getitem__(self, idx): - lightfield_images, gt, gt2, fd, gazeX, gazeY, sample_idx = self.get_datum(idx) - if self.transforms: - lightfield_images = self.transforms(lightfield_images) - # print(lightfield_images.shape,gt.shape,fd,gazeX,gazeY,sample_idx) - return (lightfield_images, gt, gt2, fd, gazeX, gazeY, sample_idx) - - def get_datum(self, idx): - lf_image_paths = os.path.join(DATA_FILE, self.dataset_desc["train"][idx]) - # print(lf_image_paths) - fd_gt_path = os.path.join(DATA_FILE, self.dataset_desc["gt"][idx]) - fd_gt_path2 = os.path.join(DATA_FILE, self.dataset_desc["gt2"][idx]) - # print(fd_gt_path) - lf_images = [] - lf_image_big = cv2.imread(lf_image_paths, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255. - lf_image_big = cv2.cvtColor(lf_image_big,cv2.COLOR_BGR2RGB) - - ## IF GrayScale - # lf_image_big = cv2.imread(lf_image_paths, cv2.IMREAD_GRAYSCALE).astype(np.float32) / 255. - # lf_image_big = np.expand_dims(lf_image_big, axis=-1) - # print(lf_image_big.shape) - - for i in range(9): - lf_image = lf_image_big[i//3*IM_H:i//3*IM_H+IM_H,i%3*IM_W:i%3*IM_W+IM_W,0:3] - ## IF GrayScale - # lf_image = lf_image_big[i//3*IM_H:i//3*IM_H+IM_H,i%3*IM_W:i%3*IM_W+IM_W,0:1] - # print(lf_image.shape) - lf_images.append(lf_image) - gt = cv2.imread(fd_gt_path, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255. - gt = cv2.cvtColor(gt,cv2.COLOR_BGR2RGB) - gt2 = cv2.imread(fd_gt_path2, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255. - gt2 = cv2.cvtColor(gt2,cv2.COLOR_BGR2RGB) - ## IF GrayScale - # gt = cv2.imread(fd_gt_path, cv2.IMREAD_GRAYSCALE).astype(np.float32) / 255. - # gt = np.expand_dims(gt, axis=-1) - - fd = self.dataset_desc["focaldepth"][idx] - gazeX = self.dataset_desc["gazeX"][idx] - gazeY = self.dataset_desc["gazeY"][idx] - sample_idx = self.dataset_desc["idx"][idx] - return np.asarray(lf_images),gt,gt2,fd,gazeX,gazeY,sample_idx - -class lightFieldValDataLoader(torch.utils.data.dataset.Dataset): - def __init__(self, file_dir_path, file_json, transforms=None): - self.file_dir_path = file_dir_path - self.transforms = transforms - # self.datum_list = glob.glob(os.path.join(file_dir_path,"*")) - with open(file_json, encoding='utf-8') as file: - self.dataset_desc = json.loads(file.read()) - - def __len__(self): - return len(self.dataset_desc["focaldepth"]) - - def __getitem__(self, idx): - lightfield_images, fd, gazeX, gazeY, sample_idx = self.get_datum(idx) - if self.transforms: - lightfield_images = self.transforms(lightfield_images) - # print(lightfield_images.shape,gt.shape,fd,gazeX,gazeY,sample_idx) - return (lightfield_images, fd, gazeX, gazeY, sample_idx) - - def get_datum(self, idx): - lf_image_paths = os.path.join(DATA_FILE, self.dataset_desc["train"][idx]) - # print(fd_gt_path) - lf_images = [] - lf_image_big = cv2.imread(lf_image_paths, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255. - lf_image_big = cv2.cvtColor(lf_image_big,cv2.COLOR_BGR2RGB) - - ## IF GrayScale - # lf_image_big = cv2.imread(lf_image_paths, cv2.IMREAD_GRAYSCALE).astype(np.float32) / 255. - # lf_image_big = np.expand_dims(lf_image_big, axis=-1) - # print(lf_image_big.shape) - - for i in range(9): - lf_image = lf_image_big[i//3*IM_H:i//3*IM_H+IM_H,i%3*IM_W:i%3*IM_W+IM_W,0:3] - ## IF GrayScale - # lf_image = lf_image_big[i//3*IM_H:i//3*IM_H+IM_H,i%3*IM_W:i%3*IM_W+IM_W,0:1] - # print(lf_image.shape) - lf_images.append(lf_image) - ## IF GrayScale - # gt = cv2.imread(fd_gt_path, cv2.IMREAD_GRAYSCALE).astype(np.float32) / 255. - # gt = np.expand_dims(gt, axis=-1) - - fd = self.dataset_desc["focaldepth"][idx] - gazeX = self.dataset_desc["gazeX"][idx] - gazeY = self.dataset_desc["gazeY"][idx] - sample_idx = self.dataset_desc["idx"][idx] - return np.asarray(lf_images),fd,gazeX,gazeY,sample_idx - -class lightFieldSeqDataLoader(torch.utils.data.dataset.Dataset): - def __init__(self, file_dir_path, file_json, transforms=None): - self.file_dir_path = file_dir_path - self.transforms = transforms - with open(file_json, encoding='utf-8') as file: - self.dataset_desc = json.loads(file.read()) - - def __len__(self): - return len(self.dataset_desc["seq"]) - - def __getitem__(self, idx): - lightfield_images, gt, gt2, fd, gazeX, gazeY, sample_idx = self.get_datum(idx) - fd = fd.astype(np.float32) - gazeX = gazeX.astype(np.float32) - gazeY = gazeY.astype(np.float32) - sample_idx = sample_idx.astype(np.int64) - # print(fd) - # print(gazeX) - # print(gazeY) - # print(sample_idx) - - # print(lightfield_images.dtype,gt.dtype, gt2.dtype, fd.dtype, gazeX.dtype, gazeY.dtype, sample_idx.dtype, delta.dtype) - # print(lightfield_images.shape,gt.shape, gt2.shape, fd.shape, gazeX.shape, gazeY.shape, sample_idx.shape, delta.shape) - if self.transforms: - lightfield_images = self.transforms(lightfield_images) - return (lightfield_images, gt, gt2, fd, gazeX, gazeY, sample_idx) - - def get_datum(self, idx): - indices = self.dataset_desc["seq"][idx] - # print("indices:",indices) - lf_images = [] - fd = [] - gazeX = [] - gazeY = [] - sample_idx = [] - gt = [] - gt2 = [] - for i in range(len(indices)): - lf_image_paths = os.path.join(DATA_FILE, self.dataset_desc["train"][indices[i]]) - fd_gt_path = os.path.join(DATA_FILE, self.dataset_desc["gt"][indices[i]]) - fd_gt_path2 = os.path.join(DATA_FILE, self.dataset_desc["gt2"][indices[i]]) - lf_image_one_sample = [] - lf_image_big = cv2.imread(lf_image_paths, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255. - lf_image_big = cv2.cvtColor(lf_image_big,cv2.COLOR_BGR2RGB) - - for j in range(9): - lf_image = lf_image_big[j//3*IM_H:j//3*IM_H+IM_H,j%3*IM_W:j%3*IM_W+IM_W,0:3] - ## IF GrayScale - # lf_image = lf_image_big[i//3*IM_H:i//3*IM_H+IM_H,i%3*IM_W:i%3*IM_W+IM_W,0:1] - # print(lf_image.shape) - lf_image_one_sample.append(lf_image) - - gt_i = cv2.imread(fd_gt_path, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255. - gt.append(cv2.cvtColor(gt_i,cv2.COLOR_BGR2RGB)) - gt2_i = cv2.imread(fd_gt_path2, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255. - gt2.append(cv2.cvtColor(gt2_i,cv2.COLOR_BGR2RGB)) - - # print("indices[i]:",indices[i]) - fd.append([self.dataset_desc["focaldepth"][indices[i]]]) - gazeX.append([self.dataset_desc["gazeX"][indices[i]]]) - gazeY.append([self.dataset_desc["gazeY"][indices[i]]]) - sample_idx.append([self.dataset_desc["idx"][indices[i]]]) - lf_images.append(lf_image_one_sample) - #lf_images: 5,9,320,320 - - return np.asarray(lf_images),np.asarray(gt),np.asarray(gt2),np.asarray(fd),np.asarray(gazeX),np.asarray(gazeY),np.asarray(sample_idx) - -#### Image Gen -conf = Conf() -u = GenSamplesInPupil(conf.pupil_size, 5) -gen = RetinalGen(conf, u) +FIRSST_LAYER_CHANNELS = 75 * INTERLEAVE_RATE**2 def GenRetinalFromLayersBatch(layers, gen, sample_idx, phi_dict, mask_dict): # layers: batchsize, 2*color, height, width @@ -248,6 +69,7 @@ def GenRetinalGazeFromLayersBatch(layers, gen, sample_idx, phi_dict, mask_dict): # retinal bs x color x height x width retinal_fovea = torch.empty(layers.shape[0], 6, 160, 160) mask_fovea = torch.empty(layers.shape[0], 2, 160, 160) + for i in range(0, layers.size()[0]): phi = phi_dict[int(sample_idx[i].data)] # print("phi_i:",phi.shape) @@ -261,12 +83,65 @@ def GenRetinalGazeFromLayersBatch(layers, gen, sample_idx, phi_dict, mask_dict): fovea_layers, fovea_layer_masks = gen.GenFoveaLayers(retinal_i,mask_i) retinal_fovea[i] = torch.cat([fovea_layers[0],fovea_layers[1]],dim=0) mask_fovea[i] = torch.stack([fovea_layer_masks[0],fovea_layer_masks[1]],dim=0) - + + + retinal_fovea = var_or_cuda(retinal_fovea) + mask_fovea = var_or_cuda(mask_fovea) # batch x 2 x height x width + # mask = torch.stack(mask,dim = 0).unsqueeze(1) + return retinal_fovea, mask_fovea + +import time +def GenRetinalGazeFromLayersBatchSpeed(layers, gen, phi, phi_invalid, retinal_invalid): + # layers: batchsize, 2*color, height, width + # Phi:torch.Size([batchsize, Layer, h, w, 41, 2]) + # df : batchsize,.. + # start1 = time.time() + # retinal bs x color x height x width + retinal_fovea = torch.empty((layers.shape[0], 6, 160, 160),device="cuda:2") + mask_fovea = torch.empty((layers.shape[0], 2, 160, 160),device="cuda:2") + # start = time.time() + retinal = gen.GenRetinalFromLayersBatch(layers,phi_batch) + # print("retinal:",retinal.shape) #retinal: torch.Size([2, 3, 320, 320]) + # print("t2:",time.time() - start) + + # start = time.time() + fovea_layers, fovea_layer_masks = gen.GenFoveaLayersBatch(retinal,mask_batch) + + mask_fovea = torch.stack([fovea_layer_masks[0],fovea_layer_masks[1]],dim=1) + retinal_fovea = torch.cat([fovea_layers[0],fovea_layers[1]],dim=1) + # print("t3:",time.time() - start) + retinal_fovea = var_or_cuda(retinal_fovea) mask_fovea = var_or_cuda(mask_fovea) # batch x 2 x height x width # mask = torch.stack(mask,dim = 0).unsqueeze(1) return retinal_fovea, mask_fovea +def MergeBatchSpeed(layers, gen, phi, phi_invalid, retinal_invalid): + # layers: batchsize, 2*color, height, width + # Phi:torch.Size([batchsize, Layer, h, w, 41, 2]) + # df : batchsize,.. + # start1 = time.time() + # retinal bs x color x height x width + # retinal_fovea = torch.empty((layers.shape[0], 6, 160, 160),device="cuda:2") + # mask_fovea = torch.empty((layers.shape[0], 2, 160, 160),device="cuda:2") + # start = time.time() + retinal = gen.GenRetinalFromLayersBatch(layers,phi) #retinal: torch.Size([BatchSize , 3, 320, 320]) + retinal.mul_(~retinal_invalid.to("cuda:2")) + # print("retinal:",retinal.shape) + # print("t2:",time.time() - start) + + # start = time.time() + # fovea_layers, fovea_layer_masks = gen.GenFoveaLayersBatch(retinal,mask_batch) + + # mask_fovea = torch.stack([fovea_layer_masks[0],fovea_layer_masks[1]],dim=1) + # retinal_fovea = torch.cat([fovea_layers[0],fovea_layers[1]],dim=1) + # print("t3:",time.time() - start) + + # retinal_fovea = var_or_cuda(retinal_fovea) + # mask_fovea = var_or_cuda(mask_fovea) # batch x 2 x height x width + # mask = torch.stack(mask,dim = 0).unsqueeze(1) + return retinal + def GenRetinalFromLayersBatch_Online(layers, gen, phi, mask): # layers: batchsize, 2*color, height, width # Phi:torch.Size([batchsize, 480, 640, 2, 41, 2]) @@ -285,47 +160,7 @@ def GenRetinalFromLayersBatch_Online(layers, gen, phi, mask): return retinal.unsqueeze(0), mask_out #### Image Gen End -weightVarScale = 0.25 -bias_stddev = 0.01 - -def weight_init_normal(m): - classname = m.__class__.__name__ - if classname.find("Conv") != -1: - torch.nn.init.xavier_normal_(m.weight.data) - torch.nn.init.normal_(m.bias.data,mean = 0.0, std=bias_stddev) - elif classname.find("BatchNorm2d") != -1: - torch.nn.init.normal_(m.weight.data, 1.0, 0.02) - torch.nn.init.constant_(m.bias.data, 0.0) - - -def calImageGradients(images): - # x is a 4-D tensor - dx = images[:, :, 1:, :] - images[:, :, :-1, :] - dy = images[:, :, :, 1:] - images[:, :, :, :-1] - return dx, dy - - -perc_loss = VGGPerceptualLoss() -perc_loss = perc_loss.to("cuda:1") - -def loss_new(generated, gt): - mse_loss = torch.nn.MSELoss() - rmse_intensity = mse_loss(generated, gt) - - psnr_intensity = torch.log10(rmse_intensity) - # print("psnr:",psnr_intensity) - # ssim_intensity = ssim(generated, gt) - labels_dx, labels_dy = calImageGradients(gt) - # print("generated:",generated.shape) - preds_dx, preds_dy = calImageGradients(generated) - rmse_grad_x, rmse_grad_y = mse_loss(labels_dx, preds_dx), mse_loss(labels_dy, preds_dy) - psnr_grad_x, psnr_grad_y = torch.log10(rmse_grad_x), torch.log10(rmse_grad_y) - # print("psnr x&y:",psnr_grad_x," ",psnr_grad_y) - p_loss = perc_loss(generated,gt) - # print("-psnr:",-psnr_intensity,",0.5*(psnr_grad_x + psnr_grad_y):",0.5*(psnr_grad_x + psnr_grad_y),",perc_loss:",p_loss) - total_loss = 10 + psnr_intensity + 0.5*(psnr_grad_x + psnr_grad_y) + p_loss - # total_loss = rmse_intensity + 0.5*(rmse_grad_x + rmse_grad_y) # + p_loss - return total_loss +from weight_init import weight_init_normal def save_checkpoints(file_path, epoch_idx, model, model_solver): print('[INFO] Saving checkpoint to %s ...' % ( file_path)) @@ -336,93 +171,112 @@ def save_checkpoints(file_path, epoch_idx, model, model_solver): } torch.save(checkpoint, file_path) -mode = "train" - -import pickle -def save_obj(obj, name ): - # with open('./outputF/dict/'+ name + '.pkl', 'wb') as f: - # pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL) - torch.save(obj,'./outputF/dict/'+ name + '.pkl') -def load_obj(name): - # with open('./outputF/dict/' + name + '.pkl', 'rb') as f: - # return pickle.load(f) - return torch.load('./outputF/dict/'+ name + '.pkl') - -def hook_fn_back(m, i, o): - for grad in i: - try: - print("Input Grad:",m,grad.shape,grad.sum()) - except AttributeError: - print ("None found for Gradient") - for grad in o: - try: - print("Output Grad:",m,grad.shape,grad.sum()) - except AttributeError: - print ("None found for Gradient") - print("\n") - -def hook_fn_for(m, i, o): - for grad in i: - try: - print("Input Feats:",m,grad.shape,grad.sum()) - except AttributeError: - print ("None found for Gradient") - for grad in o: - try: - print("Output Feats:",m,grad.shape,grad.sum()) - except AttributeError: - print ("None found for Gradient") - print("\n") - -def generatePhiMaskDict(data_json, generator): - phi_dict = {} - mask_dict = {} - idx_info_dict = {} - with open(data_json, encoding='utf-8') as file: - dataset_desc = json.loads(file.read()) - for i in range(len(dataset_desc["focaldepth"])): - # if i == 2: - # break - idx = dataset_desc["idx"][i] - focaldepth = dataset_desc["focaldepth"][i] - gazeX = dataset_desc["gazeX"][i] - gazeY = dataset_desc["gazeY"][i] - print("focaldepth:",focaldepth," idx:",idx," gazeX:",gazeX," gazeY:",gazeY) - phi,mask = generator.CalculateRetinal2LayerMappings(focaldepth,torch.tensor([gazeX, gazeY])) - phi_dict[idx]=phi - mask_dict[idx]=mask - idx_info_dict[idx]=[idx,focaldepth,gazeX,gazeY] - return phi_dict,mask_dict,idx_info_dict - +# import pickle +# def save_obj(obj, name ): +# # with open('./outputF/dict/'+ name + '.pkl', 'wb') as f: +# # pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL) +# torch.save(obj,'./outputF/dict/'+ name + '.pkl') +# def load_obj(name): +# # with open('./outputF/dict/' + name + '.pkl', 'rb') as f: +# # return pickle.load(f) +# return torch.load('./outputF/dict/'+ name + '.pkl') + +# def generatePhiMaskDict(data_json, generator): +# phi_dict = {} +# mask_dict = {} +# idx_info_dict = {} +# with open(data_json, encoding='utf-8') as file: +# dataset_desc = json.loads(file.read()) +# for i in range(len(dataset_desc["focaldepth"])): +# # if i == 2: +# # break +# idx = dataset_desc["idx"][i] +# focaldepth = dataset_desc["focaldepth"][i] +# gazeX = dataset_desc["gazeX"][i] +# gazeY = dataset_desc["gazeY"][i] +# print("focaldepth:",focaldepth," idx:",idx," gazeX:",gazeX," gazeY:",gazeY) +# phi,mask = generator.CalculateRetinal2LayerMappings(focaldepth,torch.tensor([gazeX, gazeY])) +# phi_dict[idx]=phi +# mask_dict[idx]=mask +# idx_info_dict[idx]=[idx,focaldepth,gazeX,gazeY] +# return phi_dict,mask_dict,idx_info_dict + +# def generatePhiMaskDictNew(data_json, generator): +# phi_dict = {} +# mask_dict = {} +# idx_info_dict = {} +# with open(data_json, encoding='utf-8') as file: +# dataset_desc = json.loads(file.read()) +# for i in range(len(dataset_desc["seq"])): +# for j in dataset_desc["seq"][i]: +# idx = dataset_desc["idx"][j] +# focaldepth = dataset_desc["focaldepth"][j] +# gazeX = dataset_desc["gazeX"][j] +# gazeY = dataset_desc["gazeY"][j] +# print("focaldepth:",focaldepth," idx:",idx," gazeX:",gazeX," gazeY:",gazeY) +# phi,mask = generator.CalculateRetinal2LayerMappings(focaldepth,torch.tensor([gazeX, gazeY])) +# phi_dict[idx]=phi +# mask_dict[idx]=mask +# idx_info_dict[idx]=[idx,focaldepth,gazeX,gazeY] +# return phi_dict,mask_dict,idx_info_dict + +mode = "Silence" #"Perf" +model_type = "RNN" #"RNN" +w_frame = 0.9 +w_inter_frame = 0.1 +batch_model = "NoSingle" +loss1 = ReconstructionLoss() +loss2 = ReconstructionLoss() if __name__ == "__main__": ############################## generate phi and mask in pre-training # print("generating phi and mask...") - # phi_dict,mask_dict,idx_info_dict = generatePhiMaskDict(DATA_JSON,gen) - # save_obj(phi_dict,"phi_1204") - # save_obj(mask_dict,"mask_1204") - # save_obj(idx_info_dict,"idx_info_1204") + # phi_dict,mask_dict,idx_info_dict = generatePhiMaskDictNew(DATA_JSON,gen) + # # save_obj(phi_dict,"phi_1204") + # # save_obj(mask_dict,"mask_1204") + # # save_obj(idx_info_dict,"idx_info_1204") # print("generating phi and mask end.") # exit(0) ############################# load phi and mask in pre-training - print("loading phi and mask ...") - phi_dict = load_obj("phi_1204") - mask_dict = load_obj("mask_1204") - idx_info_dict = load_obj("idx_info_1204") - print(len(phi_dict)) - print(len(mask_dict)) - print("loading phi and mask end") + # print("loading phi and mask ...") + # phi_dict = load_obj("phi_1204") + # mask_dict = load_obj("mask_1204") + # idx_info_dict = load_obj("idx_info_1204") + # print(len(phi_dict)) + # print(len(mask_dict)) + # print("loading phi and mask end") + + #### Image Gen and conf + conf = Conf() + gen = RetinalGen(conf) #train - train_data_loader = torch.utils.data.DataLoader(dataset=lightFieldSeqDataLoader(DATA_FILE,DATA_JSON), + train_data_loader = torch.utils.data.DataLoader(dataset=lightFieldFlowSeqDataLoader(DATA_FILE,DATA_JSON, gen, conf), batch_size=BATCH_SIZE, - num_workers=0, + num_workers=8, pin_memory=True, shuffle=True, drop_last=False) + #Data loader test print(len(train_data_loader)) - # exit(0) + # # lightfield_images, gt, flow, fd, gazeX, gazeY, posX, posY, sample_idx, phi, mask + # # lightfield_images, gt, phi, phi_invalid, retinal_invalid, flow, fd, gazeX, gazeY, posX, posY, posZ, sample_idx + # for batch_idx, (image_set, gt,phi, phi_invalid, retinal_invalid, flow, df, gazeX, gazeY, posX, posY, posZ, sample_idx) in enumerate(train_data_loader): + # print(image_set.shape,type(image_set)) + # print(gt.shape,type(gt)) + # print(phi.shape,type(phi)) + # print(phi_invalid.shape,type(phi_invalid)) + # print(retinal_invalid.shape,type(retinal_invalid)) + # print(flow.shape,type(flow)) + # print(df.shape,type(df)) + # print(gazeX.shape,type(gazeX)) + # print(posX.shape,type(posX)) + # print(sample_idx.shape,type(sample_idx)) + # print("test train dataloader.") + # exit(0) + #Data loader test end + ################################################ val ######################################################### @@ -464,21 +318,24 @@ if __name__ == "__main__": # print("output:",output.shape," df:",df[0].data, ",gazeX:",gazeX[0].data,",gazeY:", gazeY[0].data) # for i in range(output1.size()[0]): - # save_image(output[i][0:3].data,os.path.join(OUTPUT_DIR,"test_interp_gaze_fac1_o_%.3f_%.3f_%.3f.png"%(df[i].data,gazeX[i].data,gazeY[i].data))) - # save_image(output[i][3:6].data,os.path.join(OUTPUT_DIR,"test_interp_gaze_fac2_o_%.3f_%.3f_%.3f.png"%(df[i].data,gazeX[i].data,gazeY[i].data))) - # save_image(output1[i][0:3].data,os.path.join(OUTPUT_DIR,"test_interp_gaze_out1_o_%.3f_%.3f_%.3f.png"%(df[i].data,gazeX[i].data,gazeY[i].data))) - # save_image(output1[i][3:6].data,os.path.join(OUTPUT_DIR,"test_interp_gaze_out2_o_%.3f_%.3f_%.3f.png"%(df[i].data,gazeX[i].data,gazeY[i].data))) + # save_image(output[i][0:3].data,os.path.join(OUTPUT_DIR,"test_interp_gaze_fac1_o_%.5f_%.5f_%.5f.png"%(df[i].data,gazeX[i].data,gazeY[i].data))) + # save_image(output[i][3:6].data,os.path.join(OUTPUT_DIR,"test_interp_gaze_fac2_o_%.5f_%.5f_%.5f.png"%(df[i].data,gazeX[i].data,gazeY[i].data))) + # save_image(output1[i][0:3].data,os.path.join(OUTPUT_DIR,"test_interp_gaze_out1_o_%.5f_%.5f_%.5f.png"%(df[i].data,gazeX[i].data,gazeY[i].data))) + # save_image(output1[i][3:6].data,os.path.join(OUTPUT_DIR,"test_interp_gaze_out2_o_%.5f_%.5f_%.5f.png"%(df[i].data,gazeX[i].data,gazeY[i].data))) - # # save_image(output[0][0:3].data,os.path.join(OUTPUT_DIR,"gaze_fovea_interp_l1_%.3f.png"%(df[0].data))) - # # save_image(output[0][3:6].data,os.path.join(OUTPUT_DIR,"gaze_fovea_interp_l2_%.3f.png"%(df[0].data))) + # # save_image(output[0][0:3].data,os.path.join(OUTPUT_DIR,"gaze_fovea_interp_l1_%.5f.png"%(df[0].data))) + # # save_image(output[0][3:6].data,os.path.join(OUTPUT_DIR,"gaze_fovea_interp_l2_%.5f.png"%(df[0].data))) # # output = GenRetinalFromLayersBatch(output,conf,df,v,u) - # # save_image(output[0][0:3].data,os.path.join(OUTPUT_DIR,"1113_interp_o%.3f.png"%(df[0].data))) + # # save_image(output[0][0:3].data,os.path.join(OUTPUT_DIR,"1113_interp_o%.5f.png"%(df[0].data))) # exit() ################################################ train ######################################################### - lf_model = model(FIRSST_LAYER_CHANNELS,LAST_LAYER_CHANNELS,OUT_CHANNELS_RB,KERNEL_SIZE,KERNEL_SIZE_RB,INTERLEAVE_RATE) + if model_type == "RNN": + lf_model = model(FIRSST_LAYER_CHANNELS,LAST_LAYER_CHANNELS,OUT_CHANNELS_RB,KERNEL_SIZE,KERNEL_SIZE_RB,INTERLEAVE_RATE) + else: + lf_model = model(FIRSST_LAYER_CHANNELS,LAST_LAYER_CHANNELS,OUT_CHANNELS_RB,KERNEL_SIZE,KERNEL_SIZE_RB,INTERLEAVE_RATE,RNN=False) lf_model.apply(weight_init_normal) - + lf_model.train() epoch_begin = 0 ################################ load model file @@ -493,144 +350,243 @@ if __name__ == "__main__": if torch.cuda.is_available(): # lf_model = torch.nn.DataParallel(lf_model).cuda() - lf_model = lf_model.to('cuda:1') - lf_model.train() - optimizer = torch.optim.Adam(lf_model.parameters(),lr=1e-2,betas=(0.9,0.999)) - l1loss = torch.nn.L1Loss() + lf_model = lf_model.to('cuda:2') + + optimizer = torch.optim.Adam(lf_model.parameters(),lr=5e-3,betas=(0.9,0.999)) + # lf_model.output_layer.register_backward_hook(hook_fn_back) + if mode=="Perf": + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() print("begin training....") for epoch in range(epoch_begin, NUM_EPOCH): - for batch_idx, (image_set, gt, gt2, df, gazeX, gazeY, sample_idx) in enumerate(train_data_loader): + for batch_idx, (image_set, gt,phi, phi_invalid, retinal_invalid, flow, flow_invalid_mask, df, gazeX, gazeY, posX, posY, posZ, sample_idx) in enumerate(train_data_loader): # print(sample_idx.shape,df.shape,gazeX.shape,gazeY.shape) # torch.Size([2, 5]) # print(image_set.shape,gt.shape,gt2.shape) #torch.Size([2, 5, 9, 320, 320, 3]) torch.Size([2, 5, 160, 160, 3]) torch.Size([2, 5, 160, 160, 3]) # print(delta.shape) # delta: torch.Size([2, 4, 160, 160, 3]) - + if mode=="Perf": + end.record() + torch.cuda.synchronize() + print("load:",start.elapsed_time(end)) + + start.record() #reshape for input - image_set = image_set.permute(0,1,2,5,3,4) # N S LF C H W + image_set = image_set.permute(0,1,2,5,3,4) # N Seq 5 LF 25 C 3 H W image_set = image_set.reshape(image_set.shape[0],image_set.shape[1],-1,image_set.shape[4],image_set.shape[5]) # N, LFxC, H, W + # N, Seq 5, LF 25 C 3, H, W image_set = var_or_cuda(image_set) - gt = gt.permute(0,1,4,2,3) # N S C H W + # print(image_set.shape) #torch.Size([2, 5, 75, 320, 320]) + + gt = gt.permute(0,1,4,2,3) # BS Seq 5 C 3 H W gt = var_or_cuda(gt) - gt2 = gt2.permute(0,1,4,2,3) - gt2 = var_or_cuda(gt2) + flow = var_or_cuda(flow) #BS,Seq-1,H,W,2 + + # gt2 = gt2.permute(0,1,4,2,3) + # gt2 = var_or_cuda(gt2) - gen1 = torch.empty(gt.shape) + gen1 = torch.empty(gt.shape) # BS Seq C H W gen1 = var_or_cuda(gen1) + # print(gen1.shape) #torch.Size([2, 5, 3, 320, 320]) - gen2 = torch.empty(gt2.shape) - gen2 = var_or_cuda(gen2) + # gen2 = torch.empty(gt2.shape) + # gen2 = var_or_cuda(gen2) - warped = torch.empty(gt2.shape[0],gt2.shape[1]-1,gt2.shape[2],gt2.shape[3],gt2.shape[4]) + #BS, Seq - 1, C, H, W + warped = torch.empty(gt.shape[0],gt.shape[1]-1,gt.shape[2],gt.shape[3],gt.shape[4]) warped = var_or_cuda(warped) + gen_temp = torch.empty(warped.shape) + gen_temp = var_or_cuda(gen_temp) + # print("warped:",warped.shape) #warped: torch.Size([2, 4, 3, 320, 320]) + if mode=="Perf": + end.record() + torch.cuda.synchronize() + print("data prepare:",start.elapsed_time(end)) + + start.record() + if model_type == "RNN": + if batch_model != "Single": + for k in range(image_set.shape[1]): + if k == 0: + lf_model.reset_hidden(image_set[:,k]) + output = lf_model(image_set[:,k],df[:,k],gazeX[:,k],gazeY[:,k],posX[:,k],posY[:,k],posZ[:,k]) # batchsize, layer_num x 2 = 6, layer_res: 320, layer_res: 320 + output1 = MergeBatchSpeed(output, gen, phi[:,k], phi_invalid[:,k], retinal_invalid[:,k]) + gen1[:,k] = output1[:,0:3] + gt[:,k] = gt[:,k].mul_(~retinal_invalid[:,k].to("cuda:2")) + + if ((epoch%10 == 0 and epoch != 0) or epoch == 2): + for i in range(output.shape[0]): + save_image(output[i][0:3].data,os.path.join(OUTPUT_DIR,"gaze_fac1_o_%.5f_%.5f_%.5f_%.5f_%.5f.png"%(df[i][k].data,gazeX[i][k].data,gazeY[i][k].data,posX[i][k].data,posY[i][k].data))) + save_image(output[i][3:6].data,os.path.join(OUTPUT_DIR,"gaze_fac2_o_%.5f_%.5f_%.5f_%.5f_%.5f.png"%(df[i][k].data,gazeX[i][k].data,gazeY[i][k].data,posX[i][k].data,posY[i][k].data))) + + for i in range(1,gt.shape[1]): + # print(flow_invalid_mask.shape) #torch.Size([2, 4, 320, 320]) + # print(FlowMap(gen1[:,i-1],flow[:,i-1]).shape) #torch.Size([2, 3, 320, 320]) + warped[:,i-1] = FlowMap(gen1[:,i-1],flow[:,i-1]).mul(~flow_invalid_mask[:,i-1].unsqueeze(1).to("cuda:2")) + gen_temp[:,i-1] = gen1[:,i].mul(~flow_invalid_mask[:,i-1].unsqueeze(1).to("cuda:2")) + else: + for k in range(image_set.shape[1]): + if k == 0: + lf_model.reset_hidden(image_set[:,k]) + output = lf_model(image_set[:,k],df[:,k],gazeX[:,k],gazeY[:,k],posX[:,k],posY[:,k],posZ[:,k]) # batchsize, layer_num x 2 = 6, layer_res: 320, layer_res: 320 + output1 = MergeBatchSpeed(output, gen, phi[:,k], phi_invalid[:,k], retinal_invalid[:,k]) + gen1[:,k] = output1[:,0:3] + gt[:,k] = gt[:,k].mul_(~retinal_invalid[:,k].to("cuda:2")) + + if k != image_set.shape[1]-1: + warped[:,k] = FlowMap(output1.detach(),flow[:,k]).mul(~flow_invalid_mask[:,k].unsqueeze(1).to("cuda:2")) + loss1 = loss_without_perc(output1,gt[:,k]) + loss = (w_frame * loss1) + if k==0: + loss.backward(retain_graph=False) + optimizer.step() + lf_model.zero_grad() + optimizer.zero_grad() + print("Epoch:",epoch,",Iter:",batch_idx,",Seq:",k,",loss:",loss.item()) + else: + output1mask = output1.mul(~flow_invalid_mask[:,k-1].unsqueeze(1).to("cuda:2")) + loss2 = l1loss(output1mask,warped[:,k-1]) + loss += (w_inter_frame * loss2) + loss.backward(retain_graph=False) + optimizer.step() + lf_model.zero_grad() + optimizer.zero_grad() + # print("Epoch:",epoch,",Iter:",batch_idx,",Seq:",k,",loss:",loss.item()) + print("Epoch:",epoch,",Iter:",batch_idx,",Seq:",k,",frame loss:",loss1.item(),",inter loss:",w_inter_frame * loss2.item()) + + + if ((epoch%10 == 0 and epoch != 0) or epoch == 2): + for i in range(output.shape[0]): + save_image(output[i][0:3].data,os.path.join(OUTPUT_DIR,"gaze_fac1_o_%.5f_%.5f_%.5f_%.5f_%.5f.png"%(df[i][k].data,gazeX[i][k].data,gazeY[i][k].data,posX[i][k].data,posY[i][k].data))) + save_image(output[i][3:6].data,os.path.join(OUTPUT_DIR,"gaze_fac2_o_%.5f_%.5f_%.5f_%.5f_%.5f.png"%(df[i][k].data,gazeX[i][k].data,gazeY[i][k].data,posX[i][k].data,posY[i][k].data))) + + # BSxSeq C H W + gen1 = gen1.reshape(-1,gen1.shape[2],gen1.shape[3],gen1.shape[4]) + gt = gt.reshape(-1,gt.shape[2],gt.shape[3],gt.shape[4]) + + if ((epoch%10== 0 and epoch != 0) or epoch == 2): # torch.Size([2, 5, 160, 160, 3]) + for i in range(gt.size()[0]): + save_image(gen1[i].data,os.path.join(OUTPUT_DIR,"gaze_out1_o_%.5f_%.5f_%.5f_%.5f_%.5f.png"%(df[i//5][i%5].data,gazeX[i//5][i%5].data,gazeY[i//5][i%5].data,posX[i//5][i%5].data,posY[i//5][i%5].data))) + save_image(gt[i].data,os.path.join(OUTPUT_DIR,"gaze_test1_gt0_%.5f_%.5f_%.5f_%.5f_%.5f.png"%(df[i//5][i%5].data,gazeX[i//5][i%5].data,gazeY[i//5][i%5].data,posX[i//5][i%5].data,posY[i//5][i%5].data))) + if ((epoch%100 == 0) and epoch != 0 and batch_idx==len(train_data_loader)-1): + save_checkpoints(os.path.join(OUTPUT_DIR, 'gaze-ckpt-epoch-%04d.pth' % (epoch)),epoch,lf_model,optimizer) + else: + if batch_model != "Single": + for k in range(image_set.shape[1]): + output = lf_model(image_set[:,k],df[:,k],gazeX[:,k],gazeY[:,k],posX[:,k],posY[:,k],posZ[:,k]) # batchsize, layer_num x 2 = 6, layer_res: 320, layer_res: 320 + output1 = MergeBatchSpeed(output, gen, phi[:,k], phi_invalid[:,k], retinal_invalid[:,k]) + gen1[:,k] = output1[:,0:3] + gt[:,k] = gt[:,k].mul_(~retinal_invalid[:,k].to("cuda:2")) + + if ((epoch%10 == 0 and epoch != 0) or epoch == 2): + for i in range(output.shape[0]): + save_image(output[i][0:3].data,os.path.join(OUTPUT_DIR,"gaze_fac1_o_%.5f_%.5f_%.5f_%.5f_%.5f.png"%(df[i][k].data,gazeX[i][k].data,gazeY[i][k].data,posX[i][k].data,posY[i][k].data))) + save_image(output[i][3:6].data,os.path.join(OUTPUT_DIR,"gaze_fac2_o_%.5f_%.5f_%.5f_%.5f_%.5f.png"%(df[i][k].data,gazeX[i][k].data,gazeY[i][k].data,posX[i][k].data,posY[i][k].data))) + for i in range(1,gt.shape[1]): + # print(flow_invalid_mask.shape) #torch.Size([2, 4, 320, 320]) + # print(FlowMap(gen1[:,i-1],flow[:,i-1]).shape) #torch.Size([2, 3, 320, 320]) + warped[:,i-1] = FlowMap(gen1[:,i-1],flow[:,i-1]).mul(~flow_invalid_mask[:,i-1].unsqueeze(1).to("cuda:2")) + gen_temp[:,i-1] = gen1[:,i].mul(~flow_invalid_mask[:,i-1].unsqueeze(1).to("cuda:2")) + else: + for k in range(image_set.shape[1]): + output = lf_model(image_set[:,k],df[:,k],gazeX[:,k],gazeY[:,k],posX[:,k],posY[:,k],posZ[:,k]) # batchsize, layer_num x 2 = 6, layer_res: 320, layer_res: 320 + output1 = MergeBatchSpeed(output, gen, phi[:,k], phi_invalid[:,k], retinal_invalid[:,k]) + gen1[:,k] = output1[:,0:3] + gt[:,k] = gt[:,k].mul_(~retinal_invalid[:,k].to("cuda:2")) + + # print(output1.shape) #torch.Size([BS, 3, 320, 320]) + loss1 = loss_without_perc(output1,gt[:,k]) + loss = (w_frame * loss1) + # print("loss:",loss1.item()) + loss.backward(retain_graph=False) + optimizer.step() + lf_model.zero_grad() + optimizer.zero_grad() + + print("Epoch:",epoch,",Iter:",batch_idx,",Seq:",k,",loss:",loss.item()) + if ((epoch%10 == 0 and epoch != 0) or epoch == 2): + for i in range(output.shape[0]): + save_image(output[i][0:3].data,os.path.join(OUTPUT_DIR,"gaze_fac1_o_%.5f_%.5f_%.5f_%.5f_%.5f.png"%(df[i][k].data,gazeX[i][k].data,gazeY[i][k].data,posX[i][k].data,posY[i][k].data))) + save_image(output[i][3:6].data,os.path.join(OUTPUT_DIR,"gaze_fac2_o_%.5f_%.5f_%.5f_%.5f_%.5f.png"%(df[i][k].data,gazeX[i][k].data,gazeY[i][k].data,posX[i][k].data,posY[i][k].data))) + for i in range(1,gt.shape[1]): + warped[:,i-1] = FlowMap(gen1[:,i-1],flow[:,i-1]).mul(~flow_invalid_mask[:,i-1].unsqueeze(1).to("cuda:2")) + gen_temp[:,i-1] = gen1[:,i].mul(~flow_invalid_mask[:,i-1].unsqueeze(1).to("cuda:2")) + + warped = warped.reshape(-1,warped.shape[2],warped.shape[3],warped.shape[4]) + gen_temp = gen_temp.reshape(-1,gen_temp.shape[2],gen_temp.shape[3],gen_temp.shape[4]) + loss3 = l1loss(warped,gen_temp) + print("Epoch:",epoch,",Iter:",batch_idx,",inter-frame loss:",w_inter_frame *loss3.item()) + + # BSxSeq C H W + gen1 = gen1.reshape(-1,gen1.shape[2],gen1.shape[3],gen1.shape[4]) + gt = gt.reshape(-1,gt.shape[2],gt.shape[3],gt.shape[4]) + + if ((epoch%10== 0 and epoch != 0) or epoch == 2): # torch.Size([2, 5, 160, 160, 3]) + for i in range(gt.size()[0]): + save_image(gen1[i].data,os.path.join(OUTPUT_DIR,"gaze_out1_o_%.5f_%.5f_%.5f_%.5f_%.5f.png"%(df[i//5][i%5].data,gazeX[i//5][i%5].data,gazeY[i//5][i%5].data,posX[i//5][i%5].data,posY[i//5][i%5].data))) + save_image(gt[i].data,os.path.join(OUTPUT_DIR,"gaze_test1_gt0_%.5f_%.5f_%.5f_%.5f_%.5f.png"%(df[i//5][i%5].data,gazeX[i//5][i%5].data,gazeY[i//5][i%5].data,posX[i//5][i%5].data,posY[i//5][i%5].data))) + if ((epoch%100 == 0) and epoch != 0 and batch_idx==len(train_data_loader)-1): + save_checkpoints(os.path.join(OUTPUT_DIR, 'gaze-ckpt-epoch-%04d.pth' % (epoch)),epoch,lf_model,optimizer) + + + if batch_model != "Single": + if mode=="Perf": + end.record() + torch.cuda.synchronize() + print("forward:",start.elapsed_time(end)) + + start.record() + optimizer.zero_grad() - delta = torch.empty(gt2.shape[0],gt2.shape[1]-1,gt2.shape[2],gt2.shape[3],gt2.shape[4]) - delta = var_or_cuda(delta) - - for k in range(image_set.shape[1]): - if k == 0: - lf_model.reset_hidden(image_set[:,k]) - # start = torch.cuda.Event(enable_timing=True) - # end = torch.cuda.Event(enable_timing=True) - # start.record() - output = lf_model(image_set[:,k],df[:,k],gazeX[:,k],gazeY[:,k]) - # end.record() - # torch.cuda.synchronize() - # print("Model Forward:",start.elapsed_time(end)) - # print("output:",output.shape) # [2, 6, 320, 320] - # exit() - ########################### Use Pregen Phi and Mask ################### - # start.record() - output1,mask = GenRetinalGazeFromLayersBatch(output, gen, sample_idx[:,k], phi_dict, mask_dict) - # end.record() - # torch.cuda.synchronize() - # print("Merge:",start.elapsed_time(end)) - - # print("output1 shape:",output1.shape, "mask shape:",mask.shape) - # output1 shape: torch.Size([2, 6, 160, 160]) mask shape: torch.Size([2, 2, 160, 160]) - for i in range(0, 2): - output1[:,i*3:i*3+3].mul_(mask[:,i:i+1]) - if i == 0: - gt[:,k].mul_(mask[:,i:i+1]) - if i == 1: - gt2[:,k].mul_(mask[:,i:i+1]) + # BSxSeq C H W + gen1 = gen1.reshape(-1,gen1.shape[2],gen1.shape[3],gen1.shape[4]) + # gen2 = gen2.reshape(-1,gen2.shape[2],gen2.shape[3],gen2.shape[4]) + # BSxSeq C H W + gt = gt.reshape(-1,gt.shape[2],gt.shape[3],gt.shape[4]) + # gt2 = gt2.reshape(-1,gt2.shape[2],gt2.shape[3],gt2.shape[4]) + + # BSx(Seq-1) C H W + warped = warped.reshape(-1,warped.shape[2],warped.shape[3],warped.shape[4]) + gen_temp = gen_temp.reshape(-1,gen_temp.shape[2],gen_temp.shape[3],gen_temp.shape[4]) - gen1[:,k] = output1[:,0:3] - gen2[:,k] = output1[:,3:6] - if ((epoch%5== 0) or epoch == 2): - for i in range(output.shape[0]): - save_image(output[i][0:3].data,os.path.join(OUTPUT_DIR,"gaze_fac1_o_%.3f_%.3f_%.3f.png"%(df[i][k].data,gazeX[i][k].data,gazeY[i][k].data))) - save_image(output[i][3:6].data,os.path.join(OUTPUT_DIR,"gaze_fac2_o_%.3f_%.3f_%.3f.png"%(df[i][k].data,gazeX[i][k].data,gazeY[i][k].data))) - - ########################### Update ################### - for i in range(1,image_set.shape[1]): - delta[:,i-1] = gt2[:,i] - gt2[:,i] - warped[:,i-1] = gen2[:,i]-gen2[:,i-1] - - optimizer.zero_grad() - - # # N S C H W - gen1 = gen1.reshape(-1,gen1.shape[2],gen1.shape[3],gen1.shape[4]) - gen2 = gen2.reshape(-1,gen2.shape[2],gen2.shape[3],gen2.shape[4]) - gt = gt.reshape(-1,gt.shape[2],gt.shape[3],gt.shape[4]) - gt2 = gt2.reshape(-1,gt2.shape[2],gt2.shape[3],gt2.shape[4]) - warped = warped.reshape(-1,warped.shape[2],warped.shape[3],warped.shape[4]) - delta = delta.reshape(-1,delta.shape[2],delta.shape[3],delta.shape[4]) - - - # start = torch.cuda.Event(enable_timing=True) - # end = torch.cuda.Event(enable_timing=True) - # start.record() - loss1 = loss_new(gen1,gt) - loss2 = loss_new(gen2,gt2) - loss3 = l1loss(warped,delta) - loss = loss1+loss2+loss3 - # end.record() - # torch.cuda.synchronize() - # print("loss comp:",start.elapsed_time(end)) - - - # start.record() - loss.backward() - # end.record() - # torch.cuda.synchronize() - # print("backward:",start.elapsed_time(end)) - - # start.record() - optimizer.step() - # end.record() - # torch.cuda.synchronize() - # print("optimizer step:",start.elapsed_time(end)) - - ## Update Prev - print("Epoch:",epoch,",Iter:",batch_idx,",loss:",loss) - ########################### Save ##################### - if ((epoch%5== 0) or epoch == 2): # torch.Size([2, 5, 160, 160, 3]) - for i in range(gt.size()[0]): - # df 2,5 - save_image(gen1[i].data,os.path.join(OUTPUT_DIR,"gaze_out1_o_%.3f_%.3f_%.3f.png"%(df[i//5][i%5].data,gazeX[i//5][i%5].data,gazeY[i//5][i%5].data))) - save_image(gen2[i].data,os.path.join(OUTPUT_DIR,"gaze_out2_o_%.3f_%.3f_%.3f.png"%(df[i//5][i%5].data,gazeX[i//5][i%5].data,gazeY[i//5][i%5].data))) - save_image(gt[i].data,os.path.join(OUTPUT_DIR,"gaze_test1_gt0_%.3f_%.3f_%.3f.png"%(df[i//5][i%5].data,gazeX[i//5][i%5].data,gazeY[i//5][i%5].data))) - save_image(gt2[i].data,os.path.join(OUTPUT_DIR,"gaze_test1_gt1_%.3f_%.3f_%.3f.png"%(df[i//5][i%5].data,gazeX[i//5][i%5].data,gazeY[i//5][i%5].data))) - if ((epoch%100 == 0) and epoch != 0 and batch_idx==len(train_data_loader)-1): - save_checkpoints(os.path.join(OUTPUT_DIR, 'gaze-ckpt-epoch-%04d.pth' % (epoch + 1)),epoch,lf_model,optimizer) - - ########################## test Phi and Mask ########################## - # phi,mask = gen.CalculateRetinal2LayerMappings(df[0],torch.tensor([gazeX[0], gazeY[0]])) - # # print("gaze Online:",gazeX[0]," ,",gazeY[0]) - # # print("df Online:",df[0]) - # # print("idx:",int(sample_idx[0].data)) - # phi_t = phi_dict[int(sample_idx[0].data)] - # mask_t = mask_dict[int(sample_idx[0].data)] - # # print("idx info:",idx_info_dict[int(sample_idx[0].data)]) - # # print("phi online:", phi.shape, " phi_t:", phi_t.shape) - # # print("mask online:", mask.shape, " mask_t:", mask_t.shape) - # print("phi delta:", (phi-phi_t).sum()," mask delta:",(mask -mask_t).sum()) - # exit(0) - - ###########################Gen Batch 1 by 1################### - # phi,mask = gen.CalculateRetinal2LayerMappings(df[0],torch.tensor([gazeX[0], gazeY[0]])) - # # print(phi.shape) # 2,240,320,41,2 - # output1, mask = GenRetinalFromLayersBatch_Online(output, gen, phi, mask) - ###########################Gen Batch 1 by 1################### \ No newline at end of file + loss1_value = loss1(gen1,gt) + loss2_value = loss2(warped,gen_temp) + if model_type == "RNN": + loss = (w_frame * loss1_value)+ (w_inter_frame * loss2_value) + else: + loss = (w_frame * loss1_value) + + if mode=="Perf": + end.record() + torch.cuda.synchronize() + print("compute loss:",start.elapsed_time(end)) + + start.record() + loss.backward() + if mode=="Perf": + end.record() + torch.cuda.synchronize() + print("backward:",start.elapsed_time(end)) + + start.record() + optimizer.step() + if mode=="Perf": + end.record() + torch.cuda.synchronize() + print("update:",start.elapsed_time(end)) + + print("Epoch:",epoch,",Iter:",batch_idx,",loss:",loss.item(),",frame loss:",loss1_value.item(),",inter-frame loss:",loss2_value.item()) + + # exit(0) + ########################### Save ##################### + if ((epoch%10== 0 and epoch != 0) or epoch == 2): # torch.Size([2, 5, 160, 160, 3]) + for i in range(gt.size()[0]): + # df 2,5 + save_image(gen1[i].data,os.path.join(OUTPUT_DIR,"gaze_out1_o_%.5f_%.5f_%.5f_%.5f_%.5f.png"%(df[i//5][i%5].data,gazeX[i//5][i%5].data,gazeY[i//5][i%5].data,posX[i//5][i%5].data,posY[i//5][i%5].data))) + # save_image(gen2[i].data,os.path.join(OUTPUT_DIR,"gaze_out2_o_%.5f_%.5f_%.5f.png"%(df[i//5][i%5].data,gazeX[i//5][i%5].data,gazeY[i//5][i%5].data))) + save_image(gt[i].data,os.path.join(OUTPUT_DIR,"gaze_test1_gt0_%.5f_%.5f_%.5f_%.5f_%.5f.png"%(df[i//5][i%5].data,gazeX[i//5][i%5].data,gazeY[i//5][i%5].data,posX[i//5][i%5].data,posY[i//5][i%5].data))) + # save_image(gt2[i].data,os.path.join(OUTPUT_DIR,"gaze_test1_gt1_%.5f_%.5f_%.5f.png"%(df[i//5][i%5].data,gazeX[i//5][i%5].data,gazeY[i//5][i%5].data))) + if ((epoch%100 == 0) and epoch != 0 and batch_idx==len(train_data_loader)-1): + save_checkpoints(os.path.join(OUTPUT_DIR, 'gaze-ckpt-epoch-%04d.pth' % (epoch)),epoch,lf_model,optimizer) \ No newline at end of file diff --git a/main_lf_syn.py b/main_lf_syn.py index f6de1e1..f84ab8f 100644 --- a/main_lf_syn.py +++ b/main_lf_syn.py @@ -31,7 +31,7 @@ M = 1 # number of display layers DATA_FILE = "/home/yejiannan/Project/LightField/data/lf_syn" DATA_JSON = "/home/yejiannan/Project/LightField/data/data_lf_syn_full.json" # DATA_VAL_JSON = "/home/yejiannan/Project/LightField/data/data_gaze_fovea_val.json" -OUTPUT_DIR = "/home/yejiannan/Project/LightField/outputE/lf_syn_full_perc" +OUTPUT_DIR = "/home/yejiannan/Project/LightField/outputE/lf_syn_full" OUT_CHANNELS_RB = 128 KERNEL_SIZE_RB = 3 KERNEL_SIZE = 3 @@ -50,7 +50,7 @@ def save_checkpoints(file_path, epoch_idx, model, model_solver): torch.save(checkpoint, file_path) mode = "Silence" #"Perf" -w_frame = 0.9 +w_frame = 1.0 loss1 = PerceptionReconstructionLoss() if __name__ == "__main__": #train @@ -70,7 +70,7 @@ if __name__ == "__main__": if torch.cuda.is_available(): # lf_model = torch.nn.DataParallel(lf_model).cuda() - lf_model = lf_model.to('cuda:3') + lf_model = lf_model.to('cuda:1') optimizer = torch.optim.Adam(lf_model.parameters(),lr=5e-3,betas=(0.9,0.999)) diff --git a/model/__init__.py b/model/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/model/recurrent.py b/model/recurrent.py deleted file mode 100644 index 3aeaf6a..0000000 --- a/model/recurrent.py +++ /dev/null @@ -1,146 +0,0 @@ -import torch, os, sys, cv2 -import torch.nn as nn -from torch.nn import init -import functools -import torch.optim as optim - -from torch.utils.data import Dataset, DataLoader -from torch.nn import functional as func -from PIL import Image - -import torchvision.transforms as transforms -import numpy as np -import torch - - -class RecurrentBlock(nn.Module): - - def __init__(self, input_nc, output_nc, downsampling=False, bottleneck=False, upsampling=False): - super(RecurrentBlock, self).__init__() - - self.input_nc = input_nc - self.output_nc = output_nc - - self.downsampling = downsampling - self.upsampling = upsampling - self.bottleneck = bottleneck - - self.hidden = None - - if self.downsampling: - self.l1 = nn.Sequential( - nn.Conv2d(input_nc, output_nc, 3, padding=1), - nn.LeakyReLU(negative_slope=0.1) - ) - self.l2 = nn.Sequential( - nn.Conv2d(2 * output_nc, output_nc, 3, padding=1), - nn.LeakyReLU(negative_slope=0.1), - nn.Conv2d(output_nc, output_nc, 3, padding=1), - nn.LeakyReLU(negative_slope=0.1), - ) - elif self.upsampling: - self.l1 = nn.Sequential( - nn.Upsample(scale_factor=2, mode='nearest'), - nn.Conv2d(2 * input_nc, output_nc, 3, padding=1), - nn.LeakyReLU(negative_slope=0.1), - nn.Conv2d(output_nc, output_nc, 3, padding=1), - nn.LeakyReLU(negative_slope=0.1), - ) - elif self.bottleneck: - self.l1 = nn.Sequential( - nn.Conv2d(input_nc, output_nc, 3, padding=1), - nn.LeakyReLU(negative_slope=0.1) - ) - self.l2 = nn.Sequential( - nn.Conv2d(2 * output_nc, output_nc, 3, padding=1), - nn.LeakyReLU(negative_slope=0.1), - nn.Conv2d(output_nc, output_nc, 3, padding=1), - nn.LeakyReLU(negative_slope=0.1), - ) - - def forward(self, inp): - - if self.downsampling: - op1 = self.l1(inp) - op2 = self.l2(torch.cat((op1, self.hidden), dim=1)) - - self.hidden = op2 - - return op2 - elif self.upsampling: - op1 = self.l1(inp) - - return op1 - elif self.bottleneck: - op1 = self.l1(inp) - op2 = self.l2(torch.cat((op1, self.hidden), dim=1)) - - self.hidden = op2 - - return op2 - - def reset_hidden(self, inp, dfac): - size = list(inp.size()) - size[1] = self.output_nc - size[2] /= dfac - size[3] /= dfac - - self.hidden_size = size - self.hidden = torch.zeros(*(size)).to('cuda:0') - - - -class RecurrentAE(nn.Module): - - def __init__(self, input_nc): - super(RecurrentAE, self).__init__() - - self.d1 = RecurrentBlock(input_nc=input_nc, output_nc=32, downsampling=True) - self.d2 = RecurrentBlock(input_nc=32, output_nc=43, downsampling=True) - self.d3 = RecurrentBlock(input_nc=43, output_nc=57, downsampling=True) - self.d4 = RecurrentBlock(input_nc=57, output_nc=76, downsampling=True) - self.d5 = RecurrentBlock(input_nc=76, output_nc=101, downsampling=True) - - self.bottleneck = RecurrentBlock(input_nc=101, output_nc=101, bottleneck=True) - - self.u5 = RecurrentBlock(input_nc=101, output_nc=76, upsampling=True) - self.u4 = RecurrentBlock(input_nc=76, output_nc=57, upsampling=True) - self.u3 = RecurrentBlock(input_nc=57, output_nc=43, upsampling=True) - self.u2 = RecurrentBlock(input_nc=43, output_nc=32, upsampling=True) - self.u1 = RecurrentBlock(input_nc=32, output_nc=3, upsampling=True) - - def set_input(self, inp): - self.inp = inp['A'] - - def forward(self): - d1 = func.max_pool2d(input=self.d1(self.inp), kernel_size=2) - d2 = func.max_pool2d(input=self.d2(d1), kernel_size=2) - d3 = func.max_pool2d(input=self.d3(d2), kernel_size=2) - d4 = func.max_pool2d(input=self.d4(d3), kernel_size=2) - d5 = func.max_pool2d(input=self.d5(d4), kernel_size=2) - - b = self.bottleneck(d5) - - u5 = self.u5(torch.cat((b, d5), dim=1)) - u4 = self.u4(torch.cat((u5, d4), dim=1)) - u3 = self.u3(torch.cat((u4, d3), dim=1)) - u2 = self.u2(torch.cat((u3, d2), dim=1)) - u1 = self.u1(torch.cat((u2, d1), dim=1)) - - return u1 - - def reset_hidden(self): - self.d1.reset_hidden(self.inp, dfac=1) - self.d2.reset_hidden(self.inp, dfac=2) - self.d3.reset_hidden(self.inp, dfac=4) - self.d4.reset_hidden(self.inp, dfac=8) - self.d5.reset_hidden(self.inp, dfac=16) - - self.bottleneck.reset_hidden(self.inp, dfac=32) - - self.u4.reset_hidden(self.inp, dfac=16) - self.u3.reset_hidden(self.inp, dfac=8) - self.u5.reset_hidden(self.inp, dfac=4) - self.u2.reset_hidden(self.inp, dfac=2) - self.u1.reset_hidden(self.inp, dfac=1) - diff --git a/util.py b/util.py new file mode 100644 index 0000000..7e6846b --- /dev/null +++ b/util.py @@ -0,0 +1,104 @@ +import numpy as np +import torch +import matplotlib.pyplot as plt +import glm + +gvec_type = [ glm.dvec1, glm.dvec2, glm.dvec3, glm.dvec4 ] +gmat_type = [ [ glm.dmat2, glm.dmat2x3, glm.dmat2x4 ], + [ glm.dmat3x2, glm.dmat3, glm.dmat3x4 ], + [ glm.dmat4x2, glm.dmat4x3, glm.dmat4 ] ] + +def Fov2Length(angle): + return np.tan(angle * np.pi / 360) * 2 + +def SmoothStep(x0, x1, x): + y = torch.clamp((x - x0) / (x1 - x0), 0, 1) + return y * y * (3 - 2 * y) + +def MatImg2Tensor(img, permute = True, batch_dim = True): + batch_input = len(img.shape) == 4 + if permute: + t = torch.from_numpy(np.transpose(img, + [0, 3, 1, 2] if batch_input else [2, 0, 1])) + else: + t = torch.from_numpy(img) + if not batch_input and batch_dim: + t = t.unsqueeze(0) + return t + +def MatImg2Numpy(img, permute = True, batch_dim = True): + batch_input = len(img.shape) == 4 + if permute: + t = np.transpose(img, [0, 3, 1, 2] if batch_input else [2, 0, 1]) + else: + t = img + if not batch_input and batch_dim: + t = t.unsqueeze(0) + return t + +def Tensor2MatImg(t): + img = t.squeeze().cpu().numpy() + batch_input = len(img.shape) == 4 + if t.size()[batch_input] <= 4: + return np.transpose(img, [0, 2, 3, 1] if batch_input else [1, 2, 0]) + return img + +def ReadImageTensor(path, permute = True, rgb_only = True, batch_dim = True): + channels = 3 if rgb_only else 4 + if isinstance(path,list): + first_image = plt.imread(path[0])[:, :, 0:channels] + b_image = np.empty((len(path), first_image.shape[0], first_image.shape[1], channels), dtype=np.float32) + b_image[0] = first_image + for i in range(1, len(path)): + b_image[i] = plt.imread(path[i])[:, :, 0:channels] + return MatImg2Tensor(b_image, permute) + return MatImg2Tensor(plt.imread(path)[:, :, 0:channels], permute, batch_dim) + +def ReadImageNumpyArray(path, permute = True, rgb_only = True, batch_dim = True): + channels = 3 if rgb_only else 4 + if isinstance(path,list): + first_image = plt.imread(path[0])[:, :, 0:channels] + b_image = np.empty((len(path), first_image.shape[0], first_image.shape[1], channels), dtype=np.float32) + b_image[0] = first_image + for i in range(1, len(path)): + b_image[i] = plt.imread(path[i])[:, :, 0:channels] + return MatImg2Numpy(b_image, permute) + return MatImg2Numpy(plt.imread(path)[:, :, 0:channels], permute, batch_dim) + +def WriteImageTensor(t, path): + image = Tensor2MatImg(t) + if isinstance(path,list): + if len(image.shape) != 4 or image.shape[0] != len(path): + raise ValueError + for i in range(len(path)): + plt.imsave(path[i], image[i]) + else: + if len(image.shape) == 4 and image.shape[0] != 1: + raise ValueError + plt.imsave(path, image) + +def PlotImageTensor(t): + plt.imshow(Tensor2MatImg(t)) + +def Tensor2Glm(t): + t = t.squeeze() + size = t.size() + if len(size) == 1: + if size[0] <= 0 or size[0] > 4: + raise ValueError + return gvec_type[size[0] - 1](t.cpu().numpy()) + if len(size) == 2: + if size[0] <= 1 or size[0] > 4 or size[1] <= 1 or size[1] > 4: + raise ValueError + return gmat_type[size[1] - 2][size[0] - 2](t.cpu().numpy()) + raise ValueError + +def Glm2Tensor(val): + return torch.from_numpy(np.array(val)) + +def MeshGrid(size, normalize=False): + y,x = torch.meshgrid(torch.tensor(range(size[0])), + torch.tensor(range(size[1]))) + if normalize: + return torch.stack([x / (size[1] - 1.), y / (size[0] - 1.)], 2) + return torch.stack([x, y], 2) \ No newline at end of file diff --git a/weight_init.py b/weight_init.py new file mode 100644 index 0000000..9c0a20a --- /dev/null +++ b/weight_init.py @@ -0,0 +1,12 @@ +import torch +weightVarScale = 0.25 +bias_stddev = 0.01 + +def weight_init_normal(m): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + torch.nn.init.xavier_normal_(m.weight.data) + torch.nn.init.normal_(m.bias.data,mean = 0.0, std=bias_stddev) + elif classname.find("BatchNorm2d") != -1: + torch.nn.init.normal_(m.weight.data, 1.0, 0.02) + torch.nn.init.constant_(m.bias.data, 0.0) \ No newline at end of file -- GitLab