import torch import os import glob import numpy as np import torchvision.transforms as transforms from torchvision import datasets from torch.utils.data import DataLoader import cv2 import json from Flow import * from gen_image import * import util import time class lightFieldSynDataLoader(torch.utils.data.dataset.Dataset): def __init__(self, file_dir_path,file_json): self.file_dir_path = file_dir_path with open(file_json, encoding='utf-8') as file: self.dataset_desc = json.loads(file.read()) self.input_img = [] for i in self.dataset_desc["train"]: lf_element = os.path.join(self.file_dir_path,i) lf_element = cv2.imread(lf_element, -cv2.IMREAD_ANYDEPTH)[:, :, 0:3] lf_element = cv2.cvtColor(lf_element, cv2.COLOR_BGR2RGB).astype(np.float32) / 255. lf_element = lf_element[:,:-1,:] self.input_img.append(lf_element) self.input_img = np.asarray(self.input_img) def __len__(self): return len(self.dataset_desc["gt"]) def __getitem__(self, idx): gt, pos_row, pos_col = self.get_datum(idx) return (self.input_img, gt, pos_row, pos_col) def get_datum(self, idx): fd_gt_path = os.path.join(self.file_dir_path, self.dataset_desc["gt"][idx]) gt = cv2.imread(fd_gt_path, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255. gt = cv2.cvtColor(gt,cv2.COLOR_BGR2RGB) gt = gt[:,:-1,:] pos_col = self.dataset_desc["x"][idx] pos_row = self.dataset_desc["y"][idx] return gt, pos_row, pos_col class lightFieldDataLoader(torch.utils.data.dataset.Dataset): def __init__(self, file_dir_path, file_json, transforms=None): self.file_dir_path = file_dir_path self.transforms = transforms # self.datum_list = glob.glob(os.path.join(file_dir_path,"*")) with open(file_json, encoding='utf-8') as file: self.dataset_desc = json.loads(file.read()) def __len__(self): return len(self.dataset_desc["focaldepth"]) def __getitem__(self, idx): lightfield_images, gt, gt2, fd, gazeX, gazeY, sample_idx = self.get_datum(idx) if self.transforms: lightfield_images = self.transforms(lightfield_images) # print(lightfield_images.shape,gt.shape,fd,gazeX,gazeY,sample_idx) return (lightfield_images, gt, gt2, fd, gazeX, gazeY, sample_idx) def get_datum(self, idx): lf_image_paths = os.path.join(self.file_dir_path, self.dataset_desc["train"][idx]) # print(lf_image_paths) fd_gt_path = os.path.join(self.file_dir_path, self.dataset_desc["gt"][idx]) fd_gt_path2 = os.path.join(self.file_dir_path, self.dataset_desc["gt2"][idx]) # print(fd_gt_path) lf_images = [] lf_image_big = cv2.imread(lf_image_paths, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255. lf_image_big = cv2.cvtColor(lf_image_big,cv2.COLOR_BGR2RGB) ## IF GrayScale # lf_image_big = cv2.imread(lf_image_paths, cv2.IMREAD_GRAYSCALE).astype(np.float32) / 255. # lf_image_big = np.expand_dims(lf_image_big, axis=-1) # print(lf_image_big.shape) for i in range(9): lf_image = lf_image_big[i//3*IM_H:i//3*IM_H+IM_H,i%3*IM_W:i%3*IM_W+IM_W,0:3] ## IF GrayScale # lf_image = lf_image_big[i//3*IM_H:i//3*IM_H+IM_H,i%3*IM_W:i%3*IM_W+IM_W,0:1] # print(lf_image.shape) lf_images.append(lf_image) gt = cv2.imread(fd_gt_path, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255. gt = cv2.cvtColor(gt,cv2.COLOR_BGR2RGB) gt2 = cv2.imread(fd_gt_path2, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255. gt2 = cv2.cvtColor(gt2,cv2.COLOR_BGR2RGB) ## IF GrayScale # gt = cv2.imread(fd_gt_path, cv2.IMREAD_GRAYSCALE).astype(np.float32) / 255. # gt = np.expand_dims(gt, axis=-1) fd = self.dataset_desc["focaldepth"][idx] gazeX = self.dataset_desc["gazeX"][idx] gazeY = self.dataset_desc["gazeY"][idx] sample_idx = self.dataset_desc["idx"][idx] return np.asarray(lf_images),gt,gt2,fd,gazeX,gazeY,sample_idx class lightFieldValDataLoader(torch.utils.data.dataset.Dataset): def __init__(self, file_dir_path, file_json, transforms=None): self.file_dir_path = file_dir_path self.transforms = transforms # self.datum_list = glob.glob(os.path.join(file_dir_path,"*")) with open(file_json, encoding='utf-8') as file: self.dataset_desc = json.loads(file.read()) def __len__(self): return len(self.dataset_desc["focaldepth"]) def __getitem__(self, idx): lightfield_images, fd, gazeX, gazeY, sample_idx = self.get_datum(idx) if self.transforms: lightfield_images = self.transforms(lightfield_images) # print(lightfield_images.shape,gt.shape,fd,gazeX,gazeY,sample_idx) return (lightfield_images, fd, gazeX, gazeY, sample_idx) def get_datum(self, idx): lf_image_paths = os.path.join(self.file_dir_path, self.dataset_desc["train"][idx]) # print(fd_gt_path) lf_images = [] lf_image_big = cv2.imread(lf_image_paths, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255. lf_image_big = cv2.cvtColor(lf_image_big,cv2.COLOR_BGR2RGB) ## IF GrayScale # lf_image_big = cv2.imread(lf_image_paths, cv2.IMREAD_GRAYSCALE).astype(np.float32) / 255. # lf_image_big = np.expand_dims(lf_image_big, axis=-1) # print(lf_image_big.shape) for i in range(9): lf_image = lf_image_big[i//3*IM_H:i//3*IM_H+IM_H,i%3*IM_W:i%3*IM_W+IM_W,0:3] ## IF GrayScale # lf_image = lf_image_big[i//3*IM_H:i//3*IM_H+IM_H,i%3*IM_W:i%3*IM_W+IM_W,0:1] # print(lf_image.shape) lf_images.append(lf_image) ## IF GrayScale # gt = cv2.imread(fd_gt_path, cv2.IMREAD_GRAYSCALE).astype(np.float32) / 255. # gt = np.expand_dims(gt, axis=-1) fd = self.dataset_desc["focaldepth"][idx] gazeX = self.dataset_desc["gazeX"][idx] gazeY = self.dataset_desc["gazeY"][idx] sample_idx = self.dataset_desc["idx"][idx] return np.asarray(lf_images),fd,gazeX,gazeY,sample_idx class lightFieldSeqDataLoader(torch.utils.data.dataset.Dataset): def __init__(self, file_dir_path, file_json, transforms=None): self.file_dir_path = file_dir_path self.transforms = transforms with open(file_json, encoding='utf-8') as file: self.dataset_desc = json.loads(file.read()) def __len__(self): return len(self.dataset_desc["seq"]) def __getitem__(self, idx): lightfield_images, gt, gt2, fd, gazeX, gazeY, sample_idx = self.get_datum(idx) fd = fd.astype(np.float32) gazeX = gazeX.astype(np.float32) gazeY = gazeY.astype(np.float32) sample_idx = sample_idx.astype(np.int64) # print(fd) # print(gazeX) # print(gazeY) # print(sample_idx) # print(lightfield_images.dtype,gt.dtype, gt2.dtype, fd.dtype, gazeX.dtype, gazeY.dtype, sample_idx.dtype, delta.dtype) # print(lightfield_images.shape,gt.shape, gt2.shape, fd.shape, gazeX.shape, gazeY.shape, sample_idx.shape, delta.shape) if self.transforms: lightfield_images = self.transforms(lightfield_images) return (lightfield_images, gt, gt2, fd, gazeX, gazeY, sample_idx) def get_datum(self, idx): indices = self.dataset_desc["seq"][idx] # print("indices:",indices) lf_images = [] fd = [] gazeX = [] gazeY = [] sample_idx = [] gt = [] gt2 = [] for i in range(len(indices)): lf_image_paths = os.path.join(self.file_dir_path, self.dataset_desc["train"][indices[i]]) fd_gt_path = os.path.join(self.file_dir_path, self.dataset_desc["gt"][indices[i]]) fd_gt_path2 = os.path.join(self.file_dir_path, self.dataset_desc["gt2"][indices[i]]) lf_image_one_sample = [] lf_image_big = cv2.imread(lf_image_paths, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255. lf_image_big = cv2.cvtColor(lf_image_big,cv2.COLOR_BGR2RGB) for j in range(9): lf_image = lf_image_big[j//3*IM_H:j//3*IM_H+IM_H,j%3*IM_W:j%3*IM_W+IM_W,0:3] lf_image_one_sample.append(lf_image) gt_i = cv2.imread(fd_gt_path, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255. gt.append(cv2.cvtColor(gt_i,cv2.COLOR_BGR2RGB)) gt2_i = cv2.imread(fd_gt_path2, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255. gt2.append(cv2.cvtColor(gt2_i,cv2.COLOR_BGR2RGB)) # print("indices[i]:",indices[i]) fd.append([self.dataset_desc["focaldepth"][indices[i]]]) gazeX.append([self.dataset_desc["gazeX"][indices[i]]]) gazeY.append([self.dataset_desc["gazeY"][indices[i]]]) sample_idx.append([self.dataset_desc["idx"][indices[i]]]) lf_images.append(lf_image_one_sample) #lf_images: 5,9,320,320 return np.asarray(lf_images),np.asarray(gt),np.asarray(gt2),np.asarray(fd),np.asarray(gazeX),np.asarray(gazeY),np.asarray(sample_idx) class lightFieldFlowSeqDataLoader(torch.utils.data.dataset.Dataset): def __init__(self, file_dir_path, file_json, gen, conf, transforms=None): self.file_dir_path = file_dir_path self.file_json = file_json self.gen = gen self.conf = conf self.transforms = transforms with open(file_json, encoding='utf-8') as file: self.dataset_desc = json.loads(file.read()) def __len__(self): return len(self.dataset_desc["seq"]) def __getitem__(self, idx): # start = time.time() lightfield_images, gt, phi, phi_invalid, retinal_invalid, flow, flow_invalid_mask, fd, gazeX, gazeY, posX, posY, posZ, sample_idx = self.get_datum(idx) fd = fd.astype(np.float32) gazeX = gazeX.astype(np.float32) gazeY = gazeY.astype(np.float32) posX = posX.astype(np.float32) posY = posY.astype(np.float32) posZ = posZ.astype(np.float32) sample_idx = sample_idx.astype(np.int64) if self.transforms: lightfield_images = self.transforms(lightfield_images) # print("read once:",time.time() - start) # 8 ms return (lightfield_images, gt, phi, phi_invalid, retinal_invalid, flow, flow_invalid_mask, fd, gazeX, gazeY, posX, posY, posZ, sample_idx) def get_datum(self, idx): IM_H = 320 IM_W = 320 indices = self.dataset_desc["seq"][idx] # print("indices:",indices) lf_images = [] fd = [] gazeX = [] gazeY = [] posX = [] posY = [] posZ = [] sample_idx = [] gt = [] # gt2 = [] phi = [] phi_invalid = [] retinal_invalid = [] for i in range(len(indices)): # 5 lf_image_paths = os.path.join(self.file_dir_path, self.dataset_desc["train"][indices[i]]) fd_gt_path = os.path.join(self.file_dir_path, self.dataset_desc["gt"][indices[i]]) # fd_gt_path2 = os.path.join(self.file_dir_path, self.dataset_desc["gt2"][indices[i]]) lf_image_one_sample = [] lf_image_big = cv2.imread(lf_image_paths, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255. lf_image_big = cv2.cvtColor(lf_image_big,cv2.COLOR_BGR2RGB) lf_dim = int(self.conf.light_field_dim) for j in range(lf_dim**2): lf_image = lf_image_big[j//lf_dim*IM_H:j//lf_dim*IM_H+IM_H,j%lf_dim*IM_W:j%lf_dim*IM_W+IM_W,0:3] lf_image_one_sample.append(lf_image) gt_i = cv2.imread(fd_gt_path, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255. gt.append(cv2.cvtColor(gt_i,cv2.COLOR_BGR2RGB)) # gt2_i = cv2.imread(fd_gt_path2, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255. # gt2.append(cv2.cvtColor(gt2_i,cv2.COLOR_BGR2RGB)) # print("indices[i]:",indices[i]) fd.append([self.dataset_desc["focaldepth"][indices[i]]]) gazeX.append([self.dataset_desc["gazeX"][indices[i]]]) gazeY.append([self.dataset_desc["gazeY"][indices[i]]]) posX.append([self.dataset_desc["x"][indices[i]]]) posY.append([self.dataset_desc["y"][indices[i]]]) posZ.append([0.0]) sample_idx.append([self.dataset_desc["idx"][indices[i]]]) lf_images.append(lf_image_one_sample) idx_i = sample_idx[i][0] focaldepth_i = fd[i][0] gazeX_i = gazeX[i][0] gazeY_i = gazeY[i][0] posX_i = posX[i][0] posY_i = posY[i][0] posZ_i = posZ[i][0] # print("x:%.3f,y:%.3f,z:%.3f;gaze:%.4f,%.4f,focaldepth:%.3f."%(posX_i,posY_i,posZ_i,gazeX_i,gazeY_i,focaldepth_i)) phi_i,phi_invalid_i,retinal_invalid_i = self.gen.CalculateRetinal2LayerMappings(torch.tensor([posX_i, posY_i,posZ_i]),torch.tensor([gazeX_i, gazeY_i]),focaldepth_i) phi.append(phi_i) phi_invalid.append(phi_invalid_i) retinal_invalid.append(retinal_invalid_i) #lf_images: 5,9,320,320 flow = Flow.Load([os.path.join(self.file_dir_path, self.dataset_desc["flow"][indices[i-1]]) for i in range(1,len(indices))]) flow_map = flow.getMap() flow_invalid_mask = flow.b_invalid_mask # print("flow:",flow_map.shape) return np.asarray(lf_images),np.asarray(gt), torch.stack(phi,dim=0), torch.stack(phi_invalid,dim=0),torch.stack(retinal_invalid,dim=0), flow_map, flow_invalid_mask, np.asarray(fd),np.asarray(gazeX),np.asarray(gazeY),np.asarray(posX),np.asarray(posY),np.asarray(posZ),np.asarray(sample_idx)