train.py 5.04 KB
Newer Older
Nianchen Deng's avatar
sync    
Nianchen Deng committed
1
2
3
4
import argparse
import logging
import os
import sys
Nianchen Deng's avatar
sync    
Nianchen Deng committed
5
6
from pathlib import Path
from typing import List
Nianchen Deng's avatar
sync    
Nianchen Deng committed
7
8
9
10

import model as mdl
import train
from utils import device
Nianchen Deng's avatar
sync    
Nianchen Deng committed
11
12
13
from utils import netio
from data import *
from utils.misc import print_and_log
Nianchen Deng's avatar
sync    
Nianchen Deng committed
14
15


Nianchen Deng's avatar
sync    
Nianchen Deng committed
16
RAYS_PER_BATCH = 2 ** 12
Nianchen Deng's avatar
sync    
Nianchen Deng committed
17
DATA_LOADER_CHUNK_SIZE = 1e8
Nianchen Deng's avatar
sync    
Nianchen Deng committed
18
root_dir = Path(__file__).absolute().parent
Nianchen Deng's avatar
sync    
Nianchen Deng committed
19
20
21
22
23


parser = argparse.ArgumentParser()
parser.add_argument('-c', '--config', type=str,
                    help='Net config files')
Nianchen Deng's avatar
sync    
Nianchen Deng committed
24
parser.add_argument('-e', '--epochs', type=int,
Nianchen Deng's avatar
sync    
Nianchen Deng committed
25
                    help='Max epochs for train')
Nianchen Deng's avatar
sync    
Nianchen Deng committed
26
parser.add_argument('--perf', type=int,
Nianchen Deng's avatar
sync    
Nianchen Deng committed
27
                    help='Performance measurement frames (0 for disabling performance measurement)')
Nianchen Deng's avatar
sync    
Nianchen Deng committed
28
parser.add_argument('--prune', type=int, nargs='+',
Nianchen Deng's avatar
sync    
Nianchen Deng committed
29
                    help='Prune voxels on every # epochs')
Nianchen Deng's avatar
sync    
Nianchen Deng committed
30
parser.add_argument('--split', type=int, nargs='+',
Nianchen Deng's avatar
sync    
Nianchen Deng committed
31
                    help='Split voxels on every # epochs')
Nianchen Deng's avatar
sync    
Nianchen Deng committed
32
33
34
parser.add_argument('--freeze', type=int, nargs='+',
                    help='freeze levels on epochs')
parser.add_argument('--checkpoint-interval', type=int)
Nianchen Deng's avatar
sync    
Nianchen Deng committed
35
36
37
38
39
40
parser.add_argument('--views', type=str,
                    help='Specify the range of views to train')
parser.add_argument('path', type=str,
                    help='Dataset description file')
args = parser.parse_args()

Nianchen Deng's avatar
sync    
Nianchen Deng committed
41
views_to_load = range(*[int(val) for val in args.views.split('-')]) if args.views else None
Nianchen Deng's avatar
sync    
Nianchen Deng committed
42
43
44
45
46
47
argpath = Path(args.path)
# argpath: May be model path or data path
# 1) model path: continue training on the specified model
# 2) data path: train a new model using specified dataset


Nianchen Deng's avatar
sync    
Nianchen Deng committed
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
def load_dataset(data_path: Path):
    print(f"Loading dataset {data_path}")
    try:
        dataset = DatasetFactory.load(data_path, views_to_load=views_to_load)
        print(f"Dataset loaded: {dataset.root}/{dataset.name}")
        os.chdir(dataset.root)
        return dataset, dataset.name
    except FileNotFoundError:
        return load_multiscale_dataset(data_path)


def load_multiscale_dataset(data_path: Path):
    if not data_path.is_dir():
        raise ValueError(
            f"Path {data_path} is not a directory")
    dataset: List[Union[PanoDataset, ViewDataset]] = []
    for sub_data_desc_path in data_path.glob("*.json"):
        sub_dataset = DatasetFactory.load(sub_data_desc_path, views_to_load=views_to_load)
        print(f"Sub-dataset loaded: {sub_dataset.root}/{sub_dataset.name}")
        dataset.append(sub_dataset)
    if len(dataset) == 0:
        raise ValueError(f"Path {data_path} does not contain sub-datasets")
    os.chdir(data_path.parent)
    return dataset, data_path.name


try:
    states, checkpoint_path = netio.load_checkpoint(argpath)
Nianchen Deng's avatar
sync    
Nianchen Deng committed
76
77
    # Infer dataset path from model path
    # The model path follows such rule: <dataset_dir>/_nets/<dataset_name>/<model_name>/checkpoint_*.tar
Nianchen Deng's avatar
sync    
Nianchen Deng committed
78
79
80
81
    model_name = checkpoint_path.parts[-2]
    dataset, dataset_name = load_dataset(
        Path(*checkpoint_path.parts[:-4]) / checkpoint_path.parts[-3])
except Exception:
Nianchen Deng's avatar
sync    
Nianchen Deng committed
82
    model_name = args.config
Nianchen Deng's avatar
sync    
Nianchen Deng committed
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
    dataset, dataset_name = load_dataset(argpath)

    # Load state 0 from specified configuration
    with Path(f'{root_dir}/configs/{args.config}.json').open() as fp:
        states = json.load(fp)
    states['args']['bbox'] = dataset[0].bbox if isinstance(dataset, list) else dataset.bbox
    states['args']['depth_range'] = dataset[0].depth_range if isinstance(dataset, list)\
        else dataset.depth_range

if 'train' not in states:
    states['train'] = {}
if args.prune is not None:
    states['train']['prune_epochs'] = args.prune
if args.split is not None:
    states['train']['split_epochs'] = args.split
if args.freeze is not None:
    states['train']['freeze_epochs'] = args.freeze
if args.perf is not None:
    states['train']['perf_frames'] = args.perf
if args.checkpoint_interval is not None:
    states['train']['checkpoint_interval'] = args.checkpoint_interval
if args.epochs is not None:
    states['train']['max_epochs'] = args.epochs

model = mdl.deserialize(states).to(device.default())

# Initialize run directory
run_dir = Path(f"_nets/{dataset_name}/{model_name}")
Nianchen Deng's avatar
sync    
Nianchen Deng committed
111
112
run_dir.mkdir(parents=True, exist_ok=True)

Nianchen Deng's avatar
sync    
Nianchen Deng committed
113
# Initialize logging
Nianchen Deng's avatar
sync    
Nianchen Deng committed
114
115
116
117
log_file = run_dir / "train.log"
logging.basicConfig(format='%(asctime)s[%(levelname)s] %(message)s', level=logging.INFO,
                    filename=log_file, filemode='a' if log_file.exists() else 'w')

Nianchen Deng's avatar
sync    
Nianchen Deng committed
118
119
120
121
122
123
124
125
126
127
128
129
130

def log_exception(exc_type, exc_value, exc_traceback):
    if not issubclass(exc_type, KeyboardInterrupt):
        logging.exception(exc_value, exc_info=(exc_type, exc_value, exc_traceback))
    sys.__excepthook__(exc_type, exc_value, exc_traceback)


sys.excepthook = log_exception

print_and_log(f"model: {model_name} ({model.cls})")
print_and_log(f"args:")
model.print_config()
print(model)
Nianchen Deng's avatar
sync    
Nianchen Deng committed
131
132
133
134


if __name__ == "__main__":
    # 1. Initialize data loader
Nianchen Deng's avatar
sync    
Nianchen Deng committed
135
136
    data_loader = get_loader(dataset, RAYS_PER_BATCH, chunk_max_items=DATA_LOADER_CHUNK_SIZE,
                             shuffle=True, enable_preload=False, color=model.color)
Nianchen Deng's avatar
sync    
Nianchen Deng committed
137
138

    # 2. Initialize model and trainer
Nianchen Deng's avatar
sync    
Nianchen Deng committed
139
    trainer = train.get_trainer(model, run_dir, states)
Nianchen Deng's avatar
sync    
Nianchen Deng committed
140
141

    # 3. Train
Nianchen Deng's avatar
sync    
Nianchen Deng committed
142
    trainer.train(data_loader)