@@ -70,6 +70,33 @@ static inline py::dtype dtypeFromString(const std::string& dtypeStr) {
7070 return py::dtype::of<uint8_t >();
7171}
7272
73+ // ---------------------------------------------------------------------------------
74+ // Helper: case-insensitive "float32 request" for input_data_type/output_data_type
75+ // Accepts: "float", "float32", "fp32"
76+ // ---------------------------------------------------------------------------------
77+ static inline bool isFloat32Request (const std::string& s) {
78+ std::string t = s;
79+ for (auto & c : t) c = static_cast <char >(::tolower (c));
80+ return (t == " float" || t == " float32" || t == " fp32" );
81+ }
82+
83+ // ---------------------------------------------------------------------------------
84+ // Helper: identify if a py::dtype is float32 (NumPy kind 'f' and itemsize == 4)
85+ // Note: We avoid relying on dtype object identity; use kind/itemsize instead.
86+ // ---------------------------------------------------------------------------------
87+ static inline bool isNumpyFloat32Dtype (const py::dtype& dt) {
88+ try {
89+ // dt.kind is a 1-char string in NumPy, e.g. 'f' for floating
90+ std::string kindStr = py::str (dt.attr (" kind" ));
91+ char kind = kindStr.empty () ? ' \0 ' : kindStr[0 ];
92+ py::ssize_t itemsize = dt.attr (" itemsize" ).cast <py::ssize_t >();
93+ return (kind == ' f' && itemsize == 4 );
94+ } catch (...) {
95+ // conservative fallback
96+ return false ;
97+ }
98+ }
99+
73100// ---------------------------------------------------------------------------
74101// Helper: product of dims (for output element count)
75102// ---------------------------------------------------------------------------
@@ -206,7 +233,8 @@ std::vector<py::array> inference(std::string model_name, const std::vector<py::a
206233
207234 // Keep temporary converted/contiguous arrays alive during ModelInference
208235 std::vector<py::array> keepAlive;
209- const bool floatMode = (input_data_type == " float" );
236+ const bool floatMode = isFloat32Request (input_data_type);
237+ const bool floatOutMode = isFloat32Request (output_data_type);
210238
211239 // QNN_INF("inference input vector length: %d\n", input.size());
212240
@@ -271,7 +299,16 @@ std::vector<py::array> inference(std::string model_name, const std::vector<py::a
271299 { static_cast <py::ssize_t >(dt.itemsize ()) },
272300 outputBuffers[i],
273301 free_data);
274- output.push_back (result);
302+
303+ // If user requests float output, cast to float32 before returning.
304+ // IMPORTANT: do NOT reinterpret the raw buffer as float32 (size may not match).
305+ // We first create 'result' using the inferred real dtype, then cast (copy) if needed.
306+ if (floatOutMode && !isNumpyFloat32Dtype (dt)) {
307+ py::array_t <float , py::array::c_style | py::array::forcecast> farr (result);
308+ output.push_back (py::array (farr));
309+ } else {
310+ output.push_back (result);
311+ }
275312 }
276313 // print_time("convert Data To ArrayV");
277314
@@ -288,7 +325,8 @@ std::vector<py::array> inference_P(std::string model_name, std::string proc_name
288325
289326 // Keep temporary converted/contiguous arrays alive during ModelInference
290327 std::vector<py::array> keepAlive;
291- const bool floatMode = (input_data_type == " float" );
328+ const bool floatMode = isFloat32Request (input_data_type);
329+ const bool floatOutMode = isFloat32Request (output_data_type);
292330
293331 for (auto i = 0 ; i < input.size (); i++) {
294332 if (floatMode) {
@@ -351,7 +389,15 @@ std::vector<py::array> inference_P(std::string model_name, std::string proc_name
351389 { static_cast <py::ssize_t >(dt.itemsize ()) },
352390 outputBuffers[i],
353391 free_data);
354- output.push_back (result);
392+
393+ // If user requests float output, cast to float32 before returning.
394+ // For shared memory outputs, this will create a float32 copy (shared memory remains untouched).
395+ if (floatOutMode && !isNumpyFloat32Dtype (dt)) {
396+ py::array_t <float , py::array::c_style | py::array::forcecast> farr (result);
397+ output.push_back (py::array (farr));
398+ } else {
399+ output.push_back (result);
400+ }
355401 }
356402 // print_time("convert Data To ArrayV");
357403
@@ -371,6 +417,7 @@ int delete_memory(std::string share_memory_name) {
371417class ShareMemory {
372418public:
373419 std::string m_share_memory_name;
420+ size_t m_share_memory_size = 0 ;
374421
375422 ShareMemory (const std::string& share_memory_name, const size_t share_memory_size);
376423 ~ShareMemory ();
0 commit comments