Skip to content

Commit 91f0f51

Browse files
committed
Fix get shape issue for remote process.
1 parent 6951fc7 commit 91f0f51

File tree

11 files changed

+411
-25
lines changed

11 files changed

+411
-25
lines changed

pybind/AppBuilder.cpp

Lines changed: 61 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,53 @@ std::vector<std::vector<size_t>> QNNContext::getOutputShapes(){
6262
return g_LibAppBuilder.getOutputShapes(m_model_name);
6363
};
6464

65+
std::string QNNContext::getGraphName(){
66+
return g_LibAppBuilder.getGraphName(m_model_name);
67+
};
68+
69+
std::vector<std::string> QNNContext::getInputName(){
70+
return g_LibAppBuilder.getInputName(m_model_name);
71+
};
72+
73+
std::vector<std::string> QNNContext::getOutputName(){
74+
return g_LibAppBuilder.getOutputName(m_model_name);
75+
};
76+
77+
std::vector<std::vector<size_t>> QNNContext::getInputShapes(const std::string& proc_name){
78+
::ModelInfo_t m_moduleInfo = getModelInfo_P(m_model_name, m_proc_name, "is", /*perf_profile, graphIndex*/ 0);
79+
return m_moduleInfo.inputShapes;
80+
};
81+
82+
std::vector<std::string> QNNContext::getInputDataType(const std::string& proc_name){
83+
::ModelInfo_t m_moduleInfo = getModelInfo_P(m_model_name, m_proc_name, "id");
84+
return m_moduleInfo.inputDataType;
85+
};
86+
87+
std::vector<std::string> QNNContext::getOutputDataType(const std::string& proc_name){
88+
::ModelInfo_t m_moduleInfo = getModelInfo_P(m_model_name, m_proc_name, "od");
89+
return m_moduleInfo.outputDataType;
90+
};
91+
92+
std::vector<std::vector<size_t>> QNNContext::getOutputShapes(const std::string& proc_name){
93+
::ModelInfo_t m_moduleInfo = getModelInfo_P(m_model_name, m_proc_name, "os");
94+
return m_moduleInfo.outputShapes;
95+
};
96+
97+
std::string QNNContext::getGraphName(const std::string& proc_name){
98+
::ModelInfo_t m_moduleInfo = getModelInfo_P(m_model_name, m_proc_name, "gn");
99+
return m_moduleInfo.graphName;
100+
};
101+
102+
std::vector<std::string> QNNContext::getInputName(const std::string& proc_name){
103+
::ModelInfo_t m_moduleInfo = getModelInfo_P(m_model_name, m_proc_name, "in");
104+
return m_moduleInfo.inputName;
105+
};
106+
107+
std::vector<std::string> QNNContext::getOutputName(const std::string& proc_name){
108+
::ModelInfo_t m_moduleInfo = getModelInfo_P(m_model_name, m_proc_name, "on");
109+
return m_moduleInfo.outputName;
110+
};
111+
65112
QNNContext::~QNNContext() {
66113
if (m_proc_name.empty())
67114
g_LibAppBuilder.ModelDestroy(m_model_name);
@@ -148,11 +195,20 @@ PYBIND11_MODULE(appbuilder, m) {
148195
.def("Inference", py::overload_cast<const std::vector<py::array_t<float>>&, const std::string&, size_t>(&QNNContext::Inference))
149196
.def("Inference", py::overload_cast<const ShareMemory&, const std::vector<py::array_t<float>>&, const std::string&, size_t>(&QNNContext::Inference))
150197
.def("ApplyBinaryUpdate", &QNNContext::ApplyBinaryUpdate, "Apply Lora binary update")
151-
.def("getInputShapes", &QNNContext::getInputShapes, "Get Input Shape")
152-
.def("getInputDataType", &QNNContext::getInputDataType, "Get Input data type")
153-
.def("getOutputDataType", &QNNContext::getOutputDataType, "Get output data type")
154-
.def("getOutputShapes", &QNNContext::getOutputShapes, "Get Output Shape");
155-
198+
.def("getInputShapes", py::overload_cast<>(&QNNContext::getInputShapes))
199+
.def("getInputDataType", py::overload_cast<>(&QNNContext::getInputDataType))
200+
.def("getOutputShapes", py::overload_cast<>(&QNNContext::getOutputShapes))
201+
.def("getOutputDataType", py::overload_cast<>(&QNNContext::getOutputDataType))
202+
.def("getInputName", py::overload_cast<>(&QNNContext::getInputName))
203+
.def("getOutputName", py::overload_cast<>(&QNNContext::getOutputName))
204+
.def("getGraphName", py::overload_cast<>(&QNNContext::getGraphName))
205+
.def("getInputShapes", py::overload_cast<const std::string&>(&QNNContext::getInputShapes))
206+
.def("getInputDataType", py::overload_cast<const std::string&>(&QNNContext::getInputDataType))
207+
.def("getOutputDataType", py::overload_cast<const std::string&>(&QNNContext::getOutputDataType))
208+
.def("getOutputShapes", py::overload_cast<const std::string&>(&QNNContext::getOutputShapes))
209+
.def("getInputName", py::overload_cast<const std::string&>(&QNNContext::getInputName))
210+
.def("getOutputName", py::overload_cast<const std::string&>(&QNNContext::getOutputName))
211+
.def("getGraphName", py::overload_cast<const std::string&>(&QNNContext::getGraphName));
156212

157213
py::class_<LoraAdapter>(m, "LoraAdapter")
158214
.def(py::init<const std::string &, const std::vector<std::string> &>());

pybind/AppBuilder.h

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,15 @@ std::vector<py::array_t<float>> inference_P(std::string model_name, std::string
140140
return output;
141141
}
142142

143+
ModelInfo_t getModelInfo_P(std::string model_name, std::string proc_name,
144+
std::string input, size_t graphIndex = 0) {
145+
146+
std::vector<void*> outputBuffers;
147+
std::vector<size_t> outputSize;
148+
ModelInfo_t output = g_LibAppBuilder.getModelInfo(model_name, proc_name, input);
149+
return output;
150+
}
151+
143152
bool ApplyBinaryUpdate(const std::vector<LoraAdapter>& lora_adapters);
144153

145154
int create_memory(std::string share_memory_name, size_t share_memory_size) {
@@ -187,7 +196,25 @@ class QNNContext {
187196
std::vector<std::string> getInputDataType();
188197
std::vector<std::string> getOutputDataType();
189198
std::vector<std::vector<size_t>> getOutputShapes();
190-
199+
std::string getGraphName();
200+
std::vector<std::string> getInputName();
201+
std::vector<std::string> getOutputName();
202+
203+
std::vector<std::vector<size_t>> getInputShapes(const std::string& proc_name);
204+
std::vector<std::string> getInputDataType(const std::string& proc_name);
205+
std::vector<std::string> getOutputDataType(const std::string& proc_name);
206+
std::vector<std::vector<size_t>> getOutputShapes(const std::string& proc_name);
207+
std::string getGraphName(const std::string& proc_name);
208+
std::vector<std::string> getInputName(const std::string& proc_name);
209+
std::vector<std::string> getOutputName(const std::string& proc_name);
210+
211+
typedef struct ModelInfo {
212+
std::vector<std::vector<size_t>> inputShapes;
213+
std::vector<std::string> inputDataType;
214+
std::vector<std::vector<size_t>> onputShapes;
215+
std::vector<std::string> onputDataType;
216+
std::string graphName;
217+
} ModelInfo_t;
191218
~QNNContext();
192219
};
193220

samples/python/real_esrgan_x4plus/real_esrgan_x4plus.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -114,21 +114,33 @@ def Inference(input_image_path, output_image_path, show_image = True):
114114
if show_image:
115115
image_buffer.show()
116116

117+
def getGraphName():
118+
graph_name = realesrgan.getGraphName()
119+
print("debug_realesrgan,graph_name:",graph_name)
120+
117121
def getInputShapes():
118122
input_shapes = realesrgan.getInputShapes()
119-
print(input_shapes)
123+
print("debug_realesrgan,input_shapes:",input_shapes)
120124

121125
def getInputDataType():
122126
input_dataType = realesrgan.getInputDataType()
123-
print(input_dataType)
127+
print("debug_realesrgan,input_dataType:",input_dataType)
124128

125129
def getOutputShapes():
126130
output_shapes = realesrgan.getOutputShapes()
127-
print(output_shapes)
131+
print("debug_realesrgan,output_shapes:",output_shapes)
128132

129133
def getOutputDataType():
130134
output_dataType = realesrgan.getOutputDataType()
131-
print(output_dataType)
135+
print("debug_realesrgan,output_dataType:",output_dataType)
136+
137+
def getInputName():
138+
input_name = realesrgan.getInputName()
139+
print("debug_realesrgan,input_name:",input_name)
140+
141+
def getOutputName():
142+
output_name = realesrgan.getOutputName()
143+
print("debug_realesrgan,output_name:",output_name)
132144

133145
def Release():
134146
global realesrgan
@@ -151,7 +163,10 @@ def main(input_image_path=None, output_image_path=None, show_image = True):
151163
getInputDataType()
152164
getOutputShapes()
153165
getOutputDataType()
154-
166+
getGraphName()
167+
getInputName()
168+
getOutputName()
169+
155170
Inference(input_image_path=input_image_path,output_image_path=output_image_path,show_image=show_image)
156171

157172
Release()

script/qai_appbuilder/qnncontext.py

Lines changed: 33 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -192,13 +192,22 @@ def getInputDataType(self, ):
192192

193193
def getOutputDataType(self, ):
194194
return self.m_context.getOutputDataType()
195+
196+
def getGraphName(self, ):
197+
return self.m_context.getGraphName()
198+
199+
def getInputName(self, ):
200+
return self.m_context.getInputName()
201+
202+
def getOutputName(self, ):
203+
return self.m_context.getOutputName()
204+
195205
#@timer
196206
def Inference(self, input, perf_profile = PerfProfile.DEFAULT, graphIndex = 0):
197207
input= reshape_input(input)
198208
output = self.m_context.Inference(input, perf_profile, graphIndex)
199209
outputshape_list = self.getOutputShapes()
200210
output = reshape_output(output, outputshape_list)
201-
202211
return output
203212

204213
def apply_binary_update(self, lora_adapters=None):
@@ -261,15 +270,21 @@ def getInputDataType(self, ):
261270
def getOutputDataType(self, ):
262271
return self.m_context.getOutputDataType()
263272

273+
def getGraphName(self, ):
274+
return self.m_context.getGraphName()
275+
276+
def getInputName(self, ):
277+
return self.m_context.getInputName()
278+
279+
def getOutputName(self, ):
280+
return self.m_context.getOutputName()
281+
264282
#@timer
265283
def Inference(self, input, perf_profile = PerfProfile.DEFAULT, graphIndex = 0):
266284
input = reshape_input(input)
267-
268285
output = self.m_context.Inference(input, perf_profile, graphIndex)
269-
270286
outputshape_list = self.getOutputShapes()
271287
output = reshape_output(output, outputshape_list)
272-
273288
return output
274289

275290
#@timer
@@ -317,26 +332,33 @@ def __init__(self,
317332

318333
# issue#24
319334
def getInputShapes(self, ):
320-
return self.m_context.getInputShapes()
335+
return self.m_context.getInputShapes(self.proc_name)
321336

322337
def getOutputShapes(self, ):
323-
return self.m_context.getOutputShapes()
338+
return self.m_context.getOutputShapes(self.proc_name)
324339

325340
def getInputDataType(self, ):
326-
return self.m_context.getInputDataType()
341+
return self.m_context.getInputDataType(self.proc_name)
327342

328343
def getOutputDataType(self, ):
329-
return self.m_context.getOutputDataType()
344+
return self.m_context.getOutputDataType(self.proc_name)
345+
346+
347+
def getGraphName(self, ):
348+
return self.m_context.getGraphName(self.proc_name)
349+
350+
def getInputName(self, ):
351+
return self.m_context.getInputName(self.proc_name)
352+
353+
def getOutputName(self, ):
354+
return self.m_context.getOutputName(self.proc_name)
330355

331356
#@timer
332357
def Inference(self, shareMemory, input, perf_profile = PerfProfile.DEFAULT, graphIndex = 0):
333358
input = reshape_input(input)
334-
335359
output = self.m_context.Inference(shareMemory.m_memory, input, perf_profile, graphIndex)
336-
337360
outputshape_list = self.getOutputShapes()
338361
output = reshape_output(output, outputshape_list)
339-
340362
return output
341363

342364
#@timer

setup.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
# Compile Commands:
1010
# [windows]
1111
# Set QNN_SDK_ROOT=C:/Qualcomm/AIStack/QAIRT/2.40.0.251030/
12-
# Set QNN_SDK_ROOT=C:/Qualcomm/AIStack/QAIRT/2.39.0.250926/
1312
# Set QNN_SDK_ROOT=C:/Qualcomm/AIStack/QAIRT/2.38.0.250901/
1413
# python setup.py bdist_wheel
1514
# [linux]
@@ -50,7 +49,7 @@
5049

5150
python_path = "script"
5251
binary_path = python_path + "/" + package_name
53-
qai_libs_path = os.path.join(binary_path, 'libs')
52+
qai_libs_path = binary_path + "/libs"
5453
os.makedirs(qai_libs_path, exist_ok=True)
5554
init_path = os.path.join(qai_libs_path, "__init__.py")
5655
with open(init_path, "w") as f:
@@ -115,6 +114,7 @@ def build_clean():
115114
os.remove(binary_path + "/libappbuilder.so")
116115
if os.path.exists(binary_path + "/Genie.dll"):
117116
os.remove(binary_path + "/Genie.dll")
117+
shutil.rmtree(qai_libs_path)
118118

119119
def build_cmake():
120120
if not os.path.exists("build"):

src/LibAppBuilder.cpp

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -635,6 +635,76 @@ std::vector<std::string> LibAppBuilder::getOutputDataType(std::string model_name
635635
return m_outputDataType;
636636
};
637637

638+
std::string LibAppBuilder::getGraphName(std::string model_name){
639+
std::unique_ptr<sample_app::QnnSampleApp> app = getQnnSampleApp(model_name);
640+
m_graphName = app->getGraphName();
641+
sg_model_map.insert(std::make_pair(model_name, std::move(app)));
642+
return m_graphName;
643+
};
644+
645+
std::vector<std::string> LibAppBuilder::getInputName(std::string model_name){
646+
std::unique_ptr<sample_app::QnnSampleApp> app = getQnnSampleApp(model_name);
647+
m_inputName = app->getInputName();
648+
sg_model_map.insert(std::make_pair(model_name, std::move(app)));
649+
return m_inputName;
650+
};
651+
652+
std::vector<std::string> LibAppBuilder::getOutputName(std::string model_name){
653+
std::unique_ptr<sample_app::QnnSampleApp> app = getQnnSampleApp(model_name);
654+
m_outputName = app->getOutputName();
655+
sg_model_map.insert(std::make_pair(model_name, std::move(app)));
656+
return m_outputName;
657+
};
658+
659+
ModelInfo_t LibAppBuilder::getModelInfo(std::string model_name, std::string proc_name, std::string input) {
660+
ModelInfo_t output;
661+
#ifdef _WIN32
662+
if (!proc_name.empty()) { // If proc_name, run the model in that process.
663+
output = TalkToSvc_getModelInfo(model_name, proc_name, input);
664+
665+
}
666+
#endif
667+
return output;
668+
}
669+
670+
ModelInfo_t LibAppBuilder::getModelInfo(std::string model_name, std::string input) {
671+
return getModelInfoExt(model_name, input);
672+
}
673+
ModelInfo_t LibAppBuilder::getModelInfoExt(std::string model_name, std::string input) {
674+
bool result = true;
675+
ModelInfo_t info;
676+
677+
std::unique_ptr<sample_app::QnnSampleApp> app = getQnnSampleApp(model_name);
678+
if (nullptr == app) {
679+
app->reportError("getModelInfoExt failure");
680+
result = false;
681+
}
682+
if(result){
683+
if (input == "is") {
684+
info.inputShapes = app->getInputShapes();
685+
} else if (input == "id") {
686+
info.inputDataType = app->getInputDataType();
687+
} else if (input == "os") {
688+
info.outputShapes = app->getOutputShapes();
689+
} else if (input == "od") {
690+
info.outputDataType = app->getOutputDataType();
691+
} else if (input == "in") {
692+
info.inputName = app->getInputName();
693+
} else if (input == "on") {
694+
info.outputName = app->getOutputName();
695+
} else if (input == "gn") {
696+
info.graphName = app->getGraphName();
697+
} else {
698+
printf("wrong input in LibAppBuilder::getModelInfoExt: %s\n", input.c_str());
699+
app->reportError("getModelInfoExt failure");
700+
return info;
701+
}
702+
}
703+
sg_model_map.insert(std::make_pair(model_name, std::move(app)));
704+
705+
return info;
706+
}
707+
638708
int main(int argc, char** argv) {
639709

640710
return EXIT_SUCCESS;

0 commit comments

Comments
 (0)