#include "InferPipeline.h" #include "Nmsl2.h" InferPipeline::InferPipeline(sptr net, uint nRays, uint nSamplesPerRay, glm::vec2 depthRange, int encodeDim, int coordChns) : _nRays(nRays), _nSamplesPerRay(nSamplesPerRay), _net(net), _sampler(new Sampler(depthRange, nSamplesPerRay, coordChns == 3)), _encoder(new Encoder(encodeDim, coordChns)), _renderer(new Renderer()) { uint nSamples = _nRays * _nSamplesPerRay; _coords = sptr>(new CudaArray(nSamples * coordChns)); _depths = sptr>(new CudaArray(nSamples)); _encoded = sptr>(new CudaArray(nSamples * _encoder->outDim())); _layeredColors = sptr>(new CudaArray(nSamples)); _net->bindResources(_encoded.get(), _depths.get(), _layeredColors.get()); } void InferPipeline::run(sptr> o_colors, sptr> rays, glm::vec3 origin, bool showPerf) { CudaEvent eStart, eSampled, eEncoded, eInferred, eRendered; cudaEventRecord(eStart); _sampler->sampleOnRays(_coords, _depths, rays, origin); cudaEventRecord(eSampled); _encoder->encode(_encoded, _coords); cudaEventRecord(eEncoded); _net->infer(); cudaEventRecord(eInferred); _renderer->render(o_colors, _layeredColors); cudaEventRecord(eRendered); if (showPerf) { CHECK_EX(cudaDeviceSynchronize()); float timeTotal, timeSample, timeEncode, timeInfer, timeRender; cudaEventElapsedTime(&timeTotal, eStart, eRendered); cudaEventElapsedTime(&timeSample, eStart, eSampled); cudaEventElapsedTime(&timeEncode, eSampled, eEncoded); cudaEventElapsedTime(&timeInfer, eEncoded, eInferred); cudaEventElapsedTime(&timeRender, eInferred, eRendered); std::ostringstream sout; sout << "Infer pipeline: " << timeTotal << "ms (Sample: " << timeSample << "ms, Encode: " << timeEncode << "ms, Infer: " << timeInfer << "ms, Render: " << timeRender << "ms)"; Logger::instance.info(sout.str()); } /* { std::ostringstream sout; sout << "Rays:" << std::endl; dumpFloatArray(sout, *rays, 10); Logger::instance.info(sout.str()); } { std::ostringstream sout; sout << "Spherical coords:" << std::endl; dumpFloatArray(sout, *sphericalCoords, 10); Logger::instance.info(sout.str()); } { std::ostringstream sout; sout << "Depths:" << std::endl; dumpFloatArray(sout, *depths, 10); Logger::instance.info(sout.str()); } { std::ostringstream sout; sout << "Encoded:" << std::endl; dumpFloatArray(sout, *encoded, 10, encoder.outDim()); Logger::instance.info(sout.str()); } */ }