spherical_view_syn.py 12.2 KB
Newer Older
BobYeah's avatar
BobYeah committed
1
2
import math
from typing import Tuple
3
4
5
6
import torch
import torchvision.transforms.functional as trans_f
import json
from ..my import util
BobYeah's avatar
BobYeah committed
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
from ..my import device


def _convert_camera_params(input_camera_params, view_res):
    """
    Check and convert camera parameters in config file to pixel-space

    :param cam_params: { ["fx", "fy" | "fov"], "cx", "cy", ["normalized"] },
        the parameters of camera
    :return: camera parameters
    """
    input_is_normalized = bool(input_camera_params.get('normalized'))
    camera_params = {}
    if 'fov' in input_camera_params:
        camera_params['fx'] = camera_params['fy'] = \
            (1 if input_is_normalized else view_res[0]) / \
            util.Fov2Length(input_camera_params['fov'])
        camera_params['fy'] *= -1
    else:
        camera_params['fx'] = input_camera_params['fx']
        camera_params['fy'] = input_camera_params['fy']
    camera_params['cx'] = input_camera_params['cx']
    camera_params['cy'] = input_camera_params['cy']
    if input_is_normalized:
        camera_params['fx'] *= view_res[1]
        camera_params['fy'] *= view_res[0]
        camera_params['cx'] *= view_res[1]
        camera_params['cy'] *= view_res[0]
    return camera_params
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60


class SphericalViewSynDataset(torch.utils.data.dataset.Dataset):
    """
    Data loader for spherical view synthesis task

    Attributes
    --------
    data_dir ```str```: the directory of dataset\n
    view_file_pattern ```str```: the filename pattern of view images\n
    cam_params ```object```: camera intrinsic parameters\n
    view_centers ```Tensor(N, 3)```: centers of views\n
    view_rots ```Tensor(N, 3, 3)```: rotation matrices of views\n
    view_images ```Tensor(N, 3, H, W)```: images of views\n
    """

    def __init__(self, dataset_desc_path: str, load_images: bool = True, gray: bool = False,
                 ray_as_item=False):
        """
        Initialize data loader for spherical view synthesis task

        The dataset description file is a JSON file with following fields:

        - view_file_pattern: string, the path pattern of view images
        - view_res: { "x", "y" }, the resolution of view images
BobYeah's avatar
BobYeah committed
61
        - cam_params: { ["fx", "fy" | "fov"], "cx", "cy", ["normalized"] }, the parameters of camera
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
        - view_centers: [ [ x, y, z ], ... ], centers of views
        - view_rots: [ [ m00, m01, ..., m22 ], ... ], rotation matrices of views

        :param dataset_desc_path ```str```: path to the data description file
        :param load_images ```bool```: whether load view images and return in __getitem__()
        :param gray ```bool```: whether convert view images to grayscale
        :param ray_as_item ```bool```: whether to treat each ray in view as an item
        """
        self.data_dir = dataset_desc_path.rsplit('/', 1)[0] + '/'
        self.load_images = load_images
        self.ray_as_item = ray_as_item

        # Load dataset description file
        with open(dataset_desc_path, 'r', encoding='utf-8') as file:
            data_desc = json.loads(file.read())
BobYeah's avatar
BobYeah committed
77
78
79
80
81
        if data_desc['view_file_pattern'] == '':
            self.load_images = False
        else:
            self.view_file_pattern: str = self.data_dir + \
                data_desc['view_file_pattern']
82
83
        self.view_res = (data_desc['view_res']['y'],
                         data_desc['view_res']['x'])
BobYeah's avatar
BobYeah committed
84
85
        self.cam_params = _convert_camera_params(
            data_desc['cam_params'], self.view_res)
86
87
88
89
90
        self.view_centers = torch.tensor(data_desc['view_centers'])  # (N, 3)
        self.view_rots = torch.tensor(data_desc['view_rots']) \
            .view(-1, 3, 3)  # (N, 3, 3)

        # Load view images
BobYeah's avatar
BobYeah committed
91
        if self.load_images:
92
93
94
95
96
97
98
99
100
101
102
103
104
            self.view_images = util.ReadImageTensor(
                [self.view_file_pattern % i for i in range(self.view_centers.size(0))])
            if gray:
                self.view_images = trans_f.rgb_to_grayscale(self.view_images)
        else:
            self.view_images = None

        local_view_rays = util.GetLocalViewRays(self.cam_params,
                                                self.view_res,
                                                flatten=True)  # (M, 3)
        # Transpose matrix so we can perform vec x mat
        view_rots_t = self.view_rots.permute(0, 2, 1)

BobYeah's avatar
BobYeah committed
105
106
        # rays_o & rays_d are both (N, M, 3)
        self.rays_o = self.view_centers.unsqueeze(1) \
107
            .expand(-1, local_view_rays.size(0), -1)
BobYeah's avatar
BobYeah committed
108
        self.rays_d = torch.matmul(local_view_rays, view_rots_t)
109
110
111

        # Flatten rays if ray_as_item = True
        if ray_as_item:
BobYeah's avatar
BobYeah committed
112
113
            self.view_pixels = self.view_images.permute(0, 2, 3, 1).flatten(
                0, 2) if self.view_images != None else None
BobYeah's avatar
BobYeah committed
114
115
            self.rays_o = self.rays_o.flatten(0, 1)
            self.rays_d = self.rays_d.flatten(0, 1)
116
117

    def __len__(self):
BobYeah's avatar
BobYeah committed
118
        return self.rays_o.size(0)
119
120
121
122

    def __getitem__(self, idx):
        if self.load_images:
            if self.ray_as_item:
BobYeah's avatar
BobYeah committed
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
                return idx, self.view_pixels[idx], self.rays_o[idx], self.rays_d[idx]
            return idx, self.view_images[idx], self.rays_o[idx], self.rays_d[idx]
        return idx, False, self.rays_o[idx], self.rays_d[idx]


class FastSphericalViewSynDataset(torch.utils.data.dataset.Dataset):
    """
    Data loader for spherical view synthesis task

    Attributes
    --------
    data_dir ```str```: the directory of dataset\n
    view_file_pattern ```str```: the filename pattern of view images\n
    cam_params ```object```: camera intrinsic parameters\n
    view_centers ```Tensor(N, 3)```: centers of views\n
    view_rots ```Tensor(N, 3, 3)```: rotation matrices of views\n
    view_images ```Tensor(N, 3, H, W)```: images of views\n
    """

    def __init__(self, dataset_desc_path: str, load_images: bool = True, gray: bool = False):
        """
        Initialize data loader for spherical view synthesis task

        The dataset description file is a JSON file with following fields:

        - view_file_pattern: string, the path pattern of view images
        - view_res: { "x", "y" }, the resolution of view images
        - cam_params: { "fx", "fy", "cx", "cy" }, the focal and center of camera (in normalized image space)
        - view_centers: [ [ x, y, z ], ... ], centers of views
        - view_rots: [ [ m00, m01, ..., m22 ], ... ], rotation matrices of views

        :param dataset_desc_path ```str```: path to the data description file
        :param load_images ```bool```: whether load view images and return in __getitem__()
        :param gray ```bool```: whether convert view images to grayscale
        """
        super().__init__()
        self.data_dir = dataset_desc_path.rsplit('/', 1)[0] + '/'
        self.load_images = load_images

        # Load dataset description file
        with open(dataset_desc_path, 'r', encoding='utf-8') as file:
            data_desc = json.loads(file.read())
        if data_desc['view_file_pattern'] == '':
            self.load_images = False
        else:
            self.view_file_pattern: str = self.data_dir + \
                data_desc['view_file_pattern']
        self.view_res = (data_desc['view_res']['y'],
                         data_desc['view_res']['x'])
        self.cam_params = _convert_camera_params(
            data_desc['cam_params'], self.view_res)
        self.view_centers = torch.tensor(
            data_desc['view_centers'], device=device.GetDevice())  # (N, 3)
        self.view_rots = torch.tensor(
            data_desc['view_rots'], device=device.GetDevice()).view(-1, 3, 3)  # (N, 3, 3)
        self.n_views = self.view_centers.size(0)
        self.n_pixels = self.n_views * self.view_res[0] * self.view_res[1]

        # Load view images
        if self.load_images:
            self.view_images = util.ReadImageTensor(
                [self.view_file_pattern % i
                 for i in range(self.view_centers.size(0))]
            ).to(device.GetDevice())
            if gray:
                self.view_images = trans_f.rgb_to_grayscale(self.view_images)
        else:
            self.view_images = None

        local_view_rays = util.GetLocalViewRays(self.cam_params, self.view_res, True) \
            .to(device.GetDevice())  # (HW, 3)
        # Transpose matrix so we can perform vec x mat
        view_rots_t = self.view_rots.permute(0, 2, 1)

        # rays_o & rays_d are both (N, H, W, 3)
        self.rays_o = self.view_centers[:, None, None, :] \
            .expand(-1, self.view_res[0], self.view_res[1], -1)
        self.rays_d = torch.matmul(local_view_rays, view_rots_t) \
            .view_as(self.rays_o)
        self.patched_images = self.view_images  # (N, 1|3, H, W)
        self.patched_rays_o = self.rays_o  # (N, H, W, 3)
        self.patched_rays_d = self.rays_d  # (N, H, W, 3)

    def set_patch_size(self, patch_size: Tuple[int, int], offset: Tuple[int, int] = (0, 0)):
        """
        Set the size of patch and (optional) offset. If patch_size = (1, 1)

        :param patch_size: 
        :param offset: 
        """
        patches = ((self.view_res[0] - offset[0]) // patch_size[0],
                   (self.view_res[1] - offset[1]) // patch_size[1])
        slices = (..., slice(offset[0], offset[0] + patches[0] * patch_size[0]),
                  slice(offset[1], offset[1] + patches[1] * patch_size[1]))
        if patch_size[0] == 1 and patch_size[1] == 1:
            self.patched_images = self.view_images[slices] \
                .permute(0, 2, 3, 1).flatten(0, 2) if self.load_images else None
            self.patched_rays_o = self.rays_o[slices].flatten(0, 2)
            self.patched_rays_d = self.rays_d[slices].flatten(0, 2)
        elif patch_size[0] == self.view_res[0] and patch_size[1] == self.view_res[1]:
            self.patched_images = self.view_images
            self.patched_rays_o = self.rays_o
            self.patched_rays_d = self.rays_d
        else:
            print(self.view_images.size(), self.rays_o.size())
            print(self.view_images[slices].size(), self.rays_o[slices].size())
            self.patched_images = self.view_images[slices] \
                .view(self.n_views, -1, patches[0], patch_size[0], patches[1], patch_size[1]) \
                .permute(0, 2, 4, 1, 3, 5).flatten(0, 2) if self.load_images else None
            self.patched_rays_o = self.rays_o[slices] \
                .view(self.n_views, patches[0], patch_size[0], patches[1], patch_size[1], -1) \
                .permute(0, 1, 3, 2, 4, 5).flatten(0, 2)
            self.patched_rays_d = self.rays_d[slices] \
                .view(self.n_views, patches[0], patch_size[0], patches[1], patch_size[1], -1) \
                .permute(0, 1, 3, 2, 4, 5).flatten(0, 2)

    def __len__(self):
        return self.patched_rays_o.size(0)

    def __getitem__(self, idx):
        if self.load_images:
            return idx, self.patched_images[idx], self.patched_rays_o[idx], \
                self.patched_rays_d[idx]
        return idx, False, self.patched_rays_o[idx], self.patched_rays_d[idx]


class FastDataLoader(object):

    class Iter(object):

        def __init__(self, dataset, batch_size, shuffle, drop_last) -> None:
            super().__init__()
            self.indices = torch.randperm(len(dataset), device=device.GetDevice()) \
                if shuffle else torch.arange(len(dataset), device=device.GetDevice())
            self.offset = 0
            self.batch_size = batch_size
            self.dataset = dataset
            self.drop_last = drop_last

        def __next__(self):
            if self.offset + (self.batch_size if self.drop_last else 0) >= len(self.dataset):
                raise StopIteration()
            indices = self.indices[self.offset:self.offset + self.batch_size]
            self.offset += self.batch_size
            return self.dataset[indices]

    def __init__(self, dataset, batch_size, shuffle, drop_last, **kwargs) -> None:
        super().__init__()
        self.dataset = dataset
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.drop_last = drop_last

    def __iter__(self):
        return FastDataLoader.Iter(self.dataset, self.batch_size,
                                   self.shuffle, self.drop_last)

    def __len__(self):
        return math.floor(len(self.dataset) / self.batch_size) if self.drop_last \
            else math.ceil(len(self.dataset) / self.batch_size)