split_dataset.py 3.59 KB
Newer Older
Nianchen Deng's avatar
sync    
Nianchen Deng committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import json
import sys
import os
import argparse
import numpy as np
import torch

sys.path.append(os.path.abspath(sys.path[0] + '/../'))

from utils import misc

parser = argparse.ArgumentParser()
parser.add_argument('-o', '--output', type=str, default='train1')
parser.add_argument('dataset', type=str)
args = parser.parse_args()


data_desc_path = args.dataset
data_desc_name = os.path.splitext(os.path.basename(data_desc_path))[0]
data_dir = os.path.dirname(data_desc_path) + '/'

with open(data_desc_path, 'r') as fp:
    dataset_desc = json.load(fp)

indices = torch.arange(len(dataset_desc['view_centers'])).view(dataset_desc['samples'])

idx = 0
'''
for i in range(3):
    for j in range(2):
        out_desc_name = f'part{idx:d}'
        out_desc = dataset_desc.copy()
        out_desc['view_file_pattern'] = f'{out_desc_name}/view_%04d.png'
        n_x = out_desc['samples'][3] // 3
        n_y = out_desc['samples'][4] // 2
        views = indices[..., i * n_x:(i + 1) * n_x, j * n_y:(j + 1) * n_y].flatten().tolist()
        out_desc['samples'] = [len(views)]
        out_desc['views'] = views
        out_desc['view_centers'] = np.array(dataset_desc['view_centers'])[views].tolist()
        out_desc['view_rots'] = np.array(dataset_desc['view_rots'])[views].tolist()
        with open(os.path.join(data_dir, f'{out_desc_name}.json'), 'w') as fp:
            json.dump(out_desc, fp, indent=4)
        misc.create_dir(os.path.join(data_dir, out_desc_name))
        for k in range(len(views)):
            os.symlink(os.path.join('..', dataset_desc['view_file_pattern'] % views[k]),
                    os.path.join(data_dir, out_desc['view_file_pattern'] % views[k]))
        idx += 1
'''

'''
for xi in range(0, 4, 2):
    for yi in range(0, 4, 2):
        for zi in range(0, 4, 2):
            out_desc_name = f'part{idx:d}'
            out_desc = dataset_desc.copy()
            out_desc['view_file_pattern'] = f'{out_desc_name}/view_%04d.png'
            views = indices[xi:xi + 2, yi:yi + 2, zi:zi + 2].flatten().tolist()
            out_desc['samples'] = [len(views)]
            out_desc['views'] = views
            out_desc['view_centers'] = np.array(dataset_desc['view_centers'])[views].tolist()
            out_desc['view_rots'] = np.array(dataset_desc['view_rots'])[views].tolist()
            with open(os.path.join(data_dir, f'{out_desc_name}.json'), 'w') as fp:
                json.dump(out_desc, fp, indent=4)
            misc.create_dir(os.path.join(data_dir, out_desc_name))
            for k in range(len(views)):
                os.symlink(os.path.join('..', dataset_desc['view_file_pattern'] % views[k]),
                           os.path.join(data_dir, out_desc['view_file_pattern'] % views[k]))
            idx += 1
'''
from itertools import product
out_desc_name = args.output
out_desc = dataset_desc.copy()
out_desc['view_file_pattern'] = f"{out_desc_name}/{dataset_desc['view_file_pattern'].split('/')[-1]}"
views = []
for idx in product([0, 2, 4], [0, 2, 4], [0, 2, 4], list(range(9)), [1]):#, [0, 2, 3, 5], [1, 2, 3, 4]):
    views += indices[idx].flatten().tolist()
out_desc['samples'] = [len(views)]
out_desc['views'] = views
out_desc['view_centers'] = np.array(dataset_desc['view_centers'])[views].tolist()
out_desc['view_rots'] = np.array(dataset_desc['view_rots'])[views].tolist()
with open(os.path.join(data_dir, f'{out_desc_name}.json'), 'w') as fp:
    json.dump(out_desc, fp, indent=4)
misc.create_dir(os.path.join(data_dir, out_desc_name))
for k in range(len(views)):
    os.symlink(os.path.join('..', dataset_desc['view_file_pattern'] % views[k]),
               os.path.join(data_dir, out_desc['view_file_pattern'] % views[k]))