Commit f6604bd2 authored by Nianchen Deng's avatar Nianchen Deng
Browse files

rebuttal version

parent 6e54b394
from ..my import color_mode from my import color_mode
def update_config(config): def update_config(config):
# Dataset settings # Dataset settings
......
from ..my import color_mode from my import color_mode
def update_config(config): def update_config(config):
# Dataset settings # Dataset settings
......
from ..my import color_mode from my import color_mode
def update_config(config): def update_config(config):
# Dataset settings # Dataset settings
......
from ..my import color_mode from my import color_mode
def update_config(config): def update_config(config):
# Dataset settings # Dataset settings
......
import os import os
import importlib import importlib
from os.path import join from my import color_mode
from ..my import color_mode from nets.msl_net import MslNet
from ..nets.msl_net import MslNet from nets.msl_net_new import NewMslNet
from ..nets.msl_net_new import NewMslNet
from ..nets.spher_net import SpherNet
class SphericalViewSynConfig(object): class SphericalViewSynConfig(object):
...@@ -36,14 +34,13 @@ class SphericalViewSynConfig(object): ...@@ -36,14 +34,13 @@ class SphericalViewSynConfig(object):
def load(self, path): def load(self, path):
module_name = os.path.splitext(path)[0].replace('/', '.') module_name = os.path.splitext(path)[0].replace('/', '.')
config_module = importlib.import_module( config_module = importlib.import_module(module_name)
'deep_view_syn.' + module_name)
config_module.update_config(self) config_module.update_config(self)
self.name = module_name.split('.')[-1] self.name = module_name.split('.')[-1]
def load_by_name(self, name): def load_by_name(self, name):
config_module = importlib.import_module( config_module = importlib.import_module(
'deep_view_syn.configs.' + name) 'configs.' + name)
config_module.update_config(self) config_module.update_config(self)
self.name = name self.name = name
......
from ..my import color_mode from my import color_mode
def update_config(config): def update_config(config):
# Dataset settings # Dataset settings
......
from ..my import color_mode from my import color_mode
def update_config(config): def update_config(config):
# Dataset settings # Dataset settings
......
from ..my import color_mode from my import color_mode
def update_config(config): def update_config(config):
# Dataset settings # Dataset settings
......
...@@ -10,7 +10,7 @@ import plotly.express as px ...@@ -10,7 +10,7 @@ import plotly.express as px
import pandas as pd import pandas as pd
from dash.dependencies import Input, Output from dash.dependencies import Input, Output
sys.path.append(os.path.abspath(sys.path[0] + '/../')) #sys.path.append(os.path.abspath(sys.path[0] + '/../'))
#__package__ = "deep_view_syn" #__package__ = "deep_view_syn"
if __name__ == '__main__': if __name__ == '__main__':
...@@ -24,23 +24,30 @@ if __name__ == '__main__': ...@@ -24,23 +24,30 @@ if __name__ == '__main__':
print("Set CUDA:%d as current device." % torch.cuda.current_device()) print("Set CUDA:%d as current device." % torch.cuda.current_device())
torch.autograd.set_grad_enabled(False) torch.autograd.set_grad_enabled(False)
from deep_view_syn.data.spherical_view_syn import * from data.spherical_view_syn import *
from deep_view_syn.configs.spherical_view_syn import SphericalViewSynConfig from configs.spherical_view_syn import SphericalViewSynConfig
from deep_view_syn.my import netio from my import netio
from deep_view_syn.my import util from my import util
from deep_view_syn.my import device from my import device
from deep_view_syn.my import view from my import view
from deep_view_syn.my.gen_final import GenFinal from my.gen_final import GenFinal
from deep_view_syn.nets.modules import Sampler from nets.modules import Sampler
datadir = None datadir = 'data/__0_user_study/us_gas_periph_r135x135_t0.3_2021.01.16/'
data_desc_file = 'train.json'
net_config = 'periph_rgb@msl-rgb_e10_fc96x4_d1.00-50.00_s16'
net_path = datadir + net_config + '/model-epoch_200.pth'
fov = 45
res = (256, 256)
view_idx = 4
center = (0, 0)
def load_net(path): def load_net(path):
print(path) print(path)
config = SphericalViewSynConfig() config = SphericalViewSynConfig()
config.from_id(os.path.splitext(os.path.basename(path))[0]) config.from_id(net_config)
config.SAMPLE_PARAMS['perturb_sample'] = False config.SAMPLE_PARAMS['perturb_sample'] = False
net = config.create_net().to(device.GetDevice()) net = config.create_net().to(device.GetDevice())
netio.LoadNet(path, net) netio.LoadNet(path, net)
...@@ -64,24 +71,25 @@ def load_views(data_desc_file) -> view.Trans: ...@@ -64,24 +71,25 @@ def load_views(data_desc_file) -> view.Trans:
return view.Trans(view_centers, view_rots) return view.Trans(view_centers, view_rots)
scenes = { cam = view.CameraParam({
'gas': '__0_user_study/us_gas_all_in_one', 'fov': fov,
'mc': '__0_user_study/us_mc_all_in_one', 'cx': 0.5,
'bedroom': 'bedroom_all_in_one', 'cy': 0.5,
'gallery': 'gallery_all_in_one', 'normalized': True
'lobby': 'lobby_all_in_one' }, res, device=device.GetDevice())
} net = load_net(net_path)
sampler = Sampler(depth_range=(1, 50), n_samples=32, perturb_sample=False,
fov_list = [20, 45, 110] spherical=True, lindisp=True, inverse_r=True)
res_list = [(128, 128), (256, 256), (256, 230)] x = y = None
res_full = (1600, 1440)
views = load_views(data_desc_file)
print('%d Views loaded.' % views.size()[0])
scene = 'gas' test_view = views.get(view_idx)
view_file = 'views.json' rays_o, rays_d = cam.get_global_rays(test_view, True)
image = net(rays_o.view(-1, 3), rays_d.view(-1, 3)).view(1,
res[0], res[1], -1).permute(0, 3, 1, 2)
app = dash.Dash(__name__, external_stylesheets=[
'https://codepen.io/chriddyp/pen/bWLwgP.css'])
styles = { styles = {
'pre': { 'pre': {
...@@ -90,35 +98,17 @@ styles = { ...@@ -90,35 +98,17 @@ styles = {
'overflowX': 'scroll' 'overflowX': 'scroll'
} }
} }
fig = px.imshow(util.Tensor2MatImg(image))
datadir = 'data/' + scenes[scene] + '/'
fovea_net = load_net_by_name('fovea')
periph_net = load_net_by_name('periph')
gen = GenFinal(fov_list, res_list, res_full, fovea_net, periph_net,
device=device.GetDevice())
sampler = Sampler(depth_range=(1, 50), n_samples=32, perturb_sample=False,
spherical=True, lindisp=True, inverse_r=True)
x = y = None
views = load_views(view_file)
print('%d Views loaded.', views.size())
view_idx = 27
center = (0, 0)
test_view = views.get(view_idx)
images = gen(center, test_view)
fig = px.imshow(util.Tensor2MatImg(images['fovea']))
fig1 = px.scatter(x=[0, 1, 2], y=[2, 0, 1]) fig1 = px.scatter(x=[0, 1, 2], y=[2, 0, 1])
fig2 = px.scatter(x=[0, 1, 2], y=[2, 0, 1])
app = dash.Dash(__name__, external_stylesheets=[
'https://codepen.io/chriddyp/pen/bWLwgP.css'])
app.layout = html.Div([ app.layout = html.Div([
html.H3("Drag and draw annotations"), html.H3("Drag and draw annotations"),
html.Div(className='row', children=[ html.Div(className='row', children=[
dcc.Graph(id='image', figure=fig), # , config=config), dcc.Graph(id='image', figure=fig), # , config=config),
dcc.Graph(id='scatter', figure=fig1), # , config=config), dcc.Graph(id='scatter', figure=fig1), # , config=config),
dcc.Graph(id='scatter1', figure=fig2), # , config=config),
dcc.Slider(id='samples-slider', min=4, max=128, step=None, dcc.Slider(id='samples-slider', min=4, max=128, step=None,
marks={ marks={
4: '4', 4: '4',
...@@ -128,43 +118,91 @@ app.layout = html.Div([ ...@@ -128,43 +118,91 @@ app.layout = html.Div([
64: '64', 64: '64',
128: '128', 128: '128',
}, },
value=32, value=33,
updatemode='drag' updatemode='drag'
) )
]) ])
]) ])
def raw2alpha(raw, dists, act_fn=torch.relu):
"""
Function for computing density from model prediction.
This value is strictly between [0, 1].
"""
print('act_fn: ', act_fn(raw))
print('act_fn * dists: ', act_fn(raw) * dists)
return -torch.exp(-act_fn(raw) * dists) + 1.0
def raw2color(raw: torch.Tensor, z_vals: torch.Tensor):
"""
Raw value inferred from model to color and alpha
:param raw ```Tensor(N.rays, N.samples, 2|4)```: model's output
:param z_vals ```Tensor(N.rays, N.samples)```: integration time
:return ```Tensor(N.rays, N.samples, 1|3)```: color
:return ```Tensor(N.rays, N.samples)```: alpha
"""
# Compute 'distance' (in time) between each integration time along a ray.
# The 'distance' from the last integration time is infinity.
# dists: (N_rays, N_samples)
dists = z_vals[..., 1:] - z_vals[..., :-1]
last_dist = z_vals[..., 0:1] * 0 + 1e10
dists = torch.cat([dists, last_dist], -1)
print('dists: ', dists)
# Extract RGB of each sample position along each ray.
color = torch.sigmoid(raw[..., :-1]) # (N_rays, N_samples, 1|3)
alpha = raw2alpha(raw[..., -1], dists)
return color, alpha
def draw_scatter(): def draw_scatter():
global fig1 global fig1, fig2
p = torch.tensor([x, y], device=gen.layer_cams[0].c.device) p = torch.tensor([x, y], device=device.GetDevice())
ray_d = test_view.trans_vector(gen.layer_cams[0].unproj(p)) ray_d = test_view.trans_vector(cam.unproj(p))
ray_o = test_view.t ray_o = test_view.t
raw, depths = fovea_net.sample_and_infer(ray_o, ray_d, sampler=sampler) raw, depths = net.sample_and_infer(ray_o, ray_d, sampler=sampler)
colors, alphas = fovea_net.rendering.raw2color(raw, depths) colors, alphas = raw2color(raw, depths)
scatter_x = (1 / depths[0]).cpu().detach().numpy() scatter_x = (1 / depths[0]).cpu().detach().numpy()
scatter_y = alphas[0].cpu().detach().numpy() scatter_y = alphas[0].cpu().detach().numpy()
scatter_y1 = raw[0, :, 3].cpu().detach().numpy()
scatter_color = colors[0].cpu().detach().numpy() * 255 scatter_color = colors[0].cpu().detach().numpy() * 255
marker_colors = [ marker_colors = [
i#'rgb(%d,%d,%d)' % (scatter_color[i][0], scatter_color[i][1], scatter_color[i][2]) # 'rgb(%d,%d,%d)' % (scatter_color[i][0], scatter_color[i][1], scatter_color[i][2])
i
for i in range(scatter_color.shape[0]) for i in range(scatter_color.shape[0])
] ]
marker_colors_str = [ marker_colors_str = [
'rgb(%d,%d,%d)' % (scatter_color[i][0], scatter_color[i][1], scatter_color[i][2]) 'rgb(%d,%d,%d)' % (scatter_color[i][0],
scatter_color[i][1], scatter_color[i][2])
for i in range(scatter_color.shape[0]) for i in range(scatter_color.shape[0])
] ]
fig1 = px.scatter(x=scatter_x, y=scatter_y, color=marker_colors, color_continuous_scale=marker_colors_str)#, color_discrete_map='identity') fig1 = px.scatter(x=scatter_x, y=scatter_y, color=marker_colors,
color_continuous_scale=marker_colors_str) # , color_discrete_map='identity')
fig1.update_traces(mode='lines+markers') fig1.update_traces(mode='lines+markers')
fig1.update_xaxes(showgrid=False) fig1.update_xaxes(showgrid=False)
fig1.update_yaxes(type='linear') fig1.update_yaxes(type='linear')
fig1.update_layout(height=225, margin={'l': 20, 'b': 30, 'r': 10, 't': 10}) fig1.update_layout(height=225, margin={'l': 20, 'b': 30, 'r': 10, 't': 10})
fig2 = px.scatter(x=scatter_x, y=scatter_y1, color=marker_colors,
color_continuous_scale=marker_colors_str) # , color_discrete_map='identity')
fig2.update_traces(mode='lines+markers')
fig2.update_xaxes(showgrid=False)
fig2.update_yaxes(type='linear')
fig2.update_layout(height=225, margin={'l': 20, 'b': 30, 'r': 10, 't': 10})
@app.callback( @app.callback(
[Output('image', 'figure'), [Output('image', 'figure'),
Output('scatter', 'figure')], Output('scatter', 'figure'),
Output('scatter1', 'figure')],
[Input('image', 'clickData'), [Input('image', 'clickData'),
dash.dependencies.Input('samples-slider', 'value')] dash.dependencies.Input('samples-slider', 'value')]
) )
...@@ -194,7 +232,7 @@ def display_hover_data(clickData, samples): ...@@ -194,7 +232,7 @@ def display_hover_data(clickData, samples):
color="LightSeaGreen", color="LightSeaGreen",
width=3, width=3,
)) ))
return fig, fig1 return fig, fig1, fig2
if __name__ == '__main__': if __name__ == '__main__':
......
from typing import List, Tuple
import torch
import json
from ..my import util
def ReadLightField(path: str, views: Tuple[int, int], flatten_views: bool = False) -> torch.Tensor:
input_img = util.ReadImageTensor(path, batch_dim=False)
h = input_img.size()[1] // views[0]
w = input_img.size()[2] // views[1]
if flatten_views:
lf = torch.empty(views[0] * views[1], 3, h, w)
for y_i in range(views[0]):
for x_i in range(views[1]):
lf[y_i * views[1] + x_i, :, :, :] = \
input_img[:, y_i * h:(y_i + 1) * h, x_i * w:(x_i + 1) * w]
else:
lf = torch.empty(views[0], views[1], 3, h, w)
for y_i in range(views[0]):
for x_i in range(views[1]):
lf[y_i, x_i, :, :, :] = \
input_img[:, y_i * h:(y_i + 1) * h, x_i * w:(x_i + 1) * w]
return lf
def DecodeDepth(depth_images: torch.Tensor) -> torch.Tensor:
return depth_images[:, 0].unsqueeze(1).mul(255) / 10
class LightFieldSynDataset(torch.utils.data.dataset.Dataset):
"""
Data loader for light field synthesis task
Attributes
--------
data_dir ```string```: the directory of dataset\n
n_views ```tuple(int, int)```: rows and columns of views\n
num_views ```int```: number of views\n
view_images ```N x H x W Tensor```: images of views\n
view_depths ```N x H x W Tensor```: depths of views\n
view_positions ```N x 3 Tensor```: positions of views\n
sparse_view_images ```N' x H x W Tensor```: images of sparse views\n
sparse_view_depths ```N' x H x W Tensor```: depths of sparse views\n
sparse_view_positions ```N' x 3 Tensor```: positions of sparse views\n
"""
def __init__(self, data_desc_path: str):
"""
Initialize data loader for light field synthesis task
The data description file is a JSON file with following fields:
- lf: string, the path of light field image
- lf_depth: string, the path of light field depth image
- n_views: { "x", "y" }, columns and rows of views
- cam_params: { "f", "c" }, the focal and center of camera (in normalized image space)
- depth_range: [ min, max ], the range of depth in depth maps
- depth_layers: int, number of layers in depth maps
- view_positions: [ [ x, y, z ], ... ], positions of views
:param data_desc_path: path to the data description file
"""
self.data_dir = data_desc_path.rsplit('/', 1)[0] + '/'
with open(data_desc_path, 'r', encoding='utf-8') as file:
self.data_desc = json.loads(file.read())
self.n_views = (self.data_desc['n_views']
['y'], self.data_desc['n_views']['x'])
self.num_views = self.n_views[0] * self.n_views[1]
self.view_images = ReadLightField(
self.data_dir + self.data_desc['lf'], self.n_views, True)
self.view_depths = DecodeDepth(ReadLightField(
self.data_dir + self.data_desc['lf_depth'], self.n_views, True))
self.cam_params = self.data_desc['cam_params']
self.depth_range = self.data_desc['depth_range']
self.depth_layers = self.data_desc['depth_layers']
self.view_positions = torch.tensor(self.data_desc['view_positions'])
_, self.sparse_view_images, self.sparse_view_depths, self.sparse_view_positions \
= self._GetCornerViews()
self.diopter_of_layers = self._GetDiopterOfLayers()
def __len__(self):
return self.num_views
def __getitem__(self, idx):
return idx, self.view_images[idx], self.view_depths[idx], self.view_positions[idx]
def _GetCornerViews(self):
corner_selector = torch.zeros(self.num_views, dtype=torch.bool)
corner_selector[0] = corner_selector[self.n_views[1] - 1] \
= corner_selector[self.num_views - self.n_views[1]] \
= corner_selector[self.num_views - 1] = True
return self.__getitem__(corner_selector)
def _GetDiopterOfLayers(self) -> List[float]:
diopter_range = (1 / self.depth_range[1], 1 / self.depth_range[0])
step = (diopter_range[1] - diopter_range[0]) / (self.depth_layers - 1)
diopter_of_layers = [diopter_range[0] + step * i for i in range(self.depth_layers)]
diopter_of_layers.insert(0, 0)
return diopter_of_layers
import torch import torch
import math import math
from ..my import device from my import device
class FastDataLoader(object): class FastDataLoader(object):
......
import torch import torch
import os import os
import json
import glob import glob
import cv2
import numpy as np import numpy as np
import torchvision.transforms as transforms import torchvision.transforms as transforms
from typing import List, Tuple
from torchvision import datasets from torchvision import datasets
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
import cv2 from my.flow import *
import json from my.gen_image import *
from Flow import * from my import util
from gen_image import *
import util
def ReadLightField(path: str, views: Tuple[int, int], flatten_views: bool = False) -> torch.Tensor:
input_img = util.ReadImageTensor(path, batch_dim=False)
h = input_img.size()[1] // views[0]
w = input_img.size()[2] // views[1]
if flatten_views:
lf = torch.empty(views[0] * views[1], 3, h, w)
for y_i in range(views[0]):
for x_i in range(views[1]):
lf[y_i * views[1] + x_i, :, :, :] = \
input_img[:, y_i * h:(y_i + 1) * h, x_i * w:(x_i + 1) * w]
else:
lf = torch.empty(views[0], views[1], 3, h, w)
for y_i in range(views[0]):
for x_i in range(views[1]):
lf[y_i, x_i, :, :, :] = \
input_img[:, y_i * h:(y_i + 1) * h, x_i * w:(x_i + 1) * w]
return lf
def DecodeDepth(depth_images: torch.Tensor) -> torch.Tensor:
return depth_images[:, 0].unsqueeze(1).mul(255) / 10
class LightFieldSynDataset(torch.utils.data.dataset.Dataset):
"""
Data loader for light field synthesis task
Attributes
--------
data_dir ```string```: the directory of dataset\n
n_views ```tuple(int, int)```: rows and columns of views\n
num_views ```int```: number of views\n
view_images ```N x H x W Tensor```: images of views\n
view_depths ```N x H x W Tensor```: depths of views\n
view_positions ```N x 3 Tensor```: positions of views\n
sparse_view_images ```N' x H x W Tensor```: images of sparse views\n
sparse_view_depths ```N' x H x W Tensor```: depths of sparse views\n
sparse_view_positions ```N' x 3 Tensor```: positions of sparse views\n
"""
def __init__(self, data_desc_path: str):
"""
Initialize data loader for light field synthesis task
The data description file is a JSON file with following fields:
- lf: string, the path of light field image
- lf_depth: string, the path of light field depth image
- n_views: { "x", "y" }, columns and rows of views
- cam_params: { "f", "c" }, the focal and center of camera (in normalized image space)
- depth_range: [ min, max ], the range of depth in depth maps
- depth_layers: int, number of layers in depth maps
- view_positions: [ [ x, y, z ], ... ], positions of views
:param data_desc_path: path to the data description file
"""
self.data_dir = data_desc_path.rsplit('/', 1)[0] + '/'
with open(data_desc_path, 'r', encoding='utf-8') as file:
self.data_desc = json.loads(file.read())
self.n_views = (self.data_desc['n_views']
['y'], self.data_desc['n_views']['x'])
self.num_views = self.n_views[0] * self.n_views[1]
self.view_images = ReadLightField(
self.data_dir + self.data_desc['lf'], self.n_views, True)
self.view_depths = DecodeDepth(ReadLightField(
self.data_dir + self.data_desc['lf_depth'], self.n_views, True))
self.cam_params = self.data_desc['cam_params']
self.depth_range = self.data_desc['depth_range']
self.depth_layers = self.data_desc['depth_layers']
self.view_positions = torch.tensor(self.data_desc['view_positions'])
_, self.sparse_view_images, self.sparse_view_depths, self.sparse_view_positions \
= self._GetCornerViews()
self.diopter_of_layers = self._GetDiopterOfLayers()
def __len__(self):
return self.num_views
def __getitem__(self, idx):
return idx, self.view_images[idx], self.view_depths[idx], self.view_positions[idx]
def _GetCornerViews(self):
corner_selector = torch.zeros(self.num_views, dtype=torch.bool)
corner_selector[0] = corner_selector[self.n_views[1] - 1] \
= corner_selector[self.num_views - self.n_views[1]] \
= corner_selector[self.num_views - 1] = True
return self.__getitem__(corner_selector)
import time def _GetDiopterOfLayers(self) -> List[float]:
diopter_range = (1 / self.depth_range[1], 1 / self.depth_range[0])
step = (diopter_range[1] - diopter_range[0]) / (self.depth_layers - 1)
diopter_of_layers = [diopter_range[0] +
step * i for i in range(self.depth_layers)]
diopter_of_layers.insert(0, 0)
return diopter_of_layers
class lightFieldSynDataLoader(torch.utils.data.dataset.Dataset): class lightFieldSynDataLoader(torch.utils.data.dataset.Dataset):
...@@ -90,8 +185,8 @@ class lightFieldDataLoader(torch.utils.data.dataset.Dataset): ...@@ -90,8 +185,8 @@ class lightFieldDataLoader(torch.utils.data.dataset.Dataset):
# print(lf_image_big.shape) # print(lf_image_big.shape)
for i in range(9): for i in range(9):
lf_image = lf_image_big[i//3*IM_H:i//3 * lf_image = lf_image_big[i // 3 * IM_H:i // 3 *
IM_H+IM_H, i % 3*IM_W:i % 3*IM_W+IM_W, 0:3] IM_H + IM_H, i % 3 * IM_W:i % 3 * IM_W + IM_W, 0:3]
# IF GrayScale # IF GrayScale
# lf_image = lf_image_big[i//3*IM_H:i//3*IM_H+IM_H,i%3*IM_W:i%3*IM_W+IM_W,0:1] # lf_image = lf_image_big[i//3*IM_H:i//3*IM_H+IM_H,i%3*IM_W:i%3*IM_W+IM_W,0:1]
# print(lf_image.shape) # print(lf_image.shape)
...@@ -146,8 +241,8 @@ class lightFieldValDataLoader(torch.utils.data.dataset.Dataset): ...@@ -146,8 +241,8 @@ class lightFieldValDataLoader(torch.utils.data.dataset.Dataset):
# print(lf_image_big.shape) # print(lf_image_big.shape)
for i in range(9): for i in range(9):
lf_image = lf_image_big[i//3*IM_H:i//3 * lf_image = lf_image_big[i // 3 * IM_H:i // 3 *
IM_H+IM_H, i % 3*IM_W:i % 3*IM_W+IM_W, 0:3] IM_H + IM_H, i % 3 * IM_W:i % 3 * IM_W + IM_W, 0:3]
# IF GrayScale # IF GrayScale
# lf_image = lf_image_big[i//3*IM_H:i//3*IM_H+IM_H,i%3*IM_W:i%3*IM_W+IM_W,0:1] # lf_image = lf_image_big[i//3*IM_H:i//3*IM_H+IM_H,i%3*IM_W:i%3*IM_W+IM_W,0:1]
# print(lf_image.shape) # print(lf_image.shape)
...@@ -214,8 +309,8 @@ class lightFieldSeqDataLoader(torch.utils.data.dataset.Dataset): ...@@ -214,8 +309,8 @@ class lightFieldSeqDataLoader(torch.utils.data.dataset.Dataset):
lf_image_big = cv2.cvtColor(lf_image_big, cv2.COLOR_BGR2RGB) lf_image_big = cv2.cvtColor(lf_image_big, cv2.COLOR_BGR2RGB)
for j in range(9): for j in range(9):
lf_image = lf_image_big[j//3*IM_H:j//3 * lf_image = lf_image_big[j // 3 * IM_H:j // 3 *
IM_H+IM_H, j % 3*IM_W:j % 3*IM_W+IM_W, 0:3] IM_H + IM_H, j % 3 * IM_W:j % 3 * IM_W + IM_W, 0:3]
lf_image_one_sample.append(lf_image) lf_image_one_sample.append(lf_image)
gt_i = cv2.imread(fd_gt_path, cv2.IMREAD_UNCHANGED).astype( gt_i = cv2.imread(fd_gt_path, cv2.IMREAD_UNCHANGED).astype(
...@@ -297,8 +392,8 @@ class lightFieldFlowSeqDataLoader(torch.utils.data.dataset.Dataset): ...@@ -297,8 +392,8 @@ class lightFieldFlowSeqDataLoader(torch.utils.data.dataset.Dataset):
lf_dim = int(self.conf.light_field_dim) lf_dim = int(self.conf.light_field_dim)
for j in range(lf_dim**2): for j in range(lf_dim**2):
lf_image = lf_image_big[j//lf_dim*IM_H:j//lf_dim * lf_image = lf_image_big[j // lf_dim * IM_H:j // lf_dim *
IM_H+IM_H, j % lf_dim*IM_W:j % lf_dim*IM_W+IM_W, 0:3] IM_H + IM_H, j % lf_dim * IM_W:j % lf_dim * IM_W + IM_W, 0:3]
lf_image_one_sample.append(lf_image) lf_image_one_sample.append(lf_image)
gt_i = cv2.imread(fd_gt_path, cv2.IMREAD_UNCHANGED).astype( gt_i = cv2.imread(fd_gt_path, cv2.IMREAD_UNCHANGED).astype(
...@@ -333,7 +428,7 @@ class lightFieldFlowSeqDataLoader(torch.utils.data.dataset.Dataset): ...@@ -333,7 +428,7 @@ class lightFieldFlowSeqDataLoader(torch.utils.data.dataset.Dataset):
retinal_invalid.append(retinal_invalid_i) retinal_invalid.append(retinal_invalid_i)
# lf_images: 5,9,320,320 # lf_images: 5,9,320,320
flow = Flow.Load([os.path.join(self.file_dir_path, self.dataset_desc["flow"] flow = Flow.Load([os.path.join(self.file_dir_path, self.dataset_desc["flow"]
[indices[i-1]]) for i in range(1, len(indices))]) [indices[i - 1]]) for i in range(1, len(indices))])
flow_map = flow.getMap() flow_map = flow.getMap()
flow_invalid_mask = flow.b_invalid_mask flow_invalid_mask = flow.b_invalid_mask
# print("flow:",flow_map.shape) # print("flow:",flow_map.shape)
......
...@@ -4,10 +4,10 @@ import torch ...@@ -4,10 +4,10 @@ import torch
import torchvision.transforms.functional as trans_f import torchvision.transforms.functional as trans_f
import torch.nn.functional as nn_f import torch.nn.functional as nn_f
from typing import Tuple, Union from typing import Tuple, Union
from ..my import util from my import util
from ..my import device from my import device
from ..my import view from my import view
from ..my import color_mode from my import color_mode
class SphericalViewSynDataset(object): class SphericalViewSynDataset(object):
...@@ -129,6 +129,13 @@ class SphericalViewSynDataset(object): ...@@ -129,6 +129,13 @@ class SphericalViewSynDataset(object):
self.n_views = self.view_centers.size(0) self.n_views = self.view_centers.size(0)
self.n_pixels = self.n_views * self.view_res[0] * self.view_res[1] self.n_pixels = self.n_views * self.view_res[0] * self.view_res[1]
if 'gl_coord' in data_desc and data_desc['gl_coord'] == True:
print('Convert from OGL coordinate to DX coordinate (i. e. right-hand to left-hand)')
self.cam_params.f[1] *= -1
self.view_centers[:, 2] *= -1
self.view_rots[:, 2] *= -1
self.view_rots[..., 2] *= -1
def set_patch_size(self, patch_size: Union[int, Tuple[int, int]], def set_patch_size(self, patch_size: Union[int, Tuple[int, int]],
offset: Union[int, Tuple[int, int]] = 0): offset: Union[int, Tuple[int, int]] = 0):
""" """
......
from typing import List, Tuple from typing import List, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
from ..my import device from my import device
from ..my import util from my import util
class FcLayer(nn.Module): class FcLayer(nn.Module):
......
...@@ -2,9 +2,8 @@ import math ...@@ -2,9 +2,8 @@ import math
import torch import torch
import torch.nn as nn import torch.nn as nn
from .modules import * from .modules import *
from ..my import util from my import util
from ..my import color_mode from my import color_mode
class MslNet(nn.Module): class MslNet(nn.Module):
......
import torch import torch
import torch.nn as nn import torch.nn as nn
from .modules import * from .modules import *
from ..my import color_mode from my import color_mode
from ..my.simple_perf import SimplePerf from my.simple_perf import SimplePerf
class NewMslNet(nn.Module): class NewMslNet(nn.Module):
......
...@@ -2,10 +2,10 @@ from typing import Tuple ...@@ -2,10 +2,10 @@ from typing import Tuple
import math import math
import torch import torch
import torch.nn as nn import torch.nn as nn
from ..my import net_modules from my import net_modules
from ..my import util from my import util
from ..my import device from my import device
from ..my import color_mode from my import color_mode
from .msl_net_new import NewMslNet from .msl_net_new import NewMslNet
......
import torch import torch
import torch.nn as nn import torch.nn as nn
from .modules import * from .modules import *
from ..my import util from my import util
class SpherNet(nn.Module): class SpherNet(nn.Module):
......
from typing import List from typing import List
import torch import torch
import torch.nn as nn import torch.nn as nn
from ..pytorch_prototyping.pytorch_prototyping import * from pytorch_prototyping.pytorch_prototyping import *
from ..my import util from my import util
from ..my import device from my import device
class Encoder(nn.Module): class Encoder(nn.Module):
......
...@@ -2,20 +2,28 @@ ...@@ -2,20 +2,28 @@
"cells": [ "cells": [
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 2,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Set CUDA:2 as current device.\n"
]
}
],
"source": [ "source": [
"import sys\n", "import sys\n",
"import os\n", "import os\n",
"sys.path.append(os.path.abspath(sys.path[0] + '/../../'))\n", "sys.path.append(os.path.abspath(sys.path[0] + '/../'))\n",
"\n", "\n",
"import torch\n", "import torch\n",
"import math\n", "import math\n",
"import matplotlib.pyplot as plt\n", "import matplotlib.pyplot as plt\n",
"import numpy as np\n", "import numpy as np\n",
"from deep_view_syn.my import util\n", "from my import util\n",
"from deep_view_syn.msl_net import *\n", "from nets.msl_net import *\n",
"\n", "\n",
"# Select device\n", "# Select device\n",
"torch.cuda.set_device(2)\n", "torch.cuda.set_device(2)\n",
...@@ -23,8 +31,10 @@ ...@@ -23,8 +31,10 @@
] ]
}, },
{ {
"cell_type": "markdown", "cell_type": "code",
"execution_count": 1,
"metadata": {}, "metadata": {},
"outputs": [],
"source": [ "source": [
"# Test Ray-Sphere Intersection & Cartesian-Spherical Conversion" "# Test Ray-Sphere Intersection & Cartesian-Spherical Conversion"
] ]
...@@ -56,14 +66,15 @@ ...@@ -56,14 +66,15 @@
"v = torch.tensor([[0.0, -1.0, 1.0]])\n", "v = torch.tensor([[0.0, -1.0, 1.0]])\n",
"r = torch.tensor([[2.5]])\n", "r = torch.tensor([[2.5]])\n",
"v = v / torch.norm(v) * r * 2\n", "v = v / torch.norm(v) * r * 2\n",
"p_on_sphere_ = RaySphereIntersect(p, v, r)[0]\n", "p_on_sphere_ = util.RaySphereIntersect(p, v, r)[0][0]\n",
"print(p_on_sphere_)\n", "print(p_on_sphere_)\n",
"print(p_on_sphere_.norm())\n", "print(p_on_sphere_.norm())\n",
"spher_coord = RayToSpherical(p, v, r)\n", "spher_coord = util.CartesianToSpherical(p_on_sphere_)\n",
"print(spher_coord[..., 1:3].rad2deg())\n", "print(spher_coord[..., 1:3].rad2deg())\n",
"p_on_sphere = util.SphericalToCartesian(spher_coord)[0]\n", "p_on_sphere = util.SphericalToCartesian(spher_coord)\n",
"print(p_on_sphere_.size())\n",
"\n", "\n",
"fig = plt.figure(figsize=(6, 6))\n", "fig = plt.figure(figsize=(8, 8))\n",
"ax = fig.gca(projection='3d')\n", "ax = fig.gca(projection='3d')\n",
"plt.xlabel('x')\n", "plt.xlabel('x')\n",
"plt.ylabel('z')\n", "plt.ylabel('z')\n",
...@@ -109,8 +120,10 @@ ...@@ -109,8 +120,10 @@
] ]
}, },
{ {
"cell_type": "markdown", "cell_type": "code",
"execution_count": null,
"metadata": {}, "metadata": {},
"outputs": [],
"source": [ "source": [
"# Test Dataset Loader & View-Spherical Transform" "# Test Dataset Loader & View-Spherical Transform"
] ]
...@@ -121,26 +134,26 @@ ...@@ -121,26 +134,26 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"from deep_view_syn.data.spherical_view_syn import FastSphericalViewSynDataset\n", "from data.spherical_view_syn import SphericalViewSynDataset\n",
"from deep_view_syn.data.spherical_view_syn import FastDataLoader\n", "from data.loader import FastDataLoader\n",
"\n", "\n",
"DATA_DIR = '../data/sp_view_syn_2020.12.28'\n", "DATA_DIR = '../data/nerf_fern'\n",
"TRAIN_DATA_DESC_FILE = DATA_DIR + '/train.json'\n", "TRAIN_DATA_DESC_FILE = DATA_DIR + '/train.json'\n",
"\n", "\n",
"dataset = FastSphericalViewSynDataset(TRAIN_DATA_DESC_FILE)\n", "dataset = SphericalViewSynDataset(TRAIN_DATA_DESC_FILE)\n",
"dataset.set_patch_size((64, 64))\n", "dataset.set_patch_size((64, 64))\n",
"data_loader = FastDataLoader(dataset=dataset, batch_size=4, shuffle=False, drop_last=False)\n", "data_loader = FastDataLoader(dataset=dataset, batch_size=4, shuffle=False, drop_last=False)\n",
"print(len(dataset))\n", "print(len(dataset))\n",
"plt.figure()\n", "fig = plt.figure(figsize=(12, 6.5))\n",
"i = 0\n", "i = 0\n",
"for indices, patches, rays_o, rays_d in data_loader:\n", "for indices, patches, rays_o, rays_d in data_loader:\n",
" print(i, patches.size(), rays_o.size(), rays_d.size())\n", " print(i, patches.size(), rays_o.size(), rays_d.size())\n",
" for idx in range(len(indices)):\n", " for idx in range(len(indices)):\n",
" plt.subplot(4, 4, i + 1)\n", " plt.subplot(4, 7, i + 1)\n",
" util.PlotImageTensor(patches[idx])\n", " util.PlotImageTensor(patches[idx])\n",
" i += 1\n", " i += 1\n",
" if i == 16:\n", " if i == 28:\n",
" break\n" " break"
] ]
}, },
{ {
...@@ -149,13 +162,15 @@ ...@@ -149,13 +162,15 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"from deep_view_syn.data.spherical_view_syn import SphericalViewSynDataset\n", "from data.spherical_view_syn import SphericalViewSynDataset\n",
"from data.loader import FastDataLoader\n",
"\n", "\n",
"DATA_DIR = '../data/sp_view_syn_2020.12.26'\n", "DATA_DIR = '../data/nerf_fern'\n",
"TRAIN_DATA_DESC_FILE = DATA_DIR + '/train.json'\n", "TRAIN_DATA_DESC_FILE = DATA_DIR + '/train.json'\n",
"DEPTH_RANGE = (1, 10)\n", "DEPTH_RANGE = (1, 10)\n",
"N_DEPTH_LAYERS = 10\n", "N_DEPTH_LAYERS = 10\n",
"\n", "\n",
"\n",
"def _GetSphereLayers(depth_range: Tuple[float, float], n_layers: int) -> torch.Tensor:\n", "def _GetSphereLayers(depth_range: Tuple[float, float], n_layers: int) -> torch.Tensor:\n",
" diopter_range = (1 / depth_range[1], 1 / depth_range[0])\n", " diopter_range = (1 / depth_range[1], 1 / depth_range[0])\n",
" step = (diopter_range[1] - diopter_range[0]) / (n_layers - 1)\n", " step = (diopter_range[1] - diopter_range[0]) / (n_layers - 1)\n",
...@@ -163,74 +178,54 @@ ...@@ -163,74 +178,54 @@
" depths += [1 / (diopter_range[0] + step * i) for i in range(n_layers)]\n", " depths += [1 / (diopter_range[0] + step * i) for i in range(n_layers)]\n",
" return torch.tensor(depths, device=device.GetDevice()).view(-1, 1)\n", " return torch.tensor(depths, device=device.GetDevice()).view(-1, 1)\n",
"\n", "\n",
"train_dataset = SphericalViewSynDataset(TRAIN_DATA_DESC_FILE)\n",
"train_data_loader = torch.utils.data.DataLoader(\n",
" dataset=train_dataset,\n",
" batch_size=4,\n",
" num_workers=8,\n",
" pin_memory=True,\n",
" shuffle=True,\n",
" drop_last=False)\n",
"print(len(train_data_loader))\n",
"\n",
"print(\"view_res\", train_dataset.view_res)\n",
"print(\"cam_params\", train_dataset.cam_params)\n",
"\n",
"msl_net = MslNet(train_dataset.cam_params,\n",
" _GetSphereLayers(DEPTH_RANGE, N_DEPTH_LAYERS),\n",
" train_dataset.view_res).to(device.GetDevice())\n",
"print(\"sphere layers\", msl_net.rendering.sphere_layers)\n",
"\n",
"p = None\n",
"v = None\n",
"centers = None\n",
"plt.figure(figsize=(6, 6))\n",
"for _, view_images, ray_positions, ray_directions in train_data_loader:\n",
" p = ray_positions\n",
" v = ray_directions\n",
" plt.subplot(2, 2, 1)\n",
" util.PlotImageTensor(view_images[0])\n",
" plt.subplot(2, 2, 2)\n",
" util.PlotImageTensor(view_images[1])\n",
" plt.subplot(2, 2, 3)\n",
" util.PlotImageTensor(view_images[2])\n",
" plt.subplot(2, 2, 4)\n",
" util.PlotImageTensor(view_images[3])\n",
" break\n",
"p_ = util.SphericalToCartesian(RayToSpherical(p.flatten(0, 1), v.flatten(0, 1),\n",
" torch.tensor([[1]], device=device.GetDevice()))) \\\n",
" .view(4, train_dataset.view_res[0], train_dataset.view_res[1], 3)\n",
"v = v.view(4, train_dataset.view_res[0], train_dataset.view_res[1], 3)[:, 0::50, 0::50, :].flatten(1, 2).cpu().numpy()\n",
"p_ = p_[:, 0::50, 0::50, :].flatten(1, 2).cpu().numpy()\n",
"\n",
"fig = plt.figure(figsize=(6, 6))\n",
"ax = fig.gca(projection='3d')\n",
"plt.xlabel('x')\n",
"plt.ylabel('z')\n",
"\n", "\n",
"PlotSphere(ax, 1)\n", "dataset = SphericalViewSynDataset(TRAIN_DATA_DESC_FILE)\n",
"dataset.set_patch_size(1)\n",
"data_loader = FastDataLoader(\n",
" dataset=dataset, batch_size=4096*16, shuffle=True, drop_last=False)\n",
"\n", "\n",
"ax.scatter([0], [0], [0], color=\"k\", s=10) # Center\n", "print(\"view_res\", dataset.view_res)\n",
"print(\"cam_params\", dataset.cam_params)\n",
"\n", "\n",
"colors = [ 'r', 'g', 'b', 'y' ]\n", "fig = plt.figure(figsize=(16, 40))\n",
"for i in range(4):\n",
" ax.scatter(p_[i, :, 0], p_[i, :, 2], p_[i, :, 1], color=colors[i], s=3)\n",
" for j in range(p_.shape[1]):\n",
" ax.plot([centers[i, 0], centers[i, 0] + v[i, j, 0]],\n",
" [centers[i, 2], centers[i, 2] + v[i, j, 2]],\n",
" [centers[i, 1], centers[i, 1] + v[i, j, 1]],\n",
" color=colors[i], linewidth=0.5, alpha=0.6)\n",
"\n", "\n",
"ax.set_xlim(-1, 1)\n", "for ri in range(0, 10):\n",
"ax.set_ylim(-1, 1)\n", " r = ri * 0.2 + 1\n",
"ax.set_zlim(-1, 1)\n", " p = None\n",
"\n", " centers = None\n",
"plt.show()\n" " pixels = None\n",
" for indices, patches, rays_o, rays_d in data_loader:\n",
" p = util.RaySphereIntersect(\n",
" rays_o, rays_d, torch.tensor([[r]], device=device.GetDevice()))[0] \\\n",
" .view(-1, 3).cpu().numpy()\n",
" centers = rays_o.view(-1, 3).cpu().numpy()\n",
" pixels = patches.view(-1, 3).cpu().numpy()\n",
" break\n",
" \n",
" ax = plt.subplot(5, 2, ri + 1, projection='3d')\n",
" #ax = plt.gca(projection='3d')\n",
" #ax = fig.gca()\n",
" plt.xlabel('x')\n",
" plt.ylabel('z')\n",
" plt.title('r = %f' % r)\n",
"\n",
" # PlotSphere(ax, 1)\n",
"\n",
" ax.scatter([0], [0], [0], color=\"k\", s=10)\n",
" ax.scatter(p[:, 0], p[:, 2], p[:, 1], color=pixels, s=0.5)\n",
"\n",
" #ax.set_xlim(-1, 1)\n",
" #ax.set_ylim(-1, 1)\n",
" #ax.set_zlim(-1, 1)\n",
" ax.view_init(elev=0,azim=-90)\n",
"\n"
] ]
}, },
{ {
"cell_type": "markdown", "cell_type": "code",
"execution_count": 3,
"metadata": {}, "metadata": {},
"outputs": [],
"source": [ "source": [
"# Test Sampler" "# Test Sampler"
] ]
...@@ -241,7 +236,7 @@ ...@@ -241,7 +236,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"from deep_view_syn.data.spherical_view_syn import SphericalViewSynDataset\n", "from data.spherical_view_syn import SphericalViewSynDataset\n",
"\n", "\n",
"DATA_DIR = '../data/sp_view_syn_2020.12.29_finetrans'\n", "DATA_DIR = '../data/sp_view_syn_2020.12.29_finetrans'\n",
"TRAIN_DATA_DESC_FILE = DATA_DIR + '/train.json'\n", "TRAIN_DATA_DESC_FILE = DATA_DIR + '/train.json'\n",
...@@ -304,7 +299,7 @@ ...@@ -304,7 +299,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"from deep_view_syn.data.spherical_view_syn import SphericalViewSynDataset\n", "from data.spherical_view_syn import SphericalViewSynDataset\n",
"\n", "\n",
"DATA_DIR = '../data/sp_view_syn_2020.12.26_rotonly'\n", "DATA_DIR = '../data/sp_view_syn_2020.12.26_rotonly'\n",
"TRAIN_DATA_DESC_FILE = DATA_DIR + '/train.json'\n", "TRAIN_DATA_DESC_FILE = DATA_DIR + '/train.json'\n",
...@@ -367,8 +362,10 @@ ...@@ -367,8 +362,10 @@
] ]
}, },
{ {
"cell_type": "markdown", "cell_type": "code",
"execution_count": null,
"metadata": {}, "metadata": {},
"outputs": [],
"source": [ "source": [
"# Test Spherical View Synthesis" "# Test Spherical View Synthesis"
] ]
...@@ -381,9 +378,9 @@ ...@@ -381,9 +378,9 @@
"source": [ "source": [
"import ipywidgets as widgets # 控件库\n", "import ipywidgets as widgets # 控件库\n",
"from IPython.display import display # 显示控件的方法\n", "from IPython.display import display # 显示控件的方法\n",
"from deep_view_syn.data.spherical_view_syn import SphericalViewSynDataset\n", "from data.spherical_view_syn import SphericalViewSynDataset\n",
"from deep_view_syn.spher_net import SpherNet\n", "from nets.spher_net import SpherNet\n",
"from deep_view_syn.my import netio\n", "from my import netio\n",
"\n", "\n",
"DATA_DIR = '../data/sp_view_syn_2020.12.28_small'\n", "DATA_DIR = '../data/sp_view_syn_2020.12.28_small'\n",
"DATA_DESC_FILE = DATA_DIR + '/train.json'\n", "DATA_DESC_FILE = DATA_DIR + '/train.json'\n",
...@@ -455,20 +452,12 @@ ...@@ -455,20 +452,12 @@
"})\n", "})\n",
"display(slider_x, slider_y, slider_z, slider_theta, slider_phi, out)\n" "display(slider_x, slider_y, slider_z, slider_theta, slider_phi, out)\n"
] ]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
} }
], ],
"metadata": { "metadata": {
"kernelspec": { "kernelspec": {
"display_name": "Python 3", "display_name": "Python 3.7.9 64-bit ('pytorch': conda)",
"language": "python", "name": "python379jvsc74a57bd0660ca2a75467d3af74a68fcc6f40bc78ab96b99ff17d2f100b5ca821fbb183f2"
"name": "python3"
}, },
"language_info": { "language_info": {
"codemirror_mode": { "codemirror_mode": {
...@@ -480,7 +469,7 @@ ...@@ -480,7 +469,7 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.7.6-final" "version": "3.7.9"
} }
}, },
"nbformat": 4, "nbformat": 4,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment