Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
Nianchen Deng
deeplightfield
Commits
f1dd9e3a
Commit
f1dd9e3a
authored
Sep 06, 2021
by
Nianchen Deng
Browse files
tog'21 baseline
parent
c10f614f
Changes
4
Show whitespace changes
Inline
Side-by-side
data/dataset_factory.py
0 → 100644
View file @
f1dd9e3a
import
os
import
json
import
utils.device
from
.pano_dataset
import
PanoDataset
from
.view_dataset
import
ViewDataset
class
DatasetFactory
(
object
):
@
staticmethod
def
load
(
path
,
device
=
None
,
**
kwargs
):
device
=
device
or
utils
.
device
.
default
()
data_dir
=
os
.
path
.
dirname
(
path
)
with
open
(
path
,
'r'
,
encoding
=
'utf-8'
)
as
file
:
data_desc
=
json
.
loads
(
file
.
read
())
cwd
=
os
.
getcwd
()
os
.
chdir
(
data_dir
)
if
'type'
in
data_desc
and
data_desc
[
'type'
]
==
'pano'
:
dataset
=
PanoDataset
(
data_desc
,
device
=
device
,
**
kwargs
)
else
:
dataset
=
ViewDataset
(
data_desc
,
device
=
device
,
**
kwargs
)
os
.
chdir
(
cwd
)
return
dataset
\ No newline at end of file
data/loader.py
View file @
f1dd9e3a
from
doctest
import
debug_script
from
logging
import
*
import
threading
import
torch
import
torch
import
math
import
math
from
utils
import
device
class
FastDataLoader
(
object
):
class
Preloader
(
object
):
def
__init__
(
self
,
device
=
None
)
->
None
:
super
().
__init__
()
self
.
stream
=
torch
.
cuda
.
Stream
(
device
)
self
.
event_chunk_loaded
=
None
def
preload_chunk
(
self
,
chunk
):
if
self
.
event_chunk_loaded
is
not
None
:
self
.
event_chunk_loaded
.
wait
()
if
chunk
.
loaded
:
return
# print(f'Preloader: preload chunk #{chunk.id}')
self
.
event_chunk_loaded
=
threading
.
Event
()
threading
.
Thread
(
target
=
Preloader
.
_load_chunk
,
args
=
(
self
,
chunk
)).
start
()
def
_load_chunk
(
self
,
chunk
):
with
torch
.
cuda
.
stream
(
self
.
stream
):
chunk
.
load
()
self
.
event_chunk_loaded
.
set
()
# print(f'Preloader: chunk #{chunk.id} is loaded')
class
DataLoader
(
object
):
class
Iter
(
object
):
class
Iter
(
object
):
def
__init__
(
self
,
dataset
,
batch_size
,
shuffle
,
d
rop_last
)
->
None
:
def
__init__
(
self
,
chunks
,
batch_size
,
shuffle
,
d
evice
:
torch
.
device
,
preloader
:
Preloader
)
:
super
().
__init__
()
super
().
__init__
()
self
.
indices
=
torch
.
randperm
(
len
(
dataset
),
device
=
device
.
default
())
\
if
shuffle
else
torch
.
arange
(
len
(
dataset
),
device
=
device
.
default
())
self
.
offset
=
0
self
.
batch_size
=
batch_size
self
.
batch_size
=
batch_size
self
.
dataset
=
dataset
self
.
chunks
=
chunks
self
.
drop_last
=
drop_last
self
.
offset
=
-
1
self
.
chunk_idx
=
-
1
self
.
current_chunk
=
None
self
.
shuffle
=
shuffle
self
.
device
=
device
self
.
preloader
=
preloader
def
__del__
(
self
):
#print('DataLoader.Iter: clean chunks')
if
self
.
preloader
is
not
None
and
self
.
preloader
.
event_chunk_loaded
is
not
None
:
self
.
preloader
.
event_chunk_loaded
.
wait
()
chunks_to_reserve
=
1
if
self
.
preloader
is
None
else
2
for
i
in
range
(
chunks_to_reserve
,
len
(
self
.
chunks
)):
if
self
.
chunks
[
i
].
loaded
:
self
.
chunks
[
i
].
release
()
def
__next__
(
self
):
def
__next__
(
self
):
if
self
.
offset
+
(
self
.
batch_size
if
self
.
drop_last
else
0
)
>=
len
(
self
.
dataset
):
if
self
.
offset
==
-
1
:
self
.
_next_chunk
()
stop
=
min
(
self
.
offset
+
self
.
batch_size
,
len
(
self
.
current_chunk
))
if
self
.
indices
is
not
None
:
indices
=
self
.
indices
[
self
.
offset
:
stop
]
else
:
indices
=
torch
.
arange
(
self
.
offset
,
stop
,
device
=
self
.
device
)
self
.
offset
=
stop
if
self
.
offset
>=
len
(
self
.
current_chunk
):
self
.
offset
=
-
1
return
self
.
current_chunk
[
indices
]
def
_next_chunk
(
self
):
if
self
.
current_chunk
is
not
None
:
chunks_to_reserve
=
1
if
self
.
preloader
is
None
else
2
if
len
(
self
.
chunks
)
>
chunks_to_reserve
:
self
.
current_chunk
.
release
()
if
self
.
chunk_idx
>=
len
(
self
.
chunks
)
-
1
:
raise
StopIteration
()
raise
StopIteration
()
indices
=
self
.
indices
[
self
.
offset
:
self
.
offset
+
self
.
batch_size
]
self
.
chunk_idx
+=
1
self
.
offset
+=
self
.
batch_size
self
.
current_chunk
=
self
.
chunks
[
self
.
chunk_idx
]
return
self
.
dataset
[
indices
]
self
.
offset
=
0
self
.
indices
=
torch
.
randperm
(
len
(
self
.
current_chunk
),
device
=
self
.
device
)
\
if
self
.
shuffle
else
None
if
self
.
preloader
is
not
None
:
self
.
preloader
.
preload_chunk
(
self
.
chunks
[(
self
.
chunk_idx
+
1
)
%
len
(
self
.
chunks
)])
def
__init__
(
self
,
dataset
,
batch_size
,
shuffle
,
drop_last
=
False
,
**
kwargs
)
->
None
:
def
__init__
(
self
,
dataset
,
batch_size
,
*
,
chunk_max_items
=
None
,
shuffle
=
False
,
enable_preload
=
True
):
super
().
__init__
()
super
().
__init__
()
self
.
dataset
=
dataset
self
.
dataset
=
dataset
self
.
batch_size
=
batch_size
self
.
batch_size
=
batch_size
self
.
shuffle
=
shuffle
self
.
shuffle
=
shuffle
self
.
drop_last
=
drop_last
self
.
preloader
=
Preloader
(
self
.
dataset
.
device
)
if
enable_preload
else
None
self
.
_init_chunks
(
chunk_max_items
)
def
__iter__
(
self
):
def
__iter__
(
self
):
return
Fast
DataLoader
.
Iter
(
self
.
dataset
,
self
.
batch_size
,
return
DataLoader
.
Iter
(
self
.
chunks
,
self
.
batch_size
,
self
.
shuffle
,
self
.
dataset
.
device
,
self
.
shuffle
,
self
.
drop_last
)
self
.
preloader
)
def
__len__
(
self
):
def
__len__
(
self
):
return
math
.
floor
(
len
(
self
.
dataset
)
/
self
.
batch_size
)
if
self
.
drop_last
\
return
sum
(
math
.
ceil
(
len
(
chunk
)
/
self
.
batch_size
)
for
chunk
in
self
.
chunks
)
else
math
.
ceil
(
len
(
self
.
dataset
)
/
self
.
batch_size
)
def
_init_chunks
(
self
,
chunk_max_items
):
data
=
self
.
dataset
.
get_data
()
if
self
.
shuffle
:
rand_seq
=
torch
.
randperm
(
self
.
dataset
.
n_views
,
device
=
self
.
dataset
.
device
)
for
key
in
data
:
data
[
key
]
=
data
[
key
][
rand_seq
]
self
.
chunks
=
[]
n_chunks
=
1
if
chunk_max_items
is
None
else
\
math
.
ceil
(
self
.
dataset
.
n_pixels
/
chunk_max_items
)
views_per_chunk
=
math
.
ceil
(
self
.
dataset
.
n_views
/
n_chunks
)
for
offset
in
range
(
0
,
self
.
dataset
.
n_views
,
views_per_chunk
):
sel
=
slice
(
offset
,
offset
+
views_per_chunk
)
chunk_data
=
{}
for
key
in
data
:
chunk_data
[
key
]
=
data
[
key
][
sel
]
self
.
chunks
.
append
(
self
.
dataset
.
Chunk
(
len
(
self
.
chunks
),
self
.
dataset
,
**
chunk_data
))
if
self
.
preloader
is
not
None
:
self
.
preloader
.
preload_chunk
(
self
.
chunks
[
0
])
data/pano_dataset.py
0 → 100644
View file @
f1dd9e3a
import
os
import
torch
import
torch.nn.functional
as
nn_f
from
typing
import
Tuple
,
Union
from
utils
import
img
from
utils
import
color
from
utils
import
misc
from
utils
import
sphere
from
utils.mem_profiler
import
*
from
utils.constants
import
*
class
PanoDataset
(
object
):
"""
Data loader for spherical view synthesis task
Attributes
--------
data_dir ```str```: the directory of dataset
\n
view_file_pattern ```str```: the filename pattern of view images
\n
cam_params ```object```: camera intrinsic parameters
\n
centers ```Tensor(N, 3)```: centers of views
\n
view_rots ```Tensor(N, 3, 3)```: rotation matrices of views
\n
images ```Tensor(N, 3, H, W)```: images of views
\n
view_depths ```Tensor(N, H, W)```: depths of views
\n
"""
class
Chunk
(
object
):
def
__init__
(
self
,
id
,
dataset
,
*
,
indices
:
torch
.
Tensor
,
centers
:
torch
.
Tensor
):
"""
[summary]
:param dataset `PanoDataset`: dataset object
:param indices `Tensor(N)`: indices of views
:param centers `Tensor(N, 3)`: centers of views
"""
self
.
id
=
id
self
.
dataset
=
dataset
self
.
indices
=
indices
self
.
centers
=
centers
self
.
n_views
=
self
.
indices
.
size
(
0
)
self
.
n_pixels_per_view
=
self
.
dataset
.
res
[
0
]
*
self
.
dataset
.
res
[
1
]
self
.
colors_cpu
=
None
self
.
colors
=
None
self
.
loaded
=
False
def
release
(
self
):
self
.
colors
=
None
self
.
loaded
=
False
MemProfiler
.
print_memory_stats
(
f
'Chunk #
{
self
.
id
}
released'
)
def
load
(
self
):
if
self
.
dataset
.
image_path
is
not
None
and
self
.
colors_cpu
is
None
:
images
=
color
.
cvt
(
img
.
load
(
self
.
dataset
.
image_path
%
i
for
i
in
self
.
indices
),
color
.
RGB
,
self
.
dataset
.
c
)
if
self
.
dataset
.
res
!=
list
(
images
.
shape
[
-
2
:]):
images
=
nn_f
.
interpolate
(
images
,
self
.
dataset
.
res
)
self
.
colors_cpu
=
images
.
permute
(
0
,
2
,
3
,
1
).
flatten
(
0
,
2
)
if
self
.
colors_cpu
is
not
None
:
self
.
colors
=
self
.
colors_cpu
.
to
(
self
.
dataset
.
device
)
self
.
loaded
=
True
MemProfiler
.
print_memory_stats
(
f
'Chunk #
{
self
.
id
}
(
{
self
.
n_views
}
views, '
f
'
{
self
.
colors
.
numel
()
*
self
.
colors
.
element_size
()
/
1024
/
1024
:.
2
f
}
MB) loaded'
)
def
__len__
(
self
):
return
self
.
n_views
*
self
.
n_pixels_per_view
def
__getitem__
(
self
,
idx
):
if
not
self
.
loaded
:
self
.
load
()
view_idx
=
idx
//
self
.
n_pixels_per_view
pix_idx
=
idx
%
self
.
n_pixels_per_view
extra_data
=
{}
if
self
.
colors
is
not
None
:
extra_data
[
'colors'
]
=
self
.
colors
[
idx
]
rays_o
=
self
.
centers
[
view_idx
]
rays_d
=
self
.
dataset
.
pano_rays
[
pix_idx
]
return
idx
,
rays_o
,
rays_d
,
extra_data
def
__init__
(
self
,
desc
:
dict
,
*
,
c
:
int
=
color
.
RGB
,
load_images
:
bool
=
True
,
res
:
Tuple
[
int
,
int
]
=
None
,
views_to_load
:
Union
[
range
,
torch
.
Tensor
]
=
None
,
device
:
torch
.
device
=
None
,
**
kwargs
):
"""
Initialize data loader for spherical view synthesis task
The dataset description file is a JSON file with following fields:
- view_file_pattern: string, the path pattern of view images
- view_res: { "x", "y" }, the resolution of view
- depth_range: { "min", "max" }, the depth range
- range: { "min": [...], "max": [...] }, the range of translation and rotation
- centers: [ [ x, y, z ], ... ], centers of views
:param desc_path ```str```: path to the data description file
:param load_images ```bool```: whether load view images and return in __getitem__()
:param c ```int```: color space to convert view images to
:param calculate_rays ```bool```: whether calculate rays
"""
self
.
c
=
c
self
.
device
=
device
self
.
_load_desc
(
desc
,
res
,
views_to_load
,
load_images
)
def
get_data
(
self
):
return
{
'indices'
:
self
.
indices
,
'centers'
:
self
.
centers
}
def
_load_desc
(
self
,
desc
:
dict
,
res
:
Tuple
[
int
,
int
],
views_to_load
:
Union
[
range
,
torch
.
Tensor
],
load_images
:
bool
):
if
load_images
and
desc
.
get
(
'view_file_pattern'
):
self
.
image_path
=
os
.
path
.
join
(
os
.
getcwd
(),
desc
[
'view_file_pattern'
])
else
:
self
.
image_path
=
None
self
.
res
=
res
if
res
else
misc
.
values
(
desc
[
'view_res'
],
'y'
,
'x'
)
self
.
depth_range
=
misc
.
values
(
desc
[
'depth_range'
],
'min'
,
'max'
)
\
if
'depth_range'
in
desc
else
None
self
.
range
=
misc
.
values
(
desc
[
'range'
],
'min'
,
'max'
)
if
'range'
in
desc
else
None
self
.
samples
=
desc
.
get
(
'samples'
)
self
.
centers
=
torch
.
tensor
(
desc
[
'view_centers'
],
device
=
self
.
device
)
# (N, 3)
self
.
indices
=
torch
.
tensor
(
desc
[
'views'
]
if
'views'
in
desc
else
list
(
range
(
self
.
centers
.
size
(
0
))),
device
=
self
.
device
)
if
views_to_load
is
not
None
:
self
.
centers
=
self
.
centers
[
views_to_load
]
self
.
indices
=
self
.
indices
[
views_to_load
]
self
.
n_views
=
self
.
centers
.
size
(
0
)
self
.
n_pixels
=
self
.
n_views
*
self
.
res
[
0
]
*
self
.
res
[
1
]
self
.
pano_rays
=
self
.
_get_pano_rays
()
# [H*W, 3]
if
desc
.
get
(
'gl_coord'
):
print
(
'Convert from OGL coordinate to DX coordinate (i. e. flip z axis)'
)
self
.
centers
[:,
2
]
*=
-
1
def
_get_pano_rays
(
self
):
"""
Get unprojected rays of pixels on a panorama
:return `Tensor(H*W, 3)`: rays' directions with one unit length
"""
spher_coords
=
torch
.
cat
([
torch
.
ones
(
*
self
.
res
,
1
),
((
misc
.
meshgrid
(
*
self
.
res
,
normalize
=
True
))
*
torch
.
tensor
([
-
2.0
,
1.0
])
+
torch
.
tensor
([
1.5
,
0.0
]))
*
PI
],
dim
=-
1
).
to
(
device
=
self
.
device
)
coords
=
sphere
.
spherical2cartesian
(
spher_coords
)
return
coords
.
flatten
(
0
,
1
)
# [H*W, 3]
data/view_dataset.py
0 → 100644
View file @
f1dd9e3a
import
os
import
torch
import
torch.nn.functional
as
nn_f
from
typing
import
Tuple
,
Union
from
utils
import
img
from
utils
import
view
from
utils
import
color
from
utils
import
misc
class
ViewDataset
(
object
):
"""
Data loader for spherical view synthesis task
Attributes
--------
data_dir ```str```: the directory of dataset
\n
view_file_pattern ```str```: the filename pattern of view images
\n
cam ```object```: camera intrinsic parameters
\n
view_centers ```Tensor(N, 3)```: centers of views
\n
view_rots ```Tensor(N, 3, 3)```: rotation matrices of views
\n
view_images ```Tensor(N, 3, H, W)```: images of views
\n
view_depths ```Tensor(N, H, W)```: depths of views
\n
"""
class
Chunk
(
object
):
def
__init__
(
self
,
id
,
dataset
,
*
,
indices
:
torch
.
Tensor
,
centers
:
torch
.
Tensor
,
rots
:
torch
.
Tensor
):
"""
[summary]
:param dataset `PanoDataset`: dataset object
:param indices `Tensor(N)`: indices of views
:param centers `Tensor(N, 3)`: centers of views
"""
self
.
id
=
id
self
.
dataset
=
dataset
self
.
indices
=
indices
self
.
centers
=
centers
self
.
rots
=
rots
self
.
n_views
=
self
.
indices
.
size
(
0
)
self
.
n_pixels_per_view
=
self
.
dataset
.
res
[
0
]
*
self
.
dataset
.
res
[
1
]
self
.
colors
=
self
.
depths
=
self
.
bins
=
None
self
.
colors_cpu
=
self
.
depths_cpu
=
self
.
bins_cpu
=
None
self
.
loaded
=
False
def
release
(
self
):
self
.
colors
=
self
.
depths
=
self
.
bins
=
None
self
.
loaded
=
False
def
load
(
self
):
if
self
.
dataset
.
image_path
and
self
.
colors_cpu
is
None
:
images
=
color
.
cvt
(
img
.
load
(
self
.
dataset
.
image_path
%
i
for
i
in
self
.
indices
),
color
.
RGB
,
self
.
dataset
.
c
)
if
self
.
dataset
.
res
!=
list
(
images
.
shape
[
-
2
:]):
images
=
nn_f
.
interpolate
(
images
,
self
.
dataset
.
res
)
self
.
colors_cpu
=
images
.
permute
(
0
,
2
,
3
,
1
).
flatten
(
0
,
2
)
if
self
.
colors_cpu
is
not
None
:
self
.
colors
=
self
.
colors_cpu
.
to
(
self
.
dataset
.
device
,
non_blocking
=
True
)
if
self
.
dataset
.
depth_path
and
self
.
depths_cpu
is
None
:
depths
=
self
.
dataset
.
_decode_depth_images
(
img
.
load
(
self
.
depth_path
%
i
for
i
in
self
.
indices
))
if
self
.
dataset
.
res
!=
list
(
depths
.
shape
[
-
2
:]):
depths
=
nn_f
.
interpolate
(
depths
,
self
.
dataset
.
res
)
self
.
depths_cpu
=
depths
.
flatten
(
0
,
2
)
if
self
.
depths_cpu
is
not
None
:
self
.
depths
=
self
.
depths_cpu
.
to
(
self
.
dataset
.
device
,
non_blocking
=
True
)
if
self
.
dataset
.
bins_path
and
self
.
bins_cpu
is
None
:
bins
=
img
.
load
([
self
.
dataset
.
bins_path
%
i
for
i
in
self
.
indices
])
if
self
.
dataset
.
res
!=
list
(
bins
.
shape
[
-
2
:]):
bins
=
nn_f
.
interpolate
(
bins
,
self
.
dataset
.
res
)
self
.
bins_cpu
=
bins
.
permute
(
0
,
2
,
3
,
1
).
flatten
(
0
,
2
)
if
self
.
bins_cpu
is
not
None
:
self
.
bins
=
self
.
bins_cpu
.
to
(
self
.
dataset
.
device
,
non_blocking
=
True
)
torch
.
cuda
.
current_stream
(
self
.
dataset
.
device
).
synchronize
()
self
.
loaded
=
True
def
__len__
(
self
):
return
self
.
n_views
*
self
.
n_pixels_per_view
def
__getitem__
(
self
,
idx
):
if
not
self
.
loaded
:
self
.
load
()
view_idx
=
idx
//
self
.
n_pixels_per_view
pix_idx
=
idx
%
self
.
n_pixels_per_view
rays_o
=
self
.
centers
[
view_idx
]
rays_d
=
self
.
dataset
.
cam_rays
[
pix_idx
]
# (N, 3)
r
=
self
.
rots
[
view_idx
].
movedim
(
-
1
,
-
2
)
# (N, 3, 3)
rays_d
=
torch
.
matmul
(
rays_d
,
r
)
extra_data
=
{}
if
self
.
colors
is
not
None
:
extra_data
[
'colors'
]
=
self
.
colors
[
idx
]
if
self
.
depths
is
not
None
:
extra_data
[
'depths'
]
=
self
.
depths
[
idx
]
if
self
.
bins
is
not
None
:
extra_data
[
'bins'
]
=
self
.
bins
[
idx
]
return
idx
,
rays_o
,
rays_d
,
extra_data
def
__init__
(
self
,
desc
:
dict
,
*
,
c
:
int
=
color
.
RGB
,
load_images
:
bool
=
True
,
load_depths
:
bool
=
False
,
load_bins
:
bool
=
False
,
res
:
Tuple
[
int
,
int
]
=
None
,
views_to_load
:
Union
[
range
,
torch
.
Tensor
]
=
None
,
device
:
torch
.
device
=
None
,
**
kwargs
):
"""
Initialize data loader for spherical view synthesis task
The dataset description file is a JSON file with following fields:
- view_file_pattern: string, the path pattern of view images
- view_res: { "x", "y" }, the resolution of view images
- cam: { "fx", "fy", "cx", "cy" }, the focal and center of camera (in normalized image space)
- view_centers: [ [ x, y, z ], ... ], centers of views
- view_rots: [ [ m00, m01, ..., m22 ], ... ], rotation matrices of views
:param dataset_desc_path ```str```: path to the data description file
:param load_images ```bool```: whether load view images and return in __getitem__()
:param load_depths ```bool```: whether load depth images and return in __getitem__()
:param c ```int```: color space to convert view images to
:param calculate_rays ```bool```: whether calculate rays
"""
self
.
c
=
c
self
.
device
=
device
self
.
_load_desc
(
desc
,
res
,
views_to_load
,
load_images
,
load_depths
,
load_bins
)
def
get_data
(
self
):
return
{
'indices'
:
self
.
indices
,
'centers'
:
self
.
centers
,
'rots'
:
self
.
rots
}
def
_decode_depth_images
(
self
,
input
):
disp_range
=
(
1
/
self
.
depth_range
[
0
],
1
/
self
.
depth_range
[
1
])
disp_val
=
(
1
-
input
[...,
0
,
:,
:])
*
(
disp_range
[
1
]
-
disp_range
[
0
])
+
disp_range
[
0
]
return
torch
.
reciprocal
(
disp_val
)
def
_load_desc
(
self
,
desc
:
dict
,
res
:
Tuple
[
int
,
int
],
views_to_load
:
Union
[
range
,
torch
.
Tensor
],
load_images
:
bool
,
load_depths
:
bool
,
load_bins
:
bool
):
if
load_images
and
desc
.
get
(
'view_file_pattern'
):
self
.
image_path
=
os
.
path
.
join
(
self
.
data_dir
,
desc
[
'view_file_pattern'
])
else
:
self
.
image_path
=
None
if
load_depths
and
desc
.
get
(
'depth_file_pattern'
):
self
.
depth_path
=
os
.
path
.
join
(
self
.
data_dir
,
desc
[
'depth_file_pattern'
])
else
:
self
.
depth_path
=
None
if
load_bins
and
desc
.
get
(
'bins_file_pattern'
):
self
.
bins_path
=
os
.
path
.
join
(
self
.
data_dir
,
desc
[
'bins_file_pattern'
])
else
:
self
.
bins_path
=
None
self
.
res
=
res
if
res
else
misc
.
values
(
desc
[
'view_res'
],
'y'
,
'x'
)
self
.
cam
=
view
.
CameraParam
(
desc
[
'cam_params'
],
self
.
res
,
device
=
self
.
device
)
self
.
depth_range
=
misc
.
values
(
desc
[
'depth_range'
],
'min'
,
'max'
)
\
if
'depth_range'
in
desc
else
None
self
.
range
=
misc
.
values
(
desc
[
'range'
],
'min'
,
'max'
)
if
'range'
in
desc
else
None
self
.
samples
=
desc
.
get
(
'samples'
)
self
.
centers
=
torch
.
tensor
(
desc
[
'view_centers'
],
device
=
self
.
device
)
# (N, 3)
self
.
rots
=
torch
.
tensor
(
[
view
.
euler_to_matrix
([
rot
[
1
]
if
desc
.
get
(
'gl_coord'
)
else
-
rot
[
1
],
rot
[
0
],
0
])
for
rot
in
desc
[
'view_rots'
]
]
if
len
(
desc
[
'view_rots'
][
0
])
==
2
else
desc
[
'view_rots'
],
device
=
self
.
device
).
view
(
-
1
,
3
,
3
)
# (N, 3, 3)
self
.
indices
=
torch
.
tensor
(
desc
[
'views'
]
if
'views'
in
desc
else
list
(
range
(
self
.
centers
.
size
(
0
))),
device
=
self
.
device
)
if
views_to_load
is
not
None
:
self
.
centers
=
self
.
centers
[
views_to_load
]
self
.
rots
=
self
.
rots
[
views_to_load
]
self
.
indices
=
self
.
indices
[
views_to_load
]
self
.
n_views
=
self
.
centers
.
size
(
0
)
self
.
n_pixels
=
self
.
n_views
*
self
.
res
[
0
]
*
self
.
res
[
1
]
if
desc
.
get
(
'gl_coord'
):
print
(
'Convert from OGL coordinate to DX coordinate (i. e. flip z axis)'
)
if
not
desc
[
'cam_params'
].
get
(
'fov'
):
self
.
cam
.
f
[
1
]
*=
-
1
self
.
centers
[:,
2
]
*=
-
1
self
.
rots
[:,
2
]
*=
-
1
self
.
rots
[...,
2
]
*=
-
1
self
.
cam_rays
=
self
.
cam
.
get_local_rays
(
flatten
=
True
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment