extract.py 4.69 KB
Newer Older
Nianchen Deng's avatar
sync    
Nianchen Deng committed
1
2
3
4
5
6
import json
import sys
import os
import argparse
import numpy as np
import torch
Nianchen Deng's avatar
sync    
Nianchen Deng committed
7
8
from itertools import product, repeat
from pathlib import Path
Nianchen Deng's avatar
sync    
Nianchen Deng committed
9

Nianchen Deng's avatar
sync    
Nianchen Deng committed
10
sys.path.append(os.path.abspath(sys.path[0] + '/../../'))
Nianchen Deng's avatar
sync    
Nianchen Deng committed
11
12
13

parser = argparse.ArgumentParser()
parser.add_argument('-o', '--output', type=str, default='train1')
Nianchen Deng's avatar
sync    
Nianchen Deng committed
14
15
16
parser.add_argument("-t", "--trans", type=float)
parser.add_argument("-v", "--views", type=int)
parser.add_argument('-g', '--grids', nargs='+', type=int)
Nianchen Deng's avatar
sync    
Nianchen Deng committed
17
18
19
parser.add_argument('dataset', type=str)
args = parser.parse_args()

Nianchen Deng's avatar
sync    
Nianchen Deng committed
20
21
22
23
if not args.dataset.endswith(".json"):
    args.dataset = args.dataset.rstrip("/") + ".json"
if not args.output.endswith(".json"):
    args.output = args.output.rstrip("/") + ".json"
Nianchen Deng's avatar
sync    
Nianchen Deng committed
24

Nianchen Deng's avatar
sync    
Nianchen Deng committed
25
26
27
28
29
in_desc_path = Path(args.dataset)
in_name = in_desc_path.stem
root_dir = in_desc_path.parent
out_desc_path: Path = root_dir / args.output
out_dir = out_desc_path.with_suffix("")
Nianchen Deng's avatar
sync    
Nianchen Deng committed
30

Nianchen Deng's avatar
sync    
Nianchen Deng committed
31
with open(in_desc_path, 'r') as fp:
Nianchen Deng's avatar
sync    
Nianchen Deng committed
32
33
34
35
36
37
38
39
    dataset_desc = json.load(fp)

idx = 0
'''
for i in range(3):
    for j in range(2):
        out_desc_name = f'part{idx:d}'
        out_desc = dataset_desc.copy()
Nianchen Deng's avatar
sync    
Nianchen Deng committed
40
        out_desc['color_file'] = f'{out_desc_name}/view_%04d.png'
Nianchen Deng's avatar
sync    
Nianchen Deng committed
41
42
43
44
45
        n_x = out_desc['samples'][3] // 3
        n_y = out_desc['samples'][4] // 2
        views = indices[..., i * n_x:(i + 1) * n_x, j * n_y:(j + 1) * n_y].flatten().tolist()
        out_desc['samples'] = [len(views)]
        out_desc['views'] = views
Nianchen Deng's avatar
sync    
Nianchen Deng committed
46
47
        out_desc['centers'] = np.array(dataset_desc['centers'])[views].tolist()
        out_desc['rots'] = np.array(dataset_desc['rots'])[views].tolist()
Nianchen Deng's avatar
sync    
Nianchen Deng committed
48
49
        with open(os.path.join(data_dir, f'{out_desc_name}.json'), 'w') as fp:
            json.dump(out_desc, fp, indent=4)
Nianchen Deng's avatar
sync    
Nianchen Deng committed
50
        os.makedirs(os.path.join(data_dir, out_desc_name), exist_ok=True)
Nianchen Deng's avatar
sync    
Nianchen Deng committed
51
        for k in range(len(views)):
Nianchen Deng's avatar
sync    
Nianchen Deng committed
52
53
            os.symlink(os.path.join('..', dataset_desc['color_file'] % views[k]),
                    os.path.join(data_dir, out_desc['color_file'] % views[k]))
Nianchen Deng's avatar
sync    
Nianchen Deng committed
54
55
56
57
58
59
60
61
62
        idx += 1
'''

'''
for xi in range(0, 4, 2):
    for yi in range(0, 4, 2):
        for zi in range(0, 4, 2):
            out_desc_name = f'part{idx:d}'
            out_desc = dataset_desc.copy()
Nianchen Deng's avatar
sync    
Nianchen Deng committed
63
            out_desc['color_file'] = f'{out_desc_name}/view_%04d.png'
Nianchen Deng's avatar
sync    
Nianchen Deng committed
64
65
66
            views = indices[xi:xi + 2, yi:yi + 2, zi:zi + 2].flatten().tolist()
            out_desc['samples'] = [len(views)]
            out_desc['views'] = views
Nianchen Deng's avatar
sync    
Nianchen Deng committed
67
68
            out_desc['centers'] = np.array(dataset_desc['centers'])[views].tolist()
            out_desc['rots'] = np.array(dataset_desc['rots'])[views].tolist()
Nianchen Deng's avatar
sync    
Nianchen Deng committed
69
70
            with open(os.path.join(data_dir, f'{out_desc_name}.json'), 'w') as fp:
                json.dump(out_desc, fp, indent=4)
Nianchen Deng's avatar
sync    
Nianchen Deng committed
71
            os.makedirs(os.path.join(data_dir, out_desc_name), exist_ok=True)
Nianchen Deng's avatar
sync    
Nianchen Deng committed
72
            for k in range(len(views)):
Nianchen Deng's avatar
sync    
Nianchen Deng committed
73
74
                os.symlink(os.path.join('..', dataset_desc['color_file'] % views[k]),
                           os.path.join(data_dir, out_desc['color_file'] % views[k]))
Nianchen Deng's avatar
sync    
Nianchen Deng committed
75
76
            idx += 1
'''
Nianchen Deng's avatar
sync    
Nianchen Deng committed
77
78
79


def extract_by_grid(*grid_indices):
Nianchen Deng's avatar
sync    
Nianchen Deng committed
80
    indices = torch.arange(len(dataset_desc['centers'])).view(dataset_desc['samples'])
Nianchen Deng's avatar
sync    
Nianchen Deng committed
81
82
83
84
85
86
87
88
    views = []
    for idx in product(*grid_indices):
        views += indices[idx].flatten().tolist()
    return views


def extract_by_trans(max_trans, max_views):
    if max_trans is not None:
Nianchen Deng's avatar
sync    
Nianchen Deng committed
89
        centers = np.array(dataset_desc['centers'])
Nianchen Deng's avatar
sync    
Nianchen Deng committed
90
91
92
        trans = np.linalg.norm(centers, axis=-1)
        indices = np.nonzero(trans <= max_trans)[0]
    else:
Nianchen Deng's avatar
sync    
Nianchen Deng committed
93
        indices = np.arange(len(dataset_desc['centers']))
Nianchen Deng's avatar
sync    
Nianchen Deng committed
94
95
96
97
98
99
100
101
102
103
    if max_views is not None:
        indices = np.sort(indices[np.random.permutation(indices.shape[0])[:max_views]])
    return indices.tolist()


if args.grids:
    views = extract_by_grid(*repeat(args.grids, 3))  # , [0, 2, 3, 5], [1, 2, 3, 4])
else:
    views = extract_by_trans(args.trans, args.views)

Nianchen Deng's avatar
sync    
Nianchen Deng committed
104
image_path = dataset_desc['color_file']
Nianchen Deng's avatar
sync    
Nianchen Deng committed
105
106
107
108
if "/" not in image_path:
    image_path = in_name + "/" + image_path

# Save new dataset
Nianchen Deng's avatar
sync    
Nianchen Deng committed
109
out_desc = dataset_desc.copy()
Nianchen Deng's avatar
sync    
Nianchen Deng committed
110
out_desc['color_file'] = image_path.split('/')[-1]
Nianchen Deng's avatar
sync    
Nianchen Deng committed
111
112
out_desc['samples'] = [len(views)]
out_desc['views'] = views
Nianchen Deng's avatar
sync    
Nianchen Deng committed
113
114
115
out_desc['centers'] = np.array(dataset_desc['centers'])[views].tolist()
if 'rots' in dataset_desc:
    out_desc['rots'] = np.array(dataset_desc['rots'])[views].tolist()
Nianchen Deng's avatar
sync    
Nianchen Deng committed
116
117
118

# Write new data desc
with open(out_desc_path, 'w') as fp:
Nianchen Deng's avatar
sync    
Nianchen Deng committed
119
    json.dump(out_desc, fp, indent=4)
Nianchen Deng's avatar
sync    
Nianchen Deng committed
120
121
122

# Create symbol links of images
out_dir.mkdir()
Nianchen Deng's avatar
sync    
Nianchen Deng committed
123
for k in range(len(views)):
Nianchen Deng's avatar
sync    
Nianchen Deng committed
124
125
    if out_dir.parent.absolute() == root_dir.absolute():
        os.symlink(Path("..") / (image_path % views[k]),
Nianchen Deng's avatar
sync    
Nianchen Deng committed
126
                   out_dir / (out_desc['color_file'] % views[k]))
Nianchen Deng's avatar
sync    
Nianchen Deng committed
127
128
    else:
        os.symlink(root_dir.absolute() / (image_path % views[k]),
Nianchen Deng's avatar
sync    
Nianchen Deng committed
129
                   out_dir / (out_desc['color_file'] % views[k]))