import sys import argparse import torch import torch.optim from pathlib import Path sys.path.append(str(Path(__file__).absolute().parent.parent)) from utils import netio import model parser = argparse.ArgumentParser() parser.add_argument('--batch-size', type=str, help='Resolution') parser.add_argument('--outdir', type=str, default='onnx', help='Output directory') parser.add_argument('model', type=str, help='Path of model to export') opt = parser.parse_args() with torch.inference_mode(): states, model_path = netio.load_checkpoint(opt.model) batch_size = opt.batch_size and eval(opt.batch_size) out_dir = model_path.parent / opt.outdir model.deserialize(states).eval().export_onnx(out_dir, batch_size) print(f'Model exported to {out_dir}')