Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
Nianchen Deng
deeplightfield
Commits
2824f796
Commit
2824f796
authored
Dec 06, 2021
by
Nianchen Deng
Browse files
sync
parent
5699ccbf
Changes
17
Expand all
Hide whitespace changes
Inline
Side-by-side
.vscode/launch.json
View file @
2824f796
...
...
@@ -24,14 +24,19 @@
"program"
:
"train.py"
,
"args"
:
[
//
"-c"
,
//
"snerf_voxels"
,
"/
home/deng
nc/dvs/data/__ne
w/barbershop_fovea_r360x80_t0.6
/_nets/train
_t0.3
/snerf
advx
_voxels
_x4
/checkpoint_1
0
.tar"
,
//
"snerf_voxels
+ls+f32
"
,
"/
data1/d
nc/dvs/data/__ne
rf/room
/_nets/train/snerf_voxels
+ls+f32
/checkpoint_1.tar"
,
"--prune"
,
"1
00
"
,
"1"
,
"--split"
,
"100"
//
"data/__new/barbershop_fovea_r360x80_t0.6/train_t0.3.json"
"1"
,
"-e"
,
"100"
,
"--views"
,
"5"
,
//
"data/__nerf/room/train.json"
],
"justMyCode"
:
false
,
"console"
:
"integratedTerminal"
},
{
...
...
configs/snerf_voxels+ls+f32.json
0 → 100644
View file @
2824f796
{
"model"
:
"SNeRF"
,
"args"
:
{
"color"
:
"rgb"
,
"n_pot_encode"
:
10
,
"n_dir_encode"
:
4
,
"fc_params"
:
{
"nf"
:
256
,
"n_layers"
:
8
,
"activation"
:
"relu"
,
"skips"
:
[
4
]
},
"n_featdim"
:
32
,
"space"
:
"voxels"
,
"steps"
:
[
4
,
16
,
8
],
"n_samples"
:
16
,
"perturb_sample"
:
true
,
"density_regularization_weight"
:
1e-4
,
"density_regularization_scale"
:
1e4
}
}
\ No newline at end of file
model/base.py
View file @
2824f796
...
...
@@ -16,7 +16,7 @@ class BaseModelMeta(type):
class
BaseModel
(
nn
.
Module
,
metaclass
=
BaseModelMeta
):
t
rainer
=
"Train"
T
rainer
Class
=
"Train"
@
property
def
args
(
self
):
...
...
model/nerf.py
View file @
2824f796
...
...
@@ -10,7 +10,7 @@ from utils.misc import masked_scatter
class
NeRF
(
BaseModel
):
t
rainer
=
"TrainWithSpace"
T
rainer
Class
=
"TrainWithSpace"
SamplerClass
=
Sampler
RendererClass
=
VolumnRenderer
...
...
@@ -124,21 +124,11 @@ class NeRF(BaseModel):
return
self
.
pot_encoder
(
x
)
def
encode_d
(
self
,
samples
:
Samples
)
->
torch
.
Tensor
:
return
self
.
dir_encoder
(
samples
.
dirs
)
if
self
.
dir_encoder
is
not
None
else
None
return
self
.
dir_encoder
(
samples
.
dirs
)
if
self
.
dir_encoder
else
None
@
torch
.
no_grad
()
def
get_scores
(
self
,
sampled_points
:
torch
.
Tensor
,
sampled_voxel_indices
:
torch
.
Tensor
)
->
torch
.
Tensor
:
densities
=
self
.
render
(
Samples
(
sampled_points
,
None
,
None
,
None
,
sampled_voxel_indices
),
'density'
)
return
1
-
(
-
densities
).
exp
()
@
torch
.
no_grad
()
def
pruning
(
self
,
threshold
:
float
=
0.5
,
train_stats
=
False
):
return
self
.
space
.
pruning
(
self
.
get_scores
,
threshold
,
train_stats
)
@
torch
.
no_grad
()
def
splitting
(
self
):
ret
=
self
.
space
.
splitting
()
def
split
(
self
):
ret
=
self
.
space
.
split
()
if
'n_samples'
in
self
.
args0
:
self
.
args0
[
'n_samples'
]
*=
2
if
'voxel_size'
in
self
.
args0
:
...
...
@@ -149,12 +139,10 @@ class NeRF(BaseModel):
if
'sample_step'
in
self
.
args0
:
self
.
args0
[
'sample_step'
]
/=
2
self
.
sampler
=
self
.
SamplerClass
(
**
self
.
args
)
if
self
.
args
.
get
(
'n_featdim'
)
and
hasattr
(
self
,
"trainer"
):
self
.
trainer
.
reset_optimizer
()
return
ret
@
torch
.
no_grad
()
def
double_samples
(
self
):
pass
@
perf
def
forward
(
self
,
rays_o
:
torch
.
Tensor
,
rays_d
:
torch
.
Tensor
,
*
,
extra_outputs
:
List
[
str
]
=
[],
**
kwargs
)
->
torch
.
Tensor
:
...
...
model/snerf_advance_x.py
View file @
2824f796
...
...
@@ -40,16 +40,8 @@ class SNeRFAdvanceX(SNeRFAdvance):
return
self
.
cores
[
chunk_id
](
x
,
d
,
outputs
,
**
extras
)
@
torch
.
no_grad
()
def
get_scores
(
self
,
sampled_points
:
torch
.
Tensor
,
sampled_voxel_indices
:
torch
.
Tensor
)
->
torch
.
Tensor
:
raise
NotImplementedError
()
@
torch
.
no_grad
()
def
pruning
(
self
,
threshold
:
float
=
0.5
,
train_stats
=
False
):
raise
NotImplementedError
()
@
torch
.
no_grad
()
def
splitting
(
self
):
ret
=
super
().
splitting
()
def
split
(
self
):
ret
=
super
().
split
()
k
=
self
.
args
[
"n_samples"
]
//
self
.
space
.
steps
[
0
].
item
()
net_samples
=
[
val
*
k
for
val
in
self
.
space
.
balance_cut
(
0
,
len
(
self
.
cores
))]
if
len
(
net_samples
)
!=
len
(
self
.
cores
):
...
...
model/snerf_x.py
View file @
2824f796
...
...
@@ -4,10 +4,6 @@ from .snerf import *
class
SNeRFX
(
SNeRF
):
trainer
=
"TrainWithSpace"
SamplerClass
=
SphericalSampler
RendererClass
=
VolumnRenderer
def
__init__
(
self
,
args0
:
dict
,
args1
:
dict
=
{}):
"""
Initialize a multi-sphere-layer net
...
...
@@ -42,16 +38,8 @@ class SNeRFX(SNeRF):
return
self
.
cores
[
chunk_id
](
x
,
d
,
outputs
)
@
torch
.
no_grad
()
def
get_scores
(
self
,
sampled_points
:
torch
.
Tensor
,
sampled_voxel_indices
:
torch
.
Tensor
)
->
torch
.
Tensor
:
raise
NotImplementedError
()
@
torch
.
no_grad
()
def
pruning
(
self
,
threshold
:
float
=
0.5
,
train_stats
=
False
):
raise
NotImplementedError
()
@
torch
.
no_grad
()
def
splitting
(
self
):
ret
=
super
().
splitting
()
def
split
(
self
):
ret
=
super
().
split
()
k
=
self
.
args
[
"n_samples"
]
//
self
.
space
.
steps
[
0
].
item
()
net_samples
=
[
val
*
k
for
val
in
self
.
space
.
balance_cut
(
0
,
len
(
self
.
cores
))
...
...
modules/space.py
View file @
2824f796
from
math
import
ceil
import
torch
import
numpy
as
np
from
typing
import
List
,
NoReturn
,
Tuple
,
Union
from
typing
import
List
,
Tuple
,
Union
from
torch
import
nn
from
plyfile
import
PlyData
,
PlyElement
from
utils.geometry
import
*
from
utils.constants
import
*
...
...
@@ -73,11 +70,11 @@ class Space(nn.Module):
return
voxel_indices
@
torch
.
no_grad
()
def
prun
ing
(
self
,
score_fn
,
threshold
:
float
=
0.5
,
train_stats
=
False
)
:
def
prun
e
(
self
,
keeps
:
torch
.
Tensor
)
->
Tuple
[
int
,
int
]
:
raise
NotImplementedError
()
@
torch
.
no_grad
()
def
split
ting
(
self
):
def
split
(
self
):
raise
NotImplementedError
()
...
...
@@ -108,7 +105,7 @@ class Voxels(Space):
return
self
.
voxels
.
size
(
0
)
@
property
def
n_corner
(
self
)
->
int
:
def
n_corner
s
(
self
)
->
int
:
"""`int` Number of corners"""
return
self
.
corners
.
size
(
0
)
...
...
@@ -145,12 +142,18 @@ class Voxels(Space):
:param n_dims `int`: embedding dimension
:return `Embedding(n_corners, n_dims)`: new embedding on voxel corners
"""
name
=
f
'emb_
{
name
}
'
self
.
add_module
(
name
,
torch
.
nn
.
Embedding
(
self
.
n_corners
.
item
(),
n_dims
))
return
self
.
__getattr__
(
name
)
if
self
.
get_embedding
(
name
)
is
not
None
:
raise
KeyError
(
f
"Embedding '
{
name
}
' already existed"
)
emb
=
torch
.
nn
.
Embedding
(
self
.
n_corners
,
n_dims
,
device
=
self
.
device
)
setattr
(
self
,
f
'emb_
{
name
}
'
,
emb
)
return
emb
def
get_embedding
(
self
,
name
:
str
=
'default'
)
->
torch
.
nn
.
Embedding
:
return
getattr
(
self
,
f
'emb_
{
name
}
'
)
return
getattr
(
self
,
f
'emb_
{
name
}
'
,
None
)
def
set_embedding
(
self
,
weight
:
torch
.
Tensor
,
name
:
str
=
'default'
):
emb
=
torch
.
nn
.
Embedding
(
*
weight
.
shape
,
_weight
=
weight
,
device
=
self
.
device
)
setattr
(
self
,
f
'emb_
{
name
}
'
,
emb
)
def
extract_embedding
(
self
,
pts
:
torch
.
Tensor
,
voxel_indices
:
torch
.
Tensor
,
name
:
str
=
'default'
)
->
torch
.
Tensor
:
...
...
@@ -167,9 +170,8 @@ class Voxels(Space):
raise
KeyError
(
f
"Embedding '
{
name
}
' doesn't exist"
)
voxels
=
self
.
voxels
[
voxel_indices
]
# (N, 3)
corner_indices
=
self
.
corner_indices
[
voxel_indices
]
# (N, 8)
p
=
(
pts
-
voxels
)
/
self
.
voxel_size
+
0.5
# (N, 3) normed-coords in voxel
features
=
emb
(
corner_indices
).
reshape
(
pts
.
size
(
0
),
8
,
-
1
)
# (N, 8, X)
return
trilinear_interp
(
p
,
features
)
p
=
(
pts
-
voxels
)
/
self
.
voxel_size
+
.
5
# (N, 3) normed-coords in voxel
return
trilinear_interp
(
p
,
emb
(
corner_indices
))
@
perf
def
ray_intersect
(
self
,
rays_o
:
torch
.
Tensor
,
rays_d
:
torch
.
Tensor
,
n_max_hits
:
int
)
->
Intersections
:
...
...
@@ -220,17 +222,34 @@ class Voxels(Space):
return
voxel_indices
@
torch
.
no_grad
()
def
split
ting
(
self
)
->
None
:
def
split
(
self
)
->
None
:
"""
Split voxels into smaller voxels with half size.
"""
n_voxels_before
=
self
.
n_voxels
self
.
steps
*=
2
self
.
voxels
=
split_voxels
(
self
.
voxels
,
self
.
voxel_size
,
2
,
align_border
=
False
)
\
new_steps
=
self
.
steps
*
2
new_voxels
=
split_voxels
(
self
.
voxels
,
self
.
voxel_size
,
2
,
align_border
=
False
)
\
.
reshape
(
-
1
,
3
)
self
.
_update_corners
()
new_corners
,
new_corner_indices
=
get_corners
(
new_voxels
,
self
.
bbox
,
new_steps
)
# Calculate new embeddings through trilinear interpolation
grid_indices_of_new_corners
=
to_flat_indices
(
to_grid_coords
(
new_corners
,
self
.
bbox
,
steps
=
self
.
steps
).
min
(
self
.
steps
-
1
),
self
.
steps
)
voxel_indices_of_new_corners
=
self
.
voxel_indices_in_grid
[
grid_indices_of_new_corners
]
for
name
,
_
in
self
.
named_modules
():
if
not
name
.
startswith
(
"emb_"
):
continue
new_emb_weight
=
self
.
extract_embedding
(
new_corners
,
voxel_indices_of_new_corners
,
name
=
name
[
4
:])
self
.
set_embedding
(
new_emb_weight
,
name
=
name
[
4
:])
# Apply new tensors
self
.
steps
=
new_steps
self
.
voxels
=
new_voxels
self
.
corners
=
new_corners
self
.
corner_indices
=
new_corner_indices
self
.
_update_voxel_indices_in_grid
()
return
n_voxels
_before
,
self
.
n_voxels
return
self
.
n_voxels
//
8
,
self
.
n_voxels
@
torch
.
no_grad
()
def
prune
(
self
,
keeps
:
torch
.
Tensor
)
->
Tuple
[
int
,
int
]:
...
...
@@ -239,11 +258,6 @@ class Voxels(Space):
self
.
_update_voxel_indices_in_grid
()
return
keeps
.
size
(
0
),
keeps
.
sum
().
item
()
@
torch
.
no_grad
()
def
pruning
(
self
,
score_fn
,
threshold
:
float
=
0.5
)
->
None
:
scores
=
self
.
_get_scores
(
score_fn
,
lambda
x
:
torch
.
max
(
x
,
-
1
)[
0
])
# (M)
return
self
.
prune
(
scores
>
threshold
)
def
n_voxels_along_dim
(
self
,
dim
:
int
)
->
torch
.
Tensor
:
sum_dims
=
[
val
for
val
in
range
(
self
.
dims
)
if
val
!=
dim
]
return
self
.
voxel_indices_in_grid
.
reshape
(
*
self
.
steps
).
ne
(
-
1
).
sum
(
sum_dims
)
...
...
@@ -261,39 +275,30 @@ class Voxels(Space):
part
=
int
(
cdf
[
i
])
+
1
return
bins
def
sample
(
self
,
bits
:
int
,
perturb
:
bool
=
False
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
sampled_xyz
=
split_voxels
(
self
.
voxels
,
self
.
voxel_size
,
bits
)
sampled_idx
=
torch
.
arange
(
self
.
n_voxels
,
device
=
self
.
device
)[:,
None
].
expand
(
*
sampled_xyz
.
shape
[:
2
])
sampled_xyz
,
sampled_idx
=
sampled_xyz
.
reshape
(
-
1
,
3
),
sampled_idx
.
flatten
()
def
sample
(
self
,
S
:
int
,
perturb
:
bool
=
False
,
include_border
:
bool
=
True
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
For each voxel, sample `S^3` points uniformly, with small perturb if `perturb` is `True`.
When `perturb` is `False`, `include_border` can specify whether to sample points from border to border or at centers of sub-voxels.
When `perturb` is `True`, points are sampled at centers of sub-voxels, then applying a random offset in sub-voxels.
@
torch
.
no_grad
()
def
_get_scores
(
self
,
score_fn
,
reduce_fn
=
None
,
bits
=
16
)
->
torch
.
Tensor
:
def
get_scores_once
(
pts
,
idxs
):
scores
=
score_fn
(
pts
,
idxs
).
reshape
(
-
1
,
bits
**
3
)
# (B, P)
if
reduce_fn
is
not
None
:
scores
=
reduce_fn
(
scores
)
# (B[, ...])
return
scores
sampled_xyz
,
sampled_idx
=
self
.
sample
(
bits
)
chunk_size
=
64
return
torch
.
cat
([
get_scores_once
(
sampled_xyz
[
i
:
i
+
chunk_size
],
sampled_idx
[
i
:
i
+
chunk_size
])
for
i
in
range
(
0
,
self
.
voxels
.
size
(
0
),
chunk_size
)
],
0
)
# (M[, ...])
:param S `int`: number of samples along each dim
:param perturb `bool?`: whether perturb samples, defaults to `False`
:param include_border `bool?`: whether include border, defaults to `True`
:return `Tensor(N*S^3, 3)`: sampled points
:return `Tensor(N*S^3)`: voxel indices of sampled points
"""
pts
=
split_voxels
(
self
.
voxels
,
self
.
voxel_size
,
S
,
align_border
=
not
perturb
and
include_border
)
# (N, X, D)
voxel_indices
=
torch
.
arange
(
self
.
n_voxels
,
device
=
self
.
device
)[:,
None
]
\
.
expand
(
*
pts
.
shape
[:
-
1
])
# (N) -> (N, X)
if
perturb
:
pts
+=
(
torch
.
rand_like
(
pts
)
-
.
5
)
*
self
.
voxel_size
/
S
return
pts
.
reshape
(
-
1
,
3
),
voxel_indices
.
flatten
()
def
_ray_intersect
(
self
,
rays_o
:
torch
.
Tensor
,
rays_d
:
torch
.
Tensor
,
n_max_hits
:
int
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
return
aabb_ray_intersect
(
self
.
voxel_size
,
n_max_hits
,
self
.
voxels
,
rays_o
,
rays_d
)
def
_update_corners
(
self
):
"""
Update voxel corners.
"""
corners
,
corner_indices
=
get_corners
(
self
.
voxels
,
self
.
bbox
,
self
.
steps
)
self
.
register_buffer
(
"corners"
,
corners
)
self
.
register_buffer
(
"corner_indices"
,
corner_indices
)
def
_update_voxel_indices_in_grid
(
self
):
"""
Update voxel indices in grid.
...
...
@@ -314,7 +319,7 @@ class Voxels(Space):
# Handle embeddings
for
name
,
module
in
self
.
named_modules
():
if
name
.
startswith
(
'emb_'
):
setattr
(
self
,
name
,
torch
.
nn
.
Embedding
(
self
.
n_corners
.
item
()
,
module
.
embedding_dim
))
setattr
(
self
,
name
,
torch
.
nn
.
Embedding
(
self
.
n_corners
,
module
.
embedding_dim
))
class
Octree
(
Voxels
):
...
...
@@ -339,8 +344,8 @@ class Octree(Voxels):
return
octree_ray_intersect
(
self
.
voxel_size
,
n_max_hits
,
nodes
,
tree
,
rays_o
,
rays_d
)
@
torch
.
no_grad
()
def
split
ting
(
self
):
ret
=
super
().
split
ting
()
def
split
(
self
):
ret
=
super
().
split
()
self
.
clear
()
return
ret
...
...
test.ipynb
0 → 100644
View file @
2824f796
{
"cells": [
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from utils.voxels import *\n",
"\n",
"bbox, steps = torch.tensor([[-2, -3.14159, 1], [2, 3.14159, 0]]), torch.tensor([2, 3, 3])\n",
"voxel_size = (bbox[1] - bbox[0]) / steps\n",
"voxels = init_voxels(bbox, steps)\n",
"corners, corner_indices = get_corners(voxels, bbox, steps)\n",
"voxel_indices_in_grid = torch.arange(voxels.shape[0])\n",
"emb = torch.nn.Embedding(corners.shape[0], 3, _weight=corners)"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([11, 3]) tensor([ 0, -1, -1, 1, -1, -1, 2, 3, 4, -1, 5, 6, -1, 7, 8, -1, 9, 10])\n"
]
}
],
"source": [
"keeps = torch.tensor([True]*18)\n",
"keeps[torch.tensor([1,2,4,5,9,12,15])] = False\n",
"voxels = voxels[keeps]\n",
"corner_indices = corner_indices[keeps]\n",
"grid_indices, _ = to_grid_indices(voxels, bbox, steps=steps)\n",
"voxel_indices_in_grid = grid_indices.new_full([steps.prod().item()], -1)\n",
"voxel_indices_in_grid[grid_indices] = torch.arange(voxels.shape[0])\n",
"print(voxels.shape, voxel_indices_in_grid)"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([88, 3]) torch.Size([185, 3]) torch.Size([88, 8])\n"
]
}
],
"source": [
"new_voxels = split_voxels(voxels, (bbox[1] - bbox[0]) / steps, 2, align_border=False).reshape(-1, 3)\n",
"new_corners, new_corner_indices = get_corners(new_voxels, bbox, steps * 2)\n",
"print(new_voxels.shape, new_corners.shape, new_corner_indices.shape)"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor([ 0, 0, -1, 0, 0, -1, 1, 1, -1, 1, 1, -1, 2, 2, 3, 3, 4, 4,\n",
" 4, 2, 2, 3, 3, 4, 4, 4, 2, 2, 3, 3, 4, 4, 4, 0, 0, -1,\n",
" 0, 0, -1, 1, 1, -1, 1, 1, -1, 2, 2, 3, 3, 4, 4, 4, 2, 2,\n",
" 3, 3, 4, 4, 4, 2, 2, 3, 3, 4, 4, 4, -1, -1, 5, 5, 6, 6,\n",
" 6, -1, -1, 5, 5, 6, 6, 6, -1, -1, 7, 7, 8, 8, 8, -1, -1, 7,\n",
" 7, 8, 8, 8, -1, -1, 9, 9, 10, 10, 10, -1, -1, 9, 9, 10, 10, 10,\n",
" -1, -1, 9, 9, 10, 10, 10, 5, 5, 6, 6, 6, 5, 5, 6, 6, 6, 7,\n",
" 7, 8, 8, 8, 7, 7, 8, 8, 8, 9, 9, 10, 10, 10, 9, 9, 10, 10,\n",
" 10, 9, 9, 10, 10, 10, 5, 5, 6, 6, 6, 5, 5, 6, 6, 6, 7, 7,\n",
" 8, 8, 8, 7, 7, 8, 8, 8, 9, 9, 10, 10, 10, 9, 9, 10, 10, 10,\n",
" 9, 9, 10, 10, 10])\n",
"tensor(0)\n"
]
}
],
"source": [
"voxel_indices_of_new_corner = voxel_indices_in_grid[to_flat_indices(to_grid_coords(new_corners, bbox, steps=steps).min(steps - 1), steps)]\n",
"print(voxel_indices_of_new_corner)\n",
"p_of_new_corners = (new_corners - voxels[voxel_indices_of_new_corner]) / voxel_size + .5\n",
"print(((new_corners - trilinear_interp(p_of_new_corners, emb(corner_indices[voxel_indices_of_new_corner]))) > 1e-6).sum())"
]
}
],
"metadata": {
"interpreter": {
"hash": "08b118544df3cb8970a671e5837a88fd458f4d4c799ef1fb2709465a22a45b92"
},
"kernelspec": {
"display_name": "Python 3.9.5 64-bit ('base': conda)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.5"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}
test.py
View file @
2824f796
...
...
@@ -38,7 +38,7 @@ from data.loader import DataLoader
from
utils.constants
import
HUGE_FLOAT
RAYS_PER_BATCH
=
2
**
1
4
RAYS_PER_BATCH
=
2
**
1
2
DATA_LOADER_CHUNK_SIZE
=
1e8
...
...
test.txt
0 → 100644
View file @
2824f796
This diff is collapsed.
Click to expand it.
test1.txt
0 → 100644
View file @
2824f796
This diff is collapsed.
Click to expand it.
train.py
View file @
2824f796
...
...
@@ -13,8 +13,9 @@ from data.loader import DataLoader
from
utils.misc
import
list_epochs
,
print_and_log
RAYS_PER_BATCH
=
2
**
1
6
RAYS_PER_BATCH
=
2
**
1
2
DATA_LOADER_CHUNK_SIZE
=
1e8
root_dir
=
Path
.
cwd
()
parser
=
argparse
.
ArgumentParser
()
...
...
@@ -68,7 +69,7 @@ if args.mdl_path:
model_args
=
model
.
args
else
:
# Create model from specified configuration
with
Path
(
f
'
{
sys
.
path
[
0
]
}
/configs/
{
args
.
config
}
.json'
).
open
()
as
fp
:
with
Path
(
f
'
{
root_dir
}
/configs/
{
args
.
config
}
.json'
).
open
()
as
fp
:
config
=
json
.
load
(
fp
)
model_name
=
args
.
config
model_class
=
config
[
'model'
]
...
...
@@ -76,7 +77,7 @@ else:
model_args
[
'bbox'
]
=
dataset
.
bbox
model_args
[
'depth_range'
]
=
dataset
.
depth_range
model
,
states
=
mdl
.
create
(
model_class
,
model_args
),
None
model
.
to
(
device
.
default
())
.
train
()
model
.
to
(
device
.
default
())
run_dir
=
Path
(
f
"_nets/
{
dataset
.
name
}
/
{
model_name
}
"
)
run_dir
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
...
...
train/__init__.py
View file @
2824f796
...
...
@@ -22,5 +22,5 @@ def get_class(class_name: str) -> type:
def
get_trainer
(
model
:
BaseModel
,
**
kwargs
)
->
base
.
Train
:
train_class
=
get_class
(
model
.
t
rainer
)
train_class
=
get_class
(
model
.
T
rainer
Class
)
return
train_class
(
model
,
**
kwargs
)
train/base.py
View file @
2824f796
...
...
@@ -42,8 +42,9 @@ class Train(object, metaclass=BaseTrainMeta):
self
.
iters
=
0
self
.
run_dir
=
run_dir
self
.
model
.
trainer
=
self
self
.
model
.
train
()
self
.
optimizer
=
torch
.
optim
.
Adam
(
self
.
model
.
parameters
(),
lr
=
5e-4
)
self
.
reset_
optimizer
(
)
if
states
:
if
'epoch'
in
states
:
...
...
@@ -58,6 +59,9 @@ class Train(object, metaclass=BaseTrainMeta):
if
self
.
perf_mode
:
enable_perf
()
def
reset_optimizer
(
self
):
self
.
optimizer
=
torch
.
optim
.
Adam
(
self
.
model
.
parameters
(),
lr
=
5e-4
)
def
train
(
self
,
data_loader
:
DataLoader
,
max_epochs
:
int
):
self
.
data_loader
=
data_loader
self
.
iters_per_epoch
=
self
.
perf_frames
or
len
(
data_loader
)
...
...
train/train_with_space.py
View file @
2824f796
...
...
@@ -20,18 +20,15 @@ class TrainWithSpace(Train):
if
self
.
splitting_loop
==
1
or
self
.
epoch
%
self
.
splitting_loop
==
1
:
try
:
with
torch
.
no_grad
():
before
,
after
=
self
.
model
.
splitting
()
print_and_log
(
f
"Splitting done. # of voxels before:
{
before
}
, after:
{
after
}
"
)
before
,
after
=
self
.
model
.
split
()
print_and_log
(
f
"Splitting done:
{
before
}
->
{
after
}
"
)
except
NotImplementedError
:
print_and_log
(
"Note: The space does not support splitting operation. Just skip it."
)
if
self
.
pruning_loop
==
1
or
self
.
epoch
%
self
.
pruning_loop
==
1
:
try
:
with
torch
.
no_grad
():
#before, after = self.model.pruning()
# print(f"Pruning by voxel densities done. # of voxels before: {before}, after: {after}")
# self._prune_inner_voxels()
# self._prune_voxels_by_densities()
self
.
_prune_voxels_by_weights
()
except
NotImplementedError
:
print_and_log
(
...
...
@@ -39,26 +36,26 @@ class TrainWithSpace(Train):
super
().
_train_epoch
()
def
_prune_
inner_voxel
s
(
self
):
def
_prune_
voxels_by_densitie
s
(
self
):
space
:
Voxels
=
self
.
model
.
space
voxel_access_counts
=
torch
.
zeros
(
space
.
n_voxels
,
dtype
=
torch
.
long
,
device
=
space
.
voxels
.
device
)
iters_in_epoch
=
0
batch_size
=
self
.
data_loader
.
batch_size
self
.
data_loader
.
batch_size
=
2
**
14
for
_
,
rays_o
,
rays_d
,
_
in
self
.
data_loa
der
:
self
.
model
(
rays_o
,
rays_d
,
raymarching_early_stop_tolerance
=
0.01
,
raymarching_chunk_size_or_sections
=
[
1
],
perturb_sample
=
False
,
voxel_access_counts
=
voxel_access_coun
ts
,
voxel_access_tolerance
=
0
)
iters_in_epoch
+=
1
percent
=
iters_in_epoch
/
len
(
self
.
data_loader
)
*
100
sys
.
stdout
.
write
(
f
'Pruning inner voxels...
{
percent
:.
1
f
}
%
\r
'
)
self
.
data_loader
.
batch_size
=
batch
_size
before
,
after
=
space
.
prune
(
voxel_access_counts
>
0
)
print
(
f
"Prune inner voxels:
{
before
}
->
{
after
}
"
)
threshold
=
.
5
bits
=
16
@
torch
.
no_grad
()
def
get_scores
(
sampled_points
:
torch
.
Tensor
,
sampled_voxel_indices
:
torch
.
Tensor
)
->
torch
.
Tensor
:
densities
=
self
.
model
.
ren
der
(
Samples
(
sampled_points
,
None
,
None
,
None
,
sampled_voxel_indices
)
,
'density'
)
return
1
-
(
-
densities
).
exp
()
sampled_xyz
,
sampled_idx
=
space
.
sample
(
bi
ts
)
chunk_size
=
64
scores
=
torch
.
cat
([
torch
.
max
(
get_scores
(
sampled_xyz
[
i
:
i
+
chunk_size
],
sampled_idx
[
i
:
i
+
chunk_size
])
.
reshape
(
-
1
,
bits
**
3
),
-
1
)[
0
]
for
i
in
range
(
0
,
self
.
voxels
.
size
(
0
),
chunk
_size
)
],
0
)
# (M[, ...]
)
return
space
.
prune
(
scores
>
threshold
)
def
_prune_voxels_by_weights
(
self
):
space
:
Voxels
=
self
.
model
.
space
...
...
utils/misc.py
View file @
2824f796
...
...
@@ -57,10 +57,11 @@ def meshgrid(*size: int, normalize: bool = False, swap_dim: bool = False) -> tor
"""
if
len
(
size
)
==
1
:
size
=
(
size
[
0
],
size
[
0
])
y
,
x
=
torch
.
meshgrid
(
torch
.
arange
(
0
,
size
[
0
]),
torch
.
arange
(
0
,
size
[
1
]))
if
swap_dim
:
return
torch
.
stack
([
y
/
(
size
[
0
]
-
1.
),
x
/
(
size
[
1
]
-
1.
)],
2
)
if
normalize
else
torch
.
stack
([
y
,
x
],
2
)
return
torch
.
stack
([
x
/
(
size
[
1
]
-
1.
),
y
/
(
size
[
0
]
-
1.
)],
2
)
if
normalize
else
torch
.
stack
([
x
,
y
],
2
)
y
,
x
=
torch
.
meshgrid
(
torch
.
arange
(
size
[
0
]),
torch
.
arange
(
size
[
1
]),
indexing
=
'ij'
)
if
normalize
:
x
.
div_
(
size
[
1
]
-
1.
)
y
.
div_
(
size
[
0
]
-
1.
)
return
torch
.
stack
([
y
,
x
],
2
)
if
swap_dim
else
torch
.
stack
([
x
,
y
],
2
)
def
get_angle
(
x
:
torch
.
Tensor
,
y
:
torch
.
Tensor
)
->
torch
.
Tensor
:
...
...
utils/voxels.py
View file @
2824f796
...
...
@@ -13,6 +13,13 @@ def get_grid_steps(bbox: torch.Tensor, step_size: Union[torch.Tensor, float]) ->
return
((
bbox
[
1
]
-
bbox
[
0
])
/
step_size
).
ceil
().
long
()
def
to_flat_indices
(
grid_coords
:
torch
.
Tensor
,
steps
:
torch
.
Tensor
)
->
torch
.
Tensor
:
indices
=
grid_coords
[...,
0
]
for
i
in
range
(
1
,
grid_coords
.
shape
[
-
1
]):
indices
=
indices
*
steps
[
i
]
+
grid_coords
[...,
i
]
return
indices
def
to_grid_coords
(
pts
:
torch
.
Tensor
,
bbox
:
torch
.
Tensor
,
*
,
step_size
:
Union
[
torch
.
Tensor
,
float
]
=
None
,
steps
:
torch
.
Tensor
=
None
)
->
torch
.
Tensor
:
...
...
@@ -55,20 +62,7 @@ def to_grid_indices(pts: torch.Tensor, bbox: torch.Tensor, *,
steps
=
get_grid_steps
(
bbox
,
step_size
)
# (D)
grid_coords
=
to_grid_coords
(
pts
,
bbox
,
step_size
=
step_size
,
steps
=
steps
)
# (N..., D)
outside_mask
=
torch
.
logical_or
(
grid_coords
<
0
,
grid_coords
>=
steps
).
any
(
-
1
)
# (N...)
if
pts
.
size
(
-
1
)
==
1
:
grid_indices
=
grid_coords
[...,
0
]
elif
pts
.
size
(
-
1
)
==
2
:
grid_indices
=
grid_coords
[...,
0
]
*
steps
[
1
]
+
grid_coords
[...,
1
]
elif
pts
.
size
(
-
1
)
==
3
:
grid_indices
=
grid_coords
[...,
0
]
*
steps
[
1
]
*
steps
[
2
]
\
+
grid_coords
[...,
1
]
*
steps
[
2
]
+
grid_coords
[...,
2
]
elif
pts
.
size
(
-
1
)
==
4
:
grid_indices
=
grid_coords
[...,
0
]
*
steps
[
1
]
*
steps
[
2
]
*
steps
[
3
]
\
+
grid_coords
[...,
1
]
*
steps
[
2
]
*
steps
[
3
]
\
+
grid_coords
[...,
2
]
*
steps
[
3
]
\
+
grid_coords
[...,
3
]
else
:
raise
NotImplementedError
(
"The function does not support D>4"
)
grid_indices
=
to_flat_indices
(
grid_coords
,
steps
)
return
grid_indices
,
outside_mask
...
...
@@ -76,7 +70,7 @@ def init_voxels(bbox: torch.Tensor, steps: torch.Tensor):
"""
Initialize voxels.
"""
x
,
y
,
z
=
torch
.
meshgrid
(
*
[
torch
.
arange
(
steps
[
i
])
for
i
in
range
(
3
)])
x
,
y
,
z
=
torch
.
meshgrid
(
*
[
torch
.
arange
(
steps
[
i
])
for
i
in
range
(
3
)]
,
indexing
=
"ij"
)
return
to_voxel_centers
(
torch
.
stack
([
x
,
y
,
z
],
-
1
).
reshape
(
-
1
,
3
),
bbox
,
steps
=
steps
)
...
...
@@ -96,7 +90,7 @@ def to_voxel_centers(grid_coords: torch.Tensor, bbox: torch.Tensor, *,
:param steps `Tensor(1|D)`: (optional) steps alone every dim
:return `Tensor(N..., D)`: discretized grid coordinates
"""
grid_coords
=
grid_coords
.
float
()
+
0
.5
grid_coords
=
grid_coords
.
float
()
+
.
5
if
step_size
is
not
None
:
return
grid_coords
*
step_size
+
bbox
[
0
]
return
grid_coords
/
steps
*
(
bbox
[
1
]
-
bbox
[
0
])
+
bbox
[
0
]
...
...
@@ -121,8 +115,8 @@ def split_voxels_local(voxel_size: Union[torch.Tensor, float], n: int, align_bor
dtype
=
like
.
dtype
device
=
like
.
device
c
=
torch
.
arange
(
1
-
n
,
n
,
2
,
dtype
=
dtype
,
device
=
device
)
offset
=
torch
.
stack
(
torch
.
meshgrid
([
c
]
*
dims
),
-
1
).
flatten
(
0
,
-
2
)
*
voxel_size
/
2
/
\
(
n
-
1
if
align_border
else
n
)
offset
=
torch
.
stack
(
torch
.
meshgrid
([
c
]
*
dims
,
indexing
=
'ij'
),
-
1
).
flatten
(
0
,
-
2
)
\
*
voxel_size
*
.
5
/
(
n
-
1
if
align_border
else
n
)
return
offset
...
...
@@ -144,7 +138,7 @@ def split_voxels(voxel_centers: torch.Tensor, voxel_size: Union[torch.Tensor, fl
def
get_corners
(
voxel_centers
:
torch
.
Tensor
,
bbox
:
torch
.
Tensor
,
steps
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
half_voxel_size
=
(
bbox
[
1
]
-
bbox
[
0
])
/
steps
*
0.5
expand_bbox
=
bbox
expand_bbox
=
bbox
.
clone
()
expand_bbox
[
0
]
-=
0.5
*
half_voxel_size
expand_bbox
[
1
]
+=
0.5
*
half_voxel_size
double_grid_coords
=
to_grid_coords
(
voxel_centers
,
expand_bbox
,
step_size
=
half_voxel_size
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment