import sys import os import argparse import torch import torch.optim from torch import onnx sys.path.append(os.path.abspath(sys.path[0] + '/../../')) __package__ = "deep_view_syn.tools" parser = argparse.ArgumentParser() parser.add_argument('--device', type=int, default=0, help='Which CUDA device to use.') parser.add_argument('--batch-size', type=str, help='Resolution') parser.add_argument('--outdir', type=str, default='./', help='Output directory') parser.add_argument('model', type=str, help='Path of model to export') opt = parser.parse_args() # Select device torch.cuda.set_device(opt.device) print("Set CUDA:%d as current device." % torch.cuda.current_device()) from ..configs.spherical_view_syn import SphericalViewSynConfig from ..my import device from ..my import netio from ..my import util dir_path, model_file = os.path.split(opt.model) batch_size = eval(opt.batch_size) os.chdir(dir_path) config = SphericalViewSynConfig() def load_net(path): name = os.path.splitext(os.path.basename(path))[0] config.from_id(name) config.SAMPLE_PARAMS['spherical'] = True config.SAMPLE_PARAMS['perturb_sample'] = False config.SAMPLE_PARAMS['n_samples'] = 4 config.print() net = config.create_net().to(device.GetDevice()) netio.LoadNet(path, net) return net, name if __name__ == "__main__": with torch.no_grad(): # Load model net, name = load_net(model_file) # Input to the model rays_o = torch.empty(batch_size, 3, device=device.GetDevice()) rays_d = torch.empty(batch_size, 3, device=device.GetDevice()) util.CreateDirIfNeed(opt.outdir) # Export the model outpath = os.path.join(opt.outdir, config.to_id() + ".onnx") onnx.export( net, # model being run (rays_o, rays_d), # model input (or a tuple for multiple inputs) outpath, export_params=True, # store the trained parameter weights inside the model file verbose=True, opset_version=9, # the ONNX version to export the model to do_constant_folding=True, # whether to execute constant folding input_names=['Rays_o', 'Rays_d'], # the model's input names output_names=['Colors'] # the model's output names ) print ('Model exported to ' + outpath)