Skip to content

Commit db380a1

Browse files
committed
less assert
1 parent 54819d0 commit db380a1

File tree

2 files changed

+38
-3
lines changed

2 files changed

+38
-3
lines changed

onnx_diagnostic/reference/torch_evaluator.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,14 @@ class TorchOnnxEvaluator:
6464
:func:`onnx_diagnostic.reference.torch_evaluator.get_kernels`.
6565
"""
6666

67+
class IO:
68+
"IO"
69+
70+
def __init__(self, name: str, type: int, shape: Tuple[Union[str, int], ...]):
71+
self.name = name
72+
self.type = type
73+
self.shape = shape
74+
6775
def __init__(
6876
self,
6977
proto: Union[onnx.FunctionProto, onnx.GraphProto, onnx.ModelProto],
@@ -103,6 +111,26 @@ def __init__(
103111
self._build_kernels(proto.graph.node)
104112
self.input_names = [i.name for i in proto.graph.input]
105113
self.output_names = [i.name for i in proto.graph.output]
114+
self._io_input_names = [
115+
self.IO(
116+
name=i.name,
117+
type=i.type.tensor_type.elem_type,
118+
shape=tuple(
119+
d.dim_param or d.dim_value for d in i.type.tensor_type.shape.dim
120+
),
121+
)
122+
for i in proto.graph.input
123+
]
124+
self._io_output_names = [
125+
self.IO(
126+
name=i.name,
127+
type=i.type.tensor_type.elem_type,
128+
shape=tuple(
129+
d.dim_param or d.dim_value for d in i.type.tensor_type.shape.dim
130+
),
131+
)
132+
for i in proto.graph.output
133+
]
106134
elif isinstance(proto, onnx.GraphProto):
107135
assert opsets, "opsets must be specified if proto is a graph"
108136
assert not proto.sparse_initializer, "sparse_initializer not support yet"
@@ -135,6 +163,16 @@ def __init__(
135163
elif not info.is_output and not info.is_initializer:
136164
self.last_used[info.last_used].append(name)
137165

166+
def get_inputs(self):
167+
"Same API than onnxruntime."
168+
assert hasattr(self, "_io_input_names"), "Missing attribute '_io_input_names'."
169+
return self._io_input_names
170+
171+
def get_outputs(self):
172+
"Same API than onnxruntime."
173+
assert hasattr(self, "_io_output_names"), "Missing attribute '_io_output_names'."
174+
return self._io_output_names
175+
138176
@property
139177
def on_cuda(self) -> bool:
140178
"Tells if the default device is CUDA."

onnx_diagnostic/torch_onnx/runtime_info.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -119,9 +119,6 @@ def string_type(self) -> str:
119119
def set_value(self, value: torch.Tensor):
120120
"""Sets the value."""
121121
assert value is not None, "Use clean_value to set a value to None"
122-
assert (
123-
self.name != "position_ids" or value.get_device() >= 0
124-
), f"{value} - is_shape={self.is_shape}"
125122
self.value = value
126123
if self.dtype:
127124
assert (

0 commit comments

Comments
 (0)