export_onnx.py 2.34 KB
Newer Older
BobYeah's avatar
sync    
BobYeah committed
1
2
3
4
5
6
7
import sys
import os
import argparse
import torch
import torch.optim
from torch import onnx

Nianchen Deng's avatar
sync    
Nianchen Deng committed
8
sys.path.append(os.path.abspath(sys.path[0] + '/../'))
BobYeah's avatar
sync    
BobYeah committed
9
10
11
12

parser = argparse.ArgumentParser()
parser.add_argument('--device', type=int, default=0,
                    help='Which CUDA device to use.')
Nianchen Deng's avatar
sync    
Nianchen Deng committed
13
parser.add_argument('--batch-size', type=str,
BobYeah's avatar
sync    
BobYeah committed
14
15
16
                    help='Resolution')
parser.add_argument('--outdir', type=str, default='./',
                    help='Output directory')
Nianchen Deng's avatar
sync    
Nianchen Deng committed
17
18
parser.add_argument('model', type=str,
                    help='Path of model to export')
BobYeah's avatar
sync    
BobYeah committed
19
20
21
22
23
24
opt = parser.parse_args()

# Select device
torch.cuda.set_device(opt.device)
print("Set CUDA:%d as current device." % torch.cuda.current_device())

Nianchen Deng's avatar
sync    
Nianchen Deng committed
25
26
27
28
from configs.spherical_view_syn import SphericalViewSynConfig
from utils import device
from utils import netio
from utils import misc
BobYeah's avatar
sync    
BobYeah committed
29

Nianchen Deng's avatar
sync    
Nianchen Deng committed
30
31
32
33
34
dir_path, model_file = os.path.split(opt.model)
batch_size = eval(opt.batch_size)
os.chdir(dir_path)

config = SphericalViewSynConfig()
BobYeah's avatar
sync    
BobYeah committed
35
36
37

def load_net(path):
    name = os.path.splitext(os.path.basename(path))[0]
Nianchen Deng's avatar
sync    
Nianchen Deng committed
38
    config.from_id(name)
Nianchen Deng's avatar
Nianchen Deng committed
39
40
41
    config.sa['spherical'] = True
    config.sa['perturb_sample'] = False
    config.sa['n_samples'] = 4
BobYeah's avatar
sync    
BobYeah committed
42
    config.print()
Nianchen Deng's avatar
sync    
Nianchen Deng committed
43
44
    net = config.create_net().to(device.default())
    netio.load(path, net)
BobYeah's avatar
sync    
BobYeah committed
45
46
47
48
49
50
    return net, name


if __name__ == "__main__":
    with torch.no_grad():
        # Load model
Nianchen Deng's avatar
sync    
Nianchen Deng committed
51
        net, name = load_net(model_file)
BobYeah's avatar
sync    
BobYeah committed
52
53

        # Input to the model
Nianchen Deng's avatar
sync    
Nianchen Deng committed
54
55
        rays_o = torch.empty(batch_size, 3, device=device.default())
        rays_d = torch.empty(batch_size, 3, device=device.default())
BobYeah's avatar
sync    
BobYeah committed
56

Nianchen Deng's avatar
sync    
Nianchen Deng committed
57
        os.makedirs(opt.outdir, exist_ok=True)
BobYeah's avatar
sync    
BobYeah committed
58
59

        # Export the model
Nianchen Deng's avatar
sync    
Nianchen Deng committed
60
        outpath = os.path.join(opt.outdir, config.to_id() + ".onnx")
BobYeah's avatar
sync    
BobYeah committed
61
62
63
64
65
66
67
68
69
70
71
72
        onnx.export(
            net,                 # model being run
            (rays_o, rays_d),    # model input (or a tuple for multiple inputs)
            outpath,
            export_params=True,  # store the trained parameter weights inside the model file
            verbose=True,
            opset_version=9,                 # the ONNX version to export the model to
            do_constant_folding=True,        # whether to execute constant folding
            input_names=['Rays_o', 'Rays_d'],  # the model's input names
            output_names=['Colors']  # the model's output names
        )
        print ('Model exported to ' + outpath)