Skip to content

Commit 1272495

Browse files
committed
Handles parameters in OnnxruntimeEvaluator for FunctionProto
1 parent d5798cc commit 1272495

File tree

5 files changed

+116
-15
lines changed

5 files changed

+116
-15
lines changed

CHANGELOGS.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ Change Logs
44
0.8.3
55
+++++
66

7+
* :pr:`324`: supports FunctionProto with arguments in OnnxruntimeEvaluator
78
* :pr:`323`: drops torch 2.8 on CI
89
* :pr:`322`: support rerunning onnx kernels with torch intermediate results in side-by-side
910
* :pr:`314`: fix modelbuilder download needed after this change https://github.com/microsoft/onnxruntime-genai/pull/1862

_unittests/ut_reference/test_onnxruntime_evaluator.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,41 @@ def test_skip_simplified_layer_normalization(self):
284284
self.assertEqual(got[1].shape, feeds["x"].shape)
285285
self.assertEqual(got[1].dtype, feeds["x"].dtype)
286286

287+
def test_function_proto_with_kwargs(self):
288+
linear_function = oh.make_function(
289+
"test_domain",
290+
"LinearRegression",
291+
["x", "a", "b"],
292+
["y"],
293+
[
294+
oh.make_node("Constant", [], ["eps"]),
295+
oh.make_node("Constant", [], ["zero"], value_ints=[0]),
296+
oh.make_node("Unsqueeze", ["eps", "zero"], ["eps1d"]),
297+
oh.make_node("MatMul", ["x", "a"], ["xa"]),
298+
oh.make_node("Add", ["b", "eps1d"], ["beps"]),
299+
oh.make_node("Add", ["xa", "beps"], ["y"]),
300+
],
301+
[oh.make_opsetid("", 14)],
302+
["epsilon"],
303+
)
304+
att = onnx.AttributeProto()
305+
att.name = "value_float"
306+
att.ref_attr_name = "epsilon"
307+
att.type = onnx.AttributeProto.FLOAT
308+
linear_function.node[0].attribute.append(att)
309+
feeds = dict(
310+
x=np.random.rand(4, 4).astype(np.float32),
311+
a=np.random.rand(4, 2).astype(np.float32),
312+
b=np.random.rand(1, 2).astype(np.float32),
313+
)
314+
epsilon = 15.6
315+
expected = feeds["x"] @ feeds["a"] + feeds["b"] + epsilon
316+
sess = OnnxruntimeEvaluator(
317+
linear_function, whole=True, function_kwargs=dict(epsilon=epsilon)
318+
)
319+
got = sess.run(None, feeds)
320+
self.assertEqualArray(expected, got[0], atol=1e-5)
321+
287322

288323
if __name__ == "__main__":
289324
unittest.main(verbosity=2)

_unittests/ut_torch_export_patches/test_patch_transformers.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -562,9 +562,27 @@ def test_plug_packed_multi_head_attention_qwen25(self):
562562
dtype=torch.int64,
563563
).to("cuda"),
564564
)
565-
qwen_sdpa_attention_versatile.verify(
565+
566+
results = qwen_sdpa_attention_versatile.verify(
567+
*inputs,
568+
scaling=0.5,
569+
num_heads=16,
570+
dump_onnx_model=self.get_dump_file(
571+
"test_plug_packed_multi_head_attention_qwen25.onnx"
572+
),
573+
)
574+
self.assertEqual(len(results.eager_outputs), len(results.onnx_outputs))
575+
self.assertEqual(len(results.eager_outputs), len(results.diffs))
576+
self.assertEqualArray(results.eager_outputs[0], results.onnx_outputs[0], atol=0.01)
577+
self.assertLess(results.diffs[0]["abs"], 0.01)
578+
579+
results = qwen_sdpa_attention_versatile.verify(
566580
*inputs, scaling=0.11180339887498948, num_heads=16
567581
)
582+
self.assertEqual(len(results.eager_outputs), len(results.onnx_outputs))
583+
self.assertEqual(len(results.eager_outputs), len(results.diffs))
584+
self.assertEqualArray(results.eager_outputs[0], results.onnx_outputs[0], atol=0.01)
585+
self.assertLess(results.diffs[0]["abs"], 0.01)
568586

569587

570588
if __name__ == "__main__":

onnx_diagnostic/export/onnx_plug.py

Lines changed: 31 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ class VerifyResult:
2727
"""
2828

2929
eager_outputs: TUPLE_TENSORS
30-
onnx_output: TUPLE_TENSORS
30+
onnx_outputs: TUPLE_TENSORS
3131
diffs: Tuple[Dict[str, float], ...]
3232

3333

@@ -238,7 +238,13 @@ def _register(self):
238238
custom_def.register_kernel(None)(self.eager_fn)
239239
custom_def._abstract_fn = self.shape_fn
240240

241-
def verify(self, *args, engine: Optional[Callable] = None, **kwargs) -> VerifyResult:
241+
def verify(
242+
self,
243+
*args,
244+
engine: Optional[Callable] = None,
245+
dump_onnx_model: Optional[str] = None,
246+
**kwargs,
247+
) -> VerifyResult:
242248
"""
243249
Verifies that the eager mode is equivalent to the onnx function given
244250
as a replacements. This function evaluates `eager_fn`, checks that the shapes
@@ -249,6 +255,9 @@ def verify(self, *args, engine: Optional[Callable] = None, **kwargs) -> VerifyRe
249255
:param kwargs: arguments for eager_fn
250256
:param engine: by default an instance of
251257
:class:`onnx_diagnostic.reference.OnnxruntimeEvaluator`.
258+
:param dump_onnx_model: to dump the onnx model used to verify
259+
eager and onnx produce the same results
260+
:param kwargs: additional arguments to the function
252261
:return: outputs of :func:`onnx_diagnostic.helpers.max_diff`
253262
"""
254263
expected = self.eager_fn(*args, **kwargs)
@@ -280,11 +289,23 @@ def verify(self, *args, engine: Optional[Callable] = None, **kwargs) -> VerifyRe
280289

281290
# Now the ONNX execution.
282291
assert engine is None, f"Not implemented yet with engine={engine!r}"
283-
sess = OnnxruntimeEvaluator(self.function_proto, whole=True)
284-
feeds = dict(zip(sess.input_names, args))
292+
ags, kws = self._make_args_kwargs(*args, **kwargs)
293+
sess = OnnxruntimeEvaluator(
294+
self.function_proto,
295+
whole=True,
296+
dump_onnx_model=dump_onnx_model,
297+
function_kwargs=kws,
298+
)
299+
feeds = dict(zip(sess.input_names, ags))
285300
got = sess.run(None, feeds)
286-
diffs = tuple(max_diff(e, g) for e, g in zip(expected, got))
287-
return VerifyResult(eager_outputs=expected, onnx_output=tuple(got), diffs=diffs) # type: ignore[arg-type]
301+
diffs = tuple(max_diff(e, g, hist=[0.1, 0.01]) for e, g in zip(expected, got))
302+
return VerifyResult(eager_outputs=expected, onnx_outputs=tuple(got), diffs=diffs) # type: ignore[arg-type]
303+
304+
def _make_args_kwargs(self, *args, **kwargs):
305+
ags = args[: len(self.args_name)]
306+
kws = dict(zip(self.kwargs_name, args[len(self.args_name) :]))
307+
kws.update(kwargs)
308+
return ags, kws
288309

289310
def custom_converter(
290311
self,
@@ -307,9 +328,7 @@ def converter(
307328
self.function_proto.name, domain=self.function_proto.domain
308329
):
309330
g.add_function(self.function_proto)
310-
ags = args[: len(self.args_name)]
311-
kws = dict(zip(self.kwargs_name, args[len(self.args_name) :]))
312-
kws.update(kwargs)
331+
ags, kws = self._make_args_kwargs(*args, **kwargs)
313332
res = g.make_node(
314333
self.function_proto.name,
315334
ags,
@@ -370,7 +389,8 @@ def onnx_dynamo_converter(self) -> Callable:
370389
onnx.defs.register_schema(schema)
371390
op = onnxscript.values.Op(onnx_plug_op, self.function_proto.name, schema)
372391

373-
def converter(*cargs):
374-
return op(*cargs, n_outputs=self.n_outputs)
392+
def converter(*cargs, **ckwargs):
393+
ags, kws = self._make_args_kwargs(*cargs, **ckwargs)
394+
return op(*ags, n_outputs=self.n_outputs, **kws)
375395

376396
return onnxscript.values.TracedOnnxFunction(onnx_plug_op, converter)

onnx_diagnostic/reference/ort_evaluator.py

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,9 @@ class OnnxruntimeEvaluator:
5454
:param whole: if True, do not split node by node
5555
:param torch_or_numpy: force the use of one of them, True for torch,
5656
False for numpy, None to let the class choose
57+
:param dump_onnx_model: dumps the temporary onnx model created if whole is True
58+
:param function_kwargs: a FunctionProto may have parameters,
59+
this contains the values of them
5760
"""
5861

5962
def __init__(
@@ -77,6 +80,8 @@ def __init__(
7780
opsets: Optional[Union[int, Dict[str, int]]] = None,
7881
whole: bool = False,
7982
torch_or_numpy: Optional[bool] = None,
83+
function_kwargs: Optional[Dict[str, Any]] = None,
84+
dump_onnx_model: Optional[str] = None,
8085
):
8186
if isinstance(proto, str):
8287
self.proto: Proto = load(proto)
@@ -90,6 +95,9 @@ def __init__(
9095
assert isinstance(
9196
self.proto, PROTO
9297
), f"Unexpected type for self.proto {type(self.proto)}"
98+
assert (
99+
whole or not dump_onnx_model
100+
), f"whole must be True for dump_onnx_model={dump_onnx_model!r}"
93101

94102
self._cache: Dict[
95103
Any, Tuple[Proto, Union["OnnxruntimeEvaluator", _InferenceSession]] # noqa: UP037
@@ -109,6 +117,8 @@ def __init__(
109117
use_training_api=use_training_api,
110118
)
111119
self.to_tensor_or_array = to_array_extended if not torch_or_numpy else to_tensor
120+
self.function_kwargs = function_kwargs
121+
self.dump_onnx_model = dump_onnx_model
112122

113123
self.verbose = verbose
114124
self.torch_or_numpy = torch_or_numpy
@@ -357,11 +367,12 @@ def _make_model_proto(
357367
nodes: Sequence[NodeProto],
358368
vinputs: Sequence[ValueInfoProto],
359369
voutputs: Sequence[ValueInfoProto],
370+
functions: Optional[Sequence[FunctionProto]] = None,
360371
) -> ModelProto:
361372
onx = oh.make_model(
362373
oh.make_graph(nodes, "-", vinputs, voutputs),
363374
ir_version=getattr(self.proto, "ir_version", self.ir_version),
364-
functions=getattr(self.proto, "functions", None),
375+
functions=[*getattr(self.proto, "functions", []), *(functions or [])],
365376
)
366377
del onx.opset_import[:]
367378
if hasattr(self.proto, "opset_import"):
@@ -430,8 +441,18 @@ def _get_sess(
430441
if isinstance(node, ModelProto):
431442
onx = node
432443
else:
444+
functions = []
445+
if isinstance(node, FunctionProto):
446+
functions.append(node)
447+
node = oh.make_node(
448+
node.name,
449+
list(node.input),
450+
list(node.output),
451+
domain=node.domain,
452+
**(self.function_kwargs or {}),
453+
)
433454
assert isinstance(node, NodeProto), f"Unexpected type {type(node)} for node"
434-
if node.op_type == "Constant":
455+
if node.op_type == "Constant" and node.domain == "":
435456
# We force the type to be a boolean.
436457
ref = ExtendedReferenceEvaluator(node)
437458
cst = ref.run(None, {})[0]
@@ -457,10 +478,16 @@ def _get_sess(
457478
# no need to run shape inference
458479
prenodes, voutputs = self._make_model_outputs(node, vinputs)
459480

460-
onx = self._make_model_proto([*prenodes, node], vinputs, voutputs)
481+
onx = self._make_model_proto(
482+
[*prenodes, node], vinputs, voutputs, functions=functions
483+
)
461484
if node.op_type in {"Shape", "Size"}:
462485
on_cpu = True
463486

487+
if self.dump_onnx_model:
488+
onnx_save(
489+
onx, self.dump_onnx_model, save_as_external_data=len(onx.graph.node) > 100
490+
)
464491
cls = (
465492
InferenceSessionForNumpy
466493
if any(isinstance(i, np.ndarray) for i in inputs)

0 commit comments

Comments
 (0)