dump_checkpoint.py 659 Bytes
Newer Older
Nianchen Deng's avatar
sync    
Nianchen Deng committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
import torch
import argparse
from operator import itemgetter

parser = argparse.ArgumentParser()
parser.add_argument("ckpt_path", type=str)
cli_args = parser.parse_args()

args, states = itemgetter("args", "states")(torch.load(cli_args.ckpt_path))

print(f"Model: {args['model']} >>>>")
for key, value in args["model_args"].items():
    print(f"{key}: {value}")
print("\n")

if args["trainer"]:
    print(f"Trainer: {args['trainer']} >>>>")
    for key, value in args["trainer_args"].items():
        print(f"{key}={value}")
    print("\n")

print("Model states >>>>")
for key, value in states["model"].items():
    print(f"{key}: Tensor{list(value.shape)}")