train.py 3.78 KB
Newer Older
Nianchen Deng's avatar
sync    
Nianchen Deng committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import argparse
import logging
import os
from pathlib import Path
import sys

import model as mdl
import train
from utils import color
from utils import device
from data.dataset_factory import *
from data.loader import DataLoader
from utils.misc import list_epochs, print_and_log


RAYS_PER_BATCH = 2 ** 16
DATA_LOADER_CHUNK_SIZE = 1e8


parser = argparse.ArgumentParser()
parser.add_argument('-c', '--config', type=str,
                    help='Net config files')
parser.add_argument('-e', '--epochs', type=int, default=50,
                    help='Max epochs for train')
parser.add_argument('--perf', type=int, default=0,
                    help='Performance measurement frames (0 for disabling performance measurement)')
parser.add_argument('--prune', type=int, default=5,
                    help='Prune voxels on every # epochs')
parser.add_argument('--split', type=int, default=10,
                    help='Split voxels on every # epochs')
parser.add_argument('--views', type=str,
                    help='Specify the range of views to train')
parser.add_argument('path', type=str,
                    help='Dataset description file')
args = parser.parse_args()

argpath = Path(args.path)
# argpath: May be model path or data path
# 1) model path: continue training on the specified model
# 2) data path: train a new model using specified dataset

if argpath.suffix == ".tar":
    args.mdl_path = argpath
else:
    existed_epochs = list_epochs(argpath, "checkpoint_*.tar")
    args.mdl_path = argpath / f"checkpoint_{existed_epochs[-1]}.tar" if existed_epochs else None

if args.mdl_path:
    # Infer dataset path from model path
    # The model path follows such rule: <dataset_dir>/_nets/<dataset_name>/<model_name>/checkpoint_*.tar
    dataset_name = args.mdl_path.parent.parent.name
    dataset_dir = args.mdl_path.parent.parent.parent.parent
    args.data_path = dataset_dir / dataset_name
    args.mdl_path = args.mdl_path.relative_to(dataset_dir)
else:
    args.data_path = argpath
args.views = range(*[int(val) for val in args.views.split('-')]) if args.views else None

dataset = DatasetFactory.load(args.data_path, views_to_load=args.views)
print(f"Dataset loaded: {dataset.root}/{dataset.name}")
os.chdir(dataset.root)

if args.mdl_path:
    # Load model to continue training
    model, states = mdl.load(args.mdl_path)
    model_name = args.mdl_path.parent.name
    model_class = model.__class__.__name__
    model_args = model.args
else:
    # Create model from specified configuration
    with Path(f'{sys.path[0]}/configs/{args.config}.json').open() as fp:
        config = json.load(fp)
    model_name = args.config
    model_class = config['model']
    model_args = config['args']
    model_args['bbox'] = dataset.bbox
    model_args['depth_range'] = dataset.depth_range
    model, states = mdl.create(model_class, model_args), None
model.to(device.default()).train()

run_dir = Path(f"_nets/{dataset.name}/{model_name}")
run_dir.mkdir(parents=True, exist_ok=True)

log_file = run_dir / "train.log"
logging.basicConfig(format='%(asctime)s[%(levelname)s] %(message)s', level=logging.INFO,
                    filename=log_file, filemode='a' if log_file.exists() else 'w')

print_and_log(f"model: {model_name} ({model_class})")
print_and_log(f"args: {json.dumps(model.args0)}")


if __name__ == "__main__":
    # 1. Initialize data loader
    data_loader = DataLoader(dataset, RAYS_PER_BATCH, chunk_max_items=DATA_LOADER_CHUNK_SIZE,
                             shuffle=True, enable_preload=True,
                             color=color.from_str(model.args['color']))

    # 2. Initialize model and trainer
    trainer = train.get_trainer(model, run_dir=run_dir, states=states, perf_frames=args.perf,
                                pruning_loop=args.prune, splitting_loop=args.split)

    # 3. Train
    trainer.train(data_loader, args.epochs)