#include "InferPipeline.h"
#include "Nmsl2.h"

InferPipeline::InferPipeline(sptr<Msl> net, uint nRays, uint nSamplesPerRay, glm::vec2 depthRange,
                             int encodeDim, int coordChns)
    : _nRays(nRays),
      _nSamplesPerRay(nSamplesPerRay),
      _net(net),
      _sampler(new Sampler(depthRange, nSamplesPerRay, coordChns == 3)),
      _encoder(new Encoder(encodeDim, coordChns)),
      _renderer(new Renderer()) {
    uint nSamples = _nRays * _nSamplesPerRay;
    _coords = sptr<CudaArray<float>>(new CudaArray<float>(nSamples * coordChns));
    _depths = sptr<CudaArray<float>>(new CudaArray<float>(nSamples));
    _encoded = sptr<CudaArray<float>>(new CudaArray<float>(nSamples * _encoder->outDim()));
    _layeredColors = sptr<CudaArray<glm::vec4>>(new CudaArray<glm::vec4>(nSamples));
    _net->bindResources(_encoded.get(), _depths.get(), _layeredColors.get());
}

void InferPipeline::run(sptr<CudaArray<glm::vec4>> o_colors, sptr<CudaArray<glm::vec3>> rays,
                        glm::vec3 origin, bool showPerf) {

    CudaEvent eStart, eSampled, eEncoded, eInferred, eRendered;

    cudaEventRecord(eStart);

    _sampler->sampleOnRays(_coords, _depths, rays, origin);

    cudaEventRecord(eSampled);

    _encoder->encode(_encoded, _coords);

    cudaEventRecord(eEncoded);

    _net->infer();

    cudaEventRecord(eInferred);

    _renderer->render(o_colors, _layeredColors);

    cudaEventRecord(eRendered);

    if (showPerf) {
        CHECK_EX(cudaDeviceSynchronize());

        float timeTotal, timeSample, timeEncode, timeInfer, timeRender;
        cudaEventElapsedTime(&timeTotal, eStart, eRendered);
        cudaEventElapsedTime(&timeSample, eStart, eSampled);
        cudaEventElapsedTime(&timeEncode, eSampled, eEncoded);
        cudaEventElapsedTime(&timeInfer, eEncoded, eInferred);
        cudaEventElapsedTime(&timeRender, eInferred, eRendered);

        std::ostringstream sout;
        sout << "Infer pipeline: " << timeTotal << "ms (Sample: " << timeSample
             << "ms, Encode: " << timeEncode << "ms, Infer: " << timeInfer
             << "ms, Render: " << timeRender << "ms)";
        Logger::instance.info(sout.str());
    }
    /*
    {
        std::ostringstream sout;
        sout << "Rays:" << std::endl;
        dumpFloatArray(sout, *rays, 10);
        Logger::instance.info(sout.str());
    }
    {
        std::ostringstream sout;
        sout << "Spherical coords:" << std::endl;
        dumpFloatArray(sout, *sphericalCoords, 10);
        Logger::instance.info(sout.str());
    }
    {
        std::ostringstream sout;
        sout << "Depths:" << std::endl;
        dumpFloatArray(sout, *depths, 10);
        Logger::instance.info(sout.str());
    }
    {
        std::ostringstream sout;
        sout << "Encoded:" << std::endl;
        dumpFloatArray(sout, *encoded, 10, encoder.outDim());
        Logger::instance.info(sout.str());
    }
    */
}