import json import sys import os import argparse import numpy as np import torch from itertools import product, repeat from pathlib import Path sys.path.append(os.path.abspath(sys.path[0] + '/../../')) parser = argparse.ArgumentParser() parser.add_argument('-o', '--output', type=str, default='train1') parser.add_argument("-t", "--trans", type=float) parser.add_argument("-v", "--views", type=int) parser.add_argument('-g', '--grids', nargs='+', type=int) parser.add_argument('dataset', type=str) args = parser.parse_args() if not args.dataset.endswith(".json"): args.dataset = args.dataset.rstrip("/") + ".json" if not args.output.endswith(".json"): args.output = args.output.rstrip("/") + ".json" in_desc_path = Path(args.dataset) in_name = in_desc_path.stem root_dir = in_desc_path.parent out_desc_path: Path = root_dir / args.output out_dir = out_desc_path.with_suffix("") with open(in_desc_path, 'r') as fp: dataset_desc = json.load(fp) idx = 0 ''' for i in range(3): for j in range(2): out_desc_name = f'part{idx:d}' out_desc = dataset_desc.copy() out_desc['view_file_pattern'] = f'{out_desc_name}/view_%04d.png' n_x = out_desc['samples'][3] // 3 n_y = out_desc['samples'][4] // 2 views = indices[..., i * n_x:(i + 1) * n_x, j * n_y:(j + 1) * n_y].flatten().tolist() out_desc['samples'] = [len(views)] out_desc['views'] = views out_desc['view_centers'] = np.array(dataset_desc['view_centers'])[views].tolist() out_desc['view_rots'] = np.array(dataset_desc['view_rots'])[views].tolist() with open(os.path.join(data_dir, f'{out_desc_name}.json'), 'w') as fp: json.dump(out_desc, fp, indent=4) os.makedirs(os.path.join(data_dir, out_desc_name), exist_ok=True) for k in range(len(views)): os.symlink(os.path.join('..', dataset_desc['view_file_pattern'] % views[k]), os.path.join(data_dir, out_desc['view_file_pattern'] % views[k])) idx += 1 ''' ''' for xi in range(0, 4, 2): for yi in range(0, 4, 2): for zi in range(0, 4, 2): out_desc_name = f'part{idx:d}' out_desc = dataset_desc.copy() out_desc['view_file_pattern'] = f'{out_desc_name}/view_%04d.png' views = indices[xi:xi + 2, yi:yi + 2, zi:zi + 2].flatten().tolist() out_desc['samples'] = [len(views)] out_desc['views'] = views out_desc['view_centers'] = np.array(dataset_desc['view_centers'])[views].tolist() out_desc['view_rots'] = np.array(dataset_desc['view_rots'])[views].tolist() with open(os.path.join(data_dir, f'{out_desc_name}.json'), 'w') as fp: json.dump(out_desc, fp, indent=4) os.makedirs(os.path.join(data_dir, out_desc_name), exist_ok=True) for k in range(len(views)): os.symlink(os.path.join('..', dataset_desc['view_file_pattern'] % views[k]), os.path.join(data_dir, out_desc['view_file_pattern'] % views[k])) idx += 1 ''' def extract_by_grid(*grid_indices): indices = torch.arange(len(dataset_desc['view_centers'])).view(dataset_desc['samples']) views = [] for idx in product(*grid_indices): views += indices[idx].flatten().tolist() return views def extract_by_trans(max_trans, max_views): if max_trans is not None: centers = np.array(dataset_desc['view_centers']) trans = np.linalg.norm(centers, axis=-1) indices = np.nonzero(trans <= max_trans)[0] else: indices = np.arange(len(dataset_desc['view_centers'])) if max_views is not None: indices = np.sort(indices[np.random.permutation(indices.shape[0])[:max_views]]) return indices.tolist() if args.grids: views = extract_by_grid(*repeat(args.grids, 3)) # , [0, 2, 3, 5], [1, 2, 3, 4]) else: views = extract_by_trans(args.trans, args.views) image_path = dataset_desc['view_file_pattern'] if "/" not in image_path: image_path = in_name + "/" + image_path # Save new dataset out_desc = dataset_desc.copy() out_desc['view_file_pattern'] = image_path.split('/')[-1] out_desc['samples'] = [len(views)] out_desc['views'] = views out_desc['view_centers'] = np.array(dataset_desc['view_centers'])[views].tolist() if 'view_rots' in dataset_desc: out_desc['view_rots'] = np.array(dataset_desc['view_rots'])[views].tolist() # Write new data desc with open(out_desc_path, 'w') as fp: json.dump(out_desc, fp, indent=4) # Create symbol links of images out_dir.mkdir() for k in range(len(views)): if out_dir.parent.absolute() == root_dir.absolute(): os.symlink(Path("..") / (image_path % views[k]), out_dir / (out_desc['view_file_pattern'] % views[k])) else: os.symlink(root_dir.absolute() / (image_path % views[k]), out_dir / (out_desc['view_file_pattern'] % views[k]))