From 055dc0bba4224ae21de8be321ab9d9e17a74d21b Mon Sep 17 00:00:00 2001
From: BobYeah <635596704@qq.com>
Date: Sat, 21 Nov 2020 16:58:38 +0800
Subject: [PATCH] First Stage

---
 .gitignore   | 112 +++++++++++++++++++++++++++++++++++++++++
 main.py      | 140 ++++++++++++++++++++++++++++++++++++++++++---------
 perc_loss.py |  38 ++++++++++++++
 ssim.py      |  72 ++++++++++++++++++++++++++
 4 files changed, 337 insertions(+), 25 deletions(-)
 create mode 100644 .gitignore
 create mode 100644 perc_loss.py
 create mode 100644 ssim.py

diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..6109683
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,112 @@
+# Byte-compiled / optimized / DLL files
+__pycache__/
+*.py[cod]
+*$py.class
+
+# C extensions
+*.so
+
+# Distribution / packaging
+.Python
+env/
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+
+# PyInstaller
+#  Usually these files are written by a python script from a template
+#  before PyInstaller builds the exe, so as to inject date/other infos into it.
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*.cover
+.hypothesis/
+
+# Translations
+*.mo
+*.pot
+
+# Django stuff:
+*.log
+local_settings.py
+
+# log
+*.txt
+*.out
+*.ipynb
+
+# Flask stuff:
+instance/
+.webassets-cache
+
+# Scrapy stuff:
+.scrapy
+
+# Sphinx documentation
+docs/_build/
+
+# PyBuilder
+target/
+
+# Jupyter Notebook
+.ipynb_checkpoints
+
+# pyenv
+.python-version
+
+# celery beat schedule file
+celerybeat-schedule
+
+# SageMath parsed files
+*.sage.py
+
+# dotenv
+.env
+
+# virtualenv
+.venv
+venv/
+ENV/
+
+# Spyder project settings
+.spyderproject
+.spyproject
+
+# Rope project settings
+.ropeproject
+
+# mkdocs documentation
+/site
+
+# mypy
+.mypy_cache/
+
+# macOS
+.DS_Store
+
+# Output
+output/
\ No newline at end of file
diff --git a/main.py b/main.py
index 216fb7f..d75206b 100644
--- a/main.py
+++ b/main.py
@@ -13,6 +13,8 @@ from torch.autograd import Variable
 import cv2
 from gen_image import *
 import json
+from ssim import *
+from perc_loss import * 
 # param
 BATCH_SIZE = 5
 NUM_EPOCH = 5000
@@ -27,6 +29,7 @@ M = 2 # number of display layers
 
 DATA_FILE = "/home/yejiannan/Project/LightField/data/try"
 DATA_JSON = "/home/yejiannan/Project/LightField/data/data.json"
+DATA_VAL_JSON = "/home/yejiannan/Project/LightField/data/data_val.json"
 OUTPUT_DIR = "/home/yejiannan/Project/LightField/output"
 
 class lightFieldDataLoader(torch.utils.data.dataset.Dataset):
@@ -34,7 +37,7 @@ class lightFieldDataLoader(torch.utils.data.dataset.Dataset):
         self.file_dir_path = file_dir_path
         self.transforms = transforms
         # self.datum_list = glob.glob(os.path.join(file_dir_path,"*"))
-        with open(DATA_JSON, encoding='utf-8') as file:
+        with open(file_json, encoding='utf-8') as file:
             self.dastset_desc = json.loads(file.read())
 
     def __len__(self):
@@ -147,7 +150,7 @@ class model(torch.nn.Module):
         self.output_layer = torch.nn.Sequential(
             torch.nn.Conv2d(OUT_CHANNELS_RB+1,LAST_LAYER_CHANNELS,KERNEL_SIZE,stride=1,padding=1),
             torch.nn.BatchNorm2d(LAST_LAYER_CHANNELS),
-            torch.nn.Tanh()
+            torch.nn.Sigmoid()
         )
         self.deinterleave = deinterleave(INTERLEAVE_RATE)
 
@@ -164,7 +167,7 @@ class model(torch.nn.Module):
         depth_layer = torch.ones((output.shape[0],1,output.shape[2],output.shape[3]))
         # print(df.shape[0])
         for i in range(focal_length.shape[0]):
-            depth_layer[i] = depth_layer[i] * focal_length[i]
+            depth_layer[i] = 1. / focal_length[i]
             # print(depth_layer.shape)
         depth_layer = var_or_cuda(depth_layer)
         output = torch.cat((output,depth_layer),dim=1)
@@ -182,8 +185,8 @@ class Conf(object):
         self.retinal_res = torch.tensor([ 480, 640 ])
         self.layer_res = torch.tensor([ 480, 640 ])
         self.n_layers = 2
-        self.d_layer = [ 1.75, 3.5 ] # layers' distance
-        self.h_layer = [ 1., 2. ] # layers' height
+        self.d_layer = [ 1., 3. ] # layers' distance
+        self.h_layer = [ 1. * 480. / 640., 3. * 480. / 640. ] # layers' height
 
 #### Image Gen
 conf = Conf()
@@ -223,14 +226,14 @@ def GenRetinalFromLayersBatch(layers, conf, df, v, u):
             torch.clamp_(pi[:, :, :, 1], 0, conf.layer_res[1] - 1)
             Phi[bs, :, :, i, :, :] = pi
     # print("Phi slice:",Phi[0, :, :, 0, 0, 0].shape)
-    retinal = torch.zeros(BS, 3, H_r, W_r)
+    retinal = torch.ones(BS, 3, H_r, W_r)
     retinal = var_or_cuda(retinal)
     for bs in range(BS):
         for j in range(0, M):
-            retinal_view = torch.zeros(3, H_r, W_r)
+            retinal_view = torch.ones(3, H_r, W_r)
             retinal_view = var_or_cuda(retinal_view)
             for i in range(0, N):
-                retinal_view.add_(layers[bs, (i * 3) : (i * 3 + 3), Phi[bs, :, :, i, j, 0], Phi[bs, :, :, i, j, 1]])
+                retinal_view.mul_(layers[bs, (i * 3) : (i * 3 + 3), Phi[bs, :, :, i, j, 0], Phi[bs, :, :, i, j, 1]])
             retinal[bs,:,:,:].add_(retinal_view)
         retinal[bs,:,:,:].div_(M)
     return retinal
@@ -263,6 +266,42 @@ def var_or_cuda(x):
         x = x.cuda(non_blocking=True)
     return x
 
+def calImageGradients(images):
+    # x is a 4-D tensor
+    dx = images[:, :, 1:, :] - images[:, :, :-1, :]
+    dy = images[:, 1:, :, :] - images[:, :-1, :, :]
+    return dx, dy
+
+
+perc_loss = VGGPerceptualLoss() 
+perc_loss = perc_loss.to("cuda")
+
+def loss_new(generated, gt):
+    mse_loss = torch.nn.MSELoss()
+    rmse_intensity = mse_loss(generated, gt)
+    RENORM_SCALE = torch.tensor(0.9)
+    RENORM_SCALE = var_or_cuda(RENORM_SCALE)
+    psnr_intensity = torch.log10(rmse_intensity)
+    ssim_intensity = ssim(generated, gt)
+    labels_dx, labels_dy = calImageGradients(gt)
+    preds_dx, preds_dy = calImageGradients(generated)
+    rmse_grad_x, rmse_grad_y = mse_loss(labels_dx, preds_dx), mse_loss(labels_dy, preds_dy)
+    psnr_grad_x, psnr_grad_y = torch.log10(rmse_grad_x), torch.log10(rmse_grad_y)
+    p_loss = perc_loss(generated,gt)
+    # print("-psnr:",-psnr_intensity,",0.5*(psnr_grad_x + psnr_grad_y):",0.5*(psnr_grad_x + psnr_grad_y),",perc_loss:",p_loss)
+    total_loss = 10 + psnr_intensity + 0.5*(psnr_grad_x + psnr_grad_y) + p_loss
+    return total_loss
+
+def save_checkpoints(file_path, epoch_idx, model, model_solver):
+    print('[INFO] Saving checkpoint to %s ...' % ( file_path))
+    checkpoint = {
+        'epoch_idx': epoch_idx,
+        'model_state_dict': model.state_dict(),
+        'model_solver_state_dict': model_solver.state_dict()
+    }
+    torch.save(checkpoint, file_path)
+
+mode = "val"
 if __name__ == "__main__":
     #test
     # train_dataset = lightFieldDataLoader(DATA_FILE,DATA_JSON)
@@ -270,41 +309,92 @@ if __name__ == "__main__":
     # cv2.imwrite("test_crop0.png",train_dataset[0][1]*255.)
     # save_image(output[0][0:3].data,os.path.join(OUTPUT_DIR,"o%d_%d.png"%(epoch,batch_idx)))
     #test end
-
+    
+    #train
     train_data_loader = torch.utils.data.DataLoader(dataset=lightFieldDataLoader(DATA_FILE,DATA_JSON),
                                                     batch_size=BATCH_SIZE,
                                                     num_workers=0,
                                                     pin_memory=True,
-                                                    shuffle=False,
+                                                    shuffle=True,
                                                     drop_last=False)
     print(len(train_data_loader))
+
+    val_data_loader = torch.utils.data.DataLoader(dataset=lightFieldDataLoader(DATA_FILE,DATA_VAL_JSON),
+                                                    batch_size=1,
+                                                    num_workers=0,
+                                                    pin_memory=True,
+                                                    shuffle=False,
+                                                    drop_last=False)
+
+    print(len(val_data_loader))
+
     device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
     lf_model = model()
-    lf_model.apply(weight_init_normal)
     if torch.cuda.is_available():
         lf_model = torch.nn.DataParallel(lf_model).cuda()
-    lf_model.train()
-    optimizer = torch.optim.Adam(lf_model.parameters(),lr=5e-3,betas=(0.9,0.999))
-    
-    for epoch in range(NUM_EPOCH):
-        for batch_idx, (image_set, gt, df) in enumerate(train_data_loader):
+
+    #val 
+    checkpoint = torch.load(os.path.join(OUTPUT_DIR,"ckpt-epoch-3001.pth"))
+    lf_model.load_state_dict(checkpoint["model_state_dict"])
+    lf_model.eval()
+
+    print("Eval::")
+    for sample_idx, (image_set, gt, df) in enumerate(val_data_loader):
+        print("sample_idx::")
+        with torch.no_grad():
             #reshape for input
             image_set = image_set.permute(0,1,4,2,3) # N LF C H W
             image_set = image_set.reshape(image_set.shape[0],-1,image_set.shape[3],image_set.shape[4]) # N, LFxC, H, W
-            
             image_set = var_or_cuda(image_set)
             # image_set.to(device)
             gt = gt.permute(0,3,1,2)
             gt = var_or_cuda(gt)
             # print("Epoch:",epoch,",Iter:",batch_idx,",Input shape:",image_set.shape, ",Input gt:",gt.shape)
-            optimizer.zero_grad()
             output = lf_model(image_set,df)
-            # print("output:",output.shape," df:",df.shape)
+            print("output:",output.shape," df:",df)
+            save_image(output[0][0:3].data,os.path.join(OUTPUT_DIR,"1113_interp_l1_%.3f.png"%(df[0].data)))
+            save_image(output[0][3:6].data,os.path.join(OUTPUT_DIR,"1113_interp_l2_%.3f.png"%(df[0].data)))
             output = GenRetinalFromLayersBatch(output,conf,df,v,u)
-            loss = loss_two_images(output,gt)
-            print("Epoch:",epoch,",Iter:",batch_idx,",loss:",loss)
-            loss.backward()
-            optimizer.step()
-            for i in range(5):
-                save_image(output[i][0:3].data,os.path.join(OUTPUT_DIR,"cuda_lr_5e-3_insertmid_o%d_%d.png"%(epoch,i)))
+            save_image(output[0][0:3].data,os.path.join(OUTPUT_DIR,"1113_interp_o%.3f.png"%(df[0].data)))
+    exit()
 
+    # train
+    # print(lf_model)
+    # exit()
+
+    # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+    # lf_model = model()
+    # lf_model.apply(weight_init_normal)
+
+    # if torch.cuda.is_available():
+    #     lf_model = torch.nn.DataParallel(lf_model).cuda()
+    # lf_model.train()
+    # optimizer = torch.optim.Adam(lf_model.parameters(),lr=5e-2,betas=(0.9,0.999))
+    
+    # for epoch in range(NUM_EPOCH):
+    #     for batch_idx, (image_set, gt, df) in enumerate(train_data_loader):
+    #         #reshape for input
+    #         image_set = image_set.permute(0,1,4,2,3) # N LF C H W
+    #         image_set = image_set.reshape(image_set.shape[0],-1,image_set.shape[3],image_set.shape[4]) # N, LFxC, H, W
+            
+    #         image_set = var_or_cuda(image_set)
+    #         # image_set.to(device)
+    #         gt = gt.permute(0,3,1,2)
+    #         gt = var_or_cuda(gt)
+    #         # print("Epoch:",epoch,",Iter:",batch_idx,",Input shape:",image_set.shape, ",Input gt:",gt.shape)
+    #         optimizer.zero_grad()
+    #         output = lf_model(image_set,df)
+    #         # print("output:",output.shape," df:",df.shape)
+    #         output = GenRetinalFromLayersBatch(output,conf,df,v,u)
+    #         loss = loss_new(output,gt)
+    #         print("Epoch:",epoch,",Iter:",batch_idx,",loss:",loss)
+    #         loss.backward()
+    #         optimizer.step()
+    #         if (epoch%100 == 0):
+    #             for i in range(BATCH_SIZE):
+    #                 save_image(output[i][0:3].data,os.path.join(OUTPUT_DIR,"cuda_lr_5e-2_mul_dip_newloss_debug_conf_o%d_%d.png"%(epoch,i)))
+    #         if (epoch%1000 == 0):
+    #             save_checkpoints(os.path.join(OUTPUT_DIR, 'ckpt-epoch-%04d.pth' % (epoch + 1)),
+    #                             epoch,lf_model,optimizer)
+                
+                
\ No newline at end of file
diff --git a/perc_loss.py b/perc_loss.py
new file mode 100644
index 0000000..acad75a
--- /dev/null
+++ b/perc_loss.py
@@ -0,0 +1,38 @@
+
+import torch
+import torchvision
+
+class VGGPerceptualLoss(torch.nn.Module):
+    def __init__(self, resize=True):
+        super(VGGPerceptualLoss, self).__init__()
+        blocks = []
+        blocks.append(torchvision.models.vgg16(pretrained=True).features[:4].eval())
+        blocks.append(torchvision.models.vgg16(pretrained=True).features[4:9].eval())
+        blocks.append(torchvision.models.vgg16(pretrained=True).features[9:16].eval())
+        blocks.append(torchvision.models.vgg16(pretrained=True).features[16:23].eval())
+        for bl in blocks:
+            for p in bl:
+                p.requires_grad = False
+        self.blocks = torch.nn.ModuleList(blocks)
+        self.transform = torch.nn.functional.interpolate
+        self.mean = torch.nn.Parameter(torch.tensor([0.485, 0.456, 0.406]).view(1,3,1,1))
+        self.std = torch.nn.Parameter(torch.tensor([0.229, 0.224, 0.225]).view(1,3,1,1))
+        self.resize = resize
+
+    def forward(self, input, target):
+        if input.shape[1] != 3:
+            input = input.repeat(1, 3, 1, 1)
+            target = target.repeat(1, 3, 1, 1)
+        input = (input-self.mean) / self.std
+        target = (target-self.mean) / self.std
+        if self.resize:
+            input = self.transform(input, mode='bilinear', size=(224, 224), align_corners=False)
+            target = self.transform(target, mode='bilinear', size=(224, 224), align_corners=False)
+        loss = 0.0
+        x = input
+        y = target
+        for block in self.blocks:
+            x = block(x)
+            y = block(y)
+            loss += torch.nn.functional.l1_loss(x, y)
+        return loss
\ No newline at end of file
diff --git a/ssim.py b/ssim.py
new file mode 100644
index 0000000..93f390b
--- /dev/null
+++ b/ssim.py
@@ -0,0 +1,72 @@
+import torch
+import torch.nn.functional as F
+from torch.autograd import Variable
+import numpy as np
+from math import exp
+
+def gaussian(window_size, sigma):
+    gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])
+    return gauss/gauss.sum()
+
+def create_window(window_size, channel):
+    _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
+    _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
+    window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
+    return window
+
+def _ssim(img1, img2, window, window_size, channel, size_average = True):
+    mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel)
+    mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel)
+
+    mu1_sq = mu1.pow(2)
+    mu2_sq = mu2.pow(2)
+    mu1_mu2 = mu1*mu2
+
+    sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq
+    sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq
+    sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2
+
+    C1 = 0.01**2
+    C2 = 0.03**2
+
+    ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2))
+
+    if size_average:
+        return ssim_map.mean()
+    else:
+        return ssim_map.mean(1).mean(1).mean(1)
+
+class SSIM(torch.nn.Module):
+    def __init__(self, window_size = 11, size_average = True):
+        super(SSIM, self).__init__()
+        self.window_size = window_size
+        self.size_average = size_average
+        self.channel = 1
+        self.window = create_window(window_size, self.channel)
+
+    def forward(self, img1, img2):
+        (_, channel, _, _) = img1.size()
+
+        if channel == self.channel and self.window.data.type() == img1.data.type():
+            window = self.window
+        else:
+            window = create_window(self.window_size, channel)
+            
+            if img1.is_cuda:
+                window = window.cuda(img1.get_device())
+            window = window.type_as(img1)
+            
+            self.window = window
+            self.channel = channel
+            
+        return _ssim(img1, img2, window, self.window_size, channel, self.size_average)
+
+def ssim(img1, img2, window_size = 11, size_average = True):
+    (_, channel, _, _) = img1.size()
+    window = create_window(window_size, channel)
+    
+    if img1.is_cuda:
+        window = window.cuda(img1.get_device())
+    window = window.type_as(img1)
+    
+    return _ssim(img1, img2, window, window_size, channel, size_average)
\ No newline at end of file
-- 
GitLab