import json import sys import os import argparse import torch from pathlib import Path sys.path.append(os.path.abspath(sys.path[0] + '/../../')) from data import DataDesc from utils.misc import calculate_autosize parser = argparse.ArgumentParser() parser.add_argument('-o', '--outputs', type=str, nargs="+", required=True, help="names of output datasets, leading with ~ to prepend the name of input dataset") parser.add_argument("-v", "--views", type=str, nargs="+", required=True, help="views of output datasets, could be -1, a positive number or a colon splited slice") parser.add_argument("--random", action="store_true") parser.add_argument('dataset', type=str) parser.usage = """ Split a dataset into one or more datasets. Examples: > python split.py path_to_dataset.json -o train test -v 20 -1 This will create two datasets "train" and "test" in the folder where input dataset locates, with the first 20 views in "train" and other views in "test". > python split.py path_to_dataset.json -o ~_train ~_test -v -1 ::8 This will create two datasets "train" and "test" in the folder where input dataset locates, with every 8 views in "test" and other views in "train". """ args = parser.parse_args() input = DataDesc.get_json_path(args.dataset) outputs = [ DataDesc.get_json_path(input.with_name( f"{input.stem}{appendix[1:]}" if appendix.startswith("~") else appendix)) for appendix in args.outputs ] with open(input, 'r') as fp: input_desc: dict = json.load(fp) n_views = len(input_desc['centers']) assert(len(args.views) == len(outputs)) indices = torch.arange(n_views) indices_assigned = torch.zeros(n_views, dtype=torch.bool) output_dataset_indices: list[torch.Tensor] = [None] * len(outputs) output_dataset_views = {} for i, output_views in enumerate(args.views): arr = output_views.split(":") if len(arr) > 1: view_slice = slice(*[int(value) if value != "" else None for value in arr]) output_dataset_indices[i] = indices[view_slice] indices_assigned[view_slice] = True else: output_dataset_views[i] = int(arr[0]) indices_remain = indices[indices_assigned.logical_not()] n_views_remain = len(indices_remain) output_dataset_views = { key: value for key, value in zip(output_dataset_views, calculate_autosize(n_views_remain, *output_dataset_views.values())[0]) } if args.random: indices_remain = indices_remain[torch.randperm(n_views_remain)] offset = 0 for key, value in output_dataset_views.items(): output_dataset_indices[key] = indices_remain[offset:offset + value] offset += value in_views = torch.tensor(input_desc["views"]) if "views" in input_desc else torch.arange(n_views) in_centers = torch.tensor(input_desc["centers"]) in_rots = torch.tensor(input_desc["rots"]) if "rots" in input_desc else None for i in range(len(outputs)): sub_indices = output_dataset_indices[i].sort()[0] output_desc = input_desc.copy() output_desc['samples'] = [len(sub_indices)] output_desc['views'] = in_views[sub_indices].tolist() output_desc['centers'] = in_centers[sub_indices].tolist() if in_rots is not None: output_desc['rots'] = in_rots[sub_indices].tolist() with open(outputs[i], 'w') as fp: json.dump(output_desc, fp, indent=4) # Create symbol links of images out_dir = outputs[i].with_suffix('') out_dir.mkdir(exist_ok=True) for k in range(len(sub_indices)): os.symlink(Path("..") / input.stem / (output_desc['color_file'] % output_desc['views'][k]), out_dir / (input_desc['color_file'] % output_desc['views'][k]))