split360.py 2.55 KB
Newer Older
Nianchen Deng's avatar
sync    
Nianchen Deng committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import json
import sys
import os
import argparse
import numpy as np
import shutil
from typing import List
from pathlib import Path

sys.path.append(os.path.abspath(sys.path[0] + '/../../'))

from data import DataDesc
from utils.misc import calculate_autosize

def run(dataset: str, outputs: list[str], views: list[int], random: bool = False):
    if len(views) != len(outputs):
        raise ValueError("")
    input = DataDesc.get_json_path(dataset)
    outputs = [
        DataDesc.get_json_path(input.with_name(f"{input.stem}_{appendix}"))
        for appendix in outputs
    ]

    with open(input, 'r') as fp:
        input_desc: dict = json.load(fp)
    n_views = len(input_desc['view_centers']) // 6

    assert(len(views) == len(outputs))
    views, sum_views = calculate_autosize(n_views, *views)

    if random:
        indices = np.random.permutation(n_views)
    else:
        indices = np.arange(n_views)
    in_views = np.array(input_desc["views"]) if "views" in input_desc else np.arange(n_views)
    in_centers = np.array(input_desc["view_centers"])
    in_rots = np.array(input_desc["view_rots"]) if "view_rots" in input_desc else None

    offset = 0
    for i in range(len(outputs)):
        n = views[i]
        end = offset + n
        sub_indices = np.sort(indices[offset:end])
        sub_indices = np.concatenate([sub_indices + j * n_views for j in range(6)], axis=0)
        output_desc = input_desc.copy()
        output_desc['samples'] = [views[i] * 6]
        output_desc['views'] = in_views[sub_indices].tolist()
        output_desc['view_centers'] = in_centers[sub_indices].tolist()
        if in_rots is not None:
            output_desc['view_rots'] = in_rots[sub_indices].tolist()
        with open(outputs[i], 'w') as fp:
            json.dump(output_desc, fp, indent=4)

        # Create symbol links of images
        out_dir: Path = outputs[i].with_suffix('')

        if out_dir.exists():        
            shutil.rmtree(out_dir)
        out_dir.mkdir()
        for view in output_desc['views']:
            filename = output_desc['color_file'] % view
            os.symlink(Path("..") / input.stem / filename, out_dir / filename)
        offset += views[i]


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('-o', '--outputs', type=str, nargs="+", required=True)
    parser.add_argument("-v", "--views", type=int, nargs="+", required=True)
    parser.add_argument("--random", action="store_true")
    parser.add_argument('dataset', type=str)
    args = parser.parse_args()
    run(args.dataset, args.outputs, args.views, args.random)