#include "FsNeRF.h" namespace models { FsNeRF::FsNeRF(const Args &args, int nRays) : _nRays(nRays), _nSamples(args.nSamples), _xChns(3 - !args.withRadius), _field(new fields::FsNeRF(args.modelPath)), _sampler(new modules::Sampler({args.near, args.far}, "spherical_radius", args.nSamples, args.withRadius)), _encoder(new modules::Encoder(args.xfreqs, 3 - !args.withRadius, false)), _renderer(new modules::Renderer(args.whiteBg)) { auto n = _nRays * _nSamples; _x = darray(new CudaArray(n * _xChns)); _depths = darray(new CudaArray(n)); _encoded = darray(new CudaArray(n * _encoder->outChns())); _rgbd = darray(new CudaArray(n)); _field->bindResources(_encoded.get(), _depths.get(), _rgbd.get()); } FsNeRF::~FsNeRF() { delete _sampler; delete _encoder; delete _renderer; } void FsNeRF::operator()(darray o_colors, const darray dirs, glm::vec3 origin, bool showPerf = false) { CudaEvent eStart, eSampled, eEncoded, eInferred, eRendered; cudaEventRecord(eStart); (*_sampler)(_x, _depths, origin, darray(dirs->subArray(0, _nRays))); CHECK_EX(cudaDeviceSynchronize()); cudaEventRecord(eSampled); (*_encoder)(_encoded, _x); CHECK_EX(cudaDeviceSynchronize()); cudaEventRecord(eEncoded); _field->infer(); CHECK_EX(cudaDeviceSynchronize()); cudaEventRecord(eInferred); (*_renderer)(darray(o_colors->subArray(0, _nRays)), _depths, _rgbd); CHECK_EX(cudaDeviceSynchronize()); cudaEventRecord(eRendered); if (showPerf) { CHECK_EX(cudaDeviceSynchronize()); float timeTotal, timeSample, timeEncode, timeInfer, timeRender; cudaEventElapsedTime(&timeTotal, eStart, eRendered); cudaEventElapsedTime(&timeSample, eStart, eSampled); cudaEventElapsedTime(&timeEncode, eSampled, eEncoded); cudaEventElapsedTime(&timeInfer, eEncoded, eInferred); cudaEventElapsedTime(&timeRender, eInferred, eRendered); std::ostringstream sout; sout << "Infer pipeline: " << timeTotal << "ms (Sample: " << timeSample << "ms, Encode: " << timeEncode << "ms, Infer: " << timeInfer << "ms, Render: " << timeRender << "ms)"; Logger::instance.info(sout.str().c_str()); } /* { std::ostringstream sout; sout << "Rays:" << std::endl; dumpArray(sout, *rays, 10); Logger::instance.info(sout.str()); } { std::ostringstream sout; sout << "Spherical coords:" << std::endl; dumpArray(sout, *_coords, 10, _xChns * _nSamples); Logger::instance.info(sout.str()); } { std::ostringstream sout; sout << "Depths:" << std::endl; dumpArray(sout, *_depths, 10, _nSamples); Logger::instance.info(sout.str()); } { std::ostringstream sout; sout << "Encoded:" << std::endl; dumpArray(sout, *_encoded, 10, _encoder->outDim() * _nSamples); Logger::instance.info(sout.str()); } { std::ostringstream sout; sout << "Color:" << std::endl; dumpArray(sout, *o_colors, 10); Logger::instance.info(sout.str()); } */ } }