import argparse import logging import os from pathlib import Path import sys import model as mdl import train from utils import color from utils import device from data.dataset_factory import * from data.loader import DataLoader from utils.misc import list_epochs, print_and_log RAYS_PER_BATCH = 2 ** 12 DATA_LOADER_CHUNK_SIZE = 1e8 root_dir = Path.cwd() parser = argparse.ArgumentParser() parser.add_argument('-c', '--config', type=str, help='Net config files') parser.add_argument('-e', '--epochs', type=int, default=50, help='Max epochs for train') parser.add_argument('--perf', type=int, default=0, help='Performance measurement frames (0 for disabling performance measurement)') parser.add_argument('--prune', type=int, default=5, help='Prune voxels on every # epochs') parser.add_argument('--split', type=int, default=10, help='Split voxels on every # epochs') parser.add_argument('--views', type=str, help='Specify the range of views to train') parser.add_argument('path', type=str, help='Dataset description file') args = parser.parse_args() argpath = Path(args.path) # argpath: May be model path or data path # 1) model path: continue training on the specified model # 2) data path: train a new model using specified dataset if argpath.suffix == ".tar": args.mdl_path = argpath else: existed_epochs = list_epochs(argpath, "checkpoint_*.tar") args.mdl_path = argpath / f"checkpoint_{existed_epochs[-1]}.tar" if existed_epochs else None if args.mdl_path: # Infer dataset path from model path # The model path follows such rule: /_nets///checkpoint_*.tar dataset_name = args.mdl_path.parent.parent.name dataset_dir = args.mdl_path.parent.parent.parent.parent args.data_path = dataset_dir / dataset_name args.mdl_path = args.mdl_path.relative_to(dataset_dir) else: args.data_path = argpath args.views = range(*[int(val) for val in args.views.split('-')]) if args.views else None dataset = DatasetFactory.load(args.data_path, views_to_load=args.views) print(f"Dataset loaded: {dataset.root}/{dataset.name}") os.chdir(dataset.root) if args.mdl_path: # Load model to continue training model, states = mdl.load(args.mdl_path) model_name = args.mdl_path.parent.name model_class = model.__class__.__name__ model_args = model.args else: # Create model from specified configuration with Path(f'{root_dir}/configs/{args.config}.json').open() as fp: config = json.load(fp) model_name = args.config model_class = config['model'] model_args = config['args'] model_args['bbox'] = dataset.bbox model_args['depth_range'] = dataset.depth_range model, states = mdl.create(model_class, model_args), None model.to(device.default()) run_dir = Path(f"_nets/{dataset.name}/{model_name}") run_dir.mkdir(parents=True, exist_ok=True) log_file = run_dir / "train.log" logging.basicConfig(format='%(asctime)s[%(levelname)s] %(message)s', level=logging.INFO, filename=log_file, filemode='a' if log_file.exists() else 'w') print_and_log(f"model: {model_name} ({model_class})") print_and_log(f"args: {json.dumps(model.args0)}") if __name__ == "__main__": # 1. Initialize data loader data_loader = DataLoader(dataset, RAYS_PER_BATCH, chunk_max_items=DATA_LOADER_CHUNK_SIZE, shuffle=True, enable_preload=True, color=color.from_str(model.args['color'])) # 2. Initialize model and trainer trainer = train.get_trainer(model, run_dir=run_dir, states=states, perf_frames=args.perf, pruning_loop=args.prune, splitting_loop=args.split) # 3. Train trainer.train(data_loader, args.epochs)