from operator import itemgetter from configargparse import ArgumentParser, SUPPRESS from model import Model from train import Trainer from utils import device, netio from utils.types import * from data import Dataset def load_dataset(data_path: Path, color: str, coord: str): dataset = Dataset(data_path, color_mode=Color[color], coord_sys=coord) print(f"Load dataset: {dataset.root}/{dataset.name}") return dataset initial_parser = ArgumentParser() initial_parser.add_argument('-c', '--config', type=str, default=SUPPRESS, help='Config name, ignored if path is a checkpoint path') initial_parser.add_argument('--expname', type=str, default=SUPPRESS, help='Experiment name, defaults to config name, ignored if path is a checkpoint path') initial_parser.add_argument('path', type=str, help='Path to dataset description file or checkpoint file') initial_args = vars(initial_parser.parse_known_args()[0]) root_dir = Path(__file__).absolute().parent argpath = Path(initial_args["path"]) # May be checkpoint path or dataset path # 1) checkpoint path: continue training a model # 2) dataset path: train a new model using specified dataset ckpt_path = netio.find_checkpoint(argpath) if ckpt_path: # Continue training from a checkpoint print(f"Load checkpoint {ckpt_path}") args, states = itemgetter("args", "states")(torch.load(ckpt_path)) # args: "model", "model_args", "trainer", "trainer_args" ModelCls = Model.get_class(args["model"]) TrainerCls = Trainer.get_class(args["trainer"]) model_args = ModelCls.Args(**args["model_args"]) trainer_args = TrainerCls.Args(**args["trainer_args"]).parse() trainset = load_dataset(trainer_args.trainset, model_args.color, model_args.coord) run_dir = ckpt_path.parent else: # Start a new train expname = initial_args.get("expname", initial_args.get("config", "unnamed")) if "config" in initial_args: config_path = root_dir / "configs" / f"{initial_args['config']}.ini" if not config_path.exists(): raise ValueError(f"Config {initial_args['config']} is not found in " f"{root_dir / 'configs'}.") print(f"Load config {config_path}") else: config_path = None # First parse model class and trainer class from config file or command-line arguments parser = ArgumentParser(default_config_files=[f"{config_path}"] if config_path else []) parser.add_argument('--color', type=str, default="rgb", help='The color mode') parser.add_argument('--model', type=str, required=True, help='The model to train') parser.add_argument('--trainer', type=str, default="Trainer", help='The trainer to use for training') args = parser.parse_known_args()[0] ModelCls = Model.get_class(args.model) TrainerCls = Trainer.get_class(args.trainer) trainset_path = argpath trainset = load_dataset(trainset_path, args.color, "gl") # Then parse model's and trainer's args if trainset.depth_range: model_args = ModelCls.Args( # Some model's args are inferred from training dataset color=trainset.color_mode.name, near=trainset.depth_range[0], far=trainset.depth_range[1], white_bg=trainset.white_bg, coord=trainset.coord_sys ) else: model_args = ModelCls.Args(white_bg=trainset.white_bg) model_args.parse(config_path) trainer_args = TrainerCls.Args(trainset=f"{trainset_path}").parse(config_path) states = None run_dir = trainset.root / "_nets" / trainset.name / expname run_dir.mkdir(parents=True, exist_ok=True) m = ModelCls(model_args).to(device.default()) trainer = TrainerCls(m, run_dir, trainer_args) if states: trainer.load_state_dict(states) # Start train trainer.train(trainset)