spherical_view_syn.py 9.59 KB
Newer Older
BobYeah's avatar
sync    
BobYeah committed
1
2
import os
import json
3
import torch
Nianchen Deng's avatar
sync    
Nianchen Deng committed
4
import glm
Nianchen Deng's avatar
sync    
Nianchen Deng committed
5
import torch.nn.functional as nn_f
BobYeah's avatar
sync    
BobYeah committed
6
from typing import Tuple, Union
Nianchen Deng's avatar
sync    
Nianchen Deng committed
7
8
9
10
from utils import img
from utils import device
from utils import view
from utils import color
BobYeah's avatar
BobYeah committed
11
12


BobYeah's avatar
sync    
BobYeah committed
13
class SphericalViewSynDataset(object):
14
15
16
17
18
19
20
21
22
23
24
    """
    Data loader for spherical view synthesis task

    Attributes
    --------
    data_dir ```str```: the directory of dataset\n
    view_file_pattern ```str```: the filename pattern of view images\n
    cam_params ```object```: camera intrinsic parameters\n
    view_centers ```Tensor(N, 3)```: centers of views\n
    view_rots ```Tensor(N, 3, 3)```: rotation matrices of views\n
    view_images ```Tensor(N, 3, H, W)```: images of views\n
Nianchen Deng's avatar
sync    
Nianchen Deng committed
25
    view_depths ```Tensor(N, H, W)```: depths of views\n
26
27
    """

BobYeah's avatar
sync    
BobYeah committed
28
    def __init__(self, dataset_desc_path: str, load_images: bool = True,
Nianchen Deng's avatar
sync    
Nianchen Deng committed
29
                 load_depths: bool = False, load_bins: bool = False, c: int = color.RGB,
Nianchen Deng's avatar
sync    
Nianchen Deng committed
30
                 calculate_rays: bool = True, res: Tuple[int, int] = None):
31
32
33
34
35
36
37
        """
        Initialize data loader for spherical view synthesis task

        The dataset description file is a JSON file with following fields:

        - view_file_pattern: string, the path pattern of view images
        - view_res: { "x", "y" }, the resolution of view images
BobYeah's avatar
sync    
BobYeah committed
38
        - cam_params: { "fx", "fy", "cx", "cy" }, the focal and center of camera (in normalized image space)
39
40
41
42
43
        - view_centers: [ [ x, y, z ], ... ], centers of views
        - view_rots: [ [ m00, m01, ..., m22 ], ... ], rotation matrices of views

        :param dataset_desc_path ```str```: path to the data description file
        :param load_images ```bool```: whether load view images and return in __getitem__()
BobYeah's avatar
sync    
BobYeah committed
44
        :param load_depths ```bool```: whether load depth images and return in __getitem__()
Nianchen Deng's avatar
sync    
Nianchen Deng committed
45
        :param c ```int```: color space to convert view images to
BobYeah's avatar
sync    
BobYeah committed
46
        :param calculate_rays ```bool```: whether calculate rays
47
        """
BobYeah's avatar
sync    
BobYeah committed
48
49
        super().__init__()
        self.data_dir = os.path.dirname(dataset_desc_path)
50
        self.load_images = load_images
BobYeah's avatar
sync    
BobYeah committed
51
        self.load_depths = load_depths
Nianchen Deng's avatar
sync    
Nianchen Deng committed
52
        self.load_bins = load_bins
53
54

        # Load dataset description file
Nianchen Deng's avatar
sync    
Nianchen Deng committed
55
        self._load_desc(dataset_desc_path, res)
56
57

        # Load view images
BobYeah's avatar
BobYeah committed
58
        if self.load_images:
Nianchen Deng's avatar
sync    
Nianchen Deng committed
59
60
61
            self.view_images = color.cvt(
                img.load(self.view_file % i for i in self.view_idxs).to(device.default()),
                color.RGB, c)
Nianchen Deng's avatar
sync    
Nianchen Deng committed
62
63
            if res:
                self.view_images = nn_f.interpolate(self.view_images, res)
64
65
        else:
            self.view_images = None
Nianchen Deng's avatar
sync    
Nianchen Deng committed
66

BobYeah's avatar
sync    
BobYeah committed
67
68
69
        # Load depthmaps
        if self.load_depths:
            self.view_depths = self._decode_depth_images(
Nianchen Deng's avatar
sync    
Nianchen Deng committed
70
                img.load(self.depth_file % i for i in self.view_idxs).to(device.default()))
Nianchen Deng's avatar
sync    
Nianchen Deng committed
71
72
            if res:
                self.view_depths = nn_f.interpolate(self.view_depths, res)
BobYeah's avatar
sync    
BobYeah committed
73
74
        else:
            self.view_depths = None
75

Nianchen Deng's avatar
sync    
Nianchen Deng committed
76
77
78
79
80
81
82
83
84
85
86
87
        # Load depthmaps
        if self.load_bins:
            self.view_bins = img.load([self.bins_file % i for i in self.view_idxs], permute=False) \
                .to(device.default())
            if res:
                self.view_bins = nn_f.interpolate(self.view_bins, res)
        else:
            self.view_bins = None

        self.patched_images = self.view_images
        self.patched_depths = self.view_depths
        self.patched_bins = self.view_bins
BobYeah's avatar
BobYeah committed
88

BobYeah's avatar
sync    
BobYeah committed
89
90
91
        if calculate_rays:
            # rays_o & rays_d are both (N, H, W, 3)
            self.rays_o, self.rays_d = self.cam_params.get_global_rays(
Nianchen Deng's avatar
sync    
Nianchen Deng committed
92
                view.Trans(self.view_centers, self.view_rots))
BobYeah's avatar
sync    
BobYeah committed
93
94
            self.patched_rays_o = self.rays_o
            self.patched_rays_d = self.rays_d
BobYeah's avatar
BobYeah committed
95

Nianchen Deng's avatar
sync    
Nianchen Deng committed
96
97
98
99
    def _decode_depth_images(self, input):
        disp_range = (1 / self.depth_range[0], 1 / self.depth_range[1])
        disp_val = (1 - input[..., 0, :, :]) * (disp_range[1] - disp_range[0]) + disp_range[0]
        return torch.reciprocal(disp_val)
BobYeah's avatar
BobYeah committed
100

Nianchen Deng's avatar
sync    
Nianchen Deng committed
101
102
103
104
105
    def _euler_to_matrix(self, euler):
        q = glm.quat(glm.radians(glm.vec3(euler[0], euler[1], euler[2])))
        return glm.transpose(glm.mat3_cast(q)).to_list()

    def _load_desc(self, path, res=None):
BobYeah's avatar
sync    
BobYeah committed
106
        with open(path, 'r', encoding='utf-8') as file:
BobYeah's avatar
BobYeah committed
107
            data_desc = json.loads(file.read())
Nianchen Deng's avatar
sync    
Nianchen Deng committed
108
        if not data_desc.get('view_file_pattern'):
BobYeah's avatar
BobYeah committed
109
110
            self.load_images = False
        else:
Nianchen Deng's avatar
sync    
Nianchen Deng committed
111
112
            self.view_file = os.path.join(self.data_dir, data_desc['view_file_pattern'])
        if not data_desc.get('depth_file_pattern'):
BobYeah's avatar
sync    
BobYeah committed
113
114
            self.load_depths = False
        else:
Nianchen Deng's avatar
sync    
Nianchen Deng committed
115
116
117
118
119
120
121
122
123
124
            self.depth_file = os.path.join(self.data_dir, data_desc['depth_file_pattern'])
        if not data_desc.get('bins_file_pattern'):
            self.load_bins = False
        else:
            self.bins_file = os.path.join(self.data_dir, data_desc['bins_file_pattern'])
        self.view_res = res if res else (data_desc['view_res']['y'], data_desc['view_res']['x'])
        self.cam_params = view.CameraParam(data_desc['cam_params'], self.view_res,
                                           device=device.default())
        self.depth_range = [data_desc['depth_range']['min'], data_desc['depth_range']['max']] \
            if 'depth_range' in data_desc else None
BobYeah's avatar
sync    
BobYeah committed
125
126
127
        self.range = [data_desc['range']['min'], data_desc['range']['max']] \
            if 'range' in data_desc else None
        self.samples = data_desc['samples'] if 'samples' in data_desc else None
Nianchen Deng's avatar
sync    
Nianchen Deng committed
128
129
        self.view_centers = torch.tensor(
            data_desc['view_centers'], device=device.default())  # (N, 3)
BobYeah's avatar
BobYeah committed
130
        self.view_rots = torch.tensor(
Nianchen Deng's avatar
sync    
Nianchen Deng committed
131
132
133
134
135
            [self._euler_to_matrix([rot[1], rot[0], 0]) for rot in data_desc['view_rots']]
            if len(data_desc['view_rots'][0]) == 2 else data_desc['view_rots'],
            device=device.default()).view(-1, 3, 3)  # (N, 3, 3)
        #self.view_centers = self.view_centers[:6]
        #self.view_rots = self.view_rots[:6]
BobYeah's avatar
BobYeah committed
136
137
        self.n_views = self.view_centers.size(0)
        self.n_pixels = self.n_views * self.view_res[0] * self.view_res[1]
Nianchen Deng's avatar
sync    
Nianchen Deng committed
138
        self.view_idxs = data_desc['views'][:self.n_views] if 'views' in data_desc else range(self.n_views)
BobYeah's avatar
BobYeah committed
139

Nianchen Deng's avatar
Nianchen Deng committed
140
        if 'gl_coord' in data_desc and data_desc['gl_coord'] == True:
Nianchen Deng's avatar
sync    
Nianchen Deng committed
141
142
143
            print('Convert from OGL coordinate to DX coordinate (i. e. flip z axis)')
            if not data_desc['cam_params'].get('normalized'):
                self.cam_params.f[1] *= -1
Nianchen Deng's avatar
Nianchen Deng committed
144
145
146
147
            self.view_centers[:, 2] *= -1
            self.view_rots[:, 2] *= -1
            self.view_rots[..., 2] *= -1

BobYeah's avatar
sync    
BobYeah committed
148
149
    def set_patch_size(self, patch_size: Union[int, Tuple[int, int]],
                       offset: Union[int, Tuple[int, int]] = 0):
BobYeah's avatar
BobYeah committed
150
151
152
153
154
155
        """
        Set the size of patch and (optional) offset. If patch_size = (1, 1)

        :param patch_size: 
        :param offset: 
        """
BobYeah's avatar
sync    
BobYeah committed
156
157
158
159
        if not isinstance(patch_size, tuple):
            patch_size = (int(patch_size), int(patch_size))
        if not isinstance(offset, tuple):
            offset = (int(offset), int(offset))
BobYeah's avatar
BobYeah committed
160
161
162
163
        patches = ((self.view_res[0] - offset[0]) // patch_size[0],
                   (self.view_res[1] - offset[1]) // patch_size[1])
        slices = (..., slice(offset[0], offset[0] + patches[0] * patch_size[0]),
                  slice(offset[1], offset[1] + patches[1] * patch_size[1]))
BobYeah's avatar
sync    
BobYeah committed
164
165
166
        ray_slices = (slice(self.n_views),
                      slice(offset[0], offset[0] + patches[0] * patch_size[0]),
                      slice(offset[1], offset[1] + patches[1] * patch_size[1]))
BobYeah's avatar
BobYeah committed
167
168
169
        if patch_size[0] == 1 and patch_size[1] == 1:
            self.patched_images = self.view_images[slices] \
                .permute(0, 2, 3, 1).flatten(0, 2) if self.load_images else None
Nianchen Deng's avatar
sync    
Nianchen Deng committed
170
171
            self.patched_depths = self.view_depths[slices].flatten() if self.load_depths else None
            self.patched_bins = self.view_bins[slices].flatten(0, 2) if self.load_bins else None
BobYeah's avatar
sync    
BobYeah committed
172
173
            self.patched_rays_o = self.rays_o[ray_slices].flatten(0, 2)
            self.patched_rays_d = self.rays_d[ray_slices].flatten(0, 2)
BobYeah's avatar
BobYeah committed
174
175
        elif patch_size[0] == self.view_res[0] and patch_size[1] == self.view_res[1]:
            self.patched_images = self.view_images
Nianchen Deng's avatar
sync    
Nianchen Deng committed
176
177
            self.patched_depths = self.view_depths
            self.patched_bins = self.view_bins
BobYeah's avatar
BobYeah committed
178
179
180
181
182
183
            self.patched_rays_o = self.rays_o
            self.patched_rays_d = self.rays_d
        else:
            self.patched_images = self.view_images[slices] \
                .view(self.n_views, -1, patches[0], patch_size[0], patches[1], patch_size[1]) \
                .permute(0, 2, 4, 1, 3, 5).flatten(0, 2) if self.load_images else None
Nianchen Deng's avatar
sync    
Nianchen Deng committed
184
185
186
187
188
189
            self.patched_depths = self.view_depths[slices] \
                .view(self.n_views, patches[0], patch_size[0], patches[1], patch_size[1]) \
                .permute(0, 1, 3, 2, 4).flatten(0, 2) if self.load_depths else None
            self.patched_bins = self.view_bins[slices] \
                .view(self.n_views, patches[0], patch_size[0], patches[1], patch_size[1], -1) \
                .permute(0, 1, 3, 2, 4, 5).flatten(0, 2) if self.load_bins else None
BobYeah's avatar
sync    
BobYeah committed
190
            self.patched_rays_o = self.rays_o[ray_slices] \
BobYeah's avatar
BobYeah committed
191
192
                .view(self.n_views, patches[0], patch_size[0], patches[1], patch_size[1], -1) \
                .permute(0, 1, 3, 2, 4, 5).flatten(0, 2)
BobYeah's avatar
sync    
BobYeah committed
193
            self.patched_rays_d = self.rays_d[ray_slices] \
BobYeah's avatar
BobYeah committed
194
195
196
197
198
199
200
                .view(self.n_views, patches[0], patch_size[0], patches[1], patch_size[1], -1) \
                .permute(0, 1, 3, 2, 4, 5).flatten(0, 2)

    def __len__(self):
        return self.patched_rays_o.size(0)

    def __getitem__(self, idx):
Nianchen Deng's avatar
sync    
Nianchen Deng committed
201
202
        return idx, self.patched_images[idx] if self.load_images else None, \
            self.patched_rays_o[idx], self.patched_rays_d[idx]