From 421085dfa6d310f0e12ff8e2941f55b772a89bbc Mon Sep 17 00:00:00 2001 From: BobYeah <635596704@qq.com> Date: Fri, 25 Dec 2020 11:01:44 +0800 Subject: [PATCH] sync --- .gitignore | 5 - Flow.py | 83 -------------- conf.py | 68 ------------ data.py | 305 --------------------------------------------------- loss.py | 80 -------------- perc_loss.py | 38 ------- ssim.py | 72 ------------ util.py | 104 ------------------ 8 files changed, 755 deletions(-) delete mode 100644 Flow.py delete mode 100644 conf.py delete mode 100644 data.py delete mode 100644 loss.py delete mode 100644 perc_loss.py delete mode 100644 ssim.py delete mode 100644 util.py diff --git a/.gitignore b/.gitignore index b2b6ea4..e35eb78 100644 --- a/.gitignore +++ b/.gitignore @@ -54,11 +54,6 @@ coverage.xml *.log local_settings.py -# log -*.txt -*.out -*.ipynb - # Flask stuff: instance/ .webassets-cache diff --git a/Flow.py b/Flow.py deleted file mode 100644 index 83a6b6a..0000000 --- a/Flow.py +++ /dev/null @@ -1,83 +0,0 @@ -import matplotlib.pyplot as plt -import torch -import util -import numpy as np -def FlowMap(b_last_frame, b_map): - ''' - Map images using the flow data. - - Parameters - -------- - b_last_frame - B x 3 x H x W tensor, batch of images - b_map - B x H x W x 2, batch of map data records pixel coords in last frames - - Returns - -------- - B x 3 x H x W tensor, batch of images mapped by flow data - ''' - return torch.nn.functional.grid_sample(b_last_frame, b_map, align_corners=False) - -class Flow(object): - ''' - Class representating optical flow - - Properties - -------- - b_data - B x H x W x 2, batch of flow data - b_map - B x H x W x 2, batch of map data records pixel coords in last frames - b_invalid_mask - B x H x W, batch of masks, indicate invalid elements in corresponding flow data - ''' - def Load(paths): - ''' - Create a Flow instance using a batch of encoded data images loaded from paths - - Parameters - -------- - paths - list of encoded data image paths - - Returns - -------- - Flow instance - ''' - b_encoded_image = util.ReadImageTensor(paths, rgb_only=False, permute=False, batch_dim=True) - return Flow(b_encoded_image) - - def __init__(self, b_encoded_image): - ''' - Initialize a Flow instance from a batch of encoded data images - - Parameters - -------- - b_encoded_image - batch of encoded data images - ''' - b_encoded_image = b_encoded_image.mul(255) - # print("b_encoded_image:",b_encoded_image.shape) - self.b_invalid_mask = (b_encoded_image[:, :, :, 0] == 255) - self.b_data = (b_encoded_image[:, :, :, 0:2] / 254 + b_encoded_image[:, :, :, 2:4] - 127) / 127 - self.b_data[:, :, :, 1] = -self.b_data[:, :, :, 1] - D = self.b_data.size() - grid = util.MeshGrid((D[1], D[2]), True) - self.b_map = (grid - self.b_data - 0.5) * 2 - self.b_map[self.b_invalid_mask] = torch.tensor([ -2.0, -2.0 ]) - - def getMap(self): - return self.b_map - - def Visualize(self, scale_factor = 1): - ''' - Visualize the flow data by "color wheel". - - Parameters - -------- - scale_factor - scale factor of flow data to visualize, default is 1 - - Returns - -------- - B x 3 x H x W tensor, visualization of flow data - ''' - try: - Flow.b_color_wheel - except AttributeError: - Flow.b_color_wheel = util.ReadImageTensor('color_wheel.png') - return torch.nn.functional.grid_sample(Flow.b_color_wheel.expand(self.b_data.size()[0], -1, -1, -1), - (self.b_data * scale_factor), align_corners=False) \ No newline at end of file diff --git a/conf.py b/conf.py deleted file mode 100644 index 2bb9180..0000000 --- a/conf.py +++ /dev/null @@ -1,68 +0,0 @@ -import torch -import util -import numpy as np -class Conf(object): - def __init__(self): - self.pupil_size = 0.02 - self.retinal_res = torch.tensor([ 320, 320 ]) - self.layer_res = torch.tensor([ 320, 320 ]) - self.layer_hfov = 90 # layers' horizontal FOV - self.eye_hfov = 80 # eye's horizontal FOV (ignored in foveated rendering) - self.eye_enable_fovea = False # enable foveated rendering - self.eye_fovea_angles = [ 40, 80 ] # eye's foveation layers' angles - self.eye_fovea_downsamples = [ 1, 2 ] # eye's foveation layers' downsamples - self.d_layer = [ 1, 3 ] # layers' distance - self.eye_fovea_blend = [ self._GenFoveaLayerBlend(0) ] - # blend maps of fovea layers - self.light_field_dim = 5 - def GetNLayers(self): - return len(self.d_layer) - - def GetLayerSize(self, i): - w = util.Fov2Length(self.layer_hfov) - h = w * self.layer_res[0] / self.layer_res[1] - return torch.tensor([ h, w ]) * self.d_layer[i] - - def GetPixelSizeOfLayer(self, i): - ''' - Get pixel size of layer i - ''' - return util.Fov2Length(self.layer_hfov) * self.d_layer[i] / self.layer_res[0] - - def GetEyeViewportSize(self): - fov = self.eye_fovea_angles[-1] if self.eye_enable_fovea else self.eye_hfov - w = util.Fov2Length(fov) - h = w * self.retinal_res[0] / self.retinal_res[1] - return torch.tensor([ h, w ]) - - def GetRegionOfFoveaLayer(self, i): - ''' - Get region of fovea layer i in retinal image - - Returns - -------- - slice object stores the start and end of region - ''' - roi_size = int(np.ceil(self.retinal_res[0] * self.eye_fovea_angles[i] / self.eye_fovea_angles[-1])) - roi_offset = int((self.retinal_res[0] - roi_size) / 2) - return slice(roi_offset, roi_offset + roi_size) - - def _GenFoveaLayerBlend(self, i): - ''' - Generate blend map for fovea layer i - - Parameters - -------- - i - index of fovea layer - - Returns - -------- - H[i] x W[i], blend map - - ''' - region = self.GetRegionOfFoveaLayer(i) - width = region.stop - region.start - R = width / 2 - p = util.MeshGrid([ width, width ]) - r = torch.linalg.norm(p - R, 2, dim=2, keepdim=False) - return util.SmoothStep(R, R * 0.6, r) diff --git a/data.py b/data.py deleted file mode 100644 index 14b5a2c..0000000 --- a/data.py +++ /dev/null @@ -1,305 +0,0 @@ -import torch -import os -import glob -import numpy as np -import torchvision.transforms as transforms - -from torchvision import datasets -from torch.utils.data import DataLoader - -import cv2 -import json -from Flow import * -from gen_image import * -import util - -import time - - -class lightFieldSynDataLoader(torch.utils.data.dataset.Dataset): - def __init__(self, file_dir_path,file_json): - self.file_dir_path = file_dir_path - with open(file_json, encoding='utf-8') as file: - self.dataset_desc = json.loads(file.read()) - self.input_img = [] - for i in self.dataset_desc["train"]: - lf_element = os.path.join(self.file_dir_path,i) - lf_element = cv2.imread(lf_element, -cv2.IMREAD_ANYDEPTH)[:, :, 0:3] - lf_element = cv2.cvtColor(lf_element, cv2.COLOR_BGR2RGB).astype(np.float32) / 255. - lf_element = lf_element[:,:-1,:] - self.input_img.append(lf_element) - self.input_img = np.asarray(self.input_img) - def __len__(self): - return len(self.dataset_desc["gt"]) - - def __getitem__(self, idx): - gt, pos_row, pos_col = self.get_datum(idx) - return (self.input_img, gt, pos_row, pos_col) - - def get_datum(self, idx): - fd_gt_path = os.path.join(self.file_dir_path, self.dataset_desc["gt"][idx]) - gt = cv2.imread(fd_gt_path, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255. - gt = cv2.cvtColor(gt,cv2.COLOR_BGR2RGB) - gt = gt[:,:-1,:] - pos_col = self.dataset_desc["x"][idx] - pos_row = self.dataset_desc["y"][idx] - return gt, pos_row, pos_col - -class lightFieldDataLoader(torch.utils.data.dataset.Dataset): - def __init__(self, file_dir_path, file_json, transforms=None): - self.file_dir_path = file_dir_path - self.transforms = transforms - # self.datum_list = glob.glob(os.path.join(file_dir_path,"*")) - with open(file_json, encoding='utf-8') as file: - self.dataset_desc = json.loads(file.read()) - - def __len__(self): - return len(self.dataset_desc["focaldepth"]) - - def __getitem__(self, idx): - lightfield_images, gt, gt2, fd, gazeX, gazeY, sample_idx = self.get_datum(idx) - if self.transforms: - lightfield_images = self.transforms(lightfield_images) - # print(lightfield_images.shape,gt.shape,fd,gazeX,gazeY,sample_idx) - return (lightfield_images, gt, gt2, fd, gazeX, gazeY, sample_idx) - - def get_datum(self, idx): - lf_image_paths = os.path.join(self.file_dir_path, self.dataset_desc["train"][idx]) - # print(lf_image_paths) - fd_gt_path = os.path.join(self.file_dir_path, self.dataset_desc["gt"][idx]) - fd_gt_path2 = os.path.join(self.file_dir_path, self.dataset_desc["gt2"][idx]) - # print(fd_gt_path) - lf_images = [] - lf_image_big = cv2.imread(lf_image_paths, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255. - lf_image_big = cv2.cvtColor(lf_image_big,cv2.COLOR_BGR2RGB) - - ## IF GrayScale - # lf_image_big = cv2.imread(lf_image_paths, cv2.IMREAD_GRAYSCALE).astype(np.float32) / 255. - # lf_image_big = np.expand_dims(lf_image_big, axis=-1) - # print(lf_image_big.shape) - - for i in range(9): - lf_image = lf_image_big[i//3*IM_H:i//3*IM_H+IM_H,i%3*IM_W:i%3*IM_W+IM_W,0:3] - ## IF GrayScale - # lf_image = lf_image_big[i//3*IM_H:i//3*IM_H+IM_H,i%3*IM_W:i%3*IM_W+IM_W,0:1] - # print(lf_image.shape) - lf_images.append(lf_image) - gt = cv2.imread(fd_gt_path, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255. - gt = cv2.cvtColor(gt,cv2.COLOR_BGR2RGB) - gt2 = cv2.imread(fd_gt_path2, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255. - gt2 = cv2.cvtColor(gt2,cv2.COLOR_BGR2RGB) - ## IF GrayScale - # gt = cv2.imread(fd_gt_path, cv2.IMREAD_GRAYSCALE).astype(np.float32) / 255. - # gt = np.expand_dims(gt, axis=-1) - - fd = self.dataset_desc["focaldepth"][idx] - gazeX = self.dataset_desc["gazeX"][idx] - gazeY = self.dataset_desc["gazeY"][idx] - sample_idx = self.dataset_desc["idx"][idx] - return np.asarray(lf_images),gt,gt2,fd,gazeX,gazeY,sample_idx - -class lightFieldValDataLoader(torch.utils.data.dataset.Dataset): - def __init__(self, file_dir_path, file_json, transforms=None): - self.file_dir_path = file_dir_path - self.transforms = transforms - # self.datum_list = glob.glob(os.path.join(file_dir_path,"*")) - with open(file_json, encoding='utf-8') as file: - self.dataset_desc = json.loads(file.read()) - - def __len__(self): - return len(self.dataset_desc["focaldepth"]) - - def __getitem__(self, idx): - lightfield_images, fd, gazeX, gazeY, sample_idx = self.get_datum(idx) - if self.transforms: - lightfield_images = self.transforms(lightfield_images) - # print(lightfield_images.shape,gt.shape,fd,gazeX,gazeY,sample_idx) - return (lightfield_images, fd, gazeX, gazeY, sample_idx) - - def get_datum(self, idx): - lf_image_paths = os.path.join(self.file_dir_path, self.dataset_desc["train"][idx]) - # print(fd_gt_path) - lf_images = [] - lf_image_big = cv2.imread(lf_image_paths, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255. - lf_image_big = cv2.cvtColor(lf_image_big,cv2.COLOR_BGR2RGB) - - ## IF GrayScale - # lf_image_big = cv2.imread(lf_image_paths, cv2.IMREAD_GRAYSCALE).astype(np.float32) / 255. - # lf_image_big = np.expand_dims(lf_image_big, axis=-1) - # print(lf_image_big.shape) - - for i in range(9): - lf_image = lf_image_big[i//3*IM_H:i//3*IM_H+IM_H,i%3*IM_W:i%3*IM_W+IM_W,0:3] - ## IF GrayScale - # lf_image = lf_image_big[i//3*IM_H:i//3*IM_H+IM_H,i%3*IM_W:i%3*IM_W+IM_W,0:1] - # print(lf_image.shape) - lf_images.append(lf_image) - ## IF GrayScale - # gt = cv2.imread(fd_gt_path, cv2.IMREAD_GRAYSCALE).astype(np.float32) / 255. - # gt = np.expand_dims(gt, axis=-1) - - fd = self.dataset_desc["focaldepth"][idx] - gazeX = self.dataset_desc["gazeX"][idx] - gazeY = self.dataset_desc["gazeY"][idx] - sample_idx = self.dataset_desc["idx"][idx] - return np.asarray(lf_images),fd,gazeX,gazeY,sample_idx - -class lightFieldSeqDataLoader(torch.utils.data.dataset.Dataset): - def __init__(self, file_dir_path, file_json, transforms=None): - self.file_dir_path = file_dir_path - self.transforms = transforms - with open(file_json, encoding='utf-8') as file: - self.dataset_desc = json.loads(file.read()) - - def __len__(self): - return len(self.dataset_desc["seq"]) - - def __getitem__(self, idx): - lightfield_images, gt, gt2, fd, gazeX, gazeY, sample_idx = self.get_datum(idx) - fd = fd.astype(np.float32) - gazeX = gazeX.astype(np.float32) - gazeY = gazeY.astype(np.float32) - sample_idx = sample_idx.astype(np.int64) - # print(fd) - # print(gazeX) - # print(gazeY) - # print(sample_idx) - - # print(lightfield_images.dtype,gt.dtype, gt2.dtype, fd.dtype, gazeX.dtype, gazeY.dtype, sample_idx.dtype, delta.dtype) - # print(lightfield_images.shape,gt.shape, gt2.shape, fd.shape, gazeX.shape, gazeY.shape, sample_idx.shape, delta.shape) - if self.transforms: - lightfield_images = self.transforms(lightfield_images) - return (lightfield_images, gt, gt2, fd, gazeX, gazeY, sample_idx) - - def get_datum(self, idx): - indices = self.dataset_desc["seq"][idx] - # print("indices:",indices) - lf_images = [] - fd = [] - gazeX = [] - gazeY = [] - sample_idx = [] - gt = [] - gt2 = [] - for i in range(len(indices)): - lf_image_paths = os.path.join(self.file_dir_path, self.dataset_desc["train"][indices[i]]) - fd_gt_path = os.path.join(self.file_dir_path, self.dataset_desc["gt"][indices[i]]) - fd_gt_path2 = os.path.join(self.file_dir_path, self.dataset_desc["gt2"][indices[i]]) - lf_image_one_sample = [] - lf_image_big = cv2.imread(lf_image_paths, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255. - lf_image_big = cv2.cvtColor(lf_image_big,cv2.COLOR_BGR2RGB) - - for j in range(9): - lf_image = lf_image_big[j//3*IM_H:j//3*IM_H+IM_H,j%3*IM_W:j%3*IM_W+IM_W,0:3] - lf_image_one_sample.append(lf_image) - - gt_i = cv2.imread(fd_gt_path, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255. - gt.append(cv2.cvtColor(gt_i,cv2.COLOR_BGR2RGB)) - gt2_i = cv2.imread(fd_gt_path2, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255. - gt2.append(cv2.cvtColor(gt2_i,cv2.COLOR_BGR2RGB)) - - # print("indices[i]:",indices[i]) - fd.append([self.dataset_desc["focaldepth"][indices[i]]]) - gazeX.append([self.dataset_desc["gazeX"][indices[i]]]) - gazeY.append([self.dataset_desc["gazeY"][indices[i]]]) - sample_idx.append([self.dataset_desc["idx"][indices[i]]]) - lf_images.append(lf_image_one_sample) - #lf_images: 5,9,320,320 - - return np.asarray(lf_images),np.asarray(gt),np.asarray(gt2),np.asarray(fd),np.asarray(gazeX),np.asarray(gazeY),np.asarray(sample_idx) - -class lightFieldFlowSeqDataLoader(torch.utils.data.dataset.Dataset): - def __init__(self, file_dir_path, file_json, gen, conf, transforms=None): - self.file_dir_path = file_dir_path - self.file_json = file_json - self.gen = gen - self.conf = conf - self.transforms = transforms - with open(file_json, encoding='utf-8') as file: - self.dataset_desc = json.loads(file.read()) - - def __len__(self): - return len(self.dataset_desc["seq"]) - - def __getitem__(self, idx): - # start = time.time() - lightfield_images, gt, phi, phi_invalid, retinal_invalid, flow, flow_invalid_mask, fd, gazeX, gazeY, posX, posY, posZ, sample_idx = self.get_datum(idx) - fd = fd.astype(np.float32) - gazeX = gazeX.astype(np.float32) - gazeY = gazeY.astype(np.float32) - posX = posX.astype(np.float32) - posY = posY.astype(np.float32) - posZ = posZ.astype(np.float32) - sample_idx = sample_idx.astype(np.int64) - - if self.transforms: - lightfield_images = self.transforms(lightfield_images) - # print("read once:",time.time() - start) # 8 ms - return (lightfield_images, gt, phi, phi_invalid, retinal_invalid, flow, flow_invalid_mask, fd, gazeX, gazeY, posX, posY, posZ, sample_idx) - - def get_datum(self, idx): - IM_H = 320 - IM_W = 320 - indices = self.dataset_desc["seq"][idx] - # print("indices:",indices) - lf_images = [] - fd = [] - gazeX = [] - gazeY = [] - posX = [] - posY = [] - posZ = [] - sample_idx = [] - gt = [] - # gt2 = [] - phi = [] - phi_invalid = [] - retinal_invalid = [] - for i in range(len(indices)): # 5 - lf_image_paths = os.path.join(self.file_dir_path, self.dataset_desc["train"][indices[i]]) - fd_gt_path = os.path.join(self.file_dir_path, self.dataset_desc["gt"][indices[i]]) - # fd_gt_path2 = os.path.join(self.file_dir_path, self.dataset_desc["gt2"][indices[i]]) - lf_image_one_sample = [] - lf_image_big = cv2.imread(lf_image_paths, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255. - lf_image_big = cv2.cvtColor(lf_image_big,cv2.COLOR_BGR2RGB) - - lf_dim = int(self.conf.light_field_dim) - for j in range(lf_dim**2): - lf_image = lf_image_big[j//lf_dim*IM_H:j//lf_dim*IM_H+IM_H,j%lf_dim*IM_W:j%lf_dim*IM_W+IM_W,0:3] - lf_image_one_sample.append(lf_image) - - gt_i = cv2.imread(fd_gt_path, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255. - gt.append(cv2.cvtColor(gt_i,cv2.COLOR_BGR2RGB)) - # gt2_i = cv2.imread(fd_gt_path2, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255. - # gt2.append(cv2.cvtColor(gt2_i,cv2.COLOR_BGR2RGB)) - - # print("indices[i]:",indices[i]) - fd.append([self.dataset_desc["focaldepth"][indices[i]]]) - gazeX.append([self.dataset_desc["gazeX"][indices[i]]]) - gazeY.append([self.dataset_desc["gazeY"][indices[i]]]) - posX.append([self.dataset_desc["x"][indices[i]]]) - posY.append([self.dataset_desc["y"][indices[i]]]) - posZ.append([0.0]) - sample_idx.append([self.dataset_desc["idx"][indices[i]]]) - lf_images.append(lf_image_one_sample) - - idx_i = sample_idx[i][0] - focaldepth_i = fd[i][0] - gazeX_i = gazeX[i][0] - gazeY_i = gazeY[i][0] - posX_i = posX[i][0] - posY_i = posY[i][0] - posZ_i = posZ[i][0] - # print("x:%.3f,y:%.3f,z:%.3f;gaze:%.4f,%.4f,focaldepth:%.3f."%(posX_i,posY_i,posZ_i,gazeX_i,gazeY_i,focaldepth_i)) - phi_i,phi_invalid_i,retinal_invalid_i = self.gen.CalculateRetinal2LayerMappings(torch.tensor([posX_i, posY_i,posZ_i]),torch.tensor([gazeX_i, gazeY_i]),focaldepth_i) - - phi.append(phi_i) - phi_invalid.append(phi_invalid_i) - retinal_invalid.append(retinal_invalid_i) - #lf_images: 5,9,320,320 - flow = Flow.Load([os.path.join(self.file_dir_path, self.dataset_desc["flow"][indices[i-1]]) for i in range(1,len(indices))]) - flow_map = flow.getMap() - flow_invalid_mask = flow.b_invalid_mask - # print("flow:",flow_map.shape) - - return np.asarray(lf_images),np.asarray(gt), torch.stack(phi,dim=0), torch.stack(phi_invalid,dim=0),torch.stack(retinal_invalid,dim=0), flow_map, flow_invalid_mask, np.asarray(fd),np.asarray(gazeX),np.asarray(gazeY),np.asarray(posX),np.asarray(posY),np.asarray(posZ),np.asarray(sample_idx) diff --git a/loss.py b/loss.py deleted file mode 100644 index 7307a67..0000000 --- a/loss.py +++ /dev/null @@ -1,80 +0,0 @@ -import torch -from ssim import * -from perc_loss import * - -l1loss = torch.nn.L1Loss() -perc_loss = VGGPerceptualLoss() -perc_loss = perc_loss.to("cuda:1") - -##### LOSS ##### -def calImageGradients(images): - # x is a 4-D tensor - dx = images[:, :, 1:, :] - images[:, :, :-1, :] - dy = images[:, :, :, 1:] - images[:, :, :, :-1] - return dx, dy - -def loss_new(generated, gt): - mse_loss = torch.nn.MSELoss() - rmse_intensity = mse_loss(generated, gt) - psnr_intensity = torch.log10(rmse_intensity) - # print("psnr:",psnr_intensity) - # ssim_intensity = ssim(generated, gt) - labels_dx, labels_dy = calImageGradients(gt) - # print("generated:",generated.shape) - preds_dx, preds_dy = calImageGradients(generated) - rmse_grad_x, rmse_grad_y = mse_loss(labels_dx, preds_dx), mse_loss(labels_dy, preds_dy) - psnr_grad_x, psnr_grad_y = torch.log10(rmse_grad_x), torch.log10(rmse_grad_y) - # print("psnr x&y:",psnr_grad_x," ",psnr_grad_y) - p_loss = perc_loss(generated,gt) - # print("-psnr:",-psnr_intensity,",0.5*(psnr_grad_x + psnr_grad_y):",0.5*(psnr_grad_x + psnr_grad_y),",perc_loss:",p_loss) - total_loss = psnr_intensity + 0.5*(psnr_grad_x + psnr_grad_y) + p_loss - # total_loss = rmse_intensity + 0.5*(rmse_grad_x + rmse_grad_y) # + p_loss - return total_loss - -def loss_without_perc(generated, gt): - mse_loss = torch.nn.MSELoss() - rmse_intensity = mse_loss(generated, gt) - psnr_intensity = torch.log10(rmse_intensity) - # print("psnr:",psnr_intensity) - # ssim_intensity = ssim(generated, gt) - labels_dx, labels_dy = calImageGradients(gt) - # print("generated:",generated.shape) - preds_dx, preds_dy = calImageGradients(generated) - rmse_grad_x, rmse_grad_y = mse_loss(labels_dx, preds_dx), mse_loss(labels_dy, preds_dy) - psnr_grad_x, psnr_grad_y = torch.log10(rmse_grad_x), torch.log10(rmse_grad_y) - # print("psnr x&y:",psnr_grad_x," ",psnr_grad_y) - # print("-psnr:",-psnr_intensity,",0.5*(psnr_grad_x + psnr_grad_y):",0.5*(psnr_grad_x + psnr_grad_y),",perc_loss:",p_loss) - total_loss = psnr_intensity + 0.5*(psnr_grad_x + psnr_grad_y) - # total_loss = rmse_intensity + 0.5*(rmse_grad_x + rmse_grad_y) # + p_loss - return total_loss -##### LOSS ##### - - -class ReconstructionLoss(torch.nn.Module): - def __init__(self): - super(ReconstructionLoss, self).__init__() - - def forward(self, generated, gt): - rmse_intensity = torch.nn.functional.mse_loss(generated, gt) - psnr_intensity = torch.log10(rmse_intensity) - labels_dx, labels_dy = calImageGradients(gt) - preds_dx, preds_dy = calImageGradients(generated) - rmse_grad_x, rmse_grad_y = torch.nn.functional.mse_loss(labels_dx, preds_dx), torch.nn.functional.mse_loss(labels_dy, preds_dy) - psnr_grad_x, psnr_grad_y = torch.log10(rmse_grad_x), torch.log10(rmse_grad_y) - total_loss = psnr_intensity + 0.5*(psnr_grad_x + psnr_grad_y) - return total_loss - -class PerceptionReconstructionLoss(torch.nn.Module): - def __init__(self): - super(PerceptionReconstructionLoss, self).__init__() - - def forward(self, generated, gt): - rmse_intensity = torch.nn.functional.mse_loss(generated, gt) - psnr_intensity = torch.log10(rmse_intensity) - labels_dx, labels_dy = calImageGradients(gt) - preds_dx, preds_dy = calImageGradients(generated) - rmse_grad_x, rmse_grad_y = torch.nn.functional.mse_loss(labels_dx, preds_dx), torch.nn.functional.mse_loss(labels_dy, preds_dy) - psnr_grad_x, psnr_grad_y = torch.log10(rmse_grad_x), torch.log10(rmse_grad_y) - p_loss = perc_loss(generated,gt) - total_loss = psnr_intensity + 0.5*(psnr_grad_x + psnr_grad_y) + p_loss - return total_loss diff --git a/perc_loss.py b/perc_loss.py deleted file mode 100644 index acad75a..0000000 --- a/perc_loss.py +++ /dev/null @@ -1,38 +0,0 @@ - -import torch -import torchvision - -class VGGPerceptualLoss(torch.nn.Module): - def __init__(self, resize=True): - super(VGGPerceptualLoss, self).__init__() - blocks = [] - blocks.append(torchvision.models.vgg16(pretrained=True).features[:4].eval()) - blocks.append(torchvision.models.vgg16(pretrained=True).features[4:9].eval()) - blocks.append(torchvision.models.vgg16(pretrained=True).features[9:16].eval()) - blocks.append(torchvision.models.vgg16(pretrained=True).features[16:23].eval()) - for bl in blocks: - for p in bl: - p.requires_grad = False - self.blocks = torch.nn.ModuleList(blocks) - self.transform = torch.nn.functional.interpolate - self.mean = torch.nn.Parameter(torch.tensor([0.485, 0.456, 0.406]).view(1,3,1,1)) - self.std = torch.nn.Parameter(torch.tensor([0.229, 0.224, 0.225]).view(1,3,1,1)) - self.resize = resize - - def forward(self, input, target): - if input.shape[1] != 3: - input = input.repeat(1, 3, 1, 1) - target = target.repeat(1, 3, 1, 1) - input = (input-self.mean) / self.std - target = (target-self.mean) / self.std - if self.resize: - input = self.transform(input, mode='bilinear', size=(224, 224), align_corners=False) - target = self.transform(target, mode='bilinear', size=(224, 224), align_corners=False) - loss = 0.0 - x = input - y = target - for block in self.blocks: - x = block(x) - y = block(y) - loss += torch.nn.functional.l1_loss(x, y) - return loss \ No newline at end of file diff --git a/ssim.py b/ssim.py deleted file mode 100644 index 93f390b..0000000 --- a/ssim.py +++ /dev/null @@ -1,72 +0,0 @@ -import torch -import torch.nn.functional as F -from torch.autograd import Variable -import numpy as np -from math import exp - -def gaussian(window_size, sigma): - gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)]) - return gauss/gauss.sum() - -def create_window(window_size, channel): - _1D_window = gaussian(window_size, 1.5).unsqueeze(1) - _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) - window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) - return window - -def _ssim(img1, img2, window, window_size, channel, size_average = True): - mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel) - mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel) - - mu1_sq = mu1.pow(2) - mu2_sq = mu2.pow(2) - mu1_mu2 = mu1*mu2 - - sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq - sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq - sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2 - - C1 = 0.01**2 - C2 = 0.03**2 - - ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2)) - - if size_average: - return ssim_map.mean() - else: - return ssim_map.mean(1).mean(1).mean(1) - -class SSIM(torch.nn.Module): - def __init__(self, window_size = 11, size_average = True): - super(SSIM, self).__init__() - self.window_size = window_size - self.size_average = size_average - self.channel = 1 - self.window = create_window(window_size, self.channel) - - def forward(self, img1, img2): - (_, channel, _, _) = img1.size() - - if channel == self.channel and self.window.data.type() == img1.data.type(): - window = self.window - else: - window = create_window(self.window_size, channel) - - if img1.is_cuda: - window = window.cuda(img1.get_device()) - window = window.type_as(img1) - - self.window = window - self.channel = channel - - return _ssim(img1, img2, window, self.window_size, channel, self.size_average) - -def ssim(img1, img2, window_size = 11, size_average = True): - (_, channel, _, _) = img1.size() - window = create_window(window_size, channel) - - if img1.is_cuda: - window = window.cuda(img1.get_device()) - window = window.type_as(img1) - - return _ssim(img1, img2, window, window_size, channel, size_average) \ No newline at end of file diff --git a/util.py b/util.py deleted file mode 100644 index 7e6846b..0000000 --- a/util.py +++ /dev/null @@ -1,104 +0,0 @@ -import numpy as np -import torch -import matplotlib.pyplot as plt -import glm - -gvec_type = [ glm.dvec1, glm.dvec2, glm.dvec3, glm.dvec4 ] -gmat_type = [ [ glm.dmat2, glm.dmat2x3, glm.dmat2x4 ], - [ glm.dmat3x2, glm.dmat3, glm.dmat3x4 ], - [ glm.dmat4x2, glm.dmat4x3, glm.dmat4 ] ] - -def Fov2Length(angle): - return np.tan(angle * np.pi / 360) * 2 - -def SmoothStep(x0, x1, x): - y = torch.clamp((x - x0) / (x1 - x0), 0, 1) - return y * y * (3 - 2 * y) - -def MatImg2Tensor(img, permute = True, batch_dim = True): - batch_input = len(img.shape) == 4 - if permute: - t = torch.from_numpy(np.transpose(img, - [0, 3, 1, 2] if batch_input else [2, 0, 1])) - else: - t = torch.from_numpy(img) - if not batch_input and batch_dim: - t = t.unsqueeze(0) - return t - -def MatImg2Numpy(img, permute = True, batch_dim = True): - batch_input = len(img.shape) == 4 - if permute: - t = np.transpose(img, [0, 3, 1, 2] if batch_input else [2, 0, 1]) - else: - t = img - if not batch_input and batch_dim: - t = t.unsqueeze(0) - return t - -def Tensor2MatImg(t): - img = t.squeeze().cpu().numpy() - batch_input = len(img.shape) == 4 - if t.size()[batch_input] <= 4: - return np.transpose(img, [0, 2, 3, 1] if batch_input else [1, 2, 0]) - return img - -def ReadImageTensor(path, permute = True, rgb_only = True, batch_dim = True): - channels = 3 if rgb_only else 4 - if isinstance(path,list): - first_image = plt.imread(path[0])[:, :, 0:channels] - b_image = np.empty((len(path), first_image.shape[0], first_image.shape[1], channels), dtype=np.float32) - b_image[0] = first_image - for i in range(1, len(path)): - b_image[i] = plt.imread(path[i])[:, :, 0:channels] - return MatImg2Tensor(b_image, permute) - return MatImg2Tensor(plt.imread(path)[:, :, 0:channels], permute, batch_dim) - -def ReadImageNumpyArray(path, permute = True, rgb_only = True, batch_dim = True): - channels = 3 if rgb_only else 4 - if isinstance(path,list): - first_image = plt.imread(path[0])[:, :, 0:channels] - b_image = np.empty((len(path), first_image.shape[0], first_image.shape[1], channels), dtype=np.float32) - b_image[0] = first_image - for i in range(1, len(path)): - b_image[i] = plt.imread(path[i])[:, :, 0:channels] - return MatImg2Numpy(b_image, permute) - return MatImg2Numpy(plt.imread(path)[:, :, 0:channels], permute, batch_dim) - -def WriteImageTensor(t, path): - image = Tensor2MatImg(t) - if isinstance(path,list): - if len(image.shape) != 4 or image.shape[0] != len(path): - raise ValueError - for i in range(len(path)): - plt.imsave(path[i], image[i]) - else: - if len(image.shape) == 4 and image.shape[0] != 1: - raise ValueError - plt.imsave(path, image) - -def PlotImageTensor(t): - plt.imshow(Tensor2MatImg(t)) - -def Tensor2Glm(t): - t = t.squeeze() - size = t.size() - if len(size) == 1: - if size[0] <= 0 or size[0] > 4: - raise ValueError - return gvec_type[size[0] - 1](t.cpu().numpy()) - if len(size) == 2: - if size[0] <= 1 or size[0] > 4 or size[1] <= 1 or size[1] > 4: - raise ValueError - return gmat_type[size[1] - 2][size[0] - 2](t.cpu().numpy()) - raise ValueError - -def Glm2Tensor(val): - return torch.from_numpy(np.array(val)) - -def MeshGrid(size, normalize=False): - y,x = torch.meshgrid(torch.tensor(range(size[0])), - torch.tensor(range(size[1]))) - if normalize: - return torch.stack([x / (size[1] - 1.), y / (size[0] - 1.)], 2) - return torch.stack([x, y], 2) \ No newline at end of file -- GitLab