Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
Nianchen Deng
deeplightfield
Commits
4860d0af
Commit
4860d0af
authored
Jun 26, 2022
by
Nianchen Deng
Browse files
a runable version
parent
338ae906
Changes
4
Show whitespace changes
Inline
Side-by-side
data/pano_dataset.py
View file @
4860d0af
...
...
@@ -106,7 +106,7 @@ class PanoDataset(object):
"""
self
.
c
=
c
self
.
device
=
device
self
.
_load_desc
(
desc
,
res
,
views_to_load
,
load_images
)
self
.
_load_desc
(
desc
,
data_dir
,
res
,
views_to_load
,
load_images
)
def
get_data
(
self
):
return
{
...
...
data/view_dataset.py
View file @
4860d0af
...
...
@@ -91,7 +91,7 @@ class ViewDataset(object):
rays_o
=
self
.
centers
[
view_idx
]
rays_d
=
self
.
dataset
.
cam_rays
[
pix_idx
]
# (N, 3)
r
=
self
.
rots
[
view_idx
].
movedim
(
-
1
,
-
2
)
# (N, 3, 3)
rays_d
=
torch
.
matmul
(
rays_d
,
r
)
rays_d
=
(
rays_d
[:,
None
]
@
r
)[:,
0
]
extra_data
=
{}
if
self
.
colors
is
not
None
:
extra_data
[
'colors'
]
=
self
.
colors
[
idx
]
...
...
@@ -150,15 +150,15 @@ class ViewDataset(object):
load_depths
:
bool
,
load_bins
:
bool
):
if
load_images
and
desc
.
get
(
'view_file_pattern'
):
self
.
image_path
=
os
.
path
.
join
(
self
.
data_dir
,
desc
[
'view_file_pattern'
])
self
.
image_path
=
os
.
path
.
join
(
os
.
getcwd
()
,
desc
[
'view_file_pattern'
])
else
:
self
.
image_path
=
None
if
load_depths
and
desc
.
get
(
'depth_file_pattern'
):
self
.
depth_path
=
os
.
path
.
join
(
self
.
data_dir
,
desc
[
'depth_file_pattern'
])
self
.
depth_path
=
os
.
path
.
join
(
os
.
getcwd
()
,
desc
[
'depth_file_pattern'
])
else
:
self
.
depth_path
=
None
if
load_bins
and
desc
.
get
(
'bins_file_pattern'
):
self
.
bins_path
=
os
.
path
.
join
(
self
.
data_dir
,
desc
[
'bins_file_pattern'
])
self
.
bins_path
=
os
.
path
.
join
(
os
.
getcwd
()
,
desc
[
'bins_file_pattern'
])
else
:
self
.
bins_path
=
None
self
.
res
=
res
if
res
else
misc
.
values
(
desc
[
'view_res'
],
'y'
,
'x'
)
...
...
modules/sampler.py
View file @
4860d0af
...
...
@@ -164,24 +164,3 @@ class PdfSampler(nn.Module):
return
samples
\ No newline at end of file
class
VoxelSampler
(
nn
.
Module
):
def
__init__
(
self
,
*
,
depth_range
:
Tuple
[
float
,
float
],
n_samples
:
int
,
perturb_sample
:
bool
,
lindisp
:
bool
,
space
):
"""
Initialize a Sampler module
:param depth_range: depth range for sampler
:param n_samples: count to sample along ray
:param perturb_sample: perturb the sample depths
:param lindisp: If True, sample linearly in inverse depth rather than in depth
"""
super
().
__init__
()
self
.
lindisp
=
lindisp
self
.
perturb_sample
=
perturb_sample
self
.
n_samples
=
n_samples
self
.
space
=
space
self
.
s_range
=
(
1
/
depth_range
[
0
],
1
/
depth_range
[
1
])
if
self
.
lindisp
else
depth_range
def
forward
(
self
,
rays_o
,
rays_d
,
*
,
weights
,
s_vals
=
None
,
include_s_vals
=
False
):
\ No newline at end of file
run_spherical_view_syn.py
View file @
4860d0af
...
...
@@ -341,6 +341,7 @@ def test():
load_images
=
args
.
output_flags
[
'perf'
])
data_loader
=
DataLoader
(
dataset
,
TEST_BATCH_SIZE
,
chunk_max_items
=
TEST_MAX_CHUNK_ITEMS
,
shuffle
=
False
)
print
(
dataset
.
image_path
)
# 2. Load trained model
netio
.
load
(
test_model_path
,
model
)
...
...
@@ -367,7 +368,10 @@ def test():
if
args
.
output_flags
[
'perf'
]:
perf
=
Perf
(
True
,
start
=
True
)
for
_
,
rays_o
,
rays_d
,
_
in
data_loader
:
gt
=
[]
for
_
,
rays_o
,
rays_d
,
extra_data
in
data_loader
:
if
args
.
output_flags
[
"perf"
]
and
"colors"
in
extra_data
:
gt
.
append
(
extra_data
[
"colors"
])
n_rays
=
rays_o
.
size
(
0
)
ret
=
model
(
rays_o
,
rays_d
,
ret_depth
=
args
.
output_flags
[
'depth'
],
...
...
@@ -388,8 +392,8 @@ def test():
*
0.5
+
0.5
)
*
(
vals
>
0.1
)
idx
=
slice
(
offset
,
offset
+
n_rays
)
for
key
in
out
:
print
(
"key "
,
key
,
", idx "
,
idx
,
", out is "
,
out
[
key
].
shape
,
", ret is "
,
ret
[
key
].
shape
,
", rays is "
,
n_rays
)
#
print("key ", key, ", idx ", idx, ", out is ",
#
out[key].shape, ", ret is ", ret[key].shape, ", rays is ", n_rays)
out
[
key
][
idx
]
=
ret
[
key
]
if
not
args
.
log_redirect
:
progress_bar
(
i
,
math
.
ceil
(
total_pixels
/
n_rays
),
'Inferring...'
)
...
...
@@ -416,11 +420,11 @@ def test():
if
args
.
output_flags
[
'perf'
]:
perf_errors
=
torch
.
ones
(
n
)
*
NaN
perf_ssims
=
torch
.
ones
(
n
)
*
NaN
if
dataset
.
images
!=
None
:
if
len
(
gt
)
>
0
:
gt
=
torch
.
cat
(
gt
).
reshape
(
n
,
*
dataset
.
res
,
-
1
).
movedim
(
-
1
,
-
3
)
for
i
in
range
(
n
):
perf_errors
[
i
]
=
loss_mse
(
dataset
.
images
[
i
],
out
[
'color'
][
i
]).
item
()
perf_ssims
[
i
]
=
ssim
(
dataset
.
images
[
i
:
i
+
1
],
out
[
'color'
][
i
:
i
+
1
]).
item
()
*
100
perf_errors
[
i
]
=
loss_mse
(
gt
[
i
],
out
[
'color'
][
i
]).
item
()
perf_ssims
[
i
]
=
ssim
(
gt
[
i
:
i
+
1
],
out
[
'color'
][
i
:
i
+
1
]).
item
()
*
100
perf_mean_time
=
tot_time
/
n
perf_mean_error
=
torch
.
mean
(
perf_errors
).
item
()
perf_name
=
'perf_%s_%.1fms_%.2e.csv'
%
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment