gen_subset.py 3.43 KB
Newer Older
Nianchen Deng's avatar
sync    
Nianchen Deng committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import json
import sys
import os
import argparse
import numpy as np

sys.path.append(os.path.abspath(sys.path[0] + '/../'))

from utils import misc

parser = argparse.ArgumentParser()
parser.add_argument('-r', '--rot-range', nargs='+', type=int)
parser.add_argument('-t', '--trans-range', nargs='+', type=float)
parser.add_argument('-k', '--trainset-ratio', type=float, default=0.7)
parser.add_argument('dataset', type=str)
args = parser.parse_args()


data_desc_path = args.dataset
data_desc_name = os.path.splitext(os.path.basename(data_desc_path))[0]
data_dir = os.path.dirname(data_desc_path) + '/'

with open(data_desc_path, 'r') as fp:
    dataset_desc = json.load(fp)

if args.trans_range:
    trans_range = np.array(args.trans_range)
else:
    trans_range = np.array(dataset_desc['range']['max'][0:3]) - \
        np.array(dataset_desc['range']['min'][0:3])
if args.rot_range:
    rot_range = np.array(args.rot_range)
else:
    rot_range = np.array(dataset_desc['range']['max'][3:5]) - \
        np.array(dataset_desc['range']['min'][3:5])
filter_range = np.concatenate([trans_range, rot_range])

out_data_dir = data_dir + 'r%dx%d_t%.1fx%.1fx%.1f/' % (
    int(rot_range[0]), int(rot_range[1]),
    trans_range[0], trans_range[1], trans_range[2]
)

dataset_version = 0
while True:
    out_trainset_name = f'train_{dataset_version}'
    out_testset_name = f'test_{dataset_version}'
    if not os.path.exists(out_data_dir + out_trainset_name):
        break
    dataset_version += 1


def in_range(val, range): return val >= -range / 2 and val <= range / 2


views = []
for i in range(len(dataset_desc['view_centers'])):
    if in_range(dataset_desc['view_rots'][i][0], rot_range[0]) and \
            in_range(dataset_desc['view_rots'][i][1], rot_range[1]) and \
            in_range(dataset_desc['view_centers'][i][0], trans_range[0]) and \
            in_range(dataset_desc['view_centers'][i][1], trans_range[1]) and \
            in_range(dataset_desc['view_centers'][i][2], trans_range[2]):
        views.append(i)

if len(views) < 100:
    print(f'Number of views in range is too small ({len(views)})')
    exit()

views = np.random.permutation(views)
n_train_views = int(len(views) * args.trainset_ratio)
train_views = np.sort(views[:n_train_views])
test_views = np.sort(views[n_train_views:])

print('Train set views: ', len(train_views))
print('Test set views: ', len(test_views))

def create_subset(views, out_desc_name):
    views = views.tolist()
    subset_desc = dataset_desc.copy()
    subset_desc['view_file_pattern'] = \
        f"{out_desc_name}/{dataset_desc['view_file_pattern'].split('/')[-1]}"
    subset_desc['range'] = {
        'min': list(-filter_range / 2),
        'max': list(filter_range / 2)
    }
    subset_desc['samples'] = [int(len(views))]
    subset_desc['views'] = views
    subset_desc['view_centers'] = np.array(dataset_desc['view_centers'])[views].tolist()
    subset_desc['view_rots'] = np.array(dataset_desc['view_rots'])[views].tolist()

    with open(os.path.join(out_data_dir, f'{out_desc_name}.json'), 'w') as fp:
        json.dump(subset_desc, fp, indent=4)
    misc.create_dir(os.path.join(out_data_dir, out_desc_name))
    for i in range(len(views)):
        os.symlink(os.path.join('../../', dataset_desc['view_file_pattern'] % views[i]),
                   os.path.join(out_data_dir, subset_desc['view_file_pattern'] % views[i]))


misc.create_dir(out_data_dir)
create_subset(train_views, out_trainset_name)
create_subset(train_views, out_testset_name)