Skip to content

Commit db41440

Browse files
committed
more checks
1 parent b8a5de3 commit db41440

File tree

2 files changed

+16
-2
lines changed

2 files changed

+16
-2
lines changed

onnx_diagnostic/reference/ort_evaluator.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,10 @@ def to(self, tensor_like) -> "OnnxList":
8181
]
8282
)
8383

84+
def clone(self) -> "OnnxList":
85+
"Clone (torch)."
86+
return [t.clone() for t in self]
87+
8488

8589
class OnnxruntimeEvaluator:
8690
"""
@@ -708,6 +712,10 @@ def _run(self, node: NodeProto, inputs: List[Any], results: Dict[str, Any]) -> L
708712

709713
outputs = list(sess.run(None, feeds))
710714
assert isinstance(outputs, list), f"Unexpected type for outputs {type(outputs)}"
715+
assert not any(type(v) is list for v in outputs), (
716+
f"One output type is a list, this should not be allowed, "
717+
f"node.op_type={node.op_type}, feeds={string_type(feeds,with_shape_type=True)}"
718+
)
711719
return outputs
712720

713721
def _run_if(
@@ -783,6 +791,11 @@ def _run_scan_or_loop(
783791
self, node: NodeProto, inputs: List[Any], results: Dict[str, Any]
784792
) -> List[Any]:
785793
"""Runs a node Scan."""
794+
assert not any(type(i) is list for i in inputs), (
795+
f"One input is a list but it should an OnnxList, "
796+
f"node.op_type={node.op_type!r}, node.input={node.input}, "
797+
f"inputs={string_type(inputs, with_shape=True)}"
798+
)
786799
feeds = dict(zip(node.input, inputs))
787800
feeds.update(results)
788801
name = "body"

onnx_diagnostic/torch_onnx/sbs.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from ..helpers.onnx_helper import pretty_onnx
1010
from ..helpers.torch_helper import to_numpy, from_numpy, to_tensor, torch_dtype_to_onnx_dtype
1111
from ..helpers.torch_fx_graph_helper import prepare_args_kwargs, run_fx_node
12+
from ..reference.ort_evaluator import OnnxList
1213
from .sbs_dataclasses import (
1314
ReplayConfiguration,
1415
RunAlignedRecord,
@@ -26,11 +27,11 @@ def _check_tensor_(use_tensor, name, obj, flip_type=False):
2627
if isinstance(obj, torch.Tensor):
2728
obj = to_numpy(obj)
2829

29-
assert not use_tensor or isinstance(obj, torch.Tensor), (
30+
assert not use_tensor or isinstance(obj, (torch.Tensor, OnnxList)), (
3031
f"Unexpected type {type(obj)} for {name!r}. "
3132
f"use_tensor is True so torch.Tensor is expected."
3233
)
33-
assert use_tensor or isinstance(obj, np.ndarray), (
34+
assert use_tensor or isinstance(obj, (np.ndarray, OnnxList)), (
3435
f"Unexpected type {type(obj)} for {name!r}. "
3536
f"use_tensor is False so np.array is expected."
3637
)

0 commit comments

Comments
 (0)