import os import torch import torch.nn.functional as nn_f from typing import Tuple, Union from utils import img from utils import color from utils import misc from utils import sphere from utils.mem_profiler import * from utils.constants import * class PanoDataset(object): """ Data loader for spherical view synthesis task Attributes -------- data_dir ```str```: the directory of dataset\n view_file_pattern ```str```: the filename pattern of view images\n cam_params ```object```: camera intrinsic parameters\n centers ```Tensor(N, 3)```: centers of views\n view_rots ```Tensor(N, 3, 3)```: rotation matrices of views\n images ```Tensor(N, 3, H, W)```: images of views\n view_depths ```Tensor(N, H, W)```: depths of views\n """ class Chunk(object): def __init__(self, id, dataset, *, indices: torch.Tensor, centers: torch.Tensor): """ [summary] :param dataset `PanoDataset`: dataset object :param indices `Tensor(N)`: indices of views :param centers `Tensor(N, 3)`: centers of views """ self.id = id self.dataset = dataset self.indices = indices self.centers = centers self.n_views = self.indices.size(0) self.n_pixels_per_view = self.dataset.res[0] * self.dataset.res[1] self.colors_cpu = None self.colors = None self.loaded = False def release(self): self.colors = None self.loaded = False MemProfiler.print_memory_stats(f'Chunk #{self.id} released') def load(self): if self.dataset.image_path is not None and self.colors_cpu is None: images = color.cvt( img.load(self.dataset.image_path % i for i in self.indices), color.RGB, self.dataset.c) if self.dataset.res != list(images.shape[-2:]): images = nn_f.interpolate(images, self.dataset.res) self.colors_cpu = images.permute(0, 2, 3, 1).flatten(0, 2) if self.colors_cpu is not None: self.colors = self.colors_cpu.to(self.dataset.device) self.loaded = True MemProfiler.print_memory_stats( f'Chunk #{self.id} ({self.n_views} views, ' f'{self.colors.numel() * self.colors.element_size() / 1024 / 1024:.2f}MB) loaded') def __len__(self): return self.n_views * self.n_pixels_per_view def __getitem__(self, idx): if not self.loaded: self.load() view_idx = idx // self.n_pixels_per_view pix_idx = idx % self.n_pixels_per_view extra_data = {} if self.colors is not None: extra_data['colors'] = self.colors[idx] rays_o = self.centers[view_idx] rays_d = self.dataset.pano_rays[pix_idx] return idx, rays_o, rays_d, extra_data def __init__(self, desc: dict, *, c: int = color.RGB, load_images: bool = True, res: Tuple[int, int] = None, views_to_load: Union[range, torch.Tensor] = None, device: torch.device = None, **kwargs): """ Initialize data loader for spherical view synthesis task The dataset description file is a JSON file with following fields: - view_file_pattern: string, the path pattern of view images - view_res: { "x", "y" }, the resolution of view - depth_range: { "min", "max" }, the depth range - range: { "min": [...], "max": [...] }, the range of translation and rotation - centers: [ [ x, y, z ], ... ], centers of views :param desc_path ```str```: path to the data description file :param load_images ```bool```: whether load view images and return in __getitem__() :param c ```int```: color space to convert view images to :param calculate_rays ```bool```: whether calculate rays """ self.c = c self.device = device self._load_desc(desc, data_dir, res, views_to_load, load_images) def get_data(self): return { 'indices': self.indices, 'centers': self.centers } def _load_desc(self, desc: dict, res: Tuple[int, int], views_to_load: Union[range, torch.Tensor], load_images: bool): if load_images and desc.get('view_file_pattern'): self.image_path = os.path.join(os.getcwd(), desc['view_file_pattern']) else: self.image_path = None self.res = res if res else misc.values(desc['view_res'], 'y', 'x') self.depth_range = misc.values(desc['depth_range'], 'min', 'max') \ if 'depth_range' in desc else None self.range = misc.values(desc['range'], 'min', 'max') if 'range' in desc else None self.samples = desc.get('samples') self.centers = torch.tensor(desc['view_centers'], device=self.device) # (N, 3) self.indices = torch.tensor( desc['views'] if 'views' in desc else list(range(self.centers.size(0))), device=self.device) if views_to_load is not None: self.centers = self.centers[views_to_load] self.indices = self.indices[views_to_load] self.n_views = self.centers.size(0) self.n_pixels = self.n_views * self.res[0] * self.res[1] self.pano_rays = self._get_pano_rays() # [H*W, 3] if desc.get('gl_coord'): print('Convert from OGL coordinate to DX coordinate (i. e. flip z axis)') self.centers[:, 2] *= -1 def _get_pano_rays(self): """ Get unprojected rays of pixels on a panorama :return `Tensor(H*W, 3)`: rays' directions with one unit length """ spher_coords = torch.cat([ torch.ones(*self.res, 1), ((misc.meshgrid(*self.res, normalize=True)) * torch.tensor([-2.0, 1.0]) + torch.tensor([1.5, 0.0])) * PI ], dim=-1).to(device=self.device) coords = sphere.spherical2cartesian(spher_coords) return coords.flatten(0, 1) # [H*W, 3]