import torch
from typing import Tuple

from .generic import *
from .space import Space
from clib import *
from utils import device
from utils import sphere
from utils import misc
from utils import math
from utils.module import Module
from utils.samples import Samples
from utils.perf import perf, checkpoint


class Bins(object):

    @property
    def up(self):
        return self.bounds[1:]

    @property
    def lo(self):
        return self.bounds[:-1]

    def __init__(self, vals: torch.Tensor):
        self.vals = vals
        self.bounds = torch.cat([
            self.vals[:1],
            0.5 * (self.vals[1:] + self.vals[:-1]),
            self.vals[-1:]
        ])

    @staticmethod
    def linspace(val_range: Tuple[float, float], N: int, device: torch.device = None):
        return Bins(torch.linspace(*val_range, N, device=device))

    def to(self, device: torch.device):
        self.vals = self.vals.to(device)
        self.bounds = self.bounds.to(device)


class Sampler(Module):

    def __init__(self, **kwargs):
        """
        Initialize a Sampler module
        """
        super().__init__()
        self._samples_indices_cached = None

    def _sample(self, range: Tuple[float, float], n_rays: int, n_samples: int, perturb: bool,
                device: torch.device) -> torch.Tensor:
        """
        [summary]

        :param t_range `float, float`: sampling range
        :param n_rays `int`: number of rays (B)
        :param n_samples `int`: number of samples per ray (P)
        :param perturb `bool`: whether perturb sampling
        :param device `torch.device`: the device used to create tensors
        :return `Tensor(B, P+1)`: sampling bounds of t
        """
        bounds = torch.linspace(*range, n_samples + 1, device=device)  # (P+1)
        if perturb:
            rand_bounds = torch.cat([
                bounds[:1],
                0.5 * (bounds[1:] + bounds[:-1]),
                bounds[-1:]
            ])
            rand_vals = torch.rand(n_rays, n_samples + 1, device=device)
            bounds = rand_bounds[:-1] * (1 - rand_vals) + rand_bounds[1:] * rand_vals
        else:
            bounds = bounds[None].expand(n_rays, -1)
        return bounds

    def _get_samples_indices(self, pts: torch.Tensor):
        if self._samples_indices_cached is None\
                or self._samples_indices_cached.shape[0] < pts.shape[0]\
                or self._samples_indices_cached.shape[1] < pts.shape[1]:
            self._samples_indices_cached = misc.meshgrid(
                *pts.shape[:2], swap_dim=True, device=pts.device)
        return self._samples_indices_cached[:pts.shape[0], :pts.shape[1]]

    @perf
    def forward(self, rays_o: torch.Tensor, rays_d: torch.Tensor, space_: Space, *,
                sample_range: Tuple[float, float], n_samples: int, lindisp: bool = False,
                perturb_sample: bool = True, spherical: bool = False,
                **kwargs) -> Tuple[Samples, torch.Tensor]:
        """
        Sample points along rays.

        :param rays_o `Tensor(B, 3)`: rays' origin
        :param rays_d `Tensor(B, 3)`: rays' direction
        :param sample_range `float, float`: sampling range
        :param n_samples `int`: number of samples per ray
        :param lindisp `bool`: whether sample linearly in disparity space (1/depth)
        :param perturb_sample `bool`: whether perturb sampling
        :return `Samples(B, P)`: samples
        """
        if spherical:
            t_bounds = self._sample(sample_range, rays_o.shape[0], n_samples, perturb_sample,
                                    rays_o.device)
            t0, t1 = t_bounds[:, :-1], t_bounds[:, 1:]  # (B, P)
            t = (t0 + t1) * .5

            p, z = sphere.ray_sphere_intersect(rays_o, rays_d, t.reciprocal())
            p = sphere.cartesian2spherical(p, inverse_r=True)
            vidxs = space_.get_voxel_indices(p)
            return Samples(
                pts=p,
                dirs=rays_d[:, None].expand(-1, n_samples, -1),
                depths=z,
                dists=(t1 + math.tiny).reciprocal() - t0.reciprocal(),
                voxel_indices=vidxs,
                indices=self._get_samples_indices(p),
                t=t
            )
        else:
            sample_range = (1 / sample_range[0], 1 / sample_range[1]) if lindisp else sample_range
            z_bounds = self._sample(sample_range, rays_o.shape[0], n_samples, perturb_sample,
                                    rays_o.device)
            if lindisp:
                z_bounds = z_bounds.reciprocal()
            z0, z1 = z_bounds[:, :-1], z_bounds[:, 1:]  # (B, P)
            z = (z0 + z1) * .5
            p = rays_o[:, None] + rays_d[:, None] * z[..., None]
            vidxs = space_.get_voxel_indices(p)
            return Samples(
                pts=p,
                dirs=rays_d[:, None].expand(-1, n_samples, -1),
                depths=z,
                dists=z1 - z0,
                voxel_indices=vidxs,
                indices=self._get_samples_indices(p),
                t=z
            )


class PdfSampler(Module):

    def __init__(self, *, depth_range: Tuple[float, float], n_samples: int, perturb_sample: bool,
                 spherical: bool, lindisp: bool, **kwargs):
        """
        Initialize a Sampler module

        :param depth_range: depth range for sampler
        :param n_samples: count to sample along ray
        :param perturb_sample: perturb the sample depths
        :param lindisp: If True, sample linearly in inverse depth rather than in depth
        """
        super().__init__()
        self.lindisp = lindisp
        self.perturb_sample = perturb_sample
        self.spherical = spherical
        self.n_samples = n_samples
        self.s_range = (1 / depth_range[0], 1 / depth_range[1]) if self.lindisp else depth_range

    def forward(self, rays_o, rays_d, *, weights, s_vals=None, include_s_vals=False, **kwargs):
        """
        Sample points along rays. return Spherical or Cartesian coordinates, 
        specified by `self.shperical`

        :param rays_o `Tensor(B, 3)`: rays' origin
        :param rays_d `Tensor(B, 3)`: rays' direction
        :param weights `Tensor(B, M)`: weights of sample bins
        :param s_vals `Tensor(B, M)`: (optional) center of sample bins
        :param include_s_vals `bool`: (default to `False`) include `s_vals` in the sample array
        :return `Tensor(B, N, 3)`: sampled points
        :return `Tensor(B, N)`: corresponding depths along rays
        """
        if s_vals is None:
            s_vals = torch.linspace(*self.s_range, self.n_samples, device=device.default())
        s = self.sample_pdf(Bins(s_vals).bounds, weights, self.n_samples, det=self.perturb_sample)
        if include_s_vals:
            s = torch.cat([s, s_vals], dim=-1)
        s = torch.sort(s, descending=self.lindisp)[0]
        z = torch.reciprocal(s) if self.lindisp else s
        if self.spherical:
            pts, depths = sphere.ray_sphere_intersect(rays_o, rays_d, z)
            sphers = sphere.cartesian2spherical(pts, inverse_r=self.lindisp)
            return sphers, depths, s, pts
        else:
            return rays_o[..., None, :] + rays_d[..., None, :] * z[..., None], z, s

    def sample_pdf(self, bins: torch.Tensor, weights: torch.Tensor, N: int, det=True):
        '''
        :param bins `Tensor(..., M+1)`: bounds of bins
        :param weights `Tensor(..., M)`: weights of bins
        :param N `int`: # of samples along each ray
        :param det `bool`: (default to `True`) perform deterministic sampling or not
        :return `Tensor(..., N)`: samples
        '''
        # Get pdf
        weights = weights + math.tiny                                          # prevent nans
        pdf = weights / torch.sum(weights, dim=-1, keepdim=True)                # [..., M]
        cdf = torch.cat([
            torch.zeros_like(pdf[..., :1]),
            torch.cumsum(pdf, dim=-1)
        ], dim=-1)                                                              # [..., M+1]

        # Take uniform samples
        dots_sh = list(weights.shape[:-1])
        M = weights.shape[-1]

        u = torch.linspace(0, 1, N, device=bins.device).expand(dots_sh + [N]) \
            if det else torch.rand(dots_sh + [N], device=bins.device)           # [..., N]

        # Invert CDF
        # [..., N, 1] >= [..., 1, M] ----> [..., N, M] ----> [..., N,]
        above_inds = torch.sum(u[..., None] >= cdf[..., None, :-1], dim=-1).long()

        # random sample inside each bin
        below_inds = torch.clamp(above_inds - 1, min=0)
        inds_g = torch.stack((below_inds, above_inds), dim=-1)                  # [..., N, 2]

        cdf = cdf[..., None, :].expand(dots_sh + [N, M + 1])                    # [..., N, M+1]
        cdf_g = torch.gather(cdf, dim=-1, index=inds_g)                         # [..., N, 2]

        bins = bins[..., None, :].expand(dots_sh + [N, M + 1])                  # [..., N, M+1]
        bins_g = torch.gather(bins, dim=-1, index=inds_g)                       # [..., N, 2]

        # fix numeric issue
        denom = cdf_g[..., 1] - cdf_g[..., 0]                                   # [..., N]
        denom = torch.where(denom < math.tiny, torch.ones_like(denom), denom)
        t = (u - cdf_g[..., 0]) / denom

        samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0] + math.tiny)

        return samples


class VoxelSampler(Module):

    def __init__(self, *, sample_step: float, **kwargs):
        """
        Initialize a VoxelSampler module

        :param perturb_sample: perturb the sample depths
        :param step_size: step size
        """
        super().__init__()
        self.sample_step = sample_step

    def _forward(self, rays_o: torch.Tensor, rays_d: torch.Tensor, space_module: Space, *,
                 perturb_sample: bool, **kwargs) -> Tuple[Samples, torch.Tensor]:
        """
        [summary]

        :param rays_o `Tensor(N, 3)`: rays' origin positions
        :param rays_d `Tensor(N, 3)`: rays' directions
        :param step_size `float`: gap between samples along a ray
        :return `Samples(N', P)`: samples along valid rays (which hit at least one voxel)
        :return `Tensor(N)`: valid rays mask
        """
        intersections = space_module.ray_intersect(rays_o, rays_d, 100)
        valid_rays_mask = intersections.hits > 0
        rays_o = rays_o[valid_rays_mask]
        rays_d = rays_d[valid_rays_mask]
        intersections = intersections[valid_rays_mask]  # (N) -> (N')
        n_rays = rays_o.size(0)
        ray_index_list = torch.arange(n_rays, device=rays_o.device, dtype=torch.long)  # (N')

        hits = intersections.hits
        min_depths = intersections.min_depths
        max_depths = intersections.max_depths
        voxel_indices = intersections.voxel_indices

        rays_near_depth = min_depths[:, :1]  # (N', 1)
        rays_far_depth = max_depths[ray_index_list, hits - 1][:, None]  # (N', 1)
        rays_length = rays_far_depth - rays_near_depth
        rays_steps = (rays_length / self.sample_step).ceil().long()
        rays_step_size = rays_length / rays_steps
        max_steps = rays_steps.max().item()
        rays_step = torch.arange(max_steps, device=rays_o.device,
                                 dtype=torch.float)[None].repeat(n_rays, 1)  # (N', P)
        invalid_samples_mask = rays_step >= rays_steps
        samples_min_depth = rays_near_depth + rays_step * rays_step_size
        samples_depth = samples_min_depth + rays_step_size \
            * (torch.rand_like(samples_min_depth) if perturb_sample else 0.5)  # (N', P)
        samples_dist = rays_step_size.repeat(1, max_steps)  # (N', 1) -> (N', P)
        samples_voxel_index = voxel_indices[
            ray_index_list[:, None],
            torch.searchsorted(max_depths, samples_depth)
        ]  # (N', P)
        samples_depth[invalid_samples_mask] = math.huge
        samples_dist[invalid_samples_mask] = 0
        samples_voxel_index[invalid_samples_mask] = -1

        rays_o, rays_d = rays_o[:, None], rays_d[:, None]
        return Samples(
            pts=rays_o + rays_d * samples_depth[..., None],
            dirs=rays_d.expand(-1, max_steps, -1),
            depths=samples_depth,
            dists=samples_dist,
            voxel_indices=samples_voxel_index
        ), valid_rays_mask

    @perf
    def forward(self, rays_o: torch.Tensor, rays_d: torch.Tensor,
                space: Space, *, perturb_sample: bool, **kwargs) -> Tuple[Samples, torch.Tensor]:
        """
        [summary]

        :param rays_o `Tensor(N, 3)`: [description]
        :param rays_d `Tensor(N, 3)`: [description]
        :param step_size `float`: [description]
        :return `Samples(N, P)`: [description]
        """
        intersections = space.ray_intersect(rays_o, rays_d, 100)
        valid_rays_mask = intersections.hits > 0
        rays_o = rays_o[valid_rays_mask]
        rays_d = rays_d[valid_rays_mask]
        intersections = intersections[valid_rays_mask]  # (N) -> (N')

        checkpoint("Ray intersect")

        if intersections.size == 0:
            return None, valid_rays_mask
        else:
            min_depth = intersections.min_depths
            max_depth = intersections.max_depths
            pts_idx = intersections.voxel_indices
            dists = max_depth - min_depth
            tot_dists = dists.sum(dim=-1, keepdim=True)  # (N, 1)
            probs = dists / tot_dists
            steps = tot_dists[:, 0] / self.sample_step

            # sample points and use middle point approximation
            sampled_indices, sampled_depths, sampled_dists = inverse_cdf_sampling(
                pts_idx, min_depth, max_depth, probs, steps, -1, not perturb_sample)
            sampled_indices = sampled_indices.long()
            invalid_idx_mask = sampled_indices.eq(-1)
            sampled_dists.clamp_min_(0).masked_fill_(invalid_idx_mask, 0)
            sampled_depths.masked_fill_(invalid_idx_mask, math.huge)

            checkpoint("Inverse CDF sampling")

            rays_o, rays_d = rays_o[:, None], rays_d[:, None]
            return Samples(
                pts=rays_o + rays_d * sampled_depths[..., None],
                dirs=rays_d.expand(-1, sampled_depths.size(1), -1),
                depths=sampled_depths,
                dists=sampled_dists,
                voxel_indices=sampled_indices
            ), valid_rays_mask