export_onnx.py 831 Bytes
Newer Older
BobYeah's avatar
sync    
BobYeah committed
1
2
3
4
import sys
import argparse
import torch
import torch.optim
Nianchen Deng's avatar
sync    
Nianchen Deng committed
5
from pathlib import Path
BobYeah's avatar
sync    
BobYeah committed
6

Nianchen Deng's avatar
sync    
Nianchen Deng committed
7
8
9
10
sys.path.append(str(Path(__file__).absolute().parent.parent))

from utils import netio
import model
BobYeah's avatar
sync    
BobYeah committed
11
12

parser = argparse.ArgumentParser()
Nianchen Deng's avatar
sync    
Nianchen Deng committed
13
parser.add_argument('--batch-size', type=str,
BobYeah's avatar
sync    
BobYeah committed
14
                    help='Resolution')
Nianchen Deng's avatar
sync    
Nianchen Deng committed
15
parser.add_argument('--outdir', type=str, default='onnx',
BobYeah's avatar
sync    
BobYeah committed
16
                    help='Output directory')
Nianchen Deng's avatar
sync    
Nianchen Deng committed
17
18
parser.add_argument('model', type=str,
                    help='Path of model to export')
BobYeah's avatar
sync    
BobYeah committed
19
20
21
opt = parser.parse_args()


Nianchen Deng's avatar
sync    
Nianchen Deng committed
22
23
24
25
with torch.inference_mode():
    states, model_path = netio.load_checkpoint(opt.model)
    batch_size = opt.batch_size and eval(opt.batch_size)
    out_dir = model_path.parent / opt.outdir
BobYeah's avatar
sync    
BobYeah committed
26

Nianchen Deng's avatar
sync    
Nianchen Deng committed
27
    model.deserialize(states).eval().export_onnx(out_dir, batch_size)
BobYeah's avatar
sync    
BobYeah committed
28

Nianchen Deng's avatar
sync    
Nianchen Deng committed
29
    print(f'Model exported to {out_dir}')