@@ -21,28 +21,31 @@ ShareMemory::~ShareMemory() {
2121}
2222
2323QNNContext::QNNContext (const std::string& model_name,
24- const std::string& model_path, const std::string& backend_lib_path, const std::string& system_lib_path, bool async) {
24+ const std::string& model_path, const std::string& backend_lib_path, const std::string& system_lib_path,
25+ bool async, const std::string& input_data_type, const std::string& output_data_type) {
2526 m_model_name = model_name;
2627
27- g_LibAppBuilder.ModelInitialize (model_name, model_path, backend_lib_path, system_lib_path, async);
28+ g_LibAppBuilder.ModelInitialize (model_name, model_path, backend_lib_path, system_lib_path, async, input_data_type, output_data_type );
2829}
2930
3031QNNContext::QNNContext (const std::string& model_name, const std::string& proc_name,
31- const std::string& model_path, const std::string& backend_lib_path, const std::string& system_lib_path, bool async) {
32+ const std::string& model_path, const std::string& backend_lib_path, const std::string& system_lib_path,
33+ bool async, const std::string& input_data_type, const std::string& output_data_type) {
3234 m_model_name = model_name;
3335 m_proc_name = proc_name;
3436
35- g_LibAppBuilder.ModelInitialize (model_name, proc_name, model_path, backend_lib_path, system_lib_path, async);
37+ g_LibAppBuilder.ModelInitialize (model_name, proc_name, model_path, backend_lib_path, system_lib_path, async, input_data_type, output_data_type );
3638}
3739
3840QNNContext::QNNContext (const std::string& model_name,
3941 const std::string& model_path, const std::string& backend_lib_path,
40- const std::string& system_lib_path, const std::vector<LoraAdapter>& lora_adapters, bool async) {
42+ const std::string& system_lib_path, const std::vector<LoraAdapter>& lora_adapters,
43+ bool async, const std::string& input_data_type, const std::string& output_data_type) {
4144
4245 m_model_name = model_name;
4346 m_lora_adapters = lora_adapters;
4447
45- g_LibAppBuilder.ModelInitialize (model_name, model_path, backend_lib_path, system_lib_path, m_lora_adapters, async);
48+ g_LibAppBuilder.ModelInitialize (model_name, model_path, backend_lib_path, system_lib_path, m_lora_adapters, async, input_data_type, output_data_type );
4649}
4750
4851// issue#24
@@ -54,14 +57,14 @@ std::vector<std::string> QNNContext::getInputDataType(){
5457 return g_LibAppBuilder.getInputDataType (m_model_name);
5558};
5659
57- std::vector<std::string> QNNContext::getOutputDataType (){
58- return g_LibAppBuilder.getOutputDataType (m_model_name);
59- };
60-
6160std::vector<std::vector<size_t >> QNNContext::getOutputShapes (){
6261 return g_LibAppBuilder.getOutputShapes (m_model_name);
6362};
6463
64+ std::vector<std::string> QNNContext::getOutputDataType (){
65+ return g_LibAppBuilder.getOutputDataType (m_model_name);
66+ };
67+
6568std::string QNNContext::getGraphName (){
6669 return g_LibAppBuilder.getGraphName (m_model_name);
6770};
@@ -84,16 +87,16 @@ std::vector<std::string> QNNContext::getInputDataType(const std::string& proc_na
8487 return m_moduleInfo.inputDataType ;
8588};
8689
87- std::vector<std::string> QNNContext::getOutputDataType (const std::string& proc_name){
88- ::ModelInfo_t m_moduleInfo = getModelInfo_P (m_model_name, m_proc_name, " od" );
89- return m_moduleInfo.outputDataType ;
90- };
91-
9290std::vector<std::vector<size_t >> QNNContext::getOutputShapes (const std::string& proc_name){
9391 ::ModelInfo_t m_moduleInfo = getModelInfo_P (m_model_name, m_proc_name, " os" );
9492 return m_moduleInfo.outputShapes ;
9593};
9694
95+ std::vector<std::string> QNNContext::getOutputDataType (const std::string& proc_name){
96+ ::ModelInfo_t m_moduleInfo = getModelInfo_P (m_model_name, m_proc_name, " od" );
97+ return m_moduleInfo.outputDataType ;
98+ };
99+
97100std::string QNNContext::getGraphName (const std::string& proc_name){
98101 ::ModelInfo_t m_moduleInfo = getModelInfo_P (m_model_name, m_proc_name, " gn" );
99102 return m_moduleInfo.graphName ;
@@ -117,14 +120,14 @@ QNNContext::~QNNContext() {
117120}
118121
119122
120- std::vector<py::array_t < float > >
121- QNNContext::Inference (const std::vector<py::array_t < float >> & input, const std::string& perf_profile, size_t graphIndex) {
122- return inference (m_model_name, input, perf_profile, graphIndex);
123+ std::vector<py::array >
124+ QNNContext::Inference (const std::vector<py::array> & input, const std::string& perf_profile, size_t graphIndex, const std::string& input_data_type, const std::string& output_data_type ) {
125+ return inference (m_model_name, input, perf_profile, graphIndex, input_data_type, output_data_type );
123126}
124127
125- std::vector<py::array_t < float > >
126- QNNContext::Inference (const ShareMemory& share_memory, const std::vector<py::array_t < float >> & input, const std::string& perf_profile, size_t graphIndex) {
127- return inference_P (m_model_name, m_proc_name, share_memory.m_share_memory_name , input, perf_profile, graphIndex);
128+ std::vector<py::array >
129+ QNNContext::Inference (const ShareMemory& share_memory, const std::vector<py::array> & input, const std::string& perf_profile, size_t graphIndex, const std::string& input_data_type, const std::string& output_data_type ) {
130+ return inference_P (m_model_name, m_proc_name, share_memory.m_share_memory_name , input, perf_profile, graphIndex, input_data_type, output_data_type );
128131}
129132
130133bool QNNContext::ApplyBinaryUpdate (const std::vector<LoraAdapter>& lora_adapters) {
@@ -189,11 +192,11 @@ PYBIND11_MODULE(appbuilder, m) {
189192 .def (py::init<const std::string&, const size_t >());
190193
191194 py::class_<QNNContext>(m, " QNNContext" )
192- .def (py::init<const std::string&, const std::string&, const std::string&, const std::string&, bool >())
193- .def (py::init<const std::string&, const std::string&, const std::string&, const std::string&, const std::vector<LoraAdapter>&, bool >())
194- .def (py::init<const std::string&, const std::string&, const std::string&, const std::string&, const std::string&, bool >())
195- .def (" Inference" , py::overload_cast<const std::vector<py::array_t < float >> &, const std::string&, size_t >(&QNNContext::Inference))
196- .def (" Inference" , py::overload_cast<const ShareMemory&, const std::vector<py::array_t < float >> &, const std::string&, size_t >(&QNNContext::Inference))
195+ .def (py::init<const std::string&, const std::string&, const std::string&, const std::string&, bool , const std::string&, const std::string& >())
196+ .def (py::init<const std::string&, const std::string&, const std::string&, const std::string&, const std::vector<LoraAdapter>&, bool , const std::string&, const std::string& >())
197+ .def (py::init<const std::string&, const std::string&, const std::string&, const std::string&, const std::string&, bool , const std::string&, const std::string& >())
198+ .def (" Inference" , py::overload_cast<const std::vector<py::array> &, const std::string&, size_t , const std::string&, const std::string& >(&QNNContext::Inference))
199+ .def (" Inference" , py::overload_cast<const ShareMemory&, const std::vector<py::array> &, const std::string&, size_t , const std::string&, const std::string& >(&QNNContext::Inference))
197200 .def (" ApplyBinaryUpdate" , &QNNContext::ApplyBinaryUpdate, " Apply Lora binary update" )
198201 .def (" getInputShapes" , py::overload_cast<>(&QNNContext::getInputShapes))
199202 .def (" getInputDataType" , py::overload_cast<>(&QNNContext::getInputDataType))
0 commit comments