From f6604bd24a91b07531f25089b7d426b0f363a2c7 Mon Sep 17 00:00:00 2001
From: Nianchen Deng <dengnianchen@sjtu.edu.cn>
Date: Tue, 16 Mar 2021 16:57:50 +0800
Subject: [PATCH] rebuttal version

---
 configs/fovea_nmsl4.py                 |   2 +-
 configs/fovea_rgb.py                   |   2 +-
 configs/new_fovea_rgb.py               |   2 +-
 configs/periph_rgb.py                  |   2 +-
 configs/spherical_view_syn.py          |  13 +-
 configs/us_fovea.py                    |   2 +-
 configs/us_periph.py                   |   2 +-
 configs/us_periph_new.py               |   2 +-
 dash_test.py                           | 162 +++++++++++++---------
 data/lf_syn.py                         |  99 -------------
 data/loader.py                         |   2 +-
 data/other.py                          | 125 +++++++++++++++--
 data/spherical_view_syn.py             |  15 +-
 nets/modules.py                        |   4 +-
 nets/msl_net.py                        |   5 +-
 nets/msl_net_new.py                    |   4 +-
 nets/msl_net_new_export.py             |   8 +-
 nets/spher_net.py                      |   2 +-
 nets/trans_unet.py                     |   6 +-
 notebook/test_spherical_view_syn.ipynb | 185 ++++++++++++-------------
 run_spherical_view_syn.py              |  21 ++-
 21 files changed, 344 insertions(+), 321 deletions(-)
 delete mode 100644 data/lf_syn.py

diff --git a/configs/fovea_nmsl4.py b/configs/fovea_nmsl4.py
index 8df32f1..cf0c353 100644
--- a/configs/fovea_nmsl4.py
+++ b/configs/fovea_nmsl4.py
@@ -1,4 +1,4 @@
-from ..my import color_mode
+from my import color_mode
 
 def update_config(config):
     # Dataset settings
diff --git a/configs/fovea_rgb.py b/configs/fovea_rgb.py
index 836037f..c11dda9 100644
--- a/configs/fovea_rgb.py
+++ b/configs/fovea_rgb.py
@@ -1,4 +1,4 @@
-from ..my import color_mode
+from my import color_mode
 
 def update_config(config):
     # Dataset settings
diff --git a/configs/new_fovea_rgb.py b/configs/new_fovea_rgb.py
index d868caa..0729785 100644
--- a/configs/new_fovea_rgb.py
+++ b/configs/new_fovea_rgb.py
@@ -1,4 +1,4 @@
-from ..my import color_mode
+from my import color_mode
 
 def update_config(config):
     # Dataset settings
diff --git a/configs/periph_rgb.py b/configs/periph_rgb.py
index 3f0bfd1..2bb8619 100644
--- a/configs/periph_rgb.py
+++ b/configs/periph_rgb.py
@@ -1,4 +1,4 @@
-from ..my import color_mode
+from my import color_mode
 
 def update_config(config):
     # Dataset settings
diff --git a/configs/spherical_view_syn.py b/configs/spherical_view_syn.py
index fc012ae..bb8f258 100644
--- a/configs/spherical_view_syn.py
+++ b/configs/spherical_view_syn.py
@@ -1,10 +1,8 @@
 import os
 import importlib
-from os.path import join
-from ..my import color_mode
-from ..nets.msl_net import MslNet
-from ..nets.msl_net_new import NewMslNet
-from ..nets.spher_net import SpherNet
+from my import color_mode
+from nets.msl_net import MslNet
+from nets.msl_net_new import NewMslNet
 
 
 class SphericalViewSynConfig(object):
@@ -36,14 +34,13 @@ class SphericalViewSynConfig(object):
 
     def load(self, path):
         module_name = os.path.splitext(path)[0].replace('/', '.')
-        config_module = importlib.import_module(
-            'deep_view_syn.' + module_name)
+        config_module = importlib.import_module(module_name)
         config_module.update_config(self)
         self.name = module_name.split('.')[-1]
 
     def load_by_name(self, name):
         config_module = importlib.import_module(
-            'deep_view_syn.configs.' + name)
+            'configs.' + name)
         config_module.update_config(self)
         self.name = name
 
diff --git a/configs/us_fovea.py b/configs/us_fovea.py
index 620d1d5..09edf37 100644
--- a/configs/us_fovea.py
+++ b/configs/us_fovea.py
@@ -1,4 +1,4 @@
-from ..my import color_mode
+from my import color_mode
 
 def update_config(config):
     # Dataset settings
diff --git a/configs/us_periph.py b/configs/us_periph.py
index 617b821..93450af 100644
--- a/configs/us_periph.py
+++ b/configs/us_periph.py
@@ -1,4 +1,4 @@
-from ..my import color_mode
+from my import color_mode
 
 def update_config(config):
     # Dataset settings
diff --git a/configs/us_periph_new.py b/configs/us_periph_new.py
index a86ca6f..9a372d1 100644
--- a/configs/us_periph_new.py
+++ b/configs/us_periph_new.py
@@ -1,4 +1,4 @@
-from ..my import color_mode
+from my import color_mode
 
 def update_config(config):
     # Dataset settings
diff --git a/dash_test.py b/dash_test.py
index 7882b4e..4ba8ebe 100644
--- a/dash_test.py
+++ b/dash_test.py
@@ -10,7 +10,7 @@ import plotly.express as px
 import pandas as pd
 from dash.dependencies import Input, Output
 
-sys.path.append(os.path.abspath(sys.path[0] + '/../'))
+#sys.path.append(os.path.abspath(sys.path[0] + '/../'))
 #__package__ = "deep_view_syn"
 
 if __name__ == '__main__':
@@ -24,23 +24,30 @@ if __name__ == '__main__':
     print("Set CUDA:%d as current device." % torch.cuda.current_device())
 torch.autograd.set_grad_enabled(False)
 
-from deep_view_syn.data.spherical_view_syn import *
-from deep_view_syn.configs.spherical_view_syn import SphericalViewSynConfig
-from deep_view_syn.my import netio
-from deep_view_syn.my import util
-from deep_view_syn.my import device
-from deep_view_syn.my import view
-from deep_view_syn.my.gen_final import GenFinal
-from deep_view_syn.nets.modules import Sampler
-
-
-datadir = None
+from data.spherical_view_syn import *
+from configs.spherical_view_syn import SphericalViewSynConfig
+from my import netio
+from my import util
+from my import device
+from my import view
+from my.gen_final import GenFinal
+from nets.modules import Sampler
+
+
+datadir = 'data/__0_user_study/us_gas_periph_r135x135_t0.3_2021.01.16/'
+data_desc_file = 'train.json'
+net_config = 'periph_rgb@msl-rgb_e10_fc96x4_d1.00-50.00_s16'
+net_path = datadir + net_config + '/model-epoch_200.pth'
+fov = 45
+res = (256, 256)
+view_idx = 4
+center = (0, 0)
 
 
 def load_net(path):
     print(path)
     config = SphericalViewSynConfig()
-    config.from_id(os.path.splitext(os.path.basename(path))[0])
+    config.from_id(net_config)
     config.SAMPLE_PARAMS['perturb_sample'] = False
     net = config.create_net().to(device.GetDevice())
     netio.LoadNet(path, net)
@@ -64,24 +71,25 @@ def load_views(data_desc_file) -> view.Trans:
         return view.Trans(view_centers, view_rots)
 
 
-scenes = {
-    'gas': '__0_user_study/us_gas_all_in_one',
-    'mc': '__0_user_study/us_mc_all_in_one',
-    'bedroom': 'bedroom_all_in_one',
-    'gallery': 'gallery_all_in_one',
-    'lobby': 'lobby_all_in_one'
-}
-
-fov_list = [20, 45, 110]
-res_list = [(128, 128), (256, 256), (256, 230)]
-res_full = (1600, 1440)
+cam = view.CameraParam({
+    'fov': fov,
+    'cx': 0.5,
+    'cy': 0.5,
+    'normalized': True
+}, res, device=device.GetDevice())
+net = load_net(net_path)
+sampler = Sampler(depth_range=(1, 50), n_samples=32, perturb_sample=False,
+                  spherical=True, lindisp=True, inverse_r=True)
+x = y = None
 
+views = load_views(data_desc_file)
+print('%d Views loaded.' % views.size()[0])
 
-scene = 'gas'
-view_file = 'views.json'
+test_view = views.get(view_idx)
+rays_o, rays_d = cam.get_global_rays(test_view, True)
+image = net(rays_o.view(-1, 3), rays_d.view(-1, 3)).view(1,
+                                                         res[0], res[1], -1).permute(0, 3, 1, 2)
 
-app = dash.Dash(__name__, external_stylesheets=[
-                'https://codepen.io/chriddyp/pen/bWLwgP.css'])
 
 styles = {
     'pre': {
@@ -90,35 +98,17 @@ styles = {
         'overflowX': 'scroll'
     }
 }
-
-datadir = 'data/' + scenes[scene] + '/'
-
-fovea_net = load_net_by_name('fovea')
-periph_net = load_net_by_name('periph')
-gen = GenFinal(fov_list, res_list, res_full, fovea_net, periph_net,
-               device=device.GetDevice())
-
-sampler = Sampler(depth_range=(1, 50), n_samples=32, perturb_sample=False,
-                  spherical=True, lindisp=True, inverse_r=True)
-x = y = None
-
-views = load_views(view_file)
-print('%d Views loaded.', views.size())
-
-view_idx = 27
-center = (0, 0)
-
-test_view = views.get(view_idx)
-images = gen(center, test_view)
-
-fig = px.imshow(util.Tensor2MatImg(images['fovea']))
+fig = px.imshow(util.Tensor2MatImg(image))
 fig1 = px.scatter(x=[0, 1, 2], y=[2, 0, 1])
-
+fig2 = px.scatter(x=[0, 1, 2], y=[2, 0, 1])
+app = dash.Dash(__name__, external_stylesheets=[
+                'https://codepen.io/chriddyp/pen/bWLwgP.css'])
 app.layout = html.Div([
     html.H3("Drag and draw annotations"),
     html.Div(className='row', children=[
         dcc.Graph(id='image', figure=fig),  # , config=config),
         dcc.Graph(id='scatter', figure=fig1),  # , config=config),
+        dcc.Graph(id='scatter1', figure=fig2),  # , config=config),
         dcc.Slider(id='samples-slider', min=4, max=128, step=None,
                    marks={
                        4: '4',
@@ -128,43 +118,91 @@ app.layout = html.Div([
                        64: '64',
                        128: '128',
                    },
-                   value=32,
+                   value=33,
                    updatemode='drag'
                    )
     ])
 ])
 
 
+def raw2alpha(raw, dists, act_fn=torch.relu):
+    """
+    Function for computing density from model prediction.
+    This value is strictly between [0, 1].
+    """
+    print('act_fn: ', act_fn(raw))
+    print('act_fn * dists: ', act_fn(raw) * dists)
+    return -torch.exp(-act_fn(raw) * dists) + 1.0
+
+
+def raw2color(raw: torch.Tensor, z_vals: torch.Tensor):
+    """
+    Raw value inferred from model to color and alpha
+
+    :param raw ```Tensor(N.rays, N.samples, 2|4)```: model's output
+    :param z_vals ```Tensor(N.rays, N.samples)```: integration time
+    :return ```Tensor(N.rays, N.samples, 1|3)```: color
+    :return ```Tensor(N.rays, N.samples)```: alpha
+    """
+
+    # Compute 'distance' (in time) between each integration time along a ray.
+    # The 'distance' from the last integration time is infinity.
+    # dists: (N_rays, N_samples)
+    dists = z_vals[..., 1:] - z_vals[..., :-1]
+    last_dist = z_vals[..., 0:1] * 0 + 1e10
+
+    dists = torch.cat([dists, last_dist], -1)
+    print('dists: ', dists)
+
+    # Extract RGB of each sample position along each ray.
+    color = torch.sigmoid(raw[..., :-1])  # (N_rays, N_samples, 1|3)
+    alpha = raw2alpha(raw[..., -1], dists)
+
+    return color, alpha
+
+
 def draw_scatter():
-    global fig1
-    p = torch.tensor([x, y], device=gen.layer_cams[0].c.device)
-    ray_d = test_view.trans_vector(gen.layer_cams[0].unproj(p))
+    global fig1, fig2
+    p = torch.tensor([x, y], device=device.GetDevice())
+    ray_d = test_view.trans_vector(cam.unproj(p))
     ray_o = test_view.t
-    raw, depths = fovea_net.sample_and_infer(ray_o, ray_d, sampler=sampler)
-    colors, alphas = fovea_net.rendering.raw2color(raw, depths)
+    raw, depths = net.sample_and_infer(ray_o, ray_d, sampler=sampler)
+    colors, alphas = raw2color(raw, depths)
 
     scatter_x = (1 / depths[0]).cpu().detach().numpy()
     scatter_y = alphas[0].cpu().detach().numpy()
+    scatter_y1 = raw[0, :, 3].cpu().detach().numpy()
     scatter_color = colors[0].cpu().detach().numpy() * 255
     marker_colors = [
-        i#'rgb(%d,%d,%d)' % (scatter_color[i][0], scatter_color[i][1], scatter_color[i][2])
+        # 'rgb(%d,%d,%d)' % (scatter_color[i][0], scatter_color[i][1], scatter_color[i][2])
+        i
         for i in range(scatter_color.shape[0])
     ]
     marker_colors_str = [
-        'rgb(%d,%d,%d)' % (scatter_color[i][0], scatter_color[i][1], scatter_color[i][2])
+        'rgb(%d,%d,%d)' % (scatter_color[i][0],
+                           scatter_color[i][1], scatter_color[i][2])
         for i in range(scatter_color.shape[0])
     ]
 
-    fig1 = px.scatter(x=scatter_x, y=scatter_y, color=marker_colors, color_continuous_scale=marker_colors_str)#, color_discrete_map='identity')
+    fig1 = px.scatter(x=scatter_x, y=scatter_y, color=marker_colors,
+                      color_continuous_scale=marker_colors_str)  # , color_discrete_map='identity')
     fig1.update_traces(mode='lines+markers')
     fig1.update_xaxes(showgrid=False)
     fig1.update_yaxes(type='linear')
     fig1.update_layout(height=225, margin={'l': 20, 'b': 30, 'r': 10, 't': 10})
 
+    fig2 = px.scatter(x=scatter_x, y=scatter_y1, color=marker_colors,
+                      color_continuous_scale=marker_colors_str)  # , color_discrete_map='identity')
+    fig2.update_traces(mode='lines+markers')
+    fig2.update_xaxes(showgrid=False)
+    fig2.update_yaxes(type='linear')
+    fig2.update_layout(height=225, margin={'l': 20, 'b': 30, 'r': 10, 't': 10})
+
 
 @app.callback(
     [Output('image', 'figure'),
-     Output('scatter', 'figure')],
+     Output('scatter', 'figure'),
+     Output('scatter1', 'figure')],
     [Input('image', 'clickData'),
      dash.dependencies.Input('samples-slider', 'value')]
 )
@@ -194,7 +232,7 @@ def display_hover_data(clickData, samples):
                           color="LightSeaGreen",
                           width=3,
                       ))
-    return fig, fig1
+    return fig, fig1, fig2
 
 
 if __name__ == '__main__':
diff --git a/data/lf_syn.py b/data/lf_syn.py
deleted file mode 100644
index 75fd45b..0000000
--- a/data/lf_syn.py
+++ /dev/null
@@ -1,99 +0,0 @@
-from typing import List, Tuple
-import torch
-import json
-from ..my import util
-
-
-def ReadLightField(path: str, views: Tuple[int, int], flatten_views: bool = False) -> torch.Tensor:
-    input_img = util.ReadImageTensor(path, batch_dim=False)
-    h = input_img.size()[1] // views[0]
-    w = input_img.size()[2] // views[1]
-    if flatten_views:
-        lf = torch.empty(views[0] * views[1], 3, h, w)
-        for y_i in range(views[0]):
-            for x_i in range(views[1]):
-                lf[y_i * views[1] + x_i, :, :, :] = \
-                    input_img[:, y_i * h:(y_i + 1) * h, x_i * w:(x_i + 1) * w]
-    else:
-        lf = torch.empty(views[0], views[1], 3, h, w)
-        for y_i in range(views[0]):
-            for x_i in range(views[1]):
-                lf[y_i, x_i, :, :, :] = \
-                    input_img[:, y_i * h:(y_i + 1) * h, x_i * w:(x_i + 1) * w]
-    return lf
-
-
-def DecodeDepth(depth_images: torch.Tensor) -> torch.Tensor:
-    return depth_images[:, 0].unsqueeze(1).mul(255) / 10
-
-
-class LightFieldSynDataset(torch.utils.data.dataset.Dataset):
-    """
-    Data loader for light field synthesis task
-
-    Attributes
-    --------
-    data_dir ```string```: the directory of dataset\n
-    n_views ```tuple(int, int)```: rows and columns of views\n
-    num_views ```int```: number of views\n
-    view_images ```N x H x W Tensor```: images of views\n
-    view_depths ```N x H x W Tensor```: depths of views\n
-    view_positions ```N x 3 Tensor```: positions of views\n
-    sparse_view_images ```N' x H x W Tensor```: images of sparse views\n
-    sparse_view_depths ```N' x H x W Tensor```: depths of sparse views\n
-    sparse_view_positions ```N' x 3 Tensor```: positions of sparse views\n
-    """
-
-    def __init__(self, data_desc_path: str):
-        """
-        Initialize data loader for light field synthesis task
-
-        The data description file is a JSON file with following fields:
-
-        - lf: string, the path of light field image
-        - lf_depth: string, the path of light field depth image
-        - n_views: { "x",  "y" }, columns and rows of views
-        - cam_params: { "f", "c" }, the focal and center of camera (in normalized image space)
-        - depth_range: [ min, max ], the range of depth in depth maps
-        - depth_layers: int, number of layers in depth maps
-        - view_positions: [ [ x, y, z ], ... ], positions of views
-
-        :param data_desc_path: path to the data description file
-        """
-        self.data_dir = data_desc_path.rsplit('/', 1)[0] + '/'
-        with open(data_desc_path, 'r', encoding='utf-8') as file:
-            self.data_desc = json.loads(file.read())
-        self.n_views = (self.data_desc['n_views']
-                        ['y'], self.data_desc['n_views']['x'])
-        self.num_views = self.n_views[0] * self.n_views[1]
-        self.view_images = ReadLightField(
-            self.data_dir + self.data_desc['lf'], self.n_views, True)
-        self.view_depths = DecodeDepth(ReadLightField(
-            self.data_dir + self.data_desc['lf_depth'], self.n_views, True))
-        self.cam_params = self.data_desc['cam_params']
-        self.depth_range = self.data_desc['depth_range']
-        self.depth_layers = self.data_desc['depth_layers']
-        self.view_positions = torch.tensor(self.data_desc['view_positions'])
-        _, self.sparse_view_images, self.sparse_view_depths, self.sparse_view_positions \
-            = self._GetCornerViews()
-        self.diopter_of_layers = self._GetDiopterOfLayers()
-
-    def __len__(self):
-        return self.num_views
-
-    def __getitem__(self, idx):
-        return idx, self.view_images[idx], self.view_depths[idx], self.view_positions[idx]
-
-    def _GetCornerViews(self):
-        corner_selector = torch.zeros(self.num_views, dtype=torch.bool)
-        corner_selector[0] = corner_selector[self.n_views[1] - 1] \
-            = corner_selector[self.num_views - self.n_views[1]] \
-            = corner_selector[self.num_views - 1] = True
-        return self.__getitem__(corner_selector)
-
-    def _GetDiopterOfLayers(self) -> List[float]:
-        diopter_range = (1 / self.depth_range[1], 1 / self.depth_range[0])
-        step = (diopter_range[1] - diopter_range[0]) / (self.depth_layers - 1)
-        diopter_of_layers = [diopter_range[0] + step * i for i in range(self.depth_layers)]
-        diopter_of_layers.insert(0, 0)
-        return diopter_of_layers
diff --git a/data/loader.py b/data/loader.py
index cd4d146..4849cb6 100644
--- a/data/loader.py
+++ b/data/loader.py
@@ -1,6 +1,6 @@
 import torch
 import math
-from ..my import device
+from my import device
 
 
 class FastDataLoader(object):
diff --git a/data/other.py b/data/other.py
index 7b25b43..4bc1bf7 100644
--- a/data/other.py
+++ b/data/other.py
@@ -1,19 +1,114 @@
 import torch
 import os
+import json
 import glob
+import cv2
 import numpy as np
 import torchvision.transforms as transforms
 
+from typing import List, Tuple
 from torchvision import datasets
 from torch.utils.data import DataLoader
 
-import cv2
-import json
-from Flow import *
-from gen_image import *
-import util
+from my.flow import *
+from my.gen_image import *
+from my import util
+
+
+def ReadLightField(path: str, views: Tuple[int, int], flatten_views: bool = False) -> torch.Tensor:
+    input_img = util.ReadImageTensor(path, batch_dim=False)
+    h = input_img.size()[1] // views[0]
+    w = input_img.size()[2] // views[1]
+    if flatten_views:
+        lf = torch.empty(views[0] * views[1], 3, h, w)
+        for y_i in range(views[0]):
+            for x_i in range(views[1]):
+                lf[y_i * views[1] + x_i, :, :, :] = \
+                    input_img[:, y_i * h:(y_i + 1) * h, x_i * w:(x_i + 1) * w]
+    else:
+        lf = torch.empty(views[0], views[1], 3, h, w)
+        for y_i in range(views[0]):
+            for x_i in range(views[1]):
+                lf[y_i, x_i, :, :, :] = \
+                    input_img[:, y_i * h:(y_i + 1) * h, x_i * w:(x_i + 1) * w]
+    return lf
+
+
+def DecodeDepth(depth_images: torch.Tensor) -> torch.Tensor:
+    return depth_images[:, 0].unsqueeze(1).mul(255) / 10
+
+
+class LightFieldSynDataset(torch.utils.data.dataset.Dataset):
+    """
+    Data loader for light field synthesis task
+
+    Attributes
+    --------
+    data_dir ```string```: the directory of dataset\n
+    n_views ```tuple(int, int)```: rows and columns of views\n
+    num_views ```int```: number of views\n
+    view_images ```N x H x W Tensor```: images of views\n
+    view_depths ```N x H x W Tensor```: depths of views\n
+    view_positions ```N x 3 Tensor```: positions of views\n
+    sparse_view_images ```N' x H x W Tensor```: images of sparse views\n
+    sparse_view_depths ```N' x H x W Tensor```: depths of sparse views\n
+    sparse_view_positions ```N' x 3 Tensor```: positions of sparse views\n
+    """
+
+    def __init__(self, data_desc_path: str):
+        """
+        Initialize data loader for light field synthesis task
+
+        The data description file is a JSON file with following fields:
+
+        - lf: string, the path of light field image
+        - lf_depth: string, the path of light field depth image
+        - n_views: { "x",  "y" }, columns and rows of views
+        - cam_params: { "f", "c" }, the focal and center of camera (in normalized image space)
+        - depth_range: [ min, max ], the range of depth in depth maps
+        - depth_layers: int, number of layers in depth maps
+        - view_positions: [ [ x, y, z ], ... ], positions of views
+
+        :param data_desc_path: path to the data description file
+        """
+        self.data_dir = data_desc_path.rsplit('/', 1)[0] + '/'
+        with open(data_desc_path, 'r', encoding='utf-8') as file:
+            self.data_desc = json.loads(file.read())
+        self.n_views = (self.data_desc['n_views']
+                        ['y'], self.data_desc['n_views']['x'])
+        self.num_views = self.n_views[0] * self.n_views[1]
+        self.view_images = ReadLightField(
+            self.data_dir + self.data_desc['lf'], self.n_views, True)
+        self.view_depths = DecodeDepth(ReadLightField(
+            self.data_dir + self.data_desc['lf_depth'], self.n_views, True))
+        self.cam_params = self.data_desc['cam_params']
+        self.depth_range = self.data_desc['depth_range']
+        self.depth_layers = self.data_desc['depth_layers']
+        self.view_positions = torch.tensor(self.data_desc['view_positions'])
+        _, self.sparse_view_images, self.sparse_view_depths, self.sparse_view_positions \
+            = self._GetCornerViews()
+        self.diopter_of_layers = self._GetDiopterOfLayers()
+
+    def __len__(self):
+        return self.num_views
+
+    def __getitem__(self, idx):
+        return idx, self.view_images[idx], self.view_depths[idx], self.view_positions[idx]
+
+    def _GetCornerViews(self):
+        corner_selector = torch.zeros(self.num_views, dtype=torch.bool)
+        corner_selector[0] = corner_selector[self.n_views[1] - 1] \
+            = corner_selector[self.num_views - self.n_views[1]] \
+            = corner_selector[self.num_views - 1] = True
+        return self.__getitem__(corner_selector)
 
-import time
+    def _GetDiopterOfLayers(self) -> List[float]:
+        diopter_range = (1 / self.depth_range[1], 1 / self.depth_range[0])
+        step = (diopter_range[1] - diopter_range[0]) / (self.depth_layers - 1)
+        diopter_of_layers = [diopter_range[0] +
+                             step * i for i in range(self.depth_layers)]
+        diopter_of_layers.insert(0, 0)
+        return diopter_of_layers
 
 
 class lightFieldSynDataLoader(torch.utils.data.dataset.Dataset):
@@ -90,8 +185,8 @@ class lightFieldDataLoader(torch.utils.data.dataset.Dataset):
         # print(lf_image_big.shape)
 
         for i in range(9):
-            lf_image = lf_image_big[i//3*IM_H:i//3 *
-                                    IM_H+IM_H, i % 3*IM_W:i % 3*IM_W+IM_W, 0:3]
+            lf_image = lf_image_big[i // 3 * IM_H:i // 3 *
+                                    IM_H + IM_H, i % 3 * IM_W:i % 3 * IM_W + IM_W, 0:3]
             # IF GrayScale
             # lf_image = lf_image_big[i//3*IM_H:i//3*IM_H+IM_H,i%3*IM_W:i%3*IM_W+IM_W,0:1]
             # print(lf_image.shape)
@@ -146,8 +241,8 @@ class lightFieldValDataLoader(torch.utils.data.dataset.Dataset):
         # print(lf_image_big.shape)
 
         for i in range(9):
-            lf_image = lf_image_big[i//3*IM_H:i//3 *
-                                    IM_H+IM_H, i % 3*IM_W:i % 3*IM_W+IM_W, 0:3]
+            lf_image = lf_image_big[i // 3 * IM_H:i // 3 *
+                                    IM_H + IM_H, i % 3 * IM_W:i % 3 * IM_W + IM_W, 0:3]
             # IF GrayScale
             # lf_image = lf_image_big[i//3*IM_H:i//3*IM_H+IM_H,i%3*IM_W:i%3*IM_W+IM_W,0:1]
             # print(lf_image.shape)
@@ -214,8 +309,8 @@ class lightFieldSeqDataLoader(torch.utils.data.dataset.Dataset):
             lf_image_big = cv2.cvtColor(lf_image_big, cv2.COLOR_BGR2RGB)
 
             for j in range(9):
-                lf_image = lf_image_big[j//3*IM_H:j//3 *
-                                        IM_H+IM_H, j % 3*IM_W:j % 3*IM_W+IM_W, 0:3]
+                lf_image = lf_image_big[j // 3 * IM_H:j // 3 *
+                                        IM_H + IM_H, j % 3 * IM_W:j % 3 * IM_W + IM_W, 0:3]
                 lf_image_one_sample.append(lf_image)
 
             gt_i = cv2.imread(fd_gt_path, cv2.IMREAD_UNCHANGED).astype(
@@ -297,8 +392,8 @@ class lightFieldFlowSeqDataLoader(torch.utils.data.dataset.Dataset):
 
             lf_dim = int(self.conf.light_field_dim)
             for j in range(lf_dim**2):
-                lf_image = lf_image_big[j//lf_dim*IM_H:j//lf_dim *
-                                        IM_H+IM_H, j % lf_dim*IM_W:j % lf_dim*IM_W+IM_W, 0:3]
+                lf_image = lf_image_big[j // lf_dim * IM_H:j // lf_dim *
+                                        IM_H + IM_H, j % lf_dim * IM_W:j % lf_dim * IM_W + IM_W, 0:3]
                 lf_image_one_sample.append(lf_image)
 
             gt_i = cv2.imread(fd_gt_path, cv2.IMREAD_UNCHANGED).astype(
@@ -333,7 +428,7 @@ class lightFieldFlowSeqDataLoader(torch.utils.data.dataset.Dataset):
             retinal_invalid.append(retinal_invalid_i)
         # lf_images: 5,9,320,320
         flow = Flow.Load([os.path.join(self.file_dir_path, self.dataset_desc["flow"]
-                                       [indices[i-1]]) for i in range(1, len(indices))])
+                                       [indices[i - 1]]) for i in range(1, len(indices))])
         flow_map = flow.getMap()
         flow_invalid_mask = flow.b_invalid_mask
         # print("flow:",flow_map.shape)
diff --git a/data/spherical_view_syn.py b/data/spherical_view_syn.py
index 1e65634..8072446 100644
--- a/data/spherical_view_syn.py
+++ b/data/spherical_view_syn.py
@@ -4,10 +4,10 @@ import torch
 import torchvision.transforms.functional as trans_f
 import torch.nn.functional as nn_f
 from typing import Tuple, Union
-from ..my import util
-from ..my import device
-from ..my import view
-from ..my import color_mode
+from my import util
+from my import device
+from my import view
+from my import color_mode
 
 
 class SphericalViewSynDataset(object):
@@ -129,6 +129,13 @@ class SphericalViewSynDataset(object):
         self.n_views = self.view_centers.size(0)
         self.n_pixels = self.n_views * self.view_res[0] * self.view_res[1]
 
+        if 'gl_coord' in data_desc and data_desc['gl_coord'] == True:
+            print('Convert from OGL coordinate to DX coordinate (i. e. right-hand to left-hand)')
+            self.cam_params.f[1] *= -1
+            self.view_centers[:, 2] *= -1
+            self.view_rots[:, 2] *= -1
+            self.view_rots[..., 2] *= -1
+
     def set_patch_size(self, patch_size: Union[int, Tuple[int, int]],
                        offset: Union[int, Tuple[int, int]] = 0):
         """
diff --git a/nets/modules.py b/nets/modules.py
index 87cf09b..e16a128 100644
--- a/nets/modules.py
+++ b/nets/modules.py
@@ -1,8 +1,8 @@
 from typing import List, Tuple
 import torch
 import torch.nn as nn
-from ..my import device
-from ..my import util
+from my import device
+from my import util
 
 
 class FcLayer(nn.Module):
diff --git a/nets/msl_net.py b/nets/msl_net.py
index 9836c11..594c543 100644
--- a/nets/msl_net.py
+++ b/nets/msl_net.py
@@ -2,9 +2,8 @@ import math
 import torch
 import torch.nn as nn
 from .modules import *
-from ..my import util
-from ..my import color_mode
-
+from my import util
+from my import color_mode
 
 class MslNet(nn.Module):
 
diff --git a/nets/msl_net_new.py b/nets/msl_net_new.py
index e9d7b5e..8301e32 100644
--- a/nets/msl_net_new.py
+++ b/nets/msl_net_new.py
@@ -1,8 +1,8 @@
 import torch
 import torch.nn as nn
 from .modules import *
-from ..my import color_mode
-from ..my.simple_perf import SimplePerf
+from my import color_mode
+from my.simple_perf import SimplePerf
 
 
 class NewMslNet(nn.Module):
diff --git a/nets/msl_net_new_export.py b/nets/msl_net_new_export.py
index 9f418fe..1fa52d3 100644
--- a/nets/msl_net_new_export.py
+++ b/nets/msl_net_new_export.py
@@ -2,10 +2,10 @@ from typing import Tuple
 import math
 import torch
 import torch.nn as nn
-from ..my import net_modules
-from ..my import util
-from ..my import device
-from ..my import color_mode
+from my import net_modules
+from my import util
+from my import device
+from my import color_mode
 from .msl_net_new import NewMslNet
 
 
diff --git a/nets/spher_net.py b/nets/spher_net.py
index 548c6df..3472ebb 100644
--- a/nets/spher_net.py
+++ b/nets/spher_net.py
@@ -1,7 +1,7 @@
 import torch
 import torch.nn as nn
 from .modules import *
-from ..my import util
+from my import util
 
 
 class SpherNet(nn.Module):
diff --git a/nets/trans_unet.py b/nets/trans_unet.py
index 0ff08ae..260371e 100644
--- a/nets/trans_unet.py
+++ b/nets/trans_unet.py
@@ -1,9 +1,9 @@
 from typing import List
 import torch
 import torch.nn as nn
-from ..pytorch_prototyping.pytorch_prototyping import *
-from ..my import util
-from ..my import device
+from pytorch_prototyping.pytorch_prototyping import *
+from my import util
+from my import device
 
 
 class Encoder(nn.Module):
diff --git a/notebook/test_spherical_view_syn.ipynb b/notebook/test_spherical_view_syn.ipynb
index 75b374e..d029218 100644
--- a/notebook/test_spherical_view_syn.ipynb
+++ b/notebook/test_spherical_view_syn.ipynb
@@ -2,20 +2,28 @@
  "cells": [
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 2,
    "metadata": {},
-   "outputs": [],
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Set CUDA:2 as current device.\n"
+     ]
+    }
+   ],
    "source": [
     "import sys\n",
     "import os\n",
-    "sys.path.append(os.path.abspath(sys.path[0] + '/../../'))\n",
+    "sys.path.append(os.path.abspath(sys.path[0] + '/../'))\n",
     "\n",
     "import torch\n",
     "import math\n",
     "import matplotlib.pyplot as plt\n",
     "import numpy as np\n",
-    "from deep_view_syn.my import util\n",
-    "from deep_view_syn.msl_net import *\n",
+    "from my import util\n",
+    "from nets.msl_net import *\n",
     "\n",
     "# Select device\n",
     "torch.cuda.set_device(2)\n",
@@ -23,8 +31,10 @@
    ]
   },
   {
-   "cell_type": "markdown",
+   "cell_type": "code",
+   "execution_count": 1,
    "metadata": {},
+   "outputs": [],
    "source": [
     "# Test Ray-Sphere Intersection & Cartesian-Spherical Conversion"
    ]
@@ -56,14 +66,15 @@
     "v = torch.tensor([[0.0, -1.0, 1.0]])\n",
     "r = torch.tensor([[2.5]])\n",
     "v = v / torch.norm(v) * r * 2\n",
-    "p_on_sphere_ = RaySphereIntersect(p, v, r)[0]\n",
+    "p_on_sphere_ = util.RaySphereIntersect(p, v, r)[0][0]\n",
     "print(p_on_sphere_)\n",
     "print(p_on_sphere_.norm())\n",
-    "spher_coord = RayToSpherical(p, v, r)\n",
+    "spher_coord = util.CartesianToSpherical(p_on_sphere_)\n",
     "print(spher_coord[..., 1:3].rad2deg())\n",
-    "p_on_sphere = util.SphericalToCartesian(spher_coord)[0]\n",
+    "p_on_sphere = util.SphericalToCartesian(spher_coord)\n",
+    "print(p_on_sphere_.size())\n",
     "\n",
-    "fig = plt.figure(figsize=(6, 6))\n",
+    "fig = plt.figure(figsize=(8, 8))\n",
     "ax = fig.gca(projection='3d')\n",
     "plt.xlabel('x')\n",
     "plt.ylabel('z')\n",
@@ -109,8 +120,10 @@
    ]
   },
   {
-   "cell_type": "markdown",
+   "cell_type": "code",
+   "execution_count": null,
    "metadata": {},
+   "outputs": [],
    "source": [
     "# Test Dataset Loader & View-Spherical Transform"
    ]
@@ -121,26 +134,26 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "from deep_view_syn.data.spherical_view_syn import FastSphericalViewSynDataset\n",
-    "from deep_view_syn.data.spherical_view_syn import FastDataLoader\n",
+    "from data.spherical_view_syn import SphericalViewSynDataset\n",
+    "from data.loader import FastDataLoader\n",
     "\n",
-    "DATA_DIR = '../data/sp_view_syn_2020.12.28'\n",
+    "DATA_DIR = '../data/nerf_fern'\n",
     "TRAIN_DATA_DESC_FILE = DATA_DIR + '/train.json'\n",
     "\n",
-    "dataset = FastSphericalViewSynDataset(TRAIN_DATA_DESC_FILE)\n",
+    "dataset = SphericalViewSynDataset(TRAIN_DATA_DESC_FILE)\n",
     "dataset.set_patch_size((64, 64))\n",
     "data_loader = FastDataLoader(dataset=dataset, batch_size=4, shuffle=False, drop_last=False)\n",
     "print(len(dataset))\n",
-    "plt.figure()\n",
+    "fig = plt.figure(figsize=(12, 6.5))\n",
     "i = 0\n",
     "for indices, patches, rays_o, rays_d in data_loader:\n",
     "    print(i, patches.size(), rays_o.size(), rays_d.size())\n",
     "    for idx in range(len(indices)):\n",
-    "        plt.subplot(4, 4, i + 1)\n",
+    "        plt.subplot(4, 7, i + 1)\n",
     "        util.PlotImageTensor(patches[idx])\n",
     "        i += 1\n",
-    "    if i == 16:\n",
-    "        break\n"
+    "    if i == 28:\n",
+    "        break"
    ]
   },
   {
@@ -149,13 +162,15 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "from deep_view_syn.data.spherical_view_syn import SphericalViewSynDataset\n",
+    "from data.spherical_view_syn import SphericalViewSynDataset\n",
+    "from data.loader import FastDataLoader\n",
     "\n",
-    "DATA_DIR = '../data/sp_view_syn_2020.12.26'\n",
+    "DATA_DIR = '../data/nerf_fern'\n",
     "TRAIN_DATA_DESC_FILE = DATA_DIR + '/train.json'\n",
     "DEPTH_RANGE = (1, 10)\n",
     "N_DEPTH_LAYERS = 10\n",
     "\n",
+    "\n",
     "def _GetSphereLayers(depth_range: Tuple[float, float], n_layers: int) -> torch.Tensor:\n",
     "    diopter_range = (1 / depth_range[1], 1 / depth_range[0])\n",
     "    step = (diopter_range[1] - diopter_range[0]) / (n_layers - 1)\n",
@@ -163,74 +178,54 @@
     "    depths += [1 / (diopter_range[0] + step * i) for i in range(n_layers)]\n",
     "    return torch.tensor(depths, device=device.GetDevice()).view(-1, 1)\n",
     "\n",
-    "train_dataset = SphericalViewSynDataset(TRAIN_DATA_DESC_FILE)\n",
-    "train_data_loader = torch.utils.data.DataLoader(\n",
-    "    dataset=train_dataset,\n",
-    "    batch_size=4,\n",
-    "    num_workers=8,\n",
-    "    pin_memory=True,\n",
-    "    shuffle=True,\n",
-    "    drop_last=False)\n",
-    "print(len(train_data_loader))\n",
-    "\n",
-    "print(\"view_res\", train_dataset.view_res)\n",
-    "print(\"cam_params\", train_dataset.cam_params)\n",
-    "\n",
-    "msl_net = MslNet(train_dataset.cam_params,\n",
-    "                 _GetSphereLayers(DEPTH_RANGE, N_DEPTH_LAYERS),\n",
-    "                 train_dataset.view_res).to(device.GetDevice())\n",
-    "print(\"sphere layers\", msl_net.rendering.sphere_layers)\n",
-    "\n",
-    "p = None\n",
-    "v = None\n",
-    "centers = None\n",
-    "plt.figure(figsize=(6, 6))\n",
-    "for _, view_images, ray_positions, ray_directions in train_data_loader:\n",
-    "    p = ray_positions\n",
-    "    v = ray_directions\n",
-    "    plt.subplot(2, 2, 1)\n",
-    "    util.PlotImageTensor(view_images[0])\n",
-    "    plt.subplot(2, 2, 2)\n",
-    "    util.PlotImageTensor(view_images[1])\n",
-    "    plt.subplot(2, 2, 3)\n",
-    "    util.PlotImageTensor(view_images[2])\n",
-    "    plt.subplot(2, 2, 4)\n",
-    "    util.PlotImageTensor(view_images[3])\n",
-    "    break\n",
-    "p_ = util.SphericalToCartesian(RayToSpherical(p.flatten(0, 1), v.flatten(0, 1),\n",
-    "                                         torch.tensor([[1]], device=device.GetDevice()))) \\\n",
-    "    .view(4, train_dataset.view_res[0], train_dataset.view_res[1], 3)\n",
-    "v = v.view(4, train_dataset.view_res[0], train_dataset.view_res[1], 3)[:, 0::50, 0::50, :].flatten(1, 2).cpu().numpy()\n",
-    "p_ = p_[:, 0::50, 0::50, :].flatten(1, 2).cpu().numpy()\n",
-    "\n",
-    "fig = plt.figure(figsize=(6, 6))\n",
-    "ax = fig.gca(projection='3d')\n",
-    "plt.xlabel('x')\n",
-    "plt.ylabel('z')\n",
     "\n",
-    "PlotSphere(ax, 1)\n",
+    "dataset = SphericalViewSynDataset(TRAIN_DATA_DESC_FILE)\n",
+    "dataset.set_patch_size(1)\n",
+    "data_loader = FastDataLoader(\n",
+    "    dataset=dataset, batch_size=4096*16, shuffle=True, drop_last=False)\n",
     "\n",
-    "ax.scatter([0], [0], [0], color=\"k\", s=10)  # Center\n",
+    "print(\"view_res\", dataset.view_res)\n",
+    "print(\"cam_params\", dataset.cam_params)\n",
     "\n",
-    "colors = [ 'r', 'g', 'b', 'y' ]\n",
-    "for i in range(4):\n",
-    "    ax.scatter(p_[i, :, 0], p_[i, :, 2], p_[i, :, 1], color=colors[i], s=3)\n",
-    "    for j in range(p_.shape[1]):\n",
-    "        ax.plot([centers[i, 0], centers[i, 0] + v[i, j, 0]],\n",
-    "                [centers[i, 2], centers[i, 2] + v[i, j, 2]],\n",
-    "                [centers[i, 1], centers[i, 1] + v[i, j, 1]],\n",
-    "                color=colors[i], linewidth=0.5, alpha=0.6)\n",
+    "fig = plt.figure(figsize=(16, 40))\n",
     "\n",
-    "ax.set_xlim(-1, 1)\n",
-    "ax.set_ylim(-1, 1)\n",
-    "ax.set_zlim(-1, 1)\n",
-    "\n",
-    "plt.show()\n"
+    "for ri in range(0, 10):\n",
+    "    r = ri * 0.2 + 1\n",
+    "    p = None\n",
+    "    centers = None\n",
+    "    pixels = None\n",
+    "    for indices, patches, rays_o, rays_d in data_loader:\n",
+    "        p = util.RaySphereIntersect(\n",
+    "            rays_o, rays_d, torch.tensor([[r]], device=device.GetDevice()))[0] \\\n",
+    "            .view(-1, 3).cpu().numpy()\n",
+    "        centers = rays_o.view(-1, 3).cpu().numpy()\n",
+    "        pixels = patches.view(-1, 3).cpu().numpy()\n",
+    "        break\n",
+    "    \n",
+    "    ax = plt.subplot(5, 2, ri + 1, projection='3d')\n",
+    "    #ax = plt.gca(projection='3d')\n",
+    "    #ax = fig.gca()\n",
+    "    plt.xlabel('x')\n",
+    "    plt.ylabel('z')\n",
+    "    plt.title('r = %f' % r)\n",
+    "\n",
+    "    # PlotSphere(ax, 1)\n",
+    "\n",
+    "    ax.scatter([0], [0], [0], color=\"k\", s=10)\n",
+    "    ax.scatter(p[:, 0], p[:, 2], p[:, 1], color=pixels, s=0.5)\n",
+    "\n",
+    "    #ax.set_xlim(-1, 1)\n",
+    "    #ax.set_ylim(-1, 1)\n",
+    "    #ax.set_zlim(-1, 1)\n",
+    "    ax.view_init(elev=0,azim=-90)\n",
+    "\n"
    ]
   },
   {
-   "cell_type": "markdown",
+   "cell_type": "code",
+   "execution_count": 3,
    "metadata": {},
+   "outputs": [],
    "source": [
     "# Test Sampler"
    ]
@@ -241,7 +236,7 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "from deep_view_syn.data.spherical_view_syn import SphericalViewSynDataset\n",
+    "from data.spherical_view_syn import SphericalViewSynDataset\n",
     "\n",
     "DATA_DIR = '../data/sp_view_syn_2020.12.29_finetrans'\n",
     "TRAIN_DATA_DESC_FILE = DATA_DIR + '/train.json'\n",
@@ -304,7 +299,7 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "from deep_view_syn.data.spherical_view_syn import SphericalViewSynDataset\n",
+    "from data.spherical_view_syn import SphericalViewSynDataset\n",
     "\n",
     "DATA_DIR = '../data/sp_view_syn_2020.12.26_rotonly'\n",
     "TRAIN_DATA_DESC_FILE = DATA_DIR + '/train.json'\n",
@@ -367,8 +362,10 @@
    ]
   },
   {
-   "cell_type": "markdown",
+   "cell_type": "code",
+   "execution_count": null,
    "metadata": {},
+   "outputs": [],
    "source": [
     "# Test Spherical View Synthesis"
    ]
@@ -381,9 +378,9 @@
    "source": [
     "import ipywidgets as widgets  # 控件库\n",
     "from IPython.display import display  # 显示控件的方法\n",
-    "from deep_view_syn.data.spherical_view_syn import SphericalViewSynDataset\n",
-    "from deep_view_syn.spher_net import SpherNet\n",
-    "from deep_view_syn.my import netio\n",
+    "from data.spherical_view_syn import SphericalViewSynDataset\n",
+    "from nets.spher_net import SpherNet\n",
+    "from my import netio\n",
     "\n",
     "DATA_DIR = '../data/sp_view_syn_2020.12.28_small'\n",
     "DATA_DESC_FILE = DATA_DIR + '/train.json'\n",
@@ -455,20 +452,12 @@
     "})\n",
     "display(slider_x, slider_y, slider_z, slider_theta, slider_phi, out)\n"
    ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": []
   }
  ],
  "metadata": {
   "kernelspec": {
-   "display_name": "Python 3",
-   "language": "python",
-   "name": "python3"
+   "display_name": "Python 3.7.9 64-bit ('pytorch': conda)",
+   "name": "python379jvsc74a57bd0660ca2a75467d3af74a68fcc6f40bc78ab96b99ff17d2f100b5ca821fbb183f2"
   },
   "language_info": {
    "codemirror_mode": {
@@ -480,7 +469,7 @@
    "name": "python",
    "nbconvert_exporter": "python",
    "pygments_lexer": "ipython3",
-   "version": "3.7.6-final"
+   "version": "3.7.9"
   }
  },
  "nbformat": 4,
diff --git a/run_spherical_view_syn.py b/run_spherical_view_syn.py
index e49373f..406c8f4 100644
--- a/run_spherical_view_syn.py
+++ b/run_spherical_view_syn.py
@@ -6,9 +6,6 @@ import torch.optim
 from tensorboardX import SummaryWriter
 from torch import nn
 
-sys.path.append(os.path.abspath(sys.path[0] + '/../'))
-__package__ = "deep_view_syn"
-
 parser = argparse.ArgumentParser()
 parser.add_argument('--device', type=int, default=3,
                     help='Which CUDA device to use.')
@@ -46,15 +43,15 @@ if opt.res:
 torch.cuda.set_device(opt.device)
 print("Set CUDA:%d as current device." % torch.cuda.current_device())
 
-from .my import netio
-from .my import util
-from .my import device
-from .my import loss
-from .my.progress_bar import progress_bar
-from .my.simple_perf import SimplePerf
-from .data.spherical_view_syn import *
-from .data.loader import FastDataLoader
-from .configs.spherical_view_syn import SphericalViewSynConfig
+from my import netio
+from my import util
+from my import device
+from my import loss
+from my.progress_bar import progress_bar
+from my.simple_perf import SimplePerf
+from data.spherical_view_syn import *
+from data.loader import FastDataLoader
+from configs.spherical_view_syn import SphericalViewSynConfig
 
 
 config = SphericalViewSynConfig()
-- 
GitLab