export_snerf_fast.py 3.09 KB
Newer Older
Nianchen Deng's avatar
sync    
Nianchen Deng committed
1
2
3
4
5
6
7
8
9
10
11
import sys
import os
import argparse
import torch
import torch.optim
from torch import onnx
from typing import Mapping, List

sys.path.append(os.path.abspath(sys.path[0] + '/../'))

parser = argparse.ArgumentParser()
Nianchen Deng's avatar
sync    
Nianchen Deng committed
12
13
parser.add_argument('-b', '--batch-size', type=str, help='Resolution')
parser.add_argument('-o', '--output', type=str)
Nianchen Deng's avatar
Nianchen Deng committed
14
parser.add_argument('-t', '--trt', action="store_true")
Nianchen Deng's avatar
sync    
Nianchen Deng committed
15
16
parser.add_argument('--device', type=int, default=0, help='Which CUDA device to use.')
parser.add_argument('model', type=str, help='Path of model to export')
Nianchen Deng's avatar
sync    
Nianchen Deng committed
17
18
19
20
21
22
23
24
25
26
27
28
opt = parser.parse_args()

# Select device
torch.cuda.set_device(opt.device)
print("Set CUDA:%d as current device." % torch.cuda.current_device())

from nets.snerf_fast import *
from utils import misc
from utils import netio
from utils import device
from configs.spherical_view_syn import SphericalViewSynConfig

Nianchen Deng's avatar
Nianchen Deng committed
29

Nianchen Deng's avatar
sync    
Nianchen Deng committed
30
dir_path, model_file = os.path.split(opt.model)
Nianchen Deng's avatar
Nianchen Deng committed
31
32
33
34
if model_file.find("@") == -1:
    config_id = os.path.split(dir_path)[-1]
else:
    config_id = os.path.splitext(model_file)[0]
Nianchen Deng's avatar
sync    
Nianchen Deng committed
35

Nianchen Deng's avatar
sync    
Nianchen Deng committed
36
37
38
batch_size = eval(opt.batch_size)
batch_size_str = opt.batch_size.replace('*', 'x')

Nianchen Deng's avatar
sync    
Nianchen Deng committed
39
if not opt.output:
Nianchen Deng's avatar
Nianchen Deng committed
40
41
42
43
44
45
46
    if model_file.find("@") == -1:
        epochs = os.path.splitext(model_file)[0][12:]
        outdir = f"{dir_path}/output_{epochs}"
        output = os.path.join(outdir, f"net@{batch_size_str}.onnx")
    else:
        outdir = f"{dir_path}/export"
        output = os.path.join(outdir, f"{model_file.split('@')[0]}@{batch_size_str}.onnx")
Nianchen Deng's avatar
sync    
Nianchen Deng committed
47
    os.makedirs(outdir, exist_ok=True)
Nianchen Deng's avatar
sync    
Nianchen Deng committed
48
49
else:
    output = opt.output
Nianchen Deng's avatar
Nianchen Deng committed
50
outname = os.path.splitext(os.path.split(output)[-1])[0]
Nianchen Deng's avatar
sync    
Nianchen Deng committed
51
52


Nianchen Deng's avatar
sync    
Nianchen Deng committed
53
54
55
def load_net():
    config = SphericalViewSynConfig()
    config.from_id(config_id)
Nianchen Deng's avatar
Nianchen Deng committed
56
57
    config.sa['perturb_sample'] = False
    config.name += '@' + batch_size_str
Nianchen Deng's avatar
sync    
Nianchen Deng committed
58
59
    config.print()
    net = config.create_net().to(device.default())
Nianchen Deng's avatar
sync    
Nianchen Deng committed
60
61
    netio.load(opt.model, net)
    return net
Nianchen Deng's avatar
sync    
Nianchen Deng committed
62
63


Nianchen Deng's avatar
sync    
Nianchen Deng committed
64
65
def export_net(net: torch.nn.Module, path: str, input: Mapping[str, list[int]],
               output_names: list[str]):
Nianchen Deng's avatar
sync    
Nianchen Deng committed
66
67
68
69
70
71
72
    input_tensors = tuple([
        torch.empty(size, device=device.default())
        for size in input.values()
    ])
    onnx.export(
        net,
        input_tensors,
Nianchen Deng's avatar
sync    
Nianchen Deng committed
73
        path,
Nianchen Deng's avatar
sync    
Nianchen Deng committed
74
75
76
        export_params=True,  # store the trained parameter weights inside the model file
        verbose=True,
        opset_version=9,     # the ONNX version to export the model to
Nianchen Deng's avatar
sync    
Nianchen Deng committed
77
78
79
        do_constant_folding=True,  # whether to execute constant folding
        input_names=list(input.keys()),   # the model's input names
        output_names=output_names  # the model's output names
Nianchen Deng's avatar
sync    
Nianchen Deng committed
80
    )
Nianchen Deng's avatar
sync    
Nianchen Deng committed
81
    print('Model exported to ' + path)
Nianchen Deng's avatar
sync    
Nianchen Deng committed
82
83
84
85


if __name__ == "__main__":
    with torch.no_grad():
Nianchen Deng's avatar
sync    
Nianchen Deng committed
86
87
88
89
90
        net: SnerfFast = load_net()
        export_net(
            SnerfFastExport(net),
            output,
            {
Nianchen Deng's avatar
Nianchen Deng committed
91
                'Encoded': [batch_size, net.n_samples, net.pos_encoder.out_dim],
Nianchen Deng's avatar
sync    
Nianchen Deng committed
92
93
94
                'Depths': [batch_size, net.n_samples]
            },
            ['Colors'])
Nianchen Deng's avatar
Nianchen Deng committed
95
96
97
        os.chdir(outdir)
        if opt.trt:
            os.system(f'trtexec --onnx={outname}.onnx --fp16 --saveEngine={outname}.trt --workspace=4096 --noDataTransfers')