from typing import Tuple
import torch
import torch.nn as nn
from utils import device
from utils import sphere
from utils.constants import *
from .generic import *


class Bins(object):

    def __init__(self, vals: torch.Tensor):
        self.vals = vals
        self.bounds = torch.cat([
            self.vals[:1],
            0.5 * (self.vals[1:] + self.vals[:-1]),
            self.vals[-1:]
        ])
        self.up = self.bounds[1:]
        self.lo = self.bounds[:-1]

    @staticmethod
    def linspace(val_range: Tuple[float, float], N: int, device: torch.device = None):
        return Bins(torch.linspace(*val_range, N, device=device))

    def to(self, device: torch.device):
        self.vals = self.vals.to(device)
        self.bounds = self.bounds.to(device)
        self.up = self.bounds[1:]
        self.lo = self.bounds[:-1]


class Sampler(nn.Module):

    def __init__(self, *, sample_range: Tuple[float, float], n_samples: int,
                 perturb_sample: bool, spherical: bool, lindisp: bool):
        """
        Initialize a Sampler module

        :param depth_range: depth range for sampler
        :param n_samples: count to sample along ray
        :param perturb_sample: perturb the sample depths
        :param lindisp: If True, sample linearly in inverse depth rather than in depth
        """
        super().__init__()
        self.lindisp = lindisp
        self.spherical = spherical
        self.perturb_sample = perturb_sample
        s_range = (1 / sample_range[0], 1 / sample_range[1]) if self.lindisp else sample_range
        self.bins = Bins.linspace(s_range, n_samples, device=device.default())

    def forward(self, rays_o, rays_d):
        """
        Sample points along rays. return Spherical or Cartesian coordinates, 
        specified by `self.shperical`

        :param rays_o `Tensor(B, 3)`: rays' origin
        :param rays_d `Tensor(B, 3)`: rays' direction
        :return `Tensor(B, N, 3)`: sampled points
        :return `Tensor(B, N)`: corresponding depths along rays
        """
        s = self.bins.vals.expand(rays_o.size(0), -1)
        if self.perturb_sample:
            s = self.bins.lo + (self.bins.up - self.bins.lo) * torch.rand_like(s)
        z = torch.reciprocal(s) if self.lindisp else s
        if self.spherical:
            pts, depths = sphere.ray_sphere_intersect(rays_o, rays_d, z)
            sphers = sphere.cartesian2spherical(pts, inverse_r=self.lindisp)
            return sphers, depths, s, pts
        else:
            return rays_o[..., None, :] + rays_d[..., None, :] * z[..., None], z, s, None


class PdfSampler(nn.Module):

    def __init__(self, *, depth_range: Tuple[float, float], n_samples: int, perturb_sample: bool,
                 spherical: bool, lindisp: bool):
        """
        Initialize a Sampler module

        :param depth_range: depth range for sampler
        :param n_samples: count to sample along ray
        :param perturb_sample: perturb the sample depths
        :param lindisp: If True, sample linearly in inverse depth rather than in depth
        """
        super().__init__()
        self.lindisp = lindisp
        self.perturb_sample = perturb_sample
        self.spherical = spherical
        self.n_samples = n_samples
        self.s_range = (1 / depth_range[0], 1 / depth_range[1]) if self.lindisp else depth_range

    def forward(self, rays_o, rays_d, *, weights, s_vals=None, include_s_vals=False):
        """
        Sample points along rays. return Spherical or Cartesian coordinates, 
        specified by `self.shperical`

        :param rays_o `Tensor(B, 3)`: rays' origin
        :param rays_d `Tensor(B, 3)`: rays' direction
        :param weights `Tensor(B, M)`: weights of sample bins
        :param s_vals `Tensor(B, M)`: (optional) center of sample bins
        :param include_s_vals `bool`: (default to `False`) include `s_vals` in the sample array
        :return `Tensor(B, N, 3)`: sampled points
        :return `Tensor(B, N)`: corresponding depths along rays
        """
        if s_vals is None:
            s_vals = torch.linspace(*self.s_range, self.n_samples, device=device.default())
        s = self.sample_pdf(Bins(s_vals).bounds, weights, self.n_samples, det=self.perturb_sample)
        if include_s_vals:
            s = torch.cat([s, s_vals], dim=-1)
        s = torch.sort(s, descending=self.lindisp)[0]
        z = torch.reciprocal(s) if self.lindisp else s
        if self.spherical:
            pts, depths = sphere.ray_sphere_intersect(rays_o, rays_d, z)
            sphers = sphere.cartesian2spherical(pts, inverse_r=self.lindisp)
            return sphers, depths, s, pts
        else:
            return rays_o[..., None, :] + rays_d[..., None, :] * z[..., None], z, s

    def sample_pdf(self, bins: torch.Tensor, weights: torch.Tensor, N: int, det=True):
        '''
        :param bins `Tensor(..., M+1)`: bounds of bins
        :param weights `Tensor(..., M)`: weights of bins
        :param N `int`: # of samples along each ray
        :param det `bool`: (default to `True`) perform deterministic sampling or not
        :return `Tensor(..., N)`: samples
        '''
        # Get pdf
        weights = weights + TINY_FLOAT                                          # prevent nans
        pdf = weights / torch.sum(weights, dim=-1, keepdim=True)                # [..., M]
        cdf = torch.cat([
            torch.zeros_like(pdf[..., :1]),
            torch.cumsum(pdf, dim=-1)
        ], dim=-1)                                                              # [..., M+1]

        # Take uniform samples
        dots_sh = list(weights.shape[:-1])
        M = weights.shape[-1]

        u = torch.linspace(0, 1, N, device=bins.device).expand(dots_sh + [N]) \
            if det else torch.rand(dots_sh + [N], device=bins.device)           # [..., N]

        # Invert CDF
        # [..., N, 1] >= [..., 1, M] ----> [..., N, M] ----> [..., N,]
        above_inds = torch.sum(u[..., None] >= cdf[..., None, :-1], dim=-1).long()

        # random sample inside each bin
        below_inds = torch.clamp(above_inds - 1, min=0)
        inds_g = torch.stack((below_inds, above_inds), dim=-1)                  # [..., N, 2]

        cdf = cdf[..., None, :].expand(dots_sh + [N, M + 1])                    # [..., N, M+1]
        cdf_g = torch.gather(cdf, dim=-1, index=inds_g)                         # [..., N, 2]

        bins = bins[..., None, :].expand(dots_sh + [N, M + 1])                  # [..., N, M+1]
        bins_g = torch.gather(bins, dim=-1, index=inds_g)                       # [..., N, 2]

        # fix numeric issue
        denom = cdf_g[..., 1] - cdf_g[..., 0]                                   # [..., N]
        denom = torch.where(denom < TINY_FLOAT, torch.ones_like(denom), denom)
        t = (u - cdf_g[..., 0]) / denom

        samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0] + TINY_FLOAT)

        return samples