Skip to content

Commit bebe8c0

Browse files
authored
Fixes minor details in make_feeds (#174)
* fix to_any for BaseModelOutput * doc * Fixes dummy input in the signature * fix sig
1 parent 8c47d90 commit bebe8c0

File tree

1 file changed

+11
-1
lines changed

1 file changed

+11
-1
lines changed

onnx_diagnostic/helpers/rt_helper.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,4 +112,14 @@ def make_feeds(
112112

113113
if copy:
114114
flat = [t.copy() if hasattr(t, "copy") else t.clone() for t in flat]
115-
return dict(zip(names, flat))
115+
# bool, int, float, onnxruntime does not support float, bool, int
116+
new_flat = []
117+
for i in flat:
118+
if isinstance(i, bool):
119+
i = np.array(i, dtype=np.bool_)
120+
elif isinstance(i, int):
121+
i = np.array(i, dtype=np.int64)
122+
elif isinstance(i, float):
123+
i = np.array(i, dtype=np.float32)
124+
new_flat.append(i)
125+
return dict(zip(names, new_flat))

0 commit comments

Comments
 (0)