Commit f1dd9e3a authored by Nianchen Deng's avatar Nianchen Deng
Browse files

tog'21 baseline

parent c10f614f
import os
import json
import utils.device
from .pano_dataset import PanoDataset
from .view_dataset import ViewDataset
class DatasetFactory(object):
@staticmethod
def load(path, device=None, **kwargs):
device = device or utils.device.default()
data_dir = os.path.dirname(path)
with open(path, 'r', encoding='utf-8') as file:
data_desc = json.loads(file.read())
cwd = os.getcwd()
os.chdir(data_dir)
if 'type' in data_desc and data_desc['type'] == 'pano':
dataset = PanoDataset(data_desc, device=device, **kwargs)
else:
dataset = ViewDataset(data_desc, device=device, **kwargs)
os.chdir(cwd)
return dataset
\ No newline at end of file
from doctest import debug_script
from logging import *
import threading
import torch
import math
from utils import device
class FastDataLoader(object):
class Preloader(object):
def __init__(self, device=None) -> None:
super().__init__()
self.stream = torch.cuda.Stream(device)
self.event_chunk_loaded = None
def preload_chunk(self, chunk):
if self.event_chunk_loaded is not None:
self.event_chunk_loaded.wait()
if chunk.loaded:
return
# print(f'Preloader: preload chunk #{chunk.id}')
self.event_chunk_loaded = threading.Event()
threading.Thread(target=Preloader._load_chunk, args=(self, chunk)).start()
def _load_chunk(self, chunk):
with torch.cuda.stream(self.stream):
chunk.load()
self.event_chunk_loaded.set()
# print(f'Preloader: chunk #{chunk.id} is loaded')
class DataLoader(object):
class Iter(object):
def __init__(self, dataset, batch_size, shuffle, drop_last) -> None:
def __init__(self, chunks, batch_size, shuffle, device: torch.device, preloader: Preloader):
super().__init__()
self.indices = torch.randperm(len(dataset), device=device.default()) \
if shuffle else torch.arange(len(dataset), device=device.default())
self.offset = 0
self.batch_size = batch_size
self.dataset = dataset
self.drop_last = drop_last
self.chunks = chunks
self.offset = -1
self.chunk_idx = -1
self.current_chunk = None
self.shuffle = shuffle
self.device = device
self.preloader = preloader
def __del__(self):
#print('DataLoader.Iter: clean chunks')
if self.preloader is not None and self.preloader.event_chunk_loaded is not None:
self.preloader.event_chunk_loaded.wait()
chunks_to_reserve = 1 if self.preloader is None else 2
for i in range(chunks_to_reserve, len(self.chunks)):
if self.chunks[i].loaded:
self.chunks[i].release()
def __next__(self):
if self.offset + (self.batch_size if self.drop_last else 0) >= len(self.dataset):
if self.offset == -1:
self._next_chunk()
stop = min(self.offset + self.batch_size, len(self.current_chunk))
if self.indices is not None:
indices = self.indices[self.offset:stop]
else:
indices = torch.arange(self.offset, stop, device=self.device)
self.offset = stop
if self.offset >= len(self.current_chunk):
self.offset = -1
return self.current_chunk[indices]
def _next_chunk(self):
if self.current_chunk is not None:
chunks_to_reserve = 1 if self.preloader is None else 2
if len(self.chunks) > chunks_to_reserve:
self.current_chunk.release()
if self.chunk_idx >= len(self.chunks) - 1:
raise StopIteration()
indices = self.indices[self.offset:self.offset + self.batch_size]
self.offset += self.batch_size
return self.dataset[indices]
self.chunk_idx += 1
self.current_chunk = self.chunks[self.chunk_idx]
self.offset = 0
self.indices = torch.randperm(len(self.current_chunk), device=self.device) \
if self.shuffle else None
if self.preloader is not None:
self.preloader.preload_chunk(self.chunks[(self.chunk_idx + 1) % len(self.chunks)])
def __init__(self, dataset, batch_size, shuffle, drop_last=False, **kwargs) -> None:
def __init__(self, dataset, batch_size, *,
chunk_max_items=None, shuffle=False, enable_preload=True):
super().__init__()
self.dataset = dataset
self.batch_size = batch_size
self.shuffle = shuffle
self.drop_last = drop_last
self.preloader = Preloader(self.dataset.device) if enable_preload else None
self._init_chunks(chunk_max_items)
def __iter__(self):
return FastDataLoader.Iter(self.dataset, self.batch_size,
self.shuffle, self.drop_last)
return DataLoader.Iter(self.chunks, self.batch_size, self.shuffle, self.dataset.device,
self.preloader)
def __len__(self):
return math.floor(len(self.dataset) / self.batch_size) if self.drop_last \
else math.ceil(len(self.dataset) / self.batch_size)
return sum(math.ceil(len(chunk) / self.batch_size) for chunk in self.chunks)
def _init_chunks(self, chunk_max_items):
data = self.dataset.get_data()
if self.shuffle:
rand_seq = torch.randperm(self.dataset.n_views, device=self.dataset.device)
for key in data:
data[key] = data[key][rand_seq]
self.chunks = []
n_chunks = 1 if chunk_max_items is None else \
math.ceil(self.dataset.n_pixels / chunk_max_items)
views_per_chunk = math.ceil(self.dataset.n_views / n_chunks)
for offset in range(0, self.dataset.n_views, views_per_chunk):
sel = slice(offset, offset + views_per_chunk)
chunk_data = {}
for key in data:
chunk_data[key] = data[key][sel]
self.chunks.append(self.dataset.Chunk(len(self.chunks), self.dataset, **chunk_data))
if self.preloader is not None:
self.preloader.preload_chunk(self.chunks[0])
import os
import torch
import torch.nn.functional as nn_f
from typing import Tuple, Union
from utils import img
from utils import color
from utils import misc
from utils import sphere
from utils.mem_profiler import *
from utils.constants import *
class PanoDataset(object):
"""
Data loader for spherical view synthesis task
Attributes
--------
data_dir ```str```: the directory of dataset\n
view_file_pattern ```str```: the filename pattern of view images\n
cam_params ```object```: camera intrinsic parameters\n
centers ```Tensor(N, 3)```: centers of views\n
view_rots ```Tensor(N, 3, 3)```: rotation matrices of views\n
images ```Tensor(N, 3, H, W)```: images of views\n
view_depths ```Tensor(N, H, W)```: depths of views\n
"""
class Chunk(object):
def __init__(self, id, dataset, *,
indices: torch.Tensor, centers: torch.Tensor):
"""
[summary]
:param dataset `PanoDataset`: dataset object
:param indices `Tensor(N)`: indices of views
:param centers `Tensor(N, 3)`: centers of views
"""
self.id = id
self.dataset = dataset
self.indices = indices
self.centers = centers
self.n_views = self.indices.size(0)
self.n_pixels_per_view = self.dataset.res[0] * self.dataset.res[1]
self.colors_cpu = None
self.colors = None
self.loaded = False
def release(self):
self.colors = None
self.loaded = False
MemProfiler.print_memory_stats(f'Chunk #{self.id} released')
def load(self):
if self.dataset.image_path is not None and self.colors_cpu is None:
images = color.cvt(
img.load(self.dataset.image_path % i for i in self.indices),
color.RGB, self.dataset.c)
if self.dataset.res != list(images.shape[-2:]):
images = nn_f.interpolate(images, self.dataset.res)
self.colors_cpu = images.permute(0, 2, 3, 1).flatten(0, 2)
if self.colors_cpu is not None:
self.colors = self.colors_cpu.to(self.dataset.device)
self.loaded = True
MemProfiler.print_memory_stats(
f'Chunk #{self.id} ({self.n_views} views, '
f'{self.colors.numel() * self.colors.element_size() / 1024 / 1024:.2f}MB) loaded')
def __len__(self):
return self.n_views * self.n_pixels_per_view
def __getitem__(self, idx):
if not self.loaded:
self.load()
view_idx = idx // self.n_pixels_per_view
pix_idx = idx % self.n_pixels_per_view
extra_data = {}
if self.colors is not None:
extra_data['colors'] = self.colors[idx]
rays_o = self.centers[view_idx]
rays_d = self.dataset.pano_rays[pix_idx]
return idx, rays_o, rays_d, extra_data
def __init__(self, desc: dict, *,
c: int = color.RGB,
load_images: bool = True,
res: Tuple[int, int] = None,
views_to_load: Union[range, torch.Tensor] = None,
device: torch.device = None,
**kwargs):
"""
Initialize data loader for spherical view synthesis task
The dataset description file is a JSON file with following fields:
- view_file_pattern: string, the path pattern of view images
- view_res: { "x", "y" }, the resolution of view
- depth_range: { "min", "max" }, the depth range
- range: { "min": [...], "max": [...] }, the range of translation and rotation
- centers: [ [ x, y, z ], ... ], centers of views
:param desc_path ```str```: path to the data description file
:param load_images ```bool```: whether load view images and return in __getitem__()
:param c ```int```: color space to convert view images to
:param calculate_rays ```bool```: whether calculate rays
"""
self.c = c
self.device = device
self._load_desc(desc, res, views_to_load, load_images)
def get_data(self):
return {
'indices': self.indices,
'centers': self.centers
}
def _load_desc(self, desc: dict,
res: Tuple[int, int],
views_to_load: Union[range, torch.Tensor],
load_images: bool):
if load_images and desc.get('view_file_pattern'):
self.image_path = os.path.join(os.getcwd(), desc['view_file_pattern'])
else:
self.image_path = None
self.res = res if res else misc.values(desc['view_res'], 'y', 'x')
self.depth_range = misc.values(desc['depth_range'], 'min', 'max') \
if 'depth_range' in desc else None
self.range = misc.values(desc['range'], 'min', 'max') if 'range' in desc else None
self.samples = desc.get('samples')
self.centers = torch.tensor(desc['view_centers'], device=self.device) # (N, 3)
self.indices = torch.tensor(
desc['views'] if 'views' in desc else list(range(self.centers.size(0))),
device=self.device)
if views_to_load is not None:
self.centers = self.centers[views_to_load]
self.indices = self.indices[views_to_load]
self.n_views = self.centers.size(0)
self.n_pixels = self.n_views * self.res[0] * self.res[1]
self.pano_rays = self._get_pano_rays() # [H*W, 3]
if desc.get('gl_coord'):
print('Convert from OGL coordinate to DX coordinate (i. e. flip z axis)')
self.centers[:, 2] *= -1
def _get_pano_rays(self):
"""
Get unprojected rays of pixels on a panorama
:return `Tensor(H*W, 3)`: rays' directions with one unit length
"""
spher_coords = torch.cat([
torch.ones(*self.res, 1),
((misc.meshgrid(*self.res, normalize=True)) *
torch.tensor([-2.0, 1.0]) + torch.tensor([1.5, 0.0])) * PI
], dim=-1).to(device=self.device)
coords = sphere.spherical2cartesian(spher_coords)
return coords.flatten(0, 1) # [H*W, 3]
import os
import torch
import torch.nn.functional as nn_f
from typing import Tuple, Union
from utils import img
from utils import view
from utils import color
from utils import misc
class ViewDataset(object):
"""
Data loader for spherical view synthesis task
Attributes
--------
data_dir ```str```: the directory of dataset\n
view_file_pattern ```str```: the filename pattern of view images\n
cam ```object```: camera intrinsic parameters\n
view_centers ```Tensor(N, 3)```: centers of views\n
view_rots ```Tensor(N, 3, 3)```: rotation matrices of views\n
view_images ```Tensor(N, 3, H, W)```: images of views\n
view_depths ```Tensor(N, H, W)```: depths of views\n
"""
class Chunk(object):
def __init__(self, id, dataset, *,
indices: torch.Tensor, centers: torch.Tensor, rots: torch.Tensor):
"""
[summary]
:param dataset `PanoDataset`: dataset object
:param indices `Tensor(N)`: indices of views
:param centers `Tensor(N, 3)`: centers of views
"""
self.id = id
self.dataset = dataset
self.indices = indices
self.centers = centers
self.rots = rots
self.n_views = self.indices.size(0)
self.n_pixels_per_view = self.dataset.res[0] * self.dataset.res[1]
self.colors = self.depths = self.bins = None
self.colors_cpu = self.depths_cpu = self.bins_cpu = None
self.loaded = False
def release(self):
self.colors = self.depths = self.bins = None
self.loaded = False
def load(self):
if self.dataset.image_path and self.colors_cpu is None:
images = color.cvt(
img.load(self.dataset.image_path % i for i in self.indices),
color.RGB, self.dataset.c)
if self.dataset.res != list(images.shape[-2:]):
images = nn_f.interpolate(images, self.dataset.res)
self.colors_cpu = images.permute(0, 2, 3, 1).flatten(0, 2)
if self.colors_cpu is not None:
self.colors = self.colors_cpu.to(self.dataset.device, non_blocking=True)
if self.dataset.depth_path and self.depths_cpu is None:
depths = self.dataset._decode_depth_images(
img.load(self.depth_path % i for i in self.indices))
if self.dataset.res != list(depths.shape[-2:]):
depths = nn_f.interpolate(depths, self.dataset.res)
self.depths_cpu = depths.flatten(0, 2)
if self.depths_cpu is not None:
self.depths = self.depths_cpu.to(self.dataset.device, non_blocking=True)
if self.dataset.bins_path and self.bins_cpu is None:
bins = img.load([self.dataset.bins_path % i for i in self.indices])
if self.dataset.res != list(bins.shape[-2:]):
bins = nn_f.interpolate(bins, self.dataset.res)
self.bins_cpu = bins.permute(0, 2, 3, 1).flatten(0, 2)
if self.bins_cpu is not None:
self.bins = self.bins_cpu.to(self.dataset.device, non_blocking=True)
torch.cuda.current_stream(self.dataset.device).synchronize()
self.loaded = True
def __len__(self):
return self.n_views * self.n_pixels_per_view
def __getitem__(self, idx):
if not self.loaded:
self.load()
view_idx = idx // self.n_pixels_per_view
pix_idx = idx % self.n_pixels_per_view
rays_o = self.centers[view_idx]
rays_d = self.dataset.cam_rays[pix_idx] # (N, 3)
r = self.rots[view_idx].movedim(-1, -2) # (N, 3, 3)
rays_d = torch.matmul(rays_d, r)
extra_data = {}
if self.colors is not None:
extra_data['colors'] = self.colors[idx]
if self.depths is not None:
extra_data['depths'] = self.depths[idx]
if self.bins is not None:
extra_data['bins'] = self.bins[idx]
return idx, rays_o, rays_d, extra_data
def __init__(self, desc: dict, *,
c: int = color.RGB,
load_images: bool = True,
load_depths: bool = False,
load_bins: bool = False,
res: Tuple[int, int] = None,
views_to_load: Union[range, torch.Tensor] = None,
device: torch.device = None,
**kwargs):
"""
Initialize data loader for spherical view synthesis task
The dataset description file is a JSON file with following fields:
- view_file_pattern: string, the path pattern of view images
- view_res: { "x", "y" }, the resolution of view images
- cam: { "fx", "fy", "cx", "cy" }, the focal and center of camera (in normalized image space)
- view_centers: [ [ x, y, z ], ... ], centers of views
- view_rots: [ [ m00, m01, ..., m22 ], ... ], rotation matrices of views
:param dataset_desc_path ```str```: path to the data description file
:param load_images ```bool```: whether load view images and return in __getitem__()
:param load_depths ```bool```: whether load depth images and return in __getitem__()
:param c ```int```: color space to convert view images to
:param calculate_rays ```bool```: whether calculate rays
"""
self.c = c
self.device = device
self._load_desc(desc, res, views_to_load, load_images, load_depths, load_bins)
def get_data(self):
return {
'indices': self.indices,
'centers': self.centers,
'rots': self.rots
}
def _decode_depth_images(self, input):
disp_range = (1 / self.depth_range[0], 1 / self.depth_range[1])
disp_val = (1 - input[..., 0, :, :]) * (disp_range[1] - disp_range[0]) + disp_range[0]
return torch.reciprocal(disp_val)
def _load_desc(self, desc: dict,
res: Tuple[int, int],
views_to_load: Union[range, torch.Tensor],
load_images: bool,
load_depths: bool,
load_bins: bool):
if load_images and desc.get('view_file_pattern'):
self.image_path = os.path.join(self.data_dir, desc['view_file_pattern'])
else:
self.image_path = None
if load_depths and desc.get('depth_file_pattern'):
self.depth_path = os.path.join(self.data_dir, desc['depth_file_pattern'])
else:
self.depth_path = None
if load_bins and desc.get('bins_file_pattern'):
self.bins_path = os.path.join(self.data_dir, desc['bins_file_pattern'])
else:
self.bins_path = None
self.res = res if res else misc.values(desc['view_res'], 'y', 'x')
self.cam = view.CameraParam(desc['cam_params'], self.res, device=self.device)
self.depth_range = misc.values(desc['depth_range'], 'min', 'max') \
if 'depth_range' in desc else None
self.range = misc.values(desc['range'], 'min', 'max') if 'range' in desc else None
self.samples = desc.get('samples')
self.centers = torch.tensor(desc['view_centers'], device=self.device) # (N, 3)
self.rots = torch.tensor(
[
view.euler_to_matrix([rot[1] if desc.get('gl_coord') else -rot[1], rot[0], 0])
for rot in desc['view_rots']
]
if len(desc['view_rots'][0]) == 2 else desc['view_rots'],
device=self.device).view(-1, 3, 3) # (N, 3, 3)
self.indices = torch.tensor(
desc['views'] if 'views' in desc else list(range(self.centers.size(0))),
device=self.device)
if views_to_load is not None:
self.centers = self.centers[views_to_load]
self.rots = self.rots[views_to_load]
self.indices = self.indices[views_to_load]
self.n_views = self.centers.size(0)
self.n_pixels = self.n_views * self.res[0] * self.res[1]
if desc.get('gl_coord'):
print('Convert from OGL coordinate to DX coordinate (i. e. flip z axis)')
if not desc['cam_params'].get('fov'):
self.cam.f[1] *= -1
self.centers[:, 2] *= -1
self.rots[:, 2] *= -1
self.rots[..., 2] *= -1
self.cam_rays = self.cam.get_local_rays(flatten=True)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment