import os
import torch
import torch.nn.functional as nn_f
from typing import Dict, Tuple, Union
from operator import itemgetter
from pathlib import Path

from utils import img
from utils import color
from utils import sphere
from utils.mem_profiler import *
from utils.constants import *


class PanoDataset(object):
    """
    Data loader for spherical view synthesis task

    Attributes
    --------
    data_dir ```str```: the directory of dataset\n
    view_file_pattern ```str```: the filename pattern of view images\n
    cam_params ```object```: camera intrinsic parameters\n
    centers ```Tensor(N, 3)```: centers of views\n
    view_rots ```Tensor(N, 3, 3)```: rotation matrices of views\n
    images ```Tensor(N, 3, H, W)```: images of views\n
    view_depths ```Tensor(N, H, W)```: depths of views\n
    """

    class Chunk(object):

        @property
        def n_views(self):
            return self.indices.size(0)

        @property
        def n_pixels_per_view(self):
            return self.dataset.n_pixels_per_view

        def __init__(self, id: int, dataset, chunk_data: Dict[str, torch.Tensor], *,
                     color: int, **kwargs):
            """
            [summary]

            :param dataset `PanoDataset`: dataset object
            :param indices `Tensor(N)`: indices of views
            :param centers `Tensor(N, 3)`: centers of views
            """
            self.id = id
            self.dataset = dataset
            self.indices = chunk_data['indices']
            self.centers = chunk_data['centers']
            self.color = color
            self.colors_cpu = None
            self.colors = None
            self.loaded = False

        def release(self):
            self.colors = None
            self.loaded = False
            MemProfiler.print_memory_stats(f'Chunk #{self.id} released')

        def load(self):
            if self.dataset.image_path is not None and self.colors_cpu is None:
                images = color.cvt(img.load(self.dataset.image_path % i for i in self.indices),
                                   color.RGB, self.color)
                if self.dataset.res != tuple(images.shape[-2:]):
                    images = nn_f.interpolate(images, self.dataset.res)
                self.colors_cpu = images.permute(0, 2, 3, 1) \
                    [:, self.dataset.pixels[:, 0], self.dataset.pixels[:, 1]].flatten(0, 1)
            if self.colors_cpu is not None:
                self.colors = self.colors_cpu.to(self.dataset.device)
            self.loaded = True
            MemProfiler.print_memory_stats(
                f'Chunk #{self.id} ({self.n_views} views, '
                f'{self.colors.numel() * self.colors.element_size() / 1024 / 1024:.2f}MB) loaded')

        def __len__(self):
            return self.n_views * self.n_pixels_per_view

        def __getitem__(self, idx):
            if not self.loaded:
                self.load()
            view_idx = idx // self.n_pixels_per_view
            pix_idx = idx % self.n_pixels_per_view
            global_idx = self.indices[view_idx] * self.n_pixels_per_view + pix_idx
            extra_data = {}
            if self.colors is not None:
                extra_data['color'] = self.colors[idx]
            rays_o = self.centers[view_idx]
            rays_d = self.dataset.rays[pix_idx]
            return global_idx, rays_o, rays_d, extra_data

    @property
    def n_views(self):
        return self.centers.size(0)

    @property
    def n_pixels_per_view(self):
        return self.pixels.size(0)

    @property
    def n_pixels(self):
        return self.n_views * self.n_pixels_per_view

    def __init__(self, desc: dict, root: Path, name: str, *,
                 load_images: bool = True,
                 res: Tuple[int, int] = None,
                 views_to_load: Union[range, torch.Tensor] = None,
                 device: torch.device = None,
                 **kwargs):
        """
        Initialize data loader for spherical view synthesis task

        The dataset description file is a JSON file with following fields:

        - view_file_pattern: string, the path pattern of view images
        - view_res: { "x", "y" }, the resolution of view
        - depth_range: { "min", "max" }, the depth range
        - range: { "min": [...], "max": [...] }, the range of translation and rotation
        - centers: [ [ x, y, z ], ... ], centers of views

        :param desc_path ```str```: path to the data description file
        :param load_images ```bool```: whether load view images and return in __getitem__()
        :param c ```int```: color space to convert view images to
        :param calculate_rays ```bool```: whether calculate rays
        """
        self.root = root
        self.name = name
        self.device = device
        self._load_desc(desc, res, views_to_load, load_images)

    def get_data(self):
        return {
            'indices': self.indices,
            'centers': self.centers
        }

    def _load_desc(self, desc: dict,
                   res: Tuple[int, int],
                   views_to_load: Union[range, torch.Tensor],
                   load_images: bool):
        if load_images and desc.get('view_file_pattern'):
            file_pattern = desc['view_file_pattern']
            if "/" not in file_pattern:
                file_pattern = f"{self.name}/{file_pattern}"
            self.image_path = str(self.root / file_pattern)
        else:
            self.image_path = None
        self.res = res if res else itemgetter("y", "x")(desc['view_res'])
        self.depth_range = itemgetter("min", "max")(desc['depth_range']) \
            if 'depth_range' in desc else None
        self.bbox = None
        self.samples = desc.get('samples')
        self.centers = torch.tensor(desc['view_centers'], device=self.device)  # (N, 3)
        self.indices = torch.tensor(desc.get('views') or [*range(self.centers.size(0))],
                                    device=self.device)

        if views_to_load is not None:
            self.centers = self.centers[views_to_load]
            self.indices = self.indices[views_to_load]

        self.pixels, self.rays = self._get_pano_rays()

        if desc.get('gl_coord'):
            print('Convert from OGL coordinate to DX coordinate (i. e. flip z axis)')
            self.centers[:, 2] *= -1

    def _get_pano_rays(self):
        """
        Get unprojected rays of pixels on a panorama

        :return `Tensor(N, 2)`: rays' pixel coordinates in pano image
        :return `Tensor(N, 3)`: rays' directions with one unit length
        """
        phi = (torch.arange(self.res[0], device=self.device) + 0.5) / self.res[0] * PI  # (H)
        length = (phi.sin() * self.res[1] * 0.5).ceil() * 2
        cols = torch.arange(self.res[1], device=self.device)[None, :].expand(*self.res)  # (H, W)
        mask = torch.logical_and(cols >= (self.res[1] - length[:, None]) / 2,
                                 cols < (self.res[1] + length[:, None]) / 2)  # (H, W)
        pixs = mask.nonzero()  # (N, 2)
        pixs_phi = (0.5 - (pixs[:, 0] + 0.5) / self.res[0]) * PI
        pixs_theta = (pixs[:, 1] * 2 + 1 - self.res[1]) / length[pixs[:, 0]] * PI
        spher_coords = torch.stack([torch.ones_like(pixs_phi), pixs_theta, pixs_phi], dim=-1)
        return pixs, sphere.spherical2cartesian(spher_coords)  # (N, 3)