train.py 3.86 KB
Newer Older
Nianchen Deng's avatar
sync    
Nianchen Deng committed
1
2
from operator import itemgetter
from configargparse import ArgumentParser, SUPPRESS
Nianchen Deng's avatar
sync    
Nianchen Deng committed
3

Nianchen Deng's avatar
sync    
Nianchen Deng committed
4
5
6
7
8
from model import Model
from train import Trainer
from utils import device, netio
from utils.types import *
from data import Dataset
Nianchen Deng's avatar
sync    
Nianchen Deng committed
9
10


Nianchen Deng's avatar
sync    
Nianchen Deng committed
11
12
13
14
def load_dataset(data_path: Path, color: str, coord: str):
    dataset = Dataset(data_path, color_mode=Color[color], coord_sys=coord)
    print(f"Load dataset: {dataset.root}/{dataset.name}")
    return dataset
Nianchen Deng's avatar
sync    
Nianchen Deng committed
15
16


Nianchen Deng's avatar
sync    
Nianchen Deng committed
17
18
19
20
21
22
23
24
initial_parser = ArgumentParser()
initial_parser.add_argument('-c', '--config', type=str, default=SUPPRESS,
                            help='Config name, ignored if path is a checkpoint path')
initial_parser.add_argument('--expname', type=str, default=SUPPRESS,
                            help='Experiment name, defaults to config name, ignored if path is a checkpoint path')
initial_parser.add_argument('path', type=str,
                            help='Path to dataset description file or checkpoint file')
initial_args = vars(initial_parser.parse_known_args()[0])
Nianchen Deng's avatar
sync    
Nianchen Deng committed
25

Nianchen Deng's avatar
sync    
Nianchen Deng committed
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
root_dir = Path(__file__).absolute().parent
argpath = Path(initial_args["path"])  # May be checkpoint path or dataset path
# 1) checkpoint path: continue training a model
# 2) dataset path: train a new model using specified dataset

ckpt_path = netio.find_checkpoint(argpath)
if ckpt_path:
    # Continue training from a checkpoint
    print(f"Load checkpoint {ckpt_path}")
    args, states = itemgetter("args", "states")(torch.load(ckpt_path))
    # args: "model", "model_args", "trainer", "trainer_args"
    ModelCls = Model.get_class(args["model"])
    TrainerCls = Trainer.get_class(args["trainer"])
    model_args = ModelCls.Args(**args["model_args"])
    trainer_args = TrainerCls.Args(**args["trainer_args"]).parse()
    trainset = load_dataset(trainer_args.trainset, model_args.color, model_args.coord)
    run_dir = ckpt_path.parent
else:
    # Start a new train
    expname = initial_args.get("expname", initial_args.get("config", "unnamed"))
    if "config" in initial_args:
        config_path = root_dir / "configs" / f"{initial_args['config']}.ini"
        if not config_path.exists():
            raise ValueError(f"Config {initial_args['config']} is not found in "
                             f"{root_dir / 'configs'}.")
        print(f"Load config {config_path}")
    else:
        config_path = None

    # First parse model class and trainer class from config file or command-line arguments
    parser = ArgumentParser(default_config_files=[f"{config_path}"] if config_path else [])
    parser.add_argument('--color', type=str, default="rgb",
                        help='The color mode')
    parser.add_argument('--model', type=str, required=True,
                        help='The model to train')
    parser.add_argument('--trainer', type=str, default="Trainer",
                        help='The trainer to use for training')
    args = parser.parse_known_args()[0]
    ModelCls = Model.get_class(args.model)
    TrainerCls = Trainer.get_class(args.trainer)
    trainset_path = argpath
    trainset = load_dataset(trainset_path, args.color, "gl")

    # Then parse model's and trainer's args

    if trainset.depth_range:
        model_args = ModelCls.Args(  # Some model's args are inferred from training dataset
            color=trainset.color_mode.name,
            near=trainset.depth_range[0],
            far=trainset.depth_range[1],
            white_bg=trainset.white_bg,
            coord=trainset.coord_sys
        )
    else:
        model_args = ModelCls.Args(white_bg=trainset.white_bg)
    model_args.parse(config_path)
    trainer_args = TrainerCls.Args(trainset=f"{trainset_path}").parse(config_path)
    states = None

    run_dir = trainset.root / "_nets" / trainset.name / expname
    run_dir.mkdir(parents=True, exist_ok=True)

m = ModelCls(model_args).to(device.default())
trainer = TrainerCls(m, run_dir, trainer_args)
if states:
    trainer.load_state_dict(states)

# Start train
trainer.train(trainset)