class IntegratedPosEncoder(InputEncoder):

    def __init__(self, chns, L, shape: str, cat_input=False):
        super().__init__(chns)
        self.shape = shape

    def _lift_gaussian(self, d: torch.Tensor, t_mean: torch.Tensor, t_var: torch.Tensor,
                       r_var: torch.Tensor, diag: bool):
        """Lift a Gaussian defined along a ray to 3D coordinates."""
        mean = d[..., None, :] * t_mean[..., None]
        d_sq = d**2

        d_mag_sq = torch.sum(d_sq, -1, keepdim=True).clamp_min(1e-10)

        if diag:
            d_outer_diag = d_sq
            null_outer_diag = 1 - d_outer_diag / d_mag_sq
            t_cov_diag = t_var[..., None] * d_outer_diag[..., None, :]
            xy_cov_diag = r_var[..., None] * null_outer_diag[..., None, :]
            cov_diag = t_cov_diag + xy_cov_diag
            return mean, cov_diag
        else:
            d_outer = d[..., :, None] * d[..., None, :]
            eye = torch.eye(d.shape[-1], device=d.device)
            null_outer = eye - d[..., :, None] * (d / d_mag_sq)[..., None, :]
            t_cov = t_var[..., None, None] * d_outer[..., None, :, :]
            xy_cov = r_var[..., None, None] * null_outer[..., None, :, :]
            cov = t_cov + xy_cov
            return mean, cov

    def _conical_frustum_to_gaussian(self, d: torch.Tensor, t0: float, t1: float, base_radius: float,
                                     diag: bool, stable: bool = True):
        """Approximate a conical frustum as a Gaussian distribution (mean+cov).

        Assumes the ray is originating from the origin, and base_radius is the
        radius at dist=1. Doesn't assume `d` is normalized.

        Args:
            d: torch.float32 3-vector, the axis of the cone
            t0: float, the starting distance of the frustum.
            t1: float, the ending distance of the frustum.
            base_radius: float, the scale of the radius as a function of distance.
            diag: boolean, whether or the Gaussian will be diagonal or full-covariance.
            stable: boolean, whether or not to use the stable computation described in
            the paper (setting this to False will cause catastrophic failure).

        Returns:
            a Gaussian (mean and covariance).
        """
        if stable:
            mu = (t0 + t1) / 2
            hw = (t1 - t0) / 2
            t_mean = mu + (2 * mu * hw**2) / (3 * mu**2 + hw**2)
            t_var = (hw**2) / 3 - (4 / 15) * ((hw**4 * (12 * mu**2 - hw**2)) /
                                              (3 * mu**2 + hw**2)**2)
            r_var = base_radius**2 * ((mu**2) / 4 + (5 / 12) * hw**2 - 4 / 15 *
                                      (hw**4) / (3 * mu**2 + hw**2))
        else:
            t_mean = (3 * (t1**4 - t0**4)) / (4 * (t1**3 - t0**3))
            r_var = base_radius**2 * (3 / 20 * (t1**5 - t0**5) / (t1**3 - t0**3))
            t_mosq = 3 / 5 * (t1**5 - t0**5) / (t1**3 - t0**3)
            t_var = t_mosq - t_mean**2
        return self._lift_gaussian(d, t_mean, t_var, r_var, diag)

    def _cylinder_to_gaussian(self, d: torch.Tensor, t0: float, t1: float, radius: float, diag: bool):
        """Approximate a cylinder as a Gaussian distribution (mean+cov).

        Assumes the ray is originating from the origin, and radius is the
        radius. Does not renormalize `d`.

        Args:
            d: torch.float32 3-vector, the axis of the cylinder
            t0: float, the starting distance of the cylinder.
            t1: float, the ending distance of the cylinder.
            radius: float, the radius of the cylinder
            diag: boolean, whether or the Gaussian will be diagonal or full-covariance.

        Returns:
            a Gaussian (mean and covariance).
        """
        t_mean = (t0 + t1) / 2
        r_var = radius**2 / 4
        t_var = (t1 - t0)**2 / 12
        return self._lift_gaussian(d, t_mean, t_var, r_var, diag)

    def cast_rays(self, t_vals: torch.Tensor, rays_o: torch.Tensor, rays_d: torch.Tensor,
                  rays_r: torch.Tensor, diag: bool = True):
        """Cast rays (cone- or cylinder-shaped) and featurize sections of it.

        Args:
            t_vals: float array, the "fencepost" distances along the ray.
            rays_o: float array, the ray origin coordinates.
            rays_d: float array, the ray direction vectors.
            radii: float array, the radii (base radii for cones) of the rays.
            ray_shape: string, the shape of the ray, must be 'cone' or 'cylinder'.
            diag: boolean, whether or not the covariance matrices should be diagonal.

        Returns:
            a tuple of arrays of means and covariances.
        """
        t0 = t_vals[..., :-1]
        t1 = t_vals[..., 1:]
        if self.shape == 'cone':
            gaussian_fn = self._conical_frustum_to_gaussian
        elif self.shape == 'cylinder':
            gaussian_fn = self._cylinder_to_gaussian
        else:
            assert False
        means, covs = gaussian_fn(rays_d, t0, t1, rays_r, diag)
        means = means + rays_o[..., None, :]
        return means, covs

    def integrated_pos_enc(x_coord: tuple[torch.Tensor, torch.Tensor], min_deg: int, max_deg: int,
                           diag: bool = True):
        """Encode `x` with sinusoids scaled by 2^[min_deg:max_deg-1].

        Args:
            x_coord: a tuple containing: x, torch.ndarray, variables to be encoded. Should
            be in [-pi, pi]. x_cov, torch.ndarray, covariance matrices for `x`.
            min_deg: int, the min degree of the encoding.
            max_deg: int, the max degree of the encoding.
            diag: bool, if true, expects input covariances to be diagonal (full
            otherwise).

        Returns:
            encoded: torch.ndarray, encoded variables.
        """
        if diag:
            x, x_cov_diag = x_coord
            scales = torch.tensor([2**i for i in range(min_deg, max_deg)], device=x.device)[:, None]
            shape = list(x.shape[:-1]) + [-1]
            y = torch.reshape(x[..., None, :] * scales, shape)
            y_var = torch.reshape(x_cov_diag[..., None, :] * scales**2, shape)
        else:
            x, x_cov = x_coord
            num_dims = x.shape[-1]
            basis = torch.cat([
                2**i * torch.eye(num_dims, device=x.device)
                for i in range(min_deg, max_deg)
            ], 1)
            y = torch.matmul(x, basis)
            # Get the diagonal of a covariance matrix (ie, variance). This is equivalent
            # to jax.vmap(torch.diag)((basis.T @ covs) @ basis).
            y_var = (torch.matmul(x_cov, basis) * basis).sum(-2)

        return math.expected_sin(
            torch.cat([y, y + 0.5 * math.pi], -1),
            torch.cat([y_var] * 2, -1))[0]


# @torch.jit.script
def intepolate_calc_weight(x, corners):
    return x * corners * 2 - x - corners + 1


class MultiresHashEncoder(InputEncoder):
    fast_op = True
    t_ind: torch.dtype

    layers: int
    coarse_levels: int
    layers_hashsize: list[int]

    layers_res: torch.Tensor
    """Tensor(L, D)"""
    local_corners: torch.Tensor
    """Tensor(C, D)"""
    layers_hashoffset: torch.Tensor
    """Tensor(L+1)"""
    hashtable: torch.nn.parameter.Parameter
    """Parameter(T, F)"""

    def __init__(self, chns: int, layers: int, log2_hashsize: int, features: int,
                 res0: int | list[int], scale_up: float = 2.0):
        super().__init__(chns, layers * features, (0., 1.))

        res0 = torch.tensor([res0] * chns if isinstance(res0, int) else res0)

        self.layers = layers
        self.features = features
        self.scale_up = scale_up
        self.max_hashsize = 2 ** log2_hashsize
        self.t_ind = torch.int if self.fast_op else torch.long

        layers_res: list[torch.Tensor] = []
        self.layers_hashsize: list[int] = []
        self.coarse_levels = 0
        layers_hashoffset: list[int] = [0]
        for i in range(layers):
            layers_res.append((res0 * scale_up ** i).to(self.t_ind))
            if layers_res[-1].max() > self.max_hashsize ** (1 / 3)\
                    or layers_res[-1].prod() > self.max_hashsize:
                self.layers_hashsize.append(self.max_hashsize)
            else:
                self.layers_hashsize.append(layers_res[-1].prod().item())
                self.coarse_levels = i + 1
            layers_hashoffset.append(layers_hashoffset[-1] + self.layers_hashsize[-1])
        self.register_temp("layers_res", torch.stack(layers_res, 0))
        self.register_temp("layers_hashoffset", torch.tensor(layers_hashoffset, dtype=self.t_ind))
        self.register_temp("local_corners", split_voxels_local(1, 2, dims=chns) + .5)

        # Initialize the hash table entries using the uniform distribution U(−10^−4, 10^−4) to provide
        # a small amount of randomness while encouraging initial predictions close to zero [muller2022instant]
        self.hashtable = torch.nn.parameter.Parameter(
            (torch.rand(layers_hashoffset[-1], features, device=self.device) - .5))

    @profile
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Encode inputs using multi-resolution hash encoder [muller2022instant]

        :param x `Tensor(N..., D)`: D-dim inputs
        :return `Tensor(N..., LF)`: encoded outputs
        """
        if self.fast_op:
            N_, D = x.shape[:-1], x.shape[-1]
            return multires_hash_encode(self.layers, self.coarse_levels, self.layers_res,
                                        self.layers_hashoffset, x.reshape(-1, D), self.hashtable)\
                .transpose(0, 1).reshape(*N_, -1)

        @profile("Calculate corners")
        def calc_corners(x) -> tuple[torch.Tensor, torch.Tensor]:
            grid_pos = x.unsqueeze(-2) * (self.layers_res - 1)  # (N..., L, D)
            grid_pos.unsqueeze_(-2)  # (N..., L, 1, D)
            grid_lo = torch.floor(grid_pos)
            grid_pos.sub_(grid_lo)
            corners = (grid_lo + self.local_corners).long().min(self.layers_res.unsqueeze(-2) - 1)
            # (N..., L, C, D)
            return grid_pos, corners

        grid_pos, corners = calc_corners(x)
        # (N..., L, 1, D), (N..., L, C, D)

        @profile("Calculate encoded")
        def calc_encoded(level: int) -> torch.Tensor:
            if level < self.coarse_levels:
                idx = to_flat_indices(corners[..., level, :, :], self.layers_res[level, None])
            else:
                idx = self._fast_hash(corners[..., level, :, :]) % self.max_hashsize
            idx.add_(self.layers_hashoffset[level, None])
            return self._linear_interp(grid_pos[..., level, :, :], self.hashtable[idx])

        result = torch.stack([calc_encoded(level) for level in range(self.layers)], dim=-2)
        # (N..., L, X)

        return result.flatten(-2)

    def _linear_interp(self, x: torch.Tensor, corner_values: torch.Tensor) -> torch.Tensor:
        """
        [summary]

        :param x `Tensor(N..., L, 1, D)`: [description]
        :param corner_values `Tensor(N..., L, C, X)`: [description]
        :return `Tensor(N..., L, X): [description]
        :rtype: [type]
        """
        weights = (x * self.local_corners * 2 - x - self.local_corners + 1).prod(-1, keepdim=True)
        # (N..., L, C, 1)
        return (weights * corner_values).sum(-2)  # (N..., L, X)

    def extra_repr(self) -> str:
        return f"{self.in_chns} -> {self.out_chns}({self.layers}x{self.features})"\
            f", resolution={self.layers_res[0].tolist()}*{self.scale_up}^L"\
            f", max_hashsize={self.max_hashsize}"

    @profile
    def _fast_hash(self, grid_pos: torch.Tensor) -> torch.Tensor:
        """
        Perform fast hash according to instant-ngp

        :param grid_pos `Tensor(N..., D)`: integer grid positions
        :return `Tensor(N...)`: hash values
        """
        if grid_pos.shape[-1] > 7:
            raise ValueError("fast_hash can only hash up to 7 dimensions.")

        # While 1 is technically not a good prime for hashing (or a prime at all), it helps memory coherence
        # and is sufficient for our use case of obtaining a uniformly colliding index from high-dimensional
        # coordinates. [muller2022instant]
        primes = [1, 2654435761, 805459861, 3674653429, 2097192037, 1434869437, 2165219737]
        result = grid_pos[..., 0] * primes[0]
        for i in range(1, grid_pos.shape[-1]):
            result.bitwise_xor_(grid_pos[..., i] * primes[i])
        return result


class LayeredMultiresHashEncoder(InputEncoder):
    use_cpp = False

    layers: int
    coarse_levels: int
    layers_res: torch.Tensor
    """Tensor(L, D)"""
    local_corners: torch.Tensor
    """Tensor(C, D)"""
    layers_hashsize: list[int]
    layers_hashoffset: list[int]
    t_ind: torch.dtype

    def __init__(self, chns: int, layers: int, log2_hashsize: int, features: int,
                 res0: int | list[int], scale_up: float = 2.0, parts: int = 64):
        super().__init__(chns, layers * features, (0., 1.))

        res0 = torch.tensor([res0] * chns if isinstance(res0, int) else res0)

        self.layers = layers
        self.features = features
        self.scale_up = scale_up
        self.max_hashsize = 2 ** log2_hashsize // parts
        self.t_ind = torch.int if self.use_cpp else torch.long

        layers_res: list[torch.Tensor] = []
        self.layers_hashsize: list[int] = []
        self.layers_usehash: list[bool] = []
        self.coarse_levels = 0
        layers_hashoffset: list[int] = [0]
        for i in range(layers):
            layers_res.append(res0 if i == 0 else (layers_res[-1] * scale_up).to(self.t_ind))
            if layers_res[-1].max() > self.max_hashsize ** (1 / 3)\
                    or layers_res[-1].prod() > self.max_hashsize:
                self.layers_hashsize.append(self.max_hashsize)
            else:
                self.layers_hashsize.append(layers_res[-1].prod().item())
                self.coarse_levels = i + 1
            layers_hashoffset.append(layers_hashoffset[-1] + self.layers_hashsize[-1])
        self.register_temp("layers_res", torch.stack(layers_res, 0))
        self.register_temp("layers_hashoffset", torch.tensor(layers_hashoffset, dtype=self.t_ind))

        # Initialize the hash table entries using the uniform distribution U(−10^−4, 10^−4) to provide
        # a small amount of randomnesddaddadwss while encouraging initial predictions close to zero [muller2022instant]
        self.hashtable = torch.nn.parameter.Parameter(
            (torch.rand(parts, layers_hashoffset[-1], features, device=self.device) - .5))

        self.register_temp("local_corners", split_voxels_local(1, 2, dims=chns) + .5)

    @profile
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Encode inputs using multi-resolution hash encoder [muller2022instant]

        :param x `Tensor(N..., P, D)`: D-dim inputs
        :return `Tensor(N..., P, LF)`: encoded outputs
        """
        if self.use_cpp:
            N_, P, D = x.shape[:-2], x.shape[-2], x.shape[-1]
            return torch.stack([
                multires_hash_encode(self.layers, self.coarse_levels, self.layers_res,
                                     self.layers_hashoffset, x[..., i, :].reshape(-1, D),
                                     self.hashtable[i]).reshape(*N_, -1)
                for i in range(P)
            ], dim=-2)

        @profile("Calculate corners")
        def calc_corners(x) -> tuple[torch.Tensor, torch.Tensor]:
            grid_pos = x.unsqueeze(-2) * (self.layers_res - 1)  # (N..., P, L, D)
            grid_pos.unsqueeze_(-2)  # (N..., P, L, 1, D)
            grid_lo = torch.floor(grid_pos)
            grid_pos.sub_(grid_lo)
            corners = (grid_lo + self.local_corners).long().min(self.layers_res.unsqueeze(-2) - 1)
            # (N..., L, C, D)
            return grid_pos, corners

        grid_pos, corners = calc_corners(x)
        # (N..., P, L, 1, D), (N..., P, L, C, D)

        @profile("Calculate encoded")
        def calc_encoded(level: int) -> torch.Tensor:
            if level < self.coarse_levels:
                idx = to_flat_indices(corners[..., level, :, :], self.layers_res[level, None])
            else:
                idx = self._fast_hash(corners[..., level, :, :]) % self.max_hashsize
            idx.add_(self.layers_hashoffset[level, None])
            part_idx = torch.arange(x.shape[-2], device=x.device)[:, None].broadcast_to(idx.shape)
            return self._linear_interp(grid_pos[..., level, :, :], self.hashtable[part_idx, idx])

        result = torch.stack([calc_encoded(level) for level in range(self.layers)], dim=-2)
        # (N..., L, X)

        return result.flatten(-2)

    def _linear_interp(self, x: torch.Tensor, corner_values: torch.Tensor) -> torch.Tensor:
        """
        [summary]

        :param x `Tensor(N..., L, 1, D)`: [description]
        :param corner_values `Tensor(N..., L, C, X)`: [description]
        :return `Tensor(N..., L, X): [description]
        :rtype: [type]
        """
        weights = (x * self.local_corners * 2 - x - self.local_corners + 1).prod(-1, keepdim=True)
        # (N..., L, C, 1)
        return (weights * corner_values).sum(-2)  # (N..., L, X)

    def extra_repr(self) -> str:
        return f"{self.in_chns} -> {self.out_chns}({self.layers}x{self.features})"\
            f", resolution={self.layers_res[0].tolist()}*{self.scale_up}^L"\
            f", max_hashsize={self.max_hashsize}"

    @profile
    def _fast_hash(self, grid_pos: torch.Tensor) -> torch.Tensor:
        """
        Perform fast hash according to instant-ngp

        :param grid_pos `Tensor(N..., D)`: integer grid positions
        :return `Tensor(N...)`: hash values
        """
        if grid_pos.shape[-1] > 7:
            raise ValueError("fast_hash can only hash up to 7 dimensions.")

        # While 1 is technically not a good prime for hashing (or a prime at all), it helps memory coherence
        # and is sufficient for our use case of obtaining a uniformly colliding index from high-dimensional
        # coordinates. [muller2022instant]
        primes = [1, 2654435761, 805459861, 3674653429, 2097192037, 1434869437, 2165219737]
        result = grid_pos[..., 0] * primes[0]
        for i in range(1, grid_pos.shape[-1]):
            result.bitwise_xor_(grid_pos[..., i] * primes[i])
        return result