spherical_view_syn.py 7.43 KB
Newer Older
BobYeah's avatar
sync    
BobYeah committed
1
2
import os
import json
3
4
import torch
import torchvision.transforms.functional as trans_f
BobYeah's avatar
sync    
BobYeah committed
5
from typing import Tuple, Union
6
from ..my import util
BobYeah's avatar
BobYeah committed
7
from ..my import device
BobYeah's avatar
sync    
BobYeah committed
8
from ..my import view
BobYeah's avatar
BobYeah committed
9
10


BobYeah's avatar
sync    
BobYeah committed
11
class SphericalViewSynDataset(object):
12
13
14
15
16
17
18
19
20
21
22
23
24
    """
    Data loader for spherical view synthesis task

    Attributes
    --------
    data_dir ```str```: the directory of dataset\n
    view_file_pattern ```str```: the filename pattern of view images\n
    cam_params ```object```: camera intrinsic parameters\n
    view_centers ```Tensor(N, 3)```: centers of views\n
    view_rots ```Tensor(N, 3, 3)```: rotation matrices of views\n
    view_images ```Tensor(N, 3, H, W)```: images of views\n
    """

BobYeah's avatar
sync    
BobYeah committed
25
26
    def __init__(self, dataset_desc_path: str, load_images: bool = True,
                 load_depths: bool = False, gray: bool = False, calculate_rays: bool = True):
27
28
29
30
31
32
33
        """
        Initialize data loader for spherical view synthesis task

        The dataset description file is a JSON file with following fields:

        - view_file_pattern: string, the path pattern of view images
        - view_res: { "x", "y" }, the resolution of view images
BobYeah's avatar
sync    
BobYeah committed
34
        - cam_params: { "fx", "fy", "cx", "cy" }, the focal and center of camera (in normalized image space)
35
36
37
38
39
        - view_centers: [ [ x, y, z ], ... ], centers of views
        - view_rots: [ [ m00, m01, ..., m22 ], ... ], rotation matrices of views

        :param dataset_desc_path ```str```: path to the data description file
        :param load_images ```bool```: whether load view images and return in __getitem__()
BobYeah's avatar
sync    
BobYeah committed
40
        :param load_depths ```bool```: whether load depth images and return in __getitem__()
41
        :param gray ```bool```: whether convert view images to grayscale
BobYeah's avatar
sync    
BobYeah committed
42
        :param calculate_rays ```bool```: whether calculate rays
43
        """
BobYeah's avatar
sync    
BobYeah committed
44
45
        super().__init__()
        self.data_dir = os.path.dirname(dataset_desc_path)
46
        self.load_images = load_images
BobYeah's avatar
sync    
BobYeah committed
47
        self.load_depths = load_depths
48
49

        # Load dataset description file
BobYeah's avatar
sync    
BobYeah committed
50
        self._load_desc(dataset_desc_path)
51
52

        # Load view images
BobYeah's avatar
BobYeah committed
53
        if self.load_images:
54
            self.view_images = util.ReadImageTensor(
BobYeah's avatar
sync    
BobYeah committed
55
56
57
                [self.view_file_pattern % i
                 for i in range(self.view_centers.size(0))]
            ).to(device.GetDevice())
58
59
60
61
            if gray:
                self.view_images = trans_f.rgb_to_grayscale(self.view_images)
        else:
            self.view_images = None
BobYeah's avatar
sync    
BobYeah committed
62
63
64
65
66
67
68
69
70
71
72
        
        # Load depthmaps
        if self.load_depths:
            self.view_depths = self._decode_depth_images(
                util.ReadImageTensor(
                    [self.depth_file_pattern % i
                    for i in range(self.view_centers.size(0))]
                ).to(device.GetDevice()),
                self.cam_params.get_local_rays())
        else:
            self.view_depths = None
73

BobYeah's avatar
sync    
BobYeah committed
74
        self.patched_images = self.view_images  # (N, 1|3, H, W)
BobYeah's avatar
BobYeah committed
75

BobYeah's avatar
sync    
BobYeah committed
76
77
78
79
80
81
        if calculate_rays:
            # rays_o & rays_d are both (N, H, W, 3)
            self.rays_o, self.rays_d = self.cam_params.get_global_rays(
                self.view_centers, self.view_rots)
            self.patched_rays_o = self.rays_o
            self.patched_rays_d = self.rays_d
BobYeah's avatar
BobYeah committed
82

BobYeah's avatar
sync    
BobYeah committed
83
84
85
86
    def _decode_depth_images(self, input, local_rays):
        output = self.depth_range[0] / input[..., 0, :, :]
        output /= local_rays[..., 2]
        return output
BobYeah's avatar
BobYeah committed
87

BobYeah's avatar
sync    
BobYeah committed
88
89
    def _load_desc(self, path):
        with open(path, 'r', encoding='utf-8') as file:
BobYeah's avatar
BobYeah committed
90
91
92
93
            data_desc = json.loads(file.read())
        if data_desc['view_file_pattern'] == '':
            self.load_images = False
        else:
BobYeah's avatar
sync    
BobYeah committed
94
95
96
97
98
99
100
            self.view_file_pattern: str = os.path.join(
                self.data_dir, data_desc['view_file_pattern'])
        if data_desc['depth_file_pattern'] == '':
            self.load_depths = False
        else:
            self.depth_file_pattern: str = os.path.join(
                self.data_dir, data_desc['depth_file_pattern'])
BobYeah's avatar
BobYeah committed
101
102
        self.view_res = (data_desc['view_res']['y'],
                         data_desc['view_res']['x'])
BobYeah's avatar
sync    
BobYeah committed
103
104
105
106
107
108
109
110
111
112
113
114
        self.cam_params = view.CameraParam(data_desc['cam_params'],
                                           self.view_res,
                                           device=device.GetDevice())
        self.depth_range = [
            data_desc['depth_range']['min'],
            data_desc['depth_range']['max']
        ] if 'range' in data_desc else None
        self.range = [data_desc['range']['min'], data_desc['range']['max']] \
            if 'range' in data_desc else None
        self.samples = data_desc['samples'] if 'samples' in data_desc else None
        self.view_centers = torch.tensor(data_desc['view_centers'],
                                         device=device.GetDevice())  # (N, 3)
BobYeah's avatar
BobYeah committed
115
116
117
118
119
        self.view_rots = torch.tensor(
            data_desc['view_rots'], device=device.GetDevice()).view(-1, 3, 3)  # (N, 3, 3)
        self.n_views = self.view_centers.size(0)
        self.n_pixels = self.n_views * self.view_res[0] * self.view_res[1]

BobYeah's avatar
sync    
BobYeah committed
120
121
    def set_patch_size(self, patch_size: Union[int, Tuple[int, int]],
                       offset: Union[int, Tuple[int, int]] = 0):
BobYeah's avatar
BobYeah committed
122
123
124
125
126
127
        """
        Set the size of patch and (optional) offset. If patch_size = (1, 1)

        :param patch_size: 
        :param offset: 
        """
BobYeah's avatar
sync    
BobYeah committed
128
129
130
131
        if not isinstance(patch_size, tuple):
            patch_size = (int(patch_size), int(patch_size))
        if not isinstance(offset, tuple):
            offset = (int(offset), int(offset))
BobYeah's avatar
BobYeah committed
132
133
134
135
        patches = ((self.view_res[0] - offset[0]) // patch_size[0],
                   (self.view_res[1] - offset[1]) // patch_size[1])
        slices = (..., slice(offset[0], offset[0] + patches[0] * patch_size[0]),
                  slice(offset[1], offset[1] + patches[1] * patch_size[1]))
BobYeah's avatar
sync    
BobYeah committed
136
137
138
        ray_slices = (slice(self.n_views),
                      slice(offset[0], offset[0] + patches[0] * patch_size[0]),
                      slice(offset[1], offset[1] + patches[1] * patch_size[1]))
BobYeah's avatar
BobYeah committed
139
140
141
        if patch_size[0] == 1 and patch_size[1] == 1:
            self.patched_images = self.view_images[slices] \
                .permute(0, 2, 3, 1).flatten(0, 2) if self.load_images else None
BobYeah's avatar
sync    
BobYeah committed
142
143
            self.patched_rays_o = self.rays_o[ray_slices].flatten(0, 2)
            self.patched_rays_d = self.rays_d[ray_slices].flatten(0, 2)
BobYeah's avatar
BobYeah committed
144
145
146
147
148
149
150
151
        elif patch_size[0] == self.view_res[0] and patch_size[1] == self.view_res[1]:
            self.patched_images = self.view_images
            self.patched_rays_o = self.rays_o
            self.patched_rays_d = self.rays_d
        else:
            self.patched_images = self.view_images[slices] \
                .view(self.n_views, -1, patches[0], patch_size[0], patches[1], patch_size[1]) \
                .permute(0, 2, 4, 1, 3, 5).flatten(0, 2) if self.load_images else None
BobYeah's avatar
sync    
BobYeah committed
152
            self.patched_rays_o = self.rays_o[ray_slices] \
BobYeah's avatar
BobYeah committed
153
154
                .view(self.n_views, patches[0], patch_size[0], patches[1], patch_size[1], -1) \
                .permute(0, 1, 3, 2, 4, 5).flatten(0, 2)
BobYeah's avatar
sync    
BobYeah committed
155
            self.patched_rays_d = self.rays_d[ray_slices] \
BobYeah's avatar
BobYeah committed
156
157
158
159
160
161
162
163
164
165
166
                .view(self.n_views, patches[0], patch_size[0], patches[1], patch_size[1], -1) \
                .permute(0, 1, 3, 2, 4, 5).flatten(0, 2)

    def __len__(self):
        return self.patched_rays_o.size(0)

    def __getitem__(self, idx):
        if self.load_images:
            return idx, self.patched_images[idx], self.patched_rays_o[idx], \
                self.patched_rays_d[idx]
        return idx, False, self.patched_rays_o[idx], self.patched_rays_d[idx]