Commit 055dc0bb authored by BobYeah's avatar BobYeah
Browse files

First Stage

parent 648dfd2c
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
env/
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
.hypothesis/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
# log
*.txt
*.out
*.ipynb
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# pyenv
.python-version
# celery beat schedule file
celerybeat-schedule
# SageMath parsed files
*.sage.py
# dotenv
.env
# virtualenv
.venv
venv/
ENV/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
# macOS
.DS_Store
# Output
output/
\ No newline at end of file
......@@ -13,6 +13,8 @@ from torch.autograd import Variable
import cv2
from gen_image import *
import json
from ssim import *
from perc_loss import *
# param
BATCH_SIZE = 5
NUM_EPOCH = 5000
......@@ -27,6 +29,7 @@ M = 2 # number of display layers
DATA_FILE = "/home/yejiannan/Project/LightField/data/try"
DATA_JSON = "/home/yejiannan/Project/LightField/data/data.json"
DATA_VAL_JSON = "/home/yejiannan/Project/LightField/data/data_val.json"
OUTPUT_DIR = "/home/yejiannan/Project/LightField/output"
class lightFieldDataLoader(torch.utils.data.dataset.Dataset):
......@@ -34,7 +37,7 @@ class lightFieldDataLoader(torch.utils.data.dataset.Dataset):
self.file_dir_path = file_dir_path
self.transforms = transforms
# self.datum_list = glob.glob(os.path.join(file_dir_path,"*"))
with open(DATA_JSON, encoding='utf-8') as file:
with open(file_json, encoding='utf-8') as file:
self.dastset_desc = json.loads(file.read())
def __len__(self):
......@@ -147,7 +150,7 @@ class model(torch.nn.Module):
self.output_layer = torch.nn.Sequential(
torch.nn.Conv2d(OUT_CHANNELS_RB+1,LAST_LAYER_CHANNELS,KERNEL_SIZE,stride=1,padding=1),
torch.nn.BatchNorm2d(LAST_LAYER_CHANNELS),
torch.nn.Tanh()
torch.nn.Sigmoid()
)
self.deinterleave = deinterleave(INTERLEAVE_RATE)
......@@ -164,7 +167,7 @@ class model(torch.nn.Module):
depth_layer = torch.ones((output.shape[0],1,output.shape[2],output.shape[3]))
# print(df.shape[0])
for i in range(focal_length.shape[0]):
depth_layer[i] = depth_layer[i] * focal_length[i]
depth_layer[i] = 1. / focal_length[i]
# print(depth_layer.shape)
depth_layer = var_or_cuda(depth_layer)
output = torch.cat((output,depth_layer),dim=1)
......@@ -182,8 +185,8 @@ class Conf(object):
self.retinal_res = torch.tensor([ 480, 640 ])
self.layer_res = torch.tensor([ 480, 640 ])
self.n_layers = 2
self.d_layer = [ 1.75, 3.5 ] # layers' distance
self.h_layer = [ 1., 2. ] # layers' height
self.d_layer = [ 1., 3. ] # layers' distance
self.h_layer = [ 1. * 480. / 640., 3. * 480. / 640. ] # layers' height
#### Image Gen
conf = Conf()
......@@ -223,14 +226,14 @@ def GenRetinalFromLayersBatch(layers, conf, df, v, u):
torch.clamp_(pi[:, :, :, 1], 0, conf.layer_res[1] - 1)
Phi[bs, :, :, i, :, :] = pi
# print("Phi slice:",Phi[0, :, :, 0, 0, 0].shape)
retinal = torch.zeros(BS, 3, H_r, W_r)
retinal = torch.ones(BS, 3, H_r, W_r)
retinal = var_or_cuda(retinal)
for bs in range(BS):
for j in range(0, M):
retinal_view = torch.zeros(3, H_r, W_r)
retinal_view = torch.ones(3, H_r, W_r)
retinal_view = var_or_cuda(retinal_view)
for i in range(0, N):
retinal_view.add_(layers[bs, (i * 3) : (i * 3 + 3), Phi[bs, :, :, i, j, 0], Phi[bs, :, :, i, j, 1]])
retinal_view.mul_(layers[bs, (i * 3) : (i * 3 + 3), Phi[bs, :, :, i, j, 0], Phi[bs, :, :, i, j, 1]])
retinal[bs,:,:,:].add_(retinal_view)
retinal[bs,:,:,:].div_(M)
return retinal
......@@ -263,6 +266,42 @@ def var_or_cuda(x):
x = x.cuda(non_blocking=True)
return x
def calImageGradients(images):
# x is a 4-D tensor
dx = images[:, :, 1:, :] - images[:, :, :-1, :]
dy = images[:, 1:, :, :] - images[:, :-1, :, :]
return dx, dy
perc_loss = VGGPerceptualLoss()
perc_loss = perc_loss.to("cuda")
def loss_new(generated, gt):
mse_loss = torch.nn.MSELoss()
rmse_intensity = mse_loss(generated, gt)
RENORM_SCALE = torch.tensor(0.9)
RENORM_SCALE = var_or_cuda(RENORM_SCALE)
psnr_intensity = torch.log10(rmse_intensity)
ssim_intensity = ssim(generated, gt)
labels_dx, labels_dy = calImageGradients(gt)
preds_dx, preds_dy = calImageGradients(generated)
rmse_grad_x, rmse_grad_y = mse_loss(labels_dx, preds_dx), mse_loss(labels_dy, preds_dy)
psnr_grad_x, psnr_grad_y = torch.log10(rmse_grad_x), torch.log10(rmse_grad_y)
p_loss = perc_loss(generated,gt)
# print("-psnr:",-psnr_intensity,",0.5*(psnr_grad_x + psnr_grad_y):",0.5*(psnr_grad_x + psnr_grad_y),",perc_loss:",p_loss)
total_loss = 10 + psnr_intensity + 0.5*(psnr_grad_x + psnr_grad_y) + p_loss
return total_loss
def save_checkpoints(file_path, epoch_idx, model, model_solver):
print('[INFO] Saving checkpoint to %s ...' % ( file_path))
checkpoint = {
'epoch_idx': epoch_idx,
'model_state_dict': model.state_dict(),
'model_solver_state_dict': model_solver.state_dict()
}
torch.save(checkpoint, file_path)
mode = "val"
if __name__ == "__main__":
#test
# train_dataset = lightFieldDataLoader(DATA_FILE,DATA_JSON)
......@@ -271,40 +310,91 @@ if __name__ == "__main__":
# save_image(output[0][0:3].data,os.path.join(OUTPUT_DIR,"o%d_%d.png"%(epoch,batch_idx)))
#test end
#train
train_data_loader = torch.utils.data.DataLoader(dataset=lightFieldDataLoader(DATA_FILE,DATA_JSON),
batch_size=BATCH_SIZE,
num_workers=0,
pin_memory=True,
shuffle=False,
shuffle=True,
drop_last=False)
print(len(train_data_loader))
val_data_loader = torch.utils.data.DataLoader(dataset=lightFieldDataLoader(DATA_FILE,DATA_VAL_JSON),
batch_size=1,
num_workers=0,
pin_memory=True,
shuffle=False,
drop_last=False)
print(len(val_data_loader))
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
lf_model = model()
lf_model.apply(weight_init_normal)
if torch.cuda.is_available():
lf_model = torch.nn.DataParallel(lf_model).cuda()
lf_model.train()
optimizer = torch.optim.Adam(lf_model.parameters(),lr=5e-3,betas=(0.9,0.999))
for epoch in range(NUM_EPOCH):
for batch_idx, (image_set, gt, df) in enumerate(train_data_loader):
#val
checkpoint = torch.load(os.path.join(OUTPUT_DIR,"ckpt-epoch-3001.pth"))
lf_model.load_state_dict(checkpoint["model_state_dict"])
lf_model.eval()
print("Eval::")
for sample_idx, (image_set, gt, df) in enumerate(val_data_loader):
print("sample_idx::")
with torch.no_grad():
#reshape for input
image_set = image_set.permute(0,1,4,2,3) # N LF C H W
image_set = image_set.reshape(image_set.shape[0],-1,image_set.shape[3],image_set.shape[4]) # N, LFxC, H, W
image_set = var_or_cuda(image_set)
# image_set.to(device)
gt = gt.permute(0,3,1,2)
gt = var_or_cuda(gt)
# print("Epoch:",epoch,",Iter:",batch_idx,",Input shape:",image_set.shape, ",Input gt:",gt.shape)
optimizer.zero_grad()
output = lf_model(image_set,df)
# print("output:",output.shape," df:",df.shape)
print("output:",output.shape," df:",df)
save_image(output[0][0:3].data,os.path.join(OUTPUT_DIR,"1113_interp_l1_%.3f.png"%(df[0].data)))
save_image(output[0][3:6].data,os.path.join(OUTPUT_DIR,"1113_interp_l2_%.3f.png"%(df[0].data)))
output = GenRetinalFromLayersBatch(output,conf,df,v,u)
loss = loss_two_images(output,gt)
print("Epoch:",epoch,",Iter:",batch_idx,",loss:",loss)
loss.backward()
optimizer.step()
for i in range(5):
save_image(output[i][0:3].data,os.path.join(OUTPUT_DIR,"cuda_lr_5e-3_insertmid_o%d_%d.png"%(epoch,i)))
save_image(output[0][0:3].data,os.path.join(OUTPUT_DIR,"1113_interp_o%.3f.png"%(df[0].data)))
exit()
# train
# print(lf_model)
# exit()
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# lf_model = model()
# lf_model.apply(weight_init_normal)
# if torch.cuda.is_available():
# lf_model = torch.nn.DataParallel(lf_model).cuda()
# lf_model.train()
# optimizer = torch.optim.Adam(lf_model.parameters(),lr=5e-2,betas=(0.9,0.999))
# for epoch in range(NUM_EPOCH):
# for batch_idx, (image_set, gt, df) in enumerate(train_data_loader):
# #reshape for input
# image_set = image_set.permute(0,1,4,2,3) # N LF C H W
# image_set = image_set.reshape(image_set.shape[0],-1,image_set.shape[3],image_set.shape[4]) # N, LFxC, H, W
# image_set = var_or_cuda(image_set)
# # image_set.to(device)
# gt = gt.permute(0,3,1,2)
# gt = var_or_cuda(gt)
# # print("Epoch:",epoch,",Iter:",batch_idx,",Input shape:",image_set.shape, ",Input gt:",gt.shape)
# optimizer.zero_grad()
# output = lf_model(image_set,df)
# # print("output:",output.shape," df:",df.shape)
# output = GenRetinalFromLayersBatch(output,conf,df,v,u)
# loss = loss_new(output,gt)
# print("Epoch:",epoch,",Iter:",batch_idx,",loss:",loss)
# loss.backward()
# optimizer.step()
# if (epoch%100 == 0):
# for i in range(BATCH_SIZE):
# save_image(output[i][0:3].data,os.path.join(OUTPUT_DIR,"cuda_lr_5e-2_mul_dip_newloss_debug_conf_o%d_%d.png"%(epoch,i)))
# if (epoch%1000 == 0):
# save_checkpoints(os.path.join(OUTPUT_DIR, 'ckpt-epoch-%04d.pth' % (epoch + 1)),
# epoch,lf_model,optimizer)
\ No newline at end of file
import torch
import torchvision
class VGGPerceptualLoss(torch.nn.Module):
def __init__(self, resize=True):
super(VGGPerceptualLoss, self).__init__()
blocks = []
blocks.append(torchvision.models.vgg16(pretrained=True).features[:4].eval())
blocks.append(torchvision.models.vgg16(pretrained=True).features[4:9].eval())
blocks.append(torchvision.models.vgg16(pretrained=True).features[9:16].eval())
blocks.append(torchvision.models.vgg16(pretrained=True).features[16:23].eval())
for bl in blocks:
for p in bl:
p.requires_grad = False
self.blocks = torch.nn.ModuleList(blocks)
self.transform = torch.nn.functional.interpolate
self.mean = torch.nn.Parameter(torch.tensor([0.485, 0.456, 0.406]).view(1,3,1,1))
self.std = torch.nn.Parameter(torch.tensor([0.229, 0.224, 0.225]).view(1,3,1,1))
self.resize = resize
def forward(self, input, target):
if input.shape[1] != 3:
input = input.repeat(1, 3, 1, 1)
target = target.repeat(1, 3, 1, 1)
input = (input-self.mean) / self.std
target = (target-self.mean) / self.std
if self.resize:
input = self.transform(input, mode='bilinear', size=(224, 224), align_corners=False)
target = self.transform(target, mode='bilinear', size=(224, 224), align_corners=False)
loss = 0.0
x = input
y = target
for block in self.blocks:
x = block(x)
y = block(y)
loss += torch.nn.functional.l1_loss(x, y)
return loss
\ No newline at end of file
import torch
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np
from math import exp
def gaussian(window_size, sigma):
gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])
return gauss/gauss.sum()
def create_window(window_size, channel):
_1D_window = gaussian(window_size, 1.5).unsqueeze(1)
_2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
return window
def _ssim(img1, img2, window, window_size, channel, size_average = True):
mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel)
mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel)
mu1_sq = mu1.pow(2)
mu2_sq = mu2.pow(2)
mu1_mu2 = mu1*mu2
sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq
sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq
sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2
C1 = 0.01**2
C2 = 0.03**2
ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2))
if size_average:
return ssim_map.mean()
else:
return ssim_map.mean(1).mean(1).mean(1)
class SSIM(torch.nn.Module):
def __init__(self, window_size = 11, size_average = True):
super(SSIM, self).__init__()
self.window_size = window_size
self.size_average = size_average
self.channel = 1
self.window = create_window(window_size, self.channel)
def forward(self, img1, img2):
(_, channel, _, _) = img1.size()
if channel == self.channel and self.window.data.type() == img1.data.type():
window = self.window
else:
window = create_window(self.window_size, channel)
if img1.is_cuda:
window = window.cuda(img1.get_device())
window = window.type_as(img1)
self.window = window
self.channel = channel
return _ssim(img1, img2, window, self.window_size, channel, self.size_average)
def ssim(img1, img2, window_size = 11, size_average = True):
(_, channel, _, _) = img1.size()
window = create_window(window_size, channel)
if img1.is_cuda:
window = window.cuda(img1.get_device())
window = window.type_as(img1)
return _ssim(img1, img2, window, window_size, channel, size_average)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment