import os import torch import torch.nn.functional as nn_f from typing import Dict, Tuple, Union from operator import itemgetter from pathlib import Path from utils import img from utils import color from utils import sphere from utils.mem_profiler import * from utils.constants import * class PanoDataset(object): """ Data loader for spherical view synthesis task Attributes -------- data_dir ```str```: the directory of dataset\n view_file_pattern ```str```: the filename pattern of view images\n cam_params ```object```: camera intrinsic parameters\n centers ```Tensor(N, 3)```: centers of views\n view_rots ```Tensor(N, 3, 3)```: rotation matrices of views\n images ```Tensor(N, 3, H, W)```: images of views\n view_depths ```Tensor(N, H, W)```: depths of views\n """ class Chunk(object): @property def n_views(self): return self.indices.size(0) @property def n_pixels_per_view(self): return self.dataset.n_pixels_per_view def __init__(self, id: int, dataset, chunk_data: Dict[str, torch.Tensor], *, color: int, **kwargs): """ [summary] :param dataset `PanoDataset`: dataset object :param indices `Tensor(N)`: indices of views :param centers `Tensor(N, 3)`: centers of views """ self.id = id self.dataset = dataset self.indices = chunk_data['indices'] self.centers = chunk_data['centers'] self.color = color self.colors_cpu = None self.colors = None self.loaded = False def release(self): self.colors = None self.loaded = False MemProfiler.print_memory_stats(f'Chunk #{self.id} released') def load(self): if self.dataset.image_path is not None and self.colors_cpu is None: images = color.cvt(img.load(self.dataset.image_path % i for i in self.indices), color.RGB, self.color) if self.dataset.res != tuple(images.shape[-2:]): images = nn_f.interpolate(images, self.dataset.res) self.colors_cpu = images.permute(0, 2, 3, 1) \ [:, self.dataset.pixels[:, 0], self.dataset.pixels[:, 1]].flatten(0, 1) if self.colors_cpu is not None: self.colors = self.colors_cpu.to(self.dataset.device) self.loaded = True MemProfiler.print_memory_stats( f'Chunk #{self.id} ({self.n_views} views, ' f'{self.colors.numel() * self.colors.element_size() / 1024 / 1024:.2f}MB) loaded') def __len__(self): return self.n_views * self.n_pixels_per_view def __getitem__(self, idx): if not self.loaded: self.load() view_idx = idx // self.n_pixels_per_view pix_idx = idx % self.n_pixels_per_view global_idx = self.indices[view_idx] * self.n_pixels_per_view + pix_idx extra_data = {} if self.colors is not None: extra_data['color'] = self.colors[idx] rays_o = self.centers[view_idx] rays_d = self.dataset.rays[pix_idx] return global_idx, rays_o, rays_d, extra_data @property def n_views(self): return self.centers.size(0) @property def n_pixels_per_view(self): return self.pixels.size(0) @property def n_pixels(self): return self.n_views * self.n_pixels_per_view def __init__(self, desc: dict, root: Path, name: str, *, load_images: bool = True, res: Tuple[int, int] = None, views_to_load: Union[range, torch.Tensor] = None, device: torch.device = None, **kwargs): """ Initialize data loader for spherical view synthesis task The dataset description file is a JSON file with following fields: - view_file_pattern: string, the path pattern of view images - view_res: { "x", "y" }, the resolution of view - depth_range: { "min", "max" }, the depth range - range: { "min": [...], "max": [...] }, the range of translation and rotation - centers: [ [ x, y, z ], ... ], centers of views :param desc_path ```str```: path to the data description file :param load_images ```bool```: whether load view images and return in __getitem__() :param c ```int```: color space to convert view images to :param calculate_rays ```bool```: whether calculate rays """ self.root = root self.name = name self.device = device self._load_desc(desc, res, views_to_load, load_images) def get_data(self): return { 'indices': self.indices, 'centers': self.centers } def _load_desc(self, desc: dict, res: Tuple[int, int], views_to_load: Union[range, torch.Tensor], load_images: bool): if load_images and desc.get('view_file_pattern'): file_pattern = desc['view_file_pattern'] if "/" not in file_pattern: file_pattern = f"{self.name}/{file_pattern}" self.image_path = str(self.root / file_pattern) else: self.image_path = None self.res = res if res else itemgetter("y", "x")(desc['view_res']) self.depth_range = itemgetter("min", "max")(desc['depth_range']) \ if 'depth_range' in desc else None self.bbox = None self.samples = desc.get('samples') self.centers = torch.tensor(desc['view_centers'], device=self.device) # (N, 3) self.indices = torch.tensor(desc.get('views') or [*range(self.centers.size(0))], device=self.device) if views_to_load is not None: self.centers = self.centers[views_to_load] self.indices = self.indices[views_to_load] self.pixels, self.rays = self._get_pano_rays() if desc.get('gl_coord'): print('Convert from OGL coordinate to DX coordinate (i. e. flip z axis)') self.centers[:, 2] *= -1 def _get_pano_rays(self): """ Get unprojected rays of pixels on a panorama :return `Tensor(N, 2)`: rays' pixel coordinates in pano image :return `Tensor(N, 3)`: rays' directions with one unit length """ phi = (torch.arange(self.res[0], device=self.device) + 0.5) / self.res[0] * PI # (H) length = (phi.sin() * self.res[1] * 0.5).ceil() * 2 cols = torch.arange(self.res[1], device=self.device)[None, :].expand(*self.res) # (H, W) mask = torch.logical_and(cols >= (self.res[1] - length[:, None]) / 2, cols < (self.res[1] + length[:, None]) / 2) # (H, W) pixs = mask.nonzero() # (N, 2) pixs_phi = (0.5 - (pixs[:, 0] + 0.5) / self.res[0]) * PI pixs_theta = (pixs[:, 1] * 2 + 1 - self.res[1]) / length[pixs[:, 0]] * PI spher_coords = torch.stack([torch.ones_like(pixs_phi), pixs_theta, pixs_phi], dim=-1) return pixs, sphere.spherical2cartesian(spher_coords) # (N, 3)