Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
Nianchen Deng
deeplightfield
Commits
c570c3b1
Commit
c570c3b1
authored
Dec 28, 2020
by
BobYeah
Browse files
checkpoint
parent
172b5205
Changes
5
Hide whitespace changes
Inline
Side-by-side
data/spherical_view_syn.py
View file @
c570c3b1
...
@@ -2,7 +2,6 @@ import torch
...
@@ -2,7 +2,6 @@ import torch
import
torchvision.transforms.functional
as
trans_f
import
torchvision.transforms.functional
as
trans_f
import
json
import
json
from
..my
import
util
from
..my
import
util
from
..my
import
imgio
class
SphericalViewSynDataset
(
torch
.
utils
.
data
.
dataset
.
Dataset
):
class
SphericalViewSynDataset
(
torch
.
utils
.
data
.
dataset
.
Dataset
):
...
@@ -44,8 +43,11 @@ class SphericalViewSynDataset(torch.utils.data.dataset.Dataset):
...
@@ -44,8 +43,11 @@ class SphericalViewSynDataset(torch.utils.data.dataset.Dataset):
# Load dataset description file
# Load dataset description file
with
open
(
dataset_desc_path
,
'r'
,
encoding
=
'utf-8'
)
as
file
:
with
open
(
dataset_desc_path
,
'r'
,
encoding
=
'utf-8'
)
as
file
:
data_desc
=
json
.
loads
(
file
.
read
())
data_desc
=
json
.
loads
(
file
.
read
())
self
.
view_file_pattern
:
str
=
self
.
data_dir
+
\
if
data_desc
[
'view_file_pattern'
]
==
''
:
data_desc
[
'view_file_pattern'
]
self
.
load_images
=
False
else
:
self
.
view_file_pattern
:
str
=
self
.
data_dir
+
\
data_desc
[
'view_file_pattern'
]
self
.
view_res
=
(
data_desc
[
'view_res'
][
'y'
],
self
.
view_res
=
(
data_desc
[
'view_res'
][
'y'
],
data_desc
[
'view_res'
][
'x'
])
data_desc
[
'view_res'
][
'x'
])
self
.
cam_params
=
data_desc
[
'cam_params'
]
self
.
cam_params
=
data_desc
[
'cam_params'
]
...
@@ -54,7 +56,7 @@ class SphericalViewSynDataset(torch.utils.data.dataset.Dataset):
...
@@ -54,7 +56,7 @@ class SphericalViewSynDataset(torch.utils.data.dataset.Dataset):
.
view
(
-
1
,
3
,
3
)
# (N, 3, 3)
.
view
(
-
1
,
3
,
3
)
# (N, 3, 3)
# Load view images
# Load view images
if
load_images
:
if
self
.
load_images
:
self
.
view_images
=
util
.
ReadImageTensor
(
self
.
view_images
=
util
.
ReadImageTensor
(
[
self
.
view_file_pattern
%
i
for
i
in
range
(
self
.
view_centers
.
size
(
0
))])
[
self
.
view_file_pattern
%
i
for
i
in
range
(
self
.
view_centers
.
size
(
0
))])
if
gray
:
if
gray
:
...
@@ -75,8 +77,8 @@ class SphericalViewSynDataset(torch.utils.data.dataset.Dataset):
...
@@ -75,8 +77,8 @@ class SphericalViewSynDataset(torch.utils.data.dataset.Dataset):
# Flatten rays if ray_as_item = True
# Flatten rays if ray_as_item = True
if
ray_as_item
:
if
ray_as_item
:
self
.
view_pixels
=
self
.
view_images
.
permute
(
self
.
view_pixels
=
self
.
view_images
.
permute
(
0
,
2
,
3
,
1
).
flatten
(
0
,
2
,
3
,
1
).
flatten
(
0
,
2
)
0
,
2
)
if
self
.
view_images
!=
None
else
None
self
.
ray_positions
=
self
.
ray_positions
.
flatten
(
0
,
1
)
self
.
ray_positions
=
self
.
ray_positions
.
flatten
(
0
,
1
)
self
.
ray_directions
=
self
.
ray_directions
.
flatten
(
0
,
1
)
self
.
ray_directions
=
self
.
ray_directions
.
flatten
(
0
,
1
)
...
@@ -88,4 +90,4 @@ class SphericalViewSynDataset(torch.utils.data.dataset.Dataset):
...
@@ -88,4 +90,4 @@ class SphericalViewSynDataset(torch.utils.data.dataset.Dataset):
if
self
.
ray_as_item
:
if
self
.
ray_as_item
:
return
idx
,
self
.
view_pixels
[
idx
],
self
.
ray_positions
[
idx
],
self
.
ray_directions
[
idx
]
return
idx
,
self
.
view_pixels
[
idx
],
self
.
ray_positions
[
idx
],
self
.
ray_directions
[
idx
]
return
idx
,
self
.
view_images
[
idx
],
self
.
ray_positions
[
idx
],
self
.
ray_directions
[
idx
]
return
idx
,
self
.
view_images
[
idx
],
self
.
ray_positions
[
idx
],
self
.
ray_directions
[
idx
]
return
idx
,
self
.
ray_positions
[
idx
],
self
.
ray_directions
[
idx
]
return
idx
,
False
,
self
.
ray_positions
[
idx
],
self
.
ray_directions
[
idx
]
image_scale.py
0 → 100644
View file @
c570c3b1
import
sys
import
os
sys
.
path
.
append
(
os
.
path
.
abspath
(
sys
.
path
[
0
]
+
'/../'
))
__package__
=
"deeplightfield"
import
argparse
from
PIL
import
Image
from
.my
import
util
def
batch_scale
(
src
,
target
,
size
):
util
.
CreateDirIfNeed
(
target
)
for
file_name
in
os
.
listdir
(
src
):
postfix
=
os
.
path
.
splitext
(
file_name
)[
1
]
if
postfix
==
'.jpg'
or
postfix
==
'.png'
:
im
=
Image
.
open
(
os
.
path
.
join
(
src
,
file_name
))
im
=
im
.
resize
(
size
)
im
.
save
(
os
.
path
.
join
(
target
,
file_name
))
if
__name__
==
'__main__'
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'src'
,
type
=
str
,
help
=
'Source directory.'
)
parser
.
add_argument
(
'target'
,
type
=
str
,
help
=
'Target directory.'
)
parser
.
add_argument
(
'--width'
,
type
=
int
,
help
=
'Width of output images (pixel)'
)
parser
.
add_argument
(
'--height'
,
type
=
int
,
help
=
'Height of output images (pixel)'
)
opt
=
parser
.
parse_args
()
batch_scale
(
opt
.
src
,
opt
.
target
,
(
opt
.
width
,
opt
.
height
))
msl_net.py
View file @
c570c3b1
from
typing
import
List
,
Tuple
from
typing
import
List
,
Tuple
from
math
import
pi
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
.
pytorch_prototyping.pytorch_prototyping
import
*
from
.
my
import
net_modules
from
.my
import
util
from
.my
import
util
from
.my
import
device
from
.my
import
device
def
CartesianToSpherical
(
cart
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Convert coordinates from Cartesian to Spherical
:param cart: ... x 3, coordinates in Cartesian
:return: ... x 3, coordinates in Spherical (r, theta, phi)
"""
rho
=
torch
.
norm
(
cart
,
p
=
2
,
dim
=-
1
)
theta
=
torch
.
atan2
(
cart
[...,
2
],
cart
[...,
0
])
theta
=
theta
+
(
theta
<
0
).
type_as
(
theta
)
*
(
2
*
pi
)
phi
=
torch
.
acos
(
cart
[...,
1
]
/
rho
)
return
torch
.
stack
([
rho
,
theta
,
phi
],
dim
=-
1
)
def
SphericalToCartesian
(
spher
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Convert coordinates from Spherical to Cartesian
:param spher: ... x 3, coordinates in Spherical
:return: ... x 3, coordinates in Cartesian (r, theta, phi)
"""
rho
=
spher
[...,
0
]
sin_theta_phi
=
torch
.
sin
(
spher
[...,
1
:
3
])
cos_theta_phi
=
torch
.
cos
(
spher
[...,
1
:
3
])
x
=
rho
*
cos_theta_phi
[...,
0
]
*
sin_theta_phi
[...,
1
]
y
=
rho
*
cos_theta_phi
[...,
1
]
z
=
rho
*
sin_theta_phi
[...,
0
]
*
sin_theta_phi
[...,
1
]
return
torch
.
stack
([
x
,
y
,
z
],
dim
=-
1
)
def
RaySphereIntersect
(
p
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
r
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
RaySphereIntersect
(
p
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
r
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
"""
Calculate intersections of each rays and each spheres
Calculate intersections of each rays and each spheres
...
@@ -68,115 +37,74 @@ def RayToSpherical(p: torch.Tensor, v: torch.Tensor, r: torch.Tensor) -> torch.T
...
@@ -68,115 +37,74 @@ def RayToSpherical(p: torch.Tensor, v: torch.Tensor, r: torch.Tensor) -> torch.T
:return: B x B' x 3, spherical coordinates
:return: B x B' x 3, spherical coordinates
"""
"""
p_on_spheres
=
RaySphereIntersect
(
p
,
v
,
r
)
p_on_spheres
=
RaySphereIntersect
(
p
,
v
,
r
)
return
CartesianToSpherical
(
p_on_spheres
)
return
util
.
CartesianToSpherical
(
p_on_spheres
)
class
FcNet
(
nn
.
Module
):
def
__init__
(
self
,
in_chns
:
int
,
out_chns
:
int
,
nf
:
int
,
n_layers
:
int
):
super
().
__init__
()
self
.
layers
=
list
()
self
.
layers
+=
[
nn
.
Linear
(
in_chns
,
nf
),
#nn.LayerNorm([nf]),
nn
.
ReLU
()
]
for
_
in
range
(
1
,
n_layers
):
self
.
layers
+=
[
nn
.
Linear
(
nf
,
nf
),
#nn.LayerNorm([nf]),
nn
.
ReLU
()
]
self
.
layers
.
append
(
nn
.
Linear
(
nf
,
out_chns
))
self
.
net
=
nn
.
Sequential
(
*
self
.
layers
)
self
.
net
.
apply
(
self
.
init_weights
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
net
(
x
)
def
init_weights
(
self
,
m
):
if
isinstance
(
m
,
nn
.
Linear
):
nn
.
init
.
xavier_normal_
(
m
.
weight
)
nn
.
init
.
constant_
(
m
.
bias
,
0.0
)
class
Rendering
(
nn
.
Module
):
class
Rendering
(
nn
.
Module
):
def
__init__
(
self
,
sphere_layers
:
List
[
float
]
):
def
__init__
(
self
):
"""
"""
Initialize a Rendering module
Initialize a Rendering module
:param sphere_layers: L x 1, radius of sphere layers
"""
"""
super
().
__init__
()
super
().
__init__
()
self
.
sphere_layers
=
torch
.
tensor
(
sphere_layers
,
device
=
device
.
GetDevice
())
def
forward
(
self
,
net
:
FcNet
,
p
:
torch
.
Tensor
,
v
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
(
self
,
color_alpha
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
"""
[summary]
Blend layers to get final color
:param net: the full-connected net
:param color_alpha ```Tensor(B, L, C)```: RGB or gray with alpha channel
:param p: B x 3, positions of rays
:return ```Tensor(B, C-1)``` blended pixels
:param v: B x 3, directions of rays
:return B x 1/3, view images by blended layers
"""
"""
L
=
self
.
sphere_layers
.
size
()[
0
]
c
=
color_alpha
[...,
:
-
1
]
sp
=
RayToSpherical
(
p
,
v
,
self
.
sphere_layers
)
# B x L x 3
a
=
color_alpha
[...,
-
1
:]
sp
[...,
0
]
=
1
/
sp
[...,
0
]
# Radius to diopter
color_alpha
:
torch
.
Tensor
=
net
(
sp
.
flatten
(
0
,
1
)).
view
(
p
.
size
()[
0
],
L
,
-
1
)
if
(
color_alpha
.
size
(
-
1
)
==
2
):
# Grayscale
c
=
color_alpha
[...,
0
:
1
]
a
=
color_alpha
[...,
1
:
2
]
else
:
# RGB
c
=
color_alpha
[...,
0
:
3
]
a
=
color_alpha
[...,
3
:
4
]
blended
=
c
[:,
0
,
:]
*
a
[:,
0
,
:]
blended
=
c
[:,
0
,
:]
*
a
[:,
0
,
:]
for
l
in
range
(
1
,
L
):
for
l
in
range
(
1
,
color_alpha
.
size
(
1
)
):
blended
=
blended
*
(
1
-
a
[:,
l
,
:])
+
c
[:,
l
,
:]
*
a
[:,
l
,
:]
blended
=
blended
*
(
1
-
a
[:,
l
,
:])
+
c
[:,
l
,
:]
*
a
[:,
l
,
:]
return
blended
return
blended
class
MslNet
(
nn
.
Module
):
class
MslNet
(
nn
.
Module
):
def
__init__
(
self
,
cam_params
,
sphere_layers
:
List
[
float
],
out_res
:
Tuple
[
int
,
int
],
gray
=
False
):
def
__init__
(
self
,
cam_params
,
fc_params
,
sphere_layers
:
List
[
float
],
out_res
:
Tuple
[
int
,
int
],
gray
=
False
,
encode_to_dim
:
int
=
0
):
"""
"""
Initialize a multi-sphere-layer net
Initialize a multi-sphere-layer net
:param cam_params: intrinsic parameters of camera
:param cam_params: intrinsic parameters of camera
:param sphere_layers: L x 1, radius of sphere layers
:param fc_params: parameters of full-connection network
:param sphere_layers: list(L), radius of sphere layers
:param out_res: resolution of output view image
:param out_res: resolution of output view image
:param gray: is grayscale mode
:param encode_to_dim: encode input to number of dimensions
"""
"""
super
().
__init__
()
super
().
__init__
()
self
.
cam_params
=
cam_params
self
.
cam_params
=
cam_params
self
.
sphere_layers
=
torch
.
tensor
(
sphere_layers
,
dtype
=
torch
.
float
,
device
=
device
.
GetDevice
())
self
.
in_chns
=
3
self
.
out_res
=
out_res
self
.
out_res
=
out_res
self
.
v_local
=
util
.
GetLocalViewRays
(
self
.
cam_params
,
out_res
,
flatten
=
True
)
\
self
.
input_encoder
=
net_modules
.
InputEncoder
.
Get
(
.
to
(
device
.
GetDevice
())
# N x 3
encode_to_dim
,
self
.
in_chns
)
#self.net = FCBlock(hidden_ch=64,
fc_params
[
'in_chns'
]
=
self
.
input_encoder
.
out_dim
# num_hidden_layers=4,
fc_params
[
'out_chns'
]
=
2
if
gray
else
4
# in_features=3,
self
.
net
=
net_modules
.
FcNet
(
**
fc_params
)
# out_features=2 if gray else 4,
self
.
rendering
=
Rendering
()
# outermost_linear=True)
self
.
net
=
FcNet
(
in_chns
=
3
,
out_chns
=
2
if
gray
else
4
,
nf
=
256
,
n_layers
=
8
)
def
forward
(
self
,
ray_positions
:
torch
.
Tensor
,
ray_directions
:
torch
.
Tensor
)
->
torch
.
Tensor
:
self
.
rendering
=
Rendering
(
sphere_layers
)
def
forward
(
self
,
view_centers
:
torch
.
Tensor
,
view_rots
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
"""
T_view -> image
rays -> colors
:param
view_centers: B x 3, centers of view
s
:param
ray_positions ```Tensor(B, M, 3)|Tensor(B, 3)```: ray position
s
:param
view_rots: B x 3 x 3, rotation matrices of view
s
:param
ray_directions ```Tensor(B, M, 3)|Tensor(B, 3)```: ray direction
s
:return:
B x 1/3 x H_out x W_out
, inferred images
of view
s
:return:
Tensor(B, 1|3, H, W)|Tensor(B, 1|3)
, inferred images
/pixel
s
"""
"""
# Transpose matrix so we can perform vec x mat
p
=
ray_positions
.
view
(
-
1
,
3
)
view_rots_t
=
view_rots
.
permute
(
0
,
2
,
1
)
v
=
ray_directions
.
view
(
-
1
,
3
)
spher
=
RayToSpherical
(
p
,
v
,
self
.
sphere_layers
).
flatten
(
0
,
1
)
# p and v are B x N x 3 tensor
color_alpha
=
self
.
net
(
self
.
input_encoder
(
spher
)).
view
(
p
=
view_centers
.
unsqueeze
(
1
).
expand
(
-
1
,
self
.
v_local
.
size
(
0
),
-
1
)
p
.
size
(
0
),
self
.
sphere_layers
.
size
(
0
),
-
1
)
v
=
torch
.
matmul
(
self
.
v_local
,
view_rots_t
)
c
:
torch
.
Tensor
=
self
.
rendering
(
color_alpha
)
c
:
torch
.
Tensor
=
self
.
rendering
(
self
.
net
,
p
.
flatten
(
0
,
1
),
v
.
flatten
(
0
,
1
))
# (BN) x 3
# unflatten
# unflatten
return
c
.
view
(
view_center
s
.
size
(
0
),
self
.
out_res
[
0
],
return
c
.
view
(
ray_direction
s
.
size
(
0
),
self
.
out_res
[
0
],
self
.
out_res
[
1
],
-
1
).
permute
(
0
,
3
,
1
,
2
)
self
.
out_res
[
1
],
-
1
).
permute
(
0
,
3
,
1
,
2
)
if
len
(
ray_directions
.
size
())
==
3
else
c
run_spherical_view_syn.py
View file @
c570c3b1
import
sys
import
sys
sys
.
path
.
append
(
'/e/dengnc'
)
import
os
sys
.
path
.
append
(
os
.
path
.
abspath
(
sys
.
path
[
0
]
+
'/../'
))
__package__
=
"deeplightfield"
__package__
=
"deeplightfield"
import
argparse
import
argparse
import
torch
import
torch
import
torch.optim
import
torch.optim
import
torchvision
import
torchvision
from
typing
import
List
,
Tuple
from
tensorboardX
import
SummaryWriter
from
tensorboardX
import
SummaryWriter
from
torch
import
nn
from
torch
import
nn
from
.my
import
netio
from
.my
import
netio
from
.my
import
util
from
.my
import
util
from
.my
import
device
from
.my
import
device
from
.my.simple_perf
import
SimplePerf
from
.my.simple_perf
import
SimplePerf
from
.loss.loss
import
PerceptionReconstructionLoss
from
.data.spherical_view_syn
import
SphericalViewSynDataset
from
.data.spherical_view_syn
import
SphericalViewSynDataset
from
.msl_net
import
MslNet
from
.msl_net
import
MslNet
from
.spher_net
import
SpherNet
from
.spher_net
import
SpherNet
...
@@ -36,9 +35,9 @@ TRAIN_MODE = True
...
@@ -36,9 +35,9 @@ TRAIN_MODE = True
EVAL_TIME_PERFORMANCE
=
False
EVAL_TIME_PERFORMANCE
=
False
RAY_AS_ITEM
=
True
RAY_AS_ITEM
=
True
# ========
# ========
#
GRAY = True
GRAY
=
True
ROT_ONLY
=
True
#
ROT_ONLY = True
TRAIN_MODE
=
False
#
TRAIN_MODE = False
#EVAL_TIME_PERFORMANCE = True
#EVAL_TIME_PERFORMANCE = True
#RAY_AS_ITEM = False
#RAY_AS_ITEM = False
...
@@ -48,39 +47,39 @@ N_DEPTH_LAYERS = 10
...
@@ -48,39 +47,39 @@ N_DEPTH_LAYERS = 10
N_ENCODE_DIM
=
10
N_ENCODE_DIM
=
10
FC_PARAMS
=
{
FC_PARAMS
=
{
'nf'
:
128
,
'nf'
:
128
,
'n_layers'
:
6
,
'n_layers'
:
8
,
'skips'
:
[
4
]
'skips'
:
[
4
]
}
}
# Train
# Train
TRAIN_DATA_DESC_FILE
=
'train.json'
BATCH_SIZE
=
2048
if
RAY_AS_ITEM
else
4
BATCH_SIZE
=
2048
if
RAY_AS_ITEM
else
4
EPOCH_RANGE
=
range
(
0
,
500
)
EPOCH_RANGE
=
range
(
0
,
500
)
SAVE_INTERVAL
=
20
SAVE_INTERVAL
=
20
# Test
TEST_NET_NAME
=
'model-epoch_500'
TEST_DATA_DESC_FILE
=
'test_fovea.json'
TEST_BATCH_SIZE
=
5
# Paths
# Paths
DATA_DIR
=
sys
.
path
[
0
]
+
'/data/sp_view_syn_2020.12.2
6_rotonly
/'
DATA_DIR
=
sys
.
path
[
0
]
+
'/data/sp_view_syn_2020.12.2
8
/'
RUN_ID
=
'%s_ray_b%d_encode%d_fc%dx%d%s'
%
(
'gray'
if
GRAY
else
'rgb'
,
RUN_ID
=
'%s_ray_b%d_encode%d_fc%dx%d%s'
%
(
'gray'
if
GRAY
else
'rgb'
,
BATCH_SIZE
,
BATCH_SIZE
,
N_ENCODE_DIM
,
N_ENCODE_DIM
,
FC_PARAMS
[
'nf'
],
FC_PARAMS
[
'nf'
],
FC_PARAMS
[
'n_layers'
],
FC_PARAMS
[
'n_layers'
],
'_skip_%d'
%
FC_PARAMS
[
'skips'
][
0
]
if
len
(
FC_PARAMS
[
'skips'
])
>
0
else
''
)
'_skip_%d'
%
FC_PARAMS
[
'skips'
][
0
]
if
len
(
FC_PARAMS
[
'skips'
])
>
0
else
''
)
TRAIN_DATA_DESC_FILE
=
DATA_DIR
+
'train.json'
RUN_DIR
=
DATA_DIR
+
RUN_ID
+
'/'
RUN_DIR
=
DATA_DIR
+
RUN_ID
+
'/'
OUTPUT_DIR
=
RUN_DIR
+
'output/'
OUTPUT_DIR
=
RUN_DIR
+
'output/'
LOG_DIR
=
RUN_DIR
+
'log/'
LOG_DIR
=
RUN_DIR
+
'log/'
# Test
TEST_NET_NAME
=
'model-epoch_100'
TEST_BATCH_SIZE
=
5
def
train
():
def
train
():
# 1. Initialize data loader
# 1. Initialize data loader
print
(
"Load dataset: "
+
TRAIN_DATA_DESC_FILE
)
print
(
"Load dataset: "
+
DATA_DIR
+
TRAIN_DATA_DESC_FILE
)
train_dataset
=
SphericalViewSynDataset
(
train_dataset
=
SphericalViewSynDataset
(
DATA_DIR
+
TRAIN_DATA_DESC_FILE
,
TRAIN_DATA_DESC_FILE
,
gray
=
GRAY
,
ray_as_item
=
RAY_AS_ITEM
)
gray
=
GRAY
,
ray_as_item
=
RAY_AS_ITEM
)
train_data_loader
=
torch
.
utils
.
data
.
DataLoader
(
train_data_loader
=
torch
.
utils
.
data
.
DataLoader
(
dataset
=
train_dataset
,
dataset
=
train_dataset
,
batch_size
=
BATCH_SIZE
,
batch_size
=
BATCH_SIZE
,
...
@@ -98,10 +97,12 @@ def train():
...
@@ -98,10 +97,12 @@ def train():
encode_to_dim
=
N_ENCODE_DIM
).
to
(
device
.
GetDevice
())
encode_to_dim
=
N_ENCODE_DIM
).
to
(
device
.
GetDevice
())
else
:
else
:
model
=
MslNet
(
cam_params
=
train_dataset
.
cam_params
,
model
=
MslNet
(
cam_params
=
train_dataset
.
cam_params
,
fc_params
=
FC_PARAMS
,
sphere_layers
=
util
.
GetDepthLayers
(
sphere_layers
=
util
.
GetDepthLayers
(
DEPTH_RANGE
,
N_DEPTH_LAYERS
),
DEPTH_RANGE
,
N_DEPTH_LAYERS
),
out_res
=
train_dataset
.
view_res
,
out_res
=
train_dataset
.
view_res
,
gray
=
GRAY
).
to
(
device
.
GetDevice
())
gray
=
GRAY
,
encode_to_dim
=
N_ENCODE_DIM
).
to
(
device
.
GetDevice
())
optimizer
=
torch
.
optim
.
Adam
(
model
.
parameters
(),
lr
=
5e-4
)
optimizer
=
torch
.
optim
.
Adam
(
model
.
parameters
(),
lr
=
5e-4
)
loss
=
nn
.
MSELoss
()
loss
=
nn
.
MSELoss
()
...
@@ -172,11 +173,11 @@ def train():
...
@@ -172,11 +173,11 @@ def train():
def
test
(
net_file
:
str
):
def
test
(
net_file
:
str
):
# 1. Load train dataset
# 1. Load train dataset
print
(
"Load dataset: "
+
TRAIN
_DATA_DESC_FILE
)
print
(
"Load dataset: "
+
DATA_DIR
+
TEST
_DATA_DESC_FILE
)
t
rain
_dataset
=
SphericalViewSynDataset
(
TRAIN
_DATA_DESC_FILE
,
t
est
_dataset
=
SphericalViewSynDataset
(
DATA_DIR
+
TEST
_DATA_DESC_FILE
,
load_images
=
True
,
gray
=
GRAY
)
load_images
=
True
,
gray
=
GRAY
)
t
rain
_data_loader
=
torch
.
utils
.
data
.
DataLoader
(
t
est
_data_loader
=
torch
.
utils
.
data
.
DataLoader
(
dataset
=
t
rain
_dataset
,
dataset
=
t
est
_dataset
,
batch_size
=
TEST_BATCH_SIZE
,
batch_size
=
TEST_BATCH_SIZE
,
pin_memory
=
True
,
pin_memory
=
True
,
shuffle
=
False
,
shuffle
=
False
,
...
@@ -184,37 +185,38 @@ def test(net_file: str):
...
@@ -184,37 +185,38 @@ def test(net_file: str):
# 2. Load trained model
# 2. Load trained model
if
ROT_ONLY
:
if
ROT_ONLY
:
model
=
SpherNet
(
cam_params
=
t
rain
_dataset
.
cam_params
,
model
=
SpherNet
(
cam_params
=
t
est
_dataset
.
cam_params
,
fc_params
=
FC_PARAMS
,
fc_params
=
FC_PARAMS
,
out_res
=
t
rain
_dataset
.
view_res
,
out_res
=
t
est
_dataset
.
view_res
,
gray
=
GRAY
,
gray
=
GRAY
,
encode_to_dim
=
N_ENCODE_DIM
).
to
(
device
.
GetDevice
())
encode_to_dim
=
N_ENCODE_DIM
).
to
(
device
.
GetDevice
())
else
:
else
:
model
=
MslNet
(
cam_params
=
t
rain
_dataset
.
cam_params
,
model
=
MslNet
(
cam_params
=
t
est
_dataset
.
cam_params
,
sphere_layers
=
_GetSphere
Layers
(
sphere_layers
=
util
.
GetDepth
Layers
(
DEPTH_RANGE
,
N_DEPTH_LAYERS
),
DEPTH_RANGE
,
N_DEPTH_LAYERS
),
out_res
=
t
rain
_dataset
.
view_res
,
out_res
=
t
est
_dataset
.
view_res
,
gray
=
GRAY
).
to
(
device
.
GetDevice
())
gray
=
GRAY
).
to
(
device
.
GetDevice
())
netio
.
LoadNet
(
net_file
,
model
)
netio
.
LoadNet
(
net_file
,
model
)
# 3. Test on train dataset
# 3. Test on train dataset
print
(
"Begin test on train dataset, batch size is %d"
%
TEST_BATCH_SIZE
)
print
(
"Begin test on train dataset, batch size is %d"
%
TEST_BATCH_SIZE
)
ut
il
.
CreateDirIfNeed
(
OUTPUT_DIR
)
o
ut
put_dir
=
'%s%s/%s/'
%
(
OUTPUT_DIR
,
TEST_NET_NAME
,
TEST_DATA_DESC_FILE
)
util
.
CreateDirIfNeed
(
OUTPUT_DIR
+
TEST_NET_NAME
)
util
.
CreateDirIfNeed
(
output_dir
)
perf
=
SimplePerf
(
True
,
start
=
True
)
perf
=
SimplePerf
(
True
,
start
=
True
)
i
=
0
i
=
0
for
view_idxs
,
view_images
,
ray_positions
,
ray_directions
in
t
rain
_data_loader
:
for
view_idxs
,
view_images
,
ray_positions
,
ray_directions
in
t
est
_data_loader
:
ray_positions
=
ray_positions
.
to
(
device
.
GetDevice
())
ray_positions
=
ray_positions
.
to
(
device
.
GetDevice
())
ray_directions
=
ray_directions
.
to
(
device
.
GetDevice
())
ray_directions
=
ray_directions
.
to
(
device
.
GetDevice
())
perf
.
Checkpoint
(
"%d - Load"
%
i
)
perf
.
Checkpoint
(
"%d - Load"
%
i
)
out_view_images
=
model
(
ray_positions
,
ray_directions
)
out_view_images
=
model
(
ray_positions
,
ray_directions
)
perf
.
Checkpoint
(
"%d - Infer"
%
i
)
perf
.
Checkpoint
(
"%d - Infer"
%
i
)
util
.
WriteImageTensor
(
if
test_dataset
.
load_images
:
view_images
,
util
.
WriteImageTensor
(
[
'%s%s/gt_view_%04d.png'
%
(
OUTPUT_DIR
,
TEST_NET_NAME
,
i
)
for
i
in
view_idxs
])
view_images
,
[
'%sgt_view_%04d.png'
%
(
output_dir
,
i
)
for
i
in
view_idxs
])
util
.
WriteImageTensor
(
util
.
WriteImageTensor
(
out_view_images
,
out_view_images
,
[
'%s
%s/
out_view_%04d.png'
%
(
OUTPUT_DIR
,
TEST_NET_NAME
,
i
)
for
i
in
view_idxs
])
[
'%sout_view_%04d.png'
%
(
output_dir
,
i
)
for
i
in
view_idxs
])
perf
.
Checkpoint
(
"%d - Write"
%
i
)
perf
.
Checkpoint
(
"%d - Write"
%
i
)
i
+=
1
i
+=
1
...
...
spher_net.py
View file @
c570c3b1
from
typing
import
List
,
Tuple
from
typing
import
Tuple
from
math
import
pi
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
.pytorch_prototyping.pytorch_prototyping
import
*
from
.my
import
net_modules
from
.my
import
net_modules
from
.my
import
util
from
.my
import
util
from
.my
import
device
def
RaySphereIntersect
(
p
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
r
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Calculate intersections of each rays and each spheres
:param p: B x 3, positions of rays
:param v: B x 3, directions of rays
:param r: B'(1D), radius of spheres
:return: B x B' x 3, points of intersection
"""
# p, v: Expand to B x 1 x 3
p
=
p
.
unsqueeze
(
1
)
v
=
v
.
unsqueeze
(
1
)
# pp, vv, pv: B x 1
pp
=
(
p
*
p
).
sum
(
dim
=
2
)
vv
=
(
v
*
v
).
sum
(
dim
=
2
)
pv
=
(
p
*
v
).
sum
(
dim
=
2
)
# k: Expand to B x B' x 1
k
=
(((
pv
*
pv
-
vv
*
(
pp
-
r
*
r
)).
sqrt
()
-
pv
)
/
vv
).
unsqueeze
(
2
)
return
p
+
k
*
v
def
RayToSpherical
(
p
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
r
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Calculate intersections of each rays and each spheres
:param p: B x 3, positions of rays
:param v: B x 3, directions of rays
:param r: B' x 1, radius of spheres
:return: B x B' x 3, spherical coordinates
"""
p_on_spheres
=
RaySphereIntersect
(
p
,
v
,
r
)
return
util
.
CartesianToSpherical
(
p_on_spheres
)
class
SpherNet
(
nn
.
Module
):
class
SpherNet
(
nn
.
Module
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment