import torch
import argparse
import os
import glob
import numpy as np
import torchvision.transforms as transforms
from torchvision.utils import save_image

from torchvision import datasets
from torch.utils.data import DataLoader 
from torch.autograd import Variable

import cv2
from gen_image import *
import json
# param
BATCH_SIZE = 5
NUM_EPOCH = 5000

INTERLEAVE_RATE = 2

IM_H = 480
IM_W = 640

N = 9 # number of input light field stack
M = 2 # number of display layers

DATA_FILE = "/home/yejiannan/Project/deeplightfield/data/try"
DATA_JSON = "/home/yejiannan/Project/deeplightfield/data/data.json"
OUTPUT_DIR = "/home/yejiannan/Project/deeplightfield/output"

class lightFieldDataLoader(torch.utils.data.dataset.Dataset):
    def __init__(self, file_dir_path, file_json, transforms=None):
        self.file_dir_path = file_dir_path
        self.transforms = transforms
        # self.datum_list = glob.glob(os.path.join(file_dir_path,"*"))
        with open(DATA_JSON, encoding='utf-8') as file:
            self.dastset_desc = json.loads(file.read())

    def __len__(self):
        return len(self.dastset_desc["focaldepth"])

    def __getitem__(self, idx):
        lightfield_images, gt, fd = self.get_datum(idx)
        if self.transforms:
            lightfield_images = self.transforms(lightfield_images)
        return (lightfield_images, gt, fd)

    def get_datum(self, idx):
        lf_image_paths = os.path.join(DATA_FILE, self.dastset_desc["train"][idx])
        # print(lf_image_paths)
        fd_gt_path = os.path.join(DATA_FILE, self.dastset_desc["gt"][idx])
        # print(fd_gt_path)
        lf_images = []
        lf_image_big = cv2.imread(lf_image_paths, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.
        lf_image_big = cv2.cvtColor(lf_image_big,cv2.COLOR_BGR2RGB)
        for i in range(9):
            lf_image = lf_image_big[i//3*IM_H:i//3*IM_H+IM_H,i%3*IM_W:i%3*IM_W+IM_W,0:3]
            # print(lf_image.shape)
            lf_images.append(lf_image)
        gt = cv2.imread(fd_gt_path, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.
        gt = cv2.cvtColor(gt,cv2.COLOR_BGR2RGB)
        fd = self.dastset_desc["focaldepth"][idx]
        return (np.asarray(lf_images),gt,fd)

OUT_CHANNELS_RB = 128
KERNEL_SIZE_RB = 3
KERNEL_SIZE = 3

class residual_block(torch.nn.Module):
    def __init__(self):
        super(residual_block,self).__init__()
        self.layer1 = torch.nn.Sequential(
            torch.nn.Conv2d(OUT_CHANNELS_RB,OUT_CHANNELS_RB,KERNEL_SIZE_RB,stride=1,padding = 1),
            torch.nn.BatchNorm2d(OUT_CHANNELS_RB),
            torch.nn.ELU()
        )
        self.layer2 = torch.nn.Sequential(
            torch.nn.Conv2d(OUT_CHANNELS_RB,OUT_CHANNELS_RB,KERNEL_SIZE_RB,stride=1,padding = 1),
            torch.nn.BatchNorm2d(OUT_CHANNELS_RB,OUT_CHANNELS_RB),
            torch.nn.ELU()
        )

    def forward(self,input):
        output = self.layer1(input)
        output = self.layer2(output)
        output = input+output
        return output

class deinterleave(torch.nn.Module):
    def __init__(self, block_size):
        super(deinterleave, self).__init__()
        self.block_size = block_size
        self.block_size_sq = block_size*block_size

    def forward(self, input):
        output = input.permute(0, 2, 3, 1)
        (batch_size, d_height, d_width, d_depth) = output.size()
        s_depth = int(d_depth / self.block_size_sq)
        s_width = int(d_width * self.block_size)
        s_height = int(d_height * self.block_size)
        t_1 = output.reshape(batch_size, d_height, d_width, self.block_size_sq, s_depth)
        spl = t_1.split(self.block_size, 3)
        stack = [t_t.reshape(batch_size, d_height, s_width, s_depth) for t_t in spl]
        output = torch.stack(stack,0).transpose(0,1).permute(0,2,1,3,4).reshape(batch_size, s_height, s_width, s_depth)
        output = output.permute(0, 3, 1, 2)
        return output

class interleave(torch.nn.Module):
    def __init__(self, block_size):
        super(interleave, self).__init__()
        self.block_size = block_size
        self.block_size_sq = block_size*block_size

    def forward(self, input):
        output = input.permute(0, 2, 3, 1)
        (batch_size, s_height, s_width, s_depth) = output.size()
        d_depth = s_depth * self.block_size_sq
        d_width = int(s_width / self.block_size)
        d_height = int(s_height / self.block_size)
        t_1 = output.split(self.block_size, 2)
        stack = [t_t.reshape(batch_size, d_height, d_depth) for t_t in t_1]
        output = torch.stack(stack, 1)
        output = output.permute(0, 2, 1, 3)
        output = output.permute(0, 3, 1, 2)
        return output


LAST_LAYER_CHANNELS = 6 * INTERLEAVE_RATE**2
FIRSST_LAYER_CHANNELS = 28 * INTERLEAVE_RATE**2

class model(torch.nn.Module):
    def __init__(self):
        super(model, self).__init__()
        self.interleave = interleave(INTERLEAVE_RATE)

        self.first_layer = torch.nn.Sequential(
            torch.nn.Conv2d(FIRSST_LAYER_CHANNELS,OUT_CHANNELS_RB,KERNEL_SIZE,stride=1,padding=1),
            torch.nn.BatchNorm2d(OUT_CHANNELS_RB),
            torch.nn.ELU()
        )
        
        self.residual_block1 = residual_block()
        self.residual_block2 = residual_block()
        self.residual_block3 = residual_block()

        self.output_layer = torch.nn.Sequential(
            torch.nn.Conv2d(OUT_CHANNELS_RB,LAST_LAYER_CHANNELS,KERNEL_SIZE,stride=1,padding=1),
            torch.nn.BatchNorm2d(LAST_LAYER_CHANNELS),
            torch.nn.Tanh()
        )
        self.deinterleave = deinterleave(INTERLEAVE_RATE)


    def forward(self, lightfield_images, focal_length):
        # lightfield_images: torch.Size([batch_size, channels * D, H, W]) 
        # channels : RGB*D: 3*9, H:256, W:256
        input_to_net = self.interleave(lightfield_images)
        # print("after interleave:",input_to_net.shape)
        input_to_rb = self.first_layer(input_to_net)
        output = self.residual_block1(input_to_rb)
        # print("output1:",output.shape)
        output = self.residual_block2(output)
        
        output = self.residual_block3(output)
        # output = output + input_to_net
        output = self.output_layer(output)
        output = self.deinterleave(output)
        return output

class Conf(object):
    def __init__(self):
        self.pupil_size = 0.02 # 2cm
        self.retinal_res = torch.tensor([ 480, 640 ])
        self.layer_res = torch.tensor([ 480, 640 ])
        self.n_layers = 2
        self.d_layer = [ 1.75, 3.5 ] # layers' distance
        self.h_layer = [ 1., 2. ] # layers' height

#### Image Gen
conf = Conf()

v = torch.tensor([conf.h_layer[0] / conf.d_layer[0],
     conf.h_layer[0] / conf.d_layer[0] * conf.layer_res[1] / conf.layer_res[0]])

u = GenSamplesInPupil(conf, 5)

def GenRetinalFromLayersBatch(layers, conf, df, v, u):
    # layers: batchsize, 2, color, height, width 
    # Phi:torch.Size([batchsize, 480, 640, 2, 41, 2])
    # df : batchsize,..
    H_r = conf.retinal_res[0]
    W_r = conf.retinal_res[1]
    D_r = conf.retinal_res.double()
    N = conf.n_layers
    M = u.size()[0] #41
    BS = df.shape[0]
    Phi = torch.empty(BS, H_r, W_r, N, M, 2, dtype=torch.long)
    # print("Phi:",Phi.shape)

    p_rx, p_ry = torch.meshgrid(torch.tensor(range(0, H_r)),
                                torch.tensor(range(0, W_r)))
    p_r = torch.stack([p_rx, p_ry], 2).unsqueeze(2).expand(-1, -1, M, -1)
    # print("p_r:",p_r.shape) #torch.Size([480, 640, 41, 2])
    for bs in range(BS):
        for i in range(0, N):
            dpi = conf.h_layer[i] / float(conf.layer_res[0]) # 1 / 480
            # print("dpi:",dpi)
            ci = conf.layer_res / 2 # [240,320]
            di = conf.d_layer[i] # 深度
            pi_r = di * v * (1. / D_r * (p_r + 0.5) - 0.5) / dpi # [480, 640, 41, 2]
            wi = (1 - di / df[bs]) / dpi # (1 - 深度/聚焦) / dpi  df = 2.625 di = 1.75
            pi = torch.floor(pi_r + ci + wi * u)
            torch.clamp_(pi[:, :, :, 0], 0, conf.layer_res[0] - 1)
            torch.clamp_(pi[:, :, :, 1], 0, conf.layer_res[1] - 1)
            Phi[bs, :, :, i, :, :] = pi
    # print("Phi slice:",Phi[0, :, :, 0, 0, 0].shape)
    retinal = torch.zeros(BS, 3, H_r, W_r)
    retinal = var_or_cuda(retinal)
    for bs in range(BS):
        for j in range(0, M):
            retinal_view = torch.zeros(3, H_r, W_r)
            retinal_view = var_or_cuda(retinal_view)
            for i in range(0, N):
                retinal_view.add_(layers[bs, (i * 3) : (i * 3 + 3), Phi[bs, :, :, i, j, 0], Phi[bs, :, :, i, j, 1]])
            retinal[bs,:,:,:].add_(retinal_view)
        retinal[bs,:,:,:].div_(M)
    return retinal
#### Image Gen End

def merge_two(near,far):
    df = conf.d_layer[0] + (conf.d_layer[1] - conf.d_layer[0]) / 2.
    # Phi = GenRetinal2LayerMappings(conf, df, v, u)
    # retinal = GenRetinalFromLayers(layers, Phi)
    return near[:,0:3,:,:] + far[:,3:6,:,:] / 2.0

def loss_two_images(generated, gt):
    l1_loss = torch.nn.L1Loss()
    return l1_loss(generated, gt)

weightVarScale = 0.25
bias_stddev = 0.01

def weight_init_normal(m):
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        torch.nn.init.xavier_normal_(m.weight.data)
        torch.nn.init.normal_(m.bias.data,mean = 0.0, std=bias_stddev)
    elif classname.find("BatchNorm2d") != -1:
        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
        torch.nn.init.constant_(m.bias.data, 0.0)

def var_or_cuda(x):
    if torch.cuda.is_available():
        x = x.cuda(non_blocking=True)
    return x

if __name__ == "__main__":
    #test
    # train_dataset = lightFieldDataLoader(DATA_FILE,DATA_JSON)
    # print(train_dataset[0][0].shape)
    # cv2.imwrite("test_crop0.png",train_dataset[0][1]*255.)
    # save_image(output[0][0:3].data,os.path.join(OUTPUT_DIR,"o%d_%d.png"%(epoch,batch_idx)))
    #test end

   

    train_data_loader = torch.utils.data.DataLoader(dataset=lightFieldDataLoader(DATA_FILE,DATA_JSON),
                                                    batch_size=BATCH_SIZE,
                                                    num_workers=0,
                                                    pin_memory=True,
                                                    shuffle=False,
                                                    drop_last=False)
    print(len(train_data_loader))
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    lf_model = model()
    lf_model.apply(weight_init_normal)
    if torch.cuda.is_available():
        lf_model = torch.nn.DataParallel(lf_model).cuda()
    lf_model.train()
    optimizer = torch.optim.Adam(lf_model.parameters(),lr=5e-3,betas=(0.9,0.999))
    
    for epoch in range(NUM_EPOCH):
        for batch_idx, (image_set, gt, df) in enumerate(train_data_loader):
            #reshape for input
            image_set = image_set.permute(0,1,4,2,3) # N LF C H W
            image_set = image_set.reshape(image_set.shape[0],-1,image_set.shape[3],image_set.shape[4]) # N, LFxC, H, W
            depth_layer = torch.ones((image_set.shape[0],1,image_set.shape[2],image_set.shape[3]))
            # print(df.shape[0])
            for i in range(df.shape[0]):
                depth_layer[i] = depth_layer[i] * df[i]
            # print(depth_layer.shape)
            image_set = torch.cat((image_set,depth_layer),dim=1)
            image_set = var_or_cuda(image_set)
            # image_set.to(device)
            gt = gt.permute(0,3,1,2)
            gt = var_or_cuda(gt)
            # print("Epoch:",epoch,",Iter:",batch_idx,",Input shape:",image_set.shape, ",Input gt:",gt.shape)
            optimizer.zero_grad()
            output = lf_model(image_set,0)
            # print("output:",output.shape," df:",df.shape)
            output = GenRetinalFromLayersBatch(output,conf,df,v,u)
            loss = loss_two_images(output,gt)
            print("Epoch:",epoch,",Iter:",batch_idx,",loss:",loss)
            loss.backward()
            optimizer.step()
            for i in range(5):
                save_image(output[i][0:3].data,os.path.join(OUTPUT_DIR,"cuda_lr_5e-3_o%d_%d.png"%(epoch,i)))