From f1dd9e3a8aad6daa779881ec487ce1f8251bfe0e Mon Sep 17 00:00:00 2001
From: Nianchen Deng <dengnianchen@sjtu.edu.cn>
Date: Mon, 6 Sep 2021 10:22:39 +0800
Subject: [PATCH] tog'21 baseline

---
 data/dataset_factory.py |  23 +++++
 data/loader.py          | 113 +++++++++++++++++++----
 data/pano_dataset.py    | 159 ++++++++++++++++++++++++++++++++
 data/view_dataset.py    | 198 ++++++++++++++++++++++++++++++++++++++++
 4 files changed, 475 insertions(+), 18 deletions(-)
 create mode 100644 data/dataset_factory.py
 create mode 100644 data/pano_dataset.py
 create mode 100644 data/view_dataset.py

diff --git a/data/dataset_factory.py b/data/dataset_factory.py
new file mode 100644
index 0000000..7a1f7c2
--- /dev/null
+++ b/data/dataset_factory.py
@@ -0,0 +1,23 @@
+import os
+import json
+import utils.device
+from .pano_dataset import PanoDataset
+from .view_dataset import ViewDataset
+
+
+class DatasetFactory(object):
+
+    @staticmethod
+    def load(path, device=None, **kwargs):
+        device = device or utils.device.default()
+        data_dir = os.path.dirname(path)
+        with open(path, 'r', encoding='utf-8') as file:
+            data_desc = json.loads(file.read())
+        cwd = os.getcwd()
+        os.chdir(data_dir)
+        if 'type' in data_desc and data_desc['type'] == 'pano':
+            dataset = PanoDataset(data_desc, device=device, **kwargs)
+        else:
+            dataset = ViewDataset(data_desc, device=device, **kwargs)
+        os.chdir(cwd)
+        return dataset
\ No newline at end of file
diff --git a/data/loader.py b/data/loader.py
index 9510acd..49163cd 100644
--- a/data/loader.py
+++ b/data/loader.py
@@ -1,39 +1,116 @@
+from doctest import debug_script
+from logging import *
+import threading
 import torch
 import math
-from utils import device
 
 
-class FastDataLoader(object):
+class Preloader(object):
+
+    def __init__(self, device=None) -> None:
+        super().__init__()
+        self.stream = torch.cuda.Stream(device)
+        self.event_chunk_loaded = None
+
+    def preload_chunk(self, chunk):
+        if self.event_chunk_loaded is not None:
+            self.event_chunk_loaded.wait()
+        if chunk.loaded:
+            return
+        # print(f'Preloader: preload chunk #{chunk.id}')
+        self.event_chunk_loaded = threading.Event()
+        threading.Thread(target=Preloader._load_chunk, args=(self, chunk)).start()
+
+    def _load_chunk(self, chunk):
+        with torch.cuda.stream(self.stream):
+            chunk.load()
+        self.event_chunk_loaded.set()
+        # print(f'Preloader: chunk #{chunk.id} is loaded')
+
+
+class DataLoader(object):
 
     class Iter(object):
 
-        def __init__(self, dataset, batch_size, shuffle, drop_last) -> None:
+        def __init__(self, chunks, batch_size, shuffle, device: torch.device, preloader: Preloader):
             super().__init__()
-            self.indices = torch.randperm(len(dataset), device=device.default()) \
-                if shuffle else torch.arange(len(dataset), device=device.default())
-            self.offset = 0
             self.batch_size = batch_size
-            self.dataset = dataset
-            self.drop_last = drop_last
+            self.chunks = chunks
+            self.offset = -1
+            self.chunk_idx = -1
+            self.current_chunk = None
+            self.shuffle = shuffle
+            self.device = device
+            self.preloader = preloader
+
+        def __del__(self):
+            #print('DataLoader.Iter: clean chunks')
+            if self.preloader is not None and self.preloader.event_chunk_loaded is not None:
+                self.preloader.event_chunk_loaded.wait()
+            chunks_to_reserve = 1 if self.preloader is None else 2
+            for i in range(chunks_to_reserve, len(self.chunks)):
+                if self.chunks[i].loaded:
+                    self.chunks[i].release()
 
         def __next__(self):
-            if self.offset + (self.batch_size if self.drop_last else 0) >= len(self.dataset):
+            if self.offset == -1:
+                self._next_chunk()
+            stop = min(self.offset + self.batch_size, len(self.current_chunk))
+            if self.indices is not None:
+                indices = self.indices[self.offset:stop]
+            else:
+                indices = torch.arange(self.offset, stop, device=self.device)
+            self.offset = stop
+            if self.offset >= len(self.current_chunk):
+                self.offset = -1
+            return self.current_chunk[indices]
+
+        def _next_chunk(self):
+            if self.current_chunk is not None:
+                chunks_to_reserve = 1 if self.preloader is None else 2
+                if len(self.chunks) > chunks_to_reserve:
+                    self.current_chunk.release()
+            if self.chunk_idx >= len(self.chunks) - 1:
                 raise StopIteration()
-            indices = self.indices[self.offset:self.offset + self.batch_size]
-            self.offset += self.batch_size
-            return self.dataset[indices]
+            self.chunk_idx += 1
+            self.current_chunk = self.chunks[self.chunk_idx]
+            self.offset = 0
+            self.indices = torch.randperm(len(self.current_chunk), device=self.device) \
+                if self.shuffle else None
+            if self.preloader is not None:
+                self.preloader.preload_chunk(self.chunks[(self.chunk_idx + 1) % len(self.chunks)])
 
-    def __init__(self, dataset, batch_size, shuffle, drop_last=False, **kwargs) -> None:
+    def __init__(self, dataset, batch_size, *,
+                 chunk_max_items=None, shuffle=False, enable_preload=True):
         super().__init__()
         self.dataset = dataset
         self.batch_size = batch_size
         self.shuffle = shuffle
-        self.drop_last = drop_last
+        self.preloader = Preloader(self.dataset.device) if enable_preload else None
+        self._init_chunks(chunk_max_items)
 
     def __iter__(self):
-        return FastDataLoader.Iter(self.dataset, self.batch_size,
-                                   self.shuffle, self.drop_last)
+        return DataLoader.Iter(self.chunks, self.batch_size, self.shuffle, self.dataset.device,
+                               self.preloader)
 
     def __len__(self):
-        return math.floor(len(self.dataset) / self.batch_size) if self.drop_last \
-            else math.ceil(len(self.dataset) / self.batch_size)
+        return sum(math.ceil(len(chunk) / self.batch_size) for chunk in self.chunks)
+
+    def _init_chunks(self, chunk_max_items):
+        data = self.dataset.get_data()
+        if self.shuffle:
+            rand_seq = torch.randperm(self.dataset.n_views, device=self.dataset.device)
+            for key in data:
+                data[key] = data[key][rand_seq]
+        self.chunks = []
+        n_chunks = 1 if chunk_max_items is None else \
+            math.ceil(self.dataset.n_pixels / chunk_max_items)
+        views_per_chunk = math.ceil(self.dataset.n_views / n_chunks)
+        for offset in range(0, self.dataset.n_views, views_per_chunk):
+            sel = slice(offset, offset + views_per_chunk)
+            chunk_data = {}
+            for key in data:
+                chunk_data[key] = data[key][sel]
+            self.chunks.append(self.dataset.Chunk(len(self.chunks), self.dataset, **chunk_data))
+        if self.preloader is not None:
+            self.preloader.preload_chunk(self.chunks[0])
diff --git a/data/pano_dataset.py b/data/pano_dataset.py
new file mode 100644
index 0000000..9953c8f
--- /dev/null
+++ b/data/pano_dataset.py
@@ -0,0 +1,159 @@
+import os
+import torch
+import torch.nn.functional as nn_f
+from typing import Tuple, Union
+from utils import img
+from utils import color
+from utils import misc
+from utils import sphere
+from utils.mem_profiler import *
+from utils.constants import *
+
+
+class PanoDataset(object):
+    """
+    Data loader for spherical view synthesis task
+
+    Attributes
+    --------
+    data_dir ```str```: the directory of dataset\n
+    view_file_pattern ```str```: the filename pattern of view images\n
+    cam_params ```object```: camera intrinsic parameters\n
+    centers ```Tensor(N, 3)```: centers of views\n
+    view_rots ```Tensor(N, 3, 3)```: rotation matrices of views\n
+    images ```Tensor(N, 3, H, W)```: images of views\n
+    view_depths ```Tensor(N, H, W)```: depths of views\n
+    """
+
+    class Chunk(object):
+
+        def __init__(self, id, dataset, *,
+                     indices: torch.Tensor, centers: torch.Tensor):
+            """
+            [summary]
+
+            :param dataset `PanoDataset`: dataset object
+            :param indices `Tensor(N)`: indices of views
+            :param centers `Tensor(N, 3)`: centers of views
+            """
+            self.id = id
+            self.dataset = dataset
+            self.indices = indices
+            self.centers = centers
+            self.n_views = self.indices.size(0)
+            self.n_pixels_per_view = self.dataset.res[0] * self.dataset.res[1]
+            self.colors_cpu = None
+            self.colors = None
+            self.loaded = False
+
+        def release(self):
+            self.colors = None
+            self.loaded = False
+            MemProfiler.print_memory_stats(f'Chunk #{self.id} released')
+
+        def load(self):
+            if self.dataset.image_path is not None and self.colors_cpu is None:
+                images = color.cvt(
+                    img.load(self.dataset.image_path % i for i in self.indices),
+                    color.RGB, self.dataset.c)
+                if self.dataset.res != list(images.shape[-2:]):
+                    images = nn_f.interpolate(images, self.dataset.res)
+                self.colors_cpu = images.permute(0, 2, 3, 1).flatten(0, 2)
+            if self.colors_cpu is not None:
+                self.colors = self.colors_cpu.to(self.dataset.device)
+            self.loaded = True
+            MemProfiler.print_memory_stats(
+                f'Chunk #{self.id} ({self.n_views} views, '
+                f'{self.colors.numel() * self.colors.element_size() / 1024 / 1024:.2f}MB) loaded')
+
+        def __len__(self):
+            return self.n_views * self.n_pixels_per_view
+
+        def __getitem__(self, idx):
+            if not self.loaded:
+                self.load()
+            view_idx = idx // self.n_pixels_per_view
+            pix_idx = idx % self.n_pixels_per_view
+            extra_data = {}
+            if self.colors is not None:
+                extra_data['colors'] = self.colors[idx]
+            rays_o = self.centers[view_idx]
+            rays_d = self.dataset.pano_rays[pix_idx]
+            return idx, rays_o, rays_d, extra_data
+
+    def __init__(self, desc: dict, *,
+                 c: int = color.RGB,
+                 load_images: bool = True,
+                 res: Tuple[int, int] = None,
+                 views_to_load: Union[range, torch.Tensor] = None,
+                 device: torch.device = None,
+                 **kwargs):
+        """
+        Initialize data loader for spherical view synthesis task
+
+        The dataset description file is a JSON file with following fields:
+
+        - view_file_pattern: string, the path pattern of view images
+        - view_res: { "x", "y" }, the resolution of view
+        - depth_range: { "min", "max" }, the depth range
+        - range: { "min": [...], "max": [...] }, the range of translation and rotation
+        - centers: [ [ x, y, z ], ... ], centers of views
+
+        :param desc_path ```str```: path to the data description file
+        :param load_images ```bool```: whether load view images and return in __getitem__()
+        :param c ```int```: color space to convert view images to
+        :param calculate_rays ```bool```: whether calculate rays
+        """
+        self.c = c
+        self.device = device
+        self._load_desc(desc, res, views_to_load, load_images)
+
+    def get_data(self):
+        return {
+            'indices': self.indices,
+            'centers': self.centers
+        }
+
+    def _load_desc(self, desc: dict,
+                   res: Tuple[int, int],
+                   views_to_load: Union[range, torch.Tensor],
+                   load_images: bool):
+        if load_images and desc.get('view_file_pattern'):
+            self.image_path = os.path.join(os.getcwd(), desc['view_file_pattern'])
+        else:
+            self.image_path = None
+        self.res = res if res else misc.values(desc['view_res'], 'y', 'x')
+        self.depth_range = misc.values(desc['depth_range'], 'min', 'max') \
+            if 'depth_range' in desc else None
+        self.range = misc.values(desc['range'], 'min', 'max') if 'range' in desc else None
+        self.samples = desc.get('samples')
+        self.centers = torch.tensor(desc['view_centers'], device=self.device)  # (N, 3)
+        self.indices = torch.tensor(
+            desc['views'] if 'views' in desc else list(range(self.centers.size(0))),
+            device=self.device)
+
+        if views_to_load is not None:
+            self.centers = self.centers[views_to_load]
+            self.indices = self.indices[views_to_load]
+
+        self.n_views = self.centers.size(0)
+        self.n_pixels = self.n_views * self.res[0] * self.res[1]
+        self.pano_rays = self._get_pano_rays()  # [H*W, 3]
+
+        if desc.get('gl_coord'):
+            print('Convert from OGL coordinate to DX coordinate (i. e. flip z axis)')
+            self.centers[:, 2] *= -1
+
+    def _get_pano_rays(self):
+        """
+        Get unprojected rays of pixels on a panorama
+
+        :return `Tensor(H*W, 3)`: rays' directions with one unit length
+        """
+        spher_coords = torch.cat([
+            torch.ones(*self.res, 1),
+            ((misc.meshgrid(*self.res, normalize=True)) *
+             torch.tensor([-2.0, 1.0]) + torch.tensor([1.5, 0.0])) * PI
+        ], dim=-1).to(device=self.device)
+        coords = sphere.spherical2cartesian(spher_coords)
+        return coords.flatten(0, 1)  # [H*W, 3]
diff --git a/data/view_dataset.py b/data/view_dataset.py
new file mode 100644
index 0000000..477629b
--- /dev/null
+++ b/data/view_dataset.py
@@ -0,0 +1,198 @@
+import os
+import torch
+import torch.nn.functional as nn_f
+from typing import Tuple, Union
+from utils import img
+from utils import view
+from utils import color
+from utils import misc
+
+
+class ViewDataset(object):
+    """
+    Data loader for spherical view synthesis task
+
+    Attributes
+    --------
+    data_dir ```str```: the directory of dataset\n
+    view_file_pattern ```str```: the filename pattern of view images\n
+    cam ```object```: camera intrinsic parameters\n
+    view_centers ```Tensor(N, 3)```: centers of views\n
+    view_rots ```Tensor(N, 3, 3)```: rotation matrices of views\n
+    view_images ```Tensor(N, 3, H, W)```: images of views\n
+    view_depths ```Tensor(N, H, W)```: depths of views\n
+    """
+
+    class Chunk(object):
+
+        def __init__(self, id, dataset, *,
+                     indices: torch.Tensor, centers: torch.Tensor, rots: torch.Tensor):
+            """
+            [summary]
+
+            :param dataset `PanoDataset`: dataset object
+            :param indices `Tensor(N)`: indices of views
+            :param centers `Tensor(N, 3)`: centers of views
+            """
+            self.id = id
+            self.dataset = dataset
+            self.indices = indices
+            self.centers = centers
+            self.rots = rots
+            self.n_views = self.indices.size(0)
+            self.n_pixels_per_view = self.dataset.res[0] * self.dataset.res[1]
+            self.colors = self.depths = self.bins = None
+            self.colors_cpu = self.depths_cpu = self.bins_cpu = None
+            self.loaded = False
+
+        def release(self):
+            self.colors = self.depths = self.bins = None
+            self.loaded = False
+
+        def load(self):
+            if self.dataset.image_path and self.colors_cpu is None:
+                images = color.cvt(
+                    img.load(self.dataset.image_path % i for i in self.indices),
+                    color.RGB, self.dataset.c)
+                if self.dataset.res != list(images.shape[-2:]):
+                    images = nn_f.interpolate(images, self.dataset.res)
+                self.colors_cpu = images.permute(0, 2, 3, 1).flatten(0, 2)
+            if self.colors_cpu is not None:
+                self.colors = self.colors_cpu.to(self.dataset.device, non_blocking=True)
+
+            if self.dataset.depth_path and self.depths_cpu is None:
+                depths = self.dataset._decode_depth_images(
+                    img.load(self.depth_path % i for i in self.indices))
+                if self.dataset.res != list(depths.shape[-2:]):
+                    depths = nn_f.interpolate(depths, self.dataset.res)
+                self.depths_cpu = depths.flatten(0, 2)
+            if self.depths_cpu is not None:
+                self.depths = self.depths_cpu.to(self.dataset.device, non_blocking=True)
+
+            if self.dataset.bins_path and self.bins_cpu is None:
+                bins = img.load([self.dataset.bins_path % i for i in self.indices])
+                if self.dataset.res != list(bins.shape[-2:]):
+                    bins = nn_f.interpolate(bins, self.dataset.res)
+                self.bins_cpu = bins.permute(0, 2, 3, 1).flatten(0, 2)
+            if self.bins_cpu is not None:
+                self.bins = self.bins_cpu.to(self.dataset.device, non_blocking=True)
+
+            torch.cuda.current_stream(self.dataset.device).synchronize()
+            self.loaded = True
+
+        def __len__(self):
+            return self.n_views * self.n_pixels_per_view
+
+        def __getitem__(self, idx):
+            if not self.loaded:
+                self.load()
+            view_idx = idx // self.n_pixels_per_view
+            pix_idx = idx % self.n_pixels_per_view
+            rays_o = self.centers[view_idx]
+            rays_d = self.dataset.cam_rays[pix_idx] # (N, 3)
+            r = self.rots[view_idx].movedim(-1, -2) # (N, 3, 3)
+            rays_d = torch.matmul(rays_d, r)
+            extra_data = {}
+            if self.colors is not None:
+                extra_data['colors'] = self.colors[idx]
+            if self.depths is not None:
+                extra_data['depths'] = self.depths[idx]
+            if self.bins is not None:
+                extra_data['bins'] = self.bins[idx]
+            return idx, rays_o, rays_d, extra_data
+
+    def __init__(self, desc: dict, *,
+                 c: int = color.RGB,
+                 load_images: bool = True,
+                 load_depths: bool = False,
+                 load_bins: bool = False,
+                 res: Tuple[int, int] = None,
+                 views_to_load: Union[range, torch.Tensor] = None,
+                 device: torch.device = None,
+                 **kwargs):
+        """
+        Initialize data loader for spherical view synthesis task
+
+        The dataset description file is a JSON file with following fields:
+
+        - view_file_pattern: string, the path pattern of view images
+        - view_res: { "x", "y" }, the resolution of view images
+        - cam: { "fx", "fy", "cx", "cy" }, the focal and center of camera (in normalized image space)
+        - view_centers: [ [ x, y, z ], ... ], centers of views
+        - view_rots: [ [ m00, m01, ..., m22 ], ... ], rotation matrices of views
+
+        :param dataset_desc_path ```str```: path to the data description file
+        :param load_images ```bool```: whether load view images and return in __getitem__()
+        :param load_depths ```bool```: whether load depth images and return in __getitem__()
+        :param c ```int```: color space to convert view images to
+        :param calculate_rays ```bool```: whether calculate rays
+        """
+        self.c = c
+        self.device = device
+        self._load_desc(desc, res, views_to_load, load_images, load_depths, load_bins)
+
+    def get_data(self):
+        return {
+            'indices': self.indices,
+            'centers': self.centers,
+            'rots': self.rots
+        }
+    
+    def _decode_depth_images(self, input):
+        disp_range = (1 / self.depth_range[0], 1 / self.depth_range[1])
+        disp_val = (1 - input[..., 0, :, :]) * (disp_range[1] - disp_range[0]) + disp_range[0]
+        return torch.reciprocal(disp_val)
+
+    def _load_desc(self, desc: dict,
+                   res: Tuple[int, int],
+                   views_to_load: Union[range, torch.Tensor],
+                   load_images: bool,
+                   load_depths: bool,
+                   load_bins: bool):
+        if load_images and desc.get('view_file_pattern'):
+            self.image_path = os.path.join(self.data_dir, desc['view_file_pattern'])
+        else:
+            self.image_path = None
+        if load_depths and desc.get('depth_file_pattern'):
+            self.depth_path = os.path.join(self.data_dir, desc['depth_file_pattern'])
+        else:
+            self.depth_path = None
+        if load_bins and desc.get('bins_file_pattern'):
+            self.bins_path = os.path.join(self.data_dir, desc['bins_file_pattern'])
+        else:
+            self.bins_path = None
+        self.res = res if res else misc.values(desc['view_res'], 'y', 'x')
+        self.cam = view.CameraParam(desc['cam_params'], self.res, device=self.device)
+        self.depth_range = misc.values(desc['depth_range'], 'min', 'max') \
+            if 'depth_range' in desc else None
+        self.range = misc.values(desc['range'], 'min', 'max') if 'range' in desc else None
+        self.samples = desc.get('samples')
+        self.centers = torch.tensor(desc['view_centers'], device=self.device)  # (N, 3)
+        self.rots = torch.tensor(
+            [
+                view.euler_to_matrix([rot[1] if desc.get('gl_coord') else -rot[1], rot[0], 0])
+                for rot in desc['view_rots']
+            ]
+            if len(desc['view_rots'][0]) == 2 else desc['view_rots'],
+            device=self.device).view(-1, 3, 3)  # (N, 3, 3)
+        self.indices = torch.tensor(
+            desc['views'] if 'views' in desc else list(range(self.centers.size(0))),
+            device=self.device)
+
+        if views_to_load is not None:
+            self.centers = self.centers[views_to_load]
+            self.rots = self.rots[views_to_load]
+            self.indices = self.indices[views_to_load]
+
+        self.n_views = self.centers.size(0)
+        self.n_pixels = self.n_views * self.res[0] * self.res[1]
+
+        if desc.get('gl_coord'):
+            print('Convert from OGL coordinate to DX coordinate (i. e. flip z axis)')
+            if not desc['cam_params'].get('fov'):
+                self.cam.f[1] *= -1
+            self.centers[:, 2] *= -1
+            self.rots[:, 2] *= -1
+            self.rots[..., 2] *= -1
+        
+        self.cam_rays = self.cam.get_local_rays(flatten=True)
-- 
GitLab