From 69e1d015d57ae8b3e789706f5659273bf904f2ed Mon Sep 17 00:00:00 2001 From: BobYeah <635596704@qq.com> Date: Sat, 5 Dec 2020 01:08:20 +0800 Subject: [PATCH] Update1205ForHPC --- conf.py | 27 +++ data.py | 0 gen_image.py | 35 ++- main.py | 590 +++++++++++++++++++++++++++++++-------------- model/__init__.py | 0 model/baseline.py | 167 +++++++++++++ model/recurrent.py | 146 +++++++++++ 7 files changed, 782 insertions(+), 183 deletions(-) create mode 100644 conf.py create mode 100644 data.py create mode 100644 model/__init__.py create mode 100644 model/baseline.py create mode 100644 model/recurrent.py diff --git a/conf.py b/conf.py new file mode 100644 index 0000000..a6aea74 --- /dev/null +++ b/conf.py @@ -0,0 +1,27 @@ +import torch +from gen_image import * +class Conf(object): + def __init__(self): + self.pupil_size = 0.02 + self.retinal_res = torch.tensor([ 320, 320 ]) + self.layer_res = torch.tensor([ 320, 320 ]) + self.layer_hfov = 90 # layers' horizontal FOV + self.eye_hfov = 85 # eye's horizontal FOV (ignored in foveated rendering) + self.eye_enable_fovea = True # enable foveated rendering + self.eye_fovea_angles = [ 40, 80 ] # eye's foveation layers' angles + self.eye_fovea_downsamples = [ 1, 2 ] # eye's foveation layers' downsamples + self.d_layer = [ 1, 3 ] # layers' distance + + def GetNLayers(self): + return len(self.d_layer) + + def GetLayerSize(self, i): + w = Fov2Length(self.layer_hfov) + h = w * self.layer_res[0] / self.layer_res[1] + return torch.tensor([ h, w ]) * self.d_layer[i] + + def GetEyeViewportSize(self): + fov = self.eye_fovea_angles[-1] if self.eye_enable_fovea else self.eye_hfov + w = Fov2Length(fov) + h = w * self.retinal_res[0] / self.retinal_res[1] + return torch.tensor([ h, w ]) \ No newline at end of file diff --git a/data.py b/data.py new file mode 100644 index 0000000..e69de29 diff --git a/gen_image.py b/gen_image.py index 1c72bac..7ae1de8 100644 --- a/gen_image.py +++ b/gen_image.py @@ -158,4 +158,37 @@ class RetinalGen(object): # print("mapped_layers:",mapped_layers.shape) retinal = mapped_layers.prod(0).sum(3).div(Phi.size()[3]) # print("retinal:",retinal.shape) - return retinal \ No newline at end of file + return retinal + + def GenFoveaLayers(self, retinal, retinal_mask): + ''' + Generate foveated layers and corresponding masks + + Parameters + -------- + retinal - Retinal image generated by GenRetinalFromLayers() + retinal_mask - Mask of retinal image, also generated by GenRetinalFromLayers() + + Returns + -------- + fovea_layers - list of foveated layers + fovea_layer_masks - list of mask images, corresponding to foveated layers + ''' + fovea_layers = [] + fovea_layer_masks = [] + fov = self.conf.eye_fovea_angles[-1] + retinal_res = int(self.conf.retinal_res[0]) + for i in range(0, len(self.conf.eye_fovea_angles)): + angle = self.conf.eye_fovea_angles[i] + k = self.conf.eye_fovea_downsamples[i] + roi_size = int(np.ceil(retinal_res * angle / fov)) + roi_offset = int((retinal_res - roi_size) / 2) + roi_img = retinal[:, roi_offset:(roi_offset + roi_size), roi_offset:(roi_offset + roi_size)] + roi_mask = retinal_mask[roi_offset:(roi_offset + roi_size), roi_offset:(roi_offset + roi_size)] + if k == 1: + fovea_layers.append(roi_img) + fovea_layer_masks.append(roi_mask) + else: + fovea_layers.append(torch.nn.functional.avg_pool2d(roi_img.unsqueeze(0), k).squeeze(0)) + fovea_layer_masks.append(1 - torch.nn.functional.max_pool2d((1 - roi_mask).unsqueeze(0), k).squeeze(0)) + return [ fovea_layers, fovea_layer_masks ] \ No newline at end of file diff --git a/main.py b/main.py index 9743a61..278ad74 100644 --- a/main.py +++ b/main.py @@ -15,9 +15,14 @@ from gen_image import * import json from ssim import * from perc_loss import * +from conf import Conf + +from model.baseline import * + +import torch.autograd.profiler as profiler # param -BATCH_SIZE = 16 -NUM_EPOCH = 1000 +BATCH_SIZE = 2 +NUM_EPOCH = 300 INTERLEAVE_RATE = 2 @@ -30,15 +35,24 @@ Retinal_IM_W = 320 N = 9 # number of input light field stack M = 2 # number of display layers -DATA_FILE = "/home/yejiannan/Project/LightField/data/gaze_small_nar_new" -DATA_JSON = "/home/yejiannan/Project/LightField/data/data_gaze_low_new.json" -DATA_VAL_JSON = "/home/yejiannan/Project/LightField/data/data_val.json" -OUTPUT_DIR = "/home/yejiannan/Project/LightField/output/gaze_low_new_1125_minibatch" +DATA_FILE = "/home/yejiannan/Project/LightField/data/gaze_fovea" +DATA_JSON = "/home/yejiannan/Project/LightField/data/data_gaze_fovea_seq.json" +DATA_VAL_JSON = "/home/yejiannan/Project/LightField/data/data_gaze_fovea_val.json" +OUTPUT_DIR = "/home/yejiannan/Project/LightField/outputE/gaze_fovea_seq" + + +OUT_CHANNELS_RB = 128 +KERNEL_SIZE_RB = 3 +KERNEL_SIZE = 3 + +LAST_LAYER_CHANNELS = 6 * INTERLEAVE_RATE**2 +FIRSST_LAYER_CHANNELS = 27 * INTERLEAVE_RATE**2 class lightFieldDataLoader(torch.utils.data.dataset.Dataset): def __init__(self, file_dir_path, file_json, transforms=None): self.file_dir_path = file_dir_path self.transforms = transforms + # self.datum_list = glob.glob(os.path.join(file_dir_path,"*")) with open(file_json, encoding='utf-8') as file: self.dataset_desc = json.loads(file.read()) @@ -46,17 +60,27 @@ class lightFieldDataLoader(torch.utils.data.dataset.Dataset): return len(self.dataset_desc["focaldepth"]) def __getitem__(self, idx): - lightfield_images, gt, fd, gazeX, gazeY, sample_idx = self.get_datum(idx) + lightfield_images, gt, gt2, fd, gazeX, gazeY, sample_idx = self.get_datum(idx) if self.transforms: lightfield_images = self.transforms(lightfield_images) - return (lightfield_images, gt, fd, gazeX, gazeY, sample_idx) + # print(lightfield_images.shape,gt.shape,fd,gazeX,gazeY,sample_idx) + return (lightfield_images, gt, gt2, fd, gazeX, gazeY, sample_idx) def get_datum(self, idx): lf_image_paths = os.path.join(DATA_FILE, self.dataset_desc["train"][idx]) + # print(lf_image_paths) fd_gt_path = os.path.join(DATA_FILE, self.dataset_desc["gt"][idx]) + fd_gt_path2 = os.path.join(DATA_FILE, self.dataset_desc["gt2"][idx]) + # print(fd_gt_path) lf_images = [] lf_image_big = cv2.imread(lf_image_paths, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255. lf_image_big = cv2.cvtColor(lf_image_big,cv2.COLOR_BGR2RGB) + + ## IF GrayScale + # lf_image_big = cv2.imread(lf_image_paths, cv2.IMREAD_GRAYSCALE).astype(np.float32) / 255. + # lf_image_big = np.expand_dims(lf_image_big, axis=-1) + # print(lf_image_big.shape) + for i in range(9): lf_image = lf_image_big[i//3*IM_H:i//3*IM_H+IM_H,i%3*IM_W:i%3*IM_W+IM_W,0:3] ## IF GrayScale @@ -65,6 +89,8 @@ class lightFieldDataLoader(torch.utils.data.dataset.Dataset): lf_images.append(lf_image) gt = cv2.imread(fd_gt_path, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255. gt = cv2.cvtColor(gt,cv2.COLOR_BGR2RGB) + gt2 = cv2.imread(fd_gt_path2, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255. + gt2 = cv2.cvtColor(gt2,cv2.COLOR_BGR2RGB) ## IF GrayScale # gt = cv2.imread(fd_gt_path, cv2.IMREAD_GRAYSCALE).astype(np.float32) / 255. # gt = np.expand_dims(gt, axis=-1) @@ -73,149 +99,124 @@ class lightFieldDataLoader(torch.utils.data.dataset.Dataset): gazeX = self.dataset_desc["gazeX"][idx] gazeY = self.dataset_desc["gazeY"][idx] sample_idx = self.dataset_desc["idx"][idx] - return np.asarray(lf_images),gt,fd,gazeX,gazeY,sample_idx + return np.asarray(lf_images),gt,gt2,fd,gazeX,gazeY,sample_idx -OUT_CHANNELS_RB = 128 -KERNEL_SIZE_RB = 3 -KERNEL_SIZE = 3 +class lightFieldValDataLoader(torch.utils.data.dataset.Dataset): + def __init__(self, file_dir_path, file_json, transforms=None): + self.file_dir_path = file_dir_path + self.transforms = transforms + # self.datum_list = glob.glob(os.path.join(file_dir_path,"*")) + with open(file_json, encoding='utf-8') as file: + self.dataset_desc = json.loads(file.read()) -class residual_block(torch.nn.Module): - def __init__(self,delta_channel_dim): - super(residual_block,self).__init__() - self.layer1 = torch.nn.Sequential( - torch.nn.Conv2d(OUT_CHANNELS_RB+delta_channel_dim,OUT_CHANNELS_RB+delta_channel_dim,KERNEL_SIZE_RB,stride=1,padding = 1), - torch.nn.BatchNorm2d(OUT_CHANNELS_RB+delta_channel_dim), - torch.nn.ELU() - ) - self.layer2 = torch.nn.Sequential( - torch.nn.Conv2d(OUT_CHANNELS_RB+delta_channel_dim,OUT_CHANNELS_RB+delta_channel_dim,KERNEL_SIZE_RB,stride=1,padding = 1), - torch.nn.BatchNorm2d(OUT_CHANNELS_RB+delta_channel_dim), - torch.nn.ELU() - ) - - def forward(self,input): - output = self.layer1(input) - output = self.layer2(output) - output = input+output - return output - -class deinterleave(torch.nn.Module): - def __init__(self, block_size): - super(deinterleave, self).__init__() - self.block_size = block_size - self.block_size_sq = block_size*block_size - - def forward(self, input): - output = input.permute(0, 2, 3, 1) - (batch_size, d_height, d_width, d_depth) = output.size() - s_depth = int(d_depth / self.block_size_sq) - s_width = int(d_width * self.block_size) - s_height = int(d_height * self.block_size) - t_1 = output.reshape(batch_size, d_height, d_width, self.block_size_sq, s_depth) - spl = t_1.split(self.block_size, 3) - stack = [t_t.reshape(batch_size, d_height, s_width, s_depth) for t_t in spl] - output = torch.stack(stack,0).transpose(0,1).permute(0,2,1,3,4).reshape(batch_size, s_height, s_width, s_depth) - output = output.permute(0, 3, 1, 2) - return output - -class interleave(torch.nn.Module): - def __init__(self, block_size): - super(interleave, self).__init__() - self.block_size = block_size - self.block_size_sq = block_size*block_size - - def forward(self, input): - output = input.permute(0, 2, 3, 1) - (batch_size, s_height, s_width, s_depth) = output.size() - d_depth = s_depth * self.block_size_sq - d_width = int(s_width / self.block_size) - d_height = int(s_height / self.block_size) - t_1 = output.split(self.block_size, 2) - stack = [t_t.reshape(batch_size, d_height, d_depth) for t_t in t_1] - output = torch.stack(stack, 1) - output = output.permute(0, 2, 1, 3) - output = output.permute(0, 3, 1, 2) - return output + def __len__(self): + return len(self.dataset_desc["focaldepth"]) -LAST_LAYER_CHANNELS = 6 * INTERLEAVE_RATE**2 -FIRSST_LAYER_CHANNELS = 27 * INTERLEAVE_RATE**2 + def __getitem__(self, idx): + lightfield_images, fd, gazeX, gazeY, sample_idx = self.get_datum(idx) + if self.transforms: + lightfield_images = self.transforms(lightfield_images) + # print(lightfield_images.shape,gt.shape,fd,gazeX,gazeY,sample_idx) + return (lightfield_images, fd, gazeX, gazeY, sample_idx) -class model(torch.nn.Module): - def __init__(self): - super(model, self).__init__() - self.interleave = interleave(INTERLEAVE_RATE) + def get_datum(self, idx): + lf_image_paths = os.path.join(DATA_FILE, self.dataset_desc["train"][idx]) + # print(fd_gt_path) + lf_images = [] + lf_image_big = cv2.imread(lf_image_paths, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255. + lf_image_big = cv2.cvtColor(lf_image_big,cv2.COLOR_BGR2RGB) - self.first_layer = torch.nn.Sequential( - torch.nn.Conv2d(FIRSST_LAYER_CHANNELS,OUT_CHANNELS_RB,KERNEL_SIZE,stride=1,padding=1), - torch.nn.BatchNorm2d(OUT_CHANNELS_RB), - torch.nn.ELU() - ) - - self.residual_block1 = residual_block(0) - self.residual_block2 = residual_block(3) - self.residual_block3 = residual_block(3) - self.residual_block4 = residual_block(3) - self.residual_block5 = residual_block(3) - - self.output_layer = torch.nn.Sequential( - torch.nn.Conv2d(OUT_CHANNELS_RB+3,LAST_LAYER_CHANNELS,KERNEL_SIZE,stride=1,padding=1), - torch.nn.BatchNorm2d(LAST_LAYER_CHANNELS), - torch.nn.Sigmoid() - ) - self.deinterleave = deinterleave(INTERLEAVE_RATE) - - - def forward(self, lightfield_images, focal_length, gazeX, gazeY): - input_to_net = self.interleave(lightfield_images) - input_to_rb = self.first_layer(input_to_net) - output = self.residual_block1(input_to_rb) - depth_layer = torch.ones((input_to_rb.shape[0],1,input_to_rb.shape[2],input_to_rb.shape[3])) - gazeX_layer = torch.ones((input_to_rb.shape[0],1,input_to_rb.shape[2],input_to_rb.shape[3])) - gazeY_layer = torch.ones((input_to_rb.shape[0],1,input_to_rb.shape[2],input_to_rb.shape[3])) - for i in range(focal_length.shape[0]): - depth_layer[i] *= 1. / focal_length[i] - gazeX_layer[i] *= (gazeX[i] - (-3.333)) / (3.333*2) - gazeY_layer[i] *= (gazeY[i] - (-3.333)) / (3.333*2) - depth_layer = var_or_cuda(depth_layer) - gazeX_layer = var_or_cuda(gazeX_layer) - gazeY_layer = var_or_cuda(gazeY_layer) - - output = torch.cat((output,depth_layer,gazeX_layer,gazeY_layer),dim=1) - output = self.residual_block2(output) - output = self.residual_block3(output) - output = self.residual_block4(output) - output = self.residual_block5(output) - output = self.output_layer(output) - output = self.deinterleave(output) - return output - -class Conf(object): - def __init__(self): - self.pupil_size = 0.02 # 2cm - self.retinal_res = torch.tensor([ Retinal_IM_H, Retinal_IM_W ]) - self.layer_res = torch.tensor([ IM_H, IM_W ]) - self.layer_hfov = 90 # layers' horizontal FOV - self.eye_hfov = 85 # eye's horizontal FOV - self.d_layer = [ 1, 3 ] # layers' distance - - def GetNLayers(self): - return len(self.d_layer) - - def GetLayerSize(self, i): - w = Fov2Length(self.layer_hfov) - h = w * self.layer_res[0] / self.layer_res[1] - return torch.tensor([ h, w ]) * self.d_layer[i] + ## IF GrayScale + # lf_image_big = cv2.imread(lf_image_paths, cv2.IMREAD_GRAYSCALE).astype(np.float32) / 255. + # lf_image_big = np.expand_dims(lf_image_big, axis=-1) + # print(lf_image_big.shape) + + for i in range(9): + lf_image = lf_image_big[i//3*IM_H:i//3*IM_H+IM_H,i%3*IM_W:i%3*IM_W+IM_W,0:3] + ## IF GrayScale + # lf_image = lf_image_big[i//3*IM_H:i//3*IM_H+IM_H,i%3*IM_W:i%3*IM_W+IM_W,0:1] + # print(lf_image.shape) + lf_images.append(lf_image) + ## IF GrayScale + # gt = cv2.imread(fd_gt_path, cv2.IMREAD_GRAYSCALE).astype(np.float32) / 255. + # gt = np.expand_dims(gt, axis=-1) + + fd = self.dataset_desc["focaldepth"][idx] + gazeX = self.dataset_desc["gazeX"][idx] + gazeY = self.dataset_desc["gazeY"][idx] + sample_idx = self.dataset_desc["idx"][idx] + return np.asarray(lf_images),fd,gazeX,gazeY,sample_idx - def GetEyeViewportSize(self): - w = Fov2Length(self.eye_hfov) - h = w * self.retinal_res[0] / self.retinal_res[1] - return torch.tensor([ h, w ]) +class lightFieldSeqDataLoader(torch.utils.data.dataset.Dataset): + def __init__(self, file_dir_path, file_json, transforms=None): + self.file_dir_path = file_dir_path + self.transforms = transforms + with open(file_json, encoding='utf-8') as file: + self.dataset_desc = json.loads(file.read()) + + def __len__(self): + return len(self.dataset_desc["seq"]) + + def __getitem__(self, idx): + lightfield_images, gt, gt2, fd, gazeX, gazeY, sample_idx = self.get_datum(idx) + fd = fd.astype(np.float32) + gazeX = gazeX.astype(np.float32) + gazeY = gazeY.astype(np.float32) + sample_idx = sample_idx.astype(np.int64) + # print(fd) + # print(gazeX) + # print(gazeY) + # print(sample_idx) + + # print(lightfield_images.dtype,gt.dtype, gt2.dtype, fd.dtype, gazeX.dtype, gazeY.dtype, sample_idx.dtype, delta.dtype) + # print(lightfield_images.shape,gt.shape, gt2.shape, fd.shape, gazeX.shape, gazeY.shape, sample_idx.shape, delta.shape) + if self.transforms: + lightfield_images = self.transforms(lightfield_images) + return (lightfield_images, gt, gt2, fd, gazeX, gazeY, sample_idx) + + def get_datum(self, idx): + indices = self.dataset_desc["seq"][idx] + # print("indices:",indices) + lf_images = [] + fd = [] + gazeX = [] + gazeY = [] + sample_idx = [] + gt = [] + gt2 = [] + for i in range(len(indices)): + lf_image_paths = os.path.join(DATA_FILE, self.dataset_desc["train"][indices[i]]) + fd_gt_path = os.path.join(DATA_FILE, self.dataset_desc["gt"][indices[i]]) + fd_gt_path2 = os.path.join(DATA_FILE, self.dataset_desc["gt2"][indices[i]]) + lf_image_one_sample = [] + lf_image_big = cv2.imread(lf_image_paths, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255. + lf_image_big = cv2.cvtColor(lf_image_big,cv2.COLOR_BGR2RGB) + + for j in range(9): + lf_image = lf_image_big[j//3*IM_H:j//3*IM_H+IM_H,j%3*IM_W:j%3*IM_W+IM_W,0:3] + ## IF GrayScale + # lf_image = lf_image_big[i//3*IM_H:i//3*IM_H+IM_H,i%3*IM_W:i%3*IM_W+IM_W,0:1] + # print(lf_image.shape) + lf_image_one_sample.append(lf_image) + + gt_i = cv2.imread(fd_gt_path, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255. + gt.append(cv2.cvtColor(gt_i,cv2.COLOR_BGR2RGB)) + gt2_i = cv2.imread(fd_gt_path2, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255. + gt2.append(cv2.cvtColor(gt2_i,cv2.COLOR_BGR2RGB)) + + # print("indices[i]:",indices[i]) + fd.append([self.dataset_desc["focaldepth"][indices[i]]]) + gazeX.append([self.dataset_desc["gazeX"][indices[i]]]) + gazeY.append([self.dataset_desc["gazeY"][indices[i]]]) + sample_idx.append([self.dataset_desc["idx"][indices[i]]]) + lf_images.append(lf_image_one_sample) + #lf_images: 5,9,320,320 + + return np.asarray(lf_images),np.asarray(gt),np.asarray(gt2),np.asarray(fd),np.asarray(gazeX),np.asarray(gazeY),np.asarray(sample_idx) #### Image Gen conf = Conf() - u = GenSamplesInPupil(conf.pupil_size, 5) - gen = RetinalGen(conf, u) def GenRetinalFromLayersBatch(layers, gen, sample_idx, phi_dict, mask_dict): @@ -228,14 +229,44 @@ def GenRetinalFromLayersBatch(layers, gen, sample_idx, phi_dict, mask_dict): mask = [] # mask shape 480 x 640 for i in range(0, layers.size()[0]): phi = phi_dict[int(sample_idx[i].data)] + # print("phi_i:",phi.shape) phi = var_or_cuda(phi) phi.requires_grad = False + # print("layers[i]:",layers[i].shape) + # print("retinal[i]:",retinal[i].shape) retinal[i] = gen.GenRetinalFromLayers(layers[i],phi) mask.append(mask_dict[int(sample_idx[i].data)]) retinal = var_or_cuda(retinal) mask = torch.stack(mask,dim = 0).unsqueeze(1) # batch x 1 x height x width return retinal, mask +def GenRetinalGazeFromLayersBatch(layers, gen, sample_idx, phi_dict, mask_dict): + # layers: batchsize, 2*color, height, width + # Phi:torch.Size([batchsize, 480, 640, 2, 41, 2]) + # df : batchsize,.. + + # retinal bs x color x height x width + retinal_fovea = torch.empty(layers.shape[0], 6, 160, 160) + mask_fovea = torch.empty(layers.shape[0], 2, 160, 160) + for i in range(0, layers.size()[0]): + phi = phi_dict[int(sample_idx[i].data)] + # print("phi_i:",phi.shape) + phi = var_or_cuda(phi) + phi.requires_grad = False + mask_i = var_or_cuda(mask_dict[int(sample_idx[i].data)]) + mask_i.requires_grad = False + # print("layers[i]:",layers[i].shape) + # print("retinal[i]:",retinal[i].shape) + retinal_i = gen.GenRetinalFromLayers(layers[i],phi) + fovea_layers, fovea_layer_masks = gen.GenFoveaLayers(retinal_i,mask_i) + retinal_fovea[i] = torch.cat([fovea_layers[0],fovea_layers[1]],dim=0) + mask_fovea[i] = torch.stack([fovea_layer_masks[0],fovea_layer_masks[1]],dim=0) + + retinal_fovea = var_or_cuda(retinal_fovea) + mask_fovea = var_or_cuda(mask_fovea) # batch x 2 x height x width + # mask = torch.stack(mask,dim = 0).unsqueeze(1) + return retinal_fovea, mask_fovea + def GenRetinalFromLayersBatch_Online(layers, gen, phi, mask): # layers: batchsize, 2*color, height, width # Phi:torch.Size([batchsize, 480, 640, 2, 41, 2]) @@ -249,6 +280,8 @@ def GenRetinalFromLayersBatch_Online(layers, gen, phi, mask): retinal = gen.GenRetinalFromLayers(layers[0],phi) retinal = var_or_cuda(retinal) mask_out = mask.unsqueeze(0).unsqueeze(0) + # print("maskOUt:",mask_out.shape) # 1,1,240,320 + # mask_out = torch.stack(mask,dim = 0).unsqueeze(1) # batch x 1 x height x width return retinal.unsqueeze(0), mask_out #### Image Gen End @@ -264,10 +297,6 @@ def weight_init_normal(m): torch.nn.init.normal_(m.weight.data, 1.0, 0.02) torch.nn.init.constant_(m.bias.data, 0.0) -def var_or_cuda(x): - if torch.cuda.is_available(): - x = x.cuda(non_blocking=True) - return x def calImageGradients(images): # x is a 4-D tensor @@ -277,19 +306,25 @@ def calImageGradients(images): perc_loss = VGGPerceptualLoss() -perc_loss = perc_loss.to("cuda") +perc_loss = perc_loss.to("cuda:1") def loss_new(generated, gt): mse_loss = torch.nn.MSELoss() rmse_intensity = mse_loss(generated, gt) psnr_intensity = torch.log10(rmse_intensity) + # print("psnr:",psnr_intensity) + # ssim_intensity = ssim(generated, gt) labels_dx, labels_dy = calImageGradients(gt) + # print("generated:",generated.shape) preds_dx, preds_dy = calImageGradients(generated) rmse_grad_x, rmse_grad_y = mse_loss(labels_dx, preds_dx), mse_loss(labels_dy, preds_dy) psnr_grad_x, psnr_grad_y = torch.log10(rmse_grad_x), torch.log10(rmse_grad_y) + # print("psnr x&y:",psnr_grad_x," ",psnr_grad_y) p_loss = perc_loss(generated,gt) + # print("-psnr:",-psnr_intensity,",0.5*(psnr_grad_x + psnr_grad_y):",0.5*(psnr_grad_x + psnr_grad_y),",perc_loss:",p_loss) total_loss = 10 + psnr_intensity + 0.5*(psnr_grad_x + psnr_grad_y) + p_loss + # total_loss = rmse_intensity + 0.5*(rmse_grad_x + rmse_grad_y) # + p_loss return total_loss def save_checkpoints(file_path, epoch_idx, model, model_solver): @@ -301,6 +336,18 @@ def save_checkpoints(file_path, epoch_idx, model, model_solver): } torch.save(checkpoint, file_path) +mode = "train" + +import pickle +def save_obj(obj, name ): + # with open('./outputF/dict/'+ name + '.pkl', 'wb') as f: + # pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL) + torch.save(obj,'./outputF/dict/'+ name + '.pkl') +def load_obj(name): + # with open('./outputF/dict/' + name + '.pkl', 'rb') as f: + # return pickle.load(f) + return torch.load('./outputF/dict/'+ name + '.pkl') + def hook_fn_back(m, i, o): for grad in i: try: @@ -327,14 +374,11 @@ def hook_fn_for(m, i, o): print ("None found for Gradient") print("\n") -if __name__ == "__main__": - - ############################## generate phi and mask in pre-training +def generatePhiMaskDict(data_json, generator): phi_dict = {} mask_dict = {} idx_info_dict = {} - print("generating phi and mask...") - with open(DATA_JSON, encoding='utf-8') as file: + with open(data_json, encoding='utf-8') as file: dataset_desc = json.loads(file.read()) for i in range(len(dataset_desc["focaldepth"])): # if i == 2: @@ -343,16 +387,34 @@ if __name__ == "__main__": focaldepth = dataset_desc["focaldepth"][i] gazeX = dataset_desc["gazeX"][i] gazeY = dataset_desc["gazeY"][i] - # print("focaldepth:",focaldepth," idx:",idx," gazeX:",gazeX," gazeY:",gazeY) - phi,mask = gen.CalculateRetinal2LayerMappings(focaldepth,torch.tensor([gazeX, gazeY])) + print("focaldepth:",focaldepth," idx:",idx," gazeX:",gazeX," gazeY:",gazeY) + phi,mask = generator.CalculateRetinal2LayerMappings(focaldepth,torch.tensor([gazeX, gazeY])) phi_dict[idx]=phi mask_dict[idx]=mask idx_info_dict[idx]=[idx,focaldepth,gazeX,gazeY] - print("generating phi and mask end.") + return phi_dict,mask_dict,idx_info_dict + +if __name__ == "__main__": + ############################## generate phi and mask in pre-training + # print("generating phi and mask...") + # phi_dict,mask_dict,idx_info_dict = generatePhiMaskDict(DATA_JSON,gen) + # save_obj(phi_dict,"phi_1204") + # save_obj(mask_dict,"mask_1204") + # save_obj(idx_info_dict,"idx_info_1204") + # print("generating phi and mask end.") # exit(0) + ############################# load phi and mask in pre-training + print("loading phi and mask ...") + phi_dict = load_obj("phi_1204") + mask_dict = load_obj("mask_1204") + idx_info_dict = load_obj("idx_info_1204") + print(len(phi_dict)) + print(len(mask_dict)) + print("loading phi and mask end") + #train - train_data_loader = torch.utils.data.DataLoader(dataset=lightFieldDataLoader(DATA_FILE,DATA_JSON), + train_data_loader = torch.utils.data.DataLoader(dataset=lightFieldSeqDataLoader(DATA_FILE,DATA_JSON), batch_size=BATCH_SIZE, num_workers=0, pin_memory=True, @@ -362,49 +424,213 @@ if __name__ == "__main__": # exit(0) + + ################################################ val ######################################################### + # val_data_loader = torch.utils.data.DataLoader(dataset=lightFieldValDataLoader(DATA_FILE,DATA_VAL_JSON), + # batch_size=1, + # num_workers=0, + # pin_memory=True, + # shuffle=False, + # drop_last=False) + + # print(len(val_data_loader)) + + # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + # lf_model = baseline.model() + # if torch.cuda.is_available(): + # lf_model = torch.nn.DataParallel(lf_model).cuda() + + # checkpoint = torch.load(os.path.join(OUTPUT_DIR,"gaze-ckpt-epoch-0201.pth")) + # lf_model.load_state_dict(checkpoint["model_state_dict"]) + # lf_model.eval() + + # print("Eval::") + # for sample_idx, (image_set, df, gazeX, gazeY, sample_idx) in enumerate(val_data_loader): + # print("sample_idx::",sample_idx) + # with torch.no_grad(): + + # #reshape for input + # image_set = image_set.permute(0,1,4,2,3) # N LF C H W + # image_set = image_set.reshape(image_set.shape[0],-1,image_set.shape[3],image_set.shape[4]) # N, LFxC, H, W + # image_set = var_or_cuda(image_set) + + # # print("Epoch:",epoch,",Iter:",batch_idx,",Input shape:",image_set.shape, ",Input gt:",gt.shape) + # output = lf_model(image_set,df,gazeX,gazeY) + # output1,mask = GenRetinalGazeFromLayersBatch(output, gen, sample_idx, phi_dict, mask_dict) + + # for i in range(0, 2): + # output1[:,i*3:i*3+3].mul_(mask[:,i:i+1]) + # output1[:,i*3:i*3+3].clamp_(0., 1.) + + # print("output:",output.shape," df:",df[0].data, ",gazeX:",gazeX[0].data,",gazeY:", gazeY[0].data) + # for i in range(output1.size()[0]): + # save_image(output[i][0:3].data,os.path.join(OUTPUT_DIR,"test_interp_gaze_fac1_o_%.3f_%.3f_%.3f.png"%(df[i].data,gazeX[i].data,gazeY[i].data))) + # save_image(output[i][3:6].data,os.path.join(OUTPUT_DIR,"test_interp_gaze_fac2_o_%.3f_%.3f_%.3f.png"%(df[i].data,gazeX[i].data,gazeY[i].data))) + # save_image(output1[i][0:3].data,os.path.join(OUTPUT_DIR,"test_interp_gaze_out1_o_%.3f_%.3f_%.3f.png"%(df[i].data,gazeX[i].data,gazeY[i].data))) + # save_image(output1[i][3:6].data,os.path.join(OUTPUT_DIR,"test_interp_gaze_out2_o_%.3f_%.3f_%.3f.png"%(df[i].data,gazeX[i].data,gazeY[i].data))) + + # # save_image(output[0][0:3].data,os.path.join(OUTPUT_DIR,"gaze_fovea_interp_l1_%.3f.png"%(df[0].data))) + # # save_image(output[0][3:6].data,os.path.join(OUTPUT_DIR,"gaze_fovea_interp_l2_%.3f.png"%(df[0].data))) + # # output = GenRetinalFromLayersBatch(output,conf,df,v,u) + # # save_image(output[0][0:3].data,os.path.join(OUTPUT_DIR,"1113_interp_o%.3f.png"%(df[0].data))) + # exit() + ################################################ train ######################################################### - device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') - lf_model = model() + lf_model = model(FIRSST_LAYER_CHANNELS,LAST_LAYER_CHANNELS,OUT_CHANNELS_RB,KERNEL_SIZE,KERNEL_SIZE_RB,INTERLEAVE_RATE) lf_model.apply(weight_init_normal) epoch_begin = 0 + ################################ load model file + # WEIGHTS = os.path.join(OUTPUT_DIR, 'gaze-ckpt-epoch-%04d.pth' % (101)) + # print('[INFO] Recovering from %s ...' % (WEIGHTS)) + # checkpoint = torch.load(WEIGHTS) + # init_epoch = checkpoint['epoch_idx'] + # lf_model.load_state_dict(checkpoint['model_state_dict']) + # epoch_begin = init_epoch + 1 + # print(lf_model) + ############################################################ + if torch.cuda.is_available(): - lf_model = torch.nn.DataParallel(lf_model).cuda() + # lf_model = torch.nn.DataParallel(lf_model).cuda() + lf_model = lf_model.to('cuda:1') lf_model.train() optimizer = torch.optim.Adam(lf_model.parameters(),lr=1e-2,betas=(0.9,0.999)) - + l1loss = torch.nn.L1Loss() + # lf_model.output_layer.register_backward_hook(hook_fn_back) print("begin training....") for epoch in range(epoch_begin, NUM_EPOCH): - for batch_idx, (image_set, gt, df, gazeX, gazeY, sample_idx) in enumerate(train_data_loader): + for batch_idx, (image_set, gt, gt2, df, gazeX, gazeY, sample_idx) in enumerate(train_data_loader): + # print(sample_idx.shape,df.shape,gazeX.shape,gazeY.shape) # torch.Size([2, 5]) + # print(image_set.shape,gt.shape,gt2.shape) #torch.Size([2, 5, 9, 320, 320, 3]) torch.Size([2, 5, 160, 160, 3]) torch.Size([2, 5, 160, 160, 3]) + # print(delta.shape) # delta: torch.Size([2, 4, 160, 160, 3]) + #reshape for input - image_set = image_set.permute(0,1,4,2,3) # N LF C H W - image_set = image_set.reshape(image_set.shape[0],-1,image_set.shape[3],image_set.shape[4]) # N, LFxC, H, W + image_set = image_set.permute(0,1,2,5,3,4) # N S LF C H W + image_set = image_set.reshape(image_set.shape[0],image_set.shape[1],-1,image_set.shape[4],image_set.shape[5]) # N, LFxC, H, W image_set = var_or_cuda(image_set) - gt = gt.permute(0,3,1,2) + gt = gt.permute(0,1,4,2,3) # N S C H W gt = var_or_cuda(gt) - optimizer.zero_grad() - output = lf_model(image_set,df,gazeX,gazeY) - ########################### Use Pregen Phi and Mask ################### - output1,mask = GenRetinalFromLayersBatch(output, gen, sample_idx, phi_dict, mask_dict) - mask = var_or_cuda(mask) - mask.requires_grad = False - output_f = output1 * mask - gt = gt * mask - loss = loss_new(output_f,gt) - print("Epoch:",epoch,",Iter:",batch_idx,",loss:",loss) + gt2 = gt2.permute(0,1,4,2,3) + gt2 = var_or_cuda(gt2) + + gen1 = torch.empty(gt.shape) + gen1 = var_or_cuda(gen1) + + gen2 = torch.empty(gt2.shape) + gen2 = var_or_cuda(gen2) + + warped = torch.empty(gt2.shape[0],gt2.shape[1]-1,gt2.shape[2],gt2.shape[3],gt2.shape[4]) + warped = var_or_cuda(warped) + + delta = torch.empty(gt2.shape[0],gt2.shape[1]-1,gt2.shape[2],gt2.shape[3],gt2.shape[4]) + delta = var_or_cuda(delta) + + for k in range(image_set.shape[1]): + if k == 0: + lf_model.reset_hidden(image_set[:,k]) + + # start = torch.cuda.Event(enable_timing=True) + # end = torch.cuda.Event(enable_timing=True) + # start.record() + output = lf_model(image_set[:,k],df[:,k],gazeX[:,k],gazeY[:,k]) + # end.record() + # torch.cuda.synchronize() + # print("Model Forward:",start.elapsed_time(end)) + # print("output:",output.shape) # [2, 6, 320, 320] + # exit() + ########################### Use Pregen Phi and Mask ################### + # start.record() + output1,mask = GenRetinalGazeFromLayersBatch(output, gen, sample_idx[:,k], phi_dict, mask_dict) + # end.record() + # torch.cuda.synchronize() + # print("Merge:",start.elapsed_time(end)) + + # print("output1 shape:",output1.shape, "mask shape:",mask.shape) + # output1 shape: torch.Size([2, 6, 160, 160]) mask shape: torch.Size([2, 2, 160, 160]) + for i in range(0, 2): + output1[:,i*3:i*3+3].mul_(mask[:,i:i+1]) + if i == 0: + gt[:,k].mul_(mask[:,i:i+1]) + if i == 1: + gt2[:,k].mul_(mask[:,i:i+1]) + + gen1[:,k] = output1[:,0:3] + gen2[:,k] = output1[:,3:6] + if ((epoch%5== 0) or epoch == 2): + for i in range(output.shape[0]): + save_image(output[i][0:3].data,os.path.join(OUTPUT_DIR,"gaze_fac1_o_%.3f_%.3f_%.3f.png"%(df[i][k].data,gazeX[i][k].data,gazeY[i][k].data))) + save_image(output[i][3:6].data,os.path.join(OUTPUT_DIR,"gaze_fac2_o_%.3f_%.3f_%.3f.png"%(df[i][k].data,gazeX[i][k].data,gazeY[i][k].data))) ########################### Update ################### + for i in range(1,image_set.shape[1]): + delta[:,i-1] = gt2[:,i] - gt2[:,i] + warped[:,i-1] = gen2[:,i]-gen2[:,i-1] + + optimizer.zero_grad() + + # # N S C H W + gen1 = gen1.reshape(-1,gen1.shape[2],gen1.shape[3],gen1.shape[4]) + gen2 = gen2.reshape(-1,gen2.shape[2],gen2.shape[3],gen2.shape[4]) + gt = gt.reshape(-1,gt.shape[2],gt.shape[3],gt.shape[4]) + gt2 = gt2.reshape(-1,gt2.shape[2],gt2.shape[3],gt2.shape[4]) + warped = warped.reshape(-1,warped.shape[2],warped.shape[3],warped.shape[4]) + delta = delta.reshape(-1,delta.shape[2],delta.shape[3],delta.shape[4]) + + + # start = torch.cuda.Event(enable_timing=True) + # end = torch.cuda.Event(enable_timing=True) + # start.record() + loss1 = loss_new(gen1,gt) + loss2 = loss_new(gen2,gt2) + loss3 = l1loss(warped,delta) + loss = loss1+loss2+loss3 + # end.record() + # torch.cuda.synchronize() + # print("loss comp:",start.elapsed_time(end)) + + + # start.record() loss.backward() - optimizer.step() + # end.record() + # torch.cuda.synchronize() + # print("backward:",start.elapsed_time(end)) + # start.record() + optimizer.step() + # end.record() + # torch.cuda.synchronize() + # print("optimizer step:",start.elapsed_time(end)) + + ## Update Prev + print("Epoch:",epoch,",Iter:",batch_idx,",loss:",loss) ########################### Save ##################### - if ((epoch%50== 0) or epoch == 5): - for i in range(output_f.size()[0]): - save_image(output[i][0:3].data,os.path.join(OUTPUT_DIR,"gaze_fac1_o_%.3f_%.3f_%.3f.png"%(df[i].data,gazeX[i].data,gazeY[i].data))) - save_image(output[i][3:6].data,os.path.join(OUTPUT_DIR,"gaze_fac2_o_%.3f_%.3f_%.3f.png"%(df[i].data,gazeX[i].data,gazeY[i].data))) - save_image(output_f[i][0:3].data,os.path.join(OUTPUT_DIR,"gaze_test1_o_%.3f_%.3f_%.3f.png"%(df[i].data,gazeX[i].data,gazeY[i].data))) - if ((epoch%200 == 0) and epoch != 0 and batch_idx==len(train_data_loader)-1): - save_checkpoints(os.path.join(OUTPUT_DIR, 'gaze-ckpt-epoch-%04d.pth' % (epoch + 1)), - epoch,lf_model,optimizer) \ No newline at end of file + if ((epoch%5== 0) or epoch == 2): # torch.Size([2, 5, 160, 160, 3]) + for i in range(gt.size()[0]): + # df 2,5 + save_image(gen1[i].data,os.path.join(OUTPUT_DIR,"gaze_out1_o_%.3f_%.3f_%.3f.png"%(df[i//5][i%5].data,gazeX[i//5][i%5].data,gazeY[i//5][i%5].data))) + save_image(gen2[i].data,os.path.join(OUTPUT_DIR,"gaze_out2_o_%.3f_%.3f_%.3f.png"%(df[i//5][i%5].data,gazeX[i//5][i%5].data,gazeY[i//5][i%5].data))) + save_image(gt[i].data,os.path.join(OUTPUT_DIR,"gaze_test1_gt0_%.3f_%.3f_%.3f.png"%(df[i//5][i%5].data,gazeX[i//5][i%5].data,gazeY[i//5][i%5].data))) + save_image(gt2[i].data,os.path.join(OUTPUT_DIR,"gaze_test1_gt1_%.3f_%.3f_%.3f.png"%(df[i//5][i%5].data,gazeX[i//5][i%5].data,gazeY[i//5][i%5].data))) + if ((epoch%100 == 0) and epoch != 0 and batch_idx==len(train_data_loader)-1): + save_checkpoints(os.path.join(OUTPUT_DIR, 'gaze-ckpt-epoch-%04d.pth' % (epoch + 1)),epoch,lf_model,optimizer) + + ########################## test Phi and Mask ########################## + # phi,mask = gen.CalculateRetinal2LayerMappings(df[0],torch.tensor([gazeX[0], gazeY[0]])) + # # print("gaze Online:",gazeX[0]," ,",gazeY[0]) + # # print("df Online:",df[0]) + # # print("idx:",int(sample_idx[0].data)) + # phi_t = phi_dict[int(sample_idx[0].data)] + # mask_t = mask_dict[int(sample_idx[0].data)] + # # print("idx info:",idx_info_dict[int(sample_idx[0].data)]) + # # print("phi online:", phi.shape, " phi_t:", phi_t.shape) + # # print("mask online:", mask.shape, " mask_t:", mask_t.shape) + # print("phi delta:", (phi-phi_t).sum()," mask delta:",(mask -mask_t).sum()) + # exit(0) + + ###########################Gen Batch 1 by 1################### + # phi,mask = gen.CalculateRetinal2LayerMappings(df[0],torch.tensor([gazeX[0], gazeY[0]])) + # # print(phi.shape) # 2,240,320,41,2 + # output1, mask = GenRetinalFromLayersBatch_Online(output, gen, phi, mask) + ###########################Gen Batch 1 by 1################### \ No newline at end of file diff --git a/model/__init__.py b/model/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/model/baseline.py b/model/baseline.py new file mode 100644 index 0000000..d829464 --- /dev/null +++ b/model/baseline.py @@ -0,0 +1,167 @@ +import torch +def var_or_cuda(x): + if torch.cuda.is_available(): + # x = x.cuda(non_blocking=True) + x = x.to('cuda:1') + return x + +class residual_block(torch.nn.Module): + def __init__(self, OUT_CHANNELS_RB, delta_channel_dim,KERNEL_SIZE_RB,RNN=False): + super(residual_block,self).__init__() + self.delta_channel_dim = delta_channel_dim + self.out_channels_rb = OUT_CHANNELS_RB + self.hidden = None + self.RNN = RNN + if self.RNN: + self.layer1 = torch.nn.Sequential( + torch.nn.Conv2d((OUT_CHANNELS_RB+delta_channel_dim)*2,OUT_CHANNELS_RB+delta_channel_dim,KERNEL_SIZE_RB,stride=1,padding = 1), + torch.nn.BatchNorm2d(OUT_CHANNELS_RB+delta_channel_dim), + torch.nn.ELU() + ) + self.layer2 = torch.nn.Sequential( + torch.nn.Conv2d(OUT_CHANNELS_RB+delta_channel_dim,OUT_CHANNELS_RB+delta_channel_dim,KERNEL_SIZE_RB,stride=1,padding = 1), + torch.nn.BatchNorm2d(OUT_CHANNELS_RB+delta_channel_dim), + torch.nn.ELU() + ) + else: + self.layer1 = torch.nn.Sequential( + torch.nn.Conv2d(OUT_CHANNELS_RB+delta_channel_dim,OUT_CHANNELS_RB+delta_channel_dim,KERNEL_SIZE_RB,stride=1,padding = 1), + torch.nn.BatchNorm2d(OUT_CHANNELS_RB+delta_channel_dim), + torch.nn.ELU() + ) + self.layer2 = torch.nn.Sequential( + torch.nn.Conv2d(OUT_CHANNELS_RB+delta_channel_dim,OUT_CHANNELS_RB+delta_channel_dim,KERNEL_SIZE_RB,stride=1,padding = 1), + torch.nn.BatchNorm2d(OUT_CHANNELS_RB+delta_channel_dim), + torch.nn.ELU() + ) + + def forward(self,input): + if self.RNN: + # print("input:",input.shape,"hidden:",self.hidden.shape) + inp = torch.cat((input,self.hidden),dim=1) + # print(inp.shape) + output = self.layer1(inp) + output = self.layer2(output) + output = input+output + self.hidden = output + else: + output = self.layer1(input) + output = self.layer2(output) + output = input+output + return output + + def reset_hidden(self, inp): + size = list(inp.size()) + size[1] = self.delta_channel_dim + self.out_channels_rb + size[2] = size[2]//2 + size[3] = size[3]//2 + hidden = torch.zeros(*(size)) + self.hidden = var_or_cuda(hidden) + +class deinterleave(torch.nn.Module): + def __init__(self, block_size): + super(deinterleave, self).__init__() + self.block_size = block_size + self.block_size_sq = block_size*block_size + + def forward(self, input): + output = input.permute(0, 2, 3, 1) + (batch_size, d_height, d_width, d_depth) = output.size() + s_depth = int(d_depth / self.block_size_sq) + s_width = int(d_width * self.block_size) + s_height = int(d_height * self.block_size) + t_1 = output.reshape(batch_size, d_height, d_width, self.block_size_sq, s_depth) + spl = t_1.split(self.block_size, 3) + stack = [t_t.reshape(batch_size, d_height, s_width, s_depth) for t_t in spl] + output = torch.stack(stack,0).transpose(0,1).permute(0,2,1,3,4).reshape(batch_size, s_height, s_width, s_depth) + output = output.permute(0, 3, 1, 2) + return output + +class interleave(torch.nn.Module): + def __init__(self, block_size): + super(interleave, self).__init__() + self.block_size = block_size + self.block_size_sq = block_size*block_size + + def forward(self, input): + output = input.permute(0, 2, 3, 1) + (batch_size, s_height, s_width, s_depth) = output.size() + d_depth = s_depth * self.block_size_sq + d_width = int(s_width / self.block_size) + d_height = int(s_height / self.block_size) + t_1 = output.split(self.block_size, 2) + stack = [t_t.reshape(batch_size, d_height, d_depth) for t_t in t_1] + output = torch.stack(stack, 1) + output = output.permute(0, 2, 1, 3) + output = output.permute(0, 3, 1, 2) + return output + +class model(torch.nn.Module): + def __init__(self,FIRSST_LAYER_CHANNELS,LAST_LAYER_CHANNELS,OUT_CHANNELS_RB,KERNEL_SIZE,KERNEL_SIZE_RB,INTERLEAVE_RATE): + super(model, self).__init__() + self.interleave = interleave(INTERLEAVE_RATE) + + self.first_layer = torch.nn.Sequential( + torch.nn.Conv2d(FIRSST_LAYER_CHANNELS,OUT_CHANNELS_RB,KERNEL_SIZE,stride=1,padding=1), + torch.nn.BatchNorm2d(OUT_CHANNELS_RB), + torch.nn.ELU() + ) + + self.residual_block1 = residual_block(OUT_CHANNELS_RB,0,KERNEL_SIZE_RB,False) + self.residual_block2 = residual_block(OUT_CHANNELS_RB,3,KERNEL_SIZE_RB,False) + self.residual_block3 = residual_block(OUT_CHANNELS_RB,3,KERNEL_SIZE_RB,True) + self.residual_block4 = residual_block(OUT_CHANNELS_RB,3,KERNEL_SIZE_RB,True) + self.residual_block5 = residual_block(OUT_CHANNELS_RB,3,KERNEL_SIZE_RB,True) + + self.output_layer = torch.nn.Sequential( + torch.nn.Conv2d(OUT_CHANNELS_RB+3,LAST_LAYER_CHANNELS,KERNEL_SIZE,stride=1,padding=1), + torch.nn.BatchNorm2d(LAST_LAYER_CHANNELS), + torch.nn.Sigmoid() + ) + self.deinterleave = deinterleave(INTERLEAVE_RATE) + + def reset_hidden(self,inp): + self.residual_block3.reset_hidden(inp) + self.residual_block4.reset_hidden(inp) + self.residual_block5.reset_hidden(inp) + + def forward(self, lightfield_images, focal_length, gazeX, gazeY): + # lightfield_images: torch.Size([batch_size, channels * D, H, W]) + # channels : RGB*D: 3*9, H:256, W:256 + # print("lightfield_images:",lightfield_images.shape) + input_to_net = self.interleave(lightfield_images) + # print("after interleave:",input_to_net.shape) + input_to_rb = self.first_layer(input_to_net) + + # print("input_to_rb1:",input_to_rb.shape) + output = self.residual_block1(input_to_rb) + + depth_layer = torch.ones((input_to_rb.shape[0],1,input_to_rb.shape[2],input_to_rb.shape[3])) + gazeX_layer = torch.ones((input_to_rb.shape[0],1,input_to_rb.shape[2],input_to_rb.shape[3])) + gazeY_layer = torch.ones((input_to_rb.shape[0],1,input_to_rb.shape[2],input_to_rb.shape[3])) + # print("depth_layer:",depth_layer.shape) + # print("focal_depth:",focal_length," gazeX:",gazeX," gazeY:",gazeY, " gazeX norm:",(gazeX[0] - (-3.333)) / (3.333*2)) + for i in range(focal_length.shape[0]): + depth_layer[i] *= 1. / focal_length[i] + gazeX_layer[i] *= (gazeX[i] - (-3.333)) / (3.333*2) + gazeY_layer[i] *= (gazeY[i] - (-3.333)) / (3.333*2) + # print(depth_layer.shape) + depth_layer = var_or_cuda(depth_layer) + gazeX_layer = var_or_cuda(gazeX_layer) + gazeY_layer = var_or_cuda(gazeY_layer) + + output = torch.cat((output,depth_layer,gazeX_layer,gazeY_layer),dim=1) + # output = torch.cat((output,depth_layer),dim=1) + # print("output to rb2:",output.shape) + + output = self.residual_block2(output) + # print("output to rb3:",output.shape) + output = self.residual_block3(output) + # print("output to rb4:",output.shape) + output = self.residual_block4(output) + # print("output to rb5:",output.shape) + output = self.residual_block5(output) + # output = output + input_to_net + output = self.output_layer(output) + output = self.deinterleave(output) + return output diff --git a/model/recurrent.py b/model/recurrent.py new file mode 100644 index 0000000..3aeaf6a --- /dev/null +++ b/model/recurrent.py @@ -0,0 +1,146 @@ +import torch, os, sys, cv2 +import torch.nn as nn +from torch.nn import init +import functools +import torch.optim as optim + +from torch.utils.data import Dataset, DataLoader +from torch.nn import functional as func +from PIL import Image + +import torchvision.transforms as transforms +import numpy as np +import torch + + +class RecurrentBlock(nn.Module): + + def __init__(self, input_nc, output_nc, downsampling=False, bottleneck=False, upsampling=False): + super(RecurrentBlock, self).__init__() + + self.input_nc = input_nc + self.output_nc = output_nc + + self.downsampling = downsampling + self.upsampling = upsampling + self.bottleneck = bottleneck + + self.hidden = None + + if self.downsampling: + self.l1 = nn.Sequential( + nn.Conv2d(input_nc, output_nc, 3, padding=1), + nn.LeakyReLU(negative_slope=0.1) + ) + self.l2 = nn.Sequential( + nn.Conv2d(2 * output_nc, output_nc, 3, padding=1), + nn.LeakyReLU(negative_slope=0.1), + nn.Conv2d(output_nc, output_nc, 3, padding=1), + nn.LeakyReLU(negative_slope=0.1), + ) + elif self.upsampling: + self.l1 = nn.Sequential( + nn.Upsample(scale_factor=2, mode='nearest'), + nn.Conv2d(2 * input_nc, output_nc, 3, padding=1), + nn.LeakyReLU(negative_slope=0.1), + nn.Conv2d(output_nc, output_nc, 3, padding=1), + nn.LeakyReLU(negative_slope=0.1), + ) + elif self.bottleneck: + self.l1 = nn.Sequential( + nn.Conv2d(input_nc, output_nc, 3, padding=1), + nn.LeakyReLU(negative_slope=0.1) + ) + self.l2 = nn.Sequential( + nn.Conv2d(2 * output_nc, output_nc, 3, padding=1), + nn.LeakyReLU(negative_slope=0.1), + nn.Conv2d(output_nc, output_nc, 3, padding=1), + nn.LeakyReLU(negative_slope=0.1), + ) + + def forward(self, inp): + + if self.downsampling: + op1 = self.l1(inp) + op2 = self.l2(torch.cat((op1, self.hidden), dim=1)) + + self.hidden = op2 + + return op2 + elif self.upsampling: + op1 = self.l1(inp) + + return op1 + elif self.bottleneck: + op1 = self.l1(inp) + op2 = self.l2(torch.cat((op1, self.hidden), dim=1)) + + self.hidden = op2 + + return op2 + + def reset_hidden(self, inp, dfac): + size = list(inp.size()) + size[1] = self.output_nc + size[2] /= dfac + size[3] /= dfac + + self.hidden_size = size + self.hidden = torch.zeros(*(size)).to('cuda:0') + + + +class RecurrentAE(nn.Module): + + def __init__(self, input_nc): + super(RecurrentAE, self).__init__() + + self.d1 = RecurrentBlock(input_nc=input_nc, output_nc=32, downsampling=True) + self.d2 = RecurrentBlock(input_nc=32, output_nc=43, downsampling=True) + self.d3 = RecurrentBlock(input_nc=43, output_nc=57, downsampling=True) + self.d4 = RecurrentBlock(input_nc=57, output_nc=76, downsampling=True) + self.d5 = RecurrentBlock(input_nc=76, output_nc=101, downsampling=True) + + self.bottleneck = RecurrentBlock(input_nc=101, output_nc=101, bottleneck=True) + + self.u5 = RecurrentBlock(input_nc=101, output_nc=76, upsampling=True) + self.u4 = RecurrentBlock(input_nc=76, output_nc=57, upsampling=True) + self.u3 = RecurrentBlock(input_nc=57, output_nc=43, upsampling=True) + self.u2 = RecurrentBlock(input_nc=43, output_nc=32, upsampling=True) + self.u1 = RecurrentBlock(input_nc=32, output_nc=3, upsampling=True) + + def set_input(self, inp): + self.inp = inp['A'] + + def forward(self): + d1 = func.max_pool2d(input=self.d1(self.inp), kernel_size=2) + d2 = func.max_pool2d(input=self.d2(d1), kernel_size=2) + d3 = func.max_pool2d(input=self.d3(d2), kernel_size=2) + d4 = func.max_pool2d(input=self.d4(d3), kernel_size=2) + d5 = func.max_pool2d(input=self.d5(d4), kernel_size=2) + + b = self.bottleneck(d5) + + u5 = self.u5(torch.cat((b, d5), dim=1)) + u4 = self.u4(torch.cat((u5, d4), dim=1)) + u3 = self.u3(torch.cat((u4, d3), dim=1)) + u2 = self.u2(torch.cat((u3, d2), dim=1)) + u1 = self.u1(torch.cat((u2, d1), dim=1)) + + return u1 + + def reset_hidden(self): + self.d1.reset_hidden(self.inp, dfac=1) + self.d2.reset_hidden(self.inp, dfac=2) + self.d3.reset_hidden(self.inp, dfac=4) + self.d4.reset_hidden(self.inp, dfac=8) + self.d5.reset_hidden(self.inp, dfac=16) + + self.bottleneck.reset_hidden(self.inp, dfac=32) + + self.u4.reset_hidden(self.inp, dfac=16) + self.u3.reset_hidden(self.inp, dfac=8) + self.u5.reset_hidden(self.inp, dfac=4) + self.u2.reset_hidden(self.inp, dfac=2) + self.u1.reset_hidden(self.inp, dfac=1) + -- GitLab