Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
Nianchen Deng
deeplightfield
Commits
5699ccbf
Commit
5699ccbf
authored
Dec 03, 2021
by
Nianchen Deng
Browse files
sync
parent
338ae906
Changes
152
Hide whitespace changes
Inline
Side-by-side
train/train_with_space.py
0 → 100644
View file @
5699ccbf
from
modules.sampler
import
Samples
from
modules.space
import
Octree
,
Voxels
from
utils.mem_profiler
import
MemProfiler
from
utils.misc
import
print_and_log
from
.base
import
*
class
TrainWithSpace
(
Train
):
def
__init__
(
self
,
model
:
BaseModel
,
pruning_loop
:
int
=
10000
,
splitting_loop
:
int
=
10000
,
**
kwargs
)
->
None
:
super
().
__init__
(
model
,
**
kwargs
)
self
.
pruning_loop
=
pruning_loop
self
.
splitting_loop
=
splitting_loop
#MemProfiler.enable = True
def
_train_epoch
(
self
):
if
not
self
.
perf_mode
:
if
self
.
epoch
!=
1
:
if
self
.
splitting_loop
==
1
or
self
.
epoch
%
self
.
splitting_loop
==
1
:
try
:
with
torch
.
no_grad
():
before
,
after
=
self
.
model
.
splitting
()
print_and_log
(
f
"Splitting done. # of voxels before:
{
before
}
, after:
{
after
}
"
)
except
NotImplementedError
:
print_and_log
(
"Note: The space does not support splitting operation. Just skip it."
)
if
self
.
pruning_loop
==
1
or
self
.
epoch
%
self
.
pruning_loop
==
1
:
try
:
with
torch
.
no_grad
():
#before, after = self.model.pruning()
# print(f"Pruning by voxel densities done. # of voxels before: {before}, after: {after}")
# self._prune_inner_voxels()
self
.
_prune_voxels_by_weights
()
except
NotImplementedError
:
print_and_log
(
"Note: The space does not support pruning operation. Just skip it."
)
super
().
_train_epoch
()
def
_prune_inner_voxels
(
self
):
space
:
Voxels
=
self
.
model
.
space
voxel_access_counts
=
torch
.
zeros
(
space
.
n_voxels
,
dtype
=
torch
.
long
,
device
=
space
.
voxels
.
device
)
iters_in_epoch
=
0
batch_size
=
self
.
data_loader
.
batch_size
self
.
data_loader
.
batch_size
=
2
**
14
for
_
,
rays_o
,
rays_d
,
_
in
self
.
data_loader
:
self
.
model
(
rays_o
,
rays_d
,
raymarching_early_stop_tolerance
=
0.01
,
raymarching_chunk_size_or_sections
=
[
1
],
perturb_sample
=
False
,
voxel_access_counts
=
voxel_access_counts
,
voxel_access_tolerance
=
0
)
iters_in_epoch
+=
1
percent
=
iters_in_epoch
/
len
(
self
.
data_loader
)
*
100
sys
.
stdout
.
write
(
f
'Pruning inner voxels...
{
percent
:.
1
f
}
%
\r
'
)
self
.
data_loader
.
batch_size
=
batch_size
before
,
after
=
space
.
prune
(
voxel_access_counts
>
0
)
print
(
f
"Prune inner voxels:
{
before
}
->
{
after
}
"
)
def
_prune_voxels_by_weights
(
self
):
space
:
Voxels
=
self
.
model
.
space
voxel_access_counts
=
torch
.
zeros
(
space
.
n_voxels
,
dtype
=
torch
.
long
,
device
=
space
.
voxels
.
device
)
iters_in_epoch
=
0
batch_size
=
self
.
data_loader
.
batch_size
self
.
data_loader
.
batch_size
=
2
**
14
for
_
,
rays_o
,
rays_d
,
_
in
self
.
data_loader
:
ret
=
self
.
model
(
rays_o
,
rays_d
,
raymarching_early_stop_tolerance
=
0
,
raymarching_chunk_size_or_sections
=
None
,
perturb_sample
=
False
,
extra_outputs
=
[
'weights'
])
valid_mask
=
ret
[
'weights'
][...,
0
]
>
0.01
accessed_voxels
=
ret
[
'samples'
].
voxel_indices
[
valid_mask
]
voxel_access_counts
.
index_add_
(
0
,
accessed_voxels
,
torch
.
ones_like
(
accessed_voxels
))
iters_in_epoch
+=
1
percent
=
iters_in_epoch
/
len
(
self
.
data_loader
)
*
100
sys
.
stdout
.
write
(
f
'Pruning by weights...
{
percent
:.
1
f
}
%
\r
'
)
self
.
data_loader
.
batch_size
=
batch_size
before
,
after
=
space
.
prune
(
voxel_access_counts
>
0
)
print_and_log
(
f
"Prune by weights:
{
before
}
->
{
after
}
"
)
def
_prune_voxels_by_voxel_weights
(
self
):
space
:
Voxels
=
self
.
model
.
space
voxel_access_counts
=
torch
.
zeros
(
space
.
n_voxels
,
dtype
=
torch
.
long
,
device
=
space
.
voxels
.
device
)
with
torch
.
no_grad
():
batch_size
=
self
.
data_loader
.
batch_size
self
.
data_loader
.
batch_size
=
2
**
14
iters_in_epoch
=
0
for
_
,
rays_o
,
rays_d
,
_
in
self
.
data_loader
:
ret
=
self
.
model
(
rays_o
,
rays_d
,
raymarching_early_stop_tolerance
=
0
,
raymarching_chunk_size_or_sections
=
None
,
perturb_sample
=
False
,
extra_outputs
=
[
'weights'
])
self
.
_accumulate_access_count_by_weight
(
ret
[
'samples'
],
ret
[
'weights'
][...,
0
],
voxel_access_counts
)
iters_in_epoch
+=
1
percent
=
iters_in_epoch
/
len
(
self
.
data_loader
)
*
100
sys
.
stdout
.
write
(
f
'Pruning by voxel weights...
{
percent
:.
1
f
}
%
\r
'
)
self
.
data_loader
.
batch_size
=
batch_size
before
,
after
=
space
.
prune
(
voxel_access_counts
>
0
)
print_and_log
(
f
"Prune by voxel weights:
{
before
}
->
{
after
}
"
)
def
_accumulate_access_count_by_weight
(
self
,
samples
:
Samples
,
weights
:
torch
.
Tensor
,
voxel_access_counts
:
torch
.
Tensor
):
uni_vidxs
=
-
torch
.
ones_like
(
samples
.
voxel_indices
)
vidx_accu
=
torch
.
zeros_like
(
samples
.
voxel_indices
,
dtype
=
torch
.
float
)
uni_vidxs_row
=
torch
.
arange
(
samples
.
size
[
0
],
dtype
=
torch
.
long
,
device
=
samples
.
device
)
uni_vidxs_head
=
torch
.
zeros_like
(
samples
.
voxel_indices
[:,
0
])
uni_vidxs
[:,
0
]
=
samples
.
voxel_indices
[:,
0
]
vidx_accu
[:,
0
].
add_
(
weights
[:,
0
])
for
i
in
range
(
samples
.
size
[
1
]):
# For those rows that voxels are changed, move the head one step forward
next_voxel
=
uni_vidxs
[
uni_vidxs_row
,
uni_vidxs_head
].
ne
(
samples
.
voxel_indices
[:,
i
])
uni_vidxs_head
[
next_voxel
].
add_
(
1
)
# Set voxel indices and accumulate weights
uni_vidxs
[
uni_vidxs_row
,
uni_vidxs_head
]
=
samples
.
voxel_indices
[:,
i
]
vidx_accu
[
uni_vidxs_row
,
uni_vidxs_head
].
add_
(
weights
[:,
i
])
max_accu
=
vidx_accu
.
max
(
dim
=
1
,
keepdim
=
True
)[
0
]
uni_vidxs
[
vidx_accu
<
max_accu
*
0.1
]
=
-
1
access_voxels
,
access_count
=
uni_vidxs
.
unique
(
return_counts
=
True
)
voxel_access_counts
[
access_voxels
[
1
:]].
add_
(
access_count
[
1
:])
train_oracle.py
View file @
5699ccbf
...
...
@@ -260,8 +260,8 @@ def train():
if
epochRange
.
start
>
1
:
iters
=
netio
.
load
(
f
'
{
run_dir
}
model-epoch_
{
epochRange
.
start
-
1
}
.pth'
,
model
)
else
:
misc
.
create_
dir
(
run_dir
)
misc
.
create_
dir
(
log_dir
)
os
.
make
dir
s
(
run_dir
,
exist_ok
=
True
)
os
.
make
dir
s
(
log_dir
,
exist_ok
=
True
)
iters
=
0
# 3. Train
...
...
@@ -333,7 +333,7 @@ def test():
# 4. Save results
print
(
'Saving results...'
)
misc
.
create_
dir
(
output_dir
)
os
.
make
dir
s
(
output_dir
,
exist_ok
=
True
)
for
key
in
out
:
shape
=
[
n
]
+
list
(
dataset
.
view_res
)
+
list
(
out
[
key
].
size
()[
1
:])
...
...
@@ -367,7 +367,7 @@ def test():
for
i
in
range
(
n
)
])
output_subdir
=
f
"
{
output_dir
}
/
{
output_dataset_id
}
_bins"
misc
.
create_
dir
(
output_subdir
)
os
.
make
dir
s
(
output_subdir
,
exist_ok
=
True
)
img
.
save
(
out
[
'bins'
],
[
f
'
{
output_subdir
}
/
{
i
:
0
>
4
d
}
.png'
for
i
in
dataset
.
view_idxs
])
...
...
upsampling/run_upsampling.py
View file @
5699ccbf
...
...
@@ -60,7 +60,7 @@ args.color = color.from_str(args.color)
def
train
():
misc
.
create_
dir
(
run_dir
)
os
.
make
dir
s
(
run_dir
,
exist_ok
=
True
)
train_set
=
UpsamplingDataset
(
'.'
,
'input/out_view_%04d.png'
,
'gt/view_%04d.png'
,
color
=
args
.
color
)
training_data_loader
=
FastDataLoader
(
dataset
=
train_set
,
...
...
@@ -80,7 +80,7 @@ def train():
def
test
():
misc
.
create_
dir
(
os
.
path
.
dirname
(
args
.
testOutPatt
))
os
.
make
dir
s
(
os
.
path
.
dirname
(
args
.
testOutPatt
)
,
exist_ok
=
True
)
train_set
=
UpsamplingDataset
(
'.'
,
'input/out_view_%04d.png'
,
None
,
color
=
args
.
color
)
training_data_loader
=
FastDataLoader
(
dataset
=
train_set
,
...
...
utils/constants.py
View file @
5699ccbf
...
...
@@ -2,4 +2,6 @@ import math
HUGE_FLOAT
=
1e10
TINY_FLOAT
=
1e-6
PI
=
math
.
pi
\ No newline at end of file
PI
=
math
.
pi
NAN
=
math
.
nan
E
=
math
.
e
\ No newline at end of file
utils/geometry.py
0 → 100644
View file @
5699ccbf
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from
typing
import
Union
import
numpy
as
np
import
torch
import
torch.nn.functional
as
F
INF
=
1000.0
def
ones_like
(
x
):
T
=
torch
if
isinstance
(
x
,
torch
.
Tensor
)
else
np
return
T
.
ones_like
(
x
)
def
stack
(
x
):
T
=
torch
if
isinstance
(
x
[
0
],
torch
.
Tensor
)
else
np
return
T
.
stack
(
x
)
def
matmul
(
x
,
y
):
T
=
torch
if
isinstance
(
x
,
torch
.
Tensor
)
else
np
return
T
.
matmul
(
x
,
y
)
def
cross
(
x
,
y
,
axis
=
0
):
T
=
torch
if
isinstance
(
x
,
torch
.
Tensor
)
else
np
return
T
.
cross
(
x
,
y
,
axis
)
def
cat
(
x
,
axis
=
1
):
if
isinstance
(
x
[
0
],
torch
.
Tensor
):
return
torch
.
cat
(
x
,
dim
=
axis
)
return
np
.
concatenate
(
x
,
axis
=
axis
)
def
normalize
(
x
,
axis
=-
1
,
order
=
2
):
if
isinstance
(
x
,
torch
.
Tensor
):
l2
=
x
.
norm
(
p
=
order
,
dim
=
axis
,
keepdim
=
True
)
return
x
/
(
l2
+
1e-8
),
l2
else
:
l2
=
np
.
linalg
.
norm
(
x
,
order
,
axis
)
l2
=
np
.
expand_dims
(
l2
,
axis
)
l2
[
l2
==
0
]
=
1
return
x
/
l2
,
l2
def
parse_extrinsics
(
extrinsics
,
world2camera
=
True
):
""" this function is only for numpy for now"""
if
extrinsics
.
shape
[
0
]
==
3
and
extrinsics
.
shape
[
1
]
==
4
:
extrinsics
=
np
.
vstack
([
extrinsics
,
np
.
array
([[
0
,
0
,
0
,
1.0
]])])
if
extrinsics
.
shape
[
0
]
==
1
and
extrinsics
.
shape
[
1
]
==
16
:
extrinsics
=
extrinsics
.
reshape
(
4
,
4
)
if
world2camera
:
extrinsics
=
np
.
linalg
.
inv
(
extrinsics
).
astype
(
np
.
float32
)
return
extrinsics
def
parse_intrinsics
(
intrinsics
):
fx
=
intrinsics
[
0
,
0
]
fy
=
intrinsics
[
1
,
1
]
cx
=
intrinsics
[
0
,
2
]
cy
=
intrinsics
[
1
,
2
]
return
fx
,
fy
,
cx
,
cy
def
uv2cam
(
uv
,
z
,
intrinsics
,
homogeneous
=
False
):
fx
,
fy
,
cx
,
cy
=
parse_intrinsics
(
intrinsics
)
x_lift
=
(
uv
[
0
]
-
cx
)
/
fx
*
z
y_lift
=
(
uv
[
1
]
-
cy
)
/
fy
*
z
z_lift
=
ones_like
(
x_lift
)
*
z
if
homogeneous
:
return
stack
([
x_lift
,
y_lift
,
z_lift
,
ones_like
(
z_lift
)])
else
:
return
stack
([
x_lift
,
y_lift
,
z_lift
])
def
cam2world
(
xyz_cam
,
inv_RT
):
return
matmul
(
inv_RT
,
xyz_cam
)[:
3
]
def
r6d2mat
(
d6
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Converts 6D rotation representation by Zhou et al. [1] to rotation matrix
using Gram--Schmidt orthogonalisation per Section B of [1].
Args:
d6: 6D rotation representation, of size (*, 6)
Returns:
batch of rotation matrices of size (*, 3, 3)
[1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H.
On the Continuity of Rotation Representations in Neural Networks.
IEEE Conference on Computer Vision and Pattern Recognition, 2019.
Retrieved from http://arxiv.org/abs/1812.07035
"""
a1
,
a2
=
d6
[...,
:
3
],
d6
[...,
3
:]
b1
=
F
.
normalize
(
a1
,
dim
=-
1
)
b2
=
a2
-
(
b1
*
a2
).
sum
(
-
1
,
keepdim
=
True
)
*
b1
b2
=
F
.
normalize
(
b2
,
dim
=-
1
)
b3
=
torch
.
cross
(
b1
,
b2
,
dim
=-
1
)
return
torch
.
stack
((
b1
,
b2
,
b3
),
dim
=-
2
)
def
get_ray_direction
(
ray_start
,
uv
,
intrinsics
,
inv_RT
,
depths
=
None
):
if
depths
is
None
:
depths
=
1
rt_cam
=
uv2cam
(
uv
,
depths
,
intrinsics
,
True
)
rt
=
cam2world
(
rt_cam
,
inv_RT
)
ray_dir
,
_
=
normalize
(
rt
-
ray_start
[:,
None
],
axis
=
0
)
return
ray_dir
def
look_at_rotation
(
camera_position
,
at
=
None
,
up
=
None
,
inverse
=
False
,
cv
=
False
):
"""
This function takes a vector 'camera_position' which specifies the location
of the camera in world coordinates and two vectors `at` and `up` which
indicate the position of the object and the up directions of the world
coordinate system respectively. The object is assumed to be centered at
the origin.
The output is a rotation matrix representing the transformation
from world coordinates -> view coordinates.
Input:
camera_position: 3
at: 1 x 3 or N x 3 (0, 0, 0) in default
up: 1 x 3 or N x 3 (0, 1, 0) in default
"""
if
at
is
None
:
at
=
torch
.
zeros_like
(
camera_position
)
else
:
at
=
torch
.
tensor
(
at
).
type_as
(
camera_position
)
if
up
is
None
:
up
=
torch
.
zeros_like
(
camera_position
)
up
[
2
]
=
-
1
else
:
up
=
torch
.
tensor
(
up
).
type_as
(
camera_position
)
z_axis
=
normalize
(
at
-
camera_position
)[
0
]
x_axis
=
normalize
(
cross
(
up
,
z_axis
))[
0
]
y_axis
=
normalize
(
cross
(
z_axis
,
x_axis
))[
0
]
R
=
cat
([
x_axis
[:,
None
],
y_axis
[:,
None
],
z_axis
[:,
None
]],
axis
=
1
)
return
R
def
ray
(
ray_start
,
ray_dir
,
depths
):
return
ray_start
+
ray_dir
*
depths
def
compute_normal_map
(
ray_start
,
ray_dir
,
depths
,
RT
,
width
=
512
,
proj
=
False
):
raise
NotImplementedError
(
"This function needs fairnr.data.data_utils to work. "
"Will remove this dependency later."
)
# TODO:
# this function is pytorch-only (for not)
wld_coords
=
ray
(
ray_start
,
ray_dir
,
depths
.
unsqueeze
(
-
1
)).
transpose
(
0
,
1
)
cam_coords
=
matmul
(
RT
[:
3
,
:
3
],
wld_coords
)
+
RT
[:
3
,
3
].
unsqueeze
(
-
1
)
cam_coords
=
D
.
unflatten_img
(
cam_coords
,
width
)
# estimate local normal
shift_l
=
cam_coords
[:,
2
:,
:]
shift_r
=
cam_coords
[:,
:
-
2
,
:]
shift_u
=
cam_coords
[:,
:,
2
:]
shift_d
=
cam_coords
[:,
:,
:
-
2
]
diff_hor
=
normalize
(
shift_r
-
shift_l
,
axis
=
0
)[
0
][:,
:,
1
:
-
1
]
diff_ver
=
normalize
(
shift_u
-
shift_d
,
axis
=
0
)[
0
][:,
1
:
-
1
,
:]
normal
=
cross
(
diff_hor
,
diff_ver
)
_normal
=
normal
.
new_zeros
(
*
cam_coords
.
size
())
_normal
[:,
1
:
-
1
,
1
:
-
1
]
=
normal
_normal
=
_normal
.
reshape
(
3
,
-
1
).
transpose
(
0
,
1
)
# compute the projected color
if
proj
:
_normal
=
normalize
(
_normal
,
axis
=
1
)[
0
]
wld_coords0
=
ray
(
ray_start
,
ray_dir
,
0
).
transpose
(
0
,
1
)
cam_coords0
=
matmul
(
RT
[:
3
,
:
3
],
wld_coords0
)
+
RT
[:
3
,
3
].
unsqueeze
(
-
1
)
cam_coords0
=
D
.
unflatten_img
(
cam_coords0
,
width
)
cam_raydir
=
normalize
(
cam_coords
-
cam_coords0
,
0
)[
0
].
reshape
(
3
,
-
1
).
transpose
(
0
,
1
)
proj_factor
=
(
_normal
*
cam_raydir
).
sum
(
-
1
).
abs
()
*
0.8
+
0.2
return
proj_factor
return
_normal
# helper functions for encoder
def
padding_points
(
xs
,
pad
):
if
len
(
xs
)
==
1
:
return
xs
[
0
].
unsqueeze
(
0
)
maxlen
=
max
([
x
.
size
(
0
)
for
x
in
xs
])
xt
=
xs
[
0
].
new_ones
(
len
(
xs
),
maxlen
,
xs
[
0
].
size
(
1
)).
fill_
(
pad
)
for
i
in
range
(
len
(
xs
)):
xt
[
i
,
:
xs
[
i
].
size
(
0
)]
=
xs
[
i
]
return
xt
def
pruning_points
(
feats
,
points
,
scores
,
depth
=
0
,
th
=
0.5
):
if
depth
>
0
:
g
=
int
(
8
**
depth
)
scores
=
scores
.
reshape
(
scores
.
size
(
0
),
-
1
,
g
).
sum
(
-
1
,
keepdim
=
True
)
scores
=
scores
.
expand
(
*
scores
.
size
()[:
2
],
g
).
reshape
(
scores
.
size
(
0
),
-
1
)
alpha
=
(
1
-
torch
.
exp
(
-
scores
))
>
th
feats
=
[
feats
[
i
][
alpha
[
i
]]
for
i
in
range
(
alpha
.
size
(
0
))]
points
=
[
points
[
i
][
alpha
[
i
]]
for
i
in
range
(
alpha
.
size
(
0
))]
points
=
padding_points
(
points
,
INF
)
feats
=
padding_points
(
feats
,
0
)
return
feats
,
points
def
offset_points
(
point_xyz
:
torch
.
Tensor
,
half_voxel
:
Union
[
torch
.
Tensor
,
int
,
float
]
=
1
,
offset_only
:
bool
=
False
,
bits
:
int
=
2
)
->
torch
.
Tensor
:
"""
[summary]
:param point_xyz `Tensor(N, 3)`: [description]
:param half_voxel `Tensor(1) | int | float`: [description], defaults to 1
:param offset_only `bool`: [description], defaults to False
:param bits `int`: [description], defaults to 2
:return `Tensor(N, X, 3)|Tensor(X, 3)`: [description]
"""
c
=
torch
.
arange
(
1
-
bits
,
bits
,
2
,
dtype
=
point_xyz
.
dtype
,
device
=
point_xyz
.
device
)
offset
=
(
torch
.
stack
(
torch
.
meshgrid
(
c
,
c
,
c
),
dim
=-
1
).
reshape
(
-
1
,
3
))
/
(
bits
-
1
)
*
half_voxel
return
offset
if
offset_only
else
point_xyz
[:,
None
]
+
offset
def
discretize_points
(
voxel_points
,
voxel_size
):
# this function turns voxel centers/corners into integer indeices
# we assume all points are alreay put as voxels (real numbers)
minimal_voxel_point
=
voxel_points
.
min
(
dim
=
0
,
keepdim
=
True
)[
0
]
voxel_indices
=
((
voxel_points
-
minimal_voxel_point
)
/
voxel_size
).
round_
().
long
()
# float
residual
=
(
voxel_points
-
voxel_indices
.
type_as
(
voxel_points
)
*
voxel_size
).
mean
(
0
,
keepdim
=
True
)
return
voxel_indices
,
residual
def
expand_points
(
voxel_points
,
voxel_size
):
_voxel_size
=
min
([
torch
.
sqrt
(((
voxel_points
[
j
:
j
+
1
]
-
voxel_points
[
j
+
1
:])
**
2
).
sum
(
-
1
).
min
())
for
j
in
range
(
100
)])
depth
=
int
(
np
.
round
(
torch
.
log2
(
_voxel_size
/
voxel_size
)))
if
depth
>
0
:
half_voxel
=
_voxel_size
/
2.0
for
_
in
range
(
depth
):
voxel_points
=
offset_points
(
voxel_points
,
half_voxel
/
2.0
).
reshape
(
-
1
,
3
)
half_voxel
=
half_voxel
/
2.0
return
voxel_points
,
depth
def
get_edge
(
depth_pts
,
voxel_pts
,
voxel_size
,
th
=
0.05
):
voxel_pts
=
offset_points
(
voxel_pts
,
voxel_size
/
2.0
)
diff_pts
=
(
voxel_pts
-
depth_pts
[:,
None
,
:]).
norm
(
dim
=
2
)
ab
=
diff_pts
.
sort
(
dim
=
1
)[
0
][:,
:
2
]
a
,
b
=
ab
[:,
0
],
ab
[:,
1
]
c
=
voxel_size
p
=
(
ab
.
sum
(
-
1
)
+
c
)
/
2.0
h
=
(
p
*
(
p
-
a
)
*
(
p
-
b
)
*
(
p
-
c
))
**
0.5
/
c
return
h
<
(
th
*
voxel_size
)
# fill-in image
def
fill_in
(
shape
,
hits
,
input
,
initial
=
1.0
):
input_sizes
=
[
k
for
k
in
input
.
size
()]
if
(
len
(
input_sizes
)
==
len
(
shape
))
and
\
all
([
shape
[
i
]
==
input_sizes
[
i
]
for
i
in
range
(
len
(
shape
))]):
return
input
# shape is the same no need to fill
if
isinstance
(
initial
,
torch
.
Tensor
):
output
=
initial
.
expand
(
*
shape
)
else
:
output
=
input
.
new_ones
(
*
shape
)
*
initial
if
input
is
not
None
:
if
len
(
shape
)
==
1
:
return
output
.
masked_scatter
(
hits
,
input
)
return
output
.
masked_scatter
(
hits
.
unsqueeze
(
-
1
).
expand
(
*
shape
),
input
)
return
output
utils/img.py
View file @
5699ccbf
import
os
from
pathlib
import
Path
import
shutil
import
torch
import
matplotlib.pyplot
as
plt
import
numpy
as
np
import
torch.nn.functional
as
nn_f
from
typing
import
Tuple
from
typing
import
List
,
Tuple
,
Union
from
.
import
misc
from
.constants
import
*
...
...
@@ -65,7 +66,7 @@ def load(*paths: str, permute=True, with_alpha=False) -> torch.Tensor:
chns
=
4
if
with_alpha
else
3
new_paths
=
[]
for
path
in
paths
:
new_paths
+=
[
path
]
if
isinstance
(
path
,
str
)
else
list
(
path
)
new_paths
+=
[
path
]
if
isinstance
(
path
,
(
str
,
Path
)
)
else
list
(
path
)
imgs
=
np
.
stack
([
plt
.
imread
(
path
)[...,
:
chns
]
for
path
in
new_paths
])
if
imgs
.
dtype
==
'uint8'
:
imgs
=
imgs
.
astype
(
np
.
float32
)
/
255
...
...
@@ -76,7 +77,7 @@ def load_seq(path: str, n: int, permute=True, with_alpha=False) -> torch.Tensor:
return
load
([
path
%
i
for
i
in
range
(
n
)],
permute
=
permute
,
with_alpha
=
with_alpha
)
def
save
(
input
:
torch
.
Tensor
,
*
paths
:
str
):
def
save
(
input
:
torch
.
Tensor
,
*
paths
:
Union
[
str
,
Path
,
List
[
Union
[
str
,
Path
]]]
):
"""
Save one or multiple torch-image(s) to `paths`
...
...
@@ -86,7 +87,7 @@ def save(input: torch.Tensor, *paths: str):
"""
new_paths
=
[]
for
path
in
paths
:
new_paths
+=
[
path
]
if
isinstance
(
path
,
str
)
else
list
(
path
)
new_paths
+=
[
path
]
if
isinstance
(
path
,
(
str
,
Path
)
)
else
list
(
path
)
if
len
(
input
.
size
())
<
4
:
input
=
input
[
None
]
if
input
.
size
(
0
)
!=
len
(
new_paths
):
...
...
@@ -100,9 +101,9 @@ def save(input: torch.Tensor, *paths: str):
plt
.
imsave
(
path
,
np_img
[
i
])
def
save_seq
(
input
:
torch
.
Tensor
,
path
:
str
):
def
save_seq
(
input
:
torch
.
Tensor
,
path
:
Union
[
str
,
Path
]
):
n
=
1
if
len
(
input
.
size
())
<=
3
else
input
.
size
(
0
)
return
save
(
input
,
[
path
%
i
for
i
in
range
(
n
)])
return
save
(
input
,
[
str
(
path
)
%
i
for
i
in
range
(
n
)])
def
plot
(
input
:
torch
.
Tensor
,
*
,
ax
:
plt
.
Axes
=
None
):
...
...
@@ -118,7 +119,7 @@ def plot(input: torch.Tensor, *, ax: plt.Axes = None):
return
plt
.
imshow
(
im
)
if
ax
is
None
else
ax
.
imshow
(
im
)
def
save_video
(
frames
:
torch
.
Tensor
,
path
:
str
,
fps
:
int
,
def
save_video
(
frames
:
torch
.
Tensor
,
path
:
Union
[
str
,
Path
]
,
fps
:
int
,
repeat
:
int
=
1
,
pingpong
:
bool
=
False
):
"""
Encode and save a sequence of frames as video file
...
...
@@ -134,19 +135,16 @@ def save_video(frames: torch.Tensor, path: str, fps: int,
frames
=
torch
.
cat
([
frames
,
frames
.
flip
(
0
)],
0
)
if
repeat
>
1
:
frames
=
frames
.
expand
(
repeat
,
-
1
,
-
1
,
-
1
,
-
1
).
flatten
(
0
,
1
)
dir
,
file_name
=
os
.
path
.
split
(
path
)
if
not
dir
:
dir
=
'./'
misc
.
create_dir
(
dir
)
cwd
=
os
.
getcwd
()
os
.
chdir
(
dir
)
temp_out_dir
=
os
.
path
.
splitext
(
file_name
)[
0
]
+
'_tempout'
misc
.
create_dir
(
temp_out_dir
)
os
.
chdir
(
temp_out_dir
)
save_seq
(
frames
,
'out_%04d.png'
)
os
.
system
(
f
'ffmpeg -y -r
{
fps
:
d
}
-i out_%04d.png -c:v libx264 ../
{
file_name
}
'
)
os
.
chdir
(
cwd
)
shutil
.
rmtree
(
os
.
path
.
join
(
dir
,
temp_out_dir
))
path
=
Path
(
path
)
tempdir
=
Path
(
'/dev/shm/dvs_tmp/video'
)
inferout
=
tempdir
/
path
.
stem
/
f
"%04d.bmp"
os
.
makedirs
(
inferout
.
parent
,
exist_ok
=
True
)
os
.
makedirs
(
path
.
parent
,
exist_ok
=
True
)
save_seq
(
frames
,
inferout
)
os
.
system
(
f
'ffmpeg -y -r
{
fps
:
d
}
-i
{
inferout
}
-c:v libx264
{
path
}
'
)
shutil
.
rmtree
(
inferout
.
parent
)
def
horizontal_shift
(
input
:
torch
.
Tensor
,
offset
:
int
,
dim
=-
1
)
->
torch
.
Tensor
:
...
...
utils/mem_profiler.py
View file @
5699ccbf
...
...
@@ -2,13 +2,14 @@ from cgitb import enable
import
torch
from
.device
import
*
class
MemProfiler
:
enable
=
False
@
staticmethod
def
print_memory_stats
(
prefix
,
last_allocated
=
None
,
device
=
None
):
if
not
MemProfiler
.
enable
:
def
print_memory_stats
(
prefix
,
last_allocated
=
None
,
device
=
None
,
enable_once
=
False
):
if
not
enable_once
and
not
MemProfiler
.
enable
:
return
if
device
is
None
:
device
=
default
()
...
...
utils/misc.py
View file @
5699ccbf
import
os
from
itertools
import
repeat
import
logging
from
pathlib
import
Path
import
re
import
shutil
import
torch
import
glm
import
csv
import
numpy
as
np
from
typing
import
List
,
Tuple
,
Union
from
typing
import
List
,
Union
from
torch.types
import
Number
from
.constants
import
*
from
.device
import
*
...
...
@@ -59,31 +63,11 @@ def meshgrid(*size: int, normalize: bool = False, swap_dim: bool = False) -> tor
return
torch
.
stack
([
x
/
(
size
[
1
]
-
1.
),
y
/
(
size
[
0
]
-
1.
)],
2
)
if
normalize
else
torch
.
stack
([
x
,
y
],
2
)
def
create_dir
(
path
):
if
not
os
.
path
.
exists
(
path
):
os
.
makedirs
(
path
)
def
get_angle
(
x
:
torch
.
Tensor
,
y
:
torch
.
Tensor
)
->
torch
.
Tensor
:
angle
=
-
torch
.
atan
(
x
/
y
)
+
(
y
<
0
)
*
PI
+
0.5
*
PI
angle
=
-
torch
.
atan
(
x
/
y
)
-
(
y
<
0
)
*
PI
+
0.5
*
PI
return
angle
def
depth_sample
(
depth_range
:
Tuple
[
float
,
float
],
n
:
int
,
lindisp
:
bool
)
->
torch
.
Tensor
:
"""
Get [n_layers] foreground layers whose diopters are distributed uniformly
in [depth_range] plus a background layer
:param depth_range: depth range of foreground layers
:param n_layers: number of foreground layers
:return: list of [n_layers+1] depths
"""
if
lindisp
:
depth_range
=
(
1
/
depth_range
[
0
],
1
/
depth_range
[
1
])
samples
=
torch
.
linspace
(
depth_range
[
0
],
depth_range
[
1
],
n
)
return
samples
def
broadcast_cat
(
input
:
torch
.
Tensor
,
s
:
Union
[
Number
,
List
[
Number
],
torch
.
Tensor
],
dim
=-
1
,
...
...
@@ -130,4 +114,73 @@ def view_like(input: torch.Tensor, ref: torch.Tensor) -> torch.Tensor:
return
input
.
view
(
out_shape
)
def
values
(
map
,
*
keys
):
return
list
(
map
[
key
]
for
key
in
keys
)
def
format_time
(
seconds
):
days
=
int
(
seconds
/
3600
/
24
)
seconds
=
seconds
-
days
*
3600
*
24
hours
=
int
(
seconds
/
3600
)
seconds
=
seconds
-
hours
*
3600
minutes
=
int
(
seconds
/
60
)
seconds
=
seconds
-
minutes
*
60
seconds_final
=
int
(
seconds
)
seconds
=
seconds
-
seconds_final
millis
=
int
(
seconds
*
1000
)
if
days
>
0
:
output
=
f
"
{
days
}
D
{
hours
:
0
>
2
d
}
h
{
minutes
:
0
>
2
d
}
m"
elif
hours
>
0
:
output
=
f
"
{
hours
:
0
>
2
d
}
h
{
minutes
:
0
>
2
d
}
m
{
seconds_final
:
0
>
2
d
}
s"
elif
minutes
>
0
:
output
=
f
"
{
minutes
:
0
>
2
d
}
m
{
seconds_final
:
0
>
2
d
}
s"
elif
seconds_final
>
0
:
output
=
f
"
{
seconds_final
:
0
>
2
d
}
s
{
millis
:
0
>
3
d
}
ms"
elif
millis
>
0
:
output
=
f
"
{
millis
:
0
>
3
d
}
ms"
else
:
output
=
'0ms'
return
output
def
print_and_log
(
s
):
print
(
s
)
logging
.
info
(
s
)
def
masked_scatter
(
mask
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
initial
:
Union
[
torch
.
Tensor
,
Number
]
=
0
):
"""
Extend PyTorch's built-in `masked_scatter` function
:param mask `Tensor(M...)`: the boolean mask
:param value `Tensor(N, D...)`: the value to fill in with, should have at least as many elements
as the number of ones in `mask`
:param destination `Tensor(M..., D...)`: (optional) the destination tensor to fill,
if not specified, a new tensor filled with
`empty_value` will be created and used as destination
:param empty_value `Number`: the initial elements in the newly created destination tensor,
defaults to 0
:return `Tensor(M..., D...)`: the destination tensor after filled
"""
M_
=
mask
.
size
()
D_
=
value
.
size
()[
1
:]
if
not
isinstance
(
initial
,
torch
.
Tensor
):
initial
=
value
.
new_full
([
*
M_
,
*
D_
],
initial
)
return
initial
.
masked_scatter
(
mask
.
reshape
(
*
M_
,
*
repeat
(
1
,
len
(
D_
))),
value
)
def
list_epochs
(
dir
:
Path
,
pattern
:
str
)
->
List
[
int
]:
prefix
=
pattern
.
split
(
"*"
)[
0
]
epoch_list
=
[
int
(
str
(
path
.
stem
)[
len
(
prefix
):])
for
path
in
dir
.
glob
(
pattern
)]
epoch_list
.
sort
()
return
epoch_list
def
rename_seqs_with_offset
(
dir
:
Path
,
file_pattern
:
str
,
offset
:
int
):
start
,
end
=
re
.
search
(
r
'%0\dd'
,
file_pattern
).
span
()
prefix
,
suffix
=
start
,
len
(
file_pattern
)
-
end
seqs
=
[
int
(
path
.
name
[
prefix
:
-
suffix
])
for
path
in
dir
.
glob
(
re
.
sub
(
r
'%0\dd'
,
"*"
,
file_pattern
))
]
seqs
.
sort
(
reverse
=
offset
>
0
)
for
i
in
seqs
:
(
dir
/
(
file_pattern
%
i
)).
rename
(
dir
/
(
file_pattern
%
(
i
+
offset
)))
utils/perf.py
View file @
5699ccbf
from
numpy
import
average
import
torch
import
torch.cuda
from
typing
import
Dict
,
List
,
OrderedDict
class
Perf
(
object
):
frames
:
List
[
Dict
[
str
,
float
]]
def
__init__
(
self
,
enable
,
start
=
False
)
->
None
:
class
Node
:
def
__init__
(
self
,
name
,
parent
=
None
)
->
None
:
self
.
name
=
name
self
.
parent
=
parent
self
.
events
=
[]
self
.
event_names
=
[]
self
.
child_nodes
=
[]
self
.
child_nodes_event_idx
=
[]
self
.
add_checkpoint
(
"Start"
)
def
add_checkpoint
(
self
,
name
):
event
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
event
.
record
()
self
.
events
.
append
(
event
)
self
.
event_names
.
append
(
name
)
def
add_child
(
self
,
name
):
child
=
Perf
.
Node
(
name
,
self
)
self
.
child_nodes
.
append
(
child
)
self
.
child_nodes_event_idx
.
append
(
len
(
self
.
events
))
return
child
def
close
(
self
):
self
.
add_checkpoint
(
"End"
)
return
self
.
parent
def
duration
(
self
,
i0
=
0
,
i1
=-
1
)
->
float
:
return
self
.
events
[
i0
].
elapsed_time
(
self
.
events
[
i1
])
def
result
(
self
,
prefix
:
str
=
''
)
->
OrderedDict
[
str
,
float
]:
path
=
f
"
{
prefix
}{
self
.
name
}
"
res
=
{
path
:
self
.
duration
()}
j
=
0
for
i
in
range
(
1
,
len
(
self
.
events
)
-
1
):
event_path
=
f
"
{
path
}
/
{
self
.
event_names
[
i
]
}
"
res
[
event_path
]
=
self
.
duration
(
i
-
1
,
i
)
while
j
<
len
(
self
.
child_nodes
):
if
self
.
child_nodes_event_idx
[
j
]
>
i
:
break
res
.
update
(
self
.
child_nodes
[
j
].
result
(
f
"
{
event_path
}
/"
))
j
+=
1
while
j
<
len
(
self
.
child_nodes
):
res
.
update
(
self
.
child_nodes
[
j
].
result
(
f
"
{
path
}
/"
))
j
+=
1
return
res
def
__init__
(
self
)
->
None
:
super
().
__init__
()
self
.
enable
=
enable
self
.
start_event
=
None
if
start
:
self
.
start
()
def
start
(
self
):
if
not
self
.
enable
:
return
if
self
.
start_event
==
None
:
self
.
start_event
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
self
.
end_event
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
torch
.
cuda
.
synchronize
()
self
.
start_event
.
record
()
def
checkpoint
(
self
,
name
:
str
=
None
,
end
:
bool
=
False
):
if
not
self
.
enable
:
return
0
self
.
end_event
.
record
()
torch
.
cuda
.
synchronize
()
duration
=
self
.
start_event
.
elapsed_time
(
self
.
end_event
)
if
name
:
print
(
'%s: %.1fms'
%
(
name
,
duration
))
if
not
end
:
self
.
start_event
.
record
()
return
duration
self
.
root_node
=
None
self
.
current_node
=
None
self
.
frames
=
[]
def
start_node
(
self
,
name
):
if
self
.
current_node
is
None
:
self
.
root_node
=
self
.
current_node
=
Perf
.
Node
(
name
)
else
:
self
.
current_node
=
self
.
current_node
.
add_child
(
name
)
def
checkpoint
(
self
,
name
):
self
.
current_node
.
add_checkpoint
(
name
)
def
end_node
(
self
):
self
.
current_node
=
self
.
current_node
.
close
()
if
self
.
current_node
is
None
:
torch
.
cuda
.
synchronize
()
self
.
frames
.
append
(
self
.
root_node
.
result
())
def
get_result
(
self
,
i
=
None
):
if
i
is
not
None
:
return
self
.
frames
[
i
]
if
len
(
self
.
frames
)
==
0
:
return
{}
res
=
{
key
:
[
val
]
for
key
,
val
in
self
.
frames
[
0
].
items
()}
for
i
in
range
(
1
,
len
(
self
.
frames
)):
for
key
,
val
in
self
.
frames
[
i
].
items
():
res
[
key
].
append
(
val
)
return
{
key
:
average
(
val
)
for
key
,
val
in
res
.
items
()}
default_perf_object
=
None
def
enable_perf
():
global
default_perf_object
default_perf_object
=
Perf
()
def
perf
(
fn_or_name
):
if
isinstance
(
fn_or_name
,
str
):
name
=
fn_or_name
def
perf_with_name
(
fn
):
def
wrap_perf
(
*
args
,
**
kwargs
):
start_node
(
name
)
ret
=
fn
(
*
args
,
**
kwargs
)
end_node
()
return
ret
return
wrap_perf
return
perf_with_name
fn
=
fn_or_name
def
wrap_perf
(
*
args
,
**
kwargs
):
start_node
(
fn
.
__qualname__
)
ret
=
fn
(
*
args
,
**
kwargs
)
end_node
()
return
ret
return
wrap_perf
def
start_node
(
name
):
if
default_perf_object
is
not
None
:
default_perf_object
.
start_node
(
name
)
def
end_node
():
if
default_perf_object
is
not
None
:
default_perf_object
.
end_node
()
def
checkpoint
(
name
):
if
default_perf_object
is
not
None
:
default_perf_object
.
checkpoint
(
name
)
def
get_perf_result
(
i
=
None
):
if
default_perf_object
is
not
None
:
return
default_perf_object
.
get_result
(
i
)
return
None
utils/progress_bar.py
View file @
5699ccbf
import
shutil
import
sys
import
time
import
os
from
.misc
import
format_time
from
.constants
import
NAN
bar_length
=
50
LAST_T
=
time
.
time
()
BEGIN_T
=
LAST_T
last_time
=
time
.
time
()
begin_time
=
last_time
def
get_terminal_columns
():
return
os
.
get_terminal_size
().
columns
def
progress_bar
(
current
,
total
,
msg
=
None
,
premsg
=
None
):
global
LAST_T
,
BEGIN_T
def
progress_bar
(
current
,
total
,
msg
=
None
,
premsg
=
None
,
barmsg
=
None
):
global
last_time
,
begin_time
if
current
==
0
:
BEGIN_T
=
time
.
time
()
# Reset for new bar.
begin_time
=
time
.
time
()
# Reset for new bar.
current_time
=
time
.
time
()
step_time
=
current_time
-
LAST_T
LAST_T
=
current_time
total_time
=
current_time
-
BEGIN_T
step_time
=
current_time
-
last_time
total_time
=
current_time
-
begin_time
last_time
=
current_time
estimated_time
=
0
if
current
==
0
else
total_time
/
current
*
(
total
-
current
)
show_opt
=
int
(
current_time
)
%
6
>=
3
and
current
<
total
show_barmsg
=
barmsg
is
not
None
and
show_opt
str0
=
f
"
{
premsg
}
["
if
premsg
else
'['
str1
=
f
"]
{
current
+
1
:
d
}
/
{
total
:
d
}
| Step:
{
format_time
(
step_time
)
}
| Tot:
{
format_time
(
total_time
)
}
"
str1
=
f
"]
{
current
:
d
}
/
{
total
:
d
}
| Step:
{
format_time
(
step_time
)
}
| "
+
(
f
"Eta:
{
format_time
(
estimated_time
)
}
"
if
show_opt
else
f
"Tot:
{
format_time
(
total_time
)
}
"
)
if
msg
:
str1
+=
f
" |
{
msg
}
"
tot_cols
=
get_terminal_columns
()
tot_cols
=
shutil
.
get_terminal_
size
().
columns
-
10
bar_length
=
tot_cols
-
len
(
str0
)
-
len
(
str1
)
current_len
=
int
(
bar_length
*
(
current
+
1
)
/
total
)
rest_len
=
int
(
bar_length
-
current_len
)
if
current_len
==
0
:
str_bar
=
'.'
*
rest_len
if
show_barmsg
and
bar_length
<
len
(
barmsg
):
sys
.
stdout
.
write
(
str0
[:
-
1
]
+
barmsg
)
elif
bar_length
<=
0
:
sys
.
stdout
.
write
(
str0
[:
-
1
]
+
str1
[
2
:])
else
:
str_bar
=
'='
*
(
current_len
-
1
)
+
'>'
+
'.'
*
rest_len
sys
.
stdout
.
write
(
str0
+
str_bar
+
str1
)
if
current
<
total
-
1
:
sys
.
stdout
.
write
(
'
\r
'
)
else
:
sys
.
stdout
.
write
(
'
\n
'
)
current_len
=
int
(
bar_length
*
current
/
total
)
rest_len
=
int
(
bar_length
-
current_len
)
str_bar
=
''
if
current_len
>
0
:
str_bar
+=
'='
*
(
current_len
-
1
)
+
'>'
str_bar
+=
'.'
*
rest_len
if
show_barmsg
:
str_bar
=
barmsg
+
str_bar
[
len
(
barmsg
):]
sys
.
stdout
.
write
(
str0
+
str_bar
+
str1
)
sys
.
stdout
.
write
(
'
\r
'
if
current
<
total
else
'
\n
'
)
sys
.
stdout
.
flush
()
# return the formatted time
def
format_time
(
seconds
):
days
=
int
(
seconds
/
3600
/
24
)
seconds
=
seconds
-
days
*
3600
*
24
hours
=
int
(
seconds
/
3600
)
seconds
=
seconds
-
hours
*
3600
minutes
=
int
(
seconds
/
60
)
seconds
=
seconds
-
minutes
*
60
seconds_final
=
int
(
seconds
)
seconds
=
seconds
-
seconds_final
millis
=
int
(
seconds
*
1000
)
output
=
''
time_index
=
1
if
days
>
0
:
output
+=
str
(
days
)
+
'D'
time_index
+=
1
if
hours
>
0
and
time_index
<=
2
:
output
+=
str
(
hours
)
+
'h'
time_index
+=
1
if
minutes
>
0
and
time_index
<=
2
:
output
+=
str
(
minutes
)
+
'm'
time_index
+=
1
if
seconds_final
>
0
and
time_index
<=
2
:
output
+=
'%02ds'
%
seconds_final
time_index
+=
1
if
millis
>
0
and
time_index
<=
2
:
output
+=
'%03dms'
%
millis
time_index
+=
1
if
output
==
''
:
output
=
'0ms'
return
output
utils/sphere.py
View file @
5699ccbf
from
typing
import
List
,
Union
from
typing
import
Union
import
torch
import
math
from
.
import
misc
...
...
@@ -13,12 +13,12 @@ def cartesian2spherical(cart: torch.Tensor, inverse_r: bool = False) -> torch.Te
:return `Tensor(..., 3)`: coordinates in Spherical (r, theta, phi)
"""
rho
=
torch
.
sqrt
(
torch
.
sum
(
cart
*
cart
,
dim
=-
1
))
theta
=
misc
.
get_angle
(
cart
[...,
0
],
cart
[...,
2
])
theta
=
misc
.
get_angle
(
cart
[...,
2
],
cart
[...,
0
])
if
inverse_r
:
rho
=
rho
.
reciprocal
()
phi
=
torch
.
a
cos
(
cart
[...,
1
]
*
rho
)
phi
=
torch
.
a
sin
(
cart
[...,
1
]
*
rho
)
else
:
phi
=
torch
.
a
cos
(
cart
[...,
1
]
/
rho
)
phi
=
torch
.
a
sin
(
cart
[...,
1
]
/
rho
)
return
torch
.
stack
([
rho
,
theta
,
phi
],
dim
=-
1
)
...
...
@@ -34,9 +34,9 @@ def spherical2cartesian(spher: torch.Tensor, inverse_r: bool = False) -> torch.T
rho
=
rho
.
reciprocal
()
sin_theta_phi
=
torch
.
sin
(
spher
[...,
1
:
3
])
cos_theta_phi
=
torch
.
cos
(
spher
[...,
1
:
3
])
x
=
rho
*
cos
_theta_phi
[...,
0
]
*
sin
_theta_phi
[...,
1
]
y
=
rho
*
cos
_theta_phi
[...,
1
]
z
=
rho
*
sin
_theta_phi
[...,
0
]
*
sin
_theta_phi
[...,
1
]
x
=
rho
*
sin
_theta_phi
[...,
0
]
*
cos
_theta_phi
[...,
1
]
y
=
rho
*
sin
_theta_phi
[...,
1
]
z
=
rho
*
cos
_theta_phi
[...,
0
]
*
cos
_theta_phi
[...,
1
]
return
torch
.
stack
([
x
,
y
,
z
],
dim
=-
1
)
...
...
utils/voxels.py
0 → 100644
View file @
5699ccbf
import
torch
from
typing
import
Tuple
,
Union
def
get_grid_steps
(
bbox
:
torch
.
Tensor
,
step_size
:
Union
[
torch
.
Tensor
,
float
])
->
torch
.
Tensor
:
"""
Get grid steps alone every dim.
:param bbox `Tensor(2, D)`: bounding box
:param step_size `Tensor(1|D) | float`: step size
:return `Tensor(D)`: grid steps alone every dim
"""
return
((
bbox
[
1
]
-
bbox
[
0
])
/
step_size
).
ceil
().
long
()
def
to_grid_coords
(
pts
:
torch
.
Tensor
,
bbox
:
torch
.
Tensor
,
*
,
step_size
:
Union
[
torch
.
Tensor
,
float
]
=
None
,
steps
:
torch
.
Tensor
=
None
)
->
torch
.
Tensor
:
"""
Get discretized (integer) grid coordinates of points.
At least one of the parameters `step_size` and `steps` should be specified. If `step_size` is
specified, then the grid coordinates will be calculated according to the step size, ignoring
the value of `steps`.
:param pts `Tensor(N..., D)`: points
:param bbox `Tensor(2, D)`: bounding box
:param step_size `Tensor(1|D) | float`: (optional) step size
:param steps `Tensor(1|D)`: (optional) steps alone every dim
:return `Tensor(N..., D)`: discretized grid coordinates
"""
if
step_size
is
not
None
:
return
((
pts
-
bbox
[
0
])
/
step_size
).
floor
().
long
()
return
((
pts
-
bbox
[
0
])
/
(
bbox
[
1
]
-
bbox
[
0
])
*
steps
).
floor
().
long
()
def
to_grid_indices
(
pts
:
torch
.
Tensor
,
bbox
:
torch
.
Tensor
,
*
,
step_size
:
Union
[
torch
.
Tensor
,
float
]
=
None
,
steps
:
torch
.
Tensor
=
None
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Get flattened grid indices of points.
At least one of the parameters `step_size` and `steps` should be specified. If `step_size` is
specified, then the grid indices will be calculated according to the step size, ignoring
the value of `steps`.
:param pts `Tensor(N..., D)`: points
:param bbox `Tensor(2, D)`: bounding box
:param step_size `Tensor(1|D) | float`: (optional) step size
:param steps `Tensor(1|D)`: (optional) steps alone every dim
:return `Tensor(N...)`: grid indices
:return `Tensor(N...)`: a mask tensor indicating the returned indices are outside or not
"""
if
step_size
is
not
None
:
steps
=
get_grid_steps
(
bbox
,
step_size
)
# (D)
grid_coords
=
to_grid_coords
(
pts
,
bbox
,
step_size
=
step_size
,
steps
=
steps
)
# (N..., D)
outside_mask
=
torch
.
logical_or
(
grid_coords
<
0
,
grid_coords
>=
steps
).
any
(
-
1
)
# (N...)
if
pts
.
size
(
-
1
)
==
1
:
grid_indices
=
grid_coords
[...,
0
]
elif
pts
.
size
(
-
1
)
==
2
:
grid_indices
=
grid_coords
[...,
0
]
*
steps
[
1
]
+
grid_coords
[...,
1
]
elif
pts
.
size
(
-
1
)
==
3
:
grid_indices
=
grid_coords
[...,
0
]
*
steps
[
1
]
*
steps
[
2
]
\
+
grid_coords
[...,
1
]
*
steps
[
2
]
+
grid_coords
[...,
2
]
elif
pts
.
size
(
-
1
)
==
4
:
grid_indices
=
grid_coords
[...,
0
]
*
steps
[
1
]
*
steps
[
2
]
*
steps
[
3
]
\
+
grid_coords
[...,
1
]
*
steps
[
2
]
*
steps
[
3
]
\
+
grid_coords
[...,
2
]
*
steps
[
3
]
\
+
grid_coords
[...,
3
]
else
:
raise
NotImplementedError
(
"The function does not support D>4"
)
return
grid_indices
,
outside_mask
def
init_voxels
(
bbox
:
torch
.
Tensor
,
steps
:
torch
.
Tensor
):
"""
Initialize voxels.
"""
x
,
y
,
z
=
torch
.
meshgrid
(
*
[
torch
.
arange
(
steps
[
i
])
for
i
in
range
(
3
)])
return
to_voxel_centers
(
torch
.
stack
([
x
,
y
,
z
],
-
1
).
reshape
(
-
1
,
3
),
bbox
,
steps
=
steps
)
def
to_voxel_centers
(
grid_coords
:
torch
.
Tensor
,
bbox
:
torch
.
Tensor
,
*
,
step_size
:
Union
[
torch
.
Tensor
,
float
]
=
None
,
steps
:
torch
.
Tensor
=
None
)
->
torch
.
Tensor
:
"""
Get discretized (integer) grid coordinates of points.
At least one of the parameters `step_size` and `steps` should be specified. If `step_size` is
specified, then the grid coordinates will be calculated according to the step size, ignoring
the value of `steps`.
:param pts `Tensor(N..., D)`: points
:param bbox `Tensor(2, D)`: bounding box
:param step_size `Tensor(1|D) | float`: (optional) step size
:param steps `Tensor(1|D)`: (optional) steps alone every dim
:return `Tensor(N..., D)`: discretized grid coordinates
"""
grid_coords
=
grid_coords
.
float
()
+
0.5
if
step_size
is
not
None
:
return
grid_coords
*
step_size
+
bbox
[
0
]
return
grid_coords
/
steps
*
(
bbox
[
1
]
-
bbox
[
0
])
+
bbox
[
0
]
def
split_voxels_local
(
voxel_size
:
Union
[
torch
.
Tensor
,
float
],
n
:
int
,
align_border
:
bool
=
True
,
dims
=
3
,
*
,
dtype
:
torch
.
dtype
=
None
,
device
:
torch
.
device
=
None
,
like
:
torch
.
Tensor
=
None
):
"""
[summary]
:param voxel_size `Tensor(D)|float`: [description]
:param n `int`: [description]
:param align_border `bool`: [description], defaults to False
:param dims `int`: [description], defaults to 3
:param dtype `dtype`: [description], defaults to None
:param device `device`: [description], defaults to None
:param like `Tensor(*)`:
:return `Tensor(X, D)`: [description]
"""
if
like
is
not
None
:
dtype
=
like
.
dtype
device
=
like
.
device
c
=
torch
.
arange
(
1
-
n
,
n
,
2
,
dtype
=
dtype
,
device
=
device
)
offset
=
torch
.
stack
(
torch
.
meshgrid
([
c
]
*
dims
),
-
1
).
flatten
(
0
,
-
2
)
*
voxel_size
/
2
/
\
(
n
-
1
if
align_border
else
n
)
return
offset
def
split_voxels
(
voxel_centers
:
torch
.
Tensor
,
voxel_size
:
Union
[
torch
.
Tensor
,
float
],
n
:
int
,
align_border
:
bool
=
True
):
"""
[summary]
:param voxel_centers `Tensor(N, D)`: [description]
:param voxel_size `Tensor(D)|float`: [description]
:param n `int`: [description]
:param align_border `bool`: [description], defaults to False
:param return_local `bool`: [description], defaults to False
:return `Tensor(N, X, D)`: [description]
"""
return
voxel_centers
[:,
None
]
+
split_voxels_local
(
voxel_size
,
n
,
align_border
,
voxel_centers
.
shape
[
-
1
],
like
=
voxel_centers
)
def
get_corners
(
voxel_centers
:
torch
.
Tensor
,
bbox
:
torch
.
Tensor
,
steps
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
half_voxel_size
=
(
bbox
[
1
]
-
bbox
[
0
])
/
steps
*
0.5
expand_bbox
=
bbox
expand_bbox
[
0
]
-=
0.5
*
half_voxel_size
expand_bbox
[
1
]
+=
0.5
*
half_voxel_size
double_grid_coords
=
to_grid_coords
(
voxel_centers
,
expand_bbox
,
step_size
=
half_voxel_size
)
# (M, 3) -> [1, 3, 5, ...]
corner_coords
=
split_voxels
(
double_grid_coords
,
2
,
2
).
reshape
(
-
1
,
3
)
# (8M, 3) -> [0, 2, 4, ...]
corner_coords
,
corner_indices
=
corner_coords
.
unique
(
dim
=
0
,
sorted
=
True
,
return_inverse
=
True
)
corners
=
to_voxel_centers
(
corner_coords
,
expand_bbox
,
step_size
=
half_voxel_size
)
return
corners
,
corner_indices
.
reshape
(
-
1
,
8
)
def
trilinear_interp
(
pts
:
torch
.
Tensor
,
corner_values
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Perform trilinear interpolation in unit voxel ([0,0,0] ~ [1,1,1]).
:param pts `Tensor(N, 3)`: uniform coordinates in voxels
:param corner_values `Tensor(N, 8X)|Tensor(N, 8, X)`: values at corners of voxels
:return `Tensor(N, X)`: interpolated values
"""
pts
=
pts
[:,
None
]
# (N, 1, 3)
corners
=
split_voxels_local
(
1
,
2
,
like
=
pts
)
+
0.5
# (8, 3)
weights
=
(
pts
*
corners
*
2
-
pts
-
corners
+
1
).
prod
(
-
1
,
keepdim
=
True
)
# (N, 8, 1)
corner_values
=
corner_values
.
reshape
(
corner_values
.
size
(
0
),
8
,
-
1
)
# (N, 8, X)
return
(
weights
*
corner_values
).
sum
(
1
)
Prev
1
…
4
5
6
7
8
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment