#include <cuda_gl_interop.h>
#include "resource.h"

namespace utils::cuda {
    template <typename T> class GlTextureResource : public GraphicsResource {
    public:
        GlTextureResource(GLuint textureID, glm::uvec2 textureSize) {
            CHECK_EX(cudaGraphicsGLRegisterImage(&_res, textureID, GL_TEXTURE_2D,
                                                 cudaGraphicsRegisterFlagsWriteDiscard));
            _size = textureSize.x * textureSize.y * sizeof(T);
            _textureSize = textureSize;
        }

        virtual ~GlTextureResource() { cudaGraphicsUnmapResources(1, &_res, 0); }

        virtual void *getBuffer() const {
            cudaArray_t buffer;
            try {
                CHECK_EX(cudaGraphicsSubResourceGetMappedArray(&buffer, _res, 0, 0));
            } catch (...) {
                return nullptr;
            }
            return buffer;
        }

        operator T *() { return (T *)getBuffer(); }

        glm::uvec2 textureSize() { return _textureSize; }

    private:
        glm::uvec2 _textureSize;
    };
} // namespace utils::cuda