import os
import torch
import torch.nn.functional as nn_f
from typing import Dict, Tuple, Union
from operator import itemgetter
from pathlib import Path

from utils import img
from utils import color
from utils import sphere
from utils import math
from utils.mem_profiler import *
from .dataset import Dataset


class PanoDataset(Dataset):
    """
    Data loader for spherical view synthesis task

    Attributes
    --------
    data_dir ```str```: the directory of dataset\n
    view_file_pattern ```str```: the filename pattern of view images\n
    cam_params ```object```: camera intrinsic parameters\n
    centers ```Tensor(N, 3)```: centers of views\n
    view_rots ```Tensor(N, 3, 3)```: rotation matrices of views\n
    images ```Tensor(N, 3, H, W)```: images of views\n
    view_depths ```Tensor(N, H, W)```: depths of views\n
    """

    class Chunk(object):

        @property
        def n_views(self):
            return self.indices.size(0)

        @property
        def n_pixels_per_view(self):
            return self.dataset.n_pixels_per_view

        def __init__(self, id: int, dataset, chunk_data: Dict[str, torch.Tensor], *,
                     color: int, **kwargs):
            """
            [summary]

            :param dataset `PanoDataset`: dataset object
            :param indices `Tensor(N)`: indices of views
            :param centers `Tensor(N, 3)`: centers of views
            """
            self.id = id
            self.dataset = dataset
            self.indices = chunk_data['indices']
            self.centers = chunk_data['centers']
            self.color = color
            self.colors_cpu = None
            self.colors = None
            self.loaded = False

        def release(self):
            self.colors = None
            self.loaded = False
            MemProfiler.print_memory_stats(f'Chunk #{self.id} released')

        def load(self):
            if self.dataset.image_path and self.colors_cpu is None:
                images = color.cvt(img.load(self.dataset.image_path % i for i in self.indices),
                                   color.RGB, self.color)
                if self.dataset.res != tuple(images.shape[-2:]):
                    images = nn_f.interpolate(images, self.dataset.res)
                self.colors_cpu = images.permute(
                    0, 2, 3, 1)[:, self.dataset.pixels[:, 0], self.dataset.pixels[:, 1]].flatten(0, 1)
            if self.colors_cpu is not None:
                self.colors = self.colors_cpu.to(self.dataset.device)
            self.loaded = True
            MemProfiler.print_memory_stats(
                f'Chunk #{self.id} ({self.n_views} views, '
                f'{self.colors.numel() * self.colors.element_size() / 1024 / 1024:.2f}MB) loaded')

        def __len__(self):
            return self.n_views * self.n_pixels_per_view

        def __getitem__(self, idx):
            if not self.loaded:
                self.load()
            view_idx = idx // self.n_pixels_per_view
            pix_idx = idx % self.n_pixels_per_view
            global_idx = self.indices[view_idx] * self.n_pixels_per_view + pix_idx
            rays_o = self.centers[view_idx]
            rays_d = self.dataset.rays[pix_idx]
            data = {
                'idx': global_idx,
                'rays_o': rays_o,
                'rays_d': rays_d,
                'level': self.dataset.level
            }
            if self.colors is not None:
                data['color'] = self.colors[idx]
            return data

    @property
    def n_pixels_per_view(self):
        return self.pixels.size(0)

    def __init__(self, desc: dict, desc_path: Path, *,
                 load_images: bool = True,
                 res: Tuple[int, int] = None,
                 views_to_load: Union[range, torch.Tensor] = None,
                 device: torch.device = None,
                 **kwargs):
        """
        Initialize data loader for spherical view synthesis task

        The dataset description file is a JSON file with following fields:

        - view_file_pattern: string, the path pattern of view images
        - view_res: { "x", "y" }, the resolution of view
        - depth_range: { "min", "max" }, the depth range
        - range: { "min": [...], "max": [...] }, the range of translation and rotation
        - centers: [ [ x, y, z ], ... ], centers of views

        :param desc_path ```str```: path to the data description file
        :param load_images ```bool```: whether load view images and return in __getitem__()
        :param c ```int```: color space to convert view images to
        :param calculate_rays ```bool```: whether calculate rays
        """
        super().__init__(desc, desc_path, res=res, views_to_load=views_to_load, device=device,
                         load_images=load_images)

    def get_data(self):
        return {
            'indices': self.indices,
            'centers': self.centers
        }

    def _load_desc(self, res: Tuple[int, int], views_to_load: Union[range, torch.Tensor],
                   load_images: bool):
        super()._load_desc(res, views_to_load)
        self.image_path = load_images and self._get_data_path("view")
        self.pixels, self.rays = self._get_pano_rays()

    def _get_pano_rays(self):
        """
        Get unprojected rays of pixels on a panorama

        :return `Tensor(N, 2)`: rays' pixel coordinates in pano image
        :return `Tensor(N, 3)`: rays' directions with one unit length
        """
        phi = (torch.arange(self.res[0], device=self.device) + 0.5) / self.res[0] * math.pi  # (H)
        length = (phi.sin() * self.res[1] * 0.5).ceil() * 2
        cols = torch.arange(self.res[1], device=self.device)[None, :].expand(*self.res)  # (H, W)
        mask = torch.logical_and(cols >= (self.res[1] - length[:, None]) / 2,
                                 cols < (self.res[1] + length[:, None]) / 2)  # (H, W)
        pixs = mask.nonzero()  # (N, 2)
        pixs_phi = (0.5 - (pixs[:, 0] + 0.5) / self.res[0]) * math.pi
        pixs_theta = (pixs[:, 1] * 2 + 1 - self.res[1]) / length[pixs[:, 0]] * math.pi
        spher_coords = torch.stack([torch.ones_like(pixs_phi), pixs_theta, pixs_phi], dim=-1)
        return pixs, sphere.spherical2cartesian(spher_coords)  # (N, 3)