import math from typing import Tuple import torch import torchvision.transforms.functional as trans_f import json from ..my import util from ..my import device def _convert_camera_params(input_camera_params, view_res): """ Check and convert camera parameters in config file to pixel-space :param cam_params: { ["fx", "fy" | "fov"], "cx", "cy", ["normalized"] }, the parameters of camera :return: camera parameters """ input_is_normalized = bool(input_camera_params.get('normalized')) camera_params = {} if 'fov' in input_camera_params: camera_params['fx'] = camera_params['fy'] = \ (1 if input_is_normalized else view_res[0]) / \ util.Fov2Length(input_camera_params['fov']) camera_params['fy'] *= -1 else: camera_params['fx'] = input_camera_params['fx'] camera_params['fy'] = input_camera_params['fy'] camera_params['cx'] = input_camera_params['cx'] camera_params['cy'] = input_camera_params['cy'] if input_is_normalized: camera_params['fx'] *= view_res[1] camera_params['fy'] *= view_res[0] camera_params['cx'] *= view_res[1] camera_params['cy'] *= view_res[0] return camera_params class SphericalViewSynDataset(torch.utils.data.dataset.Dataset): """ Data loader for spherical view synthesis task Attributes -------- data_dir ```str```: the directory of dataset\n view_file_pattern ```str```: the filename pattern of view images\n cam_params ```object```: camera intrinsic parameters\n view_centers ```Tensor(N, 3)```: centers of views\n view_rots ```Tensor(N, 3, 3)```: rotation matrices of views\n view_images ```Tensor(N, 3, H, W)```: images of views\n """ def __init__(self, dataset_desc_path: str, load_images: bool = True, gray: bool = False, ray_as_item=False): """ Initialize data loader for spherical view synthesis task The dataset description file is a JSON file with following fields: - view_file_pattern: string, the path pattern of view images - view_res: { "x", "y" }, the resolution of view images - cam_params: { ["fx", "fy" | "fov"], "cx", "cy", ["normalized"] }, the parameters of camera - view_centers: [ [ x, y, z ], ... ], centers of views - view_rots: [ [ m00, m01, ..., m22 ], ... ], rotation matrices of views :param dataset_desc_path ```str```: path to the data description file :param load_images ```bool```: whether load view images and return in __getitem__() :param gray ```bool```: whether convert view images to grayscale :param ray_as_item ```bool```: whether to treat each ray in view as an item """ self.data_dir = dataset_desc_path.rsplit('/', 1)[0] + '/' self.load_images = load_images self.ray_as_item = ray_as_item # Load dataset description file with open(dataset_desc_path, 'r', encoding='utf-8') as file: data_desc = json.loads(file.read()) if data_desc['view_file_pattern'] == '': self.load_images = False else: self.view_file_pattern: str = self.data_dir + \ data_desc['view_file_pattern'] self.view_res = (data_desc['view_res']['y'], data_desc['view_res']['x']) self.cam_params = _convert_camera_params( data_desc['cam_params'], self.view_res) self.view_centers = torch.tensor(data_desc['view_centers']) # (N, 3) self.view_rots = torch.tensor(data_desc['view_rots']) \ .view(-1, 3, 3) # (N, 3, 3) # Load view images if self.load_images: self.view_images = util.ReadImageTensor( [self.view_file_pattern % i for i in range(self.view_centers.size(0))]) if gray: self.view_images = trans_f.rgb_to_grayscale(self.view_images) else: self.view_images = None local_view_rays = util.GetLocalViewRays(self.cam_params, self.view_res, flatten=True) # (M, 3) # Transpose matrix so we can perform vec x mat view_rots_t = self.view_rots.permute(0, 2, 1) # rays_o & rays_d are both (N, M, 3) self.rays_o = self.view_centers.unsqueeze(1) \ .expand(-1, local_view_rays.size(0), -1) self.rays_d = torch.matmul(local_view_rays, view_rots_t) # Flatten rays if ray_as_item = True if ray_as_item: self.view_pixels = self.view_images.permute(0, 2, 3, 1).flatten( 0, 2) if self.view_images != None else None self.rays_o = self.rays_o.flatten(0, 1) self.rays_d = self.rays_d.flatten(0, 1) def __len__(self): return self.rays_o.size(0) def __getitem__(self, idx): if self.load_images: if self.ray_as_item: return idx, self.view_pixels[idx], self.rays_o[idx], self.rays_d[idx] return idx, self.view_images[idx], self.rays_o[idx], self.rays_d[idx] return idx, False, self.rays_o[idx], self.rays_d[idx] class FastSphericalViewSynDataset(torch.utils.data.dataset.Dataset): """ Data loader for spherical view synthesis task Attributes -------- data_dir ```str```: the directory of dataset\n view_file_pattern ```str```: the filename pattern of view images\n cam_params ```object```: camera intrinsic parameters\n view_centers ```Tensor(N, 3)```: centers of views\n view_rots ```Tensor(N, 3, 3)```: rotation matrices of views\n view_images ```Tensor(N, 3, H, W)```: images of views\n """ def __init__(self, dataset_desc_path: str, load_images: bool = True, gray: bool = False): """ Initialize data loader for spherical view synthesis task The dataset description file is a JSON file with following fields: - view_file_pattern: string, the path pattern of view images - view_res: { "x", "y" }, the resolution of view images - cam_params: { "fx", "fy", "cx", "cy" }, the focal and center of camera (in normalized image space) - view_centers: [ [ x, y, z ], ... ], centers of views - view_rots: [ [ m00, m01, ..., m22 ], ... ], rotation matrices of views :param dataset_desc_path ```str```: path to the data description file :param load_images ```bool```: whether load view images and return in __getitem__() :param gray ```bool```: whether convert view images to grayscale """ super().__init__() self.data_dir = dataset_desc_path.rsplit('/', 1)[0] + '/' self.load_images = load_images # Load dataset description file with open(dataset_desc_path, 'r', encoding='utf-8') as file: data_desc = json.loads(file.read()) if data_desc['view_file_pattern'] == '': self.load_images = False else: self.view_file_pattern: str = self.data_dir + \ data_desc['view_file_pattern'] self.view_res = (data_desc['view_res']['y'], data_desc['view_res']['x']) self.cam_params = _convert_camera_params( data_desc['cam_params'], self.view_res) self.view_centers = torch.tensor( data_desc['view_centers'], device=device.GetDevice()) # (N, 3) self.view_rots = torch.tensor( data_desc['view_rots'], device=device.GetDevice()).view(-1, 3, 3) # (N, 3, 3) self.n_views = self.view_centers.size(0) self.n_pixels = self.n_views * self.view_res[0] * self.view_res[1] # Load view images if self.load_images: self.view_images = util.ReadImageTensor( [self.view_file_pattern % i for i in range(self.view_centers.size(0))] ).to(device.GetDevice()) if gray: self.view_images = trans_f.rgb_to_grayscale(self.view_images) else: self.view_images = None local_view_rays = util.GetLocalViewRays(self.cam_params, self.view_res, True) \ .to(device.GetDevice()) # (HW, 3) # Transpose matrix so we can perform vec x mat view_rots_t = self.view_rots.permute(0, 2, 1) # rays_o & rays_d are both (N, H, W, 3) self.rays_o = self.view_centers[:, None, None, :] \ .expand(-1, self.view_res[0], self.view_res[1], -1) self.rays_d = torch.matmul(local_view_rays, view_rots_t) \ .view_as(self.rays_o) self.patched_images = self.view_images # (N, 1|3, H, W) self.patched_rays_o = self.rays_o # (N, H, W, 3) self.patched_rays_d = self.rays_d # (N, H, W, 3) def set_patch_size(self, patch_size: Tuple[int, int], offset: Tuple[int, int] = (0, 0)): """ Set the size of patch and (optional) offset. If patch_size = (1, 1) :param patch_size: :param offset: """ patches = ((self.view_res[0] - offset[0]) // patch_size[0], (self.view_res[1] - offset[1]) // patch_size[1]) slices = (..., slice(offset[0], offset[0] + patches[0] * patch_size[0]), slice(offset[1], offset[1] + patches[1] * patch_size[1])) if patch_size[0] == 1 and patch_size[1] == 1: self.patched_images = self.view_images[slices] \ .permute(0, 2, 3, 1).flatten(0, 2) if self.load_images else None self.patched_rays_o = self.rays_o[slices].flatten(0, 2) self.patched_rays_d = self.rays_d[slices].flatten(0, 2) elif patch_size[0] == self.view_res[0] and patch_size[1] == self.view_res[1]: self.patched_images = self.view_images self.patched_rays_o = self.rays_o self.patched_rays_d = self.rays_d else: print(self.view_images.size(), self.rays_o.size()) print(self.view_images[slices].size(), self.rays_o[slices].size()) self.patched_images = self.view_images[slices] \ .view(self.n_views, -1, patches[0], patch_size[0], patches[1], patch_size[1]) \ .permute(0, 2, 4, 1, 3, 5).flatten(0, 2) if self.load_images else None self.patched_rays_o = self.rays_o[slices] \ .view(self.n_views, patches[0], patch_size[0], patches[1], patch_size[1], -1) \ .permute(0, 1, 3, 2, 4, 5).flatten(0, 2) self.patched_rays_d = self.rays_d[slices] \ .view(self.n_views, patches[0], patch_size[0], patches[1], patch_size[1], -1) \ .permute(0, 1, 3, 2, 4, 5).flatten(0, 2) def __len__(self): return self.patched_rays_o.size(0) def __getitem__(self, idx): if self.load_images: return idx, self.patched_images[idx], self.patched_rays_o[idx], \ self.patched_rays_d[idx] return idx, False, self.patched_rays_o[idx], self.patched_rays_d[idx] class FastDataLoader(object): class Iter(object): def __init__(self, dataset, batch_size, shuffle, drop_last) -> None: super().__init__() self.indices = torch.randperm(len(dataset), device=device.GetDevice()) \ if shuffle else torch.arange(len(dataset), device=device.GetDevice()) self.offset = 0 self.batch_size = batch_size self.dataset = dataset self.drop_last = drop_last def __next__(self): if self.offset + (self.batch_size if self.drop_last else 0) >= len(self.dataset): raise StopIteration() indices = self.indices[self.offset:self.offset + self.batch_size] self.offset += self.batch_size return self.dataset[indices] def __init__(self, dataset, batch_size, shuffle, drop_last, **kwargs) -> None: super().__init__() self.dataset = dataset self.batch_size = batch_size self.shuffle = shuffle self.drop_last = drop_last def __iter__(self): return FastDataLoader.Iter(self.dataset, self.batch_size, self.shuffle, self.drop_last) def __len__(self): return math.floor(len(self.dataset) / self.batch_size) if self.drop_last \ else math.ceil(len(self.dataset) / self.batch_size)