export_msl.py 2.52 KB
Newer Older
BobYeah's avatar
sync    
BobYeah committed
1
2
3
4
5
6
7
8
import sys
import os
import argparse
import torch
import torch.optim
from torch import onnx

sys.path.append(os.path.abspath(sys.path[0] + '/../'))
Nianchen Deng's avatar
sync    
Nianchen Deng committed
9
__package__ = "deep_view_syn"
BobYeah's avatar
sync    
BobYeah committed
10
11
12
13

parser = argparse.ArgumentParser()
parser.add_argument('--device', type=int, default=0,
                    help='Which CUDA device to use.')
Nianchen Deng's avatar
sync    
Nianchen Deng committed
14
parser.add_argument('--batch-size', type=str,
BobYeah's avatar
sync    
BobYeah committed
15
16
17
                    help='Resolution')
parser.add_argument('--outdir', type=str, default='./',
                    help='Output directory')
Nianchen Deng's avatar
sync    
Nianchen Deng committed
18
19
parser.add_argument('model', type=str,
                    help='Path of model to export')
BobYeah's avatar
sync    
BobYeah committed
20
21
22
23
24
25
26
27
28
29
30
31
opt = parser.parse_args()

# Select device
torch.cuda.set_device(opt.device)
print("Set CUDA:%d as current device." % torch.cuda.current_device())

from .msl_net import MslNet
from .configs.spherical_view_syn import SphericalViewSynConfig
from .my import device
from .my import netio
from .my import util

Nianchen Deng's avatar
sync    
Nianchen Deng committed
32
33
34
35
36
dir_path, model_file = os.path.split(opt.model)
batch_size = eval(opt.batch_size)
os.chdir(dir_path)

config = SphericalViewSynConfig()
BobYeah's avatar
sync    
BobYeah committed
37
38
39

def load_net(path):
    name = os.path.splitext(os.path.basename(path))[0]
Nianchen Deng's avatar
sync    
Nianchen Deng committed
40
    config.from_id(name)
BobYeah's avatar
sync    
BobYeah committed
41
42
    config.SAMPLE_PARAMS['spherical'] = True
    config.SAMPLE_PARAMS['perturb_sample'] = False
Nianchen Deng's avatar
sync    
Nianchen Deng committed
43
    config.SAMPLE_PARAMS['n_samples'] = 4
BobYeah's avatar
sync    
BobYeah committed
44
45
46
47
48
49
50
51
52
53
    config.print()
    net = MslNet(config.FC_PARAMS, config.SAMPLE_PARAMS, config.GRAY,
                 config.N_ENCODE_DIM, export_mode=True).to(device.GetDevice())
    netio.LoadNet(path, net)
    return net, name


if __name__ == "__main__":
    with torch.no_grad():
        # Load model
Nianchen Deng's avatar
sync    
Nianchen Deng committed
54
        net, name = load_net(model_file)
BobYeah's avatar
sync    
BobYeah committed
55
56

        # Input to the model
Nianchen Deng's avatar
sync    
Nianchen Deng committed
57
58
        rays_o = torch.empty(batch_size, 3, device=device.GetDevice())
        rays_d = torch.empty(batch_size, 3, device=device.GetDevice())
BobYeah's avatar
sync    
BobYeah committed
59
60
61
62

        util.CreateDirIfNeed(opt.outdir)

        # Export the model
Nianchen Deng's avatar
sync    
Nianchen Deng committed
63
        outpath = os.path.join(opt.outdir, config.to_id() + ".onnx")
BobYeah's avatar
sync    
BobYeah committed
64
65
66
67
68
69
70
71
72
73
74
75
        onnx.export(
            net,                 # model being run
            (rays_o, rays_d),    # model input (or a tuple for multiple inputs)
            outpath,
            export_params=True,  # store the trained parameter weights inside the model file
            verbose=True,
            opset_version=9,                 # the ONNX version to export the model to
            do_constant_folding=True,        # whether to execute constant folding
            input_names=['Rays_o', 'Rays_d'],  # the model's input names
            output_names=['Colors']  # the model's output names
        )
        print ('Model exported to ' + outpath)