Skip to content

Commit 87765df

Browse files
committed
fix issues
1 parent 38d7e10 commit 87765df

File tree

2 files changed

+15
-9
lines changed

2 files changed

+15
-9
lines changed

onnx_diagnostic/helpers/ort_session.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -244,17 +244,23 @@ def run_dlpack(
244244
feeds is a dictionary of :class:`np.ndarray`.
245245
The output device is CPU even if the outputs are on CUDA.
246246
"""
247+
memory = []
247248
new_feeds = {}
248249
for k, v in feeds.items():
249250
if not k:
250251
continue
251-
new_feeds[k] = (
252-
ORTC.OrtValue.ortvalue_from_numpy_with_onnx_type(
252+
if isinstance(v, np.ndarray):
253+
new_feeds[k] = ORTC.OrtValue.ortvalue_from_numpy_with_onnx_type(
253254
v, np_dtype_to_tensor_dtype(v.dtype)
254255
)
255-
if isinstance(v, np.ndarray)
256-
else ORTC.OrtValue.from_dlpack(v.__dlpack__(), v.dtype == torch.bool)
257-
)
256+
elif v.dtype == torch.bool:
257+
vi = v.detach().cpu().numpy()
258+
memory.append(vi)
259+
new_feeds[k] = ORTC.OrtValue.ortvalue_from_numpy_with_onnx_type(
260+
vi, onnx.TensorProto.BOOL
261+
)
262+
else:
263+
new_feeds[k] = ORTC.OrtValue.from_dlpack(v.__dlpack__(), False)
258264

259265
if self.nvtx:
260266
self.torch.cuda.nvtx.range_push("run_with_ort_values")

onnx_diagnostic/reference/ort_evaluator.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,7 @@ def _build_garbage_collector(self) -> Dict[str, int]:
289289
Returns a dictionary with the last node using the results.
290290
"""
291291
needed = {}
292-
for i, node in enumerate(self.rt_nodes_):
292+
for i, node in enumerate(self.rt_nodes_ or []):
293293
for name in node.input:
294294
needed[name] = i
295295
if node.op_type in {"Scan", "If", "Loop"}:
@@ -298,13 +298,13 @@ def _build_garbage_collector(self) -> Dict[str, int]:
298298
needed[name] = i
299299
if isinstance(self.proto, ModelProto):
300300
for o in self.proto.graph.output:
301-
needed[o.name] = len(self.rt_nodes_)
301+
needed[o.name] = len(self.rt_nodes_ or [])
302302
elif isinstance(self.proto, GraphProto):
303303
for o in self.proto.output:
304-
needed[o.name] = len(self.rt_nodes_)
304+
needed[o.name] = len(self.rt_nodes_ or [])
305305
elif isinstance(self.proto, FunctionProto):
306306
for o in self.proto.output:
307-
needed[o] = len(self.rt_nodes_)
307+
needed[o] = len(self.rt_nodes_ or [])
308308
return needed
309309

310310
def _clean_unused_inplace(self, i_node: int, node: NodeProto, results: Dict[str, Any]):

0 commit comments

Comments
 (0)