#include "View.h"
#include <cuda_runtime.h>
#include <fstream>
#include "../utils/cuda.h"

__global__ void cu_transPoints(glm::vec3 *o_vecs, glm::vec3 *vecs, glm::vec3 t, glm::mat3 rot_t,
                               uint n) {
    uint idx = flattenIdx();
    if (idx >= n)
        return;
    o_vecs[idx] = vecs[idx] * rot_t + t;
}

__global__ void cu_transPoints(glm::vec3 *o_vecs, glm::vec3 *vecs, glm::vec3 t, glm::mat3 rot_t,
                               uint n, int *indices) {
    uint idx = flattenIdx();
    if (idx >= n)
        return;
    o_vecs[idx] = vecs[indices[idx]] * rot_t + t;
}

__global__ void cu_transPointsInverse(glm::vec3 *o_pts, glm::vec3 *pts, glm::vec3 t,
                                      glm::mat3 inv_rot_t, uint n) {
    uint idx = flattenIdx();
    if (idx >= n)
        return;
    o_pts[idx] = (pts[idx] - t) * inv_rot_t;
}

__global__ void cu_transPointsInverse(glm::vec3 *o_pts, glm::vec3 *pts, glm::vec3 t,
                                      glm::mat3 inv_rot_t, uint n, int *indices) {
    uint idx = flattenIdx();
    if (idx >= n)
        return;
    o_pts[idx] = (pts[indices[idx]] - t) * inv_rot_t;
}

__global__ void cu_transVectors(glm::vec3 *o_vecs, glm::vec3 *vecs, glm::mat3 rot_t, uint n) {
    uint idx = flattenIdx();
    if (idx >= n)
        return;
    o_vecs[idx] = vecs[idx] * rot_t;
}

__global__ void cu_transVectors(glm::vec3 *o_vecs, glm::vec3 *vecs, glm::mat3 rot_t, uint n,
                                int *indices) {
    uint idx = flattenIdx();
    if (idx >= n)
        return;
    o_vecs[idx] = vecs[indices[idx]] * rot_t;
}

__global__ void cu_genLocalRays(glm::vec3 *o_rays, glm::vec2 f, glm::vec2 c, glm::uvec2 res) {
    glm::uvec2 idx2 = IDX2;
    if (idx2.x >= res.x || idx2.y >= res.y)
        return;
    uint idx = idx2.x + idx2.y * res.x;
    o_rays[idx] = glm::vec3((glm::vec2(idx2) - c) / f, 1.0f);
}

__global__ void cu_genLocalRaysNormed(glm::vec3 *o_rays, glm::vec2 f, glm::vec2 c, glm::uvec2 res) {
    glm::uvec2 idx2 = IDX2;
    if (idx2.x >= res.x || idx2.y >= res.y)
        return;
    uint idx = idx2.x + idx2.y * res.x;
    o_rays[idx] = glm::normalize(glm::vec3((glm::vec2(idx2) - c) / f, 1.0f));
}

__global__ void cu_indexedCopy(glm::vec4 *o_colors, glm::vec4 *colors, int *indices, uint n) {
    uint idx = flattenIdx();
    if (idx >= n)
        return;
    int srcIdx = indices[idx];
    o_colors[idx] = srcIdx >= 0 ? colors[srcIdx] : glm::vec4();
}

void View::transPoints(sptr<CudaArray<glm::vec3>> results, sptr<CudaArray<glm::vec3>> points,
                       sptr<CudaArray<int>> indices, bool inverse) {
    glm::mat3 r_t = inverse ? _r : glm::transpose(_r);
    dim3 blkSize(1024);
    dim3 grdSize(ceilDiv(results->n(), blkSize.x));
    if (inverse) {
        if (indices == nullptr)
            CU_INVOKE(cu_transPointsInverse)(*results, *points, _t, r_t, points->n());
        else
            CU_INVOKE(cu_transPointsInverse)(*results, *points, _t, r_t, points->n(), *indices);
    } else {
        if (indices == nullptr)
            CU_INVOKE(cu_transPoints)(*results, *points, _t, r_t, results->n());
        else
            CU_INVOKE(cu_transPoints)(*results, *points, _t, r_t, results->n(), *indices);
    }
}

void View::transVectors(sptr<CudaArray<glm::vec3>> results, sptr<CudaArray<glm::vec3>> vectors,
                        sptr<CudaArray<int>> indices, bool inverse) {
    glm::mat3 r_t = inverse ? _r : glm::transpose(_r);
    dim3 blkSize(1024);
    dim3 grdSize(ceilDiv(results->n(), blkSize.x));
    if (indices == nullptr)
        CU_INVOKE(cu_transVectors)(*results, *vectors, r_t, results->n());
    else
        CU_INVOKE(cu_transVectors)(*results, *vectors, r_t, results->n(), *indices);
}

Camera::Camera(float fov, glm::vec2 c, glm::uvec2 res) {
    _f.x = _f.y = 0.5f * res.x / tan(fov * (float)M_PI / 360.0f);
    _f.y *= -1.0f;
    _c = c;
    _res = res;
}

sptr<CudaArray<glm::vec3>> Camera::localRays() {
    if (_localRays == nullptr)
        _genLocalRays(true);
    return _localRays;
}

bool Camera::loadMaskData(std::string filepath) {
    std::ifstream fin(filepath, std::ios::binary);
    if (!fin)
        return false;
    int n;

    fin.read((char *)&n, sizeof(n));
    std::vector<int> subsetIndicesBuffer(n);
    fin.read((char *)subsetIndicesBuffer.data(), sizeof(int) * n);
    _subsetIndices.reset(new CudaArray<int>(subsetIndicesBuffer));

    fin.read((char *)&n, sizeof(n));
    std::vector<int> subsetInverseIndicesBuffer(n);
    fin.read((char *)subsetInverseIndicesBuffer.data(), sizeof(int) * n);
    _subsetInverseIndices.reset(new CudaArray<int>(subsetInverseIndicesBuffer));

    if (!fin) {
        _subsetIndices = nullptr;
        _subsetInverseIndices = nullptr;
        return false;
    }
    std::ostringstream sout;
    sout << "Mask data loaded. Subset indices: " << _subsetIndices->n()
         << ", subset inverse indices: " << _subsetInverseIndices->n() << std::endl;
    Logger::instance.info(sout.str());
    return true;
}

void Camera::getRays(sptr<CudaArray<glm::vec3>> o_rays, View &view) {
    view.transVectors(o_rays, localRays(), _subsetIndices);
}

void Camera::restoreImage(sptr<CudaArray<glm::vec4>> o_imgData, sptr<CudaArray<glm::vec4>> colors) {
    if (_subsetInverseIndices == nullptr) {
        cudaMemcpy(o_imgData->getBuffer(), colors->getBuffer(), o_imgData->n() * sizeof(glm::vec4),
                   cudaMemcpyDeviceToDevice);
    } else {
        dim3 blkSize(1024);
        dim3 grdSize(ceilDiv(o_imgData->n(), blkSize.x));
        CU_INVOKE(cu_indexedCopy)(*o_imgData, *colors, *_subsetInverseIndices, o_imgData->n());
    }
}

void Camera::_genLocalRays(bool norm) {
    _localRays = sptr<CudaArray<glm::vec3>>(new CudaArray<glm::vec3>(_res.x * _res.y));
    dim3 blkSize(32, 32);
    dim3 grdSize(ceilDiv(_res.x, blkSize.x), ceilDiv(_res.y, blkSize.y));
    if (norm)
        CU_INVOKE(cu_genLocalRaysNormed)(*_localRays, _f, _c, _res);
    else
        CU_INVOKE(cu_genLocalRays)(*_localRays, _f, _c, _res);
}