export_onnx.py 2.41 KB
Newer Older
BobYeah's avatar
sync    
BobYeah committed
1
2
3
4
5
6
7
import sys
import os
import argparse
import torch
import torch.optim
from torch import onnx

Nianchen Deng's avatar
sync    
Nianchen Deng committed
8
9
sys.path.append(os.path.abspath(sys.path[0] + '/../../'))
__package__ = "deep_view_syn.tools"
BobYeah's avatar
sync    
BobYeah committed
10
11
12
13

parser = argparse.ArgumentParser()
parser.add_argument('--device', type=int, default=0,
                    help='Which CUDA device to use.')
Nianchen Deng's avatar
sync    
Nianchen Deng committed
14
parser.add_argument('--batch-size', type=str,
BobYeah's avatar
sync    
BobYeah committed
15
16
17
                    help='Resolution')
parser.add_argument('--outdir', type=str, default='./',
                    help='Output directory')
Nianchen Deng's avatar
sync    
Nianchen Deng committed
18
19
parser.add_argument('model', type=str,
                    help='Path of model to export')
BobYeah's avatar
sync    
BobYeah committed
20
21
22
23
24
25
opt = parser.parse_args()

# Select device
torch.cuda.set_device(opt.device)
print("Set CUDA:%d as current device." % torch.cuda.current_device())

Nianchen Deng's avatar
sync    
Nianchen Deng committed
26
27
28
29
from ..configs.spherical_view_syn import SphericalViewSynConfig
from ..my import device
from ..my import netio
from ..my import util
BobYeah's avatar
sync    
BobYeah committed
30

Nianchen Deng's avatar
sync    
Nianchen Deng committed
31
32
33
34
35
dir_path, model_file = os.path.split(opt.model)
batch_size = eval(opt.batch_size)
os.chdir(dir_path)

config = SphericalViewSynConfig()
BobYeah's avatar
sync    
BobYeah committed
36
37
38

def load_net(path):
    name = os.path.splitext(os.path.basename(path))[0]
Nianchen Deng's avatar
sync    
Nianchen Deng committed
39
    config.from_id(name)
BobYeah's avatar
sync    
BobYeah committed
40
41
    config.SAMPLE_PARAMS['spherical'] = True
    config.SAMPLE_PARAMS['perturb_sample'] = False
Nianchen Deng's avatar
sync    
Nianchen Deng committed
42
    config.SAMPLE_PARAMS['n_samples'] = 4
BobYeah's avatar
sync    
BobYeah committed
43
    config.print()
Nianchen Deng's avatar
sync    
Nianchen Deng committed
44
    net = config.create_net().to(device.GetDevice())
BobYeah's avatar
sync    
BobYeah committed
45
46
47
48
49
50
51
    netio.LoadNet(path, net)
    return net, name


if __name__ == "__main__":
    with torch.no_grad():
        # Load model
Nianchen Deng's avatar
sync    
Nianchen Deng committed
52
        net, name = load_net(model_file)
BobYeah's avatar
sync    
BobYeah committed
53
54

        # Input to the model
Nianchen Deng's avatar
sync    
Nianchen Deng committed
55
56
        rays_o = torch.empty(batch_size, 3, device=device.GetDevice())
        rays_d = torch.empty(batch_size, 3, device=device.GetDevice())
BobYeah's avatar
sync    
BobYeah committed
57
58
59
60

        util.CreateDirIfNeed(opt.outdir)

        # Export the model
Nianchen Deng's avatar
sync    
Nianchen Deng committed
61
        outpath = os.path.join(opt.outdir, config.to_id() + ".onnx")
BobYeah's avatar
sync    
BobYeah committed
62
63
64
65
66
67
68
69
70
71
72
73
        onnx.export(
            net,                 # model being run
            (rays_o, rays_d),    # model input (or a tuple for multiple inputs)
            outpath,
            export_params=True,  # store the trained parameter weights inside the model file
            verbose=True,
            opset_version=9,                 # the ONNX version to export the model to
            do_constant_folding=True,        # whether to execute constant folding
            input_names=['Rays_o', 'Rays_d'],  # the model's input names
            output_names=['Colors']  # the model's output names
        )
        print ('Model exported to ' + outpath)