import sys
import os
import argparse
import torch
import torch.optim
from torch import onnx
from typing import Mapping, List

sys.path.append(os.path.abspath(sys.path[0] + '/../'))

parser = argparse.ArgumentParser()
parser.add_argument('-b', '--batch-size', type=str, help='Resolution')
parser.add_argument('-o', '--output', type=str)
parser.add_argument('-t', '--trt', action="store_true")
parser.add_argument('--device', type=int, default=0, help='Which CUDA device to use.')
parser.add_argument('model', type=str, help='Path of model to export')
opt = parser.parse_args()

# Select device
torch.cuda.set_device(opt.device)
print("Set CUDA:%d as current device." % torch.cuda.current_device())

from nets.snerf_fast import *
from utils import misc
from utils import netio
from utils import device
from configs.spherical_view_syn import SphericalViewSynConfig


dir_path, model_file = os.path.split(opt.model)
if model_file.find("@") == -1:
    config_id = os.path.split(dir_path)[-1]
else:
    config_id = os.path.splitext(model_file)[0]

batch_size = eval(opt.batch_size)
batch_size_str = opt.batch_size.replace('*', 'x')

if not opt.output:
    if model_file.find("@") == -1:
        epochs = os.path.splitext(model_file)[0][12:]
        outdir = f"{dir_path}/output_{epochs}"
        output = os.path.join(outdir, f"net@{batch_size_str}.onnx")
    else:
        outdir = f"{dir_path}/export"
        output = os.path.join(outdir, f"{model_file.split('@')[0]}@{batch_size_str}.onnx")
    misc.create_dir(outdir)
else:
    output = opt.output
outname = os.path.splitext(os.path.split(output)[-1])[0]


def load_net():
    config = SphericalViewSynConfig()
    config.from_id(config_id)
    config.sa['perturb_sample'] = False
    config.name += '@' + batch_size_str
    config.print()
    net = config.create_net().to(device.default())
    netio.load(opt.model, net)
    return net


def export_net(net: torch.nn.Module, path: str, input: Mapping[str, List[int]],
               output_names: List[str]):
    input_tensors = tuple([
        torch.empty(size, device=device.default())
        for size in input.values()
    ])
    onnx.export(
        net,
        input_tensors,
        path,
        export_params=True,  # store the trained parameter weights inside the model file
        verbose=True,
        opset_version=9,     # the ONNX version to export the model to
        do_constant_folding=True,  # whether to execute constant folding
        input_names=list(input.keys()),   # the model's input names
        output_names=output_names  # the model's output names
    )
    print('Model exported to ' + path)


if __name__ == "__main__":
    with torch.no_grad():
        net: SnerfFast = load_net()
        export_net(
            SnerfFastExport(net),
            output,
            {
                'Encoded': [batch_size, net.n_samples, net.pos_encoder.out_dim],
                'Depths': [batch_size, net.n_samples]
            },
            ['Colors'])
        os.chdir(outdir)
        if opt.trt:
            os.system(f'trtexec --onnx={outname}.onnx --fp16 --saveEngine={outname}.trt --workspace=4096 --noDataTransfers')