import torch import argparse from operator import itemgetter parser = argparse.ArgumentParser() parser.add_argument("ckpt_path", type=str) cli_args = parser.parse_args() args, states = itemgetter("args", "states")(torch.load(cli_args.ckpt_path)) print(f"Model: {args['model']} >>>>") for key, value in args["model_args"].items(): print(f"{key}: {value}") print("\n") if args["trainer"]: print(f"Trainer: {args['trainer']} >>>>") for key, value in args["trainer_args"].items(): print(f"{key}={value}") print("\n") print("Model states >>>>") for key, value in states["model"].items(): print(f"{key}: Tensor{list(value.shape)}")