Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
Nianchen Deng
deeplightfield
Commits
f6604bd2
Commit
f6604bd2
authored
Mar 16, 2021
by
Nianchen Deng
Browse files
rebuttal version
parent
6e54b394
Changes
21
Show whitespace changes
Inline
Side-by-side
configs/fovea_nmsl4.py
View file @
f6604bd2
from
..
my
import
color_mode
from
my
import
color_mode
def
update_config
(
config
):
# Dataset settings
...
...
configs/fovea_rgb.py
View file @
f6604bd2
from
..
my
import
color_mode
from
my
import
color_mode
def
update_config
(
config
):
# Dataset settings
...
...
configs/new_fovea_rgb.py
View file @
f6604bd2
from
..
my
import
color_mode
from
my
import
color_mode
def
update_config
(
config
):
# Dataset settings
...
...
configs/periph_rgb.py
View file @
f6604bd2
from
..
my
import
color_mode
from
my
import
color_mode
def
update_config
(
config
):
# Dataset settings
...
...
configs/spherical_view_syn.py
View file @
f6604bd2
import
os
import
importlib
from
os.path
import
join
from
..my
import
color_mode
from
..nets.msl_net
import
MslNet
from
..nets.msl_net_new
import
NewMslNet
from
..nets.spher_net
import
SpherNet
from
my
import
color_mode
from
nets.msl_net
import
MslNet
from
nets.msl_net_new
import
NewMslNet
class
SphericalViewSynConfig
(
object
):
...
...
@@ -36,14 +34,13 @@ class SphericalViewSynConfig(object):
def
load
(
self
,
path
):
module_name
=
os
.
path
.
splitext
(
path
)[
0
].
replace
(
'/'
,
'.'
)
config_module
=
importlib
.
import_module
(
'deep_view_syn.'
+
module_name
)
config_module
=
importlib
.
import_module
(
module_name
)
config_module
.
update_config
(
self
)
self
.
name
=
module_name
.
split
(
'.'
)[
-
1
]
def
load_by_name
(
self
,
name
):
config_module
=
importlib
.
import_module
(
'
deep_view_syn.
configs.'
+
name
)
'configs.'
+
name
)
config_module
.
update_config
(
self
)
self
.
name
=
name
...
...
configs/us_fovea.py
View file @
f6604bd2
from
..
my
import
color_mode
from
my
import
color_mode
def
update_config
(
config
):
# Dataset settings
...
...
configs/us_periph.py
View file @
f6604bd2
from
..
my
import
color_mode
from
my
import
color_mode
def
update_config
(
config
):
# Dataset settings
...
...
configs/us_periph_new.py
View file @
f6604bd2
from
..
my
import
color_mode
from
my
import
color_mode
def
update_config
(
config
):
# Dataset settings
...
...
dash_test.py
View file @
f6604bd2
...
...
@@ -10,7 +10,7 @@ import plotly.express as px
import
pandas
as
pd
from
dash.dependencies
import
Input
,
Output
sys
.
path
.
append
(
os
.
path
.
abspath
(
sys
.
path
[
0
]
+
'/../'
))
#
sys.path.append(os.path.abspath(sys.path[0] + '/../'))
#__package__ = "deep_view_syn"
if
__name__
==
'__main__'
:
...
...
@@ -24,23 +24,30 @@ if __name__ == '__main__':
print
(
"Set CUDA:%d as current device."
%
torch
.
cuda
.
current_device
())
torch
.
autograd
.
set_grad_enabled
(
False
)
from
deep_view_syn.data.spherical_view_syn
import
*
from
deep_view_syn.configs.spherical_view_syn
import
SphericalViewSynConfig
from
deep_view_syn.my
import
netio
from
deep_view_syn.my
import
util
from
deep_view_syn.my
import
device
from
deep_view_syn.my
import
view
from
deep_view_syn.my.gen_final
import
GenFinal
from
deep_view_syn.nets.modules
import
Sampler
datadir
=
None
from
data.spherical_view_syn
import
*
from
configs.spherical_view_syn
import
SphericalViewSynConfig
from
my
import
netio
from
my
import
util
from
my
import
device
from
my
import
view
from
my.gen_final
import
GenFinal
from
nets.modules
import
Sampler
datadir
=
'data/__0_user_study/us_gas_periph_r135x135_t0.3_2021.01.16/'
data_desc_file
=
'train.json'
net_config
=
'periph_rgb@msl-rgb_e10_fc96x4_d1.00-50.00_s16'
net_path
=
datadir
+
net_config
+
'/model-epoch_200.pth'
fov
=
45
res
=
(
256
,
256
)
view_idx
=
4
center
=
(
0
,
0
)
def
load_net
(
path
):
print
(
path
)
config
=
SphericalViewSynConfig
()
config
.
from_id
(
os
.
path
.
splitext
(
os
.
path
.
basename
(
path
))[
0
]
)
config
.
from_id
(
net_config
)
config
.
SAMPLE_PARAMS
[
'perturb_sample'
]
=
False
net
=
config
.
create_net
().
to
(
device
.
GetDevice
())
netio
.
LoadNet
(
path
,
net
)
...
...
@@ -64,24 +71,25 @@ def load_views(data_desc_file) -> view.Trans:
return
view
.
Trans
(
view_centers
,
view_rots
)
scenes
=
{
'gas'
:
'__0_user_study/us_gas_all_in_one'
,
'mc'
:
'__0_user_study/us_mc_all_in_one'
,
'bedroom'
:
'bedroom_all_in_one'
,
'gallery'
:
'gallery_all_in_one'
,
'lobby'
:
'lobby_all_in_one'
}
fov_list
=
[
20
,
45
,
110
]
res_list
=
[(
128
,
128
),
(
256
,
256
),
(
256
,
230
)]
res_full
=
(
1600
,
1440
)
cam
=
view
.
CameraParam
({
'fov'
:
fov
,
'cx'
:
0.5
,
'cy'
:
0.5
,
'normalized'
:
True
},
res
,
device
=
device
.
GetDevice
())
net
=
load_net
(
net_path
)
sampler
=
Sampler
(
depth_range
=
(
1
,
50
),
n_samples
=
32
,
perturb_sample
=
False
,
spherical
=
True
,
lindisp
=
True
,
inverse_r
=
True
)
x
=
y
=
None
views
=
load_views
(
data_desc_file
)
print
(
'%d Views loaded.'
%
views
.
size
()[
0
])
scene
=
'gas'
view_file
=
'views.json'
test_view
=
views
.
get
(
view_idx
)
rays_o
,
rays_d
=
cam
.
get_global_rays
(
test_view
,
True
)
image
=
net
(
rays_o
.
view
(
-
1
,
3
),
rays_d
.
view
(
-
1
,
3
)).
view
(
1
,
res
[
0
],
res
[
1
],
-
1
).
permute
(
0
,
3
,
1
,
2
)
app
=
dash
.
Dash
(
__name__
,
external_stylesheets
=
[
'https://codepen.io/chriddyp/pen/bWLwgP.css'
])
styles
=
{
'pre'
:
{
...
...
@@ -90,35 +98,17 @@ styles = {
'overflowX'
:
'scroll'
}
}
datadir
=
'data/'
+
scenes
[
scene
]
+
'/'
fovea_net
=
load_net_by_name
(
'fovea'
)
periph_net
=
load_net_by_name
(
'periph'
)
gen
=
GenFinal
(
fov_list
,
res_list
,
res_full
,
fovea_net
,
periph_net
,
device
=
device
.
GetDevice
())
sampler
=
Sampler
(
depth_range
=
(
1
,
50
),
n_samples
=
32
,
perturb_sample
=
False
,
spherical
=
True
,
lindisp
=
True
,
inverse_r
=
True
)
x
=
y
=
None
views
=
load_views
(
view_file
)
print
(
'%d Views loaded.'
,
views
.
size
())
view_idx
=
27
center
=
(
0
,
0
)
test_view
=
views
.
get
(
view_idx
)
images
=
gen
(
center
,
test_view
)
fig
=
px
.
imshow
(
util
.
Tensor2MatImg
(
images
[
'fovea'
]))
fig
=
px
.
imshow
(
util
.
Tensor2MatImg
(
image
))
fig1
=
px
.
scatter
(
x
=
[
0
,
1
,
2
],
y
=
[
2
,
0
,
1
])
fig2
=
px
.
scatter
(
x
=
[
0
,
1
,
2
],
y
=
[
2
,
0
,
1
])
app
=
dash
.
Dash
(
__name__
,
external_stylesheets
=
[
'https://codepen.io/chriddyp/pen/bWLwgP.css'
])
app
.
layout
=
html
.
Div
([
html
.
H3
(
"Drag and draw annotations"
),
html
.
Div
(
className
=
'row'
,
children
=
[
dcc
.
Graph
(
id
=
'image'
,
figure
=
fig
),
# , config=config),
dcc
.
Graph
(
id
=
'scatter'
,
figure
=
fig1
),
# , config=config),
dcc
.
Graph
(
id
=
'scatter1'
,
figure
=
fig2
),
# , config=config),
dcc
.
Slider
(
id
=
'samples-slider'
,
min
=
4
,
max
=
128
,
step
=
None
,
marks
=
{
4
:
'4'
,
...
...
@@ -128,43 +118,91 @@ app.layout = html.Div([
64
:
'64'
,
128
:
'128'
,
},
value
=
3
2
,
value
=
3
3
,
updatemode
=
'drag'
)
])
])
def
raw2alpha
(
raw
,
dists
,
act_fn
=
torch
.
relu
):
"""
Function for computing density from model prediction.
This value is strictly between [0, 1].
"""
print
(
'act_fn: '
,
act_fn
(
raw
))
print
(
'act_fn * dists: '
,
act_fn
(
raw
)
*
dists
)
return
-
torch
.
exp
(
-
act_fn
(
raw
)
*
dists
)
+
1.0
def
raw2color
(
raw
:
torch
.
Tensor
,
z_vals
:
torch
.
Tensor
):
"""
Raw value inferred from model to color and alpha
:param raw ```Tensor(N.rays, N.samples, 2|4)```: model's output
:param z_vals ```Tensor(N.rays, N.samples)```: integration time
:return ```Tensor(N.rays, N.samples, 1|3)```: color
:return ```Tensor(N.rays, N.samples)```: alpha
"""
# Compute 'distance' (in time) between each integration time along a ray.
# The 'distance' from the last integration time is infinity.
# dists: (N_rays, N_samples)
dists
=
z_vals
[...,
1
:]
-
z_vals
[...,
:
-
1
]
last_dist
=
z_vals
[...,
0
:
1
]
*
0
+
1e10
dists
=
torch
.
cat
([
dists
,
last_dist
],
-
1
)
print
(
'dists: '
,
dists
)
# Extract RGB of each sample position along each ray.
color
=
torch
.
sigmoid
(
raw
[...,
:
-
1
])
# (N_rays, N_samples, 1|3)
alpha
=
raw2alpha
(
raw
[...,
-
1
],
dists
)
return
color
,
alpha
def
draw_scatter
():
global
fig1
p
=
torch
.
tensor
([
x
,
y
],
device
=
gen
.
layer_cams
[
0
].
c
.
d
evice
)
ray_d
=
test_view
.
trans_vector
(
gen
.
layer_cams
[
0
]
.
unproj
(
p
))
global
fig1
,
fig2
p
=
torch
.
tensor
([
x
,
y
],
device
=
device
.
GetD
evice
()
)
ray_d
=
test_view
.
trans_vector
(
cam
.
unproj
(
p
))
ray_o
=
test_view
.
t
raw
,
depths
=
fovea_
net
.
sample_and_infer
(
ray_o
,
ray_d
,
sampler
=
sampler
)
colors
,
alphas
=
fovea_net
.
rendering
.
raw2color
(
raw
,
depths
)
raw
,
depths
=
net
.
sample_and_infer
(
ray_o
,
ray_d
,
sampler
=
sampler
)
colors
,
alphas
=
raw2color
(
raw
,
depths
)
scatter_x
=
(
1
/
depths
[
0
]).
cpu
().
detach
().
numpy
()
scatter_y
=
alphas
[
0
].
cpu
().
detach
().
numpy
()
scatter_y1
=
raw
[
0
,
:,
3
].
cpu
().
detach
().
numpy
()
scatter_color
=
colors
[
0
].
cpu
().
detach
().
numpy
()
*
255
marker_colors
=
[
i
#'rgb(%d,%d,%d)' % (scatter_color[i][0], scatter_color[i][1], scatter_color[i][2])
# 'rgb(%d,%d,%d)' % (scatter_color[i][0], scatter_color[i][1], scatter_color[i][2])
i
for
i
in
range
(
scatter_color
.
shape
[
0
])
]
marker_colors_str
=
[
'rgb(%d,%d,%d)'
%
(
scatter_color
[
i
][
0
],
scatter_color
[
i
][
1
],
scatter_color
[
i
][
2
])
'rgb(%d,%d,%d)'
%
(
scatter_color
[
i
][
0
],
scatter_color
[
i
][
1
],
scatter_color
[
i
][
2
])
for
i
in
range
(
scatter_color
.
shape
[
0
])
]
fig1
=
px
.
scatter
(
x
=
scatter_x
,
y
=
scatter_y
,
color
=
marker_colors
,
color_continuous_scale
=
marker_colors_str
)
#, color_discrete_map='identity')
fig1
=
px
.
scatter
(
x
=
scatter_x
,
y
=
scatter_y
,
color
=
marker_colors
,
color_continuous_scale
=
marker_colors_str
)
# , color_discrete_map='identity')
fig1
.
update_traces
(
mode
=
'lines+markers'
)
fig1
.
update_xaxes
(
showgrid
=
False
)
fig1
.
update_yaxes
(
type
=
'linear'
)
fig1
.
update_layout
(
height
=
225
,
margin
=
{
'l'
:
20
,
'b'
:
30
,
'r'
:
10
,
't'
:
10
})
fig2
=
px
.
scatter
(
x
=
scatter_x
,
y
=
scatter_y1
,
color
=
marker_colors
,
color_continuous_scale
=
marker_colors_str
)
# , color_discrete_map='identity')
fig2
.
update_traces
(
mode
=
'lines+markers'
)
fig2
.
update_xaxes
(
showgrid
=
False
)
fig2
.
update_yaxes
(
type
=
'linear'
)
fig2
.
update_layout
(
height
=
225
,
margin
=
{
'l'
:
20
,
'b'
:
30
,
'r'
:
10
,
't'
:
10
})
@
app
.
callback
(
[
Output
(
'image'
,
'figure'
),
Output
(
'scatter'
,
'figure'
)],
Output
(
'scatter'
,
'figure'
),
Output
(
'scatter1'
,
'figure'
)],
[
Input
(
'image'
,
'clickData'
),
dash
.
dependencies
.
Input
(
'samples-slider'
,
'value'
)]
)
...
...
@@ -194,7 +232,7 @@ def display_hover_data(clickData, samples):
color
=
"LightSeaGreen"
,
width
=
3
,
))
return
fig
,
fig1
return
fig
,
fig1
,
fig2
if
__name__
==
'__main__'
:
...
...
data/lf_syn.py
deleted
100644 → 0
View file @
6e54b394
from
typing
import
List
,
Tuple
import
torch
import
json
from
..my
import
util
def
ReadLightField
(
path
:
str
,
views
:
Tuple
[
int
,
int
],
flatten_views
:
bool
=
False
)
->
torch
.
Tensor
:
input_img
=
util
.
ReadImageTensor
(
path
,
batch_dim
=
False
)
h
=
input_img
.
size
()[
1
]
//
views
[
0
]
w
=
input_img
.
size
()[
2
]
//
views
[
1
]
if
flatten_views
:
lf
=
torch
.
empty
(
views
[
0
]
*
views
[
1
],
3
,
h
,
w
)
for
y_i
in
range
(
views
[
0
]):
for
x_i
in
range
(
views
[
1
]):
lf
[
y_i
*
views
[
1
]
+
x_i
,
:,
:,
:]
=
\
input_img
[:,
y_i
*
h
:(
y_i
+
1
)
*
h
,
x_i
*
w
:(
x_i
+
1
)
*
w
]
else
:
lf
=
torch
.
empty
(
views
[
0
],
views
[
1
],
3
,
h
,
w
)
for
y_i
in
range
(
views
[
0
]):
for
x_i
in
range
(
views
[
1
]):
lf
[
y_i
,
x_i
,
:,
:,
:]
=
\
input_img
[:,
y_i
*
h
:(
y_i
+
1
)
*
h
,
x_i
*
w
:(
x_i
+
1
)
*
w
]
return
lf
def
DecodeDepth
(
depth_images
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
depth_images
[:,
0
].
unsqueeze
(
1
).
mul
(
255
)
/
10
class
LightFieldSynDataset
(
torch
.
utils
.
data
.
dataset
.
Dataset
):
"""
Data loader for light field synthesis task
Attributes
--------
data_dir ```string```: the directory of dataset
\n
n_views ```tuple(int, int)```: rows and columns of views
\n
num_views ```int```: number of views
\n
view_images ```N x H x W Tensor```: images of views
\n
view_depths ```N x H x W Tensor```: depths of views
\n
view_positions ```N x 3 Tensor```: positions of views
\n
sparse_view_images ```N' x H x W Tensor```: images of sparse views
\n
sparse_view_depths ```N' x H x W Tensor```: depths of sparse views
\n
sparse_view_positions ```N' x 3 Tensor```: positions of sparse views
\n
"""
def
__init__
(
self
,
data_desc_path
:
str
):
"""
Initialize data loader for light field synthesis task
The data description file is a JSON file with following fields:
- lf: string, the path of light field image
- lf_depth: string, the path of light field depth image
- n_views: { "x", "y" }, columns and rows of views
- cam_params: { "f", "c" }, the focal and center of camera (in normalized image space)
- depth_range: [ min, max ], the range of depth in depth maps
- depth_layers: int, number of layers in depth maps
- view_positions: [ [ x, y, z ], ... ], positions of views
:param data_desc_path: path to the data description file
"""
self
.
data_dir
=
data_desc_path
.
rsplit
(
'/'
,
1
)[
0
]
+
'/'
with
open
(
data_desc_path
,
'r'
,
encoding
=
'utf-8'
)
as
file
:
self
.
data_desc
=
json
.
loads
(
file
.
read
())
self
.
n_views
=
(
self
.
data_desc
[
'n_views'
]
[
'y'
],
self
.
data_desc
[
'n_views'
][
'x'
])
self
.
num_views
=
self
.
n_views
[
0
]
*
self
.
n_views
[
1
]
self
.
view_images
=
ReadLightField
(
self
.
data_dir
+
self
.
data_desc
[
'lf'
],
self
.
n_views
,
True
)
self
.
view_depths
=
DecodeDepth
(
ReadLightField
(
self
.
data_dir
+
self
.
data_desc
[
'lf_depth'
],
self
.
n_views
,
True
))
self
.
cam_params
=
self
.
data_desc
[
'cam_params'
]
self
.
depth_range
=
self
.
data_desc
[
'depth_range'
]
self
.
depth_layers
=
self
.
data_desc
[
'depth_layers'
]
self
.
view_positions
=
torch
.
tensor
(
self
.
data_desc
[
'view_positions'
])
_
,
self
.
sparse_view_images
,
self
.
sparse_view_depths
,
self
.
sparse_view_positions
\
=
self
.
_GetCornerViews
()
self
.
diopter_of_layers
=
self
.
_GetDiopterOfLayers
()
def
__len__
(
self
):
return
self
.
num_views
def
__getitem__
(
self
,
idx
):
return
idx
,
self
.
view_images
[
idx
],
self
.
view_depths
[
idx
],
self
.
view_positions
[
idx
]
def
_GetCornerViews
(
self
):
corner_selector
=
torch
.
zeros
(
self
.
num_views
,
dtype
=
torch
.
bool
)
corner_selector
[
0
]
=
corner_selector
[
self
.
n_views
[
1
]
-
1
]
\
=
corner_selector
[
self
.
num_views
-
self
.
n_views
[
1
]]
\
=
corner_selector
[
self
.
num_views
-
1
]
=
True
return
self
.
__getitem__
(
corner_selector
)
def
_GetDiopterOfLayers
(
self
)
->
List
[
float
]:
diopter_range
=
(
1
/
self
.
depth_range
[
1
],
1
/
self
.
depth_range
[
0
])
step
=
(
diopter_range
[
1
]
-
diopter_range
[
0
])
/
(
self
.
depth_layers
-
1
)
diopter_of_layers
=
[
diopter_range
[
0
]
+
step
*
i
for
i
in
range
(
self
.
depth_layers
)]
diopter_of_layers
.
insert
(
0
,
0
)
return
diopter_of_layers
data/loader.py
View file @
f6604bd2
import
torch
import
math
from
..
my
import
device
from
my
import
device
class
FastDataLoader
(
object
):
...
...
data/other.py
View file @
f6604bd2
import
torch
import
os
import
json
import
glob
import
cv2
import
numpy
as
np
import
torchvision.transforms
as
transforms
from
typing
import
List
,
Tuple
from
torchvision
import
datasets
from
torch.utils.data
import
DataLoader
import
cv2
import
json
from
Flow
import
*
from
gen_image
import
*
import
util
from
my.flow
import
*
from
my.gen_image
import
*
from
my
import
util
def
ReadLightField
(
path
:
str
,
views
:
Tuple
[
int
,
int
],
flatten_views
:
bool
=
False
)
->
torch
.
Tensor
:
input_img
=
util
.
ReadImageTensor
(
path
,
batch_dim
=
False
)
h
=
input_img
.
size
()[
1
]
//
views
[
0
]
w
=
input_img
.
size
()[
2
]
//
views
[
1
]
if
flatten_views
:
lf
=
torch
.
empty
(
views
[
0
]
*
views
[
1
],
3
,
h
,
w
)
for
y_i
in
range
(
views
[
0
]):
for
x_i
in
range
(
views
[
1
]):
lf
[
y_i
*
views
[
1
]
+
x_i
,
:,
:,
:]
=
\
input_img
[:,
y_i
*
h
:(
y_i
+
1
)
*
h
,
x_i
*
w
:(
x_i
+
1
)
*
w
]
else
:
lf
=
torch
.
empty
(
views
[
0
],
views
[
1
],
3
,
h
,
w
)
for
y_i
in
range
(
views
[
0
]):
for
x_i
in
range
(
views
[
1
]):
lf
[
y_i
,
x_i
,
:,
:,
:]
=
\
input_img
[:,
y_i
*
h
:(
y_i
+
1
)
*
h
,
x_i
*
w
:(
x_i
+
1
)
*
w
]
return
lf
def
DecodeDepth
(
depth_images
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
depth_images
[:,
0
].
unsqueeze
(
1
).
mul
(
255
)
/
10
class
LightFieldSynDataset
(
torch
.
utils
.
data
.
dataset
.
Dataset
):
"""
Data loader for light field synthesis task
Attributes
--------
data_dir ```string```: the directory of dataset
\n
n_views ```tuple(int, int)```: rows and columns of views
\n
num_views ```int```: number of views
\n
view_images ```N x H x W Tensor```: images of views
\n
view_depths ```N x H x W Tensor```: depths of views
\n
view_positions ```N x 3 Tensor```: positions of views
\n
sparse_view_images ```N' x H x W Tensor```: images of sparse views
\n
sparse_view_depths ```N' x H x W Tensor```: depths of sparse views
\n
sparse_view_positions ```N' x 3 Tensor```: positions of sparse views
\n
"""
def
__init__
(
self
,
data_desc_path
:
str
):
"""
Initialize data loader for light field synthesis task
The data description file is a JSON file with following fields:
- lf: string, the path of light field image
- lf_depth: string, the path of light field depth image
- n_views: { "x", "y" }, columns and rows of views
- cam_params: { "f", "c" }, the focal and center of camera (in normalized image space)
- depth_range: [ min, max ], the range of depth in depth maps
- depth_layers: int, number of layers in depth maps
- view_positions: [ [ x, y, z ], ... ], positions of views
:param data_desc_path: path to the data description file
"""
self
.
data_dir
=
data_desc_path
.
rsplit
(
'/'
,
1
)[
0
]
+
'/'
with
open
(
data_desc_path
,
'r'
,
encoding
=
'utf-8'
)
as
file
:
self
.
data_desc
=
json
.
loads
(
file
.
read
())
self
.
n_views
=
(
self
.
data_desc
[
'n_views'
]
[
'y'
],
self
.
data_desc
[
'n_views'
][
'x'
])
self
.
num_views
=
self
.
n_views
[
0
]
*
self
.
n_views
[
1
]
self
.
view_images
=
ReadLightField
(
self
.
data_dir
+
self
.
data_desc
[
'lf'
],
self
.
n_views
,
True
)
self
.
view_depths
=
DecodeDepth
(
ReadLightField
(
self
.
data_dir
+
self
.
data_desc
[
'lf_depth'
],
self
.
n_views
,
True
))
self
.
cam_params
=
self
.
data_desc
[
'cam_params'
]
self
.
depth_range
=
self
.
data_desc
[
'depth_range'
]
self
.
depth_layers
=
self
.
data_desc
[
'depth_layers'
]
self
.
view_positions
=
torch
.
tensor
(
self
.
data_desc
[
'view_positions'
])
_
,
self
.
sparse_view_images
,
self
.
sparse_view_depths
,
self
.
sparse_view_positions
\
=
self
.
_GetCornerViews
()
self
.
diopter_of_layers
=
self
.
_GetDiopterOfLayers
()
def
__len__
(
self
):
return
self
.
num_views
def
__getitem__
(
self
,
idx
):
return
idx
,
self
.
view_images
[
idx
],
self
.
view_depths
[
idx
],
self
.
view_positions
[
idx
]
def
_GetCornerViews
(
self
):
corner_selector
=
torch
.
zeros
(
self
.
num_views
,
dtype
=
torch
.
bool
)
corner_selector
[
0
]
=
corner_selector
[
self
.
n_views
[
1
]
-
1
]
\
=
corner_selector
[
self
.
num_views
-
self
.
n_views
[
1
]]
\
=
corner_selector
[
self
.
num_views
-
1
]
=
True
return
self
.
__getitem__
(
corner_selector
)
import
time
def
_GetDiopterOfLayers
(
self
)
->
List
[
float
]:
diopter_range
=
(
1
/
self
.
depth_range
[
1
],
1
/
self
.
depth_range
[
0
])
step
=
(
diopter_range
[
1
]
-
diopter_range
[
0
])
/
(
self
.
depth_layers
-
1
)
diopter_of_layers
=
[
diopter_range
[
0
]
+
step
*
i
for
i
in
range
(
self
.
depth_layers
)]
diopter_of_layers
.
insert
(
0
,
0
)
return
diopter_of_layers
class
lightFieldSynDataLoader
(
torch
.
utils
.
data
.
dataset
.
Dataset
):
...
...
@@ -90,8 +185,8 @@ class lightFieldDataLoader(torch.utils.data.dataset.Dataset):
# print(lf_image_big.shape)
for
i
in
range
(
9
):
lf_image
=
lf_image_big
[
i
//
3
*
IM_H
:
i
//
3
*
IM_H
+
IM_H
,
i
%
3
*
IM_W
:
i
%
3
*
IM_W
+
IM_W
,
0
:
3
]
lf_image
=
lf_image_big
[
i
//
3
*
IM_H
:
i
//
3
*
IM_H
+
IM_H
,
i
%
3
*
IM_W
:
i
%
3
*
IM_W
+
IM_W
,
0
:
3
]
# IF GrayScale
# lf_image = lf_image_big[i//3*IM_H:i//3*IM_H+IM_H,i%3*IM_W:i%3*IM_W+IM_W,0:1]
# print(lf_image.shape)
...
...
@@ -146,8 +241,8 @@ class lightFieldValDataLoader(torch.utils.data.dataset.Dataset):
# print(lf_image_big.shape)
for
i
in
range
(
9
):
lf_image
=
lf_image_big
[
i
//
3
*
IM_H
:
i
//
3
*
IM_H
+
IM_H
,
i
%
3
*
IM_W
:
i
%
3
*
IM_W
+
IM_W
,
0
:
3
]
lf_image
=
lf_image_big
[
i
//
3
*
IM_H
:
i
//
3
*
IM_H
+
IM_H
,
i
%
3
*
IM_W
:
i
%
3
*
IM_W
+
IM_W
,
0
:
3
]
# IF GrayScale
# lf_image = lf_image_big[i//3*IM_H:i//3*IM_H+IM_H,i%3*IM_W:i%3*IM_W+IM_W,0:1]
# print(lf_image.shape)
...
...
@@ -214,8 +309,8 @@ class lightFieldSeqDataLoader(torch.utils.data.dataset.Dataset):
lf_image_big
=
cv2
.
cvtColor
(
lf_image_big
,
cv2
.
COLOR_BGR2RGB
)
for
j
in
range
(
9
):
lf_image
=
lf_image_big
[
j
//
3
*
IM_H
:
j
//
3
*
IM_H
+
IM_H
,
j
%
3
*
IM_W
:
j
%
3
*
IM_W
+
IM_W
,
0
:
3
]
lf_image
=
lf_image_big
[
j
//
3
*
IM_H
:
j
//
3
*
IM_H
+
IM_H
,
j
%
3
*
IM_W
:
j
%
3
*
IM_W
+
IM_W
,
0
:
3
]
lf_image_one_sample
.
append
(
lf_image
)
gt_i
=
cv2
.
imread
(
fd_gt_path
,
cv2
.
IMREAD_UNCHANGED
).
astype
(
...
...
@@ -297,8 +392,8 @@ class lightFieldFlowSeqDataLoader(torch.utils.data.dataset.Dataset):
lf_dim
=
int
(
self
.
conf
.
light_field_dim
)
for
j
in
range
(
lf_dim
**
2
):
lf_image
=
lf_image_big
[
j
//
lf_dim
*
IM_H
:
j
//
lf_dim
*
IM_H
+
IM_H
,
j
%
lf_dim
*
IM_W
:
j
%
lf_dim
*
IM_W
+
IM_W
,
0
:
3
]
lf_image
=
lf_image_big
[
j
//
lf_dim
*
IM_H
:
j
//
lf_dim
*
IM_H
+
IM_H
,
j
%
lf_dim
*
IM_W
:
j
%
lf_dim
*
IM_W
+
IM_W
,
0
:
3
]
lf_image_one_sample
.
append
(
lf_image
)
gt_i
=
cv2
.
imread
(
fd_gt_path
,
cv2
.
IMREAD_UNCHANGED
).
astype
(
...
...
@@ -333,7 +428,7 @@ class lightFieldFlowSeqDataLoader(torch.utils.data.dataset.Dataset):
retinal_invalid
.
append
(
retinal_invalid_i
)
# lf_images: 5,9,320,320
flow
=
Flow
.
Load
([
os
.
path
.
join
(
self
.
file_dir_path
,
self
.
dataset_desc
[
"flow"
]
[
indices
[
i
-
1
]])
for
i
in
range
(
1
,
len
(
indices
))])
[
indices
[
i
-
1
]])
for
i
in
range
(
1
,
len
(
indices
))])
flow_map
=
flow
.
getMap
()
flow_invalid_mask
=
flow
.
b_invalid_mask
# print("flow:",flow_map.shape)
...
...
data/spherical_view_syn.py
View file @
f6604bd2
...
...
@@ -4,10 +4,10 @@ import torch
import
torchvision.transforms.functional
as
trans_f
import
torch.nn.functional
as
nn_f
from
typing
import
Tuple
,
Union
from
..
my
import
util
from
..
my
import
device
from
..
my
import
view
from
..
my
import
color_mode
from
my
import
util
from
my
import
device
from
my
import
view
from
my
import
color_mode
class
SphericalViewSynDataset
(
object
):
...
...
@@ -129,6 +129,13 @@ class SphericalViewSynDataset(object):
self
.
n_views
=
self
.
view_centers
.
size
(
0
)
self
.
n_pixels
=
self
.
n_views
*
self
.
view_res
[
0
]
*
self
.
view_res
[
1
]
if
'gl_coord'
in
data_desc
and
data_desc
[
'gl_coord'
]
==
True
:
print
(
'Convert from OGL coordinate to DX coordinate (i. e. right-hand to left-hand)'
)
self
.
cam_params
.
f
[
1
]
*=
-
1
self
.
view_centers
[:,
2
]
*=
-
1
self
.
view_rots
[:,
2
]
*=
-
1
self
.
view_rots
[...,
2
]
*=
-
1
def
set_patch_size
(
self
,
patch_size
:
Union
[
int
,
Tuple
[
int
,
int
]],
offset
:
Union
[
int
,
Tuple
[
int
,
int
]]
=
0
):
"""
...
...
nets/modules.py
View file @
f6604bd2
from
typing
import
List
,
Tuple
import
torch
import
torch.nn
as
nn
from
..
my
import
device
from
..
my
import
util
from
my
import
device
from
my
import
util
class
FcLayer
(
nn
.
Module
):
...
...
nets/msl_net.py
View file @
f6604bd2
...
...
@@ -2,9 +2,8 @@ import math
import
torch
import
torch.nn
as
nn
from
.modules
import
*
from
..my
import
util
from
..my
import
color_mode
from
my
import
util
from
my
import
color_mode
class
MslNet
(
nn
.
Module
):
...
...
nets/msl_net_new.py
View file @
f6604bd2
import
torch
import
torch.nn
as
nn
from
.modules
import
*
from
..
my
import
color_mode
from
..
my.simple_perf
import
SimplePerf
from
my
import
color_mode
from
my.simple_perf
import
SimplePerf
class
NewMslNet
(
nn
.
Module
):
...
...
nets/msl_net_new_export.py
View file @
f6604bd2
...
...
@@ -2,10 +2,10 @@ from typing import Tuple
import
math
import
torch
import
torch.nn
as
nn
from
..
my
import
net_modules
from
..
my
import
util
from
..
my
import
device
from
..
my
import
color_mode
from
my
import
net_modules
from
my
import
util
from
my
import
device
from
my
import
color_mode
from
.msl_net_new
import
NewMslNet
...
...
nets/spher_net.py
View file @
f6604bd2
import
torch
import
torch.nn
as
nn
from
.modules
import
*
from
..
my
import
util
from
my
import
util
class
SpherNet
(
nn
.
Module
):
...
...
nets/trans_unet.py
View file @
f6604bd2
from
typing
import
List
import
torch
import
torch.nn
as
nn
from
..
pytorch_prototyping.pytorch_prototyping
import
*
from
..
my
import
util
from
..
my
import
device
from
pytorch_prototyping.pytorch_prototyping
import
*
from
my
import
util
from
my
import
device
class
Encoder
(
nn
.
Module
):
...
...
notebook/test_spherical_view_syn.ipynb
View file @
f6604bd2
...
...
@@ -2,20 +2,28 @@
"cells": [
{
"cell_type": "code",
"execution_count":
null
,
"execution_count":
2
,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Set CUDA:2 as current device.\n"
]
}
],
"source": [
"import sys\n",
"import os\n",
"sys.path.append(os.path.abspath(sys.path[0] + '/../
../
'))\n",
"sys.path.append(os.path.abspath(sys.path[0] + '/../'))\n",
"\n",
"import torch\n",
"import math\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"from
deep_view_syn.
my import util\n",
"from
deep_view_syn
.msl_net import *\n",
"from my import util\n",
"from
nets
.msl_net import *\n",
"\n",
"# Select device\n",
"torch.cuda.set_device(2)\n",
...
...
@@ -23,8 +31,10 @@
]
},
{
"cell_type": "markdown",
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"# Test Ray-Sphere Intersection & Cartesian-Spherical Conversion"
]
...
...
@@ -56,14 +66,15 @@
"v = torch.tensor([[0.0, -1.0, 1.0]])\n",
"r = torch.tensor([[2.5]])\n",
"v = v / torch.norm(v) * r * 2\n",
"p_on_sphere_ = RaySphereIntersect(p, v, r)[0]\n",
"p_on_sphere_ =
util.
RaySphereIntersect(p, v, r)[0]
[0]
\n",
"print(p_on_sphere_)\n",
"print(p_on_sphere_.norm())\n",
"spher_coord =
Ray
ToSpherical(p
, v, r
)\n",
"spher_coord =
util.Cartesian
ToSpherical(p
_on_sphere_
)\n",
"print(spher_coord[..., 1:3].rad2deg())\n",
"p_on_sphere = util.SphericalToCartesian(spher_coord)[0]\n",
"p_on_sphere = util.SphericalToCartesian(spher_coord)\n",
"print(p_on_sphere_.size())\n",
"\n",
"fig = plt.figure(figsize=(
6
,
6
))\n",
"fig = plt.figure(figsize=(
8
,
8
))\n",
"ax = fig.gca(projection='3d')\n",
"plt.xlabel('x')\n",
"plt.ylabel('z')\n",
...
...
@@ -109,8 +120,10 @@
]
},
{
"cell_type": "markdown",
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Test Dataset Loader & View-Spherical Transform"
]
...
...
@@ -121,26 +134,26 @@
"metadata": {},
"outputs": [],
"source": [
"from
deep_view_syn.
data.spherical_view_syn import
Fast
SphericalViewSynDataset\n",
"from d
eep_view_syn.data.spherical_view_syn
import FastDataLoader\n",
"from data.spherical_view_syn import SphericalViewSynDataset\n",
"from d
ata.loader
import FastDataLoader\n",
"\n",
"DATA_DIR = '../data/
sp_view_syn_2020.12.28
'\n",
"DATA_DIR = '../data/
nerf_fern
'\n",
"TRAIN_DATA_DESC_FILE = DATA_DIR + '/train.json'\n",
"\n",
"dataset =
Fast
SphericalViewSynDataset(TRAIN_DATA_DESC_FILE)\n",
"dataset = SphericalViewSynDataset(TRAIN_DATA_DESC_FILE)\n",
"dataset.set_patch_size((64, 64))\n",
"data_loader = FastDataLoader(dataset=dataset, batch_size=4, shuffle=False, drop_last=False)\n",
"print(len(dataset))\n",
"plt.figure()\n",
"
fig =
plt.figure(
figsize=(12, 6.5)
)\n",
"i = 0\n",
"for indices, patches, rays_o, rays_d in data_loader:\n",
" print(i, patches.size(), rays_o.size(), rays_d.size())\n",
" for idx in range(len(indices)):\n",
" plt.subplot(4,
4
, i + 1)\n",
" plt.subplot(4,
7
, i + 1)\n",
" util.PlotImageTensor(patches[idx])\n",
" i += 1\n",
" if i ==
16
:\n",
" break
\n
"
" if i ==
28
:\n",
" break"
]
},
{
...
...
@@ -149,13 +162,15 @@
"metadata": {},
"outputs": [],
"source": [
"from deep_view_syn.data.spherical_view_syn import SphericalViewSynDataset\n",
"from data.spherical_view_syn import SphericalViewSynDataset\n",
"from data.loader import FastDataLoader\n",
"\n",
"DATA_DIR = '../data/
sp_view_syn_2020.12.26
'\n",
"DATA_DIR = '../data/
nerf_fern
'\n",
"TRAIN_DATA_DESC_FILE = DATA_DIR + '/train.json'\n",
"DEPTH_RANGE = (1, 10)\n",
"N_DEPTH_LAYERS = 10\n",
"\n",
"\n",
"def _GetSphereLayers(depth_range: Tuple[float, float], n_layers: int) -> torch.Tensor:\n",
" diopter_range = (1 / depth_range[1], 1 / depth_range[0])\n",
" step = (diopter_range[1] - diopter_range[0]) / (n_layers - 1)\n",
...
...
@@ -163,74 +178,54 @@
" depths += [1 / (diopter_range[0] + step * i) for i in range(n_layers)]\n",
" return torch.tensor(depths, device=device.GetDevice()).view(-1, 1)\n",
"\n",
"train_dataset = SphericalViewSynDataset(TRAIN_DATA_DESC_FILE)\n",
"train_data_loader = torch.utils.data.DataLoader(\n",
" dataset=train_dataset,\n",
" batch_size=4,\n",
" num_workers=8,\n",
" pin_memory=True,\n",
" shuffle=True,\n",
" drop_last=False)\n",
"print(len(train_data_loader))\n",
"\n",
"print(\"view_res\", train_dataset.view_res)\n",
"print(\"cam_params\", train_dataset.cam_params)\n",
"\n",
"msl_net = MslNet(train_dataset.cam_params,\n",
" _GetSphereLayers(DEPTH_RANGE, N_DEPTH_LAYERS),\n",
" train_dataset.view_res).to(device.GetDevice())\n",
"print(\"sphere layers\", msl_net.rendering.sphere_layers)\n",
"\n",
"p = None\n",
"v = None\n",
"centers = None\n",
"plt.figure(figsize=(6, 6))\n",
"for _, view_images, ray_positions, ray_directions in train_data_loader:\n",
" p = ray_positions\n",
" v = ray_directions\n",
" plt.subplot(2, 2, 1)\n",
" util.PlotImageTensor(view_images[0])\n",
" plt.subplot(2, 2, 2)\n",
" util.PlotImageTensor(view_images[1])\n",
" plt.subplot(2, 2, 3)\n",
" util.PlotImageTensor(view_images[2])\n",
" plt.subplot(2, 2, 4)\n",
" util.PlotImageTensor(view_images[3])\n",
" break\n",
"p_ = util.SphericalToCartesian(RayToSpherical(p.flatten(0, 1), v.flatten(0, 1),\n",
" torch.tensor([[1]], device=device.GetDevice()))) \\\n",
" .view(4, train_dataset.view_res[0], train_dataset.view_res[1], 3)\n",
"v = v.view(4, train_dataset.view_res[0], train_dataset.view_res[1], 3)[:, 0::50, 0::50, :].flatten(1, 2).cpu().numpy()\n",
"p_ = p_[:, 0::50, 0::50, :].flatten(1, 2).cpu().numpy()\n",
"dataset = SphericalViewSynDataset(TRAIN_DATA_DESC_FILE)\n",
"dataset.set_patch_size(1)\n",
"data_loader = FastDataLoader(\n",
" dataset=dataset, batch_size=4096*16, shuffle=True, drop_last=False)\n",
"\n",
"fig = plt.figure(figsize=(6, 6))\n",
"ax = fig.gca(projection='3d')\n",
"plt.xlabel('x')\n",
"plt.ylabel('z')\n",
"\n",
"PlotSphere(ax, 1)\n",
"print(\"view_res\", dataset.view_res)\n",
"print(\"cam_params\", dataset.cam_params)\n",
"\n",
"
ax.scatter([0], [0], [0], color=\"k\", s=10) # Center
\n",
"
fig = plt.figure(figsize=(16, 40))
\n",
"\n",
"colors = [ 'r', 'g', 'b', 'y' ]\n",
"for i in range(4):\n",
" ax.scatter(p_[i, :, 0], p_[i, :, 2], p_[i, :, 1], color=colors[i], s=3)\n",
" for j in range(p_.shape[1]):\n",
" ax.plot([centers[i, 0], centers[i, 0] + v[i, j, 0]],\n",
" [centers[i, 2], centers[i, 2] + v[i, j, 2]],\n",
" [centers[i, 1], centers[i, 1] + v[i, j, 1]],\n",
" color=colors[i], linewidth=0.5, alpha=0.6)\n",
"\n",
"ax.set_xlim(-1, 1)\n",
"ax.set_ylim(-1, 1)\n",
"ax.set_zlim(-1, 1)\n",
"\n",
"plt.show()\n"
"for ri in range(0, 10):\n",
" r = ri * 0.2 + 1\n",
" p = None\n",
" centers = None\n",
" pixels = None\n",
" for indices, patches, rays_o, rays_d in data_loader:\n",
" p = util.RaySphereIntersect(\n",
" rays_o, rays_d, torch.tensor([[r]], device=device.GetDevice()))[0] \\\n",
" .view(-1, 3).cpu().numpy()\n",
" centers = rays_o.view(-1, 3).cpu().numpy()\n",
" pixels = patches.view(-1, 3).cpu().numpy()\n",
" break\n",
" \n",
" ax = plt.subplot(5, 2, ri + 1, projection='3d')\n",
" #ax = plt.gca(projection='3d')\n",
" #ax = fig.gca()\n",
" plt.xlabel('x')\n",
" plt.ylabel('z')\n",
" plt.title('r = %f' % r)\n",
"\n",
" # PlotSphere(ax, 1)\n",
"\n",
" ax.scatter([0], [0], [0], color=\"k\", s=10)\n",
" ax.scatter(p[:, 0], p[:, 2], p[:, 1], color=pixels, s=0.5)\n",
"\n",
" #ax.set_xlim(-1, 1)\n",
" #ax.set_ylim(-1, 1)\n",
" #ax.set_zlim(-1, 1)\n",
" ax.view_init(elev=0,azim=-90)\n",
"\n"
]
},
{
"cell_type": "markdown",
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"# Test Sampler"
]
...
...
@@ -241,7 +236,7 @@
"metadata": {},
"outputs": [],
"source": [
"from
deep_view_syn.
data.spherical_view_syn import SphericalViewSynDataset\n",
"from data.spherical_view_syn import SphericalViewSynDataset\n",
"\n",
"DATA_DIR = '../data/sp_view_syn_2020.12.29_finetrans'\n",
"TRAIN_DATA_DESC_FILE = DATA_DIR + '/train.json'\n",
...
...
@@ -304,7 +299,7 @@
"metadata": {},
"outputs": [],
"source": [
"from
deep_view_syn.
data.spherical_view_syn import SphericalViewSynDataset\n",
"from data.spherical_view_syn import SphericalViewSynDataset\n",
"\n",
"DATA_DIR = '../data/sp_view_syn_2020.12.26_rotonly'\n",
"TRAIN_DATA_DESC_FILE = DATA_DIR + '/train.json'\n",
...
...
@@ -367,8 +362,10 @@
]
},
{
"cell_type": "markdown",
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Test Spherical View Synthesis"
]
...
...
@@ -381,9 +378,9 @@
"source": [
"import ipywidgets as widgets # 控件库\n",
"from IPython.display import display # 显示控件的方法\n",
"from
deep_view_syn.
data.spherical_view_syn import SphericalViewSynDataset\n",
"from
deep_view_syn
.spher_net import SpherNet\n",
"from
deep_view_syn.
my import netio\n",
"from data.spherical_view_syn import SphericalViewSynDataset\n",
"from
nets
.spher_net import SpherNet\n",
"from my import netio\n",
"\n",
"DATA_DIR = '../data/sp_view_syn_2020.12.28_small'\n",
"DATA_DESC_FILE = DATA_DIR + '/train.json'\n",
...
...
@@ -455,20 +452,12 @@
"})\n",
"display(slider_x, slider_y, slider_z, slider_theta, slider_phi, out)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
"display_name": "Python 3.7.9 64-bit ('pytorch': conda)",
"name": "python379jvsc74a57bd0660ca2a75467d3af74a68fcc6f40bc78ab96b99ff17d2f100b5ca821fbb183f2"
},
"language_info": {
"codemirror_mode": {
...
...
@@ -480,7 +469,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.
6-final
"
"version": "3.7.
9
"
}
},
"nbformat": 4,
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment