Commit 421085df authored by BobYeah's avatar BobYeah
Browse files

sync

parent a52ea2e4
......@@ -54,11 +54,6 @@ coverage.xml
*.log
local_settings.py
# log
*.txt
*.out
*.ipynb
# Flask stuff:
instance/
.webassets-cache
......
import matplotlib.pyplot as plt
import torch
import util
import numpy as np
def FlowMap(b_last_frame, b_map):
'''
Map images using the flow data.
Parameters
--------
b_last_frame - B x 3 x H x W tensor, batch of images
b_map - B x H x W x 2, batch of map data records pixel coords in last frames
Returns
--------
B x 3 x H x W tensor, batch of images mapped by flow data
'''
return torch.nn.functional.grid_sample(b_last_frame, b_map, align_corners=False)
class Flow(object):
'''
Class representating optical flow
Properties
--------
b_data - B x H x W x 2, batch of flow data
b_map - B x H x W x 2, batch of map data records pixel coords in last frames
b_invalid_mask - B x H x W, batch of masks, indicate invalid elements in corresponding flow data
'''
def Load(paths):
'''
Create a Flow instance using a batch of encoded data images loaded from paths
Parameters
--------
paths - list of encoded data image paths
Returns
--------
Flow instance
'''
b_encoded_image = util.ReadImageTensor(paths, rgb_only=False, permute=False, batch_dim=True)
return Flow(b_encoded_image)
def __init__(self, b_encoded_image):
'''
Initialize a Flow instance from a batch of encoded data images
Parameters
--------
b_encoded_image - batch of encoded data images
'''
b_encoded_image = b_encoded_image.mul(255)
# print("b_encoded_image:",b_encoded_image.shape)
self.b_invalid_mask = (b_encoded_image[:, :, :, 0] == 255)
self.b_data = (b_encoded_image[:, :, :, 0:2] / 254 + b_encoded_image[:, :, :, 2:4] - 127) / 127
self.b_data[:, :, :, 1] = -self.b_data[:, :, :, 1]
D = self.b_data.size()
grid = util.MeshGrid((D[1], D[2]), True)
self.b_map = (grid - self.b_data - 0.5) * 2
self.b_map[self.b_invalid_mask] = torch.tensor([ -2.0, -2.0 ])
def getMap(self):
return self.b_map
def Visualize(self, scale_factor = 1):
'''
Visualize the flow data by "color wheel".
Parameters
--------
scale_factor - scale factor of flow data to visualize, default is 1
Returns
--------
B x 3 x H x W tensor, visualization of flow data
'''
try:
Flow.b_color_wheel
except AttributeError:
Flow.b_color_wheel = util.ReadImageTensor('color_wheel.png')
return torch.nn.functional.grid_sample(Flow.b_color_wheel.expand(self.b_data.size()[0], -1, -1, -1),
(self.b_data * scale_factor), align_corners=False)
\ No newline at end of file
import torch
import util
import numpy as np
class Conf(object):
def __init__(self):
self.pupil_size = 0.02
self.retinal_res = torch.tensor([ 320, 320 ])
self.layer_res = torch.tensor([ 320, 320 ])
self.layer_hfov = 90 # layers' horizontal FOV
self.eye_hfov = 80 # eye's horizontal FOV (ignored in foveated rendering)
self.eye_enable_fovea = False # enable foveated rendering
self.eye_fovea_angles = [ 40, 80 ] # eye's foveation layers' angles
self.eye_fovea_downsamples = [ 1, 2 ] # eye's foveation layers' downsamples
self.d_layer = [ 1, 3 ] # layers' distance
self.eye_fovea_blend = [ self._GenFoveaLayerBlend(0) ]
# blend maps of fovea layers
self.light_field_dim = 5
def GetNLayers(self):
return len(self.d_layer)
def GetLayerSize(self, i):
w = util.Fov2Length(self.layer_hfov)
h = w * self.layer_res[0] / self.layer_res[1]
return torch.tensor([ h, w ]) * self.d_layer[i]
def GetPixelSizeOfLayer(self, i):
'''
Get pixel size of layer i
'''
return util.Fov2Length(self.layer_hfov) * self.d_layer[i] / self.layer_res[0]
def GetEyeViewportSize(self):
fov = self.eye_fovea_angles[-1] if self.eye_enable_fovea else self.eye_hfov
w = util.Fov2Length(fov)
h = w * self.retinal_res[0] / self.retinal_res[1]
return torch.tensor([ h, w ])
def GetRegionOfFoveaLayer(self, i):
'''
Get region of fovea layer i in retinal image
Returns
--------
slice object stores the start and end of region
'''
roi_size = int(np.ceil(self.retinal_res[0] * self.eye_fovea_angles[i] / self.eye_fovea_angles[-1]))
roi_offset = int((self.retinal_res[0] - roi_size) / 2)
return slice(roi_offset, roi_offset + roi_size)
def _GenFoveaLayerBlend(self, i):
'''
Generate blend map for fovea layer i
Parameters
--------
i - index of fovea layer
Returns
--------
H[i] x W[i], blend map
'''
region = self.GetRegionOfFoveaLayer(i)
width = region.stop - region.start
R = width / 2
p = util.MeshGrid([ width, width ])
r = torch.linalg.norm(p - R, 2, dim=2, keepdim=False)
return util.SmoothStep(R, R * 0.6, r)
import torch
import os
import glob
import numpy as np
import torchvision.transforms as transforms
from torchvision import datasets
from torch.utils.data import DataLoader
import cv2
import json
from Flow import *
from gen_image import *
import util
import time
class lightFieldSynDataLoader(torch.utils.data.dataset.Dataset):
def __init__(self, file_dir_path,file_json):
self.file_dir_path = file_dir_path
with open(file_json, encoding='utf-8') as file:
self.dataset_desc = json.loads(file.read())
self.input_img = []
for i in self.dataset_desc["train"]:
lf_element = os.path.join(self.file_dir_path,i)
lf_element = cv2.imread(lf_element, -cv2.IMREAD_ANYDEPTH)[:, :, 0:3]
lf_element = cv2.cvtColor(lf_element, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.
lf_element = lf_element[:,:-1,:]
self.input_img.append(lf_element)
self.input_img = np.asarray(self.input_img)
def __len__(self):
return len(self.dataset_desc["gt"])
def __getitem__(self, idx):
gt, pos_row, pos_col = self.get_datum(idx)
return (self.input_img, gt, pos_row, pos_col)
def get_datum(self, idx):
fd_gt_path = os.path.join(self.file_dir_path, self.dataset_desc["gt"][idx])
gt = cv2.imread(fd_gt_path, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.
gt = cv2.cvtColor(gt,cv2.COLOR_BGR2RGB)
gt = gt[:,:-1,:]
pos_col = self.dataset_desc["x"][idx]
pos_row = self.dataset_desc["y"][idx]
return gt, pos_row, pos_col
class lightFieldDataLoader(torch.utils.data.dataset.Dataset):
def __init__(self, file_dir_path, file_json, transforms=None):
self.file_dir_path = file_dir_path
self.transforms = transforms
# self.datum_list = glob.glob(os.path.join(file_dir_path,"*"))
with open(file_json, encoding='utf-8') as file:
self.dataset_desc = json.loads(file.read())
def __len__(self):
return len(self.dataset_desc["focaldepth"])
def __getitem__(self, idx):
lightfield_images, gt, gt2, fd, gazeX, gazeY, sample_idx = self.get_datum(idx)
if self.transforms:
lightfield_images = self.transforms(lightfield_images)
# print(lightfield_images.shape,gt.shape,fd,gazeX,gazeY,sample_idx)
return (lightfield_images, gt, gt2, fd, gazeX, gazeY, sample_idx)
def get_datum(self, idx):
lf_image_paths = os.path.join(self.file_dir_path, self.dataset_desc["train"][idx])
# print(lf_image_paths)
fd_gt_path = os.path.join(self.file_dir_path, self.dataset_desc["gt"][idx])
fd_gt_path2 = os.path.join(self.file_dir_path, self.dataset_desc["gt2"][idx])
# print(fd_gt_path)
lf_images = []
lf_image_big = cv2.imread(lf_image_paths, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.
lf_image_big = cv2.cvtColor(lf_image_big,cv2.COLOR_BGR2RGB)
## IF GrayScale
# lf_image_big = cv2.imread(lf_image_paths, cv2.IMREAD_GRAYSCALE).astype(np.float32) / 255.
# lf_image_big = np.expand_dims(lf_image_big, axis=-1)
# print(lf_image_big.shape)
for i in range(9):
lf_image = lf_image_big[i//3*IM_H:i//3*IM_H+IM_H,i%3*IM_W:i%3*IM_W+IM_W,0:3]
## IF GrayScale
# lf_image = lf_image_big[i//3*IM_H:i//3*IM_H+IM_H,i%3*IM_W:i%3*IM_W+IM_W,0:1]
# print(lf_image.shape)
lf_images.append(lf_image)
gt = cv2.imread(fd_gt_path, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.
gt = cv2.cvtColor(gt,cv2.COLOR_BGR2RGB)
gt2 = cv2.imread(fd_gt_path2, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.
gt2 = cv2.cvtColor(gt2,cv2.COLOR_BGR2RGB)
## IF GrayScale
# gt = cv2.imread(fd_gt_path, cv2.IMREAD_GRAYSCALE).astype(np.float32) / 255.
# gt = np.expand_dims(gt, axis=-1)
fd = self.dataset_desc["focaldepth"][idx]
gazeX = self.dataset_desc["gazeX"][idx]
gazeY = self.dataset_desc["gazeY"][idx]
sample_idx = self.dataset_desc["idx"][idx]
return np.asarray(lf_images),gt,gt2,fd,gazeX,gazeY,sample_idx
class lightFieldValDataLoader(torch.utils.data.dataset.Dataset):
def __init__(self, file_dir_path, file_json, transforms=None):
self.file_dir_path = file_dir_path
self.transforms = transforms
# self.datum_list = glob.glob(os.path.join(file_dir_path,"*"))
with open(file_json, encoding='utf-8') as file:
self.dataset_desc = json.loads(file.read())
def __len__(self):
return len(self.dataset_desc["focaldepth"])
def __getitem__(self, idx):
lightfield_images, fd, gazeX, gazeY, sample_idx = self.get_datum(idx)
if self.transforms:
lightfield_images = self.transforms(lightfield_images)
# print(lightfield_images.shape,gt.shape,fd,gazeX,gazeY,sample_idx)
return (lightfield_images, fd, gazeX, gazeY, sample_idx)
def get_datum(self, idx):
lf_image_paths = os.path.join(self.file_dir_path, self.dataset_desc["train"][idx])
# print(fd_gt_path)
lf_images = []
lf_image_big = cv2.imread(lf_image_paths, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.
lf_image_big = cv2.cvtColor(lf_image_big,cv2.COLOR_BGR2RGB)
## IF GrayScale
# lf_image_big = cv2.imread(lf_image_paths, cv2.IMREAD_GRAYSCALE).astype(np.float32) / 255.
# lf_image_big = np.expand_dims(lf_image_big, axis=-1)
# print(lf_image_big.shape)
for i in range(9):
lf_image = lf_image_big[i//3*IM_H:i//3*IM_H+IM_H,i%3*IM_W:i%3*IM_W+IM_W,0:3]
## IF GrayScale
# lf_image = lf_image_big[i//3*IM_H:i//3*IM_H+IM_H,i%3*IM_W:i%3*IM_W+IM_W,0:1]
# print(lf_image.shape)
lf_images.append(lf_image)
## IF GrayScale
# gt = cv2.imread(fd_gt_path, cv2.IMREAD_GRAYSCALE).astype(np.float32) / 255.
# gt = np.expand_dims(gt, axis=-1)
fd = self.dataset_desc["focaldepth"][idx]
gazeX = self.dataset_desc["gazeX"][idx]
gazeY = self.dataset_desc["gazeY"][idx]
sample_idx = self.dataset_desc["idx"][idx]
return np.asarray(lf_images),fd,gazeX,gazeY,sample_idx
class lightFieldSeqDataLoader(torch.utils.data.dataset.Dataset):
def __init__(self, file_dir_path, file_json, transforms=None):
self.file_dir_path = file_dir_path
self.transforms = transforms
with open(file_json, encoding='utf-8') as file:
self.dataset_desc = json.loads(file.read())
def __len__(self):
return len(self.dataset_desc["seq"])
def __getitem__(self, idx):
lightfield_images, gt, gt2, fd, gazeX, gazeY, sample_idx = self.get_datum(idx)
fd = fd.astype(np.float32)
gazeX = gazeX.astype(np.float32)
gazeY = gazeY.astype(np.float32)
sample_idx = sample_idx.astype(np.int64)
# print(fd)
# print(gazeX)
# print(gazeY)
# print(sample_idx)
# print(lightfield_images.dtype,gt.dtype, gt2.dtype, fd.dtype, gazeX.dtype, gazeY.dtype, sample_idx.dtype, delta.dtype)
# print(lightfield_images.shape,gt.shape, gt2.shape, fd.shape, gazeX.shape, gazeY.shape, sample_idx.shape, delta.shape)
if self.transforms:
lightfield_images = self.transforms(lightfield_images)
return (lightfield_images, gt, gt2, fd, gazeX, gazeY, sample_idx)
def get_datum(self, idx):
indices = self.dataset_desc["seq"][idx]
# print("indices:",indices)
lf_images = []
fd = []
gazeX = []
gazeY = []
sample_idx = []
gt = []
gt2 = []
for i in range(len(indices)):
lf_image_paths = os.path.join(self.file_dir_path, self.dataset_desc["train"][indices[i]])
fd_gt_path = os.path.join(self.file_dir_path, self.dataset_desc["gt"][indices[i]])
fd_gt_path2 = os.path.join(self.file_dir_path, self.dataset_desc["gt2"][indices[i]])
lf_image_one_sample = []
lf_image_big = cv2.imread(lf_image_paths, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.
lf_image_big = cv2.cvtColor(lf_image_big,cv2.COLOR_BGR2RGB)
for j in range(9):
lf_image = lf_image_big[j//3*IM_H:j//3*IM_H+IM_H,j%3*IM_W:j%3*IM_W+IM_W,0:3]
lf_image_one_sample.append(lf_image)
gt_i = cv2.imread(fd_gt_path, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.
gt.append(cv2.cvtColor(gt_i,cv2.COLOR_BGR2RGB))
gt2_i = cv2.imread(fd_gt_path2, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.
gt2.append(cv2.cvtColor(gt2_i,cv2.COLOR_BGR2RGB))
# print("indices[i]:",indices[i])
fd.append([self.dataset_desc["focaldepth"][indices[i]]])
gazeX.append([self.dataset_desc["gazeX"][indices[i]]])
gazeY.append([self.dataset_desc["gazeY"][indices[i]]])
sample_idx.append([self.dataset_desc["idx"][indices[i]]])
lf_images.append(lf_image_one_sample)
#lf_images: 5,9,320,320
return np.asarray(lf_images),np.asarray(gt),np.asarray(gt2),np.asarray(fd),np.asarray(gazeX),np.asarray(gazeY),np.asarray(sample_idx)
class lightFieldFlowSeqDataLoader(torch.utils.data.dataset.Dataset):
def __init__(self, file_dir_path, file_json, gen, conf, transforms=None):
self.file_dir_path = file_dir_path
self.file_json = file_json
self.gen = gen
self.conf = conf
self.transforms = transforms
with open(file_json, encoding='utf-8') as file:
self.dataset_desc = json.loads(file.read())
def __len__(self):
return len(self.dataset_desc["seq"])
def __getitem__(self, idx):
# start = time.time()
lightfield_images, gt, phi, phi_invalid, retinal_invalid, flow, flow_invalid_mask, fd, gazeX, gazeY, posX, posY, posZ, sample_idx = self.get_datum(idx)
fd = fd.astype(np.float32)
gazeX = gazeX.astype(np.float32)
gazeY = gazeY.astype(np.float32)
posX = posX.astype(np.float32)
posY = posY.astype(np.float32)
posZ = posZ.astype(np.float32)
sample_idx = sample_idx.astype(np.int64)
if self.transforms:
lightfield_images = self.transforms(lightfield_images)
# print("read once:",time.time() - start) # 8 ms
return (lightfield_images, gt, phi, phi_invalid, retinal_invalid, flow, flow_invalid_mask, fd, gazeX, gazeY, posX, posY, posZ, sample_idx)
def get_datum(self, idx):
IM_H = 320
IM_W = 320
indices = self.dataset_desc["seq"][idx]
# print("indices:",indices)
lf_images = []
fd = []
gazeX = []
gazeY = []
posX = []
posY = []
posZ = []
sample_idx = []
gt = []
# gt2 = []
phi = []
phi_invalid = []
retinal_invalid = []
for i in range(len(indices)): # 5
lf_image_paths = os.path.join(self.file_dir_path, self.dataset_desc["train"][indices[i]])
fd_gt_path = os.path.join(self.file_dir_path, self.dataset_desc["gt"][indices[i]])
# fd_gt_path2 = os.path.join(self.file_dir_path, self.dataset_desc["gt2"][indices[i]])
lf_image_one_sample = []
lf_image_big = cv2.imread(lf_image_paths, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.
lf_image_big = cv2.cvtColor(lf_image_big,cv2.COLOR_BGR2RGB)
lf_dim = int(self.conf.light_field_dim)
for j in range(lf_dim**2):
lf_image = lf_image_big[j//lf_dim*IM_H:j//lf_dim*IM_H+IM_H,j%lf_dim*IM_W:j%lf_dim*IM_W+IM_W,0:3]
lf_image_one_sample.append(lf_image)
gt_i = cv2.imread(fd_gt_path, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.
gt.append(cv2.cvtColor(gt_i,cv2.COLOR_BGR2RGB))
# gt2_i = cv2.imread(fd_gt_path2, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.
# gt2.append(cv2.cvtColor(gt2_i,cv2.COLOR_BGR2RGB))
# print("indices[i]:",indices[i])
fd.append([self.dataset_desc["focaldepth"][indices[i]]])
gazeX.append([self.dataset_desc["gazeX"][indices[i]]])
gazeY.append([self.dataset_desc["gazeY"][indices[i]]])
posX.append([self.dataset_desc["x"][indices[i]]])
posY.append([self.dataset_desc["y"][indices[i]]])
posZ.append([0.0])
sample_idx.append([self.dataset_desc["idx"][indices[i]]])
lf_images.append(lf_image_one_sample)
idx_i = sample_idx[i][0]
focaldepth_i = fd[i][0]
gazeX_i = gazeX[i][0]
gazeY_i = gazeY[i][0]
posX_i = posX[i][0]
posY_i = posY[i][0]
posZ_i = posZ[i][0]
# print("x:%.3f,y:%.3f,z:%.3f;gaze:%.4f,%.4f,focaldepth:%.3f."%(posX_i,posY_i,posZ_i,gazeX_i,gazeY_i,focaldepth_i))
phi_i,phi_invalid_i,retinal_invalid_i = self.gen.CalculateRetinal2LayerMappings(torch.tensor([posX_i, posY_i,posZ_i]),torch.tensor([gazeX_i, gazeY_i]),focaldepth_i)
phi.append(phi_i)
phi_invalid.append(phi_invalid_i)
retinal_invalid.append(retinal_invalid_i)
#lf_images: 5,9,320,320
flow = Flow.Load([os.path.join(self.file_dir_path, self.dataset_desc["flow"][indices[i-1]]) for i in range(1,len(indices))])
flow_map = flow.getMap()
flow_invalid_mask = flow.b_invalid_mask
# print("flow:",flow_map.shape)
return np.asarray(lf_images),np.asarray(gt), torch.stack(phi,dim=0), torch.stack(phi_invalid,dim=0),torch.stack(retinal_invalid,dim=0), flow_map, flow_invalid_mask, np.asarray(fd),np.asarray(gazeX),np.asarray(gazeY),np.asarray(posX),np.asarray(posY),np.asarray(posZ),np.asarray(sample_idx)
import torch
from ssim import *
from perc_loss import *
l1loss = torch.nn.L1Loss()
perc_loss = VGGPerceptualLoss()
perc_loss = perc_loss.to("cuda:1")
##### LOSS #####
def calImageGradients(images):
# x is a 4-D tensor
dx = images[:, :, 1:, :] - images[:, :, :-1, :]
dy = images[:, :, :, 1:] - images[:, :, :, :-1]
return dx, dy
def loss_new(generated, gt):
mse_loss = torch.nn.MSELoss()
rmse_intensity = mse_loss(generated, gt)
psnr_intensity = torch.log10(rmse_intensity)
# print("psnr:",psnr_intensity)
# ssim_intensity = ssim(generated, gt)
labels_dx, labels_dy = calImageGradients(gt)
# print("generated:",generated.shape)
preds_dx, preds_dy = calImageGradients(generated)
rmse_grad_x, rmse_grad_y = mse_loss(labels_dx, preds_dx), mse_loss(labels_dy, preds_dy)
psnr_grad_x, psnr_grad_y = torch.log10(rmse_grad_x), torch.log10(rmse_grad_y)
# print("psnr x&y:",psnr_grad_x," ",psnr_grad_y)
p_loss = perc_loss(generated,gt)
# print("-psnr:",-psnr_intensity,",0.5*(psnr_grad_x + psnr_grad_y):",0.5*(psnr_grad_x + psnr_grad_y),",perc_loss:",p_loss)
total_loss = psnr_intensity + 0.5*(psnr_grad_x + psnr_grad_y) + p_loss
# total_loss = rmse_intensity + 0.5*(rmse_grad_x + rmse_grad_y) # + p_loss
return total_loss
def loss_without_perc(generated, gt):
mse_loss = torch.nn.MSELoss()
rmse_intensity = mse_loss(generated, gt)
psnr_intensity = torch.log10(rmse_intensity)
# print("psnr:",psnr_intensity)
# ssim_intensity = ssim(generated, gt)
labels_dx, labels_dy = calImageGradients(gt)
# print("generated:",generated.shape)
preds_dx, preds_dy = calImageGradients(generated)
rmse_grad_x, rmse_grad_y = mse_loss(labels_dx, preds_dx), mse_loss(labels_dy, preds_dy)
psnr_grad_x, psnr_grad_y = torch.log10(rmse_grad_x), torch.log10(rmse_grad_y)
# print("psnr x&y:",psnr_grad_x," ",psnr_grad_y)
# print("-psnr:",-psnr_intensity,",0.5*(psnr_grad_x + psnr_grad_y):",0.5*(psnr_grad_x + psnr_grad_y),",perc_loss:",p_loss)
total_loss = psnr_intensity + 0.5*(psnr_grad_x + psnr_grad_y)
# total_loss = rmse_intensity + 0.5*(rmse_grad_x + rmse_grad_y) # + p_loss
return total_loss
##### LOSS #####
class ReconstructionLoss(torch.nn.Module):
def __init__(self):
super(ReconstructionLoss, self).__init__()
def forward(self, generated, gt):
rmse_intensity = torch.nn.functional.mse_loss(generated, gt)
psnr_intensity = torch.log10(rmse_intensity)
labels_dx, labels_dy = calImageGradients(gt)
preds_dx, preds_dy = calImageGradients(generated)
rmse_grad_x, rmse_grad_y = torch.nn.functional.mse_loss(labels_dx, preds_dx), torch.nn.functional.mse_loss(labels_dy, preds_dy)
psnr_grad_x, psnr_grad_y = torch.log10(rmse_grad_x), torch.log10(rmse_grad_y)
total_loss = psnr_intensity + 0.5*(psnr_grad_x + psnr_grad_y)
return total_loss
class PerceptionReconstructionLoss(torch.nn.Module):
def __init__(self):
super(PerceptionReconstructionLoss, self).__init__()
def forward(self, generated, gt):
rmse_intensity = torch.nn.functional.mse_loss(generated, gt)
psnr_intensity = torch.log10(rmse_intensity)
labels_dx, labels_dy = calImageGradients(gt)
preds_dx, preds_dy = calImageGradients(generated)
rmse_grad_x, rmse_grad_y = torch.nn.functional.mse_loss(labels_dx, preds_dx), torch.nn.functional.mse_loss(labels_dy, preds_dy)
psnr_grad_x, psnr_grad_y = torch.log10(rmse_grad_x), torch.log10(rmse_grad_y)
p_loss = perc_loss(generated,gt)
total_loss = psnr_intensity + 0.5*(psnr_grad_x + psnr_grad_y) + p_loss
return total_loss
import torch
import torchvision
class VGGPerceptualLoss(torch.nn.Module):
def __init__(self, resize=True):
super(VGGPerceptualLoss, self).__init__()
blocks = []
blocks.append(torchvision.models.vgg16(pretrained=True).features[:4].eval())
blocks.append(torchvision.models.vgg16(pretrained=True).features[4:9].eval())
blocks.append(torchvision.models.vgg16(pretrained=True).features[9:16].eval())
blocks.append(torchvision.models.vgg16(pretrained=True).features[16:23].eval())
for bl in blocks:
for p in bl:
p.requires_grad = False
self.blocks = torch.nn.ModuleList(blocks)
self.transform = torch.nn.functional.interpolate
self.mean = torch.nn.Parameter(torch.tensor([0.485, 0.456, 0.406]).view(1,3,1,1))
self.std = torch.nn.Parameter(torch.tensor([0.229, 0.224, 0.225]).view(1,3,1,1))
self.resize = resize
def forward(self, input, target):
if input.shape[1] != 3:
input = input.repeat(1, 3, 1, 1)
target = target.repeat(1, 3, 1, 1)
input = (input-self.mean) / self.std
target = (target-self.mean) / self.std
if self.resize:
input = self.transform(input, mode='bilinear', size=(224, 224), align_corners=False)
target = self.transform(target, mode='bilinear', size=(224, 224), align_corners=False)
loss = 0.0
x = input
y = target
for block in self.blocks:
x = block(x)
y = block(y)
loss += torch.nn.functional.l1_loss(x, y)
return loss
\ No newline at end of file
import torch
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np
from math import exp
def gaussian(window_size, sigma):
gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])
return gauss/gauss.sum()
def create_window(window_size, channel):
_1D_window = gaussian(window_size, 1.5).unsqueeze(1)
_2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
return window
def _ssim(img1, img2, window, window_size, channel, size_average = True):
mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel)
mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel)
mu1_sq = mu1.pow(2)
mu2_sq = mu2.pow(2)
mu1_mu2 = mu1*mu2
sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq
sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq
sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2
C1 = 0.01**2
C2 = 0.03**2
ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2))
if size_average:
return ssim_map.mean()
else:
return ssim_map.mean(1).mean(1).mean(1)
class SSIM(torch.nn.Module):
def __init__(self, window_size = 11, size_average = True):
super(SSIM, self).__init__()
self.window_size = window_size
self.size_average = size_average
self.channel = 1
self.window = create_window(window_size, self.channel)
def forward(self, img1, img2):
(_, channel, _, _) = img1.size()
if channel == self.channel and self.window.data.type() == img1.data.type():
window = self.window
else:
window = create_window(self.window_size, channel)
if img1.is_cuda:
window = window.cuda(img1.get_device())
window = window.type_as(img1)
self.window = window
self.channel = channel
return _ssim(img1, img2, window, self.window_size, channel, self.size_average)
def ssim(img1, img2, window_size = 11, size_average = True):
(_, channel, _, _) = img1.size()
window = create_window(window_size, channel)
if img1.is_cuda:
window = window.cuda(img1.get_device())
window = window.type_as(img1)
return _ssim(img1, img2, window, window_size, channel, size_average)
\ No newline at end of file
import numpy as np
import torch
import matplotlib.pyplot as plt
import glm
gvec_type = [ glm.dvec1, glm.dvec2, glm.dvec3, glm.dvec4 ]
gmat_type = [ [ glm.dmat2, glm.dmat2x3, glm.dmat2x4 ],
[ glm.dmat3x2, glm.dmat3, glm.dmat3x4 ],
[ glm.dmat4x2, glm.dmat4x3, glm.dmat4 ] ]
def Fov2Length(angle):
return np.tan(angle * np.pi / 360) * 2
def SmoothStep(x0, x1, x):
y = torch.clamp((x - x0) / (x1 - x0), 0, 1)
return y * y * (3 - 2 * y)
def MatImg2Tensor(img, permute = True, batch_dim = True):
batch_input = len(img.shape) == 4
if permute:
t = torch.from_numpy(np.transpose(img,
[0, 3, 1, 2] if batch_input else [2, 0, 1]))
else:
t = torch.from_numpy(img)
if not batch_input and batch_dim:
t = t.unsqueeze(0)
return t
def MatImg2Numpy(img, permute = True, batch_dim = True):
batch_input = len(img.shape) == 4
if permute:
t = np.transpose(img, [0, 3, 1, 2] if batch_input else [2, 0, 1])
else:
t = img
if not batch_input and batch_dim:
t = t.unsqueeze(0)
return t
def Tensor2MatImg(t):
img = t.squeeze().cpu().numpy()
batch_input = len(img.shape) == 4
if t.size()[batch_input] <= 4:
return np.transpose(img, [0, 2, 3, 1] if batch_input else [1, 2, 0])
return img
def ReadImageTensor(path, permute = True, rgb_only = True, batch_dim = True):
channels = 3 if rgb_only else 4
if isinstance(path,list):
first_image = plt.imread(path[0])[:, :, 0:channels]
b_image = np.empty((len(path), first_image.shape[0], first_image.shape[1], channels), dtype=np.float32)
b_image[0] = first_image
for i in range(1, len(path)):
b_image[i] = plt.imread(path[i])[:, :, 0:channels]
return MatImg2Tensor(b_image, permute)
return MatImg2Tensor(plt.imread(path)[:, :, 0:channels], permute, batch_dim)
def ReadImageNumpyArray(path, permute = True, rgb_only = True, batch_dim = True):
channels = 3 if rgb_only else 4
if isinstance(path,list):
first_image = plt.imread(path[0])[:, :, 0:channels]
b_image = np.empty((len(path), first_image.shape[0], first_image.shape[1], channels), dtype=np.float32)
b_image[0] = first_image
for i in range(1, len(path)):
b_image[i] = plt.imread(path[i])[:, :, 0:channels]
return MatImg2Numpy(b_image, permute)
return MatImg2Numpy(plt.imread(path)[:, :, 0:channels], permute, batch_dim)
def WriteImageTensor(t, path):
image = Tensor2MatImg(t)
if isinstance(path,list):
if len(image.shape) != 4 or image.shape[0] != len(path):
raise ValueError
for i in range(len(path)):
plt.imsave(path[i], image[i])
else:
if len(image.shape) == 4 and image.shape[0] != 1:
raise ValueError
plt.imsave(path, image)
def PlotImageTensor(t):
plt.imshow(Tensor2MatImg(t))
def Tensor2Glm(t):
t = t.squeeze()
size = t.size()
if len(size) == 1:
if size[0] <= 0 or size[0] > 4:
raise ValueError
return gvec_type[size[0] - 1](t.cpu().numpy())
if len(size) == 2:
if size[0] <= 1 or size[0] > 4 or size[1] <= 1 or size[1] > 4:
raise ValueError
return gmat_type[size[1] - 2][size[0] - 2](t.cpu().numpy())
raise ValueError
def Glm2Tensor(val):
return torch.from_numpy(np.array(val))
def MeshGrid(size, normalize=False):
y,x = torch.meshgrid(torch.tensor(range(size[0])),
torch.tensor(range(size[1])))
if normalize:
return torch.stack([x / (size[1] - 1.), y / (size[0] - 1.)], 2)
return torch.stack([x, y], 2)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment