import os import torch import torch.nn.functional as nn_f from typing import Tuple, Union from utils import img from utils import view from utils import color from utils import misc class ViewDataset(object): """ Data loader for spherical view synthesis task Attributes -------- data_dir ```str```: the directory of dataset\n view_file_pattern ```str```: the filename pattern of view images\n cam ```object```: camera intrinsic parameters\n view_centers ```Tensor(N, 3)```: centers of views\n view_rots ```Tensor(N, 3, 3)```: rotation matrices of views\n view_images ```Tensor(N, 3, H, W)```: images of views\n view_depths ```Tensor(N, H, W)```: depths of views\n """ class Chunk(object): def __init__(self, id, dataset, *, indices: torch.Tensor, centers: torch.Tensor, rots: torch.Tensor): """ [summary] :param dataset `PanoDataset`: dataset object :param indices `Tensor(N)`: indices of views :param centers `Tensor(N, 3)`: centers of views """ self.id = id self.dataset = dataset self.indices = indices self.centers = centers self.rots = rots self.n_views = self.indices.size(0) self.n_pixels_per_view = self.dataset.res[0] * self.dataset.res[1] self.colors = self.depths = self.bins = None self.colors_cpu = self.depths_cpu = self.bins_cpu = None self.loaded = False def release(self): self.colors = self.depths = self.bins = None self.loaded = False def load(self): if self.dataset.image_path and self.colors_cpu is None: images = color.cvt( img.load(self.dataset.image_path % i for i in self.indices), color.RGB, self.dataset.c) if self.dataset.res != list(images.shape[-2:]): images = nn_f.interpolate(images, self.dataset.res) self.colors_cpu = images.permute(0, 2, 3, 1).flatten(0, 2) if self.colors_cpu is not None: self.colors = self.colors_cpu.to(self.dataset.device, non_blocking=True) if self.dataset.depth_path and self.depths_cpu is None: depths = self.dataset._decode_depth_images( img.load(self.depth_path % i for i in self.indices)) if self.dataset.res != list(depths.shape[-2:]): depths = nn_f.interpolate(depths, self.dataset.res) self.depths_cpu = depths.flatten(0, 2) if self.depths_cpu is not None: self.depths = self.depths_cpu.to(self.dataset.device, non_blocking=True) if self.dataset.bins_path and self.bins_cpu is None: bins = img.load([self.dataset.bins_path % i for i in self.indices]) if self.dataset.res != list(bins.shape[-2:]): bins = nn_f.interpolate(bins, self.dataset.res) self.bins_cpu = bins.permute(0, 2, 3, 1).flatten(0, 2) if self.bins_cpu is not None: self.bins = self.bins_cpu.to(self.dataset.device, non_blocking=True) torch.cuda.current_stream(self.dataset.device).synchronize() self.loaded = True def __len__(self): return self.n_views * self.n_pixels_per_view def __getitem__(self, idx): if not self.loaded: self.load() view_idx = idx // self.n_pixels_per_view pix_idx = idx % self.n_pixels_per_view rays_o = self.centers[view_idx] rays_d = self.dataset.cam_rays[pix_idx] # (N, 3) r = self.rots[view_idx].movedim(-1, -2) # (N, 3, 3) rays_d = (rays_d[:, None] @ r)[:, 0] extra_data = {} if self.colors is not None: extra_data['colors'] = self.colors[idx] if self.depths is not None: extra_data['depths'] = self.depths[idx] if self.bins is not None: extra_data['bins'] = self.bins[idx] return idx, rays_o, rays_d, extra_data def __init__(self, desc: dict, *, c: int = color.RGB, load_images: bool = True, load_depths: bool = False, load_bins: bool = False, res: Tuple[int, int] = None, views_to_load: Union[range, torch.Tensor] = None, device: torch.device = None, **kwargs): """ Initialize data loader for spherical view synthesis task The dataset description file is a JSON file with following fields: - view_file_pattern: string, the path pattern of view images - view_res: { "x", "y" }, the resolution of view images - cam: { "fx", "fy", "cx", "cy" }, the focal and center of camera (in normalized image space) - view_centers: [ [ x, y, z ], ... ], centers of views - view_rots: [ [ m00, m01, ..., m22 ], ... ], rotation matrices of views :param dataset_desc_path ```str```: path to the data description file :param load_images ```bool```: whether load view images and return in __getitem__() :param load_depths ```bool```: whether load depth images and return in __getitem__() :param c ```int```: color space to convert view images to :param calculate_rays ```bool```: whether calculate rays """ self.c = c self.device = device self._load_desc(desc, res, views_to_load, load_images, load_depths, load_bins) def get_data(self): return { 'indices': self.indices, 'centers': self.centers, 'rots': self.rots } def _decode_depth_images(self, input): disp_range = (1 / self.depth_range[0], 1 / self.depth_range[1]) disp_val = (1 - input[..., 0, :, :]) * (disp_range[1] - disp_range[0]) + disp_range[0] return torch.reciprocal(disp_val) def _load_desc(self, desc: dict, res: Tuple[int, int], views_to_load: Union[range, torch.Tensor], load_images: bool, load_depths: bool, load_bins: bool): if load_images and desc.get('view_file_pattern'): self.image_path = os.path.join(os.getcwd(), desc['view_file_pattern']) else: self.image_path = None if load_depths and desc.get('depth_file_pattern'): self.depth_path = os.path.join(os.getcwd(), desc['depth_file_pattern']) else: self.depth_path = None if load_bins and desc.get('bins_file_pattern'): self.bins_path = os.path.join(os.getcwd(), desc['bins_file_pattern']) else: self.bins_path = None self.res = res if res else misc.values(desc['view_res'], 'y', 'x') self.cam = view.CameraParam(desc['cam_params'], self.res, device=self.device) self.depth_range = misc.values(desc['depth_range'], 'min', 'max') \ if 'depth_range' in desc else None self.range = misc.values(desc['range'], 'min', 'max') if 'range' in desc else None self.samples = desc.get('samples') self.centers = torch.tensor(desc['view_centers'], device=self.device) # (N, 3) self.rots = torch.tensor( [ view.euler_to_matrix([rot[1] if desc.get('gl_coord') else -rot[1], rot[0], 0]) for rot in desc['view_rots'] ] if len(desc['view_rots'][0]) == 2 else desc['view_rots'], device=self.device).view(-1, 3, 3) # (N, 3, 3) self.indices = torch.tensor( desc['views'] if 'views' in desc else list(range(self.centers.size(0))), device=self.device) if views_to_load is not None: self.centers = self.centers[views_to_load] self.rots = self.rots[views_to_load] self.indices = self.indices[views_to_load] self.n_views = self.centers.size(0) self.n_pixels = self.n_views * self.res[0] * self.res[1] if desc.get('gl_coord'): print('Convert from OGL coordinate to DX coordinate (i. e. flip z axis)') if not desc['cam_params'].get('fov'): self.cam.f[1] *= -1 self.centers[:, 2] *= -1 self.rots[:, 2] *= -1 self.rots[..., 2] *= -1 self.cam_rays = self.cam.get_local_rays(flatten=True)