From 055dc0bba4224ae21de8be321ab9d9e17a74d21b Mon Sep 17 00:00:00 2001 From: BobYeah <635596704@qq.com> Date: Sat, 21 Nov 2020 16:58:38 +0800 Subject: [PATCH] First Stage --- .gitignore | 112 +++++++++++++++++++++++++++++++++++++++++ main.py | 140 ++++++++++++++++++++++++++++++++++++++++++--------- perc_loss.py | 38 ++++++++++++++ ssim.py | 72 ++++++++++++++++++++++++++ 4 files changed, 337 insertions(+), 25 deletions(-) create mode 100644 .gitignore create mode 100644 perc_loss.py create mode 100644 ssim.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..6109683 --- /dev/null +++ b/.gitignore @@ -0,0 +1,112 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +env/ +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +.hypothesis/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py + +# log +*.txt +*.out +*.ipynb + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# pyenv +.python-version + +# celery beat schedule file +celerybeat-schedule + +# SageMath parsed files +*.sage.py + +# dotenv +.env + +# virtualenv +.venv +venv/ +ENV/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ + +# macOS +.DS_Store + +# Output +output/ \ No newline at end of file diff --git a/main.py b/main.py index 216fb7f..d75206b 100644 --- a/main.py +++ b/main.py @@ -13,6 +13,8 @@ from torch.autograd import Variable import cv2 from gen_image import * import json +from ssim import * +from perc_loss import * # param BATCH_SIZE = 5 NUM_EPOCH = 5000 @@ -27,6 +29,7 @@ M = 2 # number of display layers DATA_FILE = "/home/yejiannan/Project/LightField/data/try" DATA_JSON = "/home/yejiannan/Project/LightField/data/data.json" +DATA_VAL_JSON = "/home/yejiannan/Project/LightField/data/data_val.json" OUTPUT_DIR = "/home/yejiannan/Project/LightField/output" class lightFieldDataLoader(torch.utils.data.dataset.Dataset): @@ -34,7 +37,7 @@ class lightFieldDataLoader(torch.utils.data.dataset.Dataset): self.file_dir_path = file_dir_path self.transforms = transforms # self.datum_list = glob.glob(os.path.join(file_dir_path,"*")) - with open(DATA_JSON, encoding='utf-8') as file: + with open(file_json, encoding='utf-8') as file: self.dastset_desc = json.loads(file.read()) def __len__(self): @@ -147,7 +150,7 @@ class model(torch.nn.Module): self.output_layer = torch.nn.Sequential( torch.nn.Conv2d(OUT_CHANNELS_RB+1,LAST_LAYER_CHANNELS,KERNEL_SIZE,stride=1,padding=1), torch.nn.BatchNorm2d(LAST_LAYER_CHANNELS), - torch.nn.Tanh() + torch.nn.Sigmoid() ) self.deinterleave = deinterleave(INTERLEAVE_RATE) @@ -164,7 +167,7 @@ class model(torch.nn.Module): depth_layer = torch.ones((output.shape[0],1,output.shape[2],output.shape[3])) # print(df.shape[0]) for i in range(focal_length.shape[0]): - depth_layer[i] = depth_layer[i] * focal_length[i] + depth_layer[i] = 1. / focal_length[i] # print(depth_layer.shape) depth_layer = var_or_cuda(depth_layer) output = torch.cat((output,depth_layer),dim=1) @@ -182,8 +185,8 @@ class Conf(object): self.retinal_res = torch.tensor([ 480, 640 ]) self.layer_res = torch.tensor([ 480, 640 ]) self.n_layers = 2 - self.d_layer = [ 1.75, 3.5 ] # layers' distance - self.h_layer = [ 1., 2. ] # layers' height + self.d_layer = [ 1., 3. ] # layers' distance + self.h_layer = [ 1. * 480. / 640., 3. * 480. / 640. ] # layers' height #### Image Gen conf = Conf() @@ -223,14 +226,14 @@ def GenRetinalFromLayersBatch(layers, conf, df, v, u): torch.clamp_(pi[:, :, :, 1], 0, conf.layer_res[1] - 1) Phi[bs, :, :, i, :, :] = pi # print("Phi slice:",Phi[0, :, :, 0, 0, 0].shape) - retinal = torch.zeros(BS, 3, H_r, W_r) + retinal = torch.ones(BS, 3, H_r, W_r) retinal = var_or_cuda(retinal) for bs in range(BS): for j in range(0, M): - retinal_view = torch.zeros(3, H_r, W_r) + retinal_view = torch.ones(3, H_r, W_r) retinal_view = var_or_cuda(retinal_view) for i in range(0, N): - retinal_view.add_(layers[bs, (i * 3) : (i * 3 + 3), Phi[bs, :, :, i, j, 0], Phi[bs, :, :, i, j, 1]]) + retinal_view.mul_(layers[bs, (i * 3) : (i * 3 + 3), Phi[bs, :, :, i, j, 0], Phi[bs, :, :, i, j, 1]]) retinal[bs,:,:,:].add_(retinal_view) retinal[bs,:,:,:].div_(M) return retinal @@ -263,6 +266,42 @@ def var_or_cuda(x): x = x.cuda(non_blocking=True) return x +def calImageGradients(images): + # x is a 4-D tensor + dx = images[:, :, 1:, :] - images[:, :, :-1, :] + dy = images[:, 1:, :, :] - images[:, :-1, :, :] + return dx, dy + + +perc_loss = VGGPerceptualLoss() +perc_loss = perc_loss.to("cuda") + +def loss_new(generated, gt): + mse_loss = torch.nn.MSELoss() + rmse_intensity = mse_loss(generated, gt) + RENORM_SCALE = torch.tensor(0.9) + RENORM_SCALE = var_or_cuda(RENORM_SCALE) + psnr_intensity = torch.log10(rmse_intensity) + ssim_intensity = ssim(generated, gt) + labels_dx, labels_dy = calImageGradients(gt) + preds_dx, preds_dy = calImageGradients(generated) + rmse_grad_x, rmse_grad_y = mse_loss(labels_dx, preds_dx), mse_loss(labels_dy, preds_dy) + psnr_grad_x, psnr_grad_y = torch.log10(rmse_grad_x), torch.log10(rmse_grad_y) + p_loss = perc_loss(generated,gt) + # print("-psnr:",-psnr_intensity,",0.5*(psnr_grad_x + psnr_grad_y):",0.5*(psnr_grad_x + psnr_grad_y),",perc_loss:",p_loss) + total_loss = 10 + psnr_intensity + 0.5*(psnr_grad_x + psnr_grad_y) + p_loss + return total_loss + +def save_checkpoints(file_path, epoch_idx, model, model_solver): + print('[INFO] Saving checkpoint to %s ...' % ( file_path)) + checkpoint = { + 'epoch_idx': epoch_idx, + 'model_state_dict': model.state_dict(), + 'model_solver_state_dict': model_solver.state_dict() + } + torch.save(checkpoint, file_path) + +mode = "val" if __name__ == "__main__": #test # train_dataset = lightFieldDataLoader(DATA_FILE,DATA_JSON) @@ -270,41 +309,92 @@ if __name__ == "__main__": # cv2.imwrite("test_crop0.png",train_dataset[0][1]*255.) # save_image(output[0][0:3].data,os.path.join(OUTPUT_DIR,"o%d_%d.png"%(epoch,batch_idx))) #test end - + + #train train_data_loader = torch.utils.data.DataLoader(dataset=lightFieldDataLoader(DATA_FILE,DATA_JSON), batch_size=BATCH_SIZE, num_workers=0, pin_memory=True, - shuffle=False, + shuffle=True, drop_last=False) print(len(train_data_loader)) + + val_data_loader = torch.utils.data.DataLoader(dataset=lightFieldDataLoader(DATA_FILE,DATA_VAL_JSON), + batch_size=1, + num_workers=0, + pin_memory=True, + shuffle=False, + drop_last=False) + + print(len(val_data_loader)) + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') lf_model = model() - lf_model.apply(weight_init_normal) if torch.cuda.is_available(): lf_model = torch.nn.DataParallel(lf_model).cuda() - lf_model.train() - optimizer = torch.optim.Adam(lf_model.parameters(),lr=5e-3,betas=(0.9,0.999)) - - for epoch in range(NUM_EPOCH): - for batch_idx, (image_set, gt, df) in enumerate(train_data_loader): + + #val + checkpoint = torch.load(os.path.join(OUTPUT_DIR,"ckpt-epoch-3001.pth")) + lf_model.load_state_dict(checkpoint["model_state_dict"]) + lf_model.eval() + + print("Eval::") + for sample_idx, (image_set, gt, df) in enumerate(val_data_loader): + print("sample_idx::") + with torch.no_grad(): #reshape for input image_set = image_set.permute(0,1,4,2,3) # N LF C H W image_set = image_set.reshape(image_set.shape[0],-1,image_set.shape[3],image_set.shape[4]) # N, LFxC, H, W - image_set = var_or_cuda(image_set) # image_set.to(device) gt = gt.permute(0,3,1,2) gt = var_or_cuda(gt) # print("Epoch:",epoch,",Iter:",batch_idx,",Input shape:",image_set.shape, ",Input gt:",gt.shape) - optimizer.zero_grad() output = lf_model(image_set,df) - # print("output:",output.shape," df:",df.shape) + print("output:",output.shape," df:",df) + save_image(output[0][0:3].data,os.path.join(OUTPUT_DIR,"1113_interp_l1_%.3f.png"%(df[0].data))) + save_image(output[0][3:6].data,os.path.join(OUTPUT_DIR,"1113_interp_l2_%.3f.png"%(df[0].data))) output = GenRetinalFromLayersBatch(output,conf,df,v,u) - loss = loss_two_images(output,gt) - print("Epoch:",epoch,",Iter:",batch_idx,",loss:",loss) - loss.backward() - optimizer.step() - for i in range(5): - save_image(output[i][0:3].data,os.path.join(OUTPUT_DIR,"cuda_lr_5e-3_insertmid_o%d_%d.png"%(epoch,i))) + save_image(output[0][0:3].data,os.path.join(OUTPUT_DIR,"1113_interp_o%.3f.png"%(df[0].data))) + exit() + # train + # print(lf_model) + # exit() + + # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + # lf_model = model() + # lf_model.apply(weight_init_normal) + + # if torch.cuda.is_available(): + # lf_model = torch.nn.DataParallel(lf_model).cuda() + # lf_model.train() + # optimizer = torch.optim.Adam(lf_model.parameters(),lr=5e-2,betas=(0.9,0.999)) + + # for epoch in range(NUM_EPOCH): + # for batch_idx, (image_set, gt, df) in enumerate(train_data_loader): + # #reshape for input + # image_set = image_set.permute(0,1,4,2,3) # N LF C H W + # image_set = image_set.reshape(image_set.shape[0],-1,image_set.shape[3],image_set.shape[4]) # N, LFxC, H, W + + # image_set = var_or_cuda(image_set) + # # image_set.to(device) + # gt = gt.permute(0,3,1,2) + # gt = var_or_cuda(gt) + # # print("Epoch:",epoch,",Iter:",batch_idx,",Input shape:",image_set.shape, ",Input gt:",gt.shape) + # optimizer.zero_grad() + # output = lf_model(image_set,df) + # # print("output:",output.shape," df:",df.shape) + # output = GenRetinalFromLayersBatch(output,conf,df,v,u) + # loss = loss_new(output,gt) + # print("Epoch:",epoch,",Iter:",batch_idx,",loss:",loss) + # loss.backward() + # optimizer.step() + # if (epoch%100 == 0): + # for i in range(BATCH_SIZE): + # save_image(output[i][0:3].data,os.path.join(OUTPUT_DIR,"cuda_lr_5e-2_mul_dip_newloss_debug_conf_o%d_%d.png"%(epoch,i))) + # if (epoch%1000 == 0): + # save_checkpoints(os.path.join(OUTPUT_DIR, 'ckpt-epoch-%04d.pth' % (epoch + 1)), + # epoch,lf_model,optimizer) + + \ No newline at end of file diff --git a/perc_loss.py b/perc_loss.py new file mode 100644 index 0000000..acad75a --- /dev/null +++ b/perc_loss.py @@ -0,0 +1,38 @@ + +import torch +import torchvision + +class VGGPerceptualLoss(torch.nn.Module): + def __init__(self, resize=True): + super(VGGPerceptualLoss, self).__init__() + blocks = [] + blocks.append(torchvision.models.vgg16(pretrained=True).features[:4].eval()) + blocks.append(torchvision.models.vgg16(pretrained=True).features[4:9].eval()) + blocks.append(torchvision.models.vgg16(pretrained=True).features[9:16].eval()) + blocks.append(torchvision.models.vgg16(pretrained=True).features[16:23].eval()) + for bl in blocks: + for p in bl: + p.requires_grad = False + self.blocks = torch.nn.ModuleList(blocks) + self.transform = torch.nn.functional.interpolate + self.mean = torch.nn.Parameter(torch.tensor([0.485, 0.456, 0.406]).view(1,3,1,1)) + self.std = torch.nn.Parameter(torch.tensor([0.229, 0.224, 0.225]).view(1,3,1,1)) + self.resize = resize + + def forward(self, input, target): + if input.shape[1] != 3: + input = input.repeat(1, 3, 1, 1) + target = target.repeat(1, 3, 1, 1) + input = (input-self.mean) / self.std + target = (target-self.mean) / self.std + if self.resize: + input = self.transform(input, mode='bilinear', size=(224, 224), align_corners=False) + target = self.transform(target, mode='bilinear', size=(224, 224), align_corners=False) + loss = 0.0 + x = input + y = target + for block in self.blocks: + x = block(x) + y = block(y) + loss += torch.nn.functional.l1_loss(x, y) + return loss \ No newline at end of file diff --git a/ssim.py b/ssim.py new file mode 100644 index 0000000..93f390b --- /dev/null +++ b/ssim.py @@ -0,0 +1,72 @@ +import torch +import torch.nn.functional as F +from torch.autograd import Variable +import numpy as np +from math import exp + +def gaussian(window_size, sigma): + gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)]) + return gauss/gauss.sum() + +def create_window(window_size, channel): + _1D_window = gaussian(window_size, 1.5).unsqueeze(1) + _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) + window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) + return window + +def _ssim(img1, img2, window, window_size, channel, size_average = True): + mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel) + mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel) + + mu1_sq = mu1.pow(2) + mu2_sq = mu2.pow(2) + mu1_mu2 = mu1*mu2 + + sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq + sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq + sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2 + + C1 = 0.01**2 + C2 = 0.03**2 + + ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2)) + + if size_average: + return ssim_map.mean() + else: + return ssim_map.mean(1).mean(1).mean(1) + +class SSIM(torch.nn.Module): + def __init__(self, window_size = 11, size_average = True): + super(SSIM, self).__init__() + self.window_size = window_size + self.size_average = size_average + self.channel = 1 + self.window = create_window(window_size, self.channel) + + def forward(self, img1, img2): + (_, channel, _, _) = img1.size() + + if channel == self.channel and self.window.data.type() == img1.data.type(): + window = self.window + else: + window = create_window(self.window_size, channel) + + if img1.is_cuda: + window = window.cuda(img1.get_device()) + window = window.type_as(img1) + + self.window = window + self.channel = channel + + return _ssim(img1, img2, window, self.window_size, channel, self.size_average) + +def ssim(img1, img2, window_size = 11, size_average = True): + (_, channel, _, _) = img1.size() + window = create_window(window_size, channel) + + if img1.is_cuda: + window = window.cuda(img1.get_device()) + window = window.type_as(img1) + + return _ssim(img1, img2, window, window_size, channel, size_average) \ No newline at end of file -- GitLab