export_msl.py 2.55 KB
Newer Older
Nianchen Deng's avatar
Nianchen Deng committed
1
2
3
4
5
6
7
8
import sys
import os
import argparse
import torch
import torch.optim
from torch import onnx
from typing import Mapping, List

Nianchen Deng's avatar
sync    
Nianchen Deng committed
9
sys.path.append(os.path.abspath(sys.path[0] + '/../'))
Nianchen Deng's avatar
Nianchen Deng committed
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25

parser = argparse.ArgumentParser()
parser.add_argument('--device', type=int, default=0,
                    help='Which CUDA device to use.')
parser.add_argument('--batch-size', type=str,
                    help='Resolution')
parser.add_argument('--outdir', type=str, default='./',
                    help='Output directory')
parser.add_argument('model', type=str,
                    help='Path of model to export')
opt = parser.parse_args()

# Select device
torch.cuda.set_device(opt.device)
print("Set CUDA:%d as current device." % torch.cuda.current_device())

Nianchen Deng's avatar
sync    
Nianchen Deng committed
26
27
28
29
30
from nets.msl_net import *
from utils import misc
from utils import netio
from utils import device
from configs.spherical_view_syn import SphericalViewSynConfig
Nianchen Deng's avatar
Nianchen Deng committed
31
32
33
34
35
36
37
38
39
40
41

dir_path, model_file = os.path.split(opt.model)
batch_size = eval(opt.batch_size)
os.chdir(dir_path)

config = SphericalViewSynConfig()


def load_net(path):
    id=os.path.splitext(os.path.basename(path))[0]
    config.from_id(id)
Nianchen Deng's avatar
Nianchen Deng committed
42
    config.sa['perturb_sample'] = False
Nianchen Deng's avatar
Nianchen Deng committed
43
44
45
    batch_size_str: str = opt.batch_size.replace('*', 'x')
    config.name += batch_size_str
    config.print()
Nianchen Deng's avatar
sync    
Nianchen Deng committed
46
47
    net = config.create_net().to(device.default())
    netio.load(path, net)
Nianchen Deng's avatar
Nianchen Deng committed
48
49
50
51
52
53
54
    return net, id


def export_net(net: torch.nn.Module, name: str,
               input: Mapping[str, List[int]], output_names: List[str]):
    outpath = os.path.join(opt.outdir, config.to_id(), name + ".onnx")
    input_tensors = tuple([
Nianchen Deng's avatar
sync    
Nianchen Deng committed
55
        torch.empty(size, device=device.default())
Nianchen Deng's avatar
Nianchen Deng committed
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
        for size in input.values()
    ])
    onnx.export(
        net,
        input_tensors,
        outpath,
        export_params=True,  # store the trained parameter weights inside the model file
        verbose=True,
        opset_version=9,     # the ONNX version to export the model to
        do_constant_folding=True, # whether to execute constant folding
        input_names=input.keys(),   # the model's input names
        output_names=output_names # the model's output names
    )
    print('Model exported to ' + outpath)


if __name__ == "__main__":
    with torch.no_grad():
        # Load model`
        net, name = load_net(model_file)

Nianchen Deng's avatar
sync    
Nianchen Deng committed
77
        misc.create_dir(os.path.join(opt.outdir, config.to_id()))
Nianchen Deng's avatar
Nianchen Deng committed
78
79
80
81
82
83

        # Export Sampler
        export_net(ExportNet(net), 'msl', {
            'Encoded': [batch_size, net.n_samples, net.input_encoder.out_dim],
            'Depths': [batch_size, net.n_samples]
        }, ['Colors'])