import torch import torch.nn.functional as nn_f from typing import Dict, Tuple, Union from pathlib import Path from utils import img from utils import color from .dataset import Dataset class ViewDataset(Dataset): """ Data loader for spherical view synthesis task Attributes -------- data_dir ```str```: the directory of dataset\n view_file_pattern ```str```: the filename pattern of view images\n cam ```object```: camera intrinsic parameters\n view_centers ```Tensor(N, 3)```: centers of views\n view_rots ```Tensor(N, 3, 3)```: rotation matrices of views\n view_images ```Tensor(N, 3, H, W)```: images of views\n view_depths ```Tensor(N, H, W)```: depths of views\n """ class Chunk(object): def __init__(self, id: int, dataset, chunk_data: Dict[str, torch.Tensor], *, color: int, **kwargs): """ [summary] :param dataset `ViewDataset`: dataset object :param indices `Tensor(N)`: indices of views :param centers `Tensor(N, 3)`: centers of views """ self.id = id self.dataset = dataset self.indices = chunk_data['indices'] self.centers = chunk_data['centers'] self.rots = chunk_data['rots'] self.color = color self.n_views = self.indices.size(0) self.n_pixels_per_view = self.dataset.res[0] * self.dataset.res[1] self.colors = self.depths = self.bins = None self.colors_cpu = self.depths_cpu = self.bins_cpu = None self.loaded = False def release(self): self.colors = self.depths = self.bins = None self.loaded = False def load(self): #print("chunk load") try: if self.dataset.image_path and self.colors_cpu is None: images = color.cvt(img.load(self.dataset.image_path % i for i in self.indices), color.RGB, self.color) if self.dataset.res != list(images.shape[-2:]): images = nn_f.interpolate(images, self.dataset.res) self.colors_cpu = images.permute(0, 2, 3, 1).flatten(0, 2) if self.colors_cpu is not None: self.colors = self.colors_cpu.to(self.dataset.device, non_blocking=True) if self.dataset.depth_path and self.depths_cpu is None: depths = self.dataset._decode_depth_images( img.load(self.depth_path % i for i in self.indices)) if self.dataset.res != list(depths.shape[-2:]): depths = nn_f.interpolate(depths, self.dataset.res) self.depths_cpu = depths.flatten(0, 2) if self.depths_cpu is not None: self.depths = self.depths_cpu.to(self.dataset.device, non_blocking=True) if self.dataset.bins_path and self.bins_cpu is None: bins = img.load([self.dataset.bins_path % i for i in self.indices]) if self.dataset.res != list(bins.shape[-2:]): bins = nn_f.interpolate(bins, self.dataset.res) self.bins_cpu = bins.permute(0, 2, 3, 1).flatten(0, 2) if self.bins_cpu is not None: self.bins = self.bins_cpu.to(self.dataset.device, non_blocking=True) torch.cuda.current_stream(self.dataset.device).synchronize() self.loaded = True except Exception as ex: print(ex) exit(-1) def __len__(self): return self.n_views * self.n_pixels_per_view def __getitem__(self, idx): if not self.loaded: self.load() view_idx = idx // self.n_pixels_per_view pix_idx = idx % self.n_pixels_per_view global_idx = self.indices[view_idx] * self.n_pixels_per_view + pix_idx rays_o = self.centers[view_idx] rays_d = self.dataset.cam_rays[pix_idx][:, None] # (N, 1, 3) r = self.rots[view_idx].movedim(-1, -2) # (N, 3, 3) rays_d = torch.matmul(rays_d, r)[:, 0] # (N, 3) data = { 'idx': global_idx, 'rays_o': rays_o, 'rays_d': rays_d, 'level': self.dataset.level } if self.colors is not None: data['color'] = self.colors[idx] if self.depths is not None: data['depth'] = self.depths[idx] if self.bins is not None: data['bin'] = self.bins[idx] #data['view_idx'] = view_idx #data['pix_idx'] = pix_idx return data def __init__(self, desc: dict, desc_path: Path, *, load_images: bool = True, load_depths: bool = False, load_bins: bool = False, res: Tuple[int, int] = None, views_to_load: Union[range, torch.Tensor] = None, device: torch.device = None, **kwargs): """ Initialize data loader for spherical view synthesis task The dataset description file is a JSON file with following fields: - view_file_pattern: string, the path pattern of view images - view_res: { "x", "y" }, the resolution of view images - cam: { "fx", "fy", "cx", "cy" }, the focal and center of camera (in normalized image space) - view_centers: [ [ x, y, z ], ... ], centers of views - view_rots: [ [ m00, m01, ..., m22 ], ... ], rotation matrices of views :param dataset_desc_path ```str```: path to the data description file :param load_images ```bool```: whether load view images and return in __getitem__() :param load_depths ```bool```: whether load depth images and return in __getitem__() :param c ```int```: color space to convert view images to :param calculate_rays ```bool```: whether calculate rays """ super().__init__(desc, desc_path, res=res, views_to_load=views_to_load, device=device, load_images=load_images, load_depths=load_depths, load_bins=load_bins) def _decode_depth_images(self, input): disp_range = (1 / self.depth_range[0], 1 / self.depth_range[1]) disp_val = (1 - input[..., 0, :, :]) * (disp_range[1] - disp_range[0]) + disp_range[0] return torch.reciprocal(disp_val) def _load_desc(self, res: Tuple[int, int], views_to_load: Union[range, torch.Tensor], load_images: bool, load_depths: bool, load_bins: bool): super()._load_desc(res, views_to_load) self.image_path = load_images and self._get_data_path("view") self.depth_path = load_depths and self._get_data_path("depth") self.bins_path = load_bins and self._get_data_path("bins") self.cam_rays = self.cam.get_local_rays(flatten=True)