Net.h 762 Bytes
Newer Older
Nianchen Deng's avatar
Nianchen Deng committed
1
#pragma once
Nianchen Deng's avatar
sync    
Nianchen Deng committed
2
#include "../utils/common.h"
Nianchen Deng's avatar
Nianchen Deng committed
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33


class Net {
public:
    Net() : mEngine(nullptr) { }

    bool load(const std::string& path);

    void bindResource(const std::string& name, Resource* res);

    bool dispose();

    bool infer(cudaStream_t stream = nullptr, bool dumpInputOutput = false);

private:
    std::shared_ptr<nv::ICudaEngine> mEngine;
	std::shared_ptr<nv::IExecutionContext> mContext;
    Resources mResources;

    void _deserialize(const std::string& path);

	std::vector<void*> _getBindings();

	void _dumpInputOutput();

protected:
    bool _dumpBuffer(std::ostream& os, void* deviceBuf, int index);

    bool _dumpBuffer(std::ostream& os, void* deviceBuf, nv::Dims bufDims, nv::DataType dataType);

};