Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
Nianchen Deng
deeplightfield
Commits
2824f796
Commit
2824f796
authored
Dec 06, 2021
by
Nianchen Deng
Browse files
sync
parent
5699ccbf
Changes
17
Expand all
Hide whitespace changes
Inline
Side-by-side
.vscode/launch.json
View file @
2824f796
...
@@ -24,14 +24,19 @@
...
@@ -24,14 +24,19 @@
"program"
:
"train.py"
,
"program"
:
"train.py"
,
"args"
:
[
"args"
:
[
//
"-c"
,
//
"-c"
,
//
"snerf_voxels"
,
//
"snerf_voxels
+ls+f32
"
,
"/
home/deng
nc/dvs/data/__ne
w/barbershop_fovea_r360x80_t0.6
/_nets/train
_t0.3
/snerf
advx
_voxels
_x4
/checkpoint_1
0
.tar"
,
"/
data1/d
nc/dvs/data/__ne
rf/room
/_nets/train/snerf_voxels
+ls+f32
/checkpoint_1.tar"
,
"--prune"
,
"--prune"
,
"1
00
"
,
"1"
,
"--split"
,
"--split"
,
"100"
"1"
,
//
"data/__new/barbershop_fovea_r360x80_t0.6/train_t0.3.json"
"-e"
,
"100"
,
"--views"
,
"5"
,
//
"data/__nerf/room/train.json"
],
],
"justMyCode"
:
false
,
"console"
:
"integratedTerminal"
"console"
:
"integratedTerminal"
},
},
{
{
...
...
configs/snerf_voxels+ls+f32.json
0 → 100644
View file @
2824f796
{
"model"
:
"SNeRF"
,
"args"
:
{
"color"
:
"rgb"
,
"n_pot_encode"
:
10
,
"n_dir_encode"
:
4
,
"fc_params"
:
{
"nf"
:
256
,
"n_layers"
:
8
,
"activation"
:
"relu"
,
"skips"
:
[
4
]
},
"n_featdim"
:
32
,
"space"
:
"voxels"
,
"steps"
:
[
4
,
16
,
8
],
"n_samples"
:
16
,
"perturb_sample"
:
true
,
"density_regularization_weight"
:
1e-4
,
"density_regularization_scale"
:
1e4
}
}
\ No newline at end of file
model/base.py
View file @
2824f796
...
@@ -16,7 +16,7 @@ class BaseModelMeta(type):
...
@@ -16,7 +16,7 @@ class BaseModelMeta(type):
class
BaseModel
(
nn
.
Module
,
metaclass
=
BaseModelMeta
):
class
BaseModel
(
nn
.
Module
,
metaclass
=
BaseModelMeta
):
t
rainer
=
"Train"
T
rainer
Class
=
"Train"
@
property
@
property
def
args
(
self
):
def
args
(
self
):
...
...
model/nerf.py
View file @
2824f796
...
@@ -10,7 +10,7 @@ from utils.misc import masked_scatter
...
@@ -10,7 +10,7 @@ from utils.misc import masked_scatter
class
NeRF
(
BaseModel
):
class
NeRF
(
BaseModel
):
t
rainer
=
"TrainWithSpace"
T
rainer
Class
=
"TrainWithSpace"
SamplerClass
=
Sampler
SamplerClass
=
Sampler
RendererClass
=
VolumnRenderer
RendererClass
=
VolumnRenderer
...
@@ -124,21 +124,11 @@ class NeRF(BaseModel):
...
@@ -124,21 +124,11 @@ class NeRF(BaseModel):
return
self
.
pot_encoder
(
x
)
return
self
.
pot_encoder
(
x
)
def
encode_d
(
self
,
samples
:
Samples
)
->
torch
.
Tensor
:
def
encode_d
(
self
,
samples
:
Samples
)
->
torch
.
Tensor
:
return
self
.
dir_encoder
(
samples
.
dirs
)
if
self
.
dir_encoder
is
not
None
else
None
return
self
.
dir_encoder
(
samples
.
dirs
)
if
self
.
dir_encoder
else
None
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
get_scores
(
self
,
sampled_points
:
torch
.
Tensor
,
sampled_voxel_indices
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
split
(
self
):
densities
=
self
.
render
(
Samples
(
sampled_points
,
None
,
None
,
None
,
sampled_voxel_indices
),
ret
=
self
.
space
.
split
()
'density'
)
return
1
-
(
-
densities
).
exp
()
@
torch
.
no_grad
()
def
pruning
(
self
,
threshold
:
float
=
0.5
,
train_stats
=
False
):
return
self
.
space
.
pruning
(
self
.
get_scores
,
threshold
,
train_stats
)
@
torch
.
no_grad
()
def
splitting
(
self
):
ret
=
self
.
space
.
splitting
()
if
'n_samples'
in
self
.
args0
:
if
'n_samples'
in
self
.
args0
:
self
.
args0
[
'n_samples'
]
*=
2
self
.
args0
[
'n_samples'
]
*=
2
if
'voxel_size'
in
self
.
args0
:
if
'voxel_size'
in
self
.
args0
:
...
@@ -149,12 +139,10 @@ class NeRF(BaseModel):
...
@@ -149,12 +139,10 @@ class NeRF(BaseModel):
if
'sample_step'
in
self
.
args0
:
if
'sample_step'
in
self
.
args0
:
self
.
args0
[
'sample_step'
]
/=
2
self
.
args0
[
'sample_step'
]
/=
2
self
.
sampler
=
self
.
SamplerClass
(
**
self
.
args
)
self
.
sampler
=
self
.
SamplerClass
(
**
self
.
args
)
if
self
.
args
.
get
(
'n_featdim'
)
and
hasattr
(
self
,
"trainer"
):
self
.
trainer
.
reset_optimizer
()
return
ret
return
ret
@
torch
.
no_grad
()
def
double_samples
(
self
):
pass
@
perf
@
perf
def
forward
(
self
,
rays_o
:
torch
.
Tensor
,
rays_d
:
torch
.
Tensor
,
*
,
def
forward
(
self
,
rays_o
:
torch
.
Tensor
,
rays_d
:
torch
.
Tensor
,
*
,
extra_outputs
:
List
[
str
]
=
[],
**
kwargs
)
->
torch
.
Tensor
:
extra_outputs
:
List
[
str
]
=
[],
**
kwargs
)
->
torch
.
Tensor
:
...
...
model/snerf_advance_x.py
View file @
2824f796
...
@@ -40,16 +40,8 @@ class SNeRFAdvanceX(SNeRFAdvance):
...
@@ -40,16 +40,8 @@ class SNeRFAdvanceX(SNeRFAdvance):
return
self
.
cores
[
chunk_id
](
x
,
d
,
outputs
,
**
extras
)
return
self
.
cores
[
chunk_id
](
x
,
d
,
outputs
,
**
extras
)
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
get_scores
(
self
,
sampled_points
:
torch
.
Tensor
,
sampled_voxel_indices
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
split
(
self
):
raise
NotImplementedError
()
ret
=
super
().
split
()
@
torch
.
no_grad
()
def
pruning
(
self
,
threshold
:
float
=
0.5
,
train_stats
=
False
):
raise
NotImplementedError
()
@
torch
.
no_grad
()
def
splitting
(
self
):
ret
=
super
().
splitting
()
k
=
self
.
args
[
"n_samples"
]
//
self
.
space
.
steps
[
0
].
item
()
k
=
self
.
args
[
"n_samples"
]
//
self
.
space
.
steps
[
0
].
item
()
net_samples
=
[
val
*
k
for
val
in
self
.
space
.
balance_cut
(
0
,
len
(
self
.
cores
))]
net_samples
=
[
val
*
k
for
val
in
self
.
space
.
balance_cut
(
0
,
len
(
self
.
cores
))]
if
len
(
net_samples
)
!=
len
(
self
.
cores
):
if
len
(
net_samples
)
!=
len
(
self
.
cores
):
...
...
model/snerf_x.py
View file @
2824f796
...
@@ -4,10 +4,6 @@ from .snerf import *
...
@@ -4,10 +4,6 @@ from .snerf import *
class
SNeRFX
(
SNeRF
):
class
SNeRFX
(
SNeRF
):
trainer
=
"TrainWithSpace"
SamplerClass
=
SphericalSampler
RendererClass
=
VolumnRenderer
def
__init__
(
self
,
args0
:
dict
,
args1
:
dict
=
{}):
def
__init__
(
self
,
args0
:
dict
,
args1
:
dict
=
{}):
"""
"""
Initialize a multi-sphere-layer net
Initialize a multi-sphere-layer net
...
@@ -42,16 +38,8 @@ class SNeRFX(SNeRF):
...
@@ -42,16 +38,8 @@ class SNeRFX(SNeRF):
return
self
.
cores
[
chunk_id
](
x
,
d
,
outputs
)
return
self
.
cores
[
chunk_id
](
x
,
d
,
outputs
)
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
get_scores
(
self
,
sampled_points
:
torch
.
Tensor
,
sampled_voxel_indices
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
split
(
self
):
raise
NotImplementedError
()
ret
=
super
().
split
()
@
torch
.
no_grad
()
def
pruning
(
self
,
threshold
:
float
=
0.5
,
train_stats
=
False
):
raise
NotImplementedError
()
@
torch
.
no_grad
()
def
splitting
(
self
):
ret
=
super
().
splitting
()
k
=
self
.
args
[
"n_samples"
]
//
self
.
space
.
steps
[
0
].
item
()
k
=
self
.
args
[
"n_samples"
]
//
self
.
space
.
steps
[
0
].
item
()
net_samples
=
[
net_samples
=
[
val
*
k
for
val
in
self
.
space
.
balance_cut
(
0
,
len
(
self
.
cores
))
val
*
k
for
val
in
self
.
space
.
balance_cut
(
0
,
len
(
self
.
cores
))
...
...
modules/space.py
View file @
2824f796
from
math
import
ceil
import
torch
import
torch
import
numpy
as
np
from
typing
import
List
,
Tuple
,
Union
from
typing
import
List
,
NoReturn
,
Tuple
,
Union
from
torch
import
nn
from
torch
import
nn
from
plyfile
import
PlyData
,
PlyElement
from
utils.geometry
import
*
from
utils.geometry
import
*
from
utils.constants
import
*
from
utils.constants
import
*
...
@@ -73,11 +70,11 @@ class Space(nn.Module):
...
@@ -73,11 +70,11 @@ class Space(nn.Module):
return
voxel_indices
return
voxel_indices
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
prun
ing
(
self
,
score_fn
,
threshold
:
float
=
0.5
,
train_stats
=
False
)
:
def
prun
e
(
self
,
keeps
:
torch
.
Tensor
)
->
Tuple
[
int
,
int
]
:
raise
NotImplementedError
()
raise
NotImplementedError
()
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
split
ting
(
self
):
def
split
(
self
):
raise
NotImplementedError
()
raise
NotImplementedError
()
...
@@ -108,7 +105,7 @@ class Voxels(Space):
...
@@ -108,7 +105,7 @@ class Voxels(Space):
return
self
.
voxels
.
size
(
0
)
return
self
.
voxels
.
size
(
0
)
@
property
@
property
def
n_corner
(
self
)
->
int
:
def
n_corner
s
(
self
)
->
int
:
"""`int` Number of corners"""
"""`int` Number of corners"""
return
self
.
corners
.
size
(
0
)
return
self
.
corners
.
size
(
0
)
...
@@ -145,12 +142,18 @@ class Voxels(Space):
...
@@ -145,12 +142,18 @@ class Voxels(Space):
:param n_dims `int`: embedding dimension
:param n_dims `int`: embedding dimension
:return `Embedding(n_corners, n_dims)`: new embedding on voxel corners
:return `Embedding(n_corners, n_dims)`: new embedding on voxel corners
"""
"""
name
=
f
'emb_
{
name
}
'
if
self
.
get_embedding
(
name
)
is
not
None
:
self
.
add_module
(
name
,
torch
.
nn
.
Embedding
(
self
.
n_corners
.
item
(),
n_dims
))
raise
KeyError
(
f
"Embedding '
{
name
}
' already existed"
)
return
self
.
__getattr__
(
name
)
emb
=
torch
.
nn
.
Embedding
(
self
.
n_corners
,
n_dims
,
device
=
self
.
device
)
setattr
(
self
,
f
'emb_
{
name
}
'
,
emb
)
return
emb
def
get_embedding
(
self
,
name
:
str
=
'default'
)
->
torch
.
nn
.
Embedding
:
def
get_embedding
(
self
,
name
:
str
=
'default'
)
->
torch
.
nn
.
Embedding
:
return
getattr
(
self
,
f
'emb_
{
name
}
'
)
return
getattr
(
self
,
f
'emb_
{
name
}
'
,
None
)
def
set_embedding
(
self
,
weight
:
torch
.
Tensor
,
name
:
str
=
'default'
):
emb
=
torch
.
nn
.
Embedding
(
*
weight
.
shape
,
_weight
=
weight
,
device
=
self
.
device
)
setattr
(
self
,
f
'emb_
{
name
}
'
,
emb
)
def
extract_embedding
(
self
,
pts
:
torch
.
Tensor
,
voxel_indices
:
torch
.
Tensor
,
def
extract_embedding
(
self
,
pts
:
torch
.
Tensor
,
voxel_indices
:
torch
.
Tensor
,
name
:
str
=
'default'
)
->
torch
.
Tensor
:
name
:
str
=
'default'
)
->
torch
.
Tensor
:
...
@@ -167,9 +170,8 @@ class Voxels(Space):
...
@@ -167,9 +170,8 @@ class Voxels(Space):
raise
KeyError
(
f
"Embedding '
{
name
}
' doesn't exist"
)
raise
KeyError
(
f
"Embedding '
{
name
}
' doesn't exist"
)
voxels
=
self
.
voxels
[
voxel_indices
]
# (N, 3)
voxels
=
self
.
voxels
[
voxel_indices
]
# (N, 3)
corner_indices
=
self
.
corner_indices
[
voxel_indices
]
# (N, 8)
corner_indices
=
self
.
corner_indices
[
voxel_indices
]
# (N, 8)
p
=
(
pts
-
voxels
)
/
self
.
voxel_size
+
0.5
# (N, 3) normed-coords in voxel
p
=
(
pts
-
voxels
)
/
self
.
voxel_size
+
.
5
# (N, 3) normed-coords in voxel
features
=
emb
(
corner_indices
).
reshape
(
pts
.
size
(
0
),
8
,
-
1
)
# (N, 8, X)
return
trilinear_interp
(
p
,
emb
(
corner_indices
))
return
trilinear_interp
(
p
,
features
)
@
perf
@
perf
def
ray_intersect
(
self
,
rays_o
:
torch
.
Tensor
,
rays_d
:
torch
.
Tensor
,
n_max_hits
:
int
)
->
Intersections
:
def
ray_intersect
(
self
,
rays_o
:
torch
.
Tensor
,
rays_d
:
torch
.
Tensor
,
n_max_hits
:
int
)
->
Intersections
:
...
@@ -220,17 +222,34 @@ class Voxels(Space):
...
@@ -220,17 +222,34 @@ class Voxels(Space):
return
voxel_indices
return
voxel_indices
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
split
ting
(
self
)
->
None
:
def
split
(
self
)
->
None
:
"""
"""
Split voxels into smaller voxels with half size.
Split voxels into smaller voxels with half size.
"""
"""
n_voxels_before
=
self
.
n_voxels
new_steps
=
self
.
steps
*
2
self
.
steps
*=
2
new_voxels
=
split_voxels
(
self
.
voxels
,
self
.
voxel_size
,
2
,
align_border
=
False
)
\
self
.
voxels
=
split_voxels
(
self
.
voxels
,
self
.
voxel_size
,
2
,
align_border
=
False
)
\
.
reshape
(
-
1
,
3
)
.
reshape
(
-
1
,
3
)
self
.
_update_corners
()
new_corners
,
new_corner_indices
=
get_corners
(
new_voxels
,
self
.
bbox
,
new_steps
)
# Calculate new embeddings through trilinear interpolation
grid_indices_of_new_corners
=
to_flat_indices
(
to_grid_coords
(
new_corners
,
self
.
bbox
,
steps
=
self
.
steps
).
min
(
self
.
steps
-
1
),
self
.
steps
)
voxel_indices_of_new_corners
=
self
.
voxel_indices_in_grid
[
grid_indices_of_new_corners
]
for
name
,
_
in
self
.
named_modules
():
if
not
name
.
startswith
(
"emb_"
):
continue
new_emb_weight
=
self
.
extract_embedding
(
new_corners
,
voxel_indices_of_new_corners
,
name
=
name
[
4
:])
self
.
set_embedding
(
new_emb_weight
,
name
=
name
[
4
:])
# Apply new tensors
self
.
steps
=
new_steps
self
.
voxels
=
new_voxels
self
.
corners
=
new_corners
self
.
corner_indices
=
new_corner_indices
self
.
_update_voxel_indices_in_grid
()
self
.
_update_voxel_indices_in_grid
()
return
n_voxels
_before
,
self
.
n_voxels
return
self
.
n_voxels
//
8
,
self
.
n_voxels
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
prune
(
self
,
keeps
:
torch
.
Tensor
)
->
Tuple
[
int
,
int
]:
def
prune
(
self
,
keeps
:
torch
.
Tensor
)
->
Tuple
[
int
,
int
]:
...
@@ -239,11 +258,6 @@ class Voxels(Space):
...
@@ -239,11 +258,6 @@ class Voxels(Space):
self
.
_update_voxel_indices_in_grid
()
self
.
_update_voxel_indices_in_grid
()
return
keeps
.
size
(
0
),
keeps
.
sum
().
item
()
return
keeps
.
size
(
0
),
keeps
.
sum
().
item
()
@
torch
.
no_grad
()
def
pruning
(
self
,
score_fn
,
threshold
:
float
=
0.5
)
->
None
:
scores
=
self
.
_get_scores
(
score_fn
,
lambda
x
:
torch
.
max
(
x
,
-
1
)[
0
])
# (M)
return
self
.
prune
(
scores
>
threshold
)
def
n_voxels_along_dim
(
self
,
dim
:
int
)
->
torch
.
Tensor
:
def
n_voxels_along_dim
(
self
,
dim
:
int
)
->
torch
.
Tensor
:
sum_dims
=
[
val
for
val
in
range
(
self
.
dims
)
if
val
!=
dim
]
sum_dims
=
[
val
for
val
in
range
(
self
.
dims
)
if
val
!=
dim
]
return
self
.
voxel_indices_in_grid
.
reshape
(
*
self
.
steps
).
ne
(
-
1
).
sum
(
sum_dims
)
return
self
.
voxel_indices_in_grid
.
reshape
(
*
self
.
steps
).
ne
(
-
1
).
sum
(
sum_dims
)
...
@@ -261,39 +275,30 @@ class Voxels(Space):
...
@@ -261,39 +275,30 @@ class Voxels(Space):
part
=
int
(
cdf
[
i
])
+
1
part
=
int
(
cdf
[
i
])
+
1
return
bins
return
bins
def
sample
(
self
,
bits
:
int
,
perturb
:
bool
=
False
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
def
sample
(
self
,
S
:
int
,
perturb
:
bool
=
False
,
include_border
:
bool
=
True
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
sampled_xyz
=
split_voxels
(
self
.
voxels
,
self
.
voxel_size
,
bits
)
"""
sampled_idx
=
torch
.
arange
(
self
.
n_voxels
,
device
=
self
.
device
)[:,
None
].
expand
(
For each voxel, sample `S^3` points uniformly, with small perturb if `perturb` is `True`.
*
sampled_xyz
.
shape
[:
2
])
sampled_xyz
,
sampled_idx
=
sampled_xyz
.
reshape
(
-
1
,
3
),
sampled_idx
.
flatten
()
When `perturb` is `False`, `include_border` can specify whether to sample points from border to border or at centers of sub-voxels.
When `perturb` is `True`, points are sampled at centers of sub-voxels, then applying a random offset in sub-voxels.
@
torch
.
no_grad
()
:param S `int`: number of samples along each dim
def
_get_scores
(
self
,
score_fn
,
reduce_fn
=
None
,
bits
=
16
)
->
torch
.
Tensor
:
:param perturb `bool?`: whether perturb samples, defaults to `False`
def
get_scores_once
(
pts
,
idxs
):
:param include_border `bool?`: whether include border, defaults to `True`
scores
=
score_fn
(
pts
,
idxs
).
reshape
(
-
1
,
bits
**
3
)
# (B, P)
:return `Tensor(N*S^3, 3)`: sampled points
if
reduce_fn
is
not
None
:
:return `Tensor(N*S^3)`: voxel indices of sampled points
scores
=
reduce_fn
(
scores
)
# (B[, ...])
"""
return
scores
pts
=
split_voxels
(
self
.
voxels
,
self
.
voxel_size
,
S
,
align_border
=
not
perturb
and
include_border
)
# (N, X, D)
sampled_xyz
,
sampled_idx
=
self
.
sample
(
bits
)
voxel_indices
=
torch
.
arange
(
self
.
n_voxels
,
device
=
self
.
device
)[:,
None
]
\
chunk_size
=
64
.
expand
(
*
pts
.
shape
[:
-
1
])
# (N) -> (N, X)
return
torch
.
cat
([
if
perturb
:
get_scores_once
(
sampled_xyz
[
i
:
i
+
chunk_size
],
sampled_idx
[
i
:
i
+
chunk_size
])
pts
+=
(
torch
.
rand_like
(
pts
)
-
.
5
)
*
self
.
voxel_size
/
S
for
i
in
range
(
0
,
self
.
voxels
.
size
(
0
),
chunk_size
)
return
pts
.
reshape
(
-
1
,
3
),
voxel_indices
.
flatten
()
],
0
)
# (M[, ...])
def
_ray_intersect
(
self
,
rays_o
:
torch
.
Tensor
,
rays_d
:
torch
.
Tensor
,
n_max_hits
:
int
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
def
_ray_intersect
(
self
,
rays_o
:
torch
.
Tensor
,
rays_d
:
torch
.
Tensor
,
n_max_hits
:
int
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
return
aabb_ray_intersect
(
self
.
voxel_size
,
n_max_hits
,
self
.
voxels
,
rays_o
,
rays_d
)
return
aabb_ray_intersect
(
self
.
voxel_size
,
n_max_hits
,
self
.
voxels
,
rays_o
,
rays_d
)
def
_update_corners
(
self
):
"""
Update voxel corners.
"""
corners
,
corner_indices
=
get_corners
(
self
.
voxels
,
self
.
bbox
,
self
.
steps
)
self
.
register_buffer
(
"corners"
,
corners
)
self
.
register_buffer
(
"corner_indices"
,
corner_indices
)
def
_update_voxel_indices_in_grid
(
self
):
def
_update_voxel_indices_in_grid
(
self
):
"""
"""
Update voxel indices in grid.
Update voxel indices in grid.
...
@@ -314,7 +319,7 @@ class Voxels(Space):
...
@@ -314,7 +319,7 @@ class Voxels(Space):
# Handle embeddings
# Handle embeddings
for
name
,
module
in
self
.
named_modules
():
for
name
,
module
in
self
.
named_modules
():
if
name
.
startswith
(
'emb_'
):
if
name
.
startswith
(
'emb_'
):
setattr
(
self
,
name
,
torch
.
nn
.
Embedding
(
self
.
n_corners
.
item
()
,
module
.
embedding_dim
))
setattr
(
self
,
name
,
torch
.
nn
.
Embedding
(
self
.
n_corners
,
module
.
embedding_dim
))
class
Octree
(
Voxels
):
class
Octree
(
Voxels
):
...
@@ -339,8 +344,8 @@ class Octree(Voxels):
...
@@ -339,8 +344,8 @@ class Octree(Voxels):
return
octree_ray_intersect
(
self
.
voxel_size
,
n_max_hits
,
nodes
,
tree
,
rays_o
,
rays_d
)
return
octree_ray_intersect
(
self
.
voxel_size
,
n_max_hits
,
nodes
,
tree
,
rays_o
,
rays_d
)
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
split
ting
(
self
):
def
split
(
self
):
ret
=
super
().
split
ting
()
ret
=
super
().
split
()
self
.
clear
()
self
.
clear
()
return
ret
return
ret
...
...
test.ipynb
0 → 100644
View file @
2824f796
{
"cells": [
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from utils.voxels import *\n",
"\n",
"bbox, steps = torch.tensor([[-2, -3.14159, 1], [2, 3.14159, 0]]), torch.tensor([2, 3, 3])\n",
"voxel_size = (bbox[1] - bbox[0]) / steps\n",
"voxels = init_voxels(bbox, steps)\n",
"corners, corner_indices = get_corners(voxels, bbox, steps)\n",
"voxel_indices_in_grid = torch.arange(voxels.shape[0])\n",
"emb = torch.nn.Embedding(corners.shape[0], 3, _weight=corners)"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([11, 3]) tensor([ 0, -1, -1, 1, -1, -1, 2, 3, 4, -1, 5, 6, -1, 7, 8, -1, 9, 10])\n"
]
}
],
"source": [
"keeps = torch.tensor([True]*18)\n",
"keeps[torch.tensor([1,2,4,5,9,12,15])] = False\n",
"voxels = voxels[keeps]\n",
"corner_indices = corner_indices[keeps]\n",
"grid_indices, _ = to_grid_indices(voxels, bbox, steps=steps)\n",
"voxel_indices_in_grid = grid_indices.new_full([steps.prod().item()], -1)\n",
"voxel_indices_in_grid[grid_indices] = torch.arange(voxels.shape[0])\n",
"print(voxels.shape, voxel_indices_in_grid)"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([88, 3]) torch.Size([185, 3]) torch.Size([88, 8])\n"
]
}
],
"source": [
"new_voxels = split_voxels(voxels, (bbox[1] - bbox[0]) / steps, 2, align_border=False).reshape(-1, 3)\n",
"new_corners, new_corner_indices = get_corners(new_voxels, bbox, steps * 2)\n",
"print(new_voxels.shape, new_corners.shape, new_corner_indices.shape)"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor([ 0, 0, -1, 0, 0, -1, 1, 1, -1, 1, 1, -1, 2, 2, 3, 3, 4, 4,\n",
" 4, 2, 2, 3, 3, 4, 4, 4, 2, 2, 3, 3, 4, 4, 4, 0, 0, -1,\n",
" 0, 0, -1, 1, 1, -1, 1, 1, -1, 2, 2, 3, 3, 4, 4, 4, 2, 2,\n",
" 3, 3, 4, 4, 4, 2, 2, 3, 3, 4, 4, 4, -1, -1, 5, 5, 6, 6,\n",
" 6, -1, -1, 5, 5, 6, 6, 6, -1, -1, 7, 7, 8, 8, 8, -1, -1, 7,\n",
" 7, 8, 8, 8, -1, -1, 9, 9, 10, 10, 10, -1, -1, 9, 9, 10, 10, 10,\n",
" -1, -1, 9, 9, 10, 10, 10, 5, 5, 6, 6, 6, 5, 5, 6, 6, 6, 7,\n",
" 7, 8, 8, 8, 7, 7, 8, 8, 8, 9, 9, 10, 10, 10, 9, 9, 10, 10,\n",
" 10, 9, 9, 10, 10, 10, 5, 5, 6, 6, 6, 5, 5, 6, 6, 6, 7, 7,\n",
" 8, 8, 8, 7, 7, 8, 8, 8, 9, 9, 10, 10, 10, 9, 9, 10, 10, 10,\n",
" 9, 9, 10, 10, 10])\n",
"tensor(0)\n"
]
}
],
"source": [
"voxel_indices_of_new_corner = voxel_indices_in_grid[to_flat_indices(to_grid_coords(new_corners, bbox, steps=steps).min(steps - 1), steps)]\n",
"print(voxel_indices_of_new_corner)\n",
"p_of_new_corners = (new_corners - voxels[voxel_indices_of_new_corner]) / voxel_size + .5\n",
"print(((new_corners - trilinear_interp(p_of_new_corners, emb(corner_indices[voxel_indices_of_new_corner]))) > 1e-6).sum())"
]
}
],
"metadata": {
"interpreter": {
"hash": "08b118544df3cb8970a671e5837a88fd458f4d4c799ef1fb2709465a22a45b92"
},
"kernelspec": {
"display_name": "Python 3.9.5 64-bit ('base': conda)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.5"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}
test.py
View file @
2824f796
...
@@ -38,7 +38,7 @@ from data.loader import DataLoader
...
@@ -38,7 +38,7 @@ from data.loader import DataLoader
from
utils.constants
import
HUGE_FLOAT
from
utils.constants
import
HUGE_FLOAT
RAYS_PER_BATCH
=
2
**
1
4
RAYS_PER_BATCH
=
2
**
1
2
DATA_LOADER_CHUNK_SIZE
=
1e8
DATA_LOADER_CHUNK_SIZE
=
1e8
...
...
test.txt
0 → 100644
View file @
2824f796
This diff is collapsed.
Click to expand it.
test1.txt
0 → 100644
View file @
2824f796
This diff is collapsed.
Click to expand it.
train.py
View file @
2824f796
...
@@ -13,8 +13,9 @@ from data.loader import DataLoader
...
@@ -13,8 +13,9 @@ from data.loader import DataLoader
from
utils.misc
import
list_epochs
,
print_and_log
from
utils.misc
import
list_epochs
,
print_and_log
RAYS_PER_BATCH
=
2
**
1
6
RAYS_PER_BATCH
=
2
**
1
2
DATA_LOADER_CHUNK_SIZE
=
1e8
DATA_LOADER_CHUNK_SIZE
=
1e8
root_dir
=
Path
.
cwd
()
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
...
@@ -68,7 +69,7 @@ if args.mdl_path:
...
@@ -68,7 +69,7 @@ if args.mdl_path:
model_args
=
model
.
args
model_args
=
model
.
args
else
:
else
:
# Create model from specified configuration
# Create model from specified configuration
with
Path
(
f
'
{
sys
.
path
[
0
]
}
/configs/
{
args
.
config
}
.json'
).
open
()
as
fp
:
with
Path
(
f
'
{
root_dir
}
/configs/
{
args
.
config
}
.json'
).
open
()
as
fp
:
config
=
json
.
load
(
fp
)
config
=
json
.
load
(
fp
)
model_name
=
args
.
config
model_name
=
args
.
config
model_class
=
config
[
'model'
]
model_class
=
config
[
'model'
]
...
@@ -76,7 +77,7 @@ else:
...
@@ -76,7 +77,7 @@ else:
model_args
[
'bbox'
]
=
dataset
.
bbox
model_args
[
'bbox'
]
=
dataset
.
bbox
model_args
[
'depth_range'
]
=
dataset
.
depth_range
model_args
[
'depth_range'
]
=
dataset
.
depth_range
model
,
states
=
mdl
.
create
(
model_class
,
model_args
),
None
model
,
states
=
mdl
.
create
(
model_class
,
model_args
),
None
model
.
to
(
device
.
default
())
.
train
()
model
.
to
(
device
.
default
())
run_dir
=
Path
(
f
"_nets/
{
dataset
.
name
}
/
{
model_name
}
"
)
run_dir
=
Path
(
f
"_nets/
{
dataset
.
name
}
/
{
model_name
}
"
)
run_dir
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
run_dir
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
...
...
train/__init__.py
View file @
2824f796
...
@@ -22,5 +22,5 @@ def get_class(class_name: str) -> type:
...
@@ -22,5 +22,5 @@ def get_class(class_name: str) -> type:
def
get_trainer
(
model
:
BaseModel
,
**
kwargs
)
->
base
.
Train
:
def
get_trainer
(
model
:
BaseModel
,
**
kwargs
)
->
base
.
Train
:
train_class
=
get_class
(
model
.
t
rainer
)
train_class
=
get_class
(
model
.
T
rainer
Class
)
return
train_class
(
model
,
**
kwargs
)
return
train_class
(
model
,
**
kwargs
)
train/base.py
View file @
2824f796
...
@@ -42,8 +42,9 @@ class Train(object, metaclass=BaseTrainMeta):
...
@@ -42,8 +42,9 @@ class Train(object, metaclass=BaseTrainMeta):
self
.
iters
=
0
self
.
iters
=
0
self
.
run_dir
=
run_dir
self
.
run_dir
=
run_dir
self
.
model
.
trainer
=
self
self
.
model
.
train
()
self
.
model
.
train
()
self
.
optimizer
=
torch
.
optim
.
Adam
(
self
.
model
.
parameters
(),
lr
=
5e-4
)
self
.
reset_
optimizer
(
)
if
states
:
if
states
:
if
'epoch'
in
states
:
if
'epoch'
in
states
:
...
@@ -58,6 +59,9 @@ class Train(object, metaclass=BaseTrainMeta):
...
@@ -58,6 +59,9 @@ class Train(object, metaclass=BaseTrainMeta):
if
self
.
perf_mode
:
if
self
.
perf_mode
:
enable_perf
()
enable_perf
()
def
reset_optimizer
(
self
):
self
.
optimizer
=
torch
.
optim
.
Adam
(
self
.
model
.
parameters
(),
lr
=
5e-4
)
def
train
(
self
,
data_loader
:
DataLoader
,
max_epochs
:
int
):
def
train
(
self
,
data_loader
:
DataLoader
,
max_epochs
:
int
):
self
.
data_loader
=
data_loader
self
.
data_loader
=
data_loader
self
.
iters_per_epoch
=
self
.
perf_frames
or
len
(
data_loader
)
self
.
iters_per_epoch
=
self
.
perf_frames
or
len
(
data_loader
)
...
...
train/train_with_space.py
View file @
2824f796
...
@@ -20,18 +20,15 @@ class TrainWithSpace(Train):
...
@@ -20,18 +20,15 @@ class TrainWithSpace(Train):
if
self
.
splitting_loop
==
1
or
self
.
epoch
%
self
.
splitting_loop
==
1
:
if
self
.
splitting_loop
==
1
or
self
.
epoch
%
self
.
splitting_loop
==
1
:
try
:
try
:
with
torch
.
no_grad
():
with
torch
.
no_grad
():
before
,
after
=
self
.
model
.
splitting
()
before
,
after
=
self
.
model
.
split
()
print_and_log
(
print_and_log
(
f
"Splitting done:
{
before
}
->
{
after
}
"
)
f
"Splitting done. # of voxels before:
{
before
}
, after:
{
after
}
"
)
except
NotImplementedError
:
except
NotImplementedError
:
print_and_log
(
print_and_log
(
"Note: The space does not support splitting operation. Just skip it."
)
"Note: The space does not support splitting operation. Just skip it."
)
if
self
.
pruning_loop
==
1
or
self
.
epoch
%
self
.
pruning_loop
==
1
:
if
self
.
pruning_loop
==
1
or
self
.
epoch
%
self
.
pruning_loop
==
1
:
try
:
try
:
with
torch
.
no_grad
():
with
torch
.
no_grad
():
#before, after = self.model.pruning()
# self._prune_voxels_by_densities()
# print(f"Pruning by voxel densities done. # of voxels before: {before}, after: {after}")
# self._prune_inner_voxels()
self
.
_prune_voxels_by_weights
()
self
.
_prune_voxels_by_weights
()
except
NotImplementedError
:
except
NotImplementedError
:
print_and_log
(
print_and_log
(
...
@@ -39,26 +36,26 @@ class TrainWithSpace(Train):
...
@@ -39,26 +36,26 @@ class TrainWithSpace(Train):
super
().
_train_epoch
()
super
().
_train_epoch
()
def
_prune_
inner_voxel
s
(
self
):
def
_prune_
voxels_by_densitie
s
(
self
):
space
:
Voxels
=
self
.
model
.
space
space
:
Voxels
=
self
.
model
.
space
voxel_access_counts
=
torch
.
zeros
(
space
.
n_voxels
,
dtype
=
torch
.
long
,
threshold
=
.
5
device
=
space
.
voxels
.
device
)
bits
=
16
iters_in_epoch
=
0
batch_size
=
self
.
data_loader
.
batch_size
@
torch
.
no_grad
()
self
.
data_loader
.
batch_size
=
2
**
14
def
get_scores
(
sampled_points
:
torch
.
Tensor
,
sampled_voxel_indices
:
torch
.
Tensor
)
->
torch
.
Tensor
:
for
_
,
rays_o
,
rays_d
,
_
in
self
.
data_loa
der
:
densities
=
self
.
model
.
ren
der
(
self
.
model
(
rays_o
,
rays_d
,
Samples
(
sampled_points
,
None
,
None
,
None
,
sampled_voxel_indices
)
,
raymarching_early_stop_tolerance
=
0.01
,
'density'
)
raymarching_chunk_size_or_sections
=
[
1
],
return
1
-
(
-
densities
).
exp
()
perturb_sample
=
False
,
voxel_access_counts
=
voxel_access_coun
ts
,
sampled_xyz
,
sampled_idx
=
space
.
sample
(
bi
ts
)
voxel_access_tolerance
=
0
)
chunk_size
=
64
iters_in_epoch
+=
1
scores
=
torch
.
cat
([
percent
=
iters_in_epoch
/
len
(
self
.
data_loader
)
*
100
torch
.
max
(
get_scores
(
sampled_xyz
[
i
:
i
+
chunk_size
],
sampled_idx
[
i
:
i
+
chunk_size
])
sys
.
stdout
.
write
(
f
'Pruning inner voxels...
{
percent
:.
1
f
}
%
\r
'
)
.
reshape
(
-
1
,
bits
**
3
),
-
1
)[
0
]
self
.
data_loader
.
batch_size
=
batch
_size
for
i
in
range
(
0
,
self
.
voxels
.
size
(
0
),
chunk
_size
)
before
,
after
=
space
.
prune
(
voxel_access_counts
>
0
)
],
0
)
# (M[, ...]
)
print
(
f
"Prune inner voxels:
{
before
}
->
{
after
}
"
)
return
space
.
prune
(
scores
>
threshold
)
def
_prune_voxels_by_weights
(
self
):
def
_prune_voxels_by_weights
(
self
):
space
:
Voxels
=
self
.
model
.
space
space
:
Voxels
=
self
.
model
.
space
...
...
utils/misc.py
View file @
2824f796
...
@@ -57,10 +57,11 @@ def meshgrid(*size: int, normalize: bool = False, swap_dim: bool = False) -> tor
...
@@ -57,10 +57,11 @@ def meshgrid(*size: int, normalize: bool = False, swap_dim: bool = False) -> tor
"""
"""
if
len
(
size
)
==
1
:
if
len
(
size
)
==
1
:
size
=
(
size
[
0
],
size
[
0
])
size
=
(
size
[
0
],
size
[
0
])
y
,
x
=
torch
.
meshgrid
(
torch
.
arange
(
0
,
size
[
0
]),
torch
.
arange
(
0
,
size
[
1
]))
y
,
x
=
torch
.
meshgrid
(
torch
.
arange
(
size
[
0
]),
torch
.
arange
(
size
[
1
]),
indexing
=
'ij'
)
if
swap_dim
:
if
normalize
:
return
torch
.
stack
([
y
/
(
size
[
0
]
-
1.
),
x
/
(
size
[
1
]
-
1.
)],
2
)
if
normalize
else
torch
.
stack
([
y
,
x
],
2
)
x
.
div_
(
size
[
1
]
-
1.
)
return
torch
.
stack
([
x
/
(
size
[
1
]
-
1.
),
y
/
(
size
[
0
]
-
1.
)],
2
)
if
normalize
else
torch
.
stack
([
x
,
y
],
2
)
y
.
div_
(
size
[
0
]
-
1.
)
return
torch
.
stack
([
y
,
x
],
2
)
if
swap_dim
else
torch
.
stack
([
x
,
y
],
2
)
def
get_angle
(
x
:
torch
.
Tensor
,
y
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
get_angle
(
x
:
torch
.
Tensor
,
y
:
torch
.
Tensor
)
->
torch
.
Tensor
:
...
...
utils/voxels.py
View file @
2824f796
...
@@ -13,6 +13,13 @@ def get_grid_steps(bbox: torch.Tensor, step_size: Union[torch.Tensor, float]) ->
...
@@ -13,6 +13,13 @@ def get_grid_steps(bbox: torch.Tensor, step_size: Union[torch.Tensor, float]) ->
return
((
bbox
[
1
]
-
bbox
[
0
])
/
step_size
).
ceil
().
long
()
return
((
bbox
[
1
]
-
bbox
[
0
])
/
step_size
).
ceil
().
long
()
def
to_flat_indices
(
grid_coords
:
torch
.
Tensor
,
steps
:
torch
.
Tensor
)
->
torch
.
Tensor
:
indices
=
grid_coords
[...,
0
]
for
i
in
range
(
1
,
grid_coords
.
shape
[
-
1
]):
indices
=
indices
*
steps
[
i
]
+
grid_coords
[...,
i
]
return
indices
def
to_grid_coords
(
pts
:
torch
.
Tensor
,
bbox
:
torch
.
Tensor
,
*
,
def
to_grid_coords
(
pts
:
torch
.
Tensor
,
bbox
:
torch
.
Tensor
,
*
,
step_size
:
Union
[
torch
.
Tensor
,
float
]
=
None
,
step_size
:
Union
[
torch
.
Tensor
,
float
]
=
None
,
steps
:
torch
.
Tensor
=
None
)
->
torch
.
Tensor
:
steps
:
torch
.
Tensor
=
None
)
->
torch
.
Tensor
:
...
@@ -55,20 +62,7 @@ def to_grid_indices(pts: torch.Tensor, bbox: torch.Tensor, *,
...
@@ -55,20 +62,7 @@ def to_grid_indices(pts: torch.Tensor, bbox: torch.Tensor, *,
steps
=
get_grid_steps
(
bbox
,
step_size
)
# (D)
steps
=
get_grid_steps
(
bbox
,
step_size
)
# (D)
grid_coords
=
to_grid_coords
(
pts
,
bbox
,
step_size
=
step_size
,
steps
=
steps
)
# (N..., D)
grid_coords
=
to_grid_coords
(
pts
,
bbox
,
step_size
=
step_size
,
steps
=
steps
)
# (N..., D)
outside_mask
=
torch
.
logical_or
(
grid_coords
<
0
,
grid_coords
>=
steps
).
any
(
-
1
)
# (N...)
outside_mask
=
torch
.
logical_or
(
grid_coords
<
0
,
grid_coords
>=
steps
).
any
(
-
1
)
# (N...)
if
pts
.
size
(
-
1
)
==
1
:
grid_indices
=
to_flat_indices
(
grid_coords
,
steps
)
grid_indices
=
grid_coords
[...,
0
]
elif
pts
.
size
(
-
1
)
==
2
:
grid_indices
=
grid_coords
[...,
0
]
*
steps
[
1
]
+
grid_coords
[...,
1
]
elif
pts
.
size
(
-
1
)
==
3
:
grid_indices
=
grid_coords
[...,
0
]
*
steps
[
1
]
*
steps
[
2
]
\
+
grid_coords
[...,
1
]
*
steps
[
2
]
+
grid_coords
[...,
2
]
elif
pts
.
size
(
-
1
)
==
4
:
grid_indices
=
grid_coords
[...,
0
]
*
steps
[
1
]
*
steps
[
2
]
*
steps
[
3
]
\
+
grid_coords
[...,
1
]
*
steps
[
2
]
*
steps
[
3
]
\
+
grid_coords
[...,
2
]
*
steps
[
3
]
\
+
grid_coords
[...,
3
]
else
:
raise
NotImplementedError
(
"The function does not support D>4"
)
return
grid_indices
,
outside_mask
return
grid_indices
,
outside_mask
...
@@ -76,7 +70,7 @@ def init_voxels(bbox: torch.Tensor, steps: torch.Tensor):
...
@@ -76,7 +70,7 @@ def init_voxels(bbox: torch.Tensor, steps: torch.Tensor):
"""
"""
Initialize voxels.
Initialize voxels.
"""
"""
x
,
y
,
z
=
torch
.
meshgrid
(
*
[
torch
.
arange
(
steps
[
i
])
for
i
in
range
(
3
)])
x
,
y
,
z
=
torch
.
meshgrid
(
*
[
torch
.
arange
(
steps
[
i
])
for
i
in
range
(
3
)]
,
indexing
=
"ij"
)
return
to_voxel_centers
(
torch
.
stack
([
x
,
y
,
z
],
-
1
).
reshape
(
-
1
,
3
),
bbox
,
steps
=
steps
)
return
to_voxel_centers
(
torch
.
stack
([
x
,
y
,
z
],
-
1
).
reshape
(
-
1
,
3
),
bbox
,
steps
=
steps
)
...
@@ -96,7 +90,7 @@ def to_voxel_centers(grid_coords: torch.Tensor, bbox: torch.Tensor, *,
...
@@ -96,7 +90,7 @@ def to_voxel_centers(grid_coords: torch.Tensor, bbox: torch.Tensor, *,
:param steps `Tensor(1|D)`: (optional) steps alone every dim
:param steps `Tensor(1|D)`: (optional) steps alone every dim
:return `Tensor(N..., D)`: discretized grid coordinates
:return `Tensor(N..., D)`: discretized grid coordinates
"""
"""
grid_coords
=
grid_coords
.
float
()
+
0
.5
grid_coords
=
grid_coords
.
float
()
+
.
5
if
step_size
is
not
None
:
if
step_size
is
not
None
:
return
grid_coords
*
step_size
+
bbox
[
0
]
return
grid_coords
*
step_size
+
bbox
[
0
]
return
grid_coords
/
steps
*
(
bbox
[
1
]
-
bbox
[
0
])
+
bbox
[
0
]
return
grid_coords
/
steps
*
(
bbox
[
1
]
-
bbox
[
0
])
+
bbox
[
0
]
...
@@ -121,8 +115,8 @@ def split_voxels_local(voxel_size: Union[torch.Tensor, float], n: int, align_bor
...
@@ -121,8 +115,8 @@ def split_voxels_local(voxel_size: Union[torch.Tensor, float], n: int, align_bor
dtype
=
like
.
dtype
dtype
=
like
.
dtype
device
=
like
.
device
device
=
like
.
device
c
=
torch
.
arange
(
1
-
n
,
n
,
2
,
dtype
=
dtype
,
device
=
device
)
c
=
torch
.
arange
(
1
-
n
,
n
,
2
,
dtype
=
dtype
,
device
=
device
)
offset
=
torch
.
stack
(
torch
.
meshgrid
([
c
]
*
dims
),
-
1
).
flatten
(
0
,
-
2
)
*
voxel_size
/
2
/
\
offset
=
torch
.
stack
(
torch
.
meshgrid
([
c
]
*
dims
,
indexing
=
'ij'
),
-
1
).
flatten
(
0
,
-
2
)
\
(
n
-
1
if
align_border
else
n
)
*
voxel_size
*
.
5
/
(
n
-
1
if
align_border
else
n
)
return
offset
return
offset
...
@@ -144,7 +138,7 @@ def split_voxels(voxel_centers: torch.Tensor, voxel_size: Union[torch.Tensor, fl
...
@@ -144,7 +138,7 @@ def split_voxels(voxel_centers: torch.Tensor, voxel_size: Union[torch.Tensor, fl
def
get_corners
(
voxel_centers
:
torch
.
Tensor
,
bbox
:
torch
.
Tensor
,
steps
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
def
get_corners
(
voxel_centers
:
torch
.
Tensor
,
bbox
:
torch
.
Tensor
,
steps
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
half_voxel_size
=
(
bbox
[
1
]
-
bbox
[
0
])
/
steps
*
0.5
half_voxel_size
=
(
bbox
[
1
]
-
bbox
[
0
])
/
steps
*
0.5
expand_bbox
=
bbox
expand_bbox
=
bbox
.
clone
()
expand_bbox
[
0
]
-=
0.5
*
half_voxel_size
expand_bbox
[
0
]
-=
0.5
*
half_voxel_size
expand_bbox
[
1
]
+=
0.5
*
half_voxel_size
expand_bbox
[
1
]
+=
0.5
*
half_voxel_size
double_grid_coords
=
to_grid_coords
(
voxel_centers
,
expand_bbox
,
step_size
=
half_voxel_size
)
double_grid_coords
=
to_grid_coords
(
voxel_centers
,
expand_bbox
,
step_size
=
half_voxel_size
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment