import sys import os import argparse import torch import torch.optim from torch import onnx from typing import Mapping, List sys.path.append(os.path.abspath(sys.path[0] + '/../')) parser = argparse.ArgumentParser() parser.add_argument('-b', '--batch-size', type=str, help='Resolution') parser.add_argument('-o', '--output', type=str) parser.add_argument('-t', '--trt', action="store_true") parser.add_argument('--device', type=int, default=0, help='Which CUDA device to use.') parser.add_argument('model', type=str, help='Path of model to export') opt = parser.parse_args() # Select device torch.cuda.set_device(opt.device) print("Set CUDA:%d as current device." % torch.cuda.current_device()) from nets.snerf_fast import * from utils import misc from utils import netio from utils import device from configs.spherical_view_syn import SphericalViewSynConfig dir_path, model_file = os.path.split(opt.model) if model_file.find("@") == -1: config_id = os.path.split(dir_path)[-1] else: config_id = os.path.splitext(model_file)[0] batch_size = eval(opt.batch_size) batch_size_str = opt.batch_size.replace('*', 'x') if not opt.output: if model_file.find("@") == -1: epochs = os.path.splitext(model_file)[0][12:] outdir = f"{dir_path}/output_{epochs}" output = os.path.join(outdir, f"net@{batch_size_str}.onnx") else: outdir = f"{dir_path}/export" output = os.path.join(outdir, f"{model_file.split('@')[0]}@{batch_size_str}.onnx") misc.create_dir(outdir) else: output = opt.output outname = os.path.splitext(os.path.split(output)[-1])[0] def load_net(): config = SphericalViewSynConfig() config.from_id(config_id) config.sa['perturb_sample'] = False config.name += '@' + batch_size_str config.print() net = config.create_net().to(device.default()) netio.load(opt.model, net) return net def export_net(net: torch.nn.Module, path: str, input: Mapping[str, List[int]], output_names: List[str]): input_tensors = tuple([ torch.empty(size, device=device.default()) for size in input.values() ]) onnx.export( net, input_tensors, path, export_params=True, # store the trained parameter weights inside the model file verbose=True, opset_version=9, # the ONNX version to export the model to do_constant_folding=True, # whether to execute constant folding input_names=list(input.keys()), # the model's input names output_names=output_names # the model's output names ) print('Model exported to ' + path) if __name__ == "__main__": with torch.no_grad(): net: SnerfFast = load_net() export_net( SnerfFastExport(net), output, { 'Encoded': [batch_size, net.n_samples, net.pos_encoder.out_dim], 'Depths': [batch_size, net.n_samples] }, ['Colors']) os.chdir(outdir) if opt.trt: os.system(f'trtexec --onnx={outname}.onnx --fp16 --saveEngine={outname}.trt --workspace=4096 --noDataTransfers')