view_dataset.py 8.43 KB
Newer Older
Nianchen Deng's avatar
Nianchen Deng committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
import os
import torch
import torch.nn.functional as nn_f
from typing import Tuple, Union
from utils import img
from utils import view
from utils import color
from utils import misc


class ViewDataset(object):
    """
    Data loader for spherical view synthesis task

    Attributes
    --------
    data_dir ```str```: the directory of dataset\n
    view_file_pattern ```str```: the filename pattern of view images\n
    cam ```object```: camera intrinsic parameters\n
    view_centers ```Tensor(N, 3)```: centers of views\n
    view_rots ```Tensor(N, 3, 3)```: rotation matrices of views\n
    view_images ```Tensor(N, 3, H, W)```: images of views\n
    view_depths ```Tensor(N, H, W)```: depths of views\n
    """

    class Chunk(object):

        def __init__(self, id, dataset, *,
                     indices: torch.Tensor, centers: torch.Tensor, rots: torch.Tensor):
            """
            [summary]

            :param dataset `PanoDataset`: dataset object
            :param indices `Tensor(N)`: indices of views
            :param centers `Tensor(N, 3)`: centers of views
            """
            self.id = id
            self.dataset = dataset
            self.indices = indices
            self.centers = centers
            self.rots = rots
            self.n_views = self.indices.size(0)
            self.n_pixels_per_view = self.dataset.res[0] * self.dataset.res[1]
            self.colors = self.depths = self.bins = None
            self.colors_cpu = self.depths_cpu = self.bins_cpu = None
            self.loaded = False

        def release(self):
            self.colors = self.depths = self.bins = None
            self.loaded = False

        def load(self):
            if self.dataset.image_path and self.colors_cpu is None:
                images = color.cvt(
                    img.load(self.dataset.image_path % i for i in self.indices),
                    color.RGB, self.dataset.c)
                if self.dataset.res != list(images.shape[-2:]):
                    images = nn_f.interpolate(images, self.dataset.res)
                self.colors_cpu = images.permute(0, 2, 3, 1).flatten(0, 2)
            if self.colors_cpu is not None:
                self.colors = self.colors_cpu.to(self.dataset.device, non_blocking=True)

            if self.dataset.depth_path and self.depths_cpu is None:
                depths = self.dataset._decode_depth_images(
                    img.load(self.depth_path % i for i in self.indices))
                if self.dataset.res != list(depths.shape[-2:]):
                    depths = nn_f.interpolate(depths, self.dataset.res)
                self.depths_cpu = depths.flatten(0, 2)
            if self.depths_cpu is not None:
                self.depths = self.depths_cpu.to(self.dataset.device, non_blocking=True)

            if self.dataset.bins_path and self.bins_cpu is None:
                bins = img.load([self.dataset.bins_path % i for i in self.indices])
                if self.dataset.res != list(bins.shape[-2:]):
                    bins = nn_f.interpolate(bins, self.dataset.res)
                self.bins_cpu = bins.permute(0, 2, 3, 1).flatten(0, 2)
            if self.bins_cpu is not None:
                self.bins = self.bins_cpu.to(self.dataset.device, non_blocking=True)

            torch.cuda.current_stream(self.dataset.device).synchronize()
            self.loaded = True

        def __len__(self):
            return self.n_views * self.n_pixels_per_view

        def __getitem__(self, idx):
            if not self.loaded:
                self.load()
            view_idx = idx // self.n_pixels_per_view
            pix_idx = idx % self.n_pixels_per_view
            rays_o = self.centers[view_idx]
            rays_d = self.dataset.cam_rays[pix_idx] # (N, 3)
            r = self.rots[view_idx].movedim(-1, -2) # (N, 3, 3)
            rays_d = torch.matmul(rays_d, r)
            extra_data = {}
            if self.colors is not None:
                extra_data['colors'] = self.colors[idx]
            if self.depths is not None:
                extra_data['depths'] = self.depths[idx]
            if self.bins is not None:
                extra_data['bins'] = self.bins[idx]
            return idx, rays_o, rays_d, extra_data

    def __init__(self, desc: dict, *,
                 c: int = color.RGB,
                 load_images: bool = True,
                 load_depths: bool = False,
                 load_bins: bool = False,
                 res: Tuple[int, int] = None,
                 views_to_load: Union[range, torch.Tensor] = None,
                 device: torch.device = None,
                 **kwargs):
        """
        Initialize data loader for spherical view synthesis task

        The dataset description file is a JSON file with following fields:

        - view_file_pattern: string, the path pattern of view images
        - view_res: { "x", "y" }, the resolution of view images
        - cam: { "fx", "fy", "cx", "cy" }, the focal and center of camera (in normalized image space)
        - view_centers: [ [ x, y, z ], ... ], centers of views
        - view_rots: [ [ m00, m01, ..., m22 ], ... ], rotation matrices of views

        :param dataset_desc_path ```str```: path to the data description file
        :param load_images ```bool```: whether load view images and return in __getitem__()
        :param load_depths ```bool```: whether load depth images and return in __getitem__()
        :param c ```int```: color space to convert view images to
        :param calculate_rays ```bool```: whether calculate rays
        """
        self.c = c
        self.device = device
        self._load_desc(desc, res, views_to_load, load_images, load_depths, load_bins)

    def get_data(self):
        return {
            'indices': self.indices,
            'centers': self.centers,
            'rots': self.rots
        }
    
    def _decode_depth_images(self, input):
        disp_range = (1 / self.depth_range[0], 1 / self.depth_range[1])
        disp_val = (1 - input[..., 0, :, :]) * (disp_range[1] - disp_range[0]) + disp_range[0]
        return torch.reciprocal(disp_val)

    def _load_desc(self, desc: dict,
                   res: Tuple[int, int],
                   views_to_load: Union[range, torch.Tensor],
                   load_images: bool,
                   load_depths: bool,
                   load_bins: bool):
        if load_images and desc.get('view_file_pattern'):
            self.image_path = os.path.join(self.data_dir, desc['view_file_pattern'])
        else:
            self.image_path = None
        if load_depths and desc.get('depth_file_pattern'):
            self.depth_path = os.path.join(self.data_dir, desc['depth_file_pattern'])
        else:
            self.depth_path = None
        if load_bins and desc.get('bins_file_pattern'):
            self.bins_path = os.path.join(self.data_dir, desc['bins_file_pattern'])
        else:
            self.bins_path = None
        self.res = res if res else misc.values(desc['view_res'], 'y', 'x')
        self.cam = view.CameraParam(desc['cam_params'], self.res, device=self.device)
        self.depth_range = misc.values(desc['depth_range'], 'min', 'max') \
            if 'depth_range' in desc else None
        self.range = misc.values(desc['range'], 'min', 'max') if 'range' in desc else None
        self.samples = desc.get('samples')
        self.centers = torch.tensor(desc['view_centers'], device=self.device)  # (N, 3)
        self.rots = torch.tensor(
            [
                view.euler_to_matrix([rot[1] if desc.get('gl_coord') else -rot[1], rot[0], 0])
                for rot in desc['view_rots']
            ]
            if len(desc['view_rots'][0]) == 2 else desc['view_rots'],
            device=self.device).view(-1, 3, 3)  # (N, 3, 3)
        self.indices = torch.tensor(
            desc['views'] if 'views' in desc else list(range(self.centers.size(0))),
            device=self.device)

        if views_to_load is not None:
            self.centers = self.centers[views_to_load]
            self.rots = self.rots[views_to_load]
            self.indices = self.indices[views_to_load]

        self.n_views = self.centers.size(0)
        self.n_pixels = self.n_views * self.res[0] * self.res[1]

        if desc.get('gl_coord'):
            print('Convert from OGL coordinate to DX coordinate (i. e. flip z axis)')
            if not desc['cam_params'].get('fov'):
                self.cam.f[1] *= -1
            self.centers[:, 2] *= -1
            self.rots[:, 2] *= -1
            self.rots[..., 2] *= -1
        
        self.cam_rays = self.cam.get_local_rays(flatten=True)