from math import ceil
import torch
import numpy as np
from typing import List, NoReturn, Tuple, Union
from torch import nn
from plyfile import PlyData, PlyElement

from utils.geometry import *
from utils.constants import *
from utils.voxels import *
from utils.perf import perf
from clib import *


class Intersections:
    min_depths: torch.Tensor
    """`Tensor(N, P)` Min ray depths of intersected voxels"""

    max_depths: torch.Tensor
    """`Tensor(N, P)` Max ray depths of intersected voxels"""

    voxel_indices: torch.Tensor
    """`Tensor(N, P)` Indices of intersected voxels"""

    hits: torch.Tensor
    """`Tensor(N)` Number of hits"""

    @property
    def size(self):
        return self.hits.size(0)

    def __init__(self, min_depths: torch.Tensor, max_depths: torch.Tensor,
                 voxel_indices: torch.Tensor, hits: torch.Tensor) -> None:
        self.min_depths = min_depths
        self.max_depths = max_depths
        self.voxel_indices = voxel_indices
        self.hits = hits

    def __getitem__(self, index):
        return Intersections(
            min_depths=self.min_depths[index],
            max_depths=self.max_depths[index],
            voxel_indices=self.voxel_indices[index],
            hits=self.hits[index])


class Space(nn.Module):
    bbox: Union[torch.Tensor, None]
    """`Tensor(2, 3)` Bounding box"""

    def __init__(self, *, bbox: List[float] = None, **kwargs):
        super().__init__()
        if bbox is None:
            self.bbox = None
        else:
            self.register_buffer('bbox', torch.Tensor(bbox).reshape(2, 3), persistent=False)

    def create_embedding(self, n_dims: int, name: str = 'default') -> torch.nn.Embedding:
        raise NotImplementedError

    def extract_embedding(self, pts: torch.Tensor, voxel_indices: torch.Tensor,
                          name: str = 'default') -> torch.Tensor:
        raise NotImplementedError

    def ray_intersect(self, rays_o: torch.Tensor, rays_d: torch.Tensor, n_max_hits: int) -> Intersections:
        raise NotImplementedError

    def get_voxel_indices(self, pts: torch.Tensor) -> torch.Tensor:
        voxel_indices = torch.zeros_like(pts[..., 0], dtype=torch.long)
        if self.bbox is not None:
            out_bbox = torch.logical_or(pts < self.bbox[0], pts >= self.bbox[1]).any(-1)  # (N...)
            voxel_indices[out_bbox] = -1
        return voxel_indices

    @torch.no_grad()
    def pruning(self, score_fn, threshold: float = 0.5, train_stats=False):
        raise NotImplementedError()

    @torch.no_grad()
    def splitting(self):
        raise NotImplementedError()


class Voxels(Space):
    steps: torch.Tensor
    """`Tensor(3)` Steps along each dimension"""

    corners: torch.Tensor
    """`Tensor(C, 3)` Corner positions"""

    voxels: torch.Tensor
    """`Tensor(M, 3)` Voxel centers"""

    corner_indices: torch.Tensor
    """`Tensor(M, 8)` Voxel corner indices"""

    voxel_indices_in_grid: torch.Tensor
    """`Tensor(G)` Indices in voxel list or -1 for pruned space"""

    @property
    def dims(self) -> int:
        """`int` Number of dimensions"""
        return self.steps.size(0)

    @property
    def n_voxels(self) -> int:
        """`int` Number of voxels"""
        return self.voxels.size(0)

    @property
    def n_corner(self) -> int:
        """`int` Number of corners"""
        return self.corners.size(0)

    @property
    def voxel_size(self) -> torch.Tensor:
        """`Tensor(3)` Voxel size"""
        return (self.bbox[1] - self.bbox[0]) / self.steps

    @property
    def device(self) -> torch.device:
        return self.voxels.device

    def __init__(self, *, voxel_size: float = None,
                 steps: Union[torch.Tensor, Tuple[int, int, int]] = None, **kwargs) -> None:
        super().__init__(**kwargs)
        if self.bbox is None:
            raise ValueError("Missing argument 'bbox'")
        if voxel_size is not None:
            self.register_buffer('steps', get_grid_steps(self.bbox, voxel_size))
        else:
            self.register_buffer('steps', torch.tensor(steps, dtype=torch.long))
        self.register_buffer('voxels', init_voxels(self.bbox, self.steps))
        corners, corner_indices = get_corners(self.voxels, self.bbox, self.steps)
        self.register_buffer("corners", corners)
        self.register_buffer("corner_indices", corner_indices)
        self.register_buffer('voxel_indices_in_grid', torch.arange(self.n_voxels))
        self._register_load_state_dict_pre_hook(self._before_load_state_dict)

    def create_embedding(self, n_dims: int, name: str = 'default') -> torch.nn.Embedding:
        """
        Create a embedding on voxel corners.

        :param name `str`: embedding name
        :param n_dims `int`: embedding dimension
        :return `Embedding(n_corners, n_dims)`: new embedding on voxel corners
        """
        name = f'emb_{name}'
        self.add_module(name, torch.nn.Embedding(self.n_corners.item(), n_dims))
        return self.__getattr__(name)

    def get_embedding(self, name: str = 'default') -> torch.nn.Embedding:
        return getattr(self, f'emb_{name}')

    def extract_embedding(self, pts: torch.Tensor, voxel_indices: torch.Tensor,
                          name: str = 'default') -> torch.Tensor:
        """
        Extract embedding values at given points using trilinear interpolation.

        :param pts `Tensor(N, 3)`: points to extract values
        :param voxel_indices `Tensor(N)`: corresponding voxel indices
        :param name `str`: embedding name, default to 'default'
        :return `Tensor(N, X)`: extracted values
        """
        emb = self.get_embedding(name)
        if emb is None:
            raise KeyError(f"Embedding '{name}' doesn't exist")
        voxels = self.voxels[voxel_indices]  # (N, 3)
        corner_indices = self.corner_indices[voxel_indices]  # (N, 8)
        p = (pts - voxels) / self.voxel_size + 0.5  # (N, 3) normed-coords in voxel
        features = emb(corner_indices).reshape(pts.size(0), 8, -1)  # (N, 8, X)
        return trilinear_interp(p, features)

    @perf
    def ray_intersect(self, rays_o: torch.Tensor, rays_d: torch.Tensor, n_max_hits: int) -> Intersections:
        """
        Calculate intersections of rays and voxels.

        :param rays_o `Tensor(N, 3)`: rays' origin
        :param rays_d `Tensor(N, 3)`: rays' direction
        :param n_max_hits `int`: maximum number of hits (for allocating enough space)
        :return `Intersection`: intersections of rays and voxels
        """
        # Prepend a dim to meet the requirement of external call
        rays_o = rays_o[None].contiguous()
        rays_d = rays_d[None].contiguous()

        voxel_indices, min_depths, max_depths = self._ray_intersect(rays_o, rays_d, n_max_hits)
        invalid_voxel_mask = voxel_indices.eq(-1)
        hits = n_max_hits - invalid_voxel_mask.sum(-1)

        # Sort intersections according to their depths
        min_depths.masked_fill_(invalid_voxel_mask, HUGE_FLOAT)
        max_depths.masked_fill_(invalid_voxel_mask, HUGE_FLOAT)
        min_depths, sorted_idx = min_depths.sort(dim=-1)
        max_depths = max_depths.gather(-1, sorted_idx)
        voxel_indices = voxel_indices.gather(-1, sorted_idx)

        return Intersections(
            min_depths=min_depths[0],
            max_depths=max_depths[0],
            voxel_indices=voxel_indices[0],
            hits=hits[0]
        )

    @perf
    def get_voxel_indices(self, pts: torch.Tensor) -> torch.Tensor:
        """
        Get voxel indices of points.

        If a point is not in any valid voxels, its corresponding voxel index is -1.

        :param pts `Tensor(N..., 3)`: points
        :return `Tensor(N...)`: corresponding voxel indices
        """
        grid_indices, out_mask = to_grid_indices(pts, self.bbox, steps=self.steps)
        grid_indices[out_mask] = 0
        voxel_indices = self.voxel_indices_in_grid[grid_indices]
        voxel_indices[out_mask] = -1
        return voxel_indices

    @torch.no_grad()
    def splitting(self) -> None:
        """
        Split voxels into smaller voxels with half size.
        """
        n_voxels_before = self.n_voxels
        self.steps *= 2
        self.voxels = split_voxels(self.voxels, self.voxel_size, 2, align_border=False)\
            .reshape(-1, 3)
        self._update_corners()
        self._update_voxel_indices_in_grid()
        return n_voxels_before, self.n_voxels

    @torch.no_grad()
    def prune(self, keeps: torch.Tensor) -> Tuple[int, int]:
        self.voxels = self.voxels[keeps]
        self.corner_indices = self.corner_indices[keeps]
        self._update_voxel_indices_in_grid()
        return keeps.size(0), keeps.sum().item()

    @torch.no_grad()
    def pruning(self, score_fn, threshold: float = 0.5) -> None:
        scores = self._get_scores(score_fn, lambda x: torch.max(x, -1)[0])  # (M)
        return self.prune(scores > threshold)

    def n_voxels_along_dim(self, dim: int) -> torch.Tensor:
        sum_dims = [val for val in range(self.dims) if val != dim]
        return self.voxel_indices_in_grid.reshape(*self.steps).ne(-1).sum(sum_dims)

    def balance_cut(self, dim: int, n_parts: int) -> List[int]:
        n_voxels_list = self.n_voxels_along_dim(dim)
        cdf = (n_voxels_list.cumsum(0) / self.n_voxels * n_parts).tolist()
        bins = []
        part = 1
        offset = 0
        for i in range(len(cdf)):
            if cdf[i] >= part:
                bins.append(i + 1 - offset)
                offset = i + 1
                part = int(cdf[i]) + 1
        return bins

    def sample(self, bits: int, perturb: bool = False) -> Tuple[torch.Tensor, torch.Tensor]:
        sampled_xyz = split_voxels(self.voxels, self.voxel_size, bits)
        sampled_idx = torch.arange(self.n_voxels, device=self.device)[:, None].expand(
            *sampled_xyz.shape[:2])
        sampled_xyz, sampled_idx = sampled_xyz.reshape(-1, 3), sampled_idx.flatten()

    @torch.no_grad()
    def _get_scores(self, score_fn, reduce_fn=None, bits=16) -> torch.Tensor:
        def get_scores_once(pts, idxs):
            scores = score_fn(pts, idxs).reshape(-1, bits ** 3)  # (B, P)
            if reduce_fn is not None:
                scores = reduce_fn(scores)  # (B[, ...])
            return scores

        sampled_xyz, sampled_idx = self.sample(bits)
        chunk_size = 64
        return torch.cat([
            get_scores_once(sampled_xyz[i:i + chunk_size], sampled_idx[i:i + chunk_size])
            for i in range(0, self.voxels.size(0), chunk_size)
        ], 0)  # (M[, ...])

    def _ray_intersect(self, rays_o: torch.Tensor, rays_d: torch.Tensor, n_max_hits: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        return aabb_ray_intersect(self.voxel_size, n_max_hits, self.voxels, rays_o, rays_d)


    def _update_corners(self):
        """
        Update voxel corners.
        """
        corners, corner_indices = get_corners(self.voxels, self.bbox, self.steps)
        self.register_buffer("corners", corners)
        self.register_buffer("corner_indices", corner_indices)

    def _update_voxel_indices_in_grid(self):
        """
        Update voxel indices in grid.
        """
        grid_indices, _ = to_grid_indices(self.voxels, self.bbox, steps=self.steps)
        self.voxel_indices_in_grid = grid_indices.new_full([self.steps.prod().item()], -1)
        self.voxel_indices_in_grid[grid_indices] = torch.arange(self.n_voxels, device=self.device)

    @torch.no_grad()
    def _before_load_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys,
                                unexpected_keys, error_msgs):
        # Handle buffers
        for name, buffer in self.named_buffers(recurse=False):
            if name in self._non_persistent_buffers_set:
                continue
            buffer.resize_as_(state_dict[prefix + name])

        # Handle embeddings
        for name, module in self.named_modules():
            if name.startswith('emb_'):
                setattr(self, name, torch.nn.Embedding(self.n_corners.item(), module.embedding_dim))


class Octree(Voxels):

    def __init__(self, **kwargs) -> None:
        super().__init__(**kwargs)
        self.nodes_cached = None
        self.tree_cached = None

    def get(self) -> Tuple[torch.Tensor, torch.Tensor]:
        if self.nodes_cached is None:
            self.nodes_cached, self.tree_cached = build_easy_octree(
                self.voxels, 0.5 * self.voxel_size)
        return self.nodes_cached, self.tree_cached

    def clear(self):
        self.nodes_cached = None
        self.tree_cached = None

    def _ray_intersect(self, rays_o: torch.Tensor, rays_d: torch.Tensor, n_max_hits: int):
        nodes, tree = self.get()
        return octree_ray_intersect(self.voxel_size, n_max_hits, nodes, tree, rays_o, rays_d)

    @torch.no_grad()
    def splitting(self):
        ret = super().splitting()
        self.clear()
        return ret

    @torch.no_grad()
    def prune(self, keeps: torch.Tensor) -> Tuple[int, int]:
        ret = super().prune(keeps)
        self.clear()
        return ret