Skip to content

Commit 257f049

Browse files
committed
Supports multiple types of data input and output: int8, uint8, int16, uint16, int32, uint32, int64, uint64, float16, float32, float64, etc. Avoid data type conversion to improve overall inference performance.
1 parent ddfdf1c commit 257f049

File tree

14 files changed

+1004
-165
lines changed

14 files changed

+1004
-165
lines changed

pybind/AppBuilder.cpp

Lines changed: 29 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -21,28 +21,31 @@ ShareMemory::~ShareMemory() {
2121
}
2222

2323
QNNContext::QNNContext(const std::string& model_name,
24-
const std::string& model_path, const std::string& backend_lib_path, const std::string& system_lib_path, bool async) {
24+
const std::string& model_path, const std::string& backend_lib_path, const std::string& system_lib_path,
25+
bool async, const std::string& input_data_type, const std::string& output_data_type) {
2526
m_model_name = model_name;
2627

27-
g_LibAppBuilder.ModelInitialize(model_name, model_path, backend_lib_path, system_lib_path, async);
28+
g_LibAppBuilder.ModelInitialize(model_name, model_path, backend_lib_path, system_lib_path, async, input_data_type, output_data_type);
2829
}
2930

3031
QNNContext::QNNContext(const std::string& model_name, const std::string& proc_name,
31-
const std::string& model_path, const std::string& backend_lib_path, const std::string& system_lib_path, bool async) {
32+
const std::string& model_path, const std::string& backend_lib_path, const std::string& system_lib_path,
33+
bool async, const std::string& input_data_type, const std::string& output_data_type) {
3234
m_model_name = model_name;
3335
m_proc_name = proc_name;
3436

35-
g_LibAppBuilder.ModelInitialize(model_name, proc_name, model_path, backend_lib_path, system_lib_path, async);
37+
g_LibAppBuilder.ModelInitialize(model_name, proc_name, model_path, backend_lib_path, system_lib_path, async, input_data_type, output_data_type);
3638
}
3739

3840
QNNContext::QNNContext(const std::string& model_name,
3941
const std::string& model_path, const std::string& backend_lib_path,
40-
const std::string& system_lib_path, const std::vector<LoraAdapter>& lora_adapters, bool async) {
42+
const std::string& system_lib_path, const std::vector<LoraAdapter>& lora_adapters,
43+
bool async, const std::string& input_data_type, const std::string& output_data_type) {
4144

4245
m_model_name = model_name;
4346
m_lora_adapters = lora_adapters;
4447

45-
g_LibAppBuilder.ModelInitialize(model_name, model_path, backend_lib_path, system_lib_path, m_lora_adapters, async);
48+
g_LibAppBuilder.ModelInitialize(model_name, model_path, backend_lib_path, system_lib_path, m_lora_adapters, async, input_data_type, output_data_type);
4649
}
4750

4851
// issue#24
@@ -54,14 +57,14 @@ std::vector<std::string> QNNContext::getInputDataType(){
5457
return g_LibAppBuilder.getInputDataType(m_model_name);
5558
};
5659

57-
std::vector<std::string> QNNContext::getOutputDataType(){
58-
return g_LibAppBuilder.getOutputDataType(m_model_name);
59-
};
60-
6160
std::vector<std::vector<size_t>> QNNContext::getOutputShapes(){
6261
return g_LibAppBuilder.getOutputShapes(m_model_name);
6362
};
6463

64+
std::vector<std::string> QNNContext::getOutputDataType(){
65+
return g_LibAppBuilder.getOutputDataType(m_model_name);
66+
};
67+
6568
std::string QNNContext::getGraphName(){
6669
return g_LibAppBuilder.getGraphName(m_model_name);
6770
};
@@ -84,16 +87,16 @@ std::vector<std::string> QNNContext::getInputDataType(const std::string& proc_na
8487
return m_moduleInfo.inputDataType;
8588
};
8689

87-
std::vector<std::string> QNNContext::getOutputDataType(const std::string& proc_name){
88-
::ModelInfo_t m_moduleInfo = getModelInfo_P(m_model_name, m_proc_name, "od");
89-
return m_moduleInfo.outputDataType;
90-
};
91-
9290
std::vector<std::vector<size_t>> QNNContext::getOutputShapes(const std::string& proc_name){
9391
::ModelInfo_t m_moduleInfo = getModelInfo_P(m_model_name, m_proc_name, "os");
9492
return m_moduleInfo.outputShapes;
9593
};
9694

95+
std::vector<std::string> QNNContext::getOutputDataType(const std::string& proc_name){
96+
::ModelInfo_t m_moduleInfo = getModelInfo_P(m_model_name, m_proc_name, "od");
97+
return m_moduleInfo.outputDataType;
98+
};
99+
97100
std::string QNNContext::getGraphName(const std::string& proc_name){
98101
::ModelInfo_t m_moduleInfo = getModelInfo_P(m_model_name, m_proc_name, "gn");
99102
return m_moduleInfo.graphName;
@@ -117,14 +120,14 @@ QNNContext::~QNNContext() {
117120
}
118121

119122

120-
std::vector<py::array_t<float>>
121-
QNNContext::Inference(const std::vector<py::array_t<float>>& input, const std::string& perf_profile, size_t graphIndex) {
122-
return inference(m_model_name, input, perf_profile, graphIndex);
123+
std::vector<py::array>
124+
QNNContext::Inference(const std::vector<py::array>& input, const std::string& perf_profile, size_t graphIndex, const std::string& input_data_type, const std::string& output_data_type) {
125+
return inference(m_model_name, input, perf_profile, graphIndex, input_data_type, output_data_type);
123126
}
124127

125-
std::vector<py::array_t<float>>
126-
QNNContext::Inference(const ShareMemory& share_memory, const std::vector<py::array_t<float>>& input, const std::string& perf_profile, size_t graphIndex) {
127-
return inference_P(m_model_name, m_proc_name, share_memory.m_share_memory_name, input, perf_profile, graphIndex);
128+
std::vector<py::array>
129+
QNNContext::Inference(const ShareMemory& share_memory, const std::vector<py::array>& input, const std::string& perf_profile, size_t graphIndex, const std::string& input_data_type, const std::string& output_data_type) {
130+
return inference_P(m_model_name, m_proc_name, share_memory.m_share_memory_name, input, perf_profile, graphIndex, input_data_type, output_data_type);
128131
}
129132

130133
bool QNNContext::ApplyBinaryUpdate(const std::vector<LoraAdapter>& lora_adapters) {
@@ -189,11 +192,11 @@ PYBIND11_MODULE(appbuilder, m) {
189192
.def(py::init<const std::string&, const size_t>());
190193

191194
py::class_<QNNContext>(m, "QNNContext")
192-
.def(py::init<const std::string&, const std::string&, const std::string&, const std::string&, bool>())
193-
.def(py::init<const std::string&, const std::string&, const std::string&, const std::string&, const std::vector<LoraAdapter>&, bool>())
194-
.def(py::init<const std::string&, const std::string&, const std::string&, const std::string&, const std::string&, bool>())
195-
.def("Inference", py::overload_cast<const std::vector<py::array_t<float>>&, const std::string&, size_t>(&QNNContext::Inference))
196-
.def("Inference", py::overload_cast<const ShareMemory&, const std::vector<py::array_t<float>>&, const std::string&, size_t>(&QNNContext::Inference))
195+
.def(py::init<const std::string&, const std::string&, const std::string&, const std::string&, bool, const std::string&, const std::string&>())
196+
.def(py::init<const std::string&, const std::string&, const std::string&, const std::string&, const std::vector<LoraAdapter>&, bool, const std::string&, const std::string&>())
197+
.def(py::init<const std::string&, const std::string&, const std::string&, const std::string&, const std::string&, bool, const std::string&, const std::string&>())
198+
.def("Inference", py::overload_cast<const std::vector<py::array>&, const std::string&, size_t, const std::string&, const std::string&>(&QNNContext::Inference))
199+
.def("Inference", py::overload_cast<const ShareMemory&, const std::vector<py::array>&, const std::string&, size_t, const std::string&, const std::string&>(&QNNContext::Inference))
197200
.def("ApplyBinaryUpdate", &QNNContext::ApplyBinaryUpdate, "Apply Lora binary update")
198201
.def("getInputShapes", py::overload_cast<>(&QNNContext::getInputShapes))
199202
.def("getInputDataType", py::overload_cast<>(&QNNContext::getInputDataType))

0 commit comments

Comments
 (0)