spherical_view_syn.py 7.96 KB
Newer Older
BobYeah's avatar
sync    
BobYeah committed
1
2
import os
import json
3
4
import torch
import torchvision.transforms.functional as trans_f
Nianchen Deng's avatar
sync    
Nianchen Deng committed
5
import torch.nn.functional as nn_f
BobYeah's avatar
sync    
BobYeah committed
6
from typing import Tuple, Union
7
from ..my import util
BobYeah's avatar
BobYeah committed
8
from ..my import device
BobYeah's avatar
sync    
BobYeah committed
9
from ..my import view
Nianchen Deng's avatar
sync    
Nianchen Deng committed
10
from ..my import color_mode
BobYeah's avatar
BobYeah committed
11
12


BobYeah's avatar
sync    
BobYeah committed
13
class SphericalViewSynDataset(object):
14
15
16
17
18
19
20
21
22
23
24
25
26
    """
    Data loader for spherical view synthesis task

    Attributes
    --------
    data_dir ```str```: the directory of dataset\n
    view_file_pattern ```str```: the filename pattern of view images\n
    cam_params ```object```: camera intrinsic parameters\n
    view_centers ```Tensor(N, 3)```: centers of views\n
    view_rots ```Tensor(N, 3, 3)```: rotation matrices of views\n
    view_images ```Tensor(N, 3, H, W)```: images of views\n
    """

BobYeah's avatar
sync    
BobYeah committed
27
    def __init__(self, dataset_desc_path: str, load_images: bool = True,
Nianchen Deng's avatar
sync    
Nianchen Deng committed
28
29
                 load_depths: bool = False, color: int = color_mode.RGB,
                 calculate_rays: bool = True, res: Tuple[int, int] = None):
30
31
32
33
34
35
36
        """
        Initialize data loader for spherical view synthesis task

        The dataset description file is a JSON file with following fields:

        - view_file_pattern: string, the path pattern of view images
        - view_res: { "x", "y" }, the resolution of view images
BobYeah's avatar
sync    
BobYeah committed
37
        - cam_params: { "fx", "fy", "cx", "cy" }, the focal and center of camera (in normalized image space)
38
39
40
41
42
        - view_centers: [ [ x, y, z ], ... ], centers of views
        - view_rots: [ [ m00, m01, ..., m22 ], ... ], rotation matrices of views

        :param dataset_desc_path ```str```: path to the data description file
        :param load_images ```bool```: whether load view images and return in __getitem__()
BobYeah's avatar
sync    
BobYeah committed
43
        :param load_depths ```bool```: whether load depth images and return in __getitem__()
Nianchen Deng's avatar
sync    
Nianchen Deng committed
44
        :param color ```int```: color space to convert view images to
BobYeah's avatar
sync    
BobYeah committed
45
        :param calculate_rays ```bool```: whether calculate rays
46
        """
BobYeah's avatar
sync    
BobYeah committed
47
48
        super().__init__()
        self.data_dir = os.path.dirname(dataset_desc_path)
49
        self.load_images = load_images
BobYeah's avatar
sync    
BobYeah committed
50
        self.load_depths = load_depths
51
52

        # Load dataset description file
Nianchen Deng's avatar
sync    
Nianchen Deng committed
53
        self._load_desc(dataset_desc_path, res)
54
55

        # Load view images
BobYeah's avatar
BobYeah committed
56
        if self.load_images:
57
            self.view_images = util.ReadImageTensor(
BobYeah's avatar
sync    
BobYeah committed
58
59
60
                [self.view_file_pattern % i
                 for i in range(self.view_centers.size(0))]
            ).to(device.GetDevice())
Nianchen Deng's avatar
sync    
Nianchen Deng committed
61
            if color == color_mode.GRAY:
62
                self.view_images = trans_f.rgb_to_grayscale(self.view_images)
Nianchen Deng's avatar
sync    
Nianchen Deng committed
63
64
65
66
            elif color == color_mode.YCbCr:
                self.view_images = util.rgb2ycbcr(self.view_images)
            if res:
                self.view_images = nn_f.interpolate(self.view_images, res)
67
68
        else:
            self.view_images = None
BobYeah's avatar
sync    
BobYeah committed
69
70
71
72
73
74
75
76
77
        
        # Load depthmaps
        if self.load_depths:
            self.view_depths = self._decode_depth_images(
                util.ReadImageTensor(
                    [self.depth_file_pattern % i
                    for i in range(self.view_centers.size(0))]
                ).to(device.GetDevice()),
                self.cam_params.get_local_rays())
Nianchen Deng's avatar
sync    
Nianchen Deng committed
78
79
            if res:
                self.view_depths = nn_f.interpolate(self.view_depths, res)
BobYeah's avatar
sync    
BobYeah committed
80
81
        else:
            self.view_depths = None
82

BobYeah's avatar
sync    
BobYeah committed
83
        self.patched_images = self.view_images  # (N, 1|3, H, W)
BobYeah's avatar
BobYeah committed
84

BobYeah's avatar
sync    
BobYeah committed
85
86
87
88
89
90
        if calculate_rays:
            # rays_o & rays_d are both (N, H, W, 3)
            self.rays_o, self.rays_d = self.cam_params.get_global_rays(
                self.view_centers, self.view_rots)
            self.patched_rays_o = self.rays_o
            self.patched_rays_d = self.rays_d
BobYeah's avatar
BobYeah committed
91

BobYeah's avatar
sync    
BobYeah committed
92
93
94
95
    def _decode_depth_images(self, input, local_rays):
        output = self.depth_range[0] / input[..., 0, :, :]
        output /= local_rays[..., 2]
        return output
BobYeah's avatar
BobYeah committed
96

Nianchen Deng's avatar
sync    
Nianchen Deng committed
97
    def _load_desc(self, path, res = None):
BobYeah's avatar
sync    
BobYeah committed
98
        with open(path, 'r', encoding='utf-8') as file:
BobYeah's avatar
BobYeah committed
99
100
101
102
            data_desc = json.loads(file.read())
        if data_desc['view_file_pattern'] == '':
            self.load_images = False
        else:
BobYeah's avatar
sync    
BobYeah committed
103
104
105
106
107
108
109
            self.view_file_pattern: str = os.path.join(
                self.data_dir, data_desc['view_file_pattern'])
        if data_desc['depth_file_pattern'] == '':
            self.load_depths = False
        else:
            self.depth_file_pattern: str = os.path.join(
                self.data_dir, data_desc['depth_file_pattern'])
BobYeah's avatar
BobYeah committed
110
111
        self.view_res = (data_desc['view_res']['y'],
                         data_desc['view_res']['x'])
BobYeah's avatar
sync    
BobYeah committed
112
113
114
        self.cam_params = view.CameraParam(data_desc['cam_params'],
                                           self.view_res,
                                           device=device.GetDevice())
Nianchen Deng's avatar
sync    
Nianchen Deng committed
115
116
117
        if res:
            self.view_res = res
            self.cam_params.resize(res)
BobYeah's avatar
sync    
BobYeah committed
118
119
120
121
122
123
124
125
126
        self.depth_range = [
            data_desc['depth_range']['min'],
            data_desc['depth_range']['max']
        ] if 'range' in data_desc else None
        self.range = [data_desc['range']['min'], data_desc['range']['max']] \
            if 'range' in data_desc else None
        self.samples = data_desc['samples'] if 'samples' in data_desc else None
        self.view_centers = torch.tensor(data_desc['view_centers'],
                                         device=device.GetDevice())  # (N, 3)
BobYeah's avatar
BobYeah committed
127
128
129
130
131
        self.view_rots = torch.tensor(
            data_desc['view_rots'], device=device.GetDevice()).view(-1, 3, 3)  # (N, 3, 3)
        self.n_views = self.view_centers.size(0)
        self.n_pixels = self.n_views * self.view_res[0] * self.view_res[1]

BobYeah's avatar
sync    
BobYeah committed
132
133
    def set_patch_size(self, patch_size: Union[int, Tuple[int, int]],
                       offset: Union[int, Tuple[int, int]] = 0):
BobYeah's avatar
BobYeah committed
134
135
136
137
138
139
        """
        Set the size of patch and (optional) offset. If patch_size = (1, 1)

        :param patch_size: 
        :param offset: 
        """
BobYeah's avatar
sync    
BobYeah committed
140
141
142
143
        if not isinstance(patch_size, tuple):
            patch_size = (int(patch_size), int(patch_size))
        if not isinstance(offset, tuple):
            offset = (int(offset), int(offset))
BobYeah's avatar
BobYeah committed
144
145
146
147
        patches = ((self.view_res[0] - offset[0]) // patch_size[0],
                   (self.view_res[1] - offset[1]) // patch_size[1])
        slices = (..., slice(offset[0], offset[0] + patches[0] * patch_size[0]),
                  slice(offset[1], offset[1] + patches[1] * patch_size[1]))
BobYeah's avatar
sync    
BobYeah committed
148
149
150
        ray_slices = (slice(self.n_views),
                      slice(offset[0], offset[0] + patches[0] * patch_size[0]),
                      slice(offset[1], offset[1] + patches[1] * patch_size[1]))
BobYeah's avatar
BobYeah committed
151
152
153
        if patch_size[0] == 1 and patch_size[1] == 1:
            self.patched_images = self.view_images[slices] \
                .permute(0, 2, 3, 1).flatten(0, 2) if self.load_images else None
BobYeah's avatar
sync    
BobYeah committed
154
155
            self.patched_rays_o = self.rays_o[ray_slices].flatten(0, 2)
            self.patched_rays_d = self.rays_d[ray_slices].flatten(0, 2)
BobYeah's avatar
BobYeah committed
156
157
158
159
160
161
162
163
        elif patch_size[0] == self.view_res[0] and patch_size[1] == self.view_res[1]:
            self.patched_images = self.view_images
            self.patched_rays_o = self.rays_o
            self.patched_rays_d = self.rays_d
        else:
            self.patched_images = self.view_images[slices] \
                .view(self.n_views, -1, patches[0], patch_size[0], patches[1], patch_size[1]) \
                .permute(0, 2, 4, 1, 3, 5).flatten(0, 2) if self.load_images else None
BobYeah's avatar
sync    
BobYeah committed
164
            self.patched_rays_o = self.rays_o[ray_slices] \
BobYeah's avatar
BobYeah committed
165
166
                .view(self.n_views, patches[0], patch_size[0], patches[1], patch_size[1], -1) \
                .permute(0, 1, 3, 2, 4, 5).flatten(0, 2)
BobYeah's avatar
sync    
BobYeah committed
167
            self.patched_rays_d = self.rays_d[ray_slices] \
BobYeah's avatar
BobYeah committed
168
169
170
171
172
173
174
175
176
177
178
                .view(self.n_views, patches[0], patch_size[0], patches[1], patch_size[1], -1) \
                .permute(0, 1, 3, 2, 4, 5).flatten(0, 2)

    def __len__(self):
        return self.patched_rays_o.size(0)

    def __getitem__(self, idx):
        if self.load_images:
            return idx, self.patched_images[idx], self.patched_rays_o[idx], \
                self.patched_rays_d[idx]
        return idx, False, self.patched_rays_o[idx], self.patched_rays_d[idx]