Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
Nianchen Deng
deeplightfield
Commits
f6604bd2
Commit
f6604bd2
authored
Mar 16, 2021
by
Nianchen Deng
Browse files
rebuttal version
parent
6e54b394
Changes
21
Hide whitespace changes
Inline
Side-by-side
configs/fovea_nmsl4.py
View file @
f6604bd2
from
..
my
import
color_mode
from
my
import
color_mode
def
update_config
(
config
):
def
update_config
(
config
):
# Dataset settings
# Dataset settings
...
...
configs/fovea_rgb.py
View file @
f6604bd2
from
..
my
import
color_mode
from
my
import
color_mode
def
update_config
(
config
):
def
update_config
(
config
):
# Dataset settings
# Dataset settings
...
...
configs/new_fovea_rgb.py
View file @
f6604bd2
from
..
my
import
color_mode
from
my
import
color_mode
def
update_config
(
config
):
def
update_config
(
config
):
# Dataset settings
# Dataset settings
...
...
configs/periph_rgb.py
View file @
f6604bd2
from
..
my
import
color_mode
from
my
import
color_mode
def
update_config
(
config
):
def
update_config
(
config
):
# Dataset settings
# Dataset settings
...
...
configs/spherical_view_syn.py
View file @
f6604bd2
import
os
import
os
import
importlib
import
importlib
from
os.path
import
join
from
my
import
color_mode
from
..my
import
color_mode
from
nets.msl_net
import
MslNet
from
..nets.msl_net
import
MslNet
from
nets.msl_net_new
import
NewMslNet
from
..nets.msl_net_new
import
NewMslNet
from
..nets.spher_net
import
SpherNet
class
SphericalViewSynConfig
(
object
):
class
SphericalViewSynConfig
(
object
):
...
@@ -36,14 +34,13 @@ class SphericalViewSynConfig(object):
...
@@ -36,14 +34,13 @@ class SphericalViewSynConfig(object):
def
load
(
self
,
path
):
def
load
(
self
,
path
):
module_name
=
os
.
path
.
splitext
(
path
)[
0
].
replace
(
'/'
,
'.'
)
module_name
=
os
.
path
.
splitext
(
path
)[
0
].
replace
(
'/'
,
'.'
)
config_module
=
importlib
.
import_module
(
config_module
=
importlib
.
import_module
(
module_name
)
'deep_view_syn.'
+
module_name
)
config_module
.
update_config
(
self
)
config_module
.
update_config
(
self
)
self
.
name
=
module_name
.
split
(
'.'
)[
-
1
]
self
.
name
=
module_name
.
split
(
'.'
)[
-
1
]
def
load_by_name
(
self
,
name
):
def
load_by_name
(
self
,
name
):
config_module
=
importlib
.
import_module
(
config_module
=
importlib
.
import_module
(
'
deep_view_syn.
configs.'
+
name
)
'configs.'
+
name
)
config_module
.
update_config
(
self
)
config_module
.
update_config
(
self
)
self
.
name
=
name
self
.
name
=
name
...
...
configs/us_fovea.py
View file @
f6604bd2
from
..
my
import
color_mode
from
my
import
color_mode
def
update_config
(
config
):
def
update_config
(
config
):
# Dataset settings
# Dataset settings
...
...
configs/us_periph.py
View file @
f6604bd2
from
..
my
import
color_mode
from
my
import
color_mode
def
update_config
(
config
):
def
update_config
(
config
):
# Dataset settings
# Dataset settings
...
...
configs/us_periph_new.py
View file @
f6604bd2
from
..
my
import
color_mode
from
my
import
color_mode
def
update_config
(
config
):
def
update_config
(
config
):
# Dataset settings
# Dataset settings
...
...
dash_test.py
View file @
f6604bd2
...
@@ -10,7 +10,7 @@ import plotly.express as px
...
@@ -10,7 +10,7 @@ import plotly.express as px
import
pandas
as
pd
import
pandas
as
pd
from
dash.dependencies
import
Input
,
Output
from
dash.dependencies
import
Input
,
Output
sys
.
path
.
append
(
os
.
path
.
abspath
(
sys
.
path
[
0
]
+
'/../'
))
#
sys.path.append(os.path.abspath(sys.path[0] + '/../'))
#__package__ = "deep_view_syn"
#__package__ = "deep_view_syn"
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
@@ -24,23 +24,30 @@ if __name__ == '__main__':
...
@@ -24,23 +24,30 @@ if __name__ == '__main__':
print
(
"Set CUDA:%d as current device."
%
torch
.
cuda
.
current_device
())
print
(
"Set CUDA:%d as current device."
%
torch
.
cuda
.
current_device
())
torch
.
autograd
.
set_grad_enabled
(
False
)
torch
.
autograd
.
set_grad_enabled
(
False
)
from
deep_view_syn.data.spherical_view_syn
import
*
from
data.spherical_view_syn
import
*
from
deep_view_syn.configs.spherical_view_syn
import
SphericalViewSynConfig
from
configs.spherical_view_syn
import
SphericalViewSynConfig
from
deep_view_syn.my
import
netio
from
my
import
netio
from
deep_view_syn.my
import
util
from
my
import
util
from
deep_view_syn.my
import
device
from
my
import
device
from
deep_view_syn.my
import
view
from
my
import
view
from
deep_view_syn.my.gen_final
import
GenFinal
from
my.gen_final
import
GenFinal
from
deep_view_syn.nets.modules
import
Sampler
from
nets.modules
import
Sampler
datadir
=
None
datadir
=
'data/__0_user_study/us_gas_periph_r135x135_t0.3_2021.01.16/'
data_desc_file
=
'train.json'
net_config
=
'periph_rgb@msl-rgb_e10_fc96x4_d1.00-50.00_s16'
net_path
=
datadir
+
net_config
+
'/model-epoch_200.pth'
fov
=
45
res
=
(
256
,
256
)
view_idx
=
4
center
=
(
0
,
0
)
def
load_net
(
path
):
def
load_net
(
path
):
print
(
path
)
print
(
path
)
config
=
SphericalViewSynConfig
()
config
=
SphericalViewSynConfig
()
config
.
from_id
(
os
.
path
.
splitext
(
os
.
path
.
basename
(
path
))[
0
]
)
config
.
from_id
(
net_config
)
config
.
SAMPLE_PARAMS
[
'perturb_sample'
]
=
False
config
.
SAMPLE_PARAMS
[
'perturb_sample'
]
=
False
net
=
config
.
create_net
().
to
(
device
.
GetDevice
())
net
=
config
.
create_net
().
to
(
device
.
GetDevice
())
netio
.
LoadNet
(
path
,
net
)
netio
.
LoadNet
(
path
,
net
)
...
@@ -64,24 +71,25 @@ def load_views(data_desc_file) -> view.Trans:
...
@@ -64,24 +71,25 @@ def load_views(data_desc_file) -> view.Trans:
return
view
.
Trans
(
view_centers
,
view_rots
)
return
view
.
Trans
(
view_centers
,
view_rots
)
scenes
=
{
cam
=
view
.
CameraParam
({
'gas'
:
'__0_user_study/us_gas_all_in_one'
,
'fov'
:
fov
,
'mc'
:
'__0_user_study/us_mc_all_in_one'
,
'cx'
:
0.5
,
'bedroom'
:
'bedroom_all_in_one'
,
'cy'
:
0.5
,
'gallery'
:
'gallery_all_in_one'
,
'normalized'
:
True
'lobby'
:
'lobby_all_in_one'
},
res
,
device
=
device
.
GetDevice
())
}
net
=
load_net
(
net_path
)
sampler
=
Sampler
(
depth_range
=
(
1
,
50
),
n_samples
=
32
,
perturb_sample
=
False
,
fov_list
=
[
20
,
45
,
110
]
spherical
=
True
,
lindisp
=
True
,
inverse_r
=
True
)
res_list
=
[(
128
,
128
),
(
256
,
256
),
(
256
,
230
)]
x
=
y
=
None
res_full
=
(
1600
,
1440
)
views
=
load_views
(
data_desc_file
)
print
(
'%d Views loaded.'
%
views
.
size
()[
0
])
scene
=
'gas'
test_view
=
views
.
get
(
view_idx
)
view_file
=
'views.json'
rays_o
,
rays_d
=
cam
.
get_global_rays
(
test_view
,
True
)
image
=
net
(
rays_o
.
view
(
-
1
,
3
),
rays_d
.
view
(
-
1
,
3
)).
view
(
1
,
res
[
0
],
res
[
1
],
-
1
).
permute
(
0
,
3
,
1
,
2
)
app
=
dash
.
Dash
(
__name__
,
external_stylesheets
=
[
'https://codepen.io/chriddyp/pen/bWLwgP.css'
])
styles
=
{
styles
=
{
'pre'
:
{
'pre'
:
{
...
@@ -90,35 +98,17 @@ styles = {
...
@@ -90,35 +98,17 @@ styles = {
'overflowX'
:
'scroll'
'overflowX'
:
'scroll'
}
}
}
}
fig
=
px
.
imshow
(
util
.
Tensor2MatImg
(
image
))
datadir
=
'data/'
+
scenes
[
scene
]
+
'/'
fovea_net
=
load_net_by_name
(
'fovea'
)
periph_net
=
load_net_by_name
(
'periph'
)
gen
=
GenFinal
(
fov_list
,
res_list
,
res_full
,
fovea_net
,
periph_net
,
device
=
device
.
GetDevice
())
sampler
=
Sampler
(
depth_range
=
(
1
,
50
),
n_samples
=
32
,
perturb_sample
=
False
,
spherical
=
True
,
lindisp
=
True
,
inverse_r
=
True
)
x
=
y
=
None
views
=
load_views
(
view_file
)
print
(
'%d Views loaded.'
,
views
.
size
())
view_idx
=
27
center
=
(
0
,
0
)
test_view
=
views
.
get
(
view_idx
)
images
=
gen
(
center
,
test_view
)
fig
=
px
.
imshow
(
util
.
Tensor2MatImg
(
images
[
'fovea'
]))
fig1
=
px
.
scatter
(
x
=
[
0
,
1
,
2
],
y
=
[
2
,
0
,
1
])
fig1
=
px
.
scatter
(
x
=
[
0
,
1
,
2
],
y
=
[
2
,
0
,
1
])
fig2
=
px
.
scatter
(
x
=
[
0
,
1
,
2
],
y
=
[
2
,
0
,
1
])
app
=
dash
.
Dash
(
__name__
,
external_stylesheets
=
[
'https://codepen.io/chriddyp/pen/bWLwgP.css'
])
app
.
layout
=
html
.
Div
([
app
.
layout
=
html
.
Div
([
html
.
H3
(
"Drag and draw annotations"
),
html
.
H3
(
"Drag and draw annotations"
),
html
.
Div
(
className
=
'row'
,
children
=
[
html
.
Div
(
className
=
'row'
,
children
=
[
dcc
.
Graph
(
id
=
'image'
,
figure
=
fig
),
# , config=config),
dcc
.
Graph
(
id
=
'image'
,
figure
=
fig
),
# , config=config),
dcc
.
Graph
(
id
=
'scatter'
,
figure
=
fig1
),
# , config=config),
dcc
.
Graph
(
id
=
'scatter'
,
figure
=
fig1
),
# , config=config),
dcc
.
Graph
(
id
=
'scatter1'
,
figure
=
fig2
),
# , config=config),
dcc
.
Slider
(
id
=
'samples-slider'
,
min
=
4
,
max
=
128
,
step
=
None
,
dcc
.
Slider
(
id
=
'samples-slider'
,
min
=
4
,
max
=
128
,
step
=
None
,
marks
=
{
marks
=
{
4
:
'4'
,
4
:
'4'
,
...
@@ -128,43 +118,91 @@ app.layout = html.Div([
...
@@ -128,43 +118,91 @@ app.layout = html.Div([
64
:
'64'
,
64
:
'64'
,
128
:
'128'
,
128
:
'128'
,
},
},
value
=
3
2
,
value
=
3
3
,
updatemode
=
'drag'
updatemode
=
'drag'
)
)
])
])
])
])
def
raw2alpha
(
raw
,
dists
,
act_fn
=
torch
.
relu
):
"""
Function for computing density from model prediction.
This value is strictly between [0, 1].
"""
print
(
'act_fn: '
,
act_fn
(
raw
))
print
(
'act_fn * dists: '
,
act_fn
(
raw
)
*
dists
)
return
-
torch
.
exp
(
-
act_fn
(
raw
)
*
dists
)
+
1.0
def
raw2color
(
raw
:
torch
.
Tensor
,
z_vals
:
torch
.
Tensor
):
"""
Raw value inferred from model to color and alpha
:param raw ```Tensor(N.rays, N.samples, 2|4)```: model's output
:param z_vals ```Tensor(N.rays, N.samples)```: integration time
:return ```Tensor(N.rays, N.samples, 1|3)```: color
:return ```Tensor(N.rays, N.samples)```: alpha
"""
# Compute 'distance' (in time) between each integration time along a ray.
# The 'distance' from the last integration time is infinity.
# dists: (N_rays, N_samples)
dists
=
z_vals
[...,
1
:]
-
z_vals
[...,
:
-
1
]
last_dist
=
z_vals
[...,
0
:
1
]
*
0
+
1e10
dists
=
torch
.
cat
([
dists
,
last_dist
],
-
1
)
print
(
'dists: '
,
dists
)
# Extract RGB of each sample position along each ray.
color
=
torch
.
sigmoid
(
raw
[...,
:
-
1
])
# (N_rays, N_samples, 1|3)
alpha
=
raw2alpha
(
raw
[...,
-
1
],
dists
)
return
color
,
alpha
def
draw_scatter
():
def
draw_scatter
():
global
fig1
global
fig1
,
fig2
p
=
torch
.
tensor
([
x
,
y
],
device
=
gen
.
layer_cams
[
0
].
c
.
d
evice
)
p
=
torch
.
tensor
([
x
,
y
],
device
=
device
.
GetD
evice
()
)
ray_d
=
test_view
.
trans_vector
(
gen
.
layer_cams
[
0
]
.
unproj
(
p
))
ray_d
=
test_view
.
trans_vector
(
cam
.
unproj
(
p
))
ray_o
=
test_view
.
t
ray_o
=
test_view
.
t
raw
,
depths
=
fovea_
net
.
sample_and_infer
(
ray_o
,
ray_d
,
sampler
=
sampler
)
raw
,
depths
=
net
.
sample_and_infer
(
ray_o
,
ray_d
,
sampler
=
sampler
)
colors
,
alphas
=
fovea_net
.
rendering
.
raw2color
(
raw
,
depths
)
colors
,
alphas
=
raw2color
(
raw
,
depths
)
scatter_x
=
(
1
/
depths
[
0
]).
cpu
().
detach
().
numpy
()
scatter_x
=
(
1
/
depths
[
0
]).
cpu
().
detach
().
numpy
()
scatter_y
=
alphas
[
0
].
cpu
().
detach
().
numpy
()
scatter_y
=
alphas
[
0
].
cpu
().
detach
().
numpy
()
scatter_y1
=
raw
[
0
,
:,
3
].
cpu
().
detach
().
numpy
()
scatter_color
=
colors
[
0
].
cpu
().
detach
().
numpy
()
*
255
scatter_color
=
colors
[
0
].
cpu
().
detach
().
numpy
()
*
255
marker_colors
=
[
marker_colors
=
[
i
#'rgb(%d,%d,%d)' % (scatter_color[i][0], scatter_color[i][1], scatter_color[i][2])
# 'rgb(%d,%d,%d)' % (scatter_color[i][0], scatter_color[i][1], scatter_color[i][2])
i
for
i
in
range
(
scatter_color
.
shape
[
0
])
for
i
in
range
(
scatter_color
.
shape
[
0
])
]
]
marker_colors_str
=
[
marker_colors_str
=
[
'rgb(%d,%d,%d)'
%
(
scatter_color
[
i
][
0
],
scatter_color
[
i
][
1
],
scatter_color
[
i
][
2
])
'rgb(%d,%d,%d)'
%
(
scatter_color
[
i
][
0
],
scatter_color
[
i
][
1
],
scatter_color
[
i
][
2
])
for
i
in
range
(
scatter_color
.
shape
[
0
])
for
i
in
range
(
scatter_color
.
shape
[
0
])
]
]
fig1
=
px
.
scatter
(
x
=
scatter_x
,
y
=
scatter_y
,
color
=
marker_colors
,
color_continuous_scale
=
marker_colors_str
)
#, color_discrete_map='identity')
fig1
=
px
.
scatter
(
x
=
scatter_x
,
y
=
scatter_y
,
color
=
marker_colors
,
color_continuous_scale
=
marker_colors_str
)
# , color_discrete_map='identity')
fig1
.
update_traces
(
mode
=
'lines+markers'
)
fig1
.
update_traces
(
mode
=
'lines+markers'
)
fig1
.
update_xaxes
(
showgrid
=
False
)
fig1
.
update_xaxes
(
showgrid
=
False
)
fig1
.
update_yaxes
(
type
=
'linear'
)
fig1
.
update_yaxes
(
type
=
'linear'
)
fig1
.
update_layout
(
height
=
225
,
margin
=
{
'l'
:
20
,
'b'
:
30
,
'r'
:
10
,
't'
:
10
})
fig1
.
update_layout
(
height
=
225
,
margin
=
{
'l'
:
20
,
'b'
:
30
,
'r'
:
10
,
't'
:
10
})
fig2
=
px
.
scatter
(
x
=
scatter_x
,
y
=
scatter_y1
,
color
=
marker_colors
,
color_continuous_scale
=
marker_colors_str
)
# , color_discrete_map='identity')
fig2
.
update_traces
(
mode
=
'lines+markers'
)
fig2
.
update_xaxes
(
showgrid
=
False
)
fig2
.
update_yaxes
(
type
=
'linear'
)
fig2
.
update_layout
(
height
=
225
,
margin
=
{
'l'
:
20
,
'b'
:
30
,
'r'
:
10
,
't'
:
10
})
@
app
.
callback
(
@
app
.
callback
(
[
Output
(
'image'
,
'figure'
),
[
Output
(
'image'
,
'figure'
),
Output
(
'scatter'
,
'figure'
)],
Output
(
'scatter'
,
'figure'
),
Output
(
'scatter1'
,
'figure'
)],
[
Input
(
'image'
,
'clickData'
),
[
Input
(
'image'
,
'clickData'
),
dash
.
dependencies
.
Input
(
'samples-slider'
,
'value'
)]
dash
.
dependencies
.
Input
(
'samples-slider'
,
'value'
)]
)
)
...
@@ -194,7 +232,7 @@ def display_hover_data(clickData, samples):
...
@@ -194,7 +232,7 @@ def display_hover_data(clickData, samples):
color
=
"LightSeaGreen"
,
color
=
"LightSeaGreen"
,
width
=
3
,
width
=
3
,
))
))
return
fig
,
fig1
return
fig
,
fig1
,
fig2
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
...
data/lf_syn.py
deleted
100644 → 0
View file @
6e54b394
from
typing
import
List
,
Tuple
import
torch
import
json
from
..my
import
util
def
ReadLightField
(
path
:
str
,
views
:
Tuple
[
int
,
int
],
flatten_views
:
bool
=
False
)
->
torch
.
Tensor
:
input_img
=
util
.
ReadImageTensor
(
path
,
batch_dim
=
False
)
h
=
input_img
.
size
()[
1
]
//
views
[
0
]
w
=
input_img
.
size
()[
2
]
//
views
[
1
]
if
flatten_views
:
lf
=
torch
.
empty
(
views
[
0
]
*
views
[
1
],
3
,
h
,
w
)
for
y_i
in
range
(
views
[
0
]):
for
x_i
in
range
(
views
[
1
]):
lf
[
y_i
*
views
[
1
]
+
x_i
,
:,
:,
:]
=
\
input_img
[:,
y_i
*
h
:(
y_i
+
1
)
*
h
,
x_i
*
w
:(
x_i
+
1
)
*
w
]
else
:
lf
=
torch
.
empty
(
views
[
0
],
views
[
1
],
3
,
h
,
w
)
for
y_i
in
range
(
views
[
0
]):
for
x_i
in
range
(
views
[
1
]):
lf
[
y_i
,
x_i
,
:,
:,
:]
=
\
input_img
[:,
y_i
*
h
:(
y_i
+
1
)
*
h
,
x_i
*
w
:(
x_i
+
1
)
*
w
]
return
lf
def
DecodeDepth
(
depth_images
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
depth_images
[:,
0
].
unsqueeze
(
1
).
mul
(
255
)
/
10
class
LightFieldSynDataset
(
torch
.
utils
.
data
.
dataset
.
Dataset
):
"""
Data loader for light field synthesis task
Attributes
--------
data_dir ```string```: the directory of dataset
\n
n_views ```tuple(int, int)```: rows and columns of views
\n
num_views ```int```: number of views
\n
view_images ```N x H x W Tensor```: images of views
\n
view_depths ```N x H x W Tensor```: depths of views
\n
view_positions ```N x 3 Tensor```: positions of views
\n
sparse_view_images ```N' x H x W Tensor```: images of sparse views
\n
sparse_view_depths ```N' x H x W Tensor```: depths of sparse views
\n
sparse_view_positions ```N' x 3 Tensor```: positions of sparse views
\n
"""
def
__init__
(
self
,
data_desc_path
:
str
):
"""
Initialize data loader for light field synthesis task
The data description file is a JSON file with following fields:
- lf: string, the path of light field image
- lf_depth: string, the path of light field depth image
- n_views: { "x", "y" }, columns and rows of views
- cam_params: { "f", "c" }, the focal and center of camera (in normalized image space)
- depth_range: [ min, max ], the range of depth in depth maps
- depth_layers: int, number of layers in depth maps
- view_positions: [ [ x, y, z ], ... ], positions of views
:param data_desc_path: path to the data description file
"""
self
.
data_dir
=
data_desc_path
.
rsplit
(
'/'
,
1
)[
0
]
+
'/'
with
open
(
data_desc_path
,
'r'
,
encoding
=
'utf-8'
)
as
file
:
self
.
data_desc
=
json
.
loads
(
file
.
read
())
self
.
n_views
=
(
self
.
data_desc
[
'n_views'
]
[
'y'
],
self
.
data_desc
[
'n_views'
][
'x'
])
self
.
num_views
=
self
.
n_views
[
0
]
*
self
.
n_views
[
1
]
self
.
view_images
=
ReadLightField
(
self
.
data_dir
+
self
.
data_desc
[
'lf'
],
self
.
n_views
,
True
)
self
.
view_depths
=
DecodeDepth
(
ReadLightField
(
self
.
data_dir
+
self
.
data_desc
[
'lf_depth'
],
self
.
n_views
,
True
))
self
.
cam_params
=
self
.
data_desc
[
'cam_params'
]
self
.
depth_range
=
self
.
data_desc
[
'depth_range'
]
self
.
depth_layers
=
self
.
data_desc
[
'depth_layers'
]
self
.
view_positions
=
torch
.
tensor
(
self
.
data_desc
[
'view_positions'
])
_
,
self
.
sparse_view_images
,
self
.
sparse_view_depths
,
self
.
sparse_view_positions
\
=
self
.
_GetCornerViews
()
self
.
diopter_of_layers
=
self
.
_GetDiopterOfLayers
()
def
__len__
(
self
):
return
self
.
num_views
def
__getitem__
(
self
,
idx
):
return
idx
,
self
.
view_images
[
idx
],
self
.
view_depths
[
idx
],
self
.
view_positions
[
idx
]
def
_GetCornerViews
(
self
):
corner_selector
=
torch
.
zeros
(
self
.
num_views
,
dtype
=
torch
.
bool
)
corner_selector
[
0
]
=
corner_selector
[
self
.
n_views
[
1
]
-
1
]
\
=
corner_selector
[
self
.
num_views
-
self
.
n_views
[
1
]]
\
=
corner_selector
[
self
.
num_views
-
1
]
=
True
return
self
.
__getitem__
(
corner_selector
)
def
_GetDiopterOfLayers
(
self
)
->
List
[
float
]:
diopter_range
=
(
1
/
self
.
depth_range
[
1
],
1
/
self
.
depth_range
[
0
])
step
=
(
diopter_range
[
1
]
-
diopter_range
[
0
])
/
(
self
.
depth_layers
-
1
)
diopter_of_layers
=
[
diopter_range
[
0
]
+
step
*
i
for
i
in
range
(
self
.
depth_layers
)]
diopter_of_layers
.
insert
(
0
,
0
)
return
diopter_of_layers
data/loader.py
View file @
f6604bd2
import
torch
import
torch
import
math
import
math
from
..
my
import
device
from
my
import
device
class
FastDataLoader
(
object
):
class
FastDataLoader
(
object
):
...
...
data/other.py
View file @
f6604bd2
import
torch
import
torch
import
os
import
os
import
json
import
glob
import
glob
import
cv2
import
numpy
as
np
import
numpy
as
np
import
torchvision.transforms
as
transforms
import
torchvision.transforms
as
transforms
from
typing
import
List
,
Tuple
from
torchvision
import
datasets
from
torchvision
import
datasets
from
torch.utils.data
import
DataLoader
from
torch.utils.data
import
DataLoader
import
cv2
from
my.flow
import
*
import
json
from
my.gen_image
import
*
from
Flow
import
*
from
my
import
util
from
gen_image
import
*
import
util
def
ReadLightField
(
path
:
str
,
views
:
Tuple
[
int
,
int
],
flatten_views
:
bool
=
False
)
->
torch
.
Tensor
:
input_img
=
util
.
ReadImageTensor
(
path
,
batch_dim
=
False
)
h
=
input_img
.
size
()[
1
]
//
views
[
0
]
w
=
input_img
.
size
()[
2
]
//
views
[
1
]
if
flatten_views
:
lf
=
torch
.
empty
(
views
[
0
]
*
views
[
1
],
3
,
h
,
w
)
for
y_i
in
range
(
views
[
0
]):
for
x_i
in
range
(
views
[
1
]):
lf
[
y_i
*
views
[
1
]
+
x_i
,
:,
:,
:]
=
\
input_img
[:,
y_i
*
h
:(
y_i
+
1
)
*
h
,
x_i
*
w
:(
x_i
+
1
)
*
w
]
else
:
lf
=
torch
.
empty
(
views
[
0
],
views
[
1
],
3
,
h
,
w
)
for
y_i
in
range
(
views
[
0
]):
for
x_i
in
range
(
views
[
1
]):
lf
[
y_i
,
x_i
,
:,
:,
:]
=
\
input_img
[:,
y_i
*
h
:(
y_i
+
1
)
*
h
,
x_i
*
w
:(
x_i
+
1
)
*
w
]
return
lf
def
DecodeDepth
(
depth_images
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
depth_images
[:,
0
].
unsqueeze
(
1
).
mul
(
255
)
/
10
class
LightFieldSynDataset
(
torch
.
utils
.
data
.
dataset
.
Dataset
):
"""
Data loader for light field synthesis task
Attributes
--------
data_dir ```string```: the directory of dataset
\n
n_views ```tuple(int, int)```: rows and columns of views
\n
num_views ```int```: number of views
\n
view_images ```N x H x W Tensor```: images of views
\n
view_depths ```N x H x W Tensor```: depths of views
\n
view_positions ```N x 3 Tensor```: positions of views
\n
sparse_view_images ```N' x H x W Tensor```: images of sparse views
\n
sparse_view_depths ```N' x H x W Tensor```: depths of sparse views
\n
sparse_view_positions ```N' x 3 Tensor```: positions of sparse views
\n
"""
def
__init__
(
self
,
data_desc_path
:
str
):
"""
Initialize data loader for light field synthesis task
The data description file is a JSON file with following fields:
- lf: string, the path of light field image
- lf_depth: string, the path of light field depth image
- n_views: { "x", "y" }, columns and rows of views
- cam_params: { "f", "c" }, the focal and center of camera (in normalized image space)
- depth_range: [ min, max ], the range of depth in depth maps
- depth_layers: int, number of layers in depth maps
- view_positions: [ [ x, y, z ], ... ], positions of views
:param data_desc_path: path to the data description file
"""
self
.
data_dir
=
data_desc_path
.
rsplit
(
'/'
,
1
)[
0
]
+
'/'
with
open
(
data_desc_path
,
'r'
,
encoding
=
'utf-8'
)
as
file
:
self
.
data_desc
=
json
.
loads
(
file
.
read
())
self
.
n_views
=
(
self
.
data_desc
[
'n_views'
]
[
'y'
],
self
.
data_desc
[
'n_views'
][
'x'
])
self
.
num_views
=
self
.
n_views
[
0
]
*
self
.
n_views
[
1
]
self
.
view_images
=
ReadLightField
(
self
.
data_dir
+
self
.
data_desc
[
'lf'
],
self
.
n_views
,
True
)
self
.
view_depths
=
DecodeDepth
(
ReadLightField
(
self
.
data_dir
+
self
.
data_desc
[
'lf_depth'
],
self
.
n_views
,
True
))
self
.
cam_params
=
self
.
data_desc
[
'cam_params'
]
self
.
depth_range
=
self
.
data_desc
[
'depth_range'
]
self
.
depth_layers
=
self
.
data_desc
[
'depth_layers'
]
self
.
view_positions
=
torch
.
tensor
(
self
.
data_desc
[
'view_positions'
])
_
,
self
.
sparse_view_images
,
self
.
sparse_view_depths
,
self
.
sparse_view_positions
\
=
self
.
_GetCornerViews
()
self
.
diopter_of_layers
=
self
.
_GetDiopterOfLayers
()
def
__len__
(
self
):
return
self
.
num_views
def
__getitem__
(
self
,
idx
):
return
idx
,
self
.
view_images
[
idx
],
self
.
view_depths
[
idx
],
self
.
view_positions
[
idx
]
def
_GetCornerViews
(
self
):
corner_selector
=
torch
.
zeros
(
self
.
num_views
,
dtype
=
torch
.
bool
)
corner_selector
[
0
]
=
corner_selector
[
self
.
n_views
[
1
]
-
1
]
\
=
corner_selector
[
self
.
num_views
-
self
.
n_views
[
1
]]
\
=
corner_selector
[
self
.
num_views
-
1
]
=
True
return
self
.
__getitem__
(
corner_selector
)
import
time
def
_GetDiopterOfLayers
(
self
)
->
List
[
float
]:
diopter_range
=
(
1
/
self
.
depth_range
[
1
],
1
/
self
.
depth_range
[
0
])
step
=
(
diopter_range
[
1
]
-
diopter_range
[
0
])
/
(
self
.
depth_layers
-
1
)
diopter_of_layers
=
[
diopter_range
[
0
]
+
step
*
i
for
i
in
range
(
self
.
depth_layers
)]
diopter_of_layers
.
insert
(
0
,
0
)
return
diopter_of_layers
class
lightFieldSynDataLoader
(
torch
.
utils
.
data
.
dataset
.
Dataset
):
class
lightFieldSynDataLoader
(
torch
.
utils
.
data
.
dataset
.
Dataset
):
...
@@ -90,8 +185,8 @@ class lightFieldDataLoader(torch.utils.data.dataset.Dataset):
...
@@ -90,8 +185,8 @@ class lightFieldDataLoader(torch.utils.data.dataset.Dataset):
# print(lf_image_big.shape)
# print(lf_image_big.shape)
for
i
in
range
(
9
):
for
i
in
range
(
9
):
lf_image
=
lf_image_big
[
i
//
3
*
IM_H
:
i
//
3
*
lf_image
=
lf_image_big
[
i
//
3
*
IM_H
:
i
//
3
*
IM_H
+
IM_H
,
i
%
3
*
IM_W
:
i
%
3
*
IM_W
+
IM_W
,
0
:
3
]
IM_H
+
IM_H
,
i
%
3
*
IM_W
:
i
%
3
*
IM_W
+
IM_W
,
0
:
3
]
# IF GrayScale
# IF GrayScale
# lf_image = lf_image_big[i//3*IM_H:i//3*IM_H+IM_H,i%3*IM_W:i%3*IM_W+IM_W,0:1]
# lf_image = lf_image_big[i//3*IM_H:i//3*IM_H+IM_H,i%3*IM_W:i%3*IM_W+IM_W,0:1]
# print(lf_image.shape)
# print(lf_image.shape)
...
@@ -146,8 +241,8 @@ class lightFieldValDataLoader(torch.utils.data.dataset.Dataset):
...
@@ -146,8 +241,8 @@ class lightFieldValDataLoader(torch.utils.data.dataset.Dataset):
# print(lf_image_big.shape)
# print(lf_image_big.shape)
for
i
in
range
(
9
):
for
i
in
range
(
9
):
lf_image
=
lf_image_big
[
i
//
3
*
IM_H
:
i
//
3
*
lf_image
=
lf_image_big
[
i
//
3
*
IM_H
:
i
//
3
*
IM_H
+
IM_H
,
i
%
3
*
IM_W
:
i
%
3
*
IM_W
+
IM_W
,
0
:
3
]
IM_H
+
IM_H
,
i
%
3
*
IM_W
:
i
%
3
*
IM_W
+
IM_W
,
0
:
3
]
# IF GrayScale
# IF GrayScale
# lf_image = lf_image_big[i//3*IM_H:i//3*IM_H+IM_H,i%3*IM_W:i%3*IM_W+IM_W,0:1]
# lf_image = lf_image_big[i//3*IM_H:i//3*IM_H+IM_H,i%3*IM_W:i%3*IM_W+IM_W,0:1]
# print(lf_image.shape)
# print(lf_image.shape)
...
@@ -214,8 +309,8 @@ class lightFieldSeqDataLoader(torch.utils.data.dataset.Dataset):
...
@@ -214,8 +309,8 @@ class lightFieldSeqDataLoader(torch.utils.data.dataset.Dataset):
lf_image_big
=
cv2
.
cvtColor
(
lf_image_big
,
cv2
.
COLOR_BGR2RGB
)
lf_image_big
=
cv2
.
cvtColor
(
lf_image_big
,
cv2
.
COLOR_BGR2RGB
)
for
j
in
range
(
9
):
for
j
in
range
(
9
):
lf_image
=
lf_image_big
[
j
//
3
*
IM_H
:
j
//
3
*
lf_image
=
lf_image_big
[
j
//
3
*
IM_H
:
j
//
3
*
IM_H
+
IM_H
,
j
%
3
*
IM_W
:
j
%
3
*
IM_W
+
IM_W
,
0
:
3
]
IM_H
+
IM_H
,
j
%
3
*
IM_W
:
j
%
3
*
IM_W
+
IM_W
,
0
:
3
]
lf_image_one_sample
.
append
(
lf_image
)
lf_image_one_sample
.
append
(
lf_image
)
gt_i
=
cv2
.
imread
(
fd_gt_path
,
cv2
.
IMREAD_UNCHANGED
).
astype
(
gt_i
=
cv2
.
imread
(
fd_gt_path
,
cv2
.
IMREAD_UNCHANGED
).
astype
(
...
@@ -297,8 +392,8 @@ class lightFieldFlowSeqDataLoader(torch.utils.data.dataset.Dataset):
...
@@ -297,8 +392,8 @@ class lightFieldFlowSeqDataLoader(torch.utils.data.dataset.Dataset):
lf_dim
=
int
(
self
.
conf
.
light_field_dim
)
lf_dim
=
int
(
self
.
conf
.
light_field_dim
)
for
j
in
range
(
lf_dim
**
2
):
for
j
in
range
(
lf_dim
**
2
):
lf_image
=
lf_image_big
[
j
//
lf_dim
*
IM_H
:
j
//
lf_dim
*
lf_image
=
lf_image_big
[
j
//
lf_dim
*
IM_H
:
j
//
lf_dim
*
IM_H
+
IM_H
,
j
%
lf_dim
*
IM_W
:
j
%
lf_dim
*
IM_W
+
IM_W
,
0
:
3
]
IM_H
+
IM_H
,
j
%
lf_dim
*
IM_W
:
j
%
lf_dim
*
IM_W
+
IM_W
,
0
:
3
]
lf_image_one_sample
.
append
(
lf_image
)
lf_image_one_sample
.
append
(
lf_image
)
gt_i
=
cv2
.
imread
(
fd_gt_path
,
cv2
.
IMREAD_UNCHANGED
).
astype
(
gt_i
=
cv2
.
imread
(
fd_gt_path
,
cv2
.
IMREAD_UNCHANGED
).
astype
(
...
@@ -333,7 +428,7 @@ class lightFieldFlowSeqDataLoader(torch.utils.data.dataset.Dataset):
...
@@ -333,7 +428,7 @@ class lightFieldFlowSeqDataLoader(torch.utils.data.dataset.Dataset):
retinal_invalid
.
append
(
retinal_invalid_i
)
retinal_invalid
.
append
(
retinal_invalid_i
)
# lf_images: 5,9,320,320
# lf_images: 5,9,320,320
flow
=
Flow
.
Load
([
os
.
path
.
join
(
self
.
file_dir_path
,
self
.
dataset_desc
[
"flow"
]
flow
=
Flow
.
Load
([
os
.
path
.
join
(
self
.
file_dir_path
,
self
.
dataset_desc
[
"flow"
]
[
indices
[
i
-
1
]])
for
i
in
range
(
1
,
len
(
indices
))])
[
indices
[
i
-
1
]])
for
i
in
range
(
1
,
len
(
indices
))])
flow_map
=
flow
.
getMap
()
flow_map
=
flow
.
getMap
()
flow_invalid_mask
=
flow
.
b_invalid_mask
flow_invalid_mask
=
flow
.
b_invalid_mask
# print("flow:",flow_map.shape)
# print("flow:",flow_map.shape)
...
...
data/spherical_view_syn.py
View file @
f6604bd2
...
@@ -4,10 +4,10 @@ import torch
...
@@ -4,10 +4,10 @@ import torch
import
torchvision.transforms.functional
as
trans_f
import
torchvision.transforms.functional
as
trans_f
import
torch.nn.functional
as
nn_f
import
torch.nn.functional
as
nn_f
from
typing
import
Tuple
,
Union
from
typing
import
Tuple
,
Union
from
..
my
import
util
from
my
import
util
from
..
my
import
device
from
my
import
device
from
..
my
import
view
from
my
import
view
from
..
my
import
color_mode
from
my
import
color_mode
class
SphericalViewSynDataset
(
object
):
class
SphericalViewSynDataset
(
object
):
...
@@ -129,6 +129,13 @@ class SphericalViewSynDataset(object):
...
@@ -129,6 +129,13 @@ class SphericalViewSynDataset(object):
self
.
n_views
=
self
.
view_centers
.
size
(
0
)
self
.
n_views
=
self
.
view_centers
.
size
(
0
)
self
.
n_pixels
=
self
.
n_views
*
self
.
view_res
[
0
]
*
self
.
view_res
[
1
]
self
.
n_pixels
=
self
.
n_views
*
self
.
view_res
[
0
]
*
self
.
view_res
[
1
]
if
'gl_coord'
in
data_desc
and
data_desc
[
'gl_coord'
]
==
True
:
print
(
'Convert from OGL coordinate to DX coordinate (i. e. right-hand to left-hand)'
)
self
.
cam_params
.
f
[
1
]
*=
-
1
self
.
view_centers
[:,
2
]
*=
-
1
self
.
view_rots
[:,
2
]
*=
-
1
self
.
view_rots
[...,
2
]
*=
-
1
def
set_patch_size
(
self
,
patch_size
:
Union
[
int
,
Tuple
[
int
,
int
]],
def
set_patch_size
(
self
,
patch_size
:
Union
[
int
,
Tuple
[
int
,
int
]],
offset
:
Union
[
int
,
Tuple
[
int
,
int
]]
=
0
):
offset
:
Union
[
int
,
Tuple
[
int
,
int
]]
=
0
):
"""
"""
...
...
nets/modules.py
View file @
f6604bd2
from
typing
import
List
,
Tuple
from
typing
import
List
,
Tuple
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
..
my
import
device
from
my
import
device
from
..
my
import
util
from
my
import
util
class
FcLayer
(
nn
.
Module
):
class
FcLayer
(
nn
.
Module
):
...
...
nets/msl_net.py
View file @
f6604bd2
...
@@ -2,9 +2,8 @@ import math
...
@@ -2,9 +2,8 @@ import math
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
.modules
import
*
from
.modules
import
*
from
..my
import
util
from
my
import
util
from
..my
import
color_mode
from
my
import
color_mode
class
MslNet
(
nn
.
Module
):
class
MslNet
(
nn
.
Module
):
...
...
nets/msl_net_new.py
View file @
f6604bd2
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
.modules
import
*
from
.modules
import
*
from
..
my
import
color_mode
from
my
import
color_mode
from
..
my.simple_perf
import
SimplePerf
from
my.simple_perf
import
SimplePerf
class
NewMslNet
(
nn
.
Module
):
class
NewMslNet
(
nn
.
Module
):
...
...
nets/msl_net_new_export.py
View file @
f6604bd2
...
@@ -2,10 +2,10 @@ from typing import Tuple
...
@@ -2,10 +2,10 @@ from typing import Tuple
import
math
import
math
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
..
my
import
net_modules
from
my
import
net_modules
from
..
my
import
util
from
my
import
util
from
..
my
import
device
from
my
import
device
from
..
my
import
color_mode
from
my
import
color_mode
from
.msl_net_new
import
NewMslNet
from
.msl_net_new
import
NewMslNet
...
...
nets/spher_net.py
View file @
f6604bd2
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
.modules
import
*
from
.modules
import
*
from
..
my
import
util
from
my
import
util
class
SpherNet
(
nn
.
Module
):
class
SpherNet
(
nn
.
Module
):
...
...
nets/trans_unet.py
View file @
f6604bd2
from
typing
import
List
from
typing
import
List
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
..
pytorch_prototyping.pytorch_prototyping
import
*
from
pytorch_prototyping.pytorch_prototyping
import
*
from
..
my
import
util
from
my
import
util
from
..
my
import
device
from
my
import
device
class
Encoder
(
nn
.
Module
):
class
Encoder
(
nn
.
Module
):
...
...
notebook/test_spherical_view_syn.ipynb
View file @
f6604bd2
...
@@ -2,20 +2,28 @@
...
@@ -2,20 +2,28 @@
"cells": [
"cells": [
{
{
"cell_type": "code",
"cell_type": "code",
"execution_count":
null
,
"execution_count":
2
,
"metadata": {},
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Set CUDA:2 as current device.\n"
]
}
],
"source": [
"source": [
"import sys\n",
"import sys\n",
"import os\n",
"import os\n",
"sys.path.append(os.path.abspath(sys.path[0] + '/../
../
'))\n",
"sys.path.append(os.path.abspath(sys.path[0] + '/../'))\n",
"\n",
"\n",
"import torch\n",
"import torch\n",
"import math\n",
"import math\n",
"import matplotlib.pyplot as plt\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import numpy as np\n",
"from
deep_view_syn.
my import util\n",
"from my import util\n",
"from
deep_view_syn
.msl_net import *\n",
"from
nets
.msl_net import *\n",
"\n",
"\n",
"# Select device\n",
"# Select device\n",
"torch.cuda.set_device(2)\n",
"torch.cuda.set_device(2)\n",
...
@@ -23,8 +31,10 @@
...
@@ -23,8 +31,10 @@
]
]
},
},
{
{
"cell_type": "markdown",
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"metadata": {},
"outputs": [],
"source": [
"source": [
"# Test Ray-Sphere Intersection & Cartesian-Spherical Conversion"
"# Test Ray-Sphere Intersection & Cartesian-Spherical Conversion"
]
]
...
@@ -56,14 +66,15 @@
...
@@ -56,14 +66,15 @@
"v = torch.tensor([[0.0, -1.0, 1.0]])\n",
"v = torch.tensor([[0.0, -1.0, 1.0]])\n",
"r = torch.tensor([[2.5]])\n",
"r = torch.tensor([[2.5]])\n",
"v = v / torch.norm(v) * r * 2\n",
"v = v / torch.norm(v) * r * 2\n",
"p_on_sphere_ = RaySphereIntersect(p, v, r)[0]\n",
"p_on_sphere_ =
util.
RaySphereIntersect(p, v, r)[0]
[0]
\n",
"print(p_on_sphere_)\n",
"print(p_on_sphere_)\n",
"print(p_on_sphere_.norm())\n",
"print(p_on_sphere_.norm())\n",
"spher_coord =
Ray
ToSpherical(p
, v, r
)\n",
"spher_coord =
util.Cartesian
ToSpherical(p
_on_sphere_
)\n",
"print(spher_coord[..., 1:3].rad2deg())\n",
"print(spher_coord[..., 1:3].rad2deg())\n",
"p_on_sphere = util.SphericalToCartesian(spher_coord)[0]\n",
"p_on_sphere = util.SphericalToCartesian(spher_coord)\n",
"print(p_on_sphere_.size())\n",
"\n",
"\n",
"fig = plt.figure(figsize=(
6
,
6
))\n",
"fig = plt.figure(figsize=(
8
,
8
))\n",
"ax = fig.gca(projection='3d')\n",
"ax = fig.gca(projection='3d')\n",
"plt.xlabel('x')\n",
"plt.xlabel('x')\n",
"plt.ylabel('z')\n",
"plt.ylabel('z')\n",
...
@@ -109,8 +120,10 @@
...
@@ -109,8 +120,10 @@
]
]
},
},
{
{
"cell_type": "markdown",
"cell_type": "code",
"execution_count": null,
"metadata": {},
"metadata": {},
"outputs": [],
"source": [
"source": [
"# Test Dataset Loader & View-Spherical Transform"
"# Test Dataset Loader & View-Spherical Transform"
]
]
...
@@ -121,26 +134,26 @@
...
@@ -121,26 +134,26 @@
"metadata": {},
"metadata": {},
"outputs": [],
"outputs": [],
"source": [
"source": [
"from
deep_view_syn.
data.spherical_view_syn import
Fast
SphericalViewSynDataset\n",
"from data.spherical_view_syn import SphericalViewSynDataset\n",
"from d
eep_view_syn.data.spherical_view_syn
import FastDataLoader\n",
"from d
ata.loader
import FastDataLoader\n",
"\n",
"\n",
"DATA_DIR = '../data/
sp_view_syn_2020.12.28
'\n",
"DATA_DIR = '../data/
nerf_fern
'\n",
"TRAIN_DATA_DESC_FILE = DATA_DIR + '/train.json'\n",
"TRAIN_DATA_DESC_FILE = DATA_DIR + '/train.json'\n",
"\n",
"\n",
"dataset =
Fast
SphericalViewSynDataset(TRAIN_DATA_DESC_FILE)\n",
"dataset = SphericalViewSynDataset(TRAIN_DATA_DESC_FILE)\n",
"dataset.set_patch_size((64, 64))\n",
"dataset.set_patch_size((64, 64))\n",
"data_loader = FastDataLoader(dataset=dataset, batch_size=4, shuffle=False, drop_last=False)\n",
"data_loader = FastDataLoader(dataset=dataset, batch_size=4, shuffle=False, drop_last=False)\n",
"print(len(dataset))\n",
"print(len(dataset))\n",
"plt.figure()\n",
"
fig =
plt.figure(
figsize=(12, 6.5)
)\n",
"i = 0\n",
"i = 0\n",
"for indices, patches, rays_o, rays_d in data_loader:\n",
"for indices, patches, rays_o, rays_d in data_loader:\n",
" print(i, patches.size(), rays_o.size(), rays_d.size())\n",
" print(i, patches.size(), rays_o.size(), rays_d.size())\n",
" for idx in range(len(indices)):\n",
" for idx in range(len(indices)):\n",
" plt.subplot(4,
4
, i + 1)\n",
" plt.subplot(4,
7
, i + 1)\n",
" util.PlotImageTensor(patches[idx])\n",
" util.PlotImageTensor(patches[idx])\n",
" i += 1\n",
" i += 1\n",
" if i ==
16
:\n",
" if i ==
28
:\n",
" break
\n
"
" break"
]
]
},
},
{
{
...
@@ -149,13 +162,15 @@
...
@@ -149,13 +162,15 @@
"metadata": {},
"metadata": {},
"outputs": [],
"outputs": [],
"source": [
"source": [
"from deep_view_syn.data.spherical_view_syn import SphericalViewSynDataset\n",
"from data.spherical_view_syn import SphericalViewSynDataset\n",
"from data.loader import FastDataLoader\n",
"\n",
"\n",
"DATA_DIR = '../data/
sp_view_syn_2020.12.26
'\n",
"DATA_DIR = '../data/
nerf_fern
'\n",
"TRAIN_DATA_DESC_FILE = DATA_DIR + '/train.json'\n",
"TRAIN_DATA_DESC_FILE = DATA_DIR + '/train.json'\n",
"DEPTH_RANGE = (1, 10)\n",
"DEPTH_RANGE = (1, 10)\n",
"N_DEPTH_LAYERS = 10\n",
"N_DEPTH_LAYERS = 10\n",
"\n",
"\n",
"\n",
"def _GetSphereLayers(depth_range: Tuple[float, float], n_layers: int) -> torch.Tensor:\n",
"def _GetSphereLayers(depth_range: Tuple[float, float], n_layers: int) -> torch.Tensor:\n",
" diopter_range = (1 / depth_range[1], 1 / depth_range[0])\n",
" diopter_range = (1 / depth_range[1], 1 / depth_range[0])\n",
" step = (diopter_range[1] - diopter_range[0]) / (n_layers - 1)\n",
" step = (diopter_range[1] - diopter_range[0]) / (n_layers - 1)\n",
...
@@ -163,74 +178,54 @@
...
@@ -163,74 +178,54 @@
" depths += [1 / (diopter_range[0] + step * i) for i in range(n_layers)]\n",
" depths += [1 / (diopter_range[0] + step * i) for i in range(n_layers)]\n",
" return torch.tensor(depths, device=device.GetDevice()).view(-1, 1)\n",
" return torch.tensor(depths, device=device.GetDevice()).view(-1, 1)\n",
"\n",
"\n",
"train_dataset = SphericalViewSynDataset(TRAIN_DATA_DESC_FILE)\n",
"train_data_loader = torch.utils.data.DataLoader(\n",
" dataset=train_dataset,\n",
" batch_size=4,\n",
" num_workers=8,\n",
" pin_memory=True,\n",
" shuffle=True,\n",
" drop_last=False)\n",
"print(len(train_data_loader))\n",
"\n",
"print(\"view_res\", train_dataset.view_res)\n",
"print(\"cam_params\", train_dataset.cam_params)\n",
"\n",
"msl_net = MslNet(train_dataset.cam_params,\n",
" _GetSphereLayers(DEPTH_RANGE, N_DEPTH_LAYERS),\n",
" train_dataset.view_res).to(device.GetDevice())\n",
"print(\"sphere layers\", msl_net.rendering.sphere_layers)\n",
"\n",
"p = None\n",
"v = None\n",
"centers = None\n",
"plt.figure(figsize=(6, 6))\n",
"for _, view_images, ray_positions, ray_directions in train_data_loader:\n",
" p = ray_positions\n",
" v = ray_directions\n",
" plt.subplot(2, 2, 1)\n",
" util.PlotImageTensor(view_images[0])\n",
" plt.subplot(2, 2, 2)\n",
" util.PlotImageTensor(view_images[1])\n",
" plt.subplot(2, 2, 3)\n",
" util.PlotImageTensor(view_images[2])\n",
" plt.subplot(2, 2, 4)\n",
" util.PlotImageTensor(view_images[3])\n",
" break\n",
"p_ = util.SphericalToCartesian(RayToSpherical(p.flatten(0, 1), v.flatten(0, 1),\n",
" torch.tensor([[1]], device=device.GetDevice()))) \\\n",
" .view(4, train_dataset.view_res[0], train_dataset.view_res[1], 3)\n",
"v = v.view(4, train_dataset.view_res[0], train_dataset.view_res[1], 3)[:, 0::50, 0::50, :].flatten(1, 2).cpu().numpy()\n",
"p_ = p_[:, 0::50, 0::50, :].flatten(1, 2).cpu().numpy()\n",
"\n",
"fig = plt.figure(figsize=(6, 6))\n",
"ax = fig.gca(projection='3d')\n",
"plt.xlabel('x')\n",
"plt.ylabel('z')\n",
"\n",
"\n",
"PlotSphere(ax, 1)\n",
"dataset = SphericalViewSynDataset(TRAIN_DATA_DESC_FILE)\n",
"dataset.set_patch_size(1)\n",
"data_loader = FastDataLoader(\n",
" dataset=dataset, batch_size=4096*16, shuffle=True, drop_last=False)\n",
"\n",
"\n",
"ax.scatter([0], [0], [0], color=\"k\", s=10) # Center\n",
"print(\"view_res\", dataset.view_res)\n",
"print(\"cam_params\", dataset.cam_params)\n",
"\n",
"\n",
"colors = [ 'r', 'g', 'b', 'y' ]\n",
"fig = plt.figure(figsize=(16, 40))\n",
"for i in range(4):\n",
" ax.scatter(p_[i, :, 0], p_[i, :, 2], p_[i, :, 1], color=colors[i], s=3)\n",
" for j in range(p_.shape[1]):\n",
" ax.plot([centers[i, 0], centers[i, 0] + v[i, j, 0]],\n",
" [centers[i, 2], centers[i, 2] + v[i, j, 2]],\n",
" [centers[i, 1], centers[i, 1] + v[i, j, 1]],\n",
" color=colors[i], linewidth=0.5, alpha=0.6)\n",
"\n",
"\n",
"ax.set_xlim(-1, 1)\n",
"for ri in range(0, 10):\n",
"ax.set_ylim(-1, 1)\n",
" r = ri * 0.2 + 1\n",
"ax.set_zlim(-1, 1)\n",
" p = None\n",
"\n",
" centers = None\n",
"plt.show()\n"
" pixels = None\n",
" for indices, patches, rays_o, rays_d in data_loader:\n",
" p = util.RaySphereIntersect(\n",
" rays_o, rays_d, torch.tensor([[r]], device=device.GetDevice()))[0] \\\n",
" .view(-1, 3).cpu().numpy()\n",
" centers = rays_o.view(-1, 3).cpu().numpy()\n",
" pixels = patches.view(-1, 3).cpu().numpy()\n",
" break\n",
" \n",
" ax = plt.subplot(5, 2, ri + 1, projection='3d')\n",
" #ax = plt.gca(projection='3d')\n",
" #ax = fig.gca()\n",
" plt.xlabel('x')\n",
" plt.ylabel('z')\n",
" plt.title('r = %f' % r)\n",
"\n",
" # PlotSphere(ax, 1)\n",
"\n",
" ax.scatter([0], [0], [0], color=\"k\", s=10)\n",
" ax.scatter(p[:, 0], p[:, 2], p[:, 1], color=pixels, s=0.5)\n",
"\n",
" #ax.set_xlim(-1, 1)\n",
" #ax.set_ylim(-1, 1)\n",
" #ax.set_zlim(-1, 1)\n",
" ax.view_init(elev=0,azim=-90)\n",
"\n"
]
]
},
},
{
{
"cell_type": "markdown",
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"metadata": {},
"outputs": [],
"source": [
"source": [
"# Test Sampler"
"# Test Sampler"
]
]
...
@@ -241,7 +236,7 @@
...
@@ -241,7 +236,7 @@
"metadata": {},
"metadata": {},
"outputs": [],
"outputs": [],
"source": [
"source": [
"from
deep_view_syn.
data.spherical_view_syn import SphericalViewSynDataset\n",
"from data.spherical_view_syn import SphericalViewSynDataset\n",
"\n",
"\n",
"DATA_DIR = '../data/sp_view_syn_2020.12.29_finetrans'\n",
"DATA_DIR = '../data/sp_view_syn_2020.12.29_finetrans'\n",
"TRAIN_DATA_DESC_FILE = DATA_DIR + '/train.json'\n",
"TRAIN_DATA_DESC_FILE = DATA_DIR + '/train.json'\n",
...
@@ -304,7 +299,7 @@
...
@@ -304,7 +299,7 @@
"metadata": {},
"metadata": {},
"outputs": [],
"outputs": [],
"source": [
"source": [
"from
deep_view_syn.
data.spherical_view_syn import SphericalViewSynDataset\n",
"from data.spherical_view_syn import SphericalViewSynDataset\n",
"\n",
"\n",
"DATA_DIR = '../data/sp_view_syn_2020.12.26_rotonly'\n",
"DATA_DIR = '../data/sp_view_syn_2020.12.26_rotonly'\n",
"TRAIN_DATA_DESC_FILE = DATA_DIR + '/train.json'\n",
"TRAIN_DATA_DESC_FILE = DATA_DIR + '/train.json'\n",
...
@@ -367,8 +362,10 @@
...
@@ -367,8 +362,10 @@
]
]
},
},
{
{
"cell_type": "markdown",
"cell_type": "code",
"execution_count": null,
"metadata": {},
"metadata": {},
"outputs": [],
"source": [
"source": [
"# Test Spherical View Synthesis"
"# Test Spherical View Synthesis"
]
]
...
@@ -381,9 +378,9 @@
...
@@ -381,9 +378,9 @@
"source": [
"source": [
"import ipywidgets as widgets # 控件库\n",
"import ipywidgets as widgets # 控件库\n",
"from IPython.display import display # 显示控件的方法\n",
"from IPython.display import display # 显示控件的方法\n",
"from
deep_view_syn.
data.spherical_view_syn import SphericalViewSynDataset\n",
"from data.spherical_view_syn import SphericalViewSynDataset\n",
"from
deep_view_syn
.spher_net import SpherNet\n",
"from
nets
.spher_net import SpherNet\n",
"from
deep_view_syn.
my import netio\n",
"from my import netio\n",
"\n",
"\n",
"DATA_DIR = '../data/sp_view_syn_2020.12.28_small'\n",
"DATA_DIR = '../data/sp_view_syn_2020.12.28_small'\n",
"DATA_DESC_FILE = DATA_DIR + '/train.json'\n",
"DATA_DESC_FILE = DATA_DIR + '/train.json'\n",
...
@@ -455,20 +452,12 @@
...
@@ -455,20 +452,12 @@
"})\n",
"})\n",
"display(slider_x, slider_y, slider_z, slider_theta, slider_phi, out)\n"
"display(slider_x, slider_y, slider_z, slider_theta, slider_phi, out)\n"
]
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
}
],
],
"metadata": {
"metadata": {
"kernelspec": {
"kernelspec": {
"display_name": "Python 3",
"display_name": "Python 3.7.9 64-bit ('pytorch': conda)",
"language": "python",
"name": "python379jvsc74a57bd0660ca2a75467d3af74a68fcc6f40bc78ab96b99ff17d2f100b5ca821fbb183f2"
"name": "python3"
},
},
"language_info": {
"language_info": {
"codemirror_mode": {
"codemirror_mode": {
...
@@ -480,7 +469,7 @@
...
@@ -480,7 +469,7 @@
"name": "python",
"name": "python",
"nbconvert_exporter": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"pygments_lexer": "ipython3",
"version": "3.7.
6-final
"
"version": "3.7.
9
"
}
}
},
},
"nbformat": 4,
"nbformat": 4,
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment