import sys
import os
import argparse
import torch
import torch.optim
from tensorboardX import SummaryWriter
from torch import nn

sys.path.append(os.path.abspath(sys.path[0] + '/../'))
__package__ = "deep_view_syn"

parser = argparse.ArgumentParser()
parser.add_argument('--device', type=int, default=3,
                    help='Which CUDA device to use.')
parser.add_argument('--config', type=str,
                    help='Net config files')
parser.add_argument('--config-id', type=str,
                    help='Net config id')
parser.add_argument('--dataset', type=str, required=True,
                    help='Dataset description file')
parser.add_argument('--cont', type=str,
                    help='Continue train on model file')
parser.add_argument('--epochs', type=int,
                    help='Max epochs for train')
parser.add_argument('--test', type=str,
                    help='Test net file')
parser.add_argument('--test-samples', type=int,
                    help='Samples used for test')
parser.add_argument('--res', type=str,
                    help='Resolution')
parser.add_argument('--output-gt', action='store_true',
                    help='Output ground truth images if exist')
parser.add_argument('--output-alongside', action='store_true',
                    help='Output generated image alongside ground truth image')
parser.add_argument('--output-video', action='store_true',
                    help='Output test results as video')
parser.add_argument('--perf', action='store_true',
                    help='Test performance')
parser.add_argument('--simple-log', action='store_true', help='Simple log')

opt = parser.parse_args()
if opt.res:
    opt.res = tuple(int(s) for s in opt.res.split('x'))

# Select device
torch.cuda.set_device(opt.device)
print("Set CUDA:%d as current device." % torch.cuda.current_device())

from .my import netio
from .my import util
from .my import device
from .my import loss
from .my.progress_bar import progress_bar
from .my.simple_perf import SimplePerf
from .data.spherical_view_syn import *
from .data.loader import FastDataLoader
from .configs.spherical_view_syn import SphericalViewSynConfig


config = SphericalViewSynConfig()

# Toggles
ROT_ONLY = False
EVAL_TIME_PERFORMANCE = False
# ========
#ROT_ONLY = True
#EVAL_TIME_PERFORMANCE = True

# Train
BATCH_SIZE = 4096
EPOCH_RANGE = range(0, opt.epochs if opt.epochs else 300)
SAVE_INTERVAL = 10

# Test
TEST_BATCH_SIZE = 1
TEST_MAX_RAYS = 32768

# Paths
data_desc_path = opt.dataset
data_desc_name = os.path.splitext(os.path.basename(data_desc_path))[0]
if opt.test:
    test_net_path = opt.test
    test_net_name = os.path.splitext(os.path.basename(test_net_path))[0]
    run_dir = os.path.dirname(test_net_path) + '/'
    run_id = os.path.basename(run_dir[:-1])
    output_dir = run_dir + 'output/%s/%s%s/' % (test_net_name, data_desc_name,
                                                '_%dx%d' % (opt.res[0], opt.res[1]) if opt.res else '')
    config.from_id(run_id)
    train_mode = False
    if opt.test_samples:
        config.SAMPLE_PARAMS['n_samples'] = opt.test_samples
        output_dir = run_dir + 'output/%s/%s_s%d/' % \
            (test_net_name, data_desc_name, opt.test_samples)
else:
    data_dir = os.path.dirname(data_desc_path) + '/'
    if opt.cont:
        train_net_name = os.path.splitext(os.path.basename(opt.cont))[0]
        EPOCH_RANGE = range(int(train_net_name[12:]), EPOCH_RANGE.stop)
        run_dir = os.path.dirname(opt.cont) + '/'
        run_id = os.path.basename(run_dir[:-1])
        config.from_id(run_id)
    else:
        if opt.config:
            config.load(opt.config)
        if opt.config_id:
            config.from_id(opt.config_id)
        run_id = config.to_id()
        run_dir = data_dir + run_id + '/'
    log_dir = run_dir + 'log/'
    output_dir = None
    train_mode = True

config.print()
print("dataset: ", data_desc_path)
print("train_mode: ", train_mode)
print("run_dir: ", run_dir)
if not train_mode:
    print("output_dir", output_dir)

config.SAMPLE_PARAMS['perturb_sample'] = \
    config.SAMPLE_PARAMS['perturb_sample'] and train_mode

LOSSES = {
    'mse': lambda: nn.MSELoss(),
    'mse_grad': lambda: loss.CombinedLoss(
        [nn.MSELoss(), loss.GradLoss()], [1.0, 0.5])
}

# Initialize model
model = config.create_net().to(device.GetDevice())
loss_mse = nn.MSELoss().to(device.GetDevice())
loss_grad = loss.GradLoss().to(device.GetDevice())


def train_loop(data_loader, optimizer, loss, perf, writer, epoch, iters):
    sub_iters = 0
    iters_in_epoch = len(data_loader)
    loss_min = 1e5
    loss_max = 0
    loss_avg = 0
    perf1 = SimplePerf(opt.simple_log, True)
    for _, gt, rays_o, rays_d in data_loader:
        patch = (len(gt.size()) == 4)
        gt = gt.to(device.GetDevice())
        rays_o = rays_o.to(device.GetDevice())
        rays_d = rays_d.to(device.GetDevice())
        perf.Checkpoint("Load")

        out = model(rays_o, rays_d)
        perf.Checkpoint("Forward")

        optimizer.zero_grad()
        if config.COLOR == color_mode.YCbCr:
            loss_mse_value = 0.3 * loss_mse(out[..., 0:2], gt[..., 0:2]) + \
                0.7 * loss_mse(out[..., 2], gt[..., 2])
        else:
            loss_mse_value = loss_mse(out, gt)
        loss_grad_value = loss_grad(out, gt) if patch else None
        loss_value = loss_mse_value  # + 0.5 * loss_grad_value if patch \
        # else loss_mse_value
        perf.Checkpoint("Compute loss")

        loss_value.backward()
        perf.Checkpoint("Backward")

        optimizer.step()
        perf.Checkpoint("Update")

        loss_value = loss_value.item()
        loss_min = min(loss_min, loss_value)
        loss_max = max(loss_max, loss_value)
        loss_avg = (loss_avg * sub_iters + loss_value) / (sub_iters + 1)
        if not opt.simple_log:
            progress_bar(sub_iters, iters_in_epoch,
                        "Loss: %.2e (%.2e/%.2e/%.2e)" % (loss_value, loss_min, loss_avg, loss_max),
                        "Epoch {:<3d}".format(epoch))

        # Write tensorboard logs.
        writer.add_scalar("loss mse", loss_value, iters)
        # if patch and iters % 100 == 0:
        #    output_vs_gt = torch.cat([out[0:4], gt[0:4]], 0).detach()
        #    writer.add_image("Output_vs_gt", torchvision.utils.make_grid(
        #        output_vs_gt, nrow=4).cpu().numpy(), iters)

        iters += 1
        sub_iters += 1
    if opt.simple_log:
        perf1.Checkpoint('Epoch %d (%.2e/%.2e/%.2e)' % (epoch, loss_min, loss_avg, loss_max), True)
    return iters


def train():
    # 1. Initialize data loader
    print("Load dataset: " + data_desc_path)
    train_dataset = SphericalViewSynDataset(
        data_desc_path, color=config.COLOR, res=opt.res)
    train_dataset.set_patch_size(1)
    train_data_loader = FastDataLoader(
        dataset=train_dataset,
        batch_size=BATCH_SIZE,
        shuffle=True,
        drop_last=False,
        pin_memory=True)

    # 2. Initialize components
    optimizer = torch.optim.Adam(model.parameters(), lr=5e-4, weight_decay=config.OPT_DECAY)
    loss = 0  # LOSSES[config.LOSS]().to(device.GetDevice())

    if EPOCH_RANGE.start > 0:
        iters = netio.LoadNet('%smodel-epoch_%d.pth' % (run_dir, EPOCH_RANGE.start),
                              model, solver=optimizer)
    else:
        if config.NORMALIZE:
            for _, _, rays_o, rays_d in train_data_loader:
                model.update_normalize_range(rays_o, rays_d)
            print('Depth/diopter range: ', model.depth_range)
            print('Angle range: ', model.angle_range / 3.14159 * 180)
        iters = 0
    epoch = None

    # 3. Train
    model.train()

    util.CreateDirIfNeed(run_dir)
    util.CreateDirIfNeed(log_dir)

    perf = SimplePerf(EVAL_TIME_PERFORMANCE, start=True)
    writer = SummaryWriter(log_dir)

    print("Begin training...")
    for epoch in EPOCH_RANGE:
        iters = train_loop(train_data_loader, optimizer, loss,
                           perf, writer, epoch, iters)
        # Save checkpoint
        if ((epoch + 1) % SAVE_INTERVAL == 0):
            netio.SaveNet('%smodel-epoch_%d.pth' % (run_dir, epoch + 1), model,
                          solver=optimizer, iters=iters)
    print("Train finished")
    netio.SaveNet('%smodel-epoch_%d.pth' % (run_dir, epoch + 1), model,
                  solver=optimizer, iters=iters)


def perf():
    with torch.no_grad():
        # 1. Load dataset
        print("Load dataset: " + data_desc_path)
        test_dataset = SphericalViewSynDataset(data_desc_path,
                                               load_images=True,
                                               color=config.COLOR, res=opt.res)
        test_data_loader = FastDataLoader(
            dataset=test_dataset,
            batch_size=1,
            shuffle=False,
            drop_last=False,
            pin_memory=True)

        # 2. Load trained model
        netio.LoadNet(test_net_path, model)

        # 3. Test on dataset
        print("Begin perf, batch size is %d" % TEST_BATCH_SIZE)

        perf = SimplePerf(True, start=True)
        loss = nn.MSELoss()
        i = 0
        n = test_dataset.n_views
        chns = 1 if config.COLOR == color_mode.GRAY else 3
        out_view_images = torch.empty(n, chns, test_dataset.view_res[0],
                                      test_dataset.view_res[1],
                                      device=device.GetDevice())
        perf_times = torch.empty(n)
        perf_errors = torch.empty(n)
        for view_idxs, gt, rays_o, rays_d in test_data_loader:
            perf.Checkpoint("%d - Load" % i)
            rays_o = rays_o.to(device.GetDevice()).view(-1, 3)
            rays_d = rays_d.to(device.GetDevice()).view(-1, 3)
            n_rays = rays_o.size(0)
            chunk_size = min(n_rays, TEST_MAX_RAYS)
            out_pixels = torch.empty(n_rays, chns, device=device.GetDevice())
            for offset in range(0, n_rays, chunk_size):
                idx = slice(offset, offset + chunk_size)
                out_pixels[idx] = model(rays_o[idx], rays_d[idx])
            if config.COLOR == color_mode.YCbCr:
                out_pixels = util.ycbcr2rgb(out_pixels)
            out_view_images[view_idxs] = out_pixels.view(
                TEST_BATCH_SIZE, test_dataset.view_res[0],
                test_dataset.view_res[1], -1).permute(0, 3, 1, 2)
            perf_times[view_idxs] = perf.Checkpoint("%d - Infer" % i)
            if config.COLOR == color_mode.YCbCr:
                gt = util.ycbcr2rgb(gt)
            error = loss(out_view_images[view_idxs], gt).item()
            print("%d - Error: %f" % (i, error))
            perf_errors[view_idxs] = error
            i += 1

        # 4. Save results
        perf_mean_time = torch.mean(perf_times).item()
        perf_mean_error = torch.mean(perf_errors).item()
        with open(run_dir + 'perf_%s_%s_%.1fms_%.2e.txt' % (test_net_name, data_desc_name, perf_mean_time, perf_mean_error), 'w') as fp:
            fp.write('View, Time, Error\n')
            fp.writelines(['%d, %f, %f\n' % (
                i, perf_times[i].item(), perf_errors[i].item()) for i in range(n)])


def test():
    with torch.no_grad():
        # 1. Load dataset
        print("Load dataset: " + data_desc_path)
        test_dataset = SphericalViewSynDataset(data_desc_path,
                                               load_images=opt.output_gt or opt.output_alongside,
                                               color=config.COLOR,
                                               res=opt.res)
        test_data_loader = FastDataLoader(
            dataset=test_dataset,
            batch_size=1,
            shuffle=False,
            drop_last=False,
            pin_memory=True)

        # 2. Load trained model
        netio.LoadNet(test_net_path, model)

        # 3. Test on dataset
        print("Begin test, batch size is %d" % TEST_BATCH_SIZE)
        util.CreateDirIfNeed(output_dir)

        perf = SimplePerf(True, start=True)
        i = 0
        n = test_dataset.n_views
        chns = 1 if config.COLOR == color_mode.GRAY else 3
        out_view_images = torch.empty(n, chns, test_dataset.view_res[0],
                                      test_dataset.view_res[1],
                                      device=device.GetDevice())
        for view_idxs, _, rays_o, rays_d in test_data_loader:
            perf.Checkpoint("%d - Load" % i)
            rays_o = rays_o.to(device.GetDevice()).view(-1, 3)
            rays_d = rays_d.to(device.GetDevice()).view(-1, 3)
            n_rays = rays_o.size(0)
            chunk_size = min(n_rays, TEST_MAX_RAYS)
            out_pixels = torch.empty(n_rays, chns, device=device.GetDevice())
            for offset in range(0, n_rays, chunk_size):
                idx = slice(offset, offset + chunk_size)
                out_pixels[idx] = model(rays_o[idx], rays_d[idx])
            if config.COLOR == color_mode.YCbCr:
                out_pixels = util.ycbcr2rgb(out_pixels)
            out_view_images[view_idxs] = out_pixels.view(
                TEST_BATCH_SIZE, test_dataset.view_res[0],
                test_dataset.view_res[1], -1).permute(0, 3, 1, 2)
            perf.Checkpoint("%d - Infer" % i)
            i += 1

        # 4. Save results
        if opt.output_video:
            util.generate_video(out_view_images, output_dir +
                                'out.mp4', 24, 3, True)
        else:
            gt_paths = [
                '%sgt_view_%04d.png' % (output_dir, i) for i in range(n)
            ]
            out_paths = [
                '%sout_view_%04d.png' % (output_dir, i) for i in range(n)
            ]
            if test_dataset.load_images:
                if opt.output_alongside:
                    util.WriteImageTensor(
                        torch.cat([
                            test_dataset.view_images,
                            out_view_images
                        ], 3), out_paths)
                else:
                    util.WriteImageTensor(out_view_images, out_paths)
                    util.WriteImageTensor(test_dataset.view_images, gt_paths)
            else:
                util.WriteImageTensor(out_view_images, out_paths)


if __name__ == "__main__":
    if train_mode:
        train()
    elif opt.perf:
        perf()
    else:
        test()