From f6604bd24a91b07531f25089b7d426b0f363a2c7 Mon Sep 17 00:00:00 2001 From: Nianchen Deng <dengnianchen@sjtu.edu.cn> Date: Tue, 16 Mar 2021 16:57:50 +0800 Subject: [PATCH] rebuttal version --- configs/fovea_nmsl4.py | 2 +- configs/fovea_rgb.py | 2 +- configs/new_fovea_rgb.py | 2 +- configs/periph_rgb.py | 2 +- configs/spherical_view_syn.py | 13 +- configs/us_fovea.py | 2 +- configs/us_periph.py | 2 +- configs/us_periph_new.py | 2 +- dash_test.py | 162 +++++++++++++--------- data/lf_syn.py | 99 ------------- data/loader.py | 2 +- data/other.py | 125 +++++++++++++++-- data/spherical_view_syn.py | 15 +- nets/modules.py | 4 +- nets/msl_net.py | 5 +- nets/msl_net_new.py | 4 +- nets/msl_net_new_export.py | 8 +- nets/spher_net.py | 2 +- nets/trans_unet.py | 6 +- notebook/test_spherical_view_syn.ipynb | 185 ++++++++++++------------- run_spherical_view_syn.py | 21 ++- 21 files changed, 344 insertions(+), 321 deletions(-) delete mode 100644 data/lf_syn.py diff --git a/configs/fovea_nmsl4.py b/configs/fovea_nmsl4.py index 8df32f1..cf0c353 100644 --- a/configs/fovea_nmsl4.py +++ b/configs/fovea_nmsl4.py @@ -1,4 +1,4 @@ -from ..my import color_mode +from my import color_mode def update_config(config): # Dataset settings diff --git a/configs/fovea_rgb.py b/configs/fovea_rgb.py index 836037f..c11dda9 100644 --- a/configs/fovea_rgb.py +++ b/configs/fovea_rgb.py @@ -1,4 +1,4 @@ -from ..my import color_mode +from my import color_mode def update_config(config): # Dataset settings diff --git a/configs/new_fovea_rgb.py b/configs/new_fovea_rgb.py index d868caa..0729785 100644 --- a/configs/new_fovea_rgb.py +++ b/configs/new_fovea_rgb.py @@ -1,4 +1,4 @@ -from ..my import color_mode +from my import color_mode def update_config(config): # Dataset settings diff --git a/configs/periph_rgb.py b/configs/periph_rgb.py index 3f0bfd1..2bb8619 100644 --- a/configs/periph_rgb.py +++ b/configs/periph_rgb.py @@ -1,4 +1,4 @@ -from ..my import color_mode +from my import color_mode def update_config(config): # Dataset settings diff --git a/configs/spherical_view_syn.py b/configs/spherical_view_syn.py index fc012ae..bb8f258 100644 --- a/configs/spherical_view_syn.py +++ b/configs/spherical_view_syn.py @@ -1,10 +1,8 @@ import os import importlib -from os.path import join -from ..my import color_mode -from ..nets.msl_net import MslNet -from ..nets.msl_net_new import NewMslNet -from ..nets.spher_net import SpherNet +from my import color_mode +from nets.msl_net import MslNet +from nets.msl_net_new import NewMslNet class SphericalViewSynConfig(object): @@ -36,14 +34,13 @@ class SphericalViewSynConfig(object): def load(self, path): module_name = os.path.splitext(path)[0].replace('/', '.') - config_module = importlib.import_module( - 'deep_view_syn.' + module_name) + config_module = importlib.import_module(module_name) config_module.update_config(self) self.name = module_name.split('.')[-1] def load_by_name(self, name): config_module = importlib.import_module( - 'deep_view_syn.configs.' + name) + 'configs.' + name) config_module.update_config(self) self.name = name diff --git a/configs/us_fovea.py b/configs/us_fovea.py index 620d1d5..09edf37 100644 --- a/configs/us_fovea.py +++ b/configs/us_fovea.py @@ -1,4 +1,4 @@ -from ..my import color_mode +from my import color_mode def update_config(config): # Dataset settings diff --git a/configs/us_periph.py b/configs/us_periph.py index 617b821..93450af 100644 --- a/configs/us_periph.py +++ b/configs/us_periph.py @@ -1,4 +1,4 @@ -from ..my import color_mode +from my import color_mode def update_config(config): # Dataset settings diff --git a/configs/us_periph_new.py b/configs/us_periph_new.py index a86ca6f..9a372d1 100644 --- a/configs/us_periph_new.py +++ b/configs/us_periph_new.py @@ -1,4 +1,4 @@ -from ..my import color_mode +from my import color_mode def update_config(config): # Dataset settings diff --git a/dash_test.py b/dash_test.py index 7882b4e..4ba8ebe 100644 --- a/dash_test.py +++ b/dash_test.py @@ -10,7 +10,7 @@ import plotly.express as px import pandas as pd from dash.dependencies import Input, Output -sys.path.append(os.path.abspath(sys.path[0] + '/../')) +#sys.path.append(os.path.abspath(sys.path[0] + '/../')) #__package__ = "deep_view_syn" if __name__ == '__main__': @@ -24,23 +24,30 @@ if __name__ == '__main__': print("Set CUDA:%d as current device." % torch.cuda.current_device()) torch.autograd.set_grad_enabled(False) -from deep_view_syn.data.spherical_view_syn import * -from deep_view_syn.configs.spherical_view_syn import SphericalViewSynConfig -from deep_view_syn.my import netio -from deep_view_syn.my import util -from deep_view_syn.my import device -from deep_view_syn.my import view -from deep_view_syn.my.gen_final import GenFinal -from deep_view_syn.nets.modules import Sampler - - -datadir = None +from data.spherical_view_syn import * +from configs.spherical_view_syn import SphericalViewSynConfig +from my import netio +from my import util +from my import device +from my import view +from my.gen_final import GenFinal +from nets.modules import Sampler + + +datadir = 'data/__0_user_study/us_gas_periph_r135x135_t0.3_2021.01.16/' +data_desc_file = 'train.json' +net_config = 'periph_rgb@msl-rgb_e10_fc96x4_d1.00-50.00_s16' +net_path = datadir + net_config + '/model-epoch_200.pth' +fov = 45 +res = (256, 256) +view_idx = 4 +center = (0, 0) def load_net(path): print(path) config = SphericalViewSynConfig() - config.from_id(os.path.splitext(os.path.basename(path))[0]) + config.from_id(net_config) config.SAMPLE_PARAMS['perturb_sample'] = False net = config.create_net().to(device.GetDevice()) netio.LoadNet(path, net) @@ -64,24 +71,25 @@ def load_views(data_desc_file) -> view.Trans: return view.Trans(view_centers, view_rots) -scenes = { - 'gas': '__0_user_study/us_gas_all_in_one', - 'mc': '__0_user_study/us_mc_all_in_one', - 'bedroom': 'bedroom_all_in_one', - 'gallery': 'gallery_all_in_one', - 'lobby': 'lobby_all_in_one' -} - -fov_list = [20, 45, 110] -res_list = [(128, 128), (256, 256), (256, 230)] -res_full = (1600, 1440) +cam = view.CameraParam({ + 'fov': fov, + 'cx': 0.5, + 'cy': 0.5, + 'normalized': True +}, res, device=device.GetDevice()) +net = load_net(net_path) +sampler = Sampler(depth_range=(1, 50), n_samples=32, perturb_sample=False, + spherical=True, lindisp=True, inverse_r=True) +x = y = None +views = load_views(data_desc_file) +print('%d Views loaded.' % views.size()[0]) -scene = 'gas' -view_file = 'views.json' +test_view = views.get(view_idx) +rays_o, rays_d = cam.get_global_rays(test_view, True) +image = net(rays_o.view(-1, 3), rays_d.view(-1, 3)).view(1, + res[0], res[1], -1).permute(0, 3, 1, 2) -app = dash.Dash(__name__, external_stylesheets=[ - 'https://codepen.io/chriddyp/pen/bWLwgP.css']) styles = { 'pre': { @@ -90,35 +98,17 @@ styles = { 'overflowX': 'scroll' } } - -datadir = 'data/' + scenes[scene] + '/' - -fovea_net = load_net_by_name('fovea') -periph_net = load_net_by_name('periph') -gen = GenFinal(fov_list, res_list, res_full, fovea_net, periph_net, - device=device.GetDevice()) - -sampler = Sampler(depth_range=(1, 50), n_samples=32, perturb_sample=False, - spherical=True, lindisp=True, inverse_r=True) -x = y = None - -views = load_views(view_file) -print('%d Views loaded.', views.size()) - -view_idx = 27 -center = (0, 0) - -test_view = views.get(view_idx) -images = gen(center, test_view) - -fig = px.imshow(util.Tensor2MatImg(images['fovea'])) +fig = px.imshow(util.Tensor2MatImg(image)) fig1 = px.scatter(x=[0, 1, 2], y=[2, 0, 1]) - +fig2 = px.scatter(x=[0, 1, 2], y=[2, 0, 1]) +app = dash.Dash(__name__, external_stylesheets=[ + 'https://codepen.io/chriddyp/pen/bWLwgP.css']) app.layout = html.Div([ html.H3("Drag and draw annotations"), html.Div(className='row', children=[ dcc.Graph(id='image', figure=fig), # , config=config), dcc.Graph(id='scatter', figure=fig1), # , config=config), + dcc.Graph(id='scatter1', figure=fig2), # , config=config), dcc.Slider(id='samples-slider', min=4, max=128, step=None, marks={ 4: '4', @@ -128,43 +118,91 @@ app.layout = html.Div([ 64: '64', 128: '128', }, - value=32, + value=33, updatemode='drag' ) ]) ]) +def raw2alpha(raw, dists, act_fn=torch.relu): + """ + Function for computing density from model prediction. + This value is strictly between [0, 1]. + """ + print('act_fn: ', act_fn(raw)) + print('act_fn * dists: ', act_fn(raw) * dists) + return -torch.exp(-act_fn(raw) * dists) + 1.0 + + +def raw2color(raw: torch.Tensor, z_vals: torch.Tensor): + """ + Raw value inferred from model to color and alpha + + :param raw ```Tensor(N.rays, N.samples, 2|4)```: model's output + :param z_vals ```Tensor(N.rays, N.samples)```: integration time + :return ```Tensor(N.rays, N.samples, 1|3)```: color + :return ```Tensor(N.rays, N.samples)```: alpha + """ + + # Compute 'distance' (in time) between each integration time along a ray. + # The 'distance' from the last integration time is infinity. + # dists: (N_rays, N_samples) + dists = z_vals[..., 1:] - z_vals[..., :-1] + last_dist = z_vals[..., 0:1] * 0 + 1e10 + + dists = torch.cat([dists, last_dist], -1) + print('dists: ', dists) + + # Extract RGB of each sample position along each ray. + color = torch.sigmoid(raw[..., :-1]) # (N_rays, N_samples, 1|3) + alpha = raw2alpha(raw[..., -1], dists) + + return color, alpha + + def draw_scatter(): - global fig1 - p = torch.tensor([x, y], device=gen.layer_cams[0].c.device) - ray_d = test_view.trans_vector(gen.layer_cams[0].unproj(p)) + global fig1, fig2 + p = torch.tensor([x, y], device=device.GetDevice()) + ray_d = test_view.trans_vector(cam.unproj(p)) ray_o = test_view.t - raw, depths = fovea_net.sample_and_infer(ray_o, ray_d, sampler=sampler) - colors, alphas = fovea_net.rendering.raw2color(raw, depths) + raw, depths = net.sample_and_infer(ray_o, ray_d, sampler=sampler) + colors, alphas = raw2color(raw, depths) scatter_x = (1 / depths[0]).cpu().detach().numpy() scatter_y = alphas[0].cpu().detach().numpy() + scatter_y1 = raw[0, :, 3].cpu().detach().numpy() scatter_color = colors[0].cpu().detach().numpy() * 255 marker_colors = [ - i#'rgb(%d,%d,%d)' % (scatter_color[i][0], scatter_color[i][1], scatter_color[i][2]) + # 'rgb(%d,%d,%d)' % (scatter_color[i][0], scatter_color[i][1], scatter_color[i][2]) + i for i in range(scatter_color.shape[0]) ] marker_colors_str = [ - 'rgb(%d,%d,%d)' % (scatter_color[i][0], scatter_color[i][1], scatter_color[i][2]) + 'rgb(%d,%d,%d)' % (scatter_color[i][0], + scatter_color[i][1], scatter_color[i][2]) for i in range(scatter_color.shape[0]) ] - fig1 = px.scatter(x=scatter_x, y=scatter_y, color=marker_colors, color_continuous_scale=marker_colors_str)#, color_discrete_map='identity') + fig1 = px.scatter(x=scatter_x, y=scatter_y, color=marker_colors, + color_continuous_scale=marker_colors_str) # , color_discrete_map='identity') fig1.update_traces(mode='lines+markers') fig1.update_xaxes(showgrid=False) fig1.update_yaxes(type='linear') fig1.update_layout(height=225, margin={'l': 20, 'b': 30, 'r': 10, 't': 10}) + fig2 = px.scatter(x=scatter_x, y=scatter_y1, color=marker_colors, + color_continuous_scale=marker_colors_str) # , color_discrete_map='identity') + fig2.update_traces(mode='lines+markers') + fig2.update_xaxes(showgrid=False) + fig2.update_yaxes(type='linear') + fig2.update_layout(height=225, margin={'l': 20, 'b': 30, 'r': 10, 't': 10}) + @app.callback( [Output('image', 'figure'), - Output('scatter', 'figure')], + Output('scatter', 'figure'), + Output('scatter1', 'figure')], [Input('image', 'clickData'), dash.dependencies.Input('samples-slider', 'value')] ) @@ -194,7 +232,7 @@ def display_hover_data(clickData, samples): color="LightSeaGreen", width=3, )) - return fig, fig1 + return fig, fig1, fig2 if __name__ == '__main__': diff --git a/data/lf_syn.py b/data/lf_syn.py deleted file mode 100644 index 75fd45b..0000000 --- a/data/lf_syn.py +++ /dev/null @@ -1,99 +0,0 @@ -from typing import List, Tuple -import torch -import json -from ..my import util - - -def ReadLightField(path: str, views: Tuple[int, int], flatten_views: bool = False) -> torch.Tensor: - input_img = util.ReadImageTensor(path, batch_dim=False) - h = input_img.size()[1] // views[0] - w = input_img.size()[2] // views[1] - if flatten_views: - lf = torch.empty(views[0] * views[1], 3, h, w) - for y_i in range(views[0]): - for x_i in range(views[1]): - lf[y_i * views[1] + x_i, :, :, :] = \ - input_img[:, y_i * h:(y_i + 1) * h, x_i * w:(x_i + 1) * w] - else: - lf = torch.empty(views[0], views[1], 3, h, w) - for y_i in range(views[0]): - for x_i in range(views[1]): - lf[y_i, x_i, :, :, :] = \ - input_img[:, y_i * h:(y_i + 1) * h, x_i * w:(x_i + 1) * w] - return lf - - -def DecodeDepth(depth_images: torch.Tensor) -> torch.Tensor: - return depth_images[:, 0].unsqueeze(1).mul(255) / 10 - - -class LightFieldSynDataset(torch.utils.data.dataset.Dataset): - """ - Data loader for light field synthesis task - - Attributes - -------- - data_dir ```string```: the directory of dataset\n - n_views ```tuple(int, int)```: rows and columns of views\n - num_views ```int```: number of views\n - view_images ```N x H x W Tensor```: images of views\n - view_depths ```N x H x W Tensor```: depths of views\n - view_positions ```N x 3 Tensor```: positions of views\n - sparse_view_images ```N' x H x W Tensor```: images of sparse views\n - sparse_view_depths ```N' x H x W Tensor```: depths of sparse views\n - sparse_view_positions ```N' x 3 Tensor```: positions of sparse views\n - """ - - def __init__(self, data_desc_path: str): - """ - Initialize data loader for light field synthesis task - - The data description file is a JSON file with following fields: - - - lf: string, the path of light field image - - lf_depth: string, the path of light field depth image - - n_views: { "x", "y" }, columns and rows of views - - cam_params: { "f", "c" }, the focal and center of camera (in normalized image space) - - depth_range: [ min, max ], the range of depth in depth maps - - depth_layers: int, number of layers in depth maps - - view_positions: [ [ x, y, z ], ... ], positions of views - - :param data_desc_path: path to the data description file - """ - self.data_dir = data_desc_path.rsplit('/', 1)[0] + '/' - with open(data_desc_path, 'r', encoding='utf-8') as file: - self.data_desc = json.loads(file.read()) - self.n_views = (self.data_desc['n_views'] - ['y'], self.data_desc['n_views']['x']) - self.num_views = self.n_views[0] * self.n_views[1] - self.view_images = ReadLightField( - self.data_dir + self.data_desc['lf'], self.n_views, True) - self.view_depths = DecodeDepth(ReadLightField( - self.data_dir + self.data_desc['lf_depth'], self.n_views, True)) - self.cam_params = self.data_desc['cam_params'] - self.depth_range = self.data_desc['depth_range'] - self.depth_layers = self.data_desc['depth_layers'] - self.view_positions = torch.tensor(self.data_desc['view_positions']) - _, self.sparse_view_images, self.sparse_view_depths, self.sparse_view_positions \ - = self._GetCornerViews() - self.diopter_of_layers = self._GetDiopterOfLayers() - - def __len__(self): - return self.num_views - - def __getitem__(self, idx): - return idx, self.view_images[idx], self.view_depths[idx], self.view_positions[idx] - - def _GetCornerViews(self): - corner_selector = torch.zeros(self.num_views, dtype=torch.bool) - corner_selector[0] = corner_selector[self.n_views[1] - 1] \ - = corner_selector[self.num_views - self.n_views[1]] \ - = corner_selector[self.num_views - 1] = True - return self.__getitem__(corner_selector) - - def _GetDiopterOfLayers(self) -> List[float]: - diopter_range = (1 / self.depth_range[1], 1 / self.depth_range[0]) - step = (diopter_range[1] - diopter_range[0]) / (self.depth_layers - 1) - diopter_of_layers = [diopter_range[0] + step * i for i in range(self.depth_layers)] - diopter_of_layers.insert(0, 0) - return diopter_of_layers diff --git a/data/loader.py b/data/loader.py index cd4d146..4849cb6 100644 --- a/data/loader.py +++ b/data/loader.py @@ -1,6 +1,6 @@ import torch import math -from ..my import device +from my import device class FastDataLoader(object): diff --git a/data/other.py b/data/other.py index 7b25b43..4bc1bf7 100644 --- a/data/other.py +++ b/data/other.py @@ -1,19 +1,114 @@ import torch import os +import json import glob +import cv2 import numpy as np import torchvision.transforms as transforms +from typing import List, Tuple from torchvision import datasets from torch.utils.data import DataLoader -import cv2 -import json -from Flow import * -from gen_image import * -import util +from my.flow import * +from my.gen_image import * +from my import util + + +def ReadLightField(path: str, views: Tuple[int, int], flatten_views: bool = False) -> torch.Tensor: + input_img = util.ReadImageTensor(path, batch_dim=False) + h = input_img.size()[1] // views[0] + w = input_img.size()[2] // views[1] + if flatten_views: + lf = torch.empty(views[0] * views[1], 3, h, w) + for y_i in range(views[0]): + for x_i in range(views[1]): + lf[y_i * views[1] + x_i, :, :, :] = \ + input_img[:, y_i * h:(y_i + 1) * h, x_i * w:(x_i + 1) * w] + else: + lf = torch.empty(views[0], views[1], 3, h, w) + for y_i in range(views[0]): + for x_i in range(views[1]): + lf[y_i, x_i, :, :, :] = \ + input_img[:, y_i * h:(y_i + 1) * h, x_i * w:(x_i + 1) * w] + return lf + + +def DecodeDepth(depth_images: torch.Tensor) -> torch.Tensor: + return depth_images[:, 0].unsqueeze(1).mul(255) / 10 + + +class LightFieldSynDataset(torch.utils.data.dataset.Dataset): + """ + Data loader for light field synthesis task + + Attributes + -------- + data_dir ```string```: the directory of dataset\n + n_views ```tuple(int, int)```: rows and columns of views\n + num_views ```int```: number of views\n + view_images ```N x H x W Tensor```: images of views\n + view_depths ```N x H x W Tensor```: depths of views\n + view_positions ```N x 3 Tensor```: positions of views\n + sparse_view_images ```N' x H x W Tensor```: images of sparse views\n + sparse_view_depths ```N' x H x W Tensor```: depths of sparse views\n + sparse_view_positions ```N' x 3 Tensor```: positions of sparse views\n + """ + + def __init__(self, data_desc_path: str): + """ + Initialize data loader for light field synthesis task + + The data description file is a JSON file with following fields: + + - lf: string, the path of light field image + - lf_depth: string, the path of light field depth image + - n_views: { "x", "y" }, columns and rows of views + - cam_params: { "f", "c" }, the focal and center of camera (in normalized image space) + - depth_range: [ min, max ], the range of depth in depth maps + - depth_layers: int, number of layers in depth maps + - view_positions: [ [ x, y, z ], ... ], positions of views + + :param data_desc_path: path to the data description file + """ + self.data_dir = data_desc_path.rsplit('/', 1)[0] + '/' + with open(data_desc_path, 'r', encoding='utf-8') as file: + self.data_desc = json.loads(file.read()) + self.n_views = (self.data_desc['n_views'] + ['y'], self.data_desc['n_views']['x']) + self.num_views = self.n_views[0] * self.n_views[1] + self.view_images = ReadLightField( + self.data_dir + self.data_desc['lf'], self.n_views, True) + self.view_depths = DecodeDepth(ReadLightField( + self.data_dir + self.data_desc['lf_depth'], self.n_views, True)) + self.cam_params = self.data_desc['cam_params'] + self.depth_range = self.data_desc['depth_range'] + self.depth_layers = self.data_desc['depth_layers'] + self.view_positions = torch.tensor(self.data_desc['view_positions']) + _, self.sparse_view_images, self.sparse_view_depths, self.sparse_view_positions \ + = self._GetCornerViews() + self.diopter_of_layers = self._GetDiopterOfLayers() + + def __len__(self): + return self.num_views + + def __getitem__(self, idx): + return idx, self.view_images[idx], self.view_depths[idx], self.view_positions[idx] + + def _GetCornerViews(self): + corner_selector = torch.zeros(self.num_views, dtype=torch.bool) + corner_selector[0] = corner_selector[self.n_views[1] - 1] \ + = corner_selector[self.num_views - self.n_views[1]] \ + = corner_selector[self.num_views - 1] = True + return self.__getitem__(corner_selector) -import time + def _GetDiopterOfLayers(self) -> List[float]: + diopter_range = (1 / self.depth_range[1], 1 / self.depth_range[0]) + step = (diopter_range[1] - diopter_range[0]) / (self.depth_layers - 1) + diopter_of_layers = [diopter_range[0] + + step * i for i in range(self.depth_layers)] + diopter_of_layers.insert(0, 0) + return diopter_of_layers class lightFieldSynDataLoader(torch.utils.data.dataset.Dataset): @@ -90,8 +185,8 @@ class lightFieldDataLoader(torch.utils.data.dataset.Dataset): # print(lf_image_big.shape) for i in range(9): - lf_image = lf_image_big[i//3*IM_H:i//3 * - IM_H+IM_H, i % 3*IM_W:i % 3*IM_W+IM_W, 0:3] + lf_image = lf_image_big[i // 3 * IM_H:i // 3 * + IM_H + IM_H, i % 3 * IM_W:i % 3 * IM_W + IM_W, 0:3] # IF GrayScale # lf_image = lf_image_big[i//3*IM_H:i//3*IM_H+IM_H,i%3*IM_W:i%3*IM_W+IM_W,0:1] # print(lf_image.shape) @@ -146,8 +241,8 @@ class lightFieldValDataLoader(torch.utils.data.dataset.Dataset): # print(lf_image_big.shape) for i in range(9): - lf_image = lf_image_big[i//3*IM_H:i//3 * - IM_H+IM_H, i % 3*IM_W:i % 3*IM_W+IM_W, 0:3] + lf_image = lf_image_big[i // 3 * IM_H:i // 3 * + IM_H + IM_H, i % 3 * IM_W:i % 3 * IM_W + IM_W, 0:3] # IF GrayScale # lf_image = lf_image_big[i//3*IM_H:i//3*IM_H+IM_H,i%3*IM_W:i%3*IM_W+IM_W,0:1] # print(lf_image.shape) @@ -214,8 +309,8 @@ class lightFieldSeqDataLoader(torch.utils.data.dataset.Dataset): lf_image_big = cv2.cvtColor(lf_image_big, cv2.COLOR_BGR2RGB) for j in range(9): - lf_image = lf_image_big[j//3*IM_H:j//3 * - IM_H+IM_H, j % 3*IM_W:j % 3*IM_W+IM_W, 0:3] + lf_image = lf_image_big[j // 3 * IM_H:j // 3 * + IM_H + IM_H, j % 3 * IM_W:j % 3 * IM_W + IM_W, 0:3] lf_image_one_sample.append(lf_image) gt_i = cv2.imread(fd_gt_path, cv2.IMREAD_UNCHANGED).astype( @@ -297,8 +392,8 @@ class lightFieldFlowSeqDataLoader(torch.utils.data.dataset.Dataset): lf_dim = int(self.conf.light_field_dim) for j in range(lf_dim**2): - lf_image = lf_image_big[j//lf_dim*IM_H:j//lf_dim * - IM_H+IM_H, j % lf_dim*IM_W:j % lf_dim*IM_W+IM_W, 0:3] + lf_image = lf_image_big[j // lf_dim * IM_H:j // lf_dim * + IM_H + IM_H, j % lf_dim * IM_W:j % lf_dim * IM_W + IM_W, 0:3] lf_image_one_sample.append(lf_image) gt_i = cv2.imread(fd_gt_path, cv2.IMREAD_UNCHANGED).astype( @@ -333,7 +428,7 @@ class lightFieldFlowSeqDataLoader(torch.utils.data.dataset.Dataset): retinal_invalid.append(retinal_invalid_i) # lf_images: 5,9,320,320 flow = Flow.Load([os.path.join(self.file_dir_path, self.dataset_desc["flow"] - [indices[i-1]]) for i in range(1, len(indices))]) + [indices[i - 1]]) for i in range(1, len(indices))]) flow_map = flow.getMap() flow_invalid_mask = flow.b_invalid_mask # print("flow:",flow_map.shape) diff --git a/data/spherical_view_syn.py b/data/spherical_view_syn.py index 1e65634..8072446 100644 --- a/data/spherical_view_syn.py +++ b/data/spherical_view_syn.py @@ -4,10 +4,10 @@ import torch import torchvision.transforms.functional as trans_f import torch.nn.functional as nn_f from typing import Tuple, Union -from ..my import util -from ..my import device -from ..my import view -from ..my import color_mode +from my import util +from my import device +from my import view +from my import color_mode class SphericalViewSynDataset(object): @@ -129,6 +129,13 @@ class SphericalViewSynDataset(object): self.n_views = self.view_centers.size(0) self.n_pixels = self.n_views * self.view_res[0] * self.view_res[1] + if 'gl_coord' in data_desc and data_desc['gl_coord'] == True: + print('Convert from OGL coordinate to DX coordinate (i. e. right-hand to left-hand)') + self.cam_params.f[1] *= -1 + self.view_centers[:, 2] *= -1 + self.view_rots[:, 2] *= -1 + self.view_rots[..., 2] *= -1 + def set_patch_size(self, patch_size: Union[int, Tuple[int, int]], offset: Union[int, Tuple[int, int]] = 0): """ diff --git a/nets/modules.py b/nets/modules.py index 87cf09b..e16a128 100644 --- a/nets/modules.py +++ b/nets/modules.py @@ -1,8 +1,8 @@ from typing import List, Tuple import torch import torch.nn as nn -from ..my import device -from ..my import util +from my import device +from my import util class FcLayer(nn.Module): diff --git a/nets/msl_net.py b/nets/msl_net.py index 9836c11..594c543 100644 --- a/nets/msl_net.py +++ b/nets/msl_net.py @@ -2,9 +2,8 @@ import math import torch import torch.nn as nn from .modules import * -from ..my import util -from ..my import color_mode - +from my import util +from my import color_mode class MslNet(nn.Module): diff --git a/nets/msl_net_new.py b/nets/msl_net_new.py index e9d7b5e..8301e32 100644 --- a/nets/msl_net_new.py +++ b/nets/msl_net_new.py @@ -1,8 +1,8 @@ import torch import torch.nn as nn from .modules import * -from ..my import color_mode -from ..my.simple_perf import SimplePerf +from my import color_mode +from my.simple_perf import SimplePerf class NewMslNet(nn.Module): diff --git a/nets/msl_net_new_export.py b/nets/msl_net_new_export.py index 9f418fe..1fa52d3 100644 --- a/nets/msl_net_new_export.py +++ b/nets/msl_net_new_export.py @@ -2,10 +2,10 @@ from typing import Tuple import math import torch import torch.nn as nn -from ..my import net_modules -from ..my import util -from ..my import device -from ..my import color_mode +from my import net_modules +from my import util +from my import device +from my import color_mode from .msl_net_new import NewMslNet diff --git a/nets/spher_net.py b/nets/spher_net.py index 548c6df..3472ebb 100644 --- a/nets/spher_net.py +++ b/nets/spher_net.py @@ -1,7 +1,7 @@ import torch import torch.nn as nn from .modules import * -from ..my import util +from my import util class SpherNet(nn.Module): diff --git a/nets/trans_unet.py b/nets/trans_unet.py index 0ff08ae..260371e 100644 --- a/nets/trans_unet.py +++ b/nets/trans_unet.py @@ -1,9 +1,9 @@ from typing import List import torch import torch.nn as nn -from ..pytorch_prototyping.pytorch_prototyping import * -from ..my import util -from ..my import device +from pytorch_prototyping.pytorch_prototyping import * +from my import util +from my import device class Encoder(nn.Module): diff --git a/notebook/test_spherical_view_syn.ipynb b/notebook/test_spherical_view_syn.ipynb index 75b374e..d029218 100644 --- a/notebook/test_spherical_view_syn.ipynb +++ b/notebook/test_spherical_view_syn.ipynb @@ -2,20 +2,28 @@ "cells": [ { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Set CUDA:2 as current device.\n" + ] + } + ], "source": [ "import sys\n", "import os\n", - "sys.path.append(os.path.abspath(sys.path[0] + '/../../'))\n", + "sys.path.append(os.path.abspath(sys.path[0] + '/../'))\n", "\n", "import torch\n", "import math\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", - "from deep_view_syn.my import util\n", - "from deep_view_syn.msl_net import *\n", + "from my import util\n", + "from nets.msl_net import *\n", "\n", "# Select device\n", "torch.cuda.set_device(2)\n", @@ -23,8 +31,10 @@ ] }, { - "cell_type": "markdown", + "cell_type": "code", + "execution_count": 1, "metadata": {}, + "outputs": [], "source": [ "# Test Ray-Sphere Intersection & Cartesian-Spherical Conversion" ] @@ -56,14 +66,15 @@ "v = torch.tensor([[0.0, -1.0, 1.0]])\n", "r = torch.tensor([[2.5]])\n", "v = v / torch.norm(v) * r * 2\n", - "p_on_sphere_ = RaySphereIntersect(p, v, r)[0]\n", + "p_on_sphere_ = util.RaySphereIntersect(p, v, r)[0][0]\n", "print(p_on_sphere_)\n", "print(p_on_sphere_.norm())\n", - "spher_coord = RayToSpherical(p, v, r)\n", + "spher_coord = util.CartesianToSpherical(p_on_sphere_)\n", "print(spher_coord[..., 1:3].rad2deg())\n", - "p_on_sphere = util.SphericalToCartesian(spher_coord)[0]\n", + "p_on_sphere = util.SphericalToCartesian(spher_coord)\n", + "print(p_on_sphere_.size())\n", "\n", - "fig = plt.figure(figsize=(6, 6))\n", + "fig = plt.figure(figsize=(8, 8))\n", "ax = fig.gca(projection='3d')\n", "plt.xlabel('x')\n", "plt.ylabel('z')\n", @@ -109,8 +120,10 @@ ] }, { - "cell_type": "markdown", + "cell_type": "code", + "execution_count": null, "metadata": {}, + "outputs": [], "source": [ "# Test Dataset Loader & View-Spherical Transform" ] @@ -121,26 +134,26 @@ "metadata": {}, "outputs": [], "source": [ - "from deep_view_syn.data.spherical_view_syn import FastSphericalViewSynDataset\n", - "from deep_view_syn.data.spherical_view_syn import FastDataLoader\n", + "from data.spherical_view_syn import SphericalViewSynDataset\n", + "from data.loader import FastDataLoader\n", "\n", - "DATA_DIR = '../data/sp_view_syn_2020.12.28'\n", + "DATA_DIR = '../data/nerf_fern'\n", "TRAIN_DATA_DESC_FILE = DATA_DIR + '/train.json'\n", "\n", - "dataset = FastSphericalViewSynDataset(TRAIN_DATA_DESC_FILE)\n", + "dataset = SphericalViewSynDataset(TRAIN_DATA_DESC_FILE)\n", "dataset.set_patch_size((64, 64))\n", "data_loader = FastDataLoader(dataset=dataset, batch_size=4, shuffle=False, drop_last=False)\n", "print(len(dataset))\n", - "plt.figure()\n", + "fig = plt.figure(figsize=(12, 6.5))\n", "i = 0\n", "for indices, patches, rays_o, rays_d in data_loader:\n", " print(i, patches.size(), rays_o.size(), rays_d.size())\n", " for idx in range(len(indices)):\n", - " plt.subplot(4, 4, i + 1)\n", + " plt.subplot(4, 7, i + 1)\n", " util.PlotImageTensor(patches[idx])\n", " i += 1\n", - " if i == 16:\n", - " break\n" + " if i == 28:\n", + " break" ] }, { @@ -149,13 +162,15 @@ "metadata": {}, "outputs": [], "source": [ - "from deep_view_syn.data.spherical_view_syn import SphericalViewSynDataset\n", + "from data.spherical_view_syn import SphericalViewSynDataset\n", + "from data.loader import FastDataLoader\n", "\n", - "DATA_DIR = '../data/sp_view_syn_2020.12.26'\n", + "DATA_DIR = '../data/nerf_fern'\n", "TRAIN_DATA_DESC_FILE = DATA_DIR + '/train.json'\n", "DEPTH_RANGE = (1, 10)\n", "N_DEPTH_LAYERS = 10\n", "\n", + "\n", "def _GetSphereLayers(depth_range: Tuple[float, float], n_layers: int) -> torch.Tensor:\n", " diopter_range = (1 / depth_range[1], 1 / depth_range[0])\n", " step = (diopter_range[1] - diopter_range[0]) / (n_layers - 1)\n", @@ -163,74 +178,54 @@ " depths += [1 / (diopter_range[0] + step * i) for i in range(n_layers)]\n", " return torch.tensor(depths, device=device.GetDevice()).view(-1, 1)\n", "\n", - "train_dataset = SphericalViewSynDataset(TRAIN_DATA_DESC_FILE)\n", - "train_data_loader = torch.utils.data.DataLoader(\n", - " dataset=train_dataset,\n", - " batch_size=4,\n", - " num_workers=8,\n", - " pin_memory=True,\n", - " shuffle=True,\n", - " drop_last=False)\n", - "print(len(train_data_loader))\n", - "\n", - "print(\"view_res\", train_dataset.view_res)\n", - "print(\"cam_params\", train_dataset.cam_params)\n", - "\n", - "msl_net = MslNet(train_dataset.cam_params,\n", - " _GetSphereLayers(DEPTH_RANGE, N_DEPTH_LAYERS),\n", - " train_dataset.view_res).to(device.GetDevice())\n", - "print(\"sphere layers\", msl_net.rendering.sphere_layers)\n", - "\n", - "p = None\n", - "v = None\n", - "centers = None\n", - "plt.figure(figsize=(6, 6))\n", - "for _, view_images, ray_positions, ray_directions in train_data_loader:\n", - " p = ray_positions\n", - " v = ray_directions\n", - " plt.subplot(2, 2, 1)\n", - " util.PlotImageTensor(view_images[0])\n", - " plt.subplot(2, 2, 2)\n", - " util.PlotImageTensor(view_images[1])\n", - " plt.subplot(2, 2, 3)\n", - " util.PlotImageTensor(view_images[2])\n", - " plt.subplot(2, 2, 4)\n", - " util.PlotImageTensor(view_images[3])\n", - " break\n", - "p_ = util.SphericalToCartesian(RayToSpherical(p.flatten(0, 1), v.flatten(0, 1),\n", - " torch.tensor([[1]], device=device.GetDevice()))) \\\n", - " .view(4, train_dataset.view_res[0], train_dataset.view_res[1], 3)\n", - "v = v.view(4, train_dataset.view_res[0], train_dataset.view_res[1], 3)[:, 0::50, 0::50, :].flatten(1, 2).cpu().numpy()\n", - "p_ = p_[:, 0::50, 0::50, :].flatten(1, 2).cpu().numpy()\n", - "\n", - "fig = plt.figure(figsize=(6, 6))\n", - "ax = fig.gca(projection='3d')\n", - "plt.xlabel('x')\n", - "plt.ylabel('z')\n", "\n", - "PlotSphere(ax, 1)\n", + "dataset = SphericalViewSynDataset(TRAIN_DATA_DESC_FILE)\n", + "dataset.set_patch_size(1)\n", + "data_loader = FastDataLoader(\n", + " dataset=dataset, batch_size=4096*16, shuffle=True, drop_last=False)\n", "\n", - "ax.scatter([0], [0], [0], color=\"k\", s=10) # Center\n", + "print(\"view_res\", dataset.view_res)\n", + "print(\"cam_params\", dataset.cam_params)\n", "\n", - "colors = [ 'r', 'g', 'b', 'y' ]\n", - "for i in range(4):\n", - " ax.scatter(p_[i, :, 0], p_[i, :, 2], p_[i, :, 1], color=colors[i], s=3)\n", - " for j in range(p_.shape[1]):\n", - " ax.plot([centers[i, 0], centers[i, 0] + v[i, j, 0]],\n", - " [centers[i, 2], centers[i, 2] + v[i, j, 2]],\n", - " [centers[i, 1], centers[i, 1] + v[i, j, 1]],\n", - " color=colors[i], linewidth=0.5, alpha=0.6)\n", + "fig = plt.figure(figsize=(16, 40))\n", "\n", - "ax.set_xlim(-1, 1)\n", - "ax.set_ylim(-1, 1)\n", - "ax.set_zlim(-1, 1)\n", - "\n", - "plt.show()\n" + "for ri in range(0, 10):\n", + " r = ri * 0.2 + 1\n", + " p = None\n", + " centers = None\n", + " pixels = None\n", + " for indices, patches, rays_o, rays_d in data_loader:\n", + " p = util.RaySphereIntersect(\n", + " rays_o, rays_d, torch.tensor([[r]], device=device.GetDevice()))[0] \\\n", + " .view(-1, 3).cpu().numpy()\n", + " centers = rays_o.view(-1, 3).cpu().numpy()\n", + " pixels = patches.view(-1, 3).cpu().numpy()\n", + " break\n", + " \n", + " ax = plt.subplot(5, 2, ri + 1, projection='3d')\n", + " #ax = plt.gca(projection='3d')\n", + " #ax = fig.gca()\n", + " plt.xlabel('x')\n", + " plt.ylabel('z')\n", + " plt.title('r = %f' % r)\n", + "\n", + " # PlotSphere(ax, 1)\n", + "\n", + " ax.scatter([0], [0], [0], color=\"k\", s=10)\n", + " ax.scatter(p[:, 0], p[:, 2], p[:, 1], color=pixels, s=0.5)\n", + "\n", + " #ax.set_xlim(-1, 1)\n", + " #ax.set_ylim(-1, 1)\n", + " #ax.set_zlim(-1, 1)\n", + " ax.view_init(elev=0,azim=-90)\n", + "\n" ] }, { - "cell_type": "markdown", + "cell_type": "code", + "execution_count": 3, "metadata": {}, + "outputs": [], "source": [ "# Test Sampler" ] @@ -241,7 +236,7 @@ "metadata": {}, "outputs": [], "source": [ - "from deep_view_syn.data.spherical_view_syn import SphericalViewSynDataset\n", + "from data.spherical_view_syn import SphericalViewSynDataset\n", "\n", "DATA_DIR = '../data/sp_view_syn_2020.12.29_finetrans'\n", "TRAIN_DATA_DESC_FILE = DATA_DIR + '/train.json'\n", @@ -304,7 +299,7 @@ "metadata": {}, "outputs": [], "source": [ - "from deep_view_syn.data.spherical_view_syn import SphericalViewSynDataset\n", + "from data.spherical_view_syn import SphericalViewSynDataset\n", "\n", "DATA_DIR = '../data/sp_view_syn_2020.12.26_rotonly'\n", "TRAIN_DATA_DESC_FILE = DATA_DIR + '/train.json'\n", @@ -367,8 +362,10 @@ ] }, { - "cell_type": "markdown", + "cell_type": "code", + "execution_count": null, "metadata": {}, + "outputs": [], "source": [ "# Test Spherical View Synthesis" ] @@ -381,9 +378,9 @@ "source": [ "import ipywidgets as widgets # 控件库\n", "from IPython.display import display # 显示控件的方法\n", - "from deep_view_syn.data.spherical_view_syn import SphericalViewSynDataset\n", - "from deep_view_syn.spher_net import SpherNet\n", - "from deep_view_syn.my import netio\n", + "from data.spherical_view_syn import SphericalViewSynDataset\n", + "from nets.spher_net import SpherNet\n", + "from my import netio\n", "\n", "DATA_DIR = '../data/sp_view_syn_2020.12.28_small'\n", "DATA_DESC_FILE = DATA_DIR + '/train.json'\n", @@ -455,20 +452,12 @@ "})\n", "display(slider_x, slider_y, slider_z, slider_theta, slider_phi, out)\n" ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] } ], "metadata": { "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" + "display_name": "Python 3.7.9 64-bit ('pytorch': conda)", + "name": "python379jvsc74a57bd0660ca2a75467d3af74a68fcc6f40bc78ab96b99ff17d2f100b5ca821fbb183f2" }, "language_info": { "codemirror_mode": { @@ -480,7 +469,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.6-final" + "version": "3.7.9" } }, "nbformat": 4, diff --git a/run_spherical_view_syn.py b/run_spherical_view_syn.py index e49373f..406c8f4 100644 --- a/run_spherical_view_syn.py +++ b/run_spherical_view_syn.py @@ -6,9 +6,6 @@ import torch.optim from tensorboardX import SummaryWriter from torch import nn -sys.path.append(os.path.abspath(sys.path[0] + '/../')) -__package__ = "deep_view_syn" - parser = argparse.ArgumentParser() parser.add_argument('--device', type=int, default=3, help='Which CUDA device to use.') @@ -46,15 +43,15 @@ if opt.res: torch.cuda.set_device(opt.device) print("Set CUDA:%d as current device." % torch.cuda.current_device()) -from .my import netio -from .my import util -from .my import device -from .my import loss -from .my.progress_bar import progress_bar -from .my.simple_perf import SimplePerf -from .data.spherical_view_syn import * -from .data.loader import FastDataLoader -from .configs.spherical_view_syn import SphericalViewSynConfig +from my import netio +from my import util +from my import device +from my import loss +from my.progress_bar import progress_bar +from my.simple_perf import SimplePerf +from data.spherical_view_syn import * +from data.loader import FastDataLoader +from configs.spherical_view_syn import SphericalViewSynConfig config = SphericalViewSynConfig() -- GitLab