import sys
import os
import argparse
import torch
import torch.optim
from torch import onnx
from typing import Mapping, List

sys.path.append(os.path.abspath(sys.path[0] + '/../'))

parser = argparse.ArgumentParser()
parser.add_argument('--device', type=int, default=0,
                    help='Which CUDA device to use.')
parser.add_argument('--batch-size', type=str,
                    help='Resolution')
parser.add_argument('--outdir', type=str, default='./',
                    help='Output directory')
parser.add_argument('model', type=str,
                    help='Path of model to export')
opt = parser.parse_args()

# Select device
torch.cuda.set_device(opt.device)
print("Set CUDA:%d as current device." % torch.cuda.current_device())

from nets.msl_net import *
from utils import misc
from utils import netio
from utils import device
from configs.spherical_view_syn import SphericalViewSynConfig

dir_path, model_file = os.path.split(opt.model)
batch_size = eval(opt.batch_size)
os.chdir(dir_path)

config = SphericalViewSynConfig()


def load_net(path):
    id=os.path.splitext(os.path.basename(path))[0]
    config.from_id(id)
    config.sa['perturb_sample'] = False
    batch_size_str: str = opt.batch_size.replace('*', 'x')
    config.name += batch_size_str
    config.print()
    net = config.create_net().to(device.default())
    netio.load(path, net)
    return net, id


def export_net(net: torch.nn.Module, name: str,
               input: Mapping[str, List[int]], output_names: List[str]):
    outpath = os.path.join(opt.outdir, config.to_id(), name + ".onnx")
    input_tensors = tuple([
        torch.empty(size, device=device.default())
        for size in input.values()
    ])
    onnx.export(
        net,
        input_tensors,
        outpath,
        export_params=True,  # store the trained parameter weights inside the model file
        verbose=True,
        opset_version=9,     # the ONNX version to export the model to
        do_constant_folding=True, # whether to execute constant folding
        input_names=input.keys(),   # the model's input names
        output_names=output_names # the model's output names
    )
    print('Model exported to ' + outpath)


if __name__ == "__main__":
    with torch.no_grad():
        # Load model`
        net, name = load_net(model_file)

        misc.create_dir(os.path.join(opt.outdir, config.to_id()))

        # Export Sampler
        export_net(ExportNet(net), 'msl', {
            'Encoded': [batch_size, net.n_samples, net.input_encoder.out_dim],
            'Depths': [batch_size, net.n_samples]
        }, ['Colors'])