import argparse import logging import os import sys from pathlib import Path from typing import List import model as mdl import train from utils import device from utils import netio from data import * from utils.misc import print_and_log RAYS_PER_BATCH = 2 ** 12 DATA_LOADER_CHUNK_SIZE = 1e8 root_dir = Path(__file__).absolute().parent parser = argparse.ArgumentParser() parser.add_argument('-c', '--config', type=str, help='Net config files') parser.add_argument('-e', '--epochs', type=int, help='Max epochs for train') parser.add_argument('--perf', type=int, help='Performance measurement frames (0 for disabling performance measurement)') parser.add_argument('--prune', type=int, nargs='+', help='Prune voxels on every # epochs') parser.add_argument('--split', type=int, nargs='+', help='Split voxels on every # epochs') parser.add_argument('--freeze', type=int, nargs='+', help='freeze levels on epochs') parser.add_argument('--checkpoint-interval', type=int) parser.add_argument('--views', type=str, help='Specify the range of views to train') parser.add_argument('path', type=str, help='Dataset description file') args = parser.parse_args() views_to_load = range(*[int(val) for val in args.views.split('-')]) if args.views else None argpath = Path(args.path) # argpath: May be model path or data path # 1) model path: continue training on the specified model # 2) data path: train a new model using specified dataset def load_dataset(data_path: Path): print(f"Loading dataset {data_path}") try: dataset = DatasetFactory.load(data_path, views_to_load=views_to_load) print(f"Dataset loaded: {dataset.root}/{dataset.name}") os.chdir(dataset.root) return dataset, dataset.name except FileNotFoundError: return load_multiscale_dataset(data_path) def load_multiscale_dataset(data_path: Path): if not data_path.is_dir(): raise ValueError( f"Path {data_path} is not a directory") dataset: List[Union[PanoDataset, ViewDataset]] = [] for sub_data_desc_path in data_path.glob("*.json"): sub_dataset = DatasetFactory.load(sub_data_desc_path, views_to_load=views_to_load) print(f"Sub-dataset loaded: {sub_dataset.root}/{sub_dataset.name}") dataset.append(sub_dataset) if len(dataset) == 0: raise ValueError(f"Path {data_path} does not contain sub-datasets") os.chdir(data_path.parent) return dataset, data_path.name try: states, checkpoint_path = netio.load_checkpoint(argpath) # Infer dataset path from model path # The model path follows such rule: /_nets///checkpoint_*.tar model_name = checkpoint_path.parts[-2] dataset, dataset_name = load_dataset( Path(*checkpoint_path.parts[:-4]) / checkpoint_path.parts[-3]) except Exception: model_name = args.config dataset, dataset_name = load_dataset(argpath) # Load state 0 from specified configuration with Path(f'{root_dir}/configs/{args.config}.json').open() as fp: states = json.load(fp) states['args']['bbox'] = dataset[0].bbox if isinstance(dataset, list) else dataset.bbox states['args']['depth_range'] = dataset[0].depth_range if isinstance(dataset, list)\ else dataset.depth_range if 'train' not in states: states['train'] = {} if args.prune is not None: states['train']['prune_epochs'] = args.prune if args.split is not None: states['train']['split_epochs'] = args.split if args.freeze is not None: states['train']['freeze_epochs'] = args.freeze if args.perf is not None: states['train']['perf_frames'] = args.perf if args.checkpoint_interval is not None: states['train']['checkpoint_interval'] = args.checkpoint_interval if args.epochs is not None: states['train']['max_epochs'] = args.epochs model = mdl.deserialize(states).to(device.default()) # Initialize run directory run_dir = Path(f"_nets/{dataset_name}/{model_name}") run_dir.mkdir(parents=True, exist_ok=True) # Initialize logging log_file = run_dir / "train.log" logging.basicConfig(format='%(asctime)s[%(levelname)s] %(message)s', level=logging.INFO, filename=log_file, filemode='a' if log_file.exists() else 'w') def log_exception(exc_type, exc_value, exc_traceback): if not issubclass(exc_type, KeyboardInterrupt): logging.exception(exc_value, exc_info=(exc_type, exc_value, exc_traceback)) sys.__excepthook__(exc_type, exc_value, exc_traceback) sys.excepthook = log_exception print_and_log(f"model: {model_name} ({model.cls})") print_and_log(f"args:") model.print_config() print(model) if __name__ == "__main__": # 1. Initialize data loader data_loader = get_loader(dataset, RAYS_PER_BATCH, chunk_max_items=DATA_LOADER_CHUNK_SIZE, shuffle=True, enable_preload=False, color=model.color) # 2. Initialize model and trainer trainer = train.get_trainer(model, run_dir, states) # 3. Train trainer.train(data_loader)