Commit f6604bd2 authored by Nianchen Deng's avatar Nianchen Deng
Browse files

rebuttal version

parent 6e54b394
from ..my import color_mode
from my import color_mode
def update_config(config):
# Dataset settings
......
from ..my import color_mode
from my import color_mode
def update_config(config):
# Dataset settings
......
from ..my import color_mode
from my import color_mode
def update_config(config):
# Dataset settings
......
from ..my import color_mode
from my import color_mode
def update_config(config):
# Dataset settings
......
import os
import importlib
from os.path import join
from ..my import color_mode
from ..nets.msl_net import MslNet
from ..nets.msl_net_new import NewMslNet
from ..nets.spher_net import SpherNet
from my import color_mode
from nets.msl_net import MslNet
from nets.msl_net_new import NewMslNet
class SphericalViewSynConfig(object):
......@@ -36,14 +34,13 @@ class SphericalViewSynConfig(object):
def load(self, path):
module_name = os.path.splitext(path)[0].replace('/', '.')
config_module = importlib.import_module(
'deep_view_syn.' + module_name)
config_module = importlib.import_module(module_name)
config_module.update_config(self)
self.name = module_name.split('.')[-1]
def load_by_name(self, name):
config_module = importlib.import_module(
'deep_view_syn.configs.' + name)
'configs.' + name)
config_module.update_config(self)
self.name = name
......
from ..my import color_mode
from my import color_mode
def update_config(config):
# Dataset settings
......
from ..my import color_mode
from my import color_mode
def update_config(config):
# Dataset settings
......
from ..my import color_mode
from my import color_mode
def update_config(config):
# Dataset settings
......
......@@ -10,7 +10,7 @@ import plotly.express as px
import pandas as pd
from dash.dependencies import Input, Output
sys.path.append(os.path.abspath(sys.path[0] + '/../'))
#sys.path.append(os.path.abspath(sys.path[0] + '/../'))
#__package__ = "deep_view_syn"
if __name__ == '__main__':
......@@ -24,23 +24,30 @@ if __name__ == '__main__':
print("Set CUDA:%d as current device." % torch.cuda.current_device())
torch.autograd.set_grad_enabled(False)
from deep_view_syn.data.spherical_view_syn import *
from deep_view_syn.configs.spherical_view_syn import SphericalViewSynConfig
from deep_view_syn.my import netio
from deep_view_syn.my import util
from deep_view_syn.my import device
from deep_view_syn.my import view
from deep_view_syn.my.gen_final import GenFinal
from deep_view_syn.nets.modules import Sampler
datadir = None
from data.spherical_view_syn import *
from configs.spherical_view_syn import SphericalViewSynConfig
from my import netio
from my import util
from my import device
from my import view
from my.gen_final import GenFinal
from nets.modules import Sampler
datadir = 'data/__0_user_study/us_gas_periph_r135x135_t0.3_2021.01.16/'
data_desc_file = 'train.json'
net_config = 'periph_rgb@msl-rgb_e10_fc96x4_d1.00-50.00_s16'
net_path = datadir + net_config + '/model-epoch_200.pth'
fov = 45
res = (256, 256)
view_idx = 4
center = (0, 0)
def load_net(path):
print(path)
config = SphericalViewSynConfig()
config.from_id(os.path.splitext(os.path.basename(path))[0])
config.from_id(net_config)
config.SAMPLE_PARAMS['perturb_sample'] = False
net = config.create_net().to(device.GetDevice())
netio.LoadNet(path, net)
......@@ -64,24 +71,25 @@ def load_views(data_desc_file) -> view.Trans:
return view.Trans(view_centers, view_rots)
scenes = {
'gas': '__0_user_study/us_gas_all_in_one',
'mc': '__0_user_study/us_mc_all_in_one',
'bedroom': 'bedroom_all_in_one',
'gallery': 'gallery_all_in_one',
'lobby': 'lobby_all_in_one'
}
fov_list = [20, 45, 110]
res_list = [(128, 128), (256, 256), (256, 230)]
res_full = (1600, 1440)
cam = view.CameraParam({
'fov': fov,
'cx': 0.5,
'cy': 0.5,
'normalized': True
}, res, device=device.GetDevice())
net = load_net(net_path)
sampler = Sampler(depth_range=(1, 50), n_samples=32, perturb_sample=False,
spherical=True, lindisp=True, inverse_r=True)
x = y = None
views = load_views(data_desc_file)
print('%d Views loaded.' % views.size()[0])
scene = 'gas'
view_file = 'views.json'
test_view = views.get(view_idx)
rays_o, rays_d = cam.get_global_rays(test_view, True)
image = net(rays_o.view(-1, 3), rays_d.view(-1, 3)).view(1,
res[0], res[1], -1).permute(0, 3, 1, 2)
app = dash.Dash(__name__, external_stylesheets=[
'https://codepen.io/chriddyp/pen/bWLwgP.css'])
styles = {
'pre': {
......@@ -90,35 +98,17 @@ styles = {
'overflowX': 'scroll'
}
}
datadir = 'data/' + scenes[scene] + '/'
fovea_net = load_net_by_name('fovea')
periph_net = load_net_by_name('periph')
gen = GenFinal(fov_list, res_list, res_full, fovea_net, periph_net,
device=device.GetDevice())
sampler = Sampler(depth_range=(1, 50), n_samples=32, perturb_sample=False,
spherical=True, lindisp=True, inverse_r=True)
x = y = None
views = load_views(view_file)
print('%d Views loaded.', views.size())
view_idx = 27
center = (0, 0)
test_view = views.get(view_idx)
images = gen(center, test_view)
fig = px.imshow(util.Tensor2MatImg(images['fovea']))
fig = px.imshow(util.Tensor2MatImg(image))
fig1 = px.scatter(x=[0, 1, 2], y=[2, 0, 1])
fig2 = px.scatter(x=[0, 1, 2], y=[2, 0, 1])
app = dash.Dash(__name__, external_stylesheets=[
'https://codepen.io/chriddyp/pen/bWLwgP.css'])
app.layout = html.Div([
html.H3("Drag and draw annotations"),
html.Div(className='row', children=[
dcc.Graph(id='image', figure=fig), # , config=config),
dcc.Graph(id='scatter', figure=fig1), # , config=config),
dcc.Graph(id='scatter1', figure=fig2), # , config=config),
dcc.Slider(id='samples-slider', min=4, max=128, step=None,
marks={
4: '4',
......@@ -128,43 +118,91 @@ app.layout = html.Div([
64: '64',
128: '128',
},
value=32,
value=33,
updatemode='drag'
)
])
])
def raw2alpha(raw, dists, act_fn=torch.relu):
"""
Function for computing density from model prediction.
This value is strictly between [0, 1].
"""
print('act_fn: ', act_fn(raw))
print('act_fn * dists: ', act_fn(raw) * dists)
return -torch.exp(-act_fn(raw) * dists) + 1.0
def raw2color(raw: torch.Tensor, z_vals: torch.Tensor):
"""
Raw value inferred from model to color and alpha
:param raw ```Tensor(N.rays, N.samples, 2|4)```: model's output
:param z_vals ```Tensor(N.rays, N.samples)```: integration time
:return ```Tensor(N.rays, N.samples, 1|3)```: color
:return ```Tensor(N.rays, N.samples)```: alpha
"""
# Compute 'distance' (in time) between each integration time along a ray.
# The 'distance' from the last integration time is infinity.
# dists: (N_rays, N_samples)
dists = z_vals[..., 1:] - z_vals[..., :-1]
last_dist = z_vals[..., 0:1] * 0 + 1e10
dists = torch.cat([dists, last_dist], -1)
print('dists: ', dists)
# Extract RGB of each sample position along each ray.
color = torch.sigmoid(raw[..., :-1]) # (N_rays, N_samples, 1|3)
alpha = raw2alpha(raw[..., -1], dists)
return color, alpha
def draw_scatter():
global fig1
p = torch.tensor([x, y], device=gen.layer_cams[0].c.device)
ray_d = test_view.trans_vector(gen.layer_cams[0].unproj(p))
global fig1, fig2
p = torch.tensor([x, y], device=device.GetDevice())
ray_d = test_view.trans_vector(cam.unproj(p))
ray_o = test_view.t
raw, depths = fovea_net.sample_and_infer(ray_o, ray_d, sampler=sampler)
colors, alphas = fovea_net.rendering.raw2color(raw, depths)
raw, depths = net.sample_and_infer(ray_o, ray_d, sampler=sampler)
colors, alphas = raw2color(raw, depths)
scatter_x = (1 / depths[0]).cpu().detach().numpy()
scatter_y = alphas[0].cpu().detach().numpy()
scatter_y1 = raw[0, :, 3].cpu().detach().numpy()
scatter_color = colors[0].cpu().detach().numpy() * 255
marker_colors = [
i#'rgb(%d,%d,%d)' % (scatter_color[i][0], scatter_color[i][1], scatter_color[i][2])
# 'rgb(%d,%d,%d)' % (scatter_color[i][0], scatter_color[i][1], scatter_color[i][2])
i
for i in range(scatter_color.shape[0])
]
marker_colors_str = [
'rgb(%d,%d,%d)' % (scatter_color[i][0], scatter_color[i][1], scatter_color[i][2])
'rgb(%d,%d,%d)' % (scatter_color[i][0],
scatter_color[i][1], scatter_color[i][2])
for i in range(scatter_color.shape[0])
]
fig1 = px.scatter(x=scatter_x, y=scatter_y, color=marker_colors, color_continuous_scale=marker_colors_str)#, color_discrete_map='identity')
fig1 = px.scatter(x=scatter_x, y=scatter_y, color=marker_colors,
color_continuous_scale=marker_colors_str) # , color_discrete_map='identity')
fig1.update_traces(mode='lines+markers')
fig1.update_xaxes(showgrid=False)
fig1.update_yaxes(type='linear')
fig1.update_layout(height=225, margin={'l': 20, 'b': 30, 'r': 10, 't': 10})
fig2 = px.scatter(x=scatter_x, y=scatter_y1, color=marker_colors,
color_continuous_scale=marker_colors_str) # , color_discrete_map='identity')
fig2.update_traces(mode='lines+markers')
fig2.update_xaxes(showgrid=False)
fig2.update_yaxes(type='linear')
fig2.update_layout(height=225, margin={'l': 20, 'b': 30, 'r': 10, 't': 10})
@app.callback(
[Output('image', 'figure'),
Output('scatter', 'figure')],
Output('scatter', 'figure'),
Output('scatter1', 'figure')],
[Input('image', 'clickData'),
dash.dependencies.Input('samples-slider', 'value')]
)
......@@ -194,7 +232,7 @@ def display_hover_data(clickData, samples):
color="LightSeaGreen",
width=3,
))
return fig, fig1
return fig, fig1, fig2
if __name__ == '__main__':
......
from typing import List, Tuple
import torch
import json
from ..my import util
def ReadLightField(path: str, views: Tuple[int, int], flatten_views: bool = False) -> torch.Tensor:
input_img = util.ReadImageTensor(path, batch_dim=False)
h = input_img.size()[1] // views[0]
w = input_img.size()[2] // views[1]
if flatten_views:
lf = torch.empty(views[0] * views[1], 3, h, w)
for y_i in range(views[0]):
for x_i in range(views[1]):
lf[y_i * views[1] + x_i, :, :, :] = \
input_img[:, y_i * h:(y_i + 1) * h, x_i * w:(x_i + 1) * w]
else:
lf = torch.empty(views[0], views[1], 3, h, w)
for y_i in range(views[0]):
for x_i in range(views[1]):
lf[y_i, x_i, :, :, :] = \
input_img[:, y_i * h:(y_i + 1) * h, x_i * w:(x_i + 1) * w]
return lf
def DecodeDepth(depth_images: torch.Tensor) -> torch.Tensor:
return depth_images[:, 0].unsqueeze(1).mul(255) / 10
class LightFieldSynDataset(torch.utils.data.dataset.Dataset):
"""
Data loader for light field synthesis task
Attributes
--------
data_dir ```string```: the directory of dataset\n
n_views ```tuple(int, int)```: rows and columns of views\n
num_views ```int```: number of views\n
view_images ```N x H x W Tensor```: images of views\n
view_depths ```N x H x W Tensor```: depths of views\n
view_positions ```N x 3 Tensor```: positions of views\n
sparse_view_images ```N' x H x W Tensor```: images of sparse views\n
sparse_view_depths ```N' x H x W Tensor```: depths of sparse views\n
sparse_view_positions ```N' x 3 Tensor```: positions of sparse views\n
"""
def __init__(self, data_desc_path: str):
"""
Initialize data loader for light field synthesis task
The data description file is a JSON file with following fields:
- lf: string, the path of light field image
- lf_depth: string, the path of light field depth image
- n_views: { "x", "y" }, columns and rows of views
- cam_params: { "f", "c" }, the focal and center of camera (in normalized image space)
- depth_range: [ min, max ], the range of depth in depth maps
- depth_layers: int, number of layers in depth maps
- view_positions: [ [ x, y, z ], ... ], positions of views
:param data_desc_path: path to the data description file
"""
self.data_dir = data_desc_path.rsplit('/', 1)[0] + '/'
with open(data_desc_path, 'r', encoding='utf-8') as file:
self.data_desc = json.loads(file.read())
self.n_views = (self.data_desc['n_views']
['y'], self.data_desc['n_views']['x'])
self.num_views = self.n_views[0] * self.n_views[1]
self.view_images = ReadLightField(
self.data_dir + self.data_desc['lf'], self.n_views, True)
self.view_depths = DecodeDepth(ReadLightField(
self.data_dir + self.data_desc['lf_depth'], self.n_views, True))
self.cam_params = self.data_desc['cam_params']
self.depth_range = self.data_desc['depth_range']
self.depth_layers = self.data_desc['depth_layers']
self.view_positions = torch.tensor(self.data_desc['view_positions'])
_, self.sparse_view_images, self.sparse_view_depths, self.sparse_view_positions \
= self._GetCornerViews()
self.diopter_of_layers = self._GetDiopterOfLayers()
def __len__(self):
return self.num_views
def __getitem__(self, idx):
return idx, self.view_images[idx], self.view_depths[idx], self.view_positions[idx]
def _GetCornerViews(self):
corner_selector = torch.zeros(self.num_views, dtype=torch.bool)
corner_selector[0] = corner_selector[self.n_views[1] - 1] \
= corner_selector[self.num_views - self.n_views[1]] \
= corner_selector[self.num_views - 1] = True
return self.__getitem__(corner_selector)
def _GetDiopterOfLayers(self) -> List[float]:
diopter_range = (1 / self.depth_range[1], 1 / self.depth_range[0])
step = (diopter_range[1] - diopter_range[0]) / (self.depth_layers - 1)
diopter_of_layers = [diopter_range[0] + step * i for i in range(self.depth_layers)]
diopter_of_layers.insert(0, 0)
return diopter_of_layers
import torch
import math
from ..my import device
from my import device
class FastDataLoader(object):
......
import torch
import os
import json
import glob
import cv2
import numpy as np
import torchvision.transforms as transforms
from typing import List, Tuple
from torchvision import datasets
from torch.utils.data import DataLoader
import cv2
import json
from Flow import *
from gen_image import *
import util
from my.flow import *
from my.gen_image import *
from my import util
def ReadLightField(path: str, views: Tuple[int, int], flatten_views: bool = False) -> torch.Tensor:
input_img = util.ReadImageTensor(path, batch_dim=False)
h = input_img.size()[1] // views[0]
w = input_img.size()[2] // views[1]
if flatten_views:
lf = torch.empty(views[0] * views[1], 3, h, w)
for y_i in range(views[0]):
for x_i in range(views[1]):
lf[y_i * views[1] + x_i, :, :, :] = \
input_img[:, y_i * h:(y_i + 1) * h, x_i * w:(x_i + 1) * w]
else:
lf = torch.empty(views[0], views[1], 3, h, w)
for y_i in range(views[0]):
for x_i in range(views[1]):
lf[y_i, x_i, :, :, :] = \
input_img[:, y_i * h:(y_i + 1) * h, x_i * w:(x_i + 1) * w]
return lf
def DecodeDepth(depth_images: torch.Tensor) -> torch.Tensor:
return depth_images[:, 0].unsqueeze(1).mul(255) / 10
class LightFieldSynDataset(torch.utils.data.dataset.Dataset):
"""
Data loader for light field synthesis task
Attributes
--------
data_dir ```string```: the directory of dataset\n
n_views ```tuple(int, int)```: rows and columns of views\n
num_views ```int```: number of views\n
view_images ```N x H x W Tensor```: images of views\n
view_depths ```N x H x W Tensor```: depths of views\n
view_positions ```N x 3 Tensor```: positions of views\n
sparse_view_images ```N' x H x W Tensor```: images of sparse views\n
sparse_view_depths ```N' x H x W Tensor```: depths of sparse views\n
sparse_view_positions ```N' x 3 Tensor```: positions of sparse views\n
"""
def __init__(self, data_desc_path: str):
"""
Initialize data loader for light field synthesis task
The data description file is a JSON file with following fields:
- lf: string, the path of light field image
- lf_depth: string, the path of light field depth image
- n_views: { "x", "y" }, columns and rows of views
- cam_params: { "f", "c" }, the focal and center of camera (in normalized image space)
- depth_range: [ min, max ], the range of depth in depth maps
- depth_layers: int, number of layers in depth maps
- view_positions: [ [ x, y, z ], ... ], positions of views
:param data_desc_path: path to the data description file
"""
self.data_dir = data_desc_path.rsplit('/', 1)[0] + '/'
with open(data_desc_path, 'r', encoding='utf-8') as file:
self.data_desc = json.loads(file.read())
self.n_views = (self.data_desc['n_views']
['y'], self.data_desc['n_views']['x'])
self.num_views = self.n_views[0] * self.n_views[1]
self.view_images = ReadLightField(
self.data_dir + self.data_desc['lf'], self.n_views, True)
self.view_depths = DecodeDepth(ReadLightField(
self.data_dir + self.data_desc['lf_depth'], self.n_views, True))
self.cam_params = self.data_desc['cam_params']
self.depth_range = self.data_desc['depth_range']
self.depth_layers = self.data_desc['depth_layers']
self.view_positions = torch.tensor(self.data_desc['view_positions'])
_, self.sparse_view_images, self.sparse_view_depths, self.sparse_view_positions \
= self._GetCornerViews()
self.diopter_of_layers = self._GetDiopterOfLayers()
def __len__(self):
return self.num_views
def __getitem__(self, idx):
return idx, self.view_images[idx], self.view_depths[idx], self.view_positions[idx]
def _GetCornerViews(self):
corner_selector = torch.zeros(self.num_views, dtype=torch.bool)
corner_selector[0] = corner_selector[self.n_views[1] - 1] \
= corner_selector[self.num_views - self.n_views[1]] \
= corner_selector[self.num_views - 1] = True
return self.__getitem__(corner_selector)
import time
def _GetDiopterOfLayers(self) -> List[float]:
diopter_range = (1 / self.depth_range[1], 1 / self.depth_range[0])
step = (diopter_range[1] - diopter_range[0]) / (self.depth_layers - 1)
diopter_of_layers = [diopter_range[0] +
step * i for i in range(self.depth_layers)]
diopter_of_layers.insert(0, 0)
return diopter_of_layers
class lightFieldSynDataLoader(torch.utils.data.dataset.Dataset):
......@@ -90,8 +185,8 @@ class lightFieldDataLoader(torch.utils.data.dataset.Dataset):
# print(lf_image_big.shape)
for i in range(9):
lf_image = lf_image_big[i//3*IM_H:i//3 *
IM_H+IM_H, i % 3*IM_W:i % 3*IM_W+IM_W, 0:3]
lf_image = lf_image_big[i // 3 * IM_H:i // 3 *
IM_H + IM_H, i % 3 * IM_W:i % 3 * IM_W + IM_W, 0:3]
# IF GrayScale
# lf_image = lf_image_big[i//3*IM_H:i//3*IM_H+IM_H,i%3*IM_W:i%3*IM_W+IM_W,0:1]
# print(lf_image.shape)
......@@ -146,8 +241,8 @@ class lightFieldValDataLoader(torch.utils.data.dataset.Dataset):
# print(lf_image_big.shape)
for i in range(9):
lf_image = lf_image_big[i//3*IM_H:i//3 *
IM_H+IM_H, i % 3*IM_W:i % 3*IM_W+IM_W, 0:3]
lf_image = lf_image_big[i // 3 * IM_H:i // 3 *
IM_H + IM_H, i % 3 * IM_W:i % 3 * IM_W + IM_W, 0:3]
# IF GrayScale
# lf_image = lf_image_big[i//3*IM_H:i//3*IM_H+IM_H,i%3*IM_W:i%3*IM_W+IM_W,0:1]
# print(lf_image.shape)
......@@ -214,8 +309,8 @@ class lightFieldSeqDataLoader(torch.utils.data.dataset.Dataset):
lf_image_big = cv2.cvtColor(lf_image_big, cv2.COLOR_BGR2RGB)
for j in range(9):
lf_image = lf_image_big[j//3*IM_H:j//3 *
IM_H+IM_H, j % 3*IM_W:j % 3*IM_W+IM_W, 0:3]
lf_image = lf_image_big[j // 3 * IM_H:j // 3 *
IM_H + IM_H, j % 3 * IM_W:j % 3 * IM_W + IM_W, 0:3]
lf_image_one_sample.append(lf_image)
gt_i = cv2.imread(fd_gt_path, cv2.IMREAD_UNCHANGED).astype(
......@@ -297,8 +392,8 @@ class lightFieldFlowSeqDataLoader(torch.utils.data.dataset.Dataset):
lf_dim = int(self.conf.light_field_dim)
for j in range(lf_dim**2):
lf_image = lf_image_big[j//lf_dim*IM_H:j//lf_dim *
IM_H+IM_H, j % lf_dim*IM_W:j % lf_dim*IM_W+IM_W, 0:3]
lf_image = lf_image_big[j // lf_dim * IM_H:j // lf_dim *
IM_H + IM_H, j % lf_dim * IM_W:j % lf_dim * IM_W + IM_W, 0:3]
lf_image_one_sample.append(lf_image)
gt_i = cv2.imread(fd_gt_path, cv2.IMREAD_UNCHANGED).astype(
......@@ -333,7 +428,7 @@ class lightFieldFlowSeqDataLoader(torch.utils.data.dataset.Dataset):
retinal_invalid.append(retinal_invalid_i)
# lf_images: 5,9,320,320
flow = Flow.Load([os.path.join(self.file_dir_path, self.dataset_desc["flow"]
[indices[i-1]]) for i in range(1, len(indices))])
[indices[i - 1]]) for i in range(1, len(indices))])
flow_map = flow.getMap()
flow_invalid_mask = flow.b_invalid_mask
# print("flow:",flow_map.shape)
......
......@@ -4,10 +4,10 @@ import torch
import torchvision.transforms.functional as trans_f
import torch.nn.functional as nn_f
from typing import Tuple, Union
from ..my import util
from ..my import device
from ..my import view
from ..my import color_mode
from my import util
from my import device
from my import view
from my import color_mode
class SphericalViewSynDataset(object):
......@@ -129,6 +129,13 @@ class SphericalViewSynDataset(object):
self.n_views = self.view_centers.size(0)
self.n_pixels = self.n_views * self.view_res[0] * self.view_res[1]
if 'gl_coord' in data_desc and data_desc['gl_coord'] == True:
print('Convert from OGL coordinate to DX coordinate (i. e. right-hand to left-hand)')
self.cam_params.f[1] *= -1
self.view_centers[:, 2] *= -1
self.view_rots[:, 2] *= -1
self.view_rots[..., 2] *= -1
def set_patch_size(self, patch_size: Union[int, Tuple[int, int]],
offset: Union[int, Tuple[int, int]] = 0):
"""
......
from typing import List, Tuple
import torch
import torch.nn as nn
from ..my import device
from ..my import util
from my import device
from my import util
class FcLayer(nn.Module):
......
......@@ -2,9 +2,8 @@ import math
import torch
import torch.nn as nn
from .modules import *
from ..my import util
from ..my import color_mode
from my import util
from my import color_mode
class MslNet(nn.Module):
......
import torch
import torch.nn as nn
from .modules import *
from ..my import color_mode
from ..my.simple_perf import SimplePerf
from my import color_mode
from my.simple_perf import SimplePerf
class NewMslNet(nn.Module):
......
......@@ -2,10 +2,10 @@ from typing import Tuple
import math
import torch
import torch.nn as nn
from ..my import net_modules
from ..my import util
from ..my import device
from ..my import color_mode
from my import net_modules
from my import util
from my import device
from my import color_mode
from .msl_net_new import NewMslNet
......
import torch
import torch.nn as nn
from .modules import *
from ..my import util
from my import util
class SpherNet(nn.Module):
......
from typing import List
import torch
import torch.nn as nn
from ..pytorch_prototyping.pytorch_prototyping import *
from ..my import util
from ..my import device
from pytorch_prototyping.pytorch_prototyping import *
from my import util
from my import device
class Encoder(nn.Module):
......
......@@ -2,20 +2,28 @@
"cells": [
{
"cell_type": "code",
"execution_count": null,
"execution_count": 2,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Set CUDA:2 as current device.\n"
]
}
],
"source": [
"import sys\n",
"import os\n",
"sys.path.append(os.path.abspath(sys.path[0] + '/../../'))\n",
"sys.path.append(os.path.abspath(sys.path[0] + '/../'))\n",
"\n",
"import torch\n",
"import math\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"from deep_view_syn.my import util\n",
"from deep_view_syn.msl_net import *\n",
"from my import util\n",
"from nets.msl_net import *\n",
"\n",
"# Select device\n",
"torch.cuda.set_device(2)\n",
......@@ -23,8 +31,10 @@
]
},
{
"cell_type": "markdown",
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"# Test Ray-Sphere Intersection & Cartesian-Spherical Conversion"
]
......@@ -56,14 +66,15 @@
"v = torch.tensor([[0.0, -1.0, 1.0]])\n",
"r = torch.tensor([[2.5]])\n",
"v = v / torch.norm(v) * r * 2\n",
"p_on_sphere_ = RaySphereIntersect(p, v, r)[0]\n",
"p_on_sphere_ = util.RaySphereIntersect(p, v, r)[0][0]\n",
"print(p_on_sphere_)\n",
"print(p_on_sphere_.norm())\n",
"spher_coord = RayToSpherical(p, v, r)\n",
"spher_coord = util.CartesianToSpherical(p_on_sphere_)\n",
"print(spher_coord[..., 1:3].rad2deg())\n",
"p_on_sphere = util.SphericalToCartesian(spher_coord)[0]\n",
"p_on_sphere = util.SphericalToCartesian(spher_coord)\n",
"print(p_on_sphere_.size())\n",
"\n",
"fig = plt.figure(figsize=(6, 6))\n",
"fig = plt.figure(figsize=(8, 8))\n",
"ax = fig.gca(projection='3d')\n",
"plt.xlabel('x')\n",
"plt.ylabel('z')\n",
......@@ -109,8 +120,10 @@
]
},
{
"cell_type": "markdown",
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Test Dataset Loader & View-Spherical Transform"
]
......@@ -121,26 +134,26 @@
"metadata": {},
"outputs": [],
"source": [
"from deep_view_syn.data.spherical_view_syn import FastSphericalViewSynDataset\n",
"from deep_view_syn.data.spherical_view_syn import FastDataLoader\n",
"from data.spherical_view_syn import SphericalViewSynDataset\n",
"from data.loader import FastDataLoader\n",
"\n",
"DATA_DIR = '../data/sp_view_syn_2020.12.28'\n",
"DATA_DIR = '../data/nerf_fern'\n",
"TRAIN_DATA_DESC_FILE = DATA_DIR + '/train.json'\n",
"\n",
"dataset = FastSphericalViewSynDataset(TRAIN_DATA_DESC_FILE)\n",
"dataset = SphericalViewSynDataset(TRAIN_DATA_DESC_FILE)\n",
"dataset.set_patch_size((64, 64))\n",
"data_loader = FastDataLoader(dataset=dataset, batch_size=4, shuffle=False, drop_last=False)\n",
"print(len(dataset))\n",
"plt.figure()\n",
"fig = plt.figure(figsize=(12, 6.5))\n",
"i = 0\n",
"for indices, patches, rays_o, rays_d in data_loader:\n",
" print(i, patches.size(), rays_o.size(), rays_d.size())\n",
" for idx in range(len(indices)):\n",
" plt.subplot(4, 4, i + 1)\n",
" plt.subplot(4, 7, i + 1)\n",
" util.PlotImageTensor(patches[idx])\n",
" i += 1\n",
" if i == 16:\n",
" break\n"
" if i == 28:\n",
" break"
]
},
{
......@@ -149,13 +162,15 @@
"metadata": {},
"outputs": [],
"source": [
"from deep_view_syn.data.spherical_view_syn import SphericalViewSynDataset\n",
"from data.spherical_view_syn import SphericalViewSynDataset\n",
"from data.loader import FastDataLoader\n",
"\n",
"DATA_DIR = '../data/sp_view_syn_2020.12.26'\n",
"DATA_DIR = '../data/nerf_fern'\n",
"TRAIN_DATA_DESC_FILE = DATA_DIR + '/train.json'\n",
"DEPTH_RANGE = (1, 10)\n",
"N_DEPTH_LAYERS = 10\n",
"\n",
"\n",
"def _GetSphereLayers(depth_range: Tuple[float, float], n_layers: int) -> torch.Tensor:\n",
" diopter_range = (1 / depth_range[1], 1 / depth_range[0])\n",
" step = (diopter_range[1] - diopter_range[0]) / (n_layers - 1)\n",
......@@ -163,74 +178,54 @@
" depths += [1 / (diopter_range[0] + step * i) for i in range(n_layers)]\n",
" return torch.tensor(depths, device=device.GetDevice()).view(-1, 1)\n",
"\n",
"train_dataset = SphericalViewSynDataset(TRAIN_DATA_DESC_FILE)\n",
"train_data_loader = torch.utils.data.DataLoader(\n",
" dataset=train_dataset,\n",
" batch_size=4,\n",
" num_workers=8,\n",
" pin_memory=True,\n",
" shuffle=True,\n",
" drop_last=False)\n",
"print(len(train_data_loader))\n",
"\n",
"print(\"view_res\", train_dataset.view_res)\n",
"print(\"cam_params\", train_dataset.cam_params)\n",
"\n",
"msl_net = MslNet(train_dataset.cam_params,\n",
" _GetSphereLayers(DEPTH_RANGE, N_DEPTH_LAYERS),\n",
" train_dataset.view_res).to(device.GetDevice())\n",
"print(\"sphere layers\", msl_net.rendering.sphere_layers)\n",
"\n",
"p = None\n",
"v = None\n",
"centers = None\n",
"plt.figure(figsize=(6, 6))\n",
"for _, view_images, ray_positions, ray_directions in train_data_loader:\n",
" p = ray_positions\n",
" v = ray_directions\n",
" plt.subplot(2, 2, 1)\n",
" util.PlotImageTensor(view_images[0])\n",
" plt.subplot(2, 2, 2)\n",
" util.PlotImageTensor(view_images[1])\n",
" plt.subplot(2, 2, 3)\n",
" util.PlotImageTensor(view_images[2])\n",
" plt.subplot(2, 2, 4)\n",
" util.PlotImageTensor(view_images[3])\n",
" break\n",
"p_ = util.SphericalToCartesian(RayToSpherical(p.flatten(0, 1), v.flatten(0, 1),\n",
" torch.tensor([[1]], device=device.GetDevice()))) \\\n",
" .view(4, train_dataset.view_res[0], train_dataset.view_res[1], 3)\n",
"v = v.view(4, train_dataset.view_res[0], train_dataset.view_res[1], 3)[:, 0::50, 0::50, :].flatten(1, 2).cpu().numpy()\n",
"p_ = p_[:, 0::50, 0::50, :].flatten(1, 2).cpu().numpy()\n",
"\n",
"fig = plt.figure(figsize=(6, 6))\n",
"ax = fig.gca(projection='3d')\n",
"plt.xlabel('x')\n",
"plt.ylabel('z')\n",
"\n",
"PlotSphere(ax, 1)\n",
"dataset = SphericalViewSynDataset(TRAIN_DATA_DESC_FILE)\n",
"dataset.set_patch_size(1)\n",
"data_loader = FastDataLoader(\n",
" dataset=dataset, batch_size=4096*16, shuffle=True, drop_last=False)\n",
"\n",
"ax.scatter([0], [0], [0], color=\"k\", s=10) # Center\n",
"print(\"view_res\", dataset.view_res)\n",
"print(\"cam_params\", dataset.cam_params)\n",
"\n",
"colors = [ 'r', 'g', 'b', 'y' ]\n",
"for i in range(4):\n",
" ax.scatter(p_[i, :, 0], p_[i, :, 2], p_[i, :, 1], color=colors[i], s=3)\n",
" for j in range(p_.shape[1]):\n",
" ax.plot([centers[i, 0], centers[i, 0] + v[i, j, 0]],\n",
" [centers[i, 2], centers[i, 2] + v[i, j, 2]],\n",
" [centers[i, 1], centers[i, 1] + v[i, j, 1]],\n",
" color=colors[i], linewidth=0.5, alpha=0.6)\n",
"fig = plt.figure(figsize=(16, 40))\n",
"\n",
"ax.set_xlim(-1, 1)\n",
"ax.set_ylim(-1, 1)\n",
"ax.set_zlim(-1, 1)\n",
"\n",
"plt.show()\n"
"for ri in range(0, 10):\n",
" r = ri * 0.2 + 1\n",
" p = None\n",
" centers = None\n",
" pixels = None\n",
" for indices, patches, rays_o, rays_d in data_loader:\n",
" p = util.RaySphereIntersect(\n",
" rays_o, rays_d, torch.tensor([[r]], device=device.GetDevice()))[0] \\\n",
" .view(-1, 3).cpu().numpy()\n",
" centers = rays_o.view(-1, 3).cpu().numpy()\n",
" pixels = patches.view(-1, 3).cpu().numpy()\n",
" break\n",
" \n",
" ax = plt.subplot(5, 2, ri + 1, projection='3d')\n",
" #ax = plt.gca(projection='3d')\n",
" #ax = fig.gca()\n",
" plt.xlabel('x')\n",
" plt.ylabel('z')\n",
" plt.title('r = %f' % r)\n",
"\n",
" # PlotSphere(ax, 1)\n",
"\n",
" ax.scatter([0], [0], [0], color=\"k\", s=10)\n",
" ax.scatter(p[:, 0], p[:, 2], p[:, 1], color=pixels, s=0.5)\n",
"\n",
" #ax.set_xlim(-1, 1)\n",
" #ax.set_ylim(-1, 1)\n",
" #ax.set_zlim(-1, 1)\n",
" ax.view_init(elev=0,azim=-90)\n",
"\n"
]
},
{
"cell_type": "markdown",
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"# Test Sampler"
]
......@@ -241,7 +236,7 @@
"metadata": {},
"outputs": [],
"source": [
"from deep_view_syn.data.spherical_view_syn import SphericalViewSynDataset\n",
"from data.spherical_view_syn import SphericalViewSynDataset\n",
"\n",
"DATA_DIR = '../data/sp_view_syn_2020.12.29_finetrans'\n",
"TRAIN_DATA_DESC_FILE = DATA_DIR + '/train.json'\n",
......@@ -304,7 +299,7 @@
"metadata": {},
"outputs": [],
"source": [
"from deep_view_syn.data.spherical_view_syn import SphericalViewSynDataset\n",
"from data.spherical_view_syn import SphericalViewSynDataset\n",
"\n",
"DATA_DIR = '../data/sp_view_syn_2020.12.26_rotonly'\n",
"TRAIN_DATA_DESC_FILE = DATA_DIR + '/train.json'\n",
......@@ -367,8 +362,10 @@
]
},
{
"cell_type": "markdown",
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Test Spherical View Synthesis"
]
......@@ -381,9 +378,9 @@
"source": [
"import ipywidgets as widgets # 控件库\n",
"from IPython.display import display # 显示控件的方法\n",
"from deep_view_syn.data.spherical_view_syn import SphericalViewSynDataset\n",
"from deep_view_syn.spher_net import SpherNet\n",
"from deep_view_syn.my import netio\n",
"from data.spherical_view_syn import SphericalViewSynDataset\n",
"from nets.spher_net import SpherNet\n",
"from my import netio\n",
"\n",
"DATA_DIR = '../data/sp_view_syn_2020.12.28_small'\n",
"DATA_DESC_FILE = DATA_DIR + '/train.json'\n",
......@@ -455,20 +452,12 @@
"})\n",
"display(slider_x, slider_y, slider_z, slider_theta, slider_phi, out)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
"display_name": "Python 3.7.9 64-bit ('pytorch': conda)",
"name": "python379jvsc74a57bd0660ca2a75467d3af74a68fcc6f40bc78ab96b99ff17d2f100b5ca821fbb183f2"
},
"language_info": {
"codemirror_mode": {
......@@ -480,7 +469,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.6-final"
"version": "3.7.9"
}
},
"nbformat": 4,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment