import torch
import torch.nn.functional as nn_f
from typing import Dict, Tuple, Union
from pathlib import Path

from utils import img
from utils import color
from .dataset import Dataset


class ViewDataset(Dataset):
    """
    Data loader for spherical view synthesis task

    Attributes
    --------
    data_dir ```str```: the directory of dataset\n
    view_file_pattern ```str```: the filename pattern of view images\n
    cam ```object```: camera intrinsic parameters\n
    view_centers ```Tensor(N, 3)```: centers of views\n
    view_rots ```Tensor(N, 3, 3)```: rotation matrices of views\n
    view_images ```Tensor(N, 3, H, W)```: images of views\n
    view_depths ```Tensor(N, H, W)```: depths of views\n
    """

    class Chunk(object):

        def __init__(self, id: int, dataset, chunk_data: Dict[str, torch.Tensor], *,
                     color: int, **kwargs):
            """
            [summary]

            :param dataset `ViewDataset`: dataset object
            :param indices `Tensor(N)`: indices of views
            :param centers `Tensor(N, 3)`: centers of views
            """
            self.id = id
            self.dataset = dataset
            self.indices = chunk_data['indices']
            self.centers = chunk_data['centers']
            self.rots = chunk_data['rots']
            self.color = color
            self.n_views = self.indices.size(0)
            self.n_pixels_per_view = self.dataset.res[0] * self.dataset.res[1]
            self.colors = self.depths = self.bins = None
            self.colors_cpu = self.depths_cpu = self.bins_cpu = None
            self.loaded = False

        def release(self):
            self.colors = self.depths = self.bins = None
            self.loaded = False

        def load(self):
            #print("chunk load")
            try:
                if self.dataset.image_path and self.colors_cpu is None:
                    images = color.cvt(img.load(self.dataset.image_path % i for i in self.indices),
                                       color.RGB, self.color)
                    if self.dataset.res != list(images.shape[-2:]):
                        images = nn_f.interpolate(images, self.dataset.res)
                    self.colors_cpu = images.permute(0, 2, 3, 1).flatten(0, 2)
                if self.colors_cpu is not None:
                    self.colors = self.colors_cpu.to(self.dataset.device, non_blocking=True)

                if self.dataset.depth_path and self.depths_cpu is None:
                    depths = self.dataset._decode_depth_images(
                        img.load(self.depth_path % i for i in self.indices))
                    if self.dataset.res != list(depths.shape[-2:]):
                        depths = nn_f.interpolate(depths, self.dataset.res)
                    self.depths_cpu = depths.flatten(0, 2)
                if self.depths_cpu is not None:
                    self.depths = self.depths_cpu.to(self.dataset.device, non_blocking=True)

                if self.dataset.bins_path and self.bins_cpu is None:
                    bins = img.load([self.dataset.bins_path % i for i in self.indices])
                    if self.dataset.res != list(bins.shape[-2:]):
                        bins = nn_f.interpolate(bins, self.dataset.res)
                    self.bins_cpu = bins.permute(0, 2, 3, 1).flatten(0, 2)
                if self.bins_cpu is not None:
                    self.bins = self.bins_cpu.to(self.dataset.device, non_blocking=True)

                torch.cuda.current_stream(self.dataset.device).synchronize()
                self.loaded = True
            except Exception as ex:
                print(ex)
                exit(-1)

        def __len__(self):
            return self.n_views * self.n_pixels_per_view

        def __getitem__(self, idx):
            if not self.loaded:
                self.load()
            view_idx = idx // self.n_pixels_per_view
            pix_idx = idx % self.n_pixels_per_view
            global_idx = self.indices[view_idx] * self.n_pixels_per_view + pix_idx
            rays_o = self.centers[view_idx]
            rays_d = self.dataset.cam_rays[pix_idx][:, None]  # (N, 1, 3)
            r = self.rots[view_idx].movedim(-1, -2)  # (N, 3, 3)
            rays_d = torch.matmul(rays_d, r)[:, 0]  # (N, 3)
            data = {
                'idx': global_idx,
                'rays_o': rays_o,
                'rays_d': rays_d,
                'level': self.dataset.level
            }
            if self.colors is not None:
                data['color'] = self.colors[idx]
            if self.depths is not None:
                data['depth'] = self.depths[idx]
            if self.bins is not None:
                data['bin'] = self.bins[idx]
            #data['view_idx'] = view_idx
            #data['pix_idx'] = pix_idx
            return data

    def __init__(self, desc: dict, desc_path: Path, *,
                 load_images: bool = True,
                 load_depths: bool = False,
                 load_bins: bool = False,
                 res: Tuple[int, int] = None,
                 views_to_load: Union[range, torch.Tensor] = None,
                 device: torch.device = None,
                 **kwargs):
        """
        Initialize data loader for spherical view synthesis task

        The dataset description file is a JSON file with following fields:

        - view_file_pattern: string, the path pattern of view images
        - view_res: { "x", "y" }, the resolution of view images
        - cam: { "fx", "fy", "cx", "cy" }, the focal and center of camera (in normalized image space)
        - view_centers: [ [ x, y, z ], ... ], centers of views
        - view_rots: [ [ m00, m01, ..., m22 ], ... ], rotation matrices of views

        :param dataset_desc_path ```str```: path to the data description file
        :param load_images ```bool```: whether load view images and return in __getitem__()
        :param load_depths ```bool```: whether load depth images and return in __getitem__()
        :param c ```int```: color space to convert view images to
        :param calculate_rays ```bool```: whether calculate rays
        """
        super().__init__(desc, desc_path, res=res, views_to_load=views_to_load, device=device,
                         load_images=load_images, load_depths=load_depths, load_bins=load_bins)

    def _decode_depth_images(self, input):
        disp_range = (1 / self.depth_range[0], 1 / self.depth_range[1])
        disp_val = (1 - input[..., 0, :, :]) * (disp_range[1] - disp_range[0]) + disp_range[0]
        return torch.reciprocal(disp_val)

    def _load_desc(self, res: Tuple[int, int], views_to_load: Union[range, torch.Tensor],
                   load_images: bool, load_depths: bool, load_bins: bool):
        super()._load_desc(res, views_to_load)
        self.image_path = load_images and self._get_data_path("view")
        self.depth_path = load_depths and self._get_data_path("depth")
        self.bins_path = load_bins and self._get_data_path("bins")
        self.cam_rays = self.cam.get_local_rays(flatten=True)