dataset.py 3.37 KB
Newer Older
Nianchen Deng's avatar
sync    
Nianchen Deng committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import torch
from operator import itemgetter
from typing import Tuple, Union
from pathlib import Path

from utils import view
from .utils import get_data_path


class Dataset(object):
    desc: dict
    desc_path: Path
    device: torch.device

    @property
    def name(self):
        return self.desc_path.stem

    @property
    def root(self):
        return self.desc_path.parent

    @property
    def n_views(self):
        return self.centers.size(0)

    @property
    def n_pixels_per_view(self):
        return self.res[0] * self.res[1]

    @property
    def n_pixels(self):
        return self.n_views * self.n_pixels_per_view

    def __init__(self, desc: dict, desc_path: Path, *,
                 res: Tuple[int, int] = None,
                 views_to_load: Union[range, torch.Tensor] = None,
                 device: torch.device = None, **kwargs) -> None:
        super().__init__()
        self.desc = desc
        self.desc_path = desc_path.absolute()
        self.device = device
        self._load_desc(res, views_to_load, **kwargs)

    def get_data(self):
        data = {
            'indices': self.indices,
            'centers': self.centers
        }
        if self.rots is not None:
            data['rots'] = self.rots
        return data

    def _get_data_path(self, name: str) -> str:
        path_pattern = self.desc.get(f"{name}_file_pattern", None)
        return path_pattern and get_data_path(self.desc_path, path_pattern)

    def _load_desc(self, res: Tuple[int, int], views_to_load: Union[range, torch.Tensor],
                   **kwargs):
        self.level = self.desc.get('level', 0)
        self.res = res or itemgetter("y", "x")(self.desc['view_res'])
        self.cam = view.CameraParam(self.desc['cam_params'], self.res, device=self.device)\
            if 'cam_params' in self.desc else None
        self.depth_range = itemgetter("min", "max")(self.desc['depth_range']) \
            if 'depth_range' in self.desc else None
        self.range = itemgetter("min", "max")(self.desc['range']) if 'range' in self.desc else None
        self.bbox = self.desc.get('bbox')
        self.samples = self.desc.get('samples')
        self.centers = torch.tensor(self.desc['view_centers'], device=self.device)  # (N, 3)
        self.rots = torch.tensor(
            [
                view.euler_to_matrix([rot[1] if self.desc.get('gl_coord') else -rot[1], rot[0], 0])
                for rot in self.desc['view_rots']
            ]
            if len(self.desc['view_rots'][0]) == 2 else self.desc['view_rots'],
            device=self.device).view(-1, 3, 3) if 'view_rots' in self.desc else None  # (N, 3, 3)
        self.indices = torch.tensor(self.desc.get('views') or [*range(self.centers.size(0))],
                                    device=self.device)

        if views_to_load is not None:
            self.centers = self.centers[views_to_load]
            self.rots = self.rots[views_to_load] if self.rots is not None else None
            self.indices = self.indices[views_to_load]

        if self.desc.get('gl_coord'):
            print('Convert from OGL coordinate to DX coordinate (i.e. flip z axis)')
            self.centers[:, 2] *= -1
            if self.cam is not None:
                if not self.desc['cam_params'].get('fov'):
                    self.cam.f[1] *= -1
            if self.rots is not None:
                self.rots[:, 2] *= -1
                self.rots[..., 2] *= -1