From f1dd9e3a8aad6daa779881ec487ce1f8251bfe0e Mon Sep 17 00:00:00 2001 From: Nianchen Deng <dengnianchen@sjtu.edu.cn> Date: Mon, 6 Sep 2021 10:22:39 +0800 Subject: [PATCH] tog'21 baseline --- data/dataset_factory.py | 23 +++++ data/loader.py | 113 +++++++++++++++++++---- data/pano_dataset.py | 159 ++++++++++++++++++++++++++++++++ data/view_dataset.py | 198 ++++++++++++++++++++++++++++++++++++++++ 4 files changed, 475 insertions(+), 18 deletions(-) create mode 100644 data/dataset_factory.py create mode 100644 data/pano_dataset.py create mode 100644 data/view_dataset.py diff --git a/data/dataset_factory.py b/data/dataset_factory.py new file mode 100644 index 0000000..7a1f7c2 --- /dev/null +++ b/data/dataset_factory.py @@ -0,0 +1,23 @@ +import os +import json +import utils.device +from .pano_dataset import PanoDataset +from .view_dataset import ViewDataset + + +class DatasetFactory(object): + + @staticmethod + def load(path, device=None, **kwargs): + device = device or utils.device.default() + data_dir = os.path.dirname(path) + with open(path, 'r', encoding='utf-8') as file: + data_desc = json.loads(file.read()) + cwd = os.getcwd() + os.chdir(data_dir) + if 'type' in data_desc and data_desc['type'] == 'pano': + dataset = PanoDataset(data_desc, device=device, **kwargs) + else: + dataset = ViewDataset(data_desc, device=device, **kwargs) + os.chdir(cwd) + return dataset \ No newline at end of file diff --git a/data/loader.py b/data/loader.py index 9510acd..49163cd 100644 --- a/data/loader.py +++ b/data/loader.py @@ -1,39 +1,116 @@ +from doctest import debug_script +from logging import * +import threading import torch import math -from utils import device -class FastDataLoader(object): +class Preloader(object): + + def __init__(self, device=None) -> None: + super().__init__() + self.stream = torch.cuda.Stream(device) + self.event_chunk_loaded = None + + def preload_chunk(self, chunk): + if self.event_chunk_loaded is not None: + self.event_chunk_loaded.wait() + if chunk.loaded: + return + # print(f'Preloader: preload chunk #{chunk.id}') + self.event_chunk_loaded = threading.Event() + threading.Thread(target=Preloader._load_chunk, args=(self, chunk)).start() + + def _load_chunk(self, chunk): + with torch.cuda.stream(self.stream): + chunk.load() + self.event_chunk_loaded.set() + # print(f'Preloader: chunk #{chunk.id} is loaded') + + +class DataLoader(object): class Iter(object): - def __init__(self, dataset, batch_size, shuffle, drop_last) -> None: + def __init__(self, chunks, batch_size, shuffle, device: torch.device, preloader: Preloader): super().__init__() - self.indices = torch.randperm(len(dataset), device=device.default()) \ - if shuffle else torch.arange(len(dataset), device=device.default()) - self.offset = 0 self.batch_size = batch_size - self.dataset = dataset - self.drop_last = drop_last + self.chunks = chunks + self.offset = -1 + self.chunk_idx = -1 + self.current_chunk = None + self.shuffle = shuffle + self.device = device + self.preloader = preloader + + def __del__(self): + #print('DataLoader.Iter: clean chunks') + if self.preloader is not None and self.preloader.event_chunk_loaded is not None: + self.preloader.event_chunk_loaded.wait() + chunks_to_reserve = 1 if self.preloader is None else 2 + for i in range(chunks_to_reserve, len(self.chunks)): + if self.chunks[i].loaded: + self.chunks[i].release() def __next__(self): - if self.offset + (self.batch_size if self.drop_last else 0) >= len(self.dataset): + if self.offset == -1: + self._next_chunk() + stop = min(self.offset + self.batch_size, len(self.current_chunk)) + if self.indices is not None: + indices = self.indices[self.offset:stop] + else: + indices = torch.arange(self.offset, stop, device=self.device) + self.offset = stop + if self.offset >= len(self.current_chunk): + self.offset = -1 + return self.current_chunk[indices] + + def _next_chunk(self): + if self.current_chunk is not None: + chunks_to_reserve = 1 if self.preloader is None else 2 + if len(self.chunks) > chunks_to_reserve: + self.current_chunk.release() + if self.chunk_idx >= len(self.chunks) - 1: raise StopIteration() - indices = self.indices[self.offset:self.offset + self.batch_size] - self.offset += self.batch_size - return self.dataset[indices] + self.chunk_idx += 1 + self.current_chunk = self.chunks[self.chunk_idx] + self.offset = 0 + self.indices = torch.randperm(len(self.current_chunk), device=self.device) \ + if self.shuffle else None + if self.preloader is not None: + self.preloader.preload_chunk(self.chunks[(self.chunk_idx + 1) % len(self.chunks)]) - def __init__(self, dataset, batch_size, shuffle, drop_last=False, **kwargs) -> None: + def __init__(self, dataset, batch_size, *, + chunk_max_items=None, shuffle=False, enable_preload=True): super().__init__() self.dataset = dataset self.batch_size = batch_size self.shuffle = shuffle - self.drop_last = drop_last + self.preloader = Preloader(self.dataset.device) if enable_preload else None + self._init_chunks(chunk_max_items) def __iter__(self): - return FastDataLoader.Iter(self.dataset, self.batch_size, - self.shuffle, self.drop_last) + return DataLoader.Iter(self.chunks, self.batch_size, self.shuffle, self.dataset.device, + self.preloader) def __len__(self): - return math.floor(len(self.dataset) / self.batch_size) if self.drop_last \ - else math.ceil(len(self.dataset) / self.batch_size) + return sum(math.ceil(len(chunk) / self.batch_size) for chunk in self.chunks) + + def _init_chunks(self, chunk_max_items): + data = self.dataset.get_data() + if self.shuffle: + rand_seq = torch.randperm(self.dataset.n_views, device=self.dataset.device) + for key in data: + data[key] = data[key][rand_seq] + self.chunks = [] + n_chunks = 1 if chunk_max_items is None else \ + math.ceil(self.dataset.n_pixels / chunk_max_items) + views_per_chunk = math.ceil(self.dataset.n_views / n_chunks) + for offset in range(0, self.dataset.n_views, views_per_chunk): + sel = slice(offset, offset + views_per_chunk) + chunk_data = {} + for key in data: + chunk_data[key] = data[key][sel] + self.chunks.append(self.dataset.Chunk(len(self.chunks), self.dataset, **chunk_data)) + if self.preloader is not None: + self.preloader.preload_chunk(self.chunks[0]) diff --git a/data/pano_dataset.py b/data/pano_dataset.py new file mode 100644 index 0000000..9953c8f --- /dev/null +++ b/data/pano_dataset.py @@ -0,0 +1,159 @@ +import os +import torch +import torch.nn.functional as nn_f +from typing import Tuple, Union +from utils import img +from utils import color +from utils import misc +from utils import sphere +from utils.mem_profiler import * +from utils.constants import * + + +class PanoDataset(object): + """ + Data loader for spherical view synthesis task + + Attributes + -------- + data_dir ```str```: the directory of dataset\n + view_file_pattern ```str```: the filename pattern of view images\n + cam_params ```object```: camera intrinsic parameters\n + centers ```Tensor(N, 3)```: centers of views\n + view_rots ```Tensor(N, 3, 3)```: rotation matrices of views\n + images ```Tensor(N, 3, H, W)```: images of views\n + view_depths ```Tensor(N, H, W)```: depths of views\n + """ + + class Chunk(object): + + def __init__(self, id, dataset, *, + indices: torch.Tensor, centers: torch.Tensor): + """ + [summary] + + :param dataset `PanoDataset`: dataset object + :param indices `Tensor(N)`: indices of views + :param centers `Tensor(N, 3)`: centers of views + """ + self.id = id + self.dataset = dataset + self.indices = indices + self.centers = centers + self.n_views = self.indices.size(0) + self.n_pixels_per_view = self.dataset.res[0] * self.dataset.res[1] + self.colors_cpu = None + self.colors = None + self.loaded = False + + def release(self): + self.colors = None + self.loaded = False + MemProfiler.print_memory_stats(f'Chunk #{self.id} released') + + def load(self): + if self.dataset.image_path is not None and self.colors_cpu is None: + images = color.cvt( + img.load(self.dataset.image_path % i for i in self.indices), + color.RGB, self.dataset.c) + if self.dataset.res != list(images.shape[-2:]): + images = nn_f.interpolate(images, self.dataset.res) + self.colors_cpu = images.permute(0, 2, 3, 1).flatten(0, 2) + if self.colors_cpu is not None: + self.colors = self.colors_cpu.to(self.dataset.device) + self.loaded = True + MemProfiler.print_memory_stats( + f'Chunk #{self.id} ({self.n_views} views, ' + f'{self.colors.numel() * self.colors.element_size() / 1024 / 1024:.2f}MB) loaded') + + def __len__(self): + return self.n_views * self.n_pixels_per_view + + def __getitem__(self, idx): + if not self.loaded: + self.load() + view_idx = idx // self.n_pixels_per_view + pix_idx = idx % self.n_pixels_per_view + extra_data = {} + if self.colors is not None: + extra_data['colors'] = self.colors[idx] + rays_o = self.centers[view_idx] + rays_d = self.dataset.pano_rays[pix_idx] + return idx, rays_o, rays_d, extra_data + + def __init__(self, desc: dict, *, + c: int = color.RGB, + load_images: bool = True, + res: Tuple[int, int] = None, + views_to_load: Union[range, torch.Tensor] = None, + device: torch.device = None, + **kwargs): + """ + Initialize data loader for spherical view synthesis task + + The dataset description file is a JSON file with following fields: + + - view_file_pattern: string, the path pattern of view images + - view_res: { "x", "y" }, the resolution of view + - depth_range: { "min", "max" }, the depth range + - range: { "min": [...], "max": [...] }, the range of translation and rotation + - centers: [ [ x, y, z ], ... ], centers of views + + :param desc_path ```str```: path to the data description file + :param load_images ```bool```: whether load view images and return in __getitem__() + :param c ```int```: color space to convert view images to + :param calculate_rays ```bool```: whether calculate rays + """ + self.c = c + self.device = device + self._load_desc(desc, res, views_to_load, load_images) + + def get_data(self): + return { + 'indices': self.indices, + 'centers': self.centers + } + + def _load_desc(self, desc: dict, + res: Tuple[int, int], + views_to_load: Union[range, torch.Tensor], + load_images: bool): + if load_images and desc.get('view_file_pattern'): + self.image_path = os.path.join(os.getcwd(), desc['view_file_pattern']) + else: + self.image_path = None + self.res = res if res else misc.values(desc['view_res'], 'y', 'x') + self.depth_range = misc.values(desc['depth_range'], 'min', 'max') \ + if 'depth_range' in desc else None + self.range = misc.values(desc['range'], 'min', 'max') if 'range' in desc else None + self.samples = desc.get('samples') + self.centers = torch.tensor(desc['view_centers'], device=self.device) # (N, 3) + self.indices = torch.tensor( + desc['views'] if 'views' in desc else list(range(self.centers.size(0))), + device=self.device) + + if views_to_load is not None: + self.centers = self.centers[views_to_load] + self.indices = self.indices[views_to_load] + + self.n_views = self.centers.size(0) + self.n_pixels = self.n_views * self.res[0] * self.res[1] + self.pano_rays = self._get_pano_rays() # [H*W, 3] + + if desc.get('gl_coord'): + print('Convert from OGL coordinate to DX coordinate (i. e. flip z axis)') + self.centers[:, 2] *= -1 + + def _get_pano_rays(self): + """ + Get unprojected rays of pixels on a panorama + + :return `Tensor(H*W, 3)`: rays' directions with one unit length + """ + spher_coords = torch.cat([ + torch.ones(*self.res, 1), + ((misc.meshgrid(*self.res, normalize=True)) * + torch.tensor([-2.0, 1.0]) + torch.tensor([1.5, 0.0])) * PI + ], dim=-1).to(device=self.device) + coords = sphere.spherical2cartesian(spher_coords) + return coords.flatten(0, 1) # [H*W, 3] diff --git a/data/view_dataset.py b/data/view_dataset.py new file mode 100644 index 0000000..477629b --- /dev/null +++ b/data/view_dataset.py @@ -0,0 +1,198 @@ +import os +import torch +import torch.nn.functional as nn_f +from typing import Tuple, Union +from utils import img +from utils import view +from utils import color +from utils import misc + + +class ViewDataset(object): + """ + Data loader for spherical view synthesis task + + Attributes + -------- + data_dir ```str```: the directory of dataset\n + view_file_pattern ```str```: the filename pattern of view images\n + cam ```object```: camera intrinsic parameters\n + view_centers ```Tensor(N, 3)```: centers of views\n + view_rots ```Tensor(N, 3, 3)```: rotation matrices of views\n + view_images ```Tensor(N, 3, H, W)```: images of views\n + view_depths ```Tensor(N, H, W)```: depths of views\n + """ + + class Chunk(object): + + def __init__(self, id, dataset, *, + indices: torch.Tensor, centers: torch.Tensor, rots: torch.Tensor): + """ + [summary] + + :param dataset `PanoDataset`: dataset object + :param indices `Tensor(N)`: indices of views + :param centers `Tensor(N, 3)`: centers of views + """ + self.id = id + self.dataset = dataset + self.indices = indices + self.centers = centers + self.rots = rots + self.n_views = self.indices.size(0) + self.n_pixels_per_view = self.dataset.res[0] * self.dataset.res[1] + self.colors = self.depths = self.bins = None + self.colors_cpu = self.depths_cpu = self.bins_cpu = None + self.loaded = False + + def release(self): + self.colors = self.depths = self.bins = None + self.loaded = False + + def load(self): + if self.dataset.image_path and self.colors_cpu is None: + images = color.cvt( + img.load(self.dataset.image_path % i for i in self.indices), + color.RGB, self.dataset.c) + if self.dataset.res != list(images.shape[-2:]): + images = nn_f.interpolate(images, self.dataset.res) + self.colors_cpu = images.permute(0, 2, 3, 1).flatten(0, 2) + if self.colors_cpu is not None: + self.colors = self.colors_cpu.to(self.dataset.device, non_blocking=True) + + if self.dataset.depth_path and self.depths_cpu is None: + depths = self.dataset._decode_depth_images( + img.load(self.depth_path % i for i in self.indices)) + if self.dataset.res != list(depths.shape[-2:]): + depths = nn_f.interpolate(depths, self.dataset.res) + self.depths_cpu = depths.flatten(0, 2) + if self.depths_cpu is not None: + self.depths = self.depths_cpu.to(self.dataset.device, non_blocking=True) + + if self.dataset.bins_path and self.bins_cpu is None: + bins = img.load([self.dataset.bins_path % i for i in self.indices]) + if self.dataset.res != list(bins.shape[-2:]): + bins = nn_f.interpolate(bins, self.dataset.res) + self.bins_cpu = bins.permute(0, 2, 3, 1).flatten(0, 2) + if self.bins_cpu is not None: + self.bins = self.bins_cpu.to(self.dataset.device, non_blocking=True) + + torch.cuda.current_stream(self.dataset.device).synchronize() + self.loaded = True + + def __len__(self): + return self.n_views * self.n_pixels_per_view + + def __getitem__(self, idx): + if not self.loaded: + self.load() + view_idx = idx // self.n_pixels_per_view + pix_idx = idx % self.n_pixels_per_view + rays_o = self.centers[view_idx] + rays_d = self.dataset.cam_rays[pix_idx] # (N, 3) + r = self.rots[view_idx].movedim(-1, -2) # (N, 3, 3) + rays_d = torch.matmul(rays_d, r) + extra_data = {} + if self.colors is not None: + extra_data['colors'] = self.colors[idx] + if self.depths is not None: + extra_data['depths'] = self.depths[idx] + if self.bins is not None: + extra_data['bins'] = self.bins[idx] + return idx, rays_o, rays_d, extra_data + + def __init__(self, desc: dict, *, + c: int = color.RGB, + load_images: bool = True, + load_depths: bool = False, + load_bins: bool = False, + res: Tuple[int, int] = None, + views_to_load: Union[range, torch.Tensor] = None, + device: torch.device = None, + **kwargs): + """ + Initialize data loader for spherical view synthesis task + + The dataset description file is a JSON file with following fields: + + - view_file_pattern: string, the path pattern of view images + - view_res: { "x", "y" }, the resolution of view images + - cam: { "fx", "fy", "cx", "cy" }, the focal and center of camera (in normalized image space) + - view_centers: [ [ x, y, z ], ... ], centers of views + - view_rots: [ [ m00, m01, ..., m22 ], ... ], rotation matrices of views + + :param dataset_desc_path ```str```: path to the data description file + :param load_images ```bool```: whether load view images and return in __getitem__() + :param load_depths ```bool```: whether load depth images and return in __getitem__() + :param c ```int```: color space to convert view images to + :param calculate_rays ```bool```: whether calculate rays + """ + self.c = c + self.device = device + self._load_desc(desc, res, views_to_load, load_images, load_depths, load_bins) + + def get_data(self): + return { + 'indices': self.indices, + 'centers': self.centers, + 'rots': self.rots + } + + def _decode_depth_images(self, input): + disp_range = (1 / self.depth_range[0], 1 / self.depth_range[1]) + disp_val = (1 - input[..., 0, :, :]) * (disp_range[1] - disp_range[0]) + disp_range[0] + return torch.reciprocal(disp_val) + + def _load_desc(self, desc: dict, + res: Tuple[int, int], + views_to_load: Union[range, torch.Tensor], + load_images: bool, + load_depths: bool, + load_bins: bool): + if load_images and desc.get('view_file_pattern'): + self.image_path = os.path.join(self.data_dir, desc['view_file_pattern']) + else: + self.image_path = None + if load_depths and desc.get('depth_file_pattern'): + self.depth_path = os.path.join(self.data_dir, desc['depth_file_pattern']) + else: + self.depth_path = None + if load_bins and desc.get('bins_file_pattern'): + self.bins_path = os.path.join(self.data_dir, desc['bins_file_pattern']) + else: + self.bins_path = None + self.res = res if res else misc.values(desc['view_res'], 'y', 'x') + self.cam = view.CameraParam(desc['cam_params'], self.res, device=self.device) + self.depth_range = misc.values(desc['depth_range'], 'min', 'max') \ + if 'depth_range' in desc else None + self.range = misc.values(desc['range'], 'min', 'max') if 'range' in desc else None + self.samples = desc.get('samples') + self.centers = torch.tensor(desc['view_centers'], device=self.device) # (N, 3) + self.rots = torch.tensor( + [ + view.euler_to_matrix([rot[1] if desc.get('gl_coord') else -rot[1], rot[0], 0]) + for rot in desc['view_rots'] + ] + if len(desc['view_rots'][0]) == 2 else desc['view_rots'], + device=self.device).view(-1, 3, 3) # (N, 3, 3) + self.indices = torch.tensor( + desc['views'] if 'views' in desc else list(range(self.centers.size(0))), + device=self.device) + + if views_to_load is not None: + self.centers = self.centers[views_to_load] + self.rots = self.rots[views_to_load] + self.indices = self.indices[views_to_load] + + self.n_views = self.centers.size(0) + self.n_pixels = self.n_views * self.res[0] * self.res[1] + + if desc.get('gl_coord'): + print('Convert from OGL coordinate to DX coordinate (i. e. flip z axis)') + if not desc['cam_params'].get('fov'): + self.cam.f[1] *= -1 + self.centers[:, 2] *= -1 + self.rots[:, 2] *= -1 + self.rots[..., 2] *= -1 + + self.cam_rays = self.cam.get_local_rays(flatten=True) -- GitLab