From c570c3b15b9766385408ea93ef06b26b0e909757 Mon Sep 17 00:00:00 2001
From: BobYeah <635596704@qq.com>
Date: Mon, 28 Dec 2020 21:04:08 +0800
Subject: [PATCH] checkpoint

---
 data/spherical_view_syn.py |  16 ++--
 image_scale.py             |  32 ++++++++
 msl_net.py                 | 152 ++++++++++---------------------------
 run_spherical_view_syn.py  |  72 +++++++++---------
 spher_net.py               |  39 +---------
 5 files changed, 119 insertions(+), 192 deletions(-)
 create mode 100644 image_scale.py

diff --git a/data/spherical_view_syn.py b/data/spherical_view_syn.py
index b37602a..e8c88af 100644
--- a/data/spherical_view_syn.py
+++ b/data/spherical_view_syn.py
@@ -2,7 +2,6 @@ import torch
 import torchvision.transforms.functional as trans_f
 import json
 from ..my import util
-from ..my import imgio
 
 
 class SphericalViewSynDataset(torch.utils.data.dataset.Dataset):
@@ -44,8 +43,11 @@ class SphericalViewSynDataset(torch.utils.data.dataset.Dataset):
         # Load dataset description file
         with open(dataset_desc_path, 'r', encoding='utf-8') as file:
             data_desc = json.loads(file.read())
-        self.view_file_pattern: str = self.data_dir + \
-            data_desc['view_file_pattern']
+        if data_desc['view_file_pattern'] == '':
+            self.load_images = False
+        else:
+            self.view_file_pattern: str = self.data_dir + \
+                data_desc['view_file_pattern']
         self.view_res = (data_desc['view_res']['y'],
                          data_desc['view_res']['x'])
         self.cam_params = data_desc['cam_params']
@@ -54,7 +56,7 @@ class SphericalViewSynDataset(torch.utils.data.dataset.Dataset):
             .view(-1, 3, 3)  # (N, 3, 3)
 
         # Load view images
-        if load_images:
+        if self.load_images:
             self.view_images = util.ReadImageTensor(
                 [self.view_file_pattern % i for i in range(self.view_centers.size(0))])
             if gray:
@@ -75,8 +77,8 @@ class SphericalViewSynDataset(torch.utils.data.dataset.Dataset):
 
         # Flatten rays if ray_as_item = True
         if ray_as_item:
-            self.view_pixels = self.view_images.permute(
-                0, 2, 3, 1).flatten(0, 2)
+            self.view_pixels = self.view_images.permute(0, 2, 3, 1).flatten(
+                0, 2) if self.view_images != None else None
             self.ray_positions = self.ray_positions.flatten(0, 1)
             self.ray_directions = self.ray_directions.flatten(0, 1)
 
@@ -88,4 +90,4 @@ class SphericalViewSynDataset(torch.utils.data.dataset.Dataset):
             if self.ray_as_item:
                 return idx, self.view_pixels[idx], self.ray_positions[idx], self.ray_directions[idx]
             return idx, self.view_images[idx], self.ray_positions[idx], self.ray_directions[idx]
-        return idx, self.ray_positions[idx], self.ray_directions[idx]
+        return idx, False, self.ray_positions[idx], self.ray_directions[idx]
diff --git a/image_scale.py b/image_scale.py
new file mode 100644
index 0000000..22630e6
--- /dev/null
+++ b/image_scale.py
@@ -0,0 +1,32 @@
+import sys
+import os
+sys.path.append(os.path.abspath(sys.path[0] + '/../'))
+__package__ = "deeplightfield"
+
+import argparse
+from PIL import Image
+from .my import util
+
+
+def batch_scale(src, target, size):
+    util.CreateDirIfNeed(target)
+    for file_name in os.listdir(src):
+        postfix = os.path.splitext(file_name)[1]
+        if postfix == '.jpg' or postfix == '.png':
+            im = Image.open(os.path.join(src, file_name))
+            im = im.resize(size)
+            im.save(os.path.join(target, file_name))
+
+
+if __name__ == '__main__':
+    parser = argparse.ArgumentParser()
+    parser.add_argument('src', type=str,
+                        help='Source directory.')
+    parser.add_argument('target', type=str,
+                        help='Target directory.')
+    parser.add_argument('--width', type=int,
+                        help='Width of output images (pixel)')
+    parser.add_argument('--height', type=int,
+                        help='Height of output images (pixel)')
+    opt = parser.parse_args()
+    batch_scale(opt.src, opt.target, (opt.width, opt.height))
diff --git a/msl_net.py b/msl_net.py
index c576743..247472d 100644
--- a/msl_net.py
+++ b/msl_net.py
@@ -1,42 +1,11 @@
 from typing import List, Tuple
-from math import pi
 import torch
 import torch.nn as nn
-from .pytorch_prototyping.pytorch_prototyping import *
+from .my import net_modules
 from .my import util
 from .my import device
 
 
-def CartesianToSpherical(cart: torch.Tensor) -> torch.Tensor:
-    """
-    Convert coordinates from Cartesian to Spherical
-
-    :param cart: ... x 3, coordinates in Cartesian
-    :return: ... x 3, coordinates in Spherical (r, theta, phi)
-    """
-    rho = torch.norm(cart, p=2, dim=-1)
-    theta = torch.atan2(cart[..., 2], cart[..., 0])
-    theta = theta + (theta < 0).type_as(theta) * (2 * pi)
-    phi = torch.acos(cart[..., 1] / rho)
-    return torch.stack([rho, theta, phi], dim=-1)
-
-
-def SphericalToCartesian(spher: torch.Tensor) -> torch.Tensor:
-    """
-    Convert coordinates from Spherical to Cartesian
-
-    :param spher: ... x 3, coordinates in Spherical
-    :return: ... x 3, coordinates in Cartesian (r, theta, phi)
-    """
-    rho = spher[..., 0]
-    sin_theta_phi = torch.sin(spher[..., 1:3])
-    cos_theta_phi = torch.cos(spher[..., 1:3])
-    x = rho * cos_theta_phi[..., 0] * sin_theta_phi[..., 1]
-    y = rho * cos_theta_phi[..., 1]
-    z = rho * sin_theta_phi[..., 0] * sin_theta_phi[..., 1]
-    return torch.stack([x, y, z], dim=-1)
-
-
 def RaySphereIntersect(p: torch.Tensor, v: torch.Tensor, r: torch.Tensor) -> torch.Tensor:
     """
     Calculate intersections of each rays and each spheres
@@ -68,115 +37,74 @@ def RayToSpherical(p: torch.Tensor, v: torch.Tensor, r: torch.Tensor) -> torch.T
     :return: B x B' x 3, spherical coordinates
     """
     p_on_spheres = RaySphereIntersect(p, v, r)
-    return CartesianToSpherical(p_on_spheres)
-
-
-class FcNet(nn.Module):
-
-    def __init__(self, in_chns: int, out_chns: int, nf: int, n_layers: int):
-        super().__init__()
-        self.layers = list()
-        self.layers += [
-            nn.Linear(in_chns, nf),
-            #nn.LayerNorm([nf]),
-            nn.ReLU()
-        ]
-        for _ in range(1, n_layers):
-            self.layers += [
-                nn.Linear(nf, nf),
-                #nn.LayerNorm([nf]),
-                nn.ReLU()
-            ]
-        self.layers.append(nn.Linear(nf, out_chns))
-        self.net = nn.Sequential(*self.layers)
-        self.net.apply(self.init_weights)
-
-    def forward(self, x: torch.Tensor) -> torch.Tensor:
-        return self.net(x)
-
-    def init_weights(self, m):
-        if isinstance(m, nn.Linear):
-            nn.init.xavier_normal_(m.weight)
-            nn.init.constant_(m.bias, 0.0)
+    return util.CartesianToSpherical(p_on_spheres)
 
 
 class Rendering(nn.Module):
 
-    def __init__(self, sphere_layers: List[float]):
+    def __init__(self):
         """
         Initialize a Rendering module
-
-        :param sphere_layers: L x 1, radius of sphere layers
         """
         super().__init__()
-        self.sphere_layers = torch.tensor(
-            sphere_layers, device=device.GetDevice())
 
-    def forward(self, net: FcNet, p: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
+    def forward(self, color_alpha: torch.Tensor) -> torch.Tensor:
         """
-        [summary]
+        Blend layers to get final color
 
-        :param net: the full-connected net
-        :param p: B x 3, positions of rays
-        :param v: B x 3, directions of rays
-        :return B x 1/3, view images by blended layers
+        :param color_alpha ```Tensor(B, L, C)```: RGB or gray with alpha channel
+        :return ```Tensor(B, C-1)``` blended pixels
         """
-        L = self.sphere_layers.size()[0]
-        sp = RayToSpherical(p, v, self.sphere_layers)  # B x L x 3
-        sp[..., 0] = 1 / sp[..., 0]                    # Radius to diopter
-        color_alpha: torch.Tensor = net(
-            sp.flatten(0, 1)).view(p.size()[0], L, -1)
-        if (color_alpha.size(-1) == 2):  # Grayscale
-            c = color_alpha[..., 0:1]
-            a = color_alpha[..., 1:2]
-        else:                           # RGB
-            c = color_alpha[..., 0:3]
-            a = color_alpha[..., 3:4]
+        c = color_alpha[..., :-1]
+        a = color_alpha[..., -1:]
         blended = c[:, 0, :] * a[:, 0, :]
-        for l in range(1, L):
+        for l in range(1, color_alpha.size(1)):
             blended = blended * (1 - a[:, l, :]) + c[:, l, :] * a[:, l, :]
         return blended
 
 
 class MslNet(nn.Module):
 
-    def __init__(self, cam_params, sphere_layers: List[float], out_res: Tuple[int, int], gray=False):
+    def __init__(self, cam_params, fc_params, sphere_layers: List[float],
+                 out_res: Tuple[int, int], gray=False, encode_to_dim: int = 0):
         """
         Initialize a multi-sphere-layer net
 
         :param cam_params: intrinsic parameters of camera
-        :param sphere_layers: L x 1, radius of sphere layers
+        :param fc_params: parameters of full-connection network
+        :param sphere_layers: list(L), radius of sphere layers
         :param out_res: resolution of output view image
+        :param gray: is grayscale mode
+        :param encode_to_dim: encode input to number of dimensions
         """
         super().__init__()
         self.cam_params = cam_params
+        self.sphere_layers = torch.tensor(sphere_layers,
+                                          dtype=torch.float,
+                                          device=device.GetDevice())
+        self.in_chns = 3
         self.out_res = out_res
-        self.v_local = util.GetLocalViewRays(self.cam_params, out_res, flatten=True) \
-            .to(device.GetDevice()) # N x 3
-        #self.net = FCBlock(hidden_ch=64,
-        #                   num_hidden_layers=4,
-        #                   in_features=3,
-        #                   out_features=2 if gray else 4,
-        #                   outermost_linear=True)
-        self.net = FcNet(in_chns=3, out_chns=2 if gray else 4, nf=256, n_layers=8)
-        self.rendering = Rendering(sphere_layers)
-
-    def forward(self, view_centers: torch.Tensor, view_rots: torch.Tensor) -> torch.Tensor:
+        self.input_encoder = net_modules.InputEncoder.Get(
+            encode_to_dim, self.in_chns)
+        fc_params['in_chns'] = self.input_encoder.out_dim
+        fc_params['out_chns'] = 2 if gray else 4
+        self.net = net_modules.FcNet(**fc_params)
+        self.rendering = Rendering()
+
+    def forward(self, ray_positions: torch.Tensor, ray_directions: torch.Tensor) -> torch.Tensor:
         """
-        T_view -> image
+        rays -> colors
 
-        :param view_centers: B x 3, centers of views
-        :param view_rots: B x 3 x 3, rotation matrices of views
-        :return: B x 1/3 x H_out x W_out, inferred images of views
+        :param ray_positions ```Tensor(B, M, 3)|Tensor(B, 3)```: ray positions
+        :param ray_directions ```Tensor(B, M, 3)|Tensor(B, 3)```: ray directions
+        :return: Tensor(B, 1|3, H, W)|Tensor(B, 1|3), inferred images/pixels
         """
-        # Transpose matrix so we can perform vec x mat
-        view_rots_t = view_rots.permute(0, 2, 1)
-
-        # p and v are B x N x 3 tensor
-        p = view_centers.unsqueeze(1).expand(-1, self.v_local.size(0), -1)
-        v = torch.matmul(self.v_local, view_rots_t)
-        c: torch.Tensor = self.rendering(
-            self.net, p.flatten(0, 1), v.flatten(0, 1))  # (BN) x 3
+        p = ray_positions.view(-1, 3)
+        v = ray_directions.view(-1, 3)
+        spher = RayToSpherical(p, v, self.sphere_layers).flatten(0, 1)
+        color_alpha = self.net(self.input_encoder(spher)).view(
+            p.size(0), self.sphere_layers.size(0), -1)
+        c: torch.Tensor = self.rendering(color_alpha)
         # unflatten
-        return c.view(view_centers.size(0), self.out_res[0],
-                      self.out_res[1], -1).permute(0, 3, 1, 2)
+        return c.view(ray_directions.size(0), self.out_res[0],
+                      self.out_res[1], -1).permute(0, 3, 1, 2) if len(ray_directions.size()) == 3 else c
diff --git a/run_spherical_view_syn.py b/run_spherical_view_syn.py
index c724ff9..8474bec 100644
--- a/run_spherical_view_syn.py
+++ b/run_spherical_view_syn.py
@@ -1,19 +1,18 @@
 import sys
-sys.path.append('/e/dengnc')
+import os
+sys.path.append(os.path.abspath(sys.path[0] + '/../'))
 __package__ = "deeplightfield"
 
 import argparse
 import torch
 import torch.optim
 import torchvision
-from typing import List, Tuple
 from tensorboardX import SummaryWriter
 from torch import nn
 from .my import netio
 from .my import util
 from .my import device
 from .my.simple_perf import SimplePerf
-from .loss.loss import PerceptionReconstructionLoss
 from .data.spherical_view_syn import SphericalViewSynDataset
 from .msl_net import MslNet
 from .spher_net import SpherNet
@@ -36,9 +35,9 @@ TRAIN_MODE = True
 EVAL_TIME_PERFORMANCE = False
 RAY_AS_ITEM = True
 # ========
-#GRAY = True
-ROT_ONLY = True
-TRAIN_MODE = False
+GRAY = True
+#ROT_ONLY = True
+#TRAIN_MODE = False
 #EVAL_TIME_PERFORMANCE = True
 #RAY_AS_ITEM = False
 
@@ -48,39 +47,39 @@ N_DEPTH_LAYERS = 10
 N_ENCODE_DIM = 10
 FC_PARAMS = {
     'nf': 128,
-    'n_layers': 6,
+    'n_layers': 8,
     'skips': [4]
 }
 
 # Train
+TRAIN_DATA_DESC_FILE = 'train.json'
 BATCH_SIZE = 2048 if RAY_AS_ITEM else 4
 EPOCH_RANGE = range(0, 500)
 SAVE_INTERVAL = 20
 
+# Test
+TEST_NET_NAME = 'model-epoch_500'
+TEST_DATA_DESC_FILE = 'test_fovea.json'
+TEST_BATCH_SIZE = 5
+
 # Paths
-DATA_DIR = sys.path[0] + '/data/sp_view_syn_2020.12.26_rotonly/'
+DATA_DIR = sys.path[0] + '/data/sp_view_syn_2020.12.28/'
 RUN_ID = '%s_ray_b%d_encode%d_fc%dx%d%s' % ('gray' if GRAY else 'rgb',
                                             BATCH_SIZE,
                                             N_ENCODE_DIM,
                                             FC_PARAMS['nf'],
                                             FC_PARAMS['n_layers'],
                                             '_skip_%d' % FC_PARAMS['skips'][0] if len(FC_PARAMS['skips']) > 0 else '')
-TRAIN_DATA_DESC_FILE = DATA_DIR + 'train.json'
 RUN_DIR = DATA_DIR + RUN_ID + '/'
 OUTPUT_DIR = RUN_DIR + 'output/'
 LOG_DIR = RUN_DIR + 'log/'
 
 
-# Test
-TEST_NET_NAME = 'model-epoch_100'
-TEST_BATCH_SIZE = 5
-
-
 def train():
     # 1. Initialize data loader
-    print("Load dataset: " + TRAIN_DATA_DESC_FILE)
-    train_dataset = SphericalViewSynDataset(
-        TRAIN_DATA_DESC_FILE, gray=GRAY, ray_as_item=RAY_AS_ITEM)
+    print("Load dataset: " + DATA_DIR + TRAIN_DATA_DESC_FILE)
+    train_dataset = SphericalViewSynDataset(DATA_DIR + TRAIN_DATA_DESC_FILE,
+                                            gray=GRAY, ray_as_item=RAY_AS_ITEM)
     train_data_loader = torch.utils.data.DataLoader(
         dataset=train_dataset,
         batch_size=BATCH_SIZE,
@@ -98,10 +97,12 @@ def train():
                          encode_to_dim=N_ENCODE_DIM).to(device.GetDevice())
     else:
         model = MslNet(cam_params=train_dataset.cam_params,
+                       fc_params=FC_PARAMS,
                        sphere_layers=util.GetDepthLayers(
                            DEPTH_RANGE, N_DEPTH_LAYERS),
                        out_res=train_dataset.view_res,
-                       gray=GRAY).to(device.GetDevice())
+                       gray=GRAY,
+                       encode_to_dim=N_ENCODE_DIM).to(device.GetDevice())
     optimizer = torch.optim.Adam(model.parameters(), lr=5e-4)
     loss = nn.MSELoss()
 
@@ -172,11 +173,11 @@ def train():
 
 def test(net_file: str):
     # 1. Load train dataset
-    print("Load dataset: " + TRAIN_DATA_DESC_FILE)
-    train_dataset = SphericalViewSynDataset(TRAIN_DATA_DESC_FILE,
-                                            load_images=True, gray=GRAY)
-    train_data_loader = torch.utils.data.DataLoader(
-        dataset=train_dataset,
+    print("Load dataset: " + DATA_DIR + TEST_DATA_DESC_FILE)
+    test_dataset = SphericalViewSynDataset(DATA_DIR + TEST_DATA_DESC_FILE,
+                                           load_images=True, gray=GRAY)
+    test_data_loader = torch.utils.data.DataLoader(
+        dataset=test_dataset,
         batch_size=TEST_BATCH_SIZE,
         pin_memory=True,
         shuffle=False,
@@ -184,37 +185,38 @@ def test(net_file: str):
 
     # 2. Load trained model
     if ROT_ONLY:
-        model = SpherNet(cam_params=train_dataset.cam_params,
+        model = SpherNet(cam_params=test_dataset.cam_params,
                          fc_params=FC_PARAMS,
-                         out_res=train_dataset.view_res,
+                         out_res=test_dataset.view_res,
                          gray=GRAY,
                          encode_to_dim=N_ENCODE_DIM).to(device.GetDevice())
     else:
-        model = MslNet(cam_params=train_dataset.cam_params,
-                       sphere_layers=_GetSphereLayers(
+        model = MslNet(cam_params=test_dataset.cam_params,
+                       sphere_layers=util.GetDepthLayers(
                            DEPTH_RANGE, N_DEPTH_LAYERS),
-                       out_res=train_dataset.view_res,
+                       out_res=test_dataset.view_res,
                        gray=GRAY).to(device.GetDevice())
     netio.LoadNet(net_file, model)
 
     # 3. Test on train dataset
     print("Begin test on train dataset, batch size is %d" % TEST_BATCH_SIZE)
-    util.CreateDirIfNeed(OUTPUT_DIR)
-    util.CreateDirIfNeed(OUTPUT_DIR + TEST_NET_NAME)
+    output_dir = '%s%s/%s/' % (OUTPUT_DIR, TEST_NET_NAME, TEST_DATA_DESC_FILE)
+    util.CreateDirIfNeed(output_dir)
     perf = SimplePerf(True, start=True)
     i = 0
-    for view_idxs, view_images, ray_positions, ray_directions in train_data_loader:
+    for view_idxs, view_images, ray_positions, ray_directions in test_data_loader:
         ray_positions = ray_positions.to(device.GetDevice())
         ray_directions = ray_directions.to(device.GetDevice())
         perf.Checkpoint("%d - Load" % i)
         out_view_images = model(ray_positions, ray_directions)
         perf.Checkpoint("%d - Infer" % i)
-        util.WriteImageTensor(
-            view_images,
-            ['%s%s/gt_view_%04d.png' % (OUTPUT_DIR, TEST_NET_NAME, i) for i in view_idxs])
+        if test_dataset.load_images:
+            util.WriteImageTensor(
+                view_images,
+                ['%sgt_view_%04d.png' % (output_dir, i) for i in view_idxs])
         util.WriteImageTensor(
             out_view_images,
-            ['%s%s/out_view_%04d.png' % (OUTPUT_DIR, TEST_NET_NAME, i) for i in view_idxs])
+            ['%sout_view_%04d.png' % (output_dir, i) for i in view_idxs])
         perf.Checkpoint("%d - Write" % i)
         i += 1
 
diff --git a/spher_net.py b/spher_net.py
index 4ab69a0..e1bc6e9 100644
--- a/spher_net.py
+++ b/spher_net.py
@@ -1,45 +1,8 @@
-from typing import List, Tuple
-from math import pi
+from typing import Tuple
 import torch
 import torch.nn as nn
-from .pytorch_prototyping.pytorch_prototyping import *
 from .my import net_modules
 from .my import util
-from .my import device
-
-
-def RaySphereIntersect(p: torch.Tensor, v: torch.Tensor, r: torch.Tensor) -> torch.Tensor:
-    """
-    Calculate intersections of each rays and each spheres
-
-    :param p: B x 3, positions of rays
-    :param v: B x 3, directions of rays
-    :param r: B'(1D), radius of spheres
-    :return: B x B' x 3, points of intersection
-    """
-    # p, v: Expand to B x 1 x 3
-    p = p.unsqueeze(1)
-    v = v.unsqueeze(1)
-    # pp, vv, pv: B x 1
-    pp = (p * p).sum(dim=2)
-    vv = (v * v).sum(dim=2)
-    pv = (p * v).sum(dim=2)
-    # k: Expand to B x B' x 1
-    k = (((pv * pv - vv * (pp - r * r)).sqrt() - pv) / vv).unsqueeze(2)
-    return p + k * v
-
-
-def RayToSpherical(p: torch.Tensor, v: torch.Tensor, r: torch.Tensor) -> torch.Tensor:
-    """
-    Calculate intersections of each rays and each spheres
-
-    :param p: B x 3, positions of rays
-    :param v: B x 3, directions of rays
-    :param r: B' x 1, radius of spheres
-    :return: B x B' x 3, spherical coordinates
-    """
-    p_on_spheres = RaySphereIntersect(p, v, r)
-    return util.CartesianToSpherical(p_on_spheres)
 
 
 class SpherNet(nn.Module):
-- 
GitLab