Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
Nianchen Deng
deeplightfield
Commits
c570c3b1
Commit
c570c3b1
authored
Dec 28, 2020
by
BobYeah
Browse files
checkpoint
parent
172b5205
Changes
5
Show whitespace changes
Inline
Side-by-side
data/spherical_view_syn.py
View file @
c570c3b1
...
...
@@ -2,7 +2,6 @@ import torch
import
torchvision.transforms.functional
as
trans_f
import
json
from
..my
import
util
from
..my
import
imgio
class
SphericalViewSynDataset
(
torch
.
utils
.
data
.
dataset
.
Dataset
):
...
...
@@ -44,6 +43,9 @@ class SphericalViewSynDataset(torch.utils.data.dataset.Dataset):
# Load dataset description file
with
open
(
dataset_desc_path
,
'r'
,
encoding
=
'utf-8'
)
as
file
:
data_desc
=
json
.
loads
(
file
.
read
())
if
data_desc
[
'view_file_pattern'
]
==
''
:
self
.
load_images
=
False
else
:
self
.
view_file_pattern
:
str
=
self
.
data_dir
+
\
data_desc
[
'view_file_pattern'
]
self
.
view_res
=
(
data_desc
[
'view_res'
][
'y'
],
...
...
@@ -54,7 +56,7 @@ class SphericalViewSynDataset(torch.utils.data.dataset.Dataset):
.
view
(
-
1
,
3
,
3
)
# (N, 3, 3)
# Load view images
if
load_images
:
if
self
.
load_images
:
self
.
view_images
=
util
.
ReadImageTensor
(
[
self
.
view_file_pattern
%
i
for
i
in
range
(
self
.
view_centers
.
size
(
0
))])
if
gray
:
...
...
@@ -75,8 +77,8 @@ class SphericalViewSynDataset(torch.utils.data.dataset.Dataset):
# Flatten rays if ray_as_item = True
if
ray_as_item
:
self
.
view_pixels
=
self
.
view_images
.
permute
(
0
,
2
,
3
,
1
).
flatten
(
0
,
2
)
self
.
view_pixels
=
self
.
view_images
.
permute
(
0
,
2
,
3
,
1
).
flatten
(
0
,
2
)
if
self
.
view_images
!=
None
else
None
self
.
ray_positions
=
self
.
ray_positions
.
flatten
(
0
,
1
)
self
.
ray_directions
=
self
.
ray_directions
.
flatten
(
0
,
1
)
...
...
@@ -88,4 +90,4 @@ class SphericalViewSynDataset(torch.utils.data.dataset.Dataset):
if
self
.
ray_as_item
:
return
idx
,
self
.
view_pixels
[
idx
],
self
.
ray_positions
[
idx
],
self
.
ray_directions
[
idx
]
return
idx
,
self
.
view_images
[
idx
],
self
.
ray_positions
[
idx
],
self
.
ray_directions
[
idx
]
return
idx
,
self
.
ray_positions
[
idx
],
self
.
ray_directions
[
idx
]
return
idx
,
False
,
self
.
ray_positions
[
idx
],
self
.
ray_directions
[
idx
]
image_scale.py
0 → 100644
View file @
c570c3b1
import
sys
import
os
sys
.
path
.
append
(
os
.
path
.
abspath
(
sys
.
path
[
0
]
+
'/../'
))
__package__
=
"deeplightfield"
import
argparse
from
PIL
import
Image
from
.my
import
util
def
batch_scale
(
src
,
target
,
size
):
util
.
CreateDirIfNeed
(
target
)
for
file_name
in
os
.
listdir
(
src
):
postfix
=
os
.
path
.
splitext
(
file_name
)[
1
]
if
postfix
==
'.jpg'
or
postfix
==
'.png'
:
im
=
Image
.
open
(
os
.
path
.
join
(
src
,
file_name
))
im
=
im
.
resize
(
size
)
im
.
save
(
os
.
path
.
join
(
target
,
file_name
))
if
__name__
==
'__main__'
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'src'
,
type
=
str
,
help
=
'Source directory.'
)
parser
.
add_argument
(
'target'
,
type
=
str
,
help
=
'Target directory.'
)
parser
.
add_argument
(
'--width'
,
type
=
int
,
help
=
'Width of output images (pixel)'
)
parser
.
add_argument
(
'--height'
,
type
=
int
,
help
=
'Height of output images (pixel)'
)
opt
=
parser
.
parse_args
()
batch_scale
(
opt
.
src
,
opt
.
target
,
(
opt
.
width
,
opt
.
height
))
msl_net.py
View file @
c570c3b1
from
typing
import
List
,
Tuple
from
math
import
pi
import
torch
import
torch.nn
as
nn
from
.
pytorch_prototyping.pytorch_prototyping
import
*
from
.
my
import
net_modules
from
.my
import
util
from
.my
import
device
def
CartesianToSpherical
(
cart
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Convert coordinates from Cartesian to Spherical
:param cart: ... x 3, coordinates in Cartesian
:return: ... x 3, coordinates in Spherical (r, theta, phi)
"""
rho
=
torch
.
norm
(
cart
,
p
=
2
,
dim
=-
1
)
theta
=
torch
.
atan2
(
cart
[...,
2
],
cart
[...,
0
])
theta
=
theta
+
(
theta
<
0
).
type_as
(
theta
)
*
(
2
*
pi
)
phi
=
torch
.
acos
(
cart
[...,
1
]
/
rho
)
return
torch
.
stack
([
rho
,
theta
,
phi
],
dim
=-
1
)
def
SphericalToCartesian
(
spher
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Convert coordinates from Spherical to Cartesian
:param spher: ... x 3, coordinates in Spherical
:return: ... x 3, coordinates in Cartesian (r, theta, phi)
"""
rho
=
spher
[...,
0
]
sin_theta_phi
=
torch
.
sin
(
spher
[...,
1
:
3
])
cos_theta_phi
=
torch
.
cos
(
spher
[...,
1
:
3
])
x
=
rho
*
cos_theta_phi
[...,
0
]
*
sin_theta_phi
[...,
1
]
y
=
rho
*
cos_theta_phi
[...,
1
]
z
=
rho
*
sin_theta_phi
[...,
0
]
*
sin_theta_phi
[...,
1
]
return
torch
.
stack
([
x
,
y
,
z
],
dim
=-
1
)
def
RaySphereIntersect
(
p
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
r
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Calculate intersections of each rays and each spheres
...
...
@@ -68,115 +37,74 @@ def RayToSpherical(p: torch.Tensor, v: torch.Tensor, r: torch.Tensor) -> torch.T
:return: B x B' x 3, spherical coordinates
"""
p_on_spheres
=
RaySphereIntersect
(
p
,
v
,
r
)
return
CartesianToSpherical
(
p_on_spheres
)
class
FcNet
(
nn
.
Module
):
def
__init__
(
self
,
in_chns
:
int
,
out_chns
:
int
,
nf
:
int
,
n_layers
:
int
):
super
().
__init__
()
self
.
layers
=
list
()
self
.
layers
+=
[
nn
.
Linear
(
in_chns
,
nf
),
#nn.LayerNorm([nf]),
nn
.
ReLU
()
]
for
_
in
range
(
1
,
n_layers
):
self
.
layers
+=
[
nn
.
Linear
(
nf
,
nf
),
#nn.LayerNorm([nf]),
nn
.
ReLU
()
]
self
.
layers
.
append
(
nn
.
Linear
(
nf
,
out_chns
))
self
.
net
=
nn
.
Sequential
(
*
self
.
layers
)
self
.
net
.
apply
(
self
.
init_weights
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
net
(
x
)
def
init_weights
(
self
,
m
):
if
isinstance
(
m
,
nn
.
Linear
):
nn
.
init
.
xavier_normal_
(
m
.
weight
)
nn
.
init
.
constant_
(
m
.
bias
,
0.0
)
return
util
.
CartesianToSpherical
(
p_on_spheres
)
class
Rendering
(
nn
.
Module
):
def
__init__
(
self
,
sphere_layers
:
List
[
float
]
):
def
__init__
(
self
):
"""
Initialize a Rendering module
:param sphere_layers: L x 1, radius of sphere layers
"""
super
().
__init__
()
self
.
sphere_layers
=
torch
.
tensor
(
sphere_layers
,
device
=
device
.
GetDevice
())
def
forward
(
self
,
net
:
FcNet
,
p
:
torch
.
Tensor
,
v
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
(
self
,
color_alpha
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
[summary]
Blend layers to get final color
:param net: the full-connected net
:param p: B x 3, positions of rays
:param v: B x 3, directions of rays
:return B x 1/3, view images by blended layers
:param color_alpha ```Tensor(B, L, C)```: RGB or gray with alpha channel
:return ```Tensor(B, C-1)``` blended pixels
"""
L
=
self
.
sphere_layers
.
size
()[
0
]
sp
=
RayToSpherical
(
p
,
v
,
self
.
sphere_layers
)
# B x L x 3
sp
[...,
0
]
=
1
/
sp
[...,
0
]
# Radius to diopter
color_alpha
:
torch
.
Tensor
=
net
(
sp
.
flatten
(
0
,
1
)).
view
(
p
.
size
()[
0
],
L
,
-
1
)
if
(
color_alpha
.
size
(
-
1
)
==
2
):
# Grayscale
c
=
color_alpha
[...,
0
:
1
]
a
=
color_alpha
[...,
1
:
2
]
else
:
# RGB
c
=
color_alpha
[...,
0
:
3
]
a
=
color_alpha
[...,
3
:
4
]
c
=
color_alpha
[...,
:
-
1
]
a
=
color_alpha
[...,
-
1
:]
blended
=
c
[:,
0
,
:]
*
a
[:,
0
,
:]
for
l
in
range
(
1
,
L
):
for
l
in
range
(
1
,
color_alpha
.
size
(
1
)
):
blended
=
blended
*
(
1
-
a
[:,
l
,
:])
+
c
[:,
l
,
:]
*
a
[:,
l
,
:]
return
blended
class
MslNet
(
nn
.
Module
):
def
__init__
(
self
,
cam_params
,
sphere_layers
:
List
[
float
],
out_res
:
Tuple
[
int
,
int
],
gray
=
False
):
def
__init__
(
self
,
cam_params
,
fc_params
,
sphere_layers
:
List
[
float
],
out_res
:
Tuple
[
int
,
int
],
gray
=
False
,
encode_to_dim
:
int
=
0
):
"""
Initialize a multi-sphere-layer net
:param cam_params: intrinsic parameters of camera
:param sphere_layers: L x 1, radius of sphere layers
:param fc_params: parameters of full-connection network
:param sphere_layers: list(L), radius of sphere layers
:param out_res: resolution of output view image
:param gray: is grayscale mode
:param encode_to_dim: encode input to number of dimensions
"""
super
().
__init__
()
self
.
cam_params
=
cam_params
self
.
sphere_layers
=
torch
.
tensor
(
sphere_layers
,
dtype
=
torch
.
float
,
device
=
device
.
GetDevice
())
self
.
in_chns
=
3
self
.
out_res
=
out_res
self
.
v_local
=
util
.
GetLocalViewRays
(
self
.
cam_params
,
out_res
,
flatten
=
True
)
\
.
to
(
device
.
GetDevice
())
# N x 3
#self.net = FCBlock(hidden_ch=64,
# num_hidden_layers=4,
# in_features=3,
# out_features=2 if gray else 4,
# outermost_linear=True)
self
.
net
=
FcNet
(
in_chns
=
3
,
out_chns
=
2
if
gray
else
4
,
nf
=
256
,
n_layers
=
8
)
self
.
rendering
=
Rendering
(
sphere_layers
)
def
forward
(
self
,
view_centers
:
torch
.
Tensor
,
view_rots
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
T_view -> image
:param view_centers: B x 3, centers of views
:param view_rots: B x 3 x 3, rotation matrices of views
:return: B x 1/3 x H_out x W_out, inferred images of views
"""
# Transpose matrix so we can perform vec x mat
view_rots_t
=
view_rots
.
permute
(
0
,
2
,
1
)
# p and v are B x N x 3 tensor
p
=
view_centers
.
unsqueeze
(
1
).
expand
(
-
1
,
self
.
v_local
.
size
(
0
),
-
1
)
v
=
torch
.
matmul
(
self
.
v_local
,
view_rots_t
)
c
:
torch
.
Tensor
=
self
.
rendering
(
self
.
net
,
p
.
flatten
(
0
,
1
),
v
.
flatten
(
0
,
1
))
# (BN) x 3
self
.
input_encoder
=
net_modules
.
InputEncoder
.
Get
(
encode_to_dim
,
self
.
in_chns
)
fc_params
[
'in_chns'
]
=
self
.
input_encoder
.
out_dim
fc_params
[
'out_chns'
]
=
2
if
gray
else
4
self
.
net
=
net_modules
.
FcNet
(
**
fc_params
)
self
.
rendering
=
Rendering
()
def
forward
(
self
,
ray_positions
:
torch
.
Tensor
,
ray_directions
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
rays -> colors
:param ray_positions ```Tensor(B, M, 3)|Tensor(B, 3)```: ray positions
:param ray_directions ```Tensor(B, M, 3)|Tensor(B, 3)```: ray directions
:return: Tensor(B, 1|3, H, W)|Tensor(B, 1|3), inferred images/pixels
"""
p
=
ray_positions
.
view
(
-
1
,
3
)
v
=
ray_directions
.
view
(
-
1
,
3
)
spher
=
RayToSpherical
(
p
,
v
,
self
.
sphere_layers
).
flatten
(
0
,
1
)
color_alpha
=
self
.
net
(
self
.
input_encoder
(
spher
)).
view
(
p
.
size
(
0
),
self
.
sphere_layers
.
size
(
0
),
-
1
)
c
:
torch
.
Tensor
=
self
.
rendering
(
color_alpha
)
# unflatten
return
c
.
view
(
view_center
s
.
size
(
0
),
self
.
out_res
[
0
],
self
.
out_res
[
1
],
-
1
).
permute
(
0
,
3
,
1
,
2
)
return
c
.
view
(
ray_direction
s
.
size
(
0
),
self
.
out_res
[
0
],
self
.
out_res
[
1
],
-
1
).
permute
(
0
,
3
,
1
,
2
)
if
len
(
ray_directions
.
size
())
==
3
else
c
run_spherical_view_syn.py
View file @
c570c3b1
import
sys
sys
.
path
.
append
(
'/e/dengnc'
)
import
os
sys
.
path
.
append
(
os
.
path
.
abspath
(
sys
.
path
[
0
]
+
'/../'
))
__package__
=
"deeplightfield"
import
argparse
import
torch
import
torch.optim
import
torchvision
from
typing
import
List
,
Tuple
from
tensorboardX
import
SummaryWriter
from
torch
import
nn
from
.my
import
netio
from
.my
import
util
from
.my
import
device
from
.my.simple_perf
import
SimplePerf
from
.loss.loss
import
PerceptionReconstructionLoss
from
.data.spherical_view_syn
import
SphericalViewSynDataset
from
.msl_net
import
MslNet
from
.spher_net
import
SpherNet
...
...
@@ -36,9 +35,9 @@ TRAIN_MODE = True
EVAL_TIME_PERFORMANCE
=
False
RAY_AS_ITEM
=
True
# ========
#
GRAY = True
ROT_ONLY
=
True
TRAIN_MODE
=
False
GRAY
=
True
#
ROT_ONLY = True
#
TRAIN_MODE = False
#EVAL_TIME_PERFORMANCE = True
#RAY_AS_ITEM = False
...
...
@@ -48,39 +47,39 @@ N_DEPTH_LAYERS = 10
N_ENCODE_DIM
=
10
FC_PARAMS
=
{
'nf'
:
128
,
'n_layers'
:
6
,
'n_layers'
:
8
,
'skips'
:
[
4
]
}
# Train
TRAIN_DATA_DESC_FILE
=
'train.json'
BATCH_SIZE
=
2048
if
RAY_AS_ITEM
else
4
EPOCH_RANGE
=
range
(
0
,
500
)
SAVE_INTERVAL
=
20
# Test
TEST_NET_NAME
=
'model-epoch_500'
TEST_DATA_DESC_FILE
=
'test_fovea.json'
TEST_BATCH_SIZE
=
5
# Paths
DATA_DIR
=
sys
.
path
[
0
]
+
'/data/sp_view_syn_2020.12.2
6_rotonly
/'
DATA_DIR
=
sys
.
path
[
0
]
+
'/data/sp_view_syn_2020.12.2
8
/'
RUN_ID
=
'%s_ray_b%d_encode%d_fc%dx%d%s'
%
(
'gray'
if
GRAY
else
'rgb'
,
BATCH_SIZE
,
N_ENCODE_DIM
,
FC_PARAMS
[
'nf'
],
FC_PARAMS
[
'n_layers'
],
'_skip_%d'
%
FC_PARAMS
[
'skips'
][
0
]
if
len
(
FC_PARAMS
[
'skips'
])
>
0
else
''
)
TRAIN_DATA_DESC_FILE
=
DATA_DIR
+
'train.json'
RUN_DIR
=
DATA_DIR
+
RUN_ID
+
'/'
OUTPUT_DIR
=
RUN_DIR
+
'output/'
LOG_DIR
=
RUN_DIR
+
'log/'
# Test
TEST_NET_NAME
=
'model-epoch_100'
TEST_BATCH_SIZE
=
5
def
train
():
# 1. Initialize data loader
print
(
"Load dataset: "
+
TRAIN_DATA_DESC_FILE
)
train_dataset
=
SphericalViewSynDataset
(
TRAIN_DATA_DESC_FILE
,
gray
=
GRAY
,
ray_as_item
=
RAY_AS_ITEM
)
print
(
"Load dataset: "
+
DATA_DIR
+
TRAIN_DATA_DESC_FILE
)
train_dataset
=
SphericalViewSynDataset
(
DATA_DIR
+
TRAIN_DATA_DESC_FILE
,
gray
=
GRAY
,
ray_as_item
=
RAY_AS_ITEM
)
train_data_loader
=
torch
.
utils
.
data
.
DataLoader
(
dataset
=
train_dataset
,
batch_size
=
BATCH_SIZE
,
...
...
@@ -98,10 +97,12 @@ def train():
encode_to_dim
=
N_ENCODE_DIM
).
to
(
device
.
GetDevice
())
else
:
model
=
MslNet
(
cam_params
=
train_dataset
.
cam_params
,
fc_params
=
FC_PARAMS
,
sphere_layers
=
util
.
GetDepthLayers
(
DEPTH_RANGE
,
N_DEPTH_LAYERS
),
out_res
=
train_dataset
.
view_res
,
gray
=
GRAY
).
to
(
device
.
GetDevice
())
gray
=
GRAY
,
encode_to_dim
=
N_ENCODE_DIM
).
to
(
device
.
GetDevice
())
optimizer
=
torch
.
optim
.
Adam
(
model
.
parameters
(),
lr
=
5e-4
)
loss
=
nn
.
MSELoss
()
...
...
@@ -172,11 +173,11 @@ def train():
def
test
(
net_file
:
str
):
# 1. Load train dataset
print
(
"Load dataset: "
+
TRAIN
_DATA_DESC_FILE
)
t
rain
_dataset
=
SphericalViewSynDataset
(
TRAIN
_DATA_DESC_FILE
,
print
(
"Load dataset: "
+
DATA_DIR
+
TEST
_DATA_DESC_FILE
)
t
est
_dataset
=
SphericalViewSynDataset
(
DATA_DIR
+
TEST
_DATA_DESC_FILE
,
load_images
=
True
,
gray
=
GRAY
)
t
rain
_data_loader
=
torch
.
utils
.
data
.
DataLoader
(
dataset
=
t
rain
_dataset
,
t
est
_data_loader
=
torch
.
utils
.
data
.
DataLoader
(
dataset
=
t
est
_dataset
,
batch_size
=
TEST_BATCH_SIZE
,
pin_memory
=
True
,
shuffle
=
False
,
...
...
@@ -184,37 +185,38 @@ def test(net_file: str):
# 2. Load trained model
if
ROT_ONLY
:
model
=
SpherNet
(
cam_params
=
t
rain
_dataset
.
cam_params
,
model
=
SpherNet
(
cam_params
=
t
est
_dataset
.
cam_params
,
fc_params
=
FC_PARAMS
,
out_res
=
t
rain
_dataset
.
view_res
,
out_res
=
t
est
_dataset
.
view_res
,
gray
=
GRAY
,
encode_to_dim
=
N_ENCODE_DIM
).
to
(
device
.
GetDevice
())
else
:
model
=
MslNet
(
cam_params
=
t
rain
_dataset
.
cam_params
,
sphere_layers
=
_GetSphere
Layers
(
model
=
MslNet
(
cam_params
=
t
est
_dataset
.
cam_params
,
sphere_layers
=
util
.
GetDepth
Layers
(
DEPTH_RANGE
,
N_DEPTH_LAYERS
),
out_res
=
t
rain
_dataset
.
view_res
,
out_res
=
t
est
_dataset
.
view_res
,
gray
=
GRAY
).
to
(
device
.
GetDevice
())
netio
.
LoadNet
(
net_file
,
model
)
# 3. Test on train dataset
print
(
"Begin test on train dataset, batch size is %d"
%
TEST_BATCH_SIZE
)
ut
il
.
CreateDirIfNeed
(
OUTPUT_DIR
)
util
.
CreateDirIfNeed
(
OUTPUT_DIR
+
TEST_NET_NAME
)
o
ut
put_dir
=
'%s%s/%s/'
%
(
OUTPUT_DIR
,
TEST_NET_NAME
,
TEST_DATA_DESC_FILE
)
util
.
CreateDirIfNeed
(
output_dir
)
perf
=
SimplePerf
(
True
,
start
=
True
)
i
=
0
for
view_idxs
,
view_images
,
ray_positions
,
ray_directions
in
t
rain
_data_loader
:
for
view_idxs
,
view_images
,
ray_positions
,
ray_directions
in
t
est
_data_loader
:
ray_positions
=
ray_positions
.
to
(
device
.
GetDevice
())
ray_directions
=
ray_directions
.
to
(
device
.
GetDevice
())
perf
.
Checkpoint
(
"%d - Load"
%
i
)
out_view_images
=
model
(
ray_positions
,
ray_directions
)
perf
.
Checkpoint
(
"%d - Infer"
%
i
)
if
test_dataset
.
load_images
:
util
.
WriteImageTensor
(
view_images
,
[
'%s
%s/
gt_view_%04d.png'
%
(
OUTPUT_DIR
,
TEST_NET_NAME
,
i
)
for
i
in
view_idxs
])
[
'%sgt_view_%04d.png'
%
(
output_dir
,
i
)
for
i
in
view_idxs
])
util
.
WriteImageTensor
(
out_view_images
,
[
'%s
%s/
out_view_%04d.png'
%
(
OUTPUT_DIR
,
TEST_NET_NAME
,
i
)
for
i
in
view_idxs
])
[
'%sout_view_%04d.png'
%
(
output_dir
,
i
)
for
i
in
view_idxs
])
perf
.
Checkpoint
(
"%d - Write"
%
i
)
i
+=
1
...
...
spher_net.py
View file @
c570c3b1
from
typing
import
List
,
Tuple
from
math
import
pi
from
typing
import
Tuple
import
torch
import
torch.nn
as
nn
from
.pytorch_prototyping.pytorch_prototyping
import
*
from
.my
import
net_modules
from
.my
import
util
from
.my
import
device
def
RaySphereIntersect
(
p
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
r
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Calculate intersections of each rays and each spheres
:param p: B x 3, positions of rays
:param v: B x 3, directions of rays
:param r: B'(1D), radius of spheres
:return: B x B' x 3, points of intersection
"""
# p, v: Expand to B x 1 x 3
p
=
p
.
unsqueeze
(
1
)
v
=
v
.
unsqueeze
(
1
)
# pp, vv, pv: B x 1
pp
=
(
p
*
p
).
sum
(
dim
=
2
)
vv
=
(
v
*
v
).
sum
(
dim
=
2
)
pv
=
(
p
*
v
).
sum
(
dim
=
2
)
# k: Expand to B x B' x 1
k
=
(((
pv
*
pv
-
vv
*
(
pp
-
r
*
r
)).
sqrt
()
-
pv
)
/
vv
).
unsqueeze
(
2
)
return
p
+
k
*
v
def
RayToSpherical
(
p
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
r
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Calculate intersections of each rays and each spheres
:param p: B x 3, positions of rays
:param v: B x 3, directions of rays
:param r: B' x 1, radius of spheres
:return: B x B' x 3, spherical coordinates
"""
p_on_spheres
=
RaySphereIntersect
(
p
,
v
,
r
)
return
util
.
CartesianToSpherical
(
p_on_spheres
)
class
SpherNet
(
nn
.
Module
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment