lf_syn.py 4.28 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
from typing import List, Tuple
import torch
import json
from ..my import util


def ReadLightField(path: str, views: Tuple[int, int], flatten_views: bool = False) -> torch.Tensor:
    input_img = util.ReadImageTensor(path, batch_dim=False)
    h = input_img.size()[1] // views[0]
    w = input_img.size()[2] // views[1]
    if flatten_views:
        lf = torch.empty(views[0] * views[1], 3, h, w)
        for y_i in range(views[0]):
            for x_i in range(views[1]):
                lf[y_i * views[1] + x_i, :, :, :] = \
                    input_img[:, y_i * h:(y_i + 1) * h, x_i * w:(x_i + 1) * w]
    else:
        lf = torch.empty(views[0], views[1], 3, h, w)
        for y_i in range(views[0]):
            for x_i in range(views[1]):
                lf[y_i, x_i, :, :, :] = \
                    input_img[:, y_i * h:(y_i + 1) * h, x_i * w:(x_i + 1) * w]
    return lf


def DecodeDepth(depth_images: torch.Tensor) -> torch.Tensor:
    return depth_images[:, 0].unsqueeze(1).mul(255) / 10


class LightFieldSynDataset(torch.utils.data.dataset.Dataset):
    """
    Data loader for light field synthesis task

    Attributes
    --------
    data_dir ```string```: the directory of dataset\n
    n_views ```tuple(int, int)```: rows and columns of views\n
    num_views ```int```: number of views\n
    view_images ```N x H x W Tensor```: images of views\n
    view_depths ```N x H x W Tensor```: depths of views\n
    view_positions ```N x 3 Tensor```: positions of views\n
    sparse_view_images ```N' x H x W Tensor```: images of sparse views\n
    sparse_view_depths ```N' x H x W Tensor```: depths of sparse views\n
    sparse_view_positions ```N' x 3 Tensor```: positions of sparse views\n
    """

    def __init__(self, data_desc_path: str):
        """
        Initialize data loader for light field synthesis task

        The data description file is a JSON file with following fields:

        - lf: string, the path of light field image
        - lf_depth: string, the path of light field depth image
        - n_views: { "x",  "y" }, columns and rows of views
        - cam_params: { "f", "c" }, the focal and center of camera (in normalized image space)
        - depth_range: [ min, max ], the range of depth in depth maps
        - depth_layers: int, number of layers in depth maps
        - view_positions: [ [ x, y, z ], ... ], positions of views

        :param data_desc_path: path to the data description file
        """
        self.data_dir = data_desc_path.rsplit('/', 1)[0] + '/'
        with open(data_desc_path, 'r', encoding='utf-8') as file:
            self.data_desc = json.loads(file.read())
        self.n_views = (self.data_desc['n_views']
                        ['y'], self.data_desc['n_views']['x'])
        self.num_views = self.n_views[0] * self.n_views[1]
        self.view_images = ReadLightField(
            self.data_dir + self.data_desc['lf'], self.n_views, True)
        self.view_depths = DecodeDepth(ReadLightField(
            self.data_dir + self.data_desc['lf_depth'], self.n_views, True))
        self.cam_params = self.data_desc['cam_params']
        self.depth_range = self.data_desc['depth_range']
        self.depth_layers = self.data_desc['depth_layers']
        self.view_positions = torch.tensor(self.data_desc['view_positions'])
        _, self.sparse_view_images, self.sparse_view_depths, self.sparse_view_positions \
            = self._GetCornerViews()
        self.diopter_of_layers = self._GetDiopterOfLayers()

    def __len__(self):
        return self.num_views

    def __getitem__(self, idx):
        return idx, self.view_images[idx], self.view_depths[idx], self.view_positions[idx]

    def _GetCornerViews(self):
        corner_selector = torch.zeros(self.num_views, dtype=torch.bool)
        corner_selector[0] = corner_selector[self.n_views[1] - 1] \
            = corner_selector[self.num_views - self.n_views[1]] \
            = corner_selector[self.num_views - 1] = True
        return self.__getitem__(corner_selector)

    def _GetDiopterOfLayers(self) -> List[float]:
        diopter_range = (1 / self.depth_range[1], 1 / self.depth_range[0])
        step = (diopter_range[1] - diopter_range[0]) / (self.depth_layers - 1)
        diopter_of_layers = [diopter_range[0] + step * i for i in range(self.depth_layers)]
        diopter_of_layers.insert(0, 0)
        return diopter_of_layers