Commit a52ea2e4 authored by BobYeah's avatar BobYeah
Browse files

implemented a simple encoder-transformer-decoder net

parent 0356508b
...@@ -3,52 +3,8 @@ import numpy as np ...@@ -3,52 +3,8 @@ import numpy as np
import torch import torch
import glm import glm
import time import time
import util from .my import util
from .my import sample_in_pupil
def RandomGenSamplesInPupil(pupil_size, n_samples):
'''
Random sample n_samples positions in pupil region
Parameters
--------
conf - multi-layers' parameters configuration
n_samples - number of samples to generate
Returns
--------
a n_samples x 3 tensor with 3D sample position in each row
'''
samples = torch.empty(n_samples, 3)
i = 0
while i < n_samples:
s = (torch.rand(2) - 0.5) * pupil_size
if np.linalg.norm(s) > pupil_size / 2.:
continue
samples[i, :] = [ s[0], s[1], 0 ]
i += 1
return samples
def GenSamplesInPupil(pupil_size, circles):
'''
Sample positions on circles in pupil region
Parameters
--------
conf - multi-layers' parameters configuration
circles - number of circles to sample
Returns
--------
a n_samples x 3 tensor with 3D sample position in each row
'''
samples = torch.zeros(1, 3)
for i in range(1, circles):
r = pupil_size / 2. / (circles - 1) * i
n = 4 * i
for j in range(0, n):
angle = 2 * np.pi / n * j
samples = torch.cat([ samples, torch.tensor([[ r * np.cos(angle), r * np.sin(angle), 0 ]]) ], 0)
return samples
class RetinalGen(object): class RetinalGen(object):
''' '''
...@@ -75,7 +31,7 @@ class RetinalGen(object): ...@@ -75,7 +31,7 @@ class RetinalGen(object):
u - a M x 3 tensor stores M sample positions in pupil u - a M x 3 tensor stores M sample positions in pupil
''' '''
self.conf = conf self.conf = conf
self.u = GenSamplesInPupil(conf.pupil_size, 5) self.u = sample_in_pupil.CircleGen(conf.pupil_size, 5)
# self.u = u.to(cuda_dev) # self.u = u.to(cuda_dev)
# self.u = u # M x 3 M sample positions # self.u = u # M x 3 M sample positions
self.D_r = conf.retinal_res # retinal res 480 x 640 self.D_r = conf.retinal_res # retinal res 480 x 640
......
import torch
from .ssim import *
from .perc_loss import *
device=torch.device("cuda:2")
l1loss = torch.nn.L1Loss()
perc_loss = VGGPerceptualLoss().to(device)
##### LOSS #####
def calImageGradients(images):
# x is a 4-D tensor
dx = images[:, :, 1:, :] - images[:, :, :-1, :]
dy = images[:, :, :, 1:] - images[:, :, :, :-1]
return dx, dy
def loss_new(generated, gt):
mse_loss = torch.nn.MSELoss()
rmse_intensity = mse_loss(generated, gt)
psnr_intensity = torch.log10(rmse_intensity)
# print("psnr:",psnr_intensity)
# ssim_intensity = ssim(generated, gt)
labels_dx, labels_dy = calImageGradients(gt)
# print("generated:",generated.shape)
preds_dx, preds_dy = calImageGradients(generated)
rmse_grad_x, rmse_grad_y = mse_loss(
labels_dx, preds_dx), mse_loss(labels_dy, preds_dy)
psnr_grad_x, psnr_grad_y = torch.log10(
rmse_grad_x), torch.log10(rmse_grad_y)
# print("psnr x&y:",psnr_grad_x," ",psnr_grad_y)
p_loss = perc_loss(generated, gt)
# print("-psnr:",-psnr_intensity,",0.5*(psnr_grad_x + psnr_grad_y):",0.5*(psnr_grad_x + psnr_grad_y),",perc_loss:",p_loss)
total_loss = psnr_intensity + 0.5*(psnr_grad_x + psnr_grad_y) + p_loss
# total_loss = rmse_intensity + 0.5*(rmse_grad_x + rmse_grad_y) # + p_loss
return total_loss
def loss_without_perc(generated, gt):
mse_loss = torch.nn.MSELoss()
rmse_intensity = mse_loss(generated, gt)
psnr_intensity = torch.log10(rmse_intensity)
# print("psnr:",psnr_intensity)
# ssim_intensity = ssim(generated, gt)
labels_dx, labels_dy = calImageGradients(gt)
# print("generated:",generated.shape)
preds_dx, preds_dy = calImageGradients(generated)
rmse_grad_x, rmse_grad_y = mse_loss(
labels_dx, preds_dx), mse_loss(labels_dy, preds_dy)
psnr_grad_x, psnr_grad_y = torch.log10(
rmse_grad_x), torch.log10(rmse_grad_y)
# print("psnr x&y:",psnr_grad_x," ",psnr_grad_y)
# print("-psnr:",-psnr_intensity,",0.5*(psnr_grad_x + psnr_grad_y):",0.5*(psnr_grad_x + psnr_grad_y),",perc_loss:",p_loss)
total_loss = psnr_intensity + 0.5*(psnr_grad_x + psnr_grad_y)
# total_loss = rmse_intensity + 0.5*(rmse_grad_x + rmse_grad_y) # + p_loss
return total_loss
##### LOSS #####
class ReconstructionLoss(torch.nn.Module):
def __init__(self):
super(ReconstructionLoss, self).__init__()
def forward(self, generated, gt):
rmse_intensity = torch.nn.functional.mse_loss(generated, gt)
psnr_intensity = torch.log10(rmse_intensity)
labels_dx, labels_dy = calImageGradients(gt)
preds_dx, preds_dy = calImageGradients(generated)
rmse_grad_x, rmse_grad_y = torch.nn.functional.mse_loss(
labels_dx, preds_dx), torch.nn.functional.mse_loss(labels_dy, preds_dy)
psnr_grad_x, psnr_grad_y = torch.log10(
rmse_grad_x), torch.log10(rmse_grad_y)
total_loss = psnr_intensity + 0.5*(psnr_grad_x + psnr_grad_y)
return total_loss
class PerceptionReconstructionLoss(torch.nn.Module):
def __init__(self):
super(PerceptionReconstructionLoss, self).__init__()
def forward(self, generated, gt):
rmse_intensity = torch.nn.functional.mse_loss(generated, gt)
psnr_intensity = torch.log10(rmse_intensity)
labels_dx, labels_dy = calImageGradients(gt)
preds_dx, preds_dy = calImageGradients(generated)
rmse_grad_x = torch.nn.functional.mse_loss(labels_dx, preds_dx)
rmse_grad_y = torch.nn.functional.mse_loss(labels_dy, preds_dy)
psnr_grad_x = torch.log10(rmse_grad_x)
psnr_grad_y = torch.log10(rmse_grad_y)
p_loss = perc_loss(generated, gt)
total_loss = psnr_intensity + 0.5 * (psnr_grad_x + psnr_grad_y) + p_loss
return total_loss
import torch
from torch import nn
from torch.nn import functional as F
import torchvision
class VGGPerceptualLoss(nn.Module):
def __init__(self, resize=True):
super(VGGPerceptualLoss, self).__init__()
blocks = []
blocks.append(torchvision.models.vgg16(pretrained=True).features[:4].eval())
blocks.append(torchvision.models.vgg16(pretrained=True).features[4:9].eval())
blocks.append(torchvision.models.vgg16(pretrained=True).features[9:16].eval())
blocks.append(torchvision.models.vgg16(pretrained=True).features[16:23].eval())
for bl in blocks:
for p in bl:
p.requires_grad = False
self.blocks = nn.ModuleList(blocks)
self.transform = F.interpolate
self.mean = nn.Parameter(torch.tensor([0.485, 0.456, 0.406]).view(1,3,1,1))
self.std = nn.Parameter(torch.tensor([0.229, 0.224, 0.225]).view(1,3,1,1))
self.resize = resize
def forward(self, input, target):
if input.shape[1] != 3:
input = input.repeat(1, 3, 1, 1)
target = target.repeat(1, 3, 1, 1)
input = (input-self.mean) / self.std
target = (target-self.mean) / self.std
if self.resize:
input = self.transform(input, mode='bilinear', size=(224, 224), align_corners=False)
target = self.transform(target, mode='bilinear', size=(224, 224), align_corners=False)
loss = 0.0
x = input
y = target
for block in self.blocks:
x = block(x)
y = block(y)
loss += F.l1_loss(x, y)
return loss
\ No newline at end of file
import torch
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np
from math import exp
def gaussian(window_size, sigma):
gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])
return gauss/gauss.sum()
def create_window(window_size, channel):
_1D_window = gaussian(window_size, 1.5).unsqueeze(1)
_2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
return window
def _ssim(img1, img2, window, window_size, channel, size_average = True):
mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel)
mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel)
mu1_sq = mu1.pow(2)
mu2_sq = mu2.pow(2)
mu1_mu2 = mu1*mu2
sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq
sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq
sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2
C1 = 0.01**2
C2 = 0.03**2
ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2))
if size_average:
return ssim_map.mean()
else:
return ssim_map.mean(1).mean(1).mean(1)
class SSIM(torch.nn.Module):
def __init__(self, window_size = 11, size_average = True):
super(SSIM, self).__init__()
self.window_size = window_size
self.size_average = size_average
self.channel = 1
self.window = create_window(window_size, self.channel)
def forward(self, img1, img2):
(_, channel, _, _) = img1.size()
if channel == self.channel and self.window.data.type() == img1.data.type():
window = self.window
else:
window = create_window(self.window_size, channel)
if img1.is_cuda:
window = window.cuda(img1.get_device())
window = window.type_as(img1)
self.window = window
self.channel = channel
return _ssim(img1, img2, window, self.window_size, channel, self.size_average)
def ssim(img1, img2, window_size = 11, size_average = True):
(_, channel, _, _) = img1.size()
window = create_window(window_size, channel)
if img1.is_cuda:
window = window.cuda(img1.get_device())
window = window.type_as(img1)
return _ssim(img1, img2, window, window_size, channel, size_average)
\ No newline at end of file
...@@ -11,13 +11,13 @@ from torch.utils.data import DataLoader ...@@ -11,13 +11,13 @@ from torch.utils.data import DataLoader
from torch.autograd import Variable from torch.autograd import Variable
import cv2 import cv2
from gen_image import * from .gen_image import *
from loss import * from .loss import *
import json import json
from conf import Conf from .conf import Conf
from baseline import * from .baseline import *
from data import * from .data import *
import torch.autograd.profiler as profiler import torch.autograd.profiler as profiler
# param # param
......
import torch
import numpy as np
from . import util
class Conf(object):
def __init__(self):
self.pupil_size = 0.02
self.retinal_res = torch.tensor([ 320, 320 ])
self.layer_res = torch.tensor([ 320, 320 ])
self.layer_hfov = 90 # layers' horizontal FOV
self.eye_hfov = 80 # eye's horizontal FOV (ignored in foveated rendering)
self.eye_enable_fovea = False # enable foveated rendering
self.eye_fovea_angles = [ 40, 80 ] # eye's foveation layers' angles
self.eye_fovea_downsamples = [ 1, 2 ] # eye's foveation layers' downsamples
self.d_layer = [ 1, 3 ] # layers' distance
self.eye_fovea_blend = [ self._GenFoveaLayerBlend(0) ]
# blend maps of fovea layers
self.light_field_dim = 5
def GetNLayers(self):
return len(self.d_layer)
def GetLayerSize(self, i):
w = util.Fov2Length(self.layer_hfov)
h = w * self.layer_res[0] / self.layer_res[1]
return torch.tensor([ h, w ]) * self.d_layer[i]
def GetPixelSizeOfLayer(self, i):
'''
Get pixel size of layer i
'''
return util.Fov2Length(self.layer_hfov) * self.d_layer[i] / self.layer_res[0]
def GetEyeViewportSize(self):
fov = self.eye_fovea_angles[-1] if self.eye_enable_fovea else self.eye_hfov
w = util.Fov2Length(fov)
h = w * self.retinal_res[0] / self.retinal_res[1]
return torch.tensor([ h, w ])
def GetRegionOfFoveaLayer(self, i):
'''
Get region of fovea layer i in retinal image
Returns
--------
slice object stores the start and end of region
'''
roi_size = int(np.ceil(self.retinal_res[0] * self.eye_fovea_angles[i] / self.eye_fovea_angles[-1]))
roi_offset = int((self.retinal_res[0] - roi_size) / 2)
return slice(roi_offset, roi_offset + roi_size)
def _GenFoveaLayerBlend(self, i):
'''
Generate blend map for fovea layer i
Parameters
--------
i - index of fovea layer
Returns
--------
H[i] x W[i], blend map
'''
region = self.GetRegionOfFoveaLayer(i)
width = region.stop - region.start
R = width / 2
p = util.MeshGrid([ width, width ])
r = torch.linalg.norm(p - R, 2, dim=2, keepdim=False)
return util.SmoothStep(R, R * 0.6, r)
import matplotlib.pyplot as plt
import torch
import util
import numpy as np
def FlowMap(b_last_frame, b_map):
'''
Map images using the flow data.
Parameters
--------
b_last_frame - B x 3 x H x W tensor, batch of images
b_map - B x H x W x 2, batch of map data records pixel coords in last frames
Returns
--------
B x 3 x H x W tensor, batch of images mapped by flow data
'''
return torch.nn.functional.grid_sample(b_last_frame, b_map, align_corners=False)
class Flow(object):
'''
Class representating optical flow
Properties
--------
b_data - B x H x W x 2, batch of flow data
b_map - B x H x W x 2, batch of map data records pixel coords in last frames
b_invalid_mask - B x H x W, batch of masks, indicate invalid elements in corresponding flow data
'''
def Load(paths):
'''
Create a Flow instance using a batch of encoded data images loaded from paths
Parameters
--------
paths - list of encoded data image paths
Returns
--------
Flow instance
'''
b_encoded_image = util.ReadImageTensor(paths, rgb_only=False, permute=False, batch_dim=True)
return Flow(b_encoded_image)
def __init__(self, b_encoded_image):
'''
Initialize a Flow instance from a batch of encoded data images
Parameters
--------
b_encoded_image - batch of encoded data images
'''
b_encoded_image = b_encoded_image.mul(255)
# print("b_encoded_image:",b_encoded_image.shape)
self.b_invalid_mask = (b_encoded_image[:, :, :, 0] == 255)
self.b_data = (b_encoded_image[:, :, :, 0:2] / 254 + b_encoded_image[:, :, :, 2:4] - 127) / 127
self.b_data[:, :, :, 1] = -self.b_data[:, :, :, 1]
D = self.b_data.size()
grid = util.MeshGrid((D[1], D[2]), True)
self.b_map = (grid - self.b_data - 0.5) * 2
self.b_map[self.b_invalid_mask] = torch.tensor([ -2.0, -2.0 ])
def getMap(self):
return self.b_map
def Visualize(self, scale_factor = 1):
'''
Visualize the flow data by "color wheel".
Parameters
--------
scale_factor - scale factor of flow data to visualize, default is 1
Returns
--------
B x 3 x H x W tensor, visualization of flow data
'''
try:
Flow.b_color_wheel
except AttributeError:
Flow.b_color_wheel = util.ReadImageTensor('color_wheel.png')
return torch.nn.functional.grid_sample(Flow.b_color_wheel.expand(self.b_data.size()[0], -1, -1, -1),
(self.b_data * scale_factor), align_corners=False)
\ No newline at end of file
import torch
import numpy as np
def PrintNet(net):
model_parameters = filter(lambda p: p.requires_grad, net.parameters())
params = sum([np.prod(p.size()) for p in model_parameters])
print("%d" % params)
def LoadNet(path, model, solver=None, discriminator=None):
print('Load net from %s ...' % path)
whole_dict = torch.load(path)
model.load_state_dict(whole_dict['model'])
if solver:
solver.load_state_dict(whole_dict['solver'])
if discriminator:
discriminator.load_state_dict(whole_dict['discriminator'])
def SaveNet(path, model, solver=None, discriminator=None):
print('Saving net to %s ...' % path)
whole_dict = {
'model': model.state_dict()
}
if solver:
whole_dict.update({'solver': solver.state_dict()})
if discriminator:
whole_dict.update({'discriminator': discriminator.state_dict()})
torch.save(whole_dict, path)
\ No newline at end of file
import torch
import numpy as np
def RandomGen(pupil_size: float, n_samples: int) -> torch.Tensor:
"""
Random sample n_samples positions in pupil region
:param pupil_size: multi-layers' parameters configuration
:param n_samples: number of samples to generate
:return: n_samples x 3, with 3D sample position in each row
"""
samples = torch.empty(n_samples, 3)
i = 0
while i < n_samples:
s = (torch.rand(2) - 0.5) * pupil_size
if np.linalg.norm(s) > pupil_size / 2.:
continue
samples[i, :] = [s[0], s[1], 0]
i += 1
return samples
def CircleGen(pupil_size: float, circles: int) -> torch.Tensor:
"""
Sample positions on circles in pupil region
:param pupil_size: diameter of pupil
:param circles: number of circles to sample
:return: M x 3, with 3D sample position in each row
"""
samples = torch.zeros(1, 3)
for i in range(1, circles):
r = pupil_size / 2. / (circles - 1) * i
n = 4 * i
for j in range(0, n):
angle = 2 * np.pi / n * j
samples = torch.cat([samples, torch.tensor([[r * np.cos(angle), r * np.sin(angle), 0]])], 0)
return samples
import torch.cuda
class SimplePerf(object):
def __init__(self, enable, start = False) -> None:
super().__init__()
self.enable = enable
if start:
self.Start()
def Start(self):
if not self.enable:
return
if self.start_event == None:
self.start_event = torch.cuda.Event(enable_timing=True)
self.end_event = torch.cuda.Event(enable_timing=True)
torch.cuda.synchronize()
self.start_event.record()
def Checkpoint(self, name: str, end: bool = False):
if not self.enable:
return
self.end_event.record()
torch.cuda.synchronize()
print(name, ': ', self.start_event.elapsed_time(self.end_event))
if not end:
self.start_event.record()
\ No newline at end of file
from typing import Tuple
import numpy as np
import torch
import torchvision
import matplotlib.pyplot as plt
import glm
import os
from torchvision.utils import save_image
gvec_type = [glm.dvec1, glm.dvec2, glm.dvec3, glm.dvec4]
gmat_type = [[glm.dmat2, glm.dmat2x3, glm.dmat2x4],
[glm.dmat3x2, glm.dmat3, glm.dmat3x4],
[glm.dmat4x2, glm.dmat4x3, glm.dmat4]]
def Fov2Length(angle):
return np.tan(angle * np.pi / 360) * 2
def SmoothStep(x0, x1, x):
y = torch.clamp((x - x0) / (x1 - x0), 0, 1)
return y * y * (3 - 2 * y)
def MatImg2Tensor(img, permute=True, batch_dim=True):
batch_input = len(img.shape) == 4
if permute:
t = torch.from_numpy(np.transpose(img,
[0, 3, 1, 2] if batch_input else [2, 0, 1]))
else:
t = torch.from_numpy(img)
if not batch_input and batch_dim:
t = t.unsqueeze(0)
return t
def MatImg2Numpy(img, permute=True, batch_dim=True):
batch_input = len(img.shape) == 4
if permute:
t = np.transpose(img, [0, 3, 1, 2] if batch_input else [2, 0, 1])
else:
t = img
if not batch_input and batch_dim:
t = t.unsqueeze(0)
return t
def Tensor2MatImg(t: torch.Tensor) -> np.ndarray:
"""
Convert image tensor to numpy ndarray suitable for matplotlib
:param t: 2D (HW), 3D (CHW/HWC) or 4D (BCHW/BHWC) tensor
:return: numpy ndarray (...C), with channel transposed to the last dim
"""
img = t.squeeze().cpu().detach().numpy()
if len(img.shape) == 2: # Single channel image
return img
batch_input = len(img.shape) == 4
if t.size()[batch_input] <= 4:
return np.transpose(img, [0, 2, 3, 1] if batch_input else [1, 2, 0])
return img
def ReadImageTensor(path, permute=True, rgb_only=True, batch_dim=True):
channels = 3 if rgb_only else 4
if isinstance(path, list):
first_image = plt.imread(path[0])[:, :, 0:channels]
b_image = np.empty(
(len(path), first_image.shape[0], first_image.shape[1], channels), dtype=np.float32)
b_image[0] = first_image
for i in range(1, len(path)):
b_image[i] = plt.imread(path[i])[:, :, 0:channels]
return MatImg2Tensor(b_image, permute)
return MatImg2Tensor(plt.imread(path)[:, :, 0:channels], permute, batch_dim)
def ReadImageNumpyArray(path, permute=True, rgb_only=True, batch_dim=True):
channels = 3 if rgb_only else 4
if isinstance(path, list):
first_image = plt.imread(path[0])[:, :, 0:channels]
b_image = np.empty(
(len(path), first_image.shape[0], first_image.shape[1], channels), dtype=np.float32)
b_image[0] = first_image
for i in range(1, len(path)):
b_image[i] = plt.imread(path[i])[:, :, 0:channels]
return MatImg2Numpy(b_image, permute)
return MatImg2Numpy(plt.imread(path)[:, :, 0:channels], permute, batch_dim)
def WriteImageTensor(t, path):
#image = Tensor2MatImg(t)
if isinstance(path, list):
if (len(t.size()) != 4 and len(path) != 1) or t.size()[0] != len(path):
raise ValueError
for i in range(len(path)):
save_image(t[i], path[i])
#plt.imsave(path[i], image[i])
else:
if len(t.squeeze().size()) >= 4:
raise ValueError
#plt.imsave(path, image)
save_image(t, path)
def PlotImageTensor(t):
plt.imshow(Tensor2MatImg(t))
def Tensor2Glm(t):
t = t.squeeze()
size = t.size()
if len(size) == 1:
if size[0] <= 0 or size[0] > 4:
raise ValueError
return gvec_type[size[0] - 1](t.cpu().numpy())
if len(size) == 2:
if size[0] <= 1 or size[0] > 4 or size[1] <= 1 or size[1] > 4:
raise ValueError
return gmat_type[size[1] - 2][size[0] - 2](t.cpu().numpy())
raise ValueError
def Glm2Tensor(val):
return torch.from_numpy(np.array(val))
def MeshGrid(size: Tuple[int, int], normalize: bool = False, swap_dim: bool = False):
"""
Generate a mesh grid
:param size: grid size (rows, columns)
:param normalize: return coords in normalized space? defaults to False
:param swap_dim: if True, return coords in (y, x) order, defaults to False
:return: rows x columns x 2 tensor
"""
y, x = torch.meshgrid(torch.tensor(range(size[0])),
torch.tensor(range(size[1])))
if swap_dim:
if normalize:
return torch.stack([y / (size[0] - 1.), x / (size[1] - 1.)], 2)
else:
return torch.stack([y, x], 2)
if normalize:
return torch.stack([x / (size[1] - 1.), y / (size[0] - 1.)], 2)
else:
return torch.stack([x, y], 2)
def CreateDirIfNeed(path):
if not os.path.exists(path):
os.makedirs(path)
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import sys\n",
"sys.path.append('/e/dengnc')\n",
"\n",
"from typing import List\n",
"import torch\n",
"from torch import nn\n",
"import matplotlib.pyplot as plt\n",
"from deeplightfield.data.lf_syn import LightFieldSynDataset\n",
"from deeplightfield.my import util\n",
"from deeplightfield.trans_unet import LatentSpaceTransformer\n",
"\n",
"device = torch.device(\"cuda:2\")\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Test data loader"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"DATA_DIR = '../data/lf_syn_2020.12.23'\n",
"TRAIN_DATA_DESC_FILE = DATA_DIR + '/train.json'\n",
"\n",
"train_dataset = LightFieldSynDataset(TRAIN_DATA_DESC_FILE)\n",
"train_data_loader = torch.utils.data.DataLoader(\n",
" dataset=train_dataset,\n",
" batch_size=3,\n",
" num_workers=8,\n",
" pin_memory=True,\n",
" shuffle=True,\n",
" drop_last=False)\n",
"print(len(train_data_loader))\n",
"\n",
"print(train_dataset.cam_params)\n",
"print(train_dataset.sparse_view_positions)\n",
"print(train_dataset.diopter_of_layers)\n",
"plt.figure()\n",
"util.PlotImageTensor(train_dataset.sparse_view_images[0])\n",
"plt.figure()\n",
"util.PlotImageTensor(train_dataset.sparse_view_depths[0] / 255 * 10)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Test disparity wrapper"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"\n",
"\n",
"\n",
"transformer = LatentSpaceTransformer(train_dataset.sparse_view_images.size()[2],\n",
" train_dataset.cam_params,\n",
" train_dataset.diopter_of_layers,\n",
" train_dataset.sparse_view_positions)\n",
"novel_views = torch.stack([\n",
" train_dataset.view_positions[13],\n",
" train_dataset.view_positions[30],\n",
" train_dataset.view_positions[57],\n",
"], dim=0)\n",
"trans_images = transformer(train_dataset.sparse_view_images.to(device),\n",
" train_dataset.sparse_view_depths.to(device),\n",
" novel_views)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\n",
"mask = (torch.sum(trans_images[0], 1) > 1e-5).to(dtype=torch.float)\n",
"blended = torch.sum(trans_images[0], 0)\n",
"weight = torch.sum(mask, 0)\n",
"blended = blended / weight.unsqueeze(0)\n",
"\n",
"plt.figure(figsize=(6, 6))\n",
"util.PlotImageTensor(train_dataset.view_images[13])\n",
"plt.figure(figsize=(6, 6))\n",
"util.PlotImageTensor(blended)\n",
"plt.figure(figsize=(12, 6))\n",
"plt.subplot(2, 4, 1)\n",
"util.PlotImageTensor(train_dataset.sparse_view_images[0])\n",
"plt.subplot(2, 4, 2)\n",
"util.PlotImageTensor(train_dataset.sparse_view_images[1])\n",
"plt.subplot(2, 4, 3)\n",
"util.PlotImageTensor(train_dataset.sparse_view_images[2])\n",
"plt.subplot(2, 4, 4)\n",
"util.PlotImageTensor(train_dataset.sparse_view_images[3])\n",
"\n",
"plt.subplot(2, 4, 5)\n",
"util.PlotImageTensor(trans_images[0, 0])\n",
"plt.subplot(2, 4, 6)\n",
"util.PlotImageTensor(trans_images[0, 1])\n",
"plt.subplot(2, 4, 7)\n",
"util.PlotImageTensor(trans_images[0, 2])\n",
"plt.subplot(2, 4, 8)\n",
"util.PlotImageTensor(trans_images[0, 3])\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.7.6 64-bit ('pytorch': conda)",
"metadata": {
"interpreter": {
"hash": "a00413fa0fb6b0da754bf9fddd63461fcd32e367fc56a5d25240eae72261060e"
}
},
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.6"
},
"orig_nbformat": 2
},
"nbformat": 4,
"nbformat_minor": 2
}
\ No newline at end of file
Subproject commit 10f49b1e7df38a58fd78451eac91d7ac1a21df64
import sys
sys.path.append('/e/dengnc')
__package__ = "deeplightfield"
import os
import torch
import torch.optim
import torchvision
from tensorboardX import SummaryWriter
from .loss.loss import PerceptionReconstructionLoss
from .my import netio
from .my import util
from .my.simple_perf import SimplePerf
from .data.lf_syn import LightFieldSynDataset
from .trans_unet import TransUnet
device = torch.device("cuda:2")
DATA_DIR = os.path.dirname(__file__) + '/data/lf_syn_2020.12.23'
TRAIN_DATA_DESC_FILE = DATA_DIR + '/train.json'
OUTPUT_DIR = DATA_DIR + '/output_low_lr'
RUN_DIR = DATA_DIR + '/run_low_lr'
BATCH_SIZE = 1
TEST_BATCH_SIZE = 10
NUM_EPOCH = 1000
MODE = "Silence" # "Perf"
EPOCH_BEGIN = 500
def train():
# 1. Initialize data loader
print("Load dataset: " + TRAIN_DATA_DESC_FILE)
train_dataset = LightFieldSynDataset(TRAIN_DATA_DESC_FILE)
train_data_loader = torch.utils.data.DataLoader(
dataset=train_dataset,
batch_size=BATCH_SIZE,
pin_memory=True,
shuffle=True,
drop_last=False)
print(len(train_data_loader))
# 2. Initialize components
model = TransUnet(cam_params=train_dataset.cam_params,
view_images=train_dataset.sparse_view_images,
view_depths=train_dataset.sparse_view_depths,
view_positions=train_dataset.sparse_view_positions,
diopter_of_layers=train_dataset.diopter_of_layers).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
loss = PerceptionReconstructionLoss()
if EPOCH_BEGIN > 0:
netio.LoadNet('%s/model-epoch_%d.pth' % (RUN_DIR, EPOCH_BEGIN), model,
solver=optimizer)
# 3. Train
model.train()
epoch = EPOCH_BEGIN
iters = EPOCH_BEGIN * len(train_data_loader)
util.CreateDirIfNeed(RUN_DIR)
perf = SimplePerf(enable=(MODE == "Perf"), start=True)
writer = SummaryWriter(RUN_DIR)
print("Begin training...")
for epoch in range(EPOCH_BEGIN, NUM_EPOCH):
for _, view_images, _, view_positions in train_data_loader:
view_images = view_images.to(device)
perf.Checkpoint("Load")
out_view_images = model(view_positions)
perf.Checkpoint("Forward")
optimizer.zero_grad()
loss_value = loss(out_view_images, view_images)
perf.Checkpoint("Compute loss")
loss_value.backward()
perf.Checkpoint("Backward")
optimizer.step()
perf.Checkpoint("Update")
print("Epoch: ", epoch, ", Iter: ", iters,
", Loss: ", loss_value.item())
iters = iters + BATCH_SIZE
# Write tensorboard logs.
writer.add_scalar("loss", loss_value, iters)
if iters % len(train_data_loader) == 0:
output_vs_gt = torch.cat([out_view_images, view_images], dim=0)
writer.add_image("Output_vs_gt", torchvision.utils.make_grid(
output_vs_gt, scale_each=True, normalize=False)
.cpu().detach().numpy(), iters)
# Save checkpoint
if ((epoch + 1) % 50 == 0):
netio.SaveNet('%s/model-epoch_%d.pth' % (RUN_DIR, epoch + 1), model,
solver=optimizer)
print("Train finished")
netio.SaveNet('%s/model-epoch_%d.pth' % (RUN_DIR, epoch + 1), model)
def test(net_file: str):
# 1. Load train dataset
print("Load dataset: " + TRAIN_DATA_DESC_FILE)
train_dataset = LightFieldSynDataset(TRAIN_DATA_DESC_FILE)
train_data_loader = torch.utils.data.DataLoader(
dataset=train_dataset,
batch_size=TEST_BATCH_SIZE,
pin_memory=True,
shuffle=False,
drop_last=False)
# 2. Load trained model
model = TransUnet(cam_params=train_dataset.cam_params,
view_images=train_dataset.sparse_view_images,
view_depths=train_dataset.sparse_view_depths,
view_positions=train_dataset.sparse_view_positions,
diopter_of_layers=train_dataset.diopter_of_layers).to(device)
netio.LoadNet(net_file, model)
# 3. Test on train dataset
print("Begin test on train dataset...")
util.CreateDirIfNeed(OUTPUT_DIR)
for view_idxs, view_images, _, view_positions in train_data_loader:
out_view_images = model(view_positions)
util.WriteImageTensor(
view_images,
['%s/gt_view%02d.png' % (OUTPUT_DIR, i) for i in view_idxs])
util.WriteImageTensor(
out_view_images,
['%s/out_view%02d.png' % (OUTPUT_DIR, i) for i in view_idxs])
if __name__ == "__main__":
#train()
test(RUN_DIR + '/model-epoch_1000.pth')
from typing import List
import torch
import torch.nn as nn
from .pytorch_prototyping.pytorch_prototyping import *
from .my import util
device = torch.device("cuda:2")
class Encoder(nn.Module):
def __init__(self, nf0, out_channels, input_resolution, output_sidelength):
"""
Initialize a encoder
:param nf0: number of outmost features
:param out_channels:
:param input_resolution: [description]
:param output_sidelength: [description]
"""
super().__init__()
norm = nn.BatchNorm2d
num_down_unet = int(math.log2(output_sidelength))
num_downsampling = int(math.log2(input_resolution)) - num_down_unet
self.net = nn.Sequential(
DownsamplingNet([nf0 * (2 ** i) for i in range(num_downsampling)],
in_channels=3,
use_dropout=False,
norm=norm),
Unet(in_channels=nf0 * (2 ** (num_downsampling-1)),
out_channels=out_channels,
nf0=nf0 * (2 ** (num_downsampling-1)),
use_dropout=False,
max_channels=8*nf0,
num_down=num_down_unet,
norm=norm)
)
self.depth_downsampler = DownsamplingNet([1 for i in range(num_downsampling)],
in_channels=1,
use_dropout=False,
norm=norm)
def forward(self, input, input_depth):
return self.net(input), torch.round(self.depth_downsampler(input_depth))
class LatentSpaceTransformer(nn.Module):
def __init__(self, feat_dim: int, cam_params,
diopter_of_layers: List[float],
view_positions: torch.Tensor):
"""
Initialize a latent space transformer
:param feat_dim: dimension of latent space
:param cam_params: camera parameters
:param diopter_of_layers: diopter of layers
:param view_positions: view positions of input sparse light field
"""
super().__init__()
self.feat_dim = feat_dim
self.f_cam = cam_params['f']
self.view_positions = view_positions
self.n_views = view_positions.size()[0]
self.diopter_of_layers = diopter_of_layers
self.feat_coords = util.MeshGrid(
(feat_dim, feat_dim)).to(device=device)
def forward(self, feats: torch.Tensor,
feat_depths: torch.Tensor,
novel_views: torch.Tensor) -> torch.Tensor:
trans_feats = torch.zeros(novel_views.size()[0], feats.size()[0],
feats.size()[1], feats.size()[
2], feats.size()[3],
device=device)
for i in range(novel_views.size()[0]):
for v in range(self.n_views):
for l in range(len(self.diopter_of_layers)):
disparity = self._DisparityFromDepth(novel_views[i],
self.view_positions[v],
self.diopter_of_layers[l])
src_window = (
slice(max(0, -int(disparity[1])),
min(feats.size()[2], feats.size()[2] - int(disparity[1]))),
slice(max(0, -int(disparity[0])),
min(feats.size()[3], feats.size()[3] - int(disparity[0])))
)
tgt_window = (
slice(max(0, int(disparity[1])),
min(feats.size()[2], feats.size()[2] + int(disparity[1]))),
slice(max(0, int(disparity[0])),
min(feats.size()[3], feats.size()[3] + int(disparity[0])))
)
mask = (feat_depths[v] == l)[:, src_window[0], src_window[1]][0]
trans_feats[i, v, :, tgt_window[0], tgt_window[1]][:, mask] = \
feats[v, :, src_window[0], src_window[1]][:, mask]
return trans_feats
def _DisparityFromDepth(self, tgt_view, src_view, diopter):
return torch.round((src_view - tgt_view) * diopter * self.f_cam * self.feat_dim)
class Decoder(nn.Module):
def __init__(self, nf0, in_channels, input_resolution, img_sidelength):
super().__init__()
num_down_unet = int(math.log2(input_resolution))
num_upsampling = int(math.log2(img_sidelength)) - num_down_unet
self.net = [
Unet(in_channels=in_channels,
out_channels=3 if num_upsampling <= 0 else 4 * nf0,
outermost_linear=True if num_upsampling <= 0 else False,
use_dropout=True,
dropout_prob=0.1,
nf0=nf0 * (2 ** num_upsampling),
norm=nn.BatchNorm2d,
max_channels=8 * nf0,
num_down=num_down_unet)
]
if num_upsampling > 0:
self.net += [
UpsamplingNet(per_layer_out_ch=num_upsampling * [nf0],
in_channels=4 * nf0,
upsampling_mode='transpose',
use_dropout=True,
dropout_prob=0.1),
Conv2dSame(nf0, out_channels=nf0 // 2,
kernel_size=3, bias=False),
nn.BatchNorm2d(nf0 // 2),
nn.ReLU(True),
Conv2dSame(nf0 // 2, 3, kernel_size=3)
]
self.net += [nn.Tanh()]
self.net = nn.Sequential(*self.net)
def forward(self, input):
return self.net(input)
class TransUnet(nn.Module):
def __init__(self, cam_params, view_images, view_depths, view_positions, diopter_of_layers):
super().__init__()
nf0 = 64 # Number of features to use in the outermost layer of all U-Nets
nf = 64 # Number of features in the latent space
latent_sidelength = 64 # The dimensions of the latent space
image_sidelength = view_images.size()[2]
self.view_images = view_images.to(device)
self.view_depths = view_depths.to(device)
self.n_views = view_images.size()[0]
self.encoder = Encoder(nf0=nf0,
out_channels=nf,
input_resolution=image_sidelength,
output_sidelength=latent_sidelength)
self.latent_space_transformer = LatentSpaceTransformer(feat_dim=latent_sidelength,
cam_params=cam_params,
view_positions=view_positions,
diopter_of_layers=diopter_of_layers)
self.decoder = Decoder(nf0=nf0,
in_channels=nf * 4,
input_resolution=latent_sidelength,
img_sidelength=image_sidelength)
def forward(self, novel_views):
if self.training:
self.feats, self.feat_depths = self.encoder(self.view_images,
self.view_depths)
transformed_feats = self.latent_space_transformer(self.feats,
self.feat_depths,
novel_views)
transformed_feats = torch.flatten(transformed_feats, 1, 2)
novel_view_images = self.decoder(transformed_feats)
return novel_view_images
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment