import os import torch import torch.nn.functional as nn_f from typing import Dict, Tuple, Union from operator import itemgetter from pathlib import Path from utils import img from utils import color from utils import sphere from utils import math from utils.mem_profiler import * from .dataset import Dataset class PanoDataset(Dataset): """ Data loader for spherical view synthesis task Attributes -------- data_dir ```str```: the directory of dataset\n view_file_pattern ```str```: the filename pattern of view images\n cam_params ```object```: camera intrinsic parameters\n centers ```Tensor(N, 3)```: centers of views\n view_rots ```Tensor(N, 3, 3)```: rotation matrices of views\n images ```Tensor(N, 3, H, W)```: images of views\n view_depths ```Tensor(N, H, W)```: depths of views\n """ class Chunk(object): @property def n_views(self): return self.indices.size(0) @property def n_pixels_per_view(self): return self.dataset.n_pixels_per_view def __init__(self, id: int, dataset, chunk_data: Dict[str, torch.Tensor], *, color: int, **kwargs): """ [summary] :param dataset `PanoDataset`: dataset object :param indices `Tensor(N)`: indices of views :param centers `Tensor(N, 3)`: centers of views """ self.id = id self.dataset = dataset self.indices = chunk_data['indices'] self.centers = chunk_data['centers'] self.color = color self.colors_cpu = None self.colors = None self.loaded = False def release(self): self.colors = None self.loaded = False MemProfiler.print_memory_stats(f'Chunk #{self.id} released') def load(self): if self.dataset.image_path and self.colors_cpu is None: images = color.cvt(img.load(self.dataset.image_path % i for i in self.indices), color.RGB, self.color) if self.dataset.res != tuple(images.shape[-2:]): images = nn_f.interpolate(images, self.dataset.res) self.colors_cpu = images.permute( 0, 2, 3, 1)[:, self.dataset.pixels[:, 0], self.dataset.pixels[:, 1]].flatten(0, 1) if self.colors_cpu is not None: self.colors = self.colors_cpu.to(self.dataset.device) self.loaded = True MemProfiler.print_memory_stats( f'Chunk #{self.id} ({self.n_views} views, ' f'{self.colors.numel() * self.colors.element_size() / 1024 / 1024:.2f}MB) loaded') def __len__(self): return self.n_views * self.n_pixels_per_view def __getitem__(self, idx): if not self.loaded: self.load() view_idx = idx // self.n_pixels_per_view pix_idx = idx % self.n_pixels_per_view global_idx = self.indices[view_idx] * self.n_pixels_per_view + pix_idx rays_o = self.centers[view_idx] rays_d = self.dataset.rays[pix_idx] data = { 'idx': global_idx, 'rays_o': rays_o, 'rays_d': rays_d, 'level': self.dataset.level } if self.colors is not None: data['color'] = self.colors[idx] return data @property def n_pixels_per_view(self): return self.pixels.size(0) def __init__(self, desc: dict, desc_path: Path, *, load_images: bool = True, res: Tuple[int, int] = None, views_to_load: Union[range, torch.Tensor] = None, device: torch.device = None, **kwargs): """ Initialize data loader for spherical view synthesis task The dataset description file is a JSON file with following fields: - view_file_pattern: string, the path pattern of view images - view_res: { "x", "y" }, the resolution of view - depth_range: { "min", "max" }, the depth range - range: { "min": [...], "max": [...] }, the range of translation and rotation - centers: [ [ x, y, z ], ... ], centers of views :param desc_path ```str```: path to the data description file :param load_images ```bool```: whether load view images and return in __getitem__() :param c ```int```: color space to convert view images to :param calculate_rays ```bool```: whether calculate rays """ super().__init__(desc, desc_path, res=res, views_to_load=views_to_load, device=device, load_images=load_images) def get_data(self): return { 'indices': self.indices, 'centers': self.centers } def _load_desc(self, res: Tuple[int, int], views_to_load: Union[range, torch.Tensor], load_images: bool): super()._load_desc(res, views_to_load) self.image_path = load_images and self._get_data_path("view") self.pixels, self.rays = self._get_pano_rays() def _get_pano_rays(self): """ Get unprojected rays of pixels on a panorama :return `Tensor(N, 2)`: rays' pixel coordinates in pano image :return `Tensor(N, 3)`: rays' directions with one unit length """ phi = (torch.arange(self.res[0], device=self.device) + 0.5) / self.res[0] * math.pi # (H) length = (phi.sin() * self.res[1] * 0.5).ceil() * 2 cols = torch.arange(self.res[1], device=self.device)[None, :].expand(*self.res) # (H, W) mask = torch.logical_and(cols >= (self.res[1] - length[:, None]) / 2, cols < (self.res[1] + length[:, None]) / 2) # (H, W) pixs = mask.nonzero() # (N, 2) pixs_phi = (0.5 - (pixs[:, 0] + 0.5) / self.res[0]) * math.pi pixs_theta = (pixs[:, 1] * 2 + 1 - self.res[1]) / length[pixs[:, 0]] * math.pi spher_coords = torch.stack([torch.ones_like(pixs_phi), pixs_theta, pixs_phi], dim=-1) return pixs, sphere.spherical2cartesian(spher_coords) # (N, 3)