import json
import sys
import os
import argparse
import torch
import shutil
import torch.nn as nn
import torch.nn.functional as nn_f

sys.path.append(os.path.abspath(sys.path[0] + '/../'))

parser = argparse.ArgumentParser()
parser.add_argument('-s', '--stereo', action='store_true')
parser.add_argument('-R', '--replace', action='store_true')
parser.add_argument('--noCE', action='store_true')
parser.add_argument('-i', '--input', type=str)
parser.add_argument('-r', '--range', type=str)
parser.add_argument('-f', '--fps', type=int)
parser.add_argument('--device', type=int, default=0,
                    help='Which CUDA device to use.')
parser.add_argument('scene', type=str)
parser.add_argument('view_file', type=str)
opt = parser.parse_args()

# Select device
torch.cuda.set_device(opt.device)
print("Set CUDA:%d as current device." % torch.cuda.current_device())
torch.autograd.set_grad_enabled(False)

from configs.spherical_view_syn import SphericalViewSynConfig
from utils import netio
from utils import misc
from utils import img
from utils import device
from utils.view import *
from utils import sphere
from components.fnr import FoveatedNeuralRenderer
from utils.progress_bar import progress_bar


def load_net(path):
    config = SphericalViewSynConfig()
    config.from_id(path[:-4])
    config.sa['perturb_sample'] = False
    # config.print()
    net = config.create_net().to(device.default())
    netio.load(path, net)
    return net


def find_file(prefix):
    for path in os.listdir():
        if path.startswith(prefix):
            return path
    return None


def clamp_gaze(gaze):
    return gaze
    scoord = sphere.cartesian2spherical(gaze)


def load_csv(data_desc_file) -> Tuple[Trans, torch.Tensor]:
    def to_tensor(line_content):
        return torch.tensor([float(str) for str in line_content.split(',')])

    with open(data_desc_file, 'r', encoding='utf-8') as file:
        lines = file.readlines()
        simple_fmt = len(lines[3].split(',')) == 1
        old_fmt = len(lines[7].split(',')) == 1
        lines_per_view = 3 if simple_fmt else 7 if old_fmt else 9
        n = len(lines) // lines_per_view
        gaze_dirs = torch.empty(n, 2, 3)
        gazes = torch.empty(n, 2, 2)
        view_mats = torch.empty(n, 2, 4, 4)
        view_idx = 0
        for i in range(0, len(lines), lines_per_view):
            if simple_fmt:
                if lines[i + 1].startswith('0,0,0,0'):
                    continue
                gazes[view_idx] = to_tensor(lines[i + 1]).view(2, 2)
                view_mats[view_idx] = to_tensor(lines[i + 2]).view(1, 4, 4).expand(2, -1, -1)
            else:
                if lines[i + 1].startswith('0,0,0') or lines[i + 2].startswith('0,0,0'):
                    continue
                j = i + 1
                gaze_dirs[view_idx, 0] = clamp_gaze(to_tensor(lines[j]))
                j += 1
                gaze_dirs[view_idx, 1] = clamp_gaze(to_tensor(lines[j]))
                j += 1
                if not old_fmt:
                    gazes[view_idx, 0] = to_tensor(lines[j])
                    j += 1
                    gazes[view_idx, 1] = to_tensor(lines[j])
                    j += 1
                view_mats[view_idx, 0] = to_tensor(lines[j]).view(4, 4)
                j += 1
                view_mats[view_idx, 1] = to_tensor(lines[j]).view(4, 4)
            view_idx += 1

    view_mats = view_mats[:view_idx]
    gaze_dirs = gaze_dirs[:view_idx]
    gazes = gazes[:view_idx]

    if old_fmt:
        gaze2 = -gaze_dirs[..., :2] / gaze_dirs[..., 2:]
        fov = math.radians(55)
        tan = torch.tensor([math.tan(fov), math.tan(fov) * 1600 / 1440])
        gazeper = gaze2 / tan
        gazes = 0.5 * torch.stack([1440 * gazeper[..., 0], 1600 * gazeper[..., 1]], -1)

    view_t = 0.5 * (view_mats[:, 0, :3, 3] + view_mats[:, 1, :3, 3])

    return Trans(view_t, view_mats[:, 0, :3, :3]), gazes

def load_json(data_desc_file) -> Tuple[Trans, torch.Tensor]:
    with open(data_desc_file, 'r', encoding='utf-8') as file:
        data = json.load(file)
        view_t = torch.tensor(data['view_centers'])
        view_r = torch.tensor(data['view_rots']).view(-1, 3, 3)
        if data.get('gazes'):
            if len(data['gazes'][0]) == 2:
                gazes = torch.tensor(data['gazes']).view(-1, 1, 2).expand(-1, 2, -1)
            else:
                gazes = torch.tensor(data['gazes']).view(-1, 2, 2)
        else:
            gazes = torch.zeros(view_t.size(0), 2, 2)
    return Trans(view_t, view_r), gazes

def load_views(data_desc_file: str) -> Tuple[Trans, torch.Tensor]:
    if data_desc_file.endswith('.csv'):
        return load_csv(data_desc_file)
    return load_json(data_desc_file)

rot_range = {
    'classroom': [120, 80],
    'barbershop': [360, 80],
    'lobby': [360, 80],
    'stones': [360, 80]
}
trans_range = {
    'classroom': 0.6,
    'barbershop': 0.3,
    'lobby': 1.0,
    'stones': 1.0
}
fov_list = [20, 45, 110]
res_list = [(256, 256), (256, 256), (400, 360)]
res_full = (1600, 1440)
stereo_disparity = 0.06

cwd = os.getcwd()
os.chdir(f"{sys.path[0]}/../data/__new/{opt.scene}_all")
fovea_net = load_net(find_file('fovea'))
periph_net = load_net(find_file('periph'))
renderer = FoveatedNeuralRenderer(fov_list, res_list,
                                  nn.ModuleList([fovea_net, periph_net, periph_net]),
                                  res_full, device=device.default())
os.chdir(cwd)

# Load Dataset
views, gazes = load_views(opt.view_file)
if opt.range:
    opt.range = [int(val) for val in opt.range.split(",")]
    if len(opt.range) == 1:
        opt.range = [0, opt.range[0]]
    views = views.get(range(opt.range[0], opt.range[1]))
    gazes = gazes[opt.range[0]:opt.range[1]]
views = views.to(device.default())
n_views = views.size()[0]
print('Dataset loaded. Views:', n_views)


videodir = os.path.dirname(os.path.abspath(opt.view_file))
tempdir = '/dev/shm/dvs_tmp/realvideo'
videoname = f"{os.path.splitext(os.path.split(opt.view_file)[-1])[0]}_{'stereo' if opt.stereo else 'mono'}"
gazeout = f"{videodir}/{videoname}_gaze.csv"
if opt.noCE:
    videoname += "_noCE"
if opt.fps:
    if opt.input:
        videoname = os.path.split(opt.input)[0]
    inferout = f"{videodir}/{opt.input}" if opt.input else f"{tempdir}/{videoname}/%04d.bmp"
    hintout = f"{tempdir}/{videoname}_hint/%04d.bmp"
else:
    inferout = f"{videodir}/{opt.input}" if opt.input else f"{videodir}/{videoname}/%04d.png"
    hintout = f"{videodir}/{videoname}_hint/%04d.png"
if opt.input:
    scale = img.load(inferout % 0).shape[-1] / res_full[1]
else:
    scale = 1
hint = img.load(f"{sys.path[0]}/fovea_hint.png", with_alpha=True).to(device=device.default())
hint = nn_f.interpolate(hint, mode='bilinear', scale_factor=scale, align_corners=False)

print("Video dir:", videodir)
print("Infer out:", inferout)
print("Hint out:", hintout)


def add_hint(image, center, right_center=None):
    if right_center is not None:
        add_hint(image[..., :image.size(-1) // 2], center)
        add_hint(image[..., image.size(-1) // 2:], right_center)
        return
    center = (center[0] * scale, center[1] * scale)
    fovea_origin = (
        int(center[0]) + image.size(-1) // 2 - hint.size(-1) // 2,
        int(center[1]) + image.size(-2) // 2 - hint.size(-2) // 2
    )
    fovea_region = (
        ...,
        slice(fovea_origin[1], fovea_origin[1] + hint.size(-2)),
        slice(fovea_origin[0], fovea_origin[0] + hint.size(-1)),
    )
    try:
        image[fovea_region] = image[fovea_region] * (1 - hint[:, 3:]) + \
            hint[:, :3] * hint[:, 3:]
    except Exception:
        print(fovea_region, image.shape, hint.shape)
        exit()


misc.create_dir(os.path.dirname(inferout))
misc.create_dir(os.path.dirname(hintout))

hint_offset = infer_offset = 0
if not opt.replace:
    for view_idx in range(n_views):
        if os.path.exists(inferout % view_idx):
            infer_offset += 1
        if os.path.exists(hintout % view_idx):
            hint_offset += 1
    hint_offset = max(0, hint_offset - 1)
    infer_offset = n_views if opt.input else max(0, infer_offset - 1)

if opt.stereo:
    gazes_out = torch.empty(n_views, 4)
    for view_idx in range(n_views):
        shift = gazes[view_idx, 0, 0] - gazes[view_idx, 1, 0]
        # print(shift.item())
        gazel = ((gazes[view_idx, 1, 0] + 0.4 * shift).item(),
                 0.5 * (gazes[view_idx, 0, 1] + gazes[view_idx, 1, 1]).item())
        gazer = ((gazes[view_idx, 0, 0] - 0.4 * shift).item(), gazel[1])
        # gazel = ((gazes[view_idx, 0, 0]).item(),
        #         0.5 * (gazes[view_idx, 0, 1] + gazes[view_idx, 1, 1]).item())
        #gazer = ((gazes[view_idx, 1, 0]).item(), gazel[1])
        gazes_out[view_idx] = torch.tensor([gazel[0], gazel[1], gazer[0], gazer[1]])
        if view_idx < hint_offset:
            continue
        if view_idx < infer_offset:
            frame = img.load(inferout % view_idx).to(device=device.default())
        else:
            view_trans = views.get(view_idx)
            left_images, right_images = renderer(view_trans, gazel, gazer,
                                                 stereo_disparity=stereo_disparity,
                                                 mono_periph_mode=3, ret_raw=True)
            frame = torch.cat([
                left_images['blended_raw'] if opt.noCE else left_images['blended'],
                right_images['blended_raw'] if opt.noCE else right_images['blended']], -1)
            frame = img.translate(frame, (0.5, 0.5))
            img.save(frame, inferout % view_idx)
        add_hint(frame, gazel, gazer)
        img.save(frame, hintout % view_idx)
        progress_bar(view_idx, n_views, 'Frame %4d inferred' % view_idx)
else:
    gazes_out = torch.empty(n_views, 2)
    for view_idx in range(n_views):
        gaze = 0.5 * (gazes[view_idx, 0] + gazes[view_idx, 1])
        gaze = (gaze[0].item(), gaze[1].item())
        gazes_out[view_idx] = torch.tensor([gaze[0], gaze[1]])
        if view_idx < hint_offset:
            continue
        if view_idx < infer_offset:
            frame = img.load(inferout % view_idx).to(device=device.default())
        else:
            view_trans = views.get(view_idx)
            frame = renderer(view_trans, gaze,
                             ret_raw=True)['blended_raw' if opt.noCE else 'blended']
            frame = img.translate(frame, (0.5, 0.5))
            img.save(frame, inferout % view_idx)
        add_hint(frame, gaze)
        img.save(frame, hintout % view_idx)
        progress_bar(view_idx, n_views, 'Frame %4d inferred' % view_idx)

with open(gazeout, 'w') as fp:
    for i in range(n_views):
        fp.write(','.join([f'{val.item()}' for val in gazes_out[i]]))
        fp.write('\n')

if opt.fps:
    # Generate video without hint
    os.system(f'ffmpeg -y -r {opt.fps:d} -i {inferout} -c:v libx264 {videodir}/{videoname}.mp4')
    if not opt.input:
        shutil.rmtree(os.path.dirname(inferout))

    # Generate video with hint
    os.system(f'ffmpeg -y -r {opt.fps:d} -i {hintout} -c:v libx264 {videodir}/{videoname}_hint.mp4')
    shutil.rmtree(os.path.dirname(hintout))