Skip to content

Commit 5842d50

Browse files
xadupresdpython
andauthored
Handles parameters in OnnxruntimeEvaluator for FunctionProto (#324)
* update ci * freeze * patch * better with cache * shorten names * few changes * ut * Handles parameters in OnnxruntimeEvaluator for FunctionProto * final --------- Co-authored-by: xavier dupré <[email protected]>
1 parent 173f6a4 commit 5842d50

File tree

8 files changed

+194
-42
lines changed

8 files changed

+194
-42
lines changed

.github/workflows/models.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,5 +61,5 @@ jobs:
6161
run: python -m pip freeze
6262

6363
- name: qwen2.5_vl_instruct
64-
run: PYTHONPATH=. UNITTEST_GOING=1 NEVERTEST=1 QWEN25ATTENTION=BIGMASK TESTDTYPE=float16 TESTDEVICE=cpu python _unittests/ut_tasks/try_export.py -f -k test_qwen_2_5_vli_visual
64+
run: PYTHONPATH=. UNITTEST_GOING=1 NEVERTEST=1 QWEN25ATTENTION=BIGMASK TESTDTYPE=float16 TESTDEVICE=cpu python _unittests/ut_tasks/try_export.py -f -k test_qwen25_vli_visual
6565

CHANGELOGS.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ Change Logs
44
0.8.3
55
+++++
66

7+
* :pr:`324`: supports FunctionProto with arguments in OnnxruntimeEvaluator
78
* :pr:`323`: drops torch 2.8 on CI
89
* :pr:`322`: support rerunning onnx kernels with torch intermediate results in side-by-side
910
* :pr:`314`: fix modelbuilder download needed after this change https://github.com/microsoft/onnxruntime-genai/pull/1862

_unittests/ut_reference/test_onnxruntime_evaluator.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,41 @@ def test_skip_simplified_layer_normalization(self):
284284
self.assertEqual(got[1].shape, feeds["x"].shape)
285285
self.assertEqual(got[1].dtype, feeds["x"].dtype)
286286

287+
def test_function_proto_with_kwargs(self):
288+
linear_function = oh.make_function(
289+
"test_domain",
290+
"LinearRegression",
291+
["x", "a", "b"],
292+
["y"],
293+
[
294+
oh.make_node("Constant", [], ["eps"]),
295+
oh.make_node("Constant", [], ["zero"], value_ints=[0]),
296+
oh.make_node("Unsqueeze", ["eps", "zero"], ["eps1d"]),
297+
oh.make_node("MatMul", ["x", "a"], ["xa"]),
298+
oh.make_node("Add", ["b", "eps1d"], ["beps"]),
299+
oh.make_node("Add", ["xa", "beps"], ["y"]),
300+
],
301+
[oh.make_opsetid("", 14)],
302+
["epsilon"],
303+
)
304+
att = onnx.AttributeProto()
305+
att.name = "value_float"
306+
att.ref_attr_name = "epsilon"
307+
att.type = onnx.AttributeProto.FLOAT
308+
linear_function.node[0].attribute.append(att)
309+
feeds = dict(
310+
x=np.random.rand(4, 4).astype(np.float32),
311+
a=np.random.rand(4, 2).astype(np.float32),
312+
b=np.random.rand(1, 2).astype(np.float32),
313+
)
314+
epsilon = 15.6
315+
expected = feeds["x"] @ feeds["a"] + feeds["b"] + epsilon
316+
sess = OnnxruntimeEvaluator(
317+
linear_function, whole=True, function_kwargs=dict(epsilon=epsilon)
318+
)
319+
got = sess.run(None, feeds)
320+
self.assertEqualArray(expected, got[0], atol=1e-5)
321+
287322

288323
if __name__ == "__main__":
289324
unittest.main(verbosity=2)

_unittests/ut_tasks/try_export.py

Lines changed: 26 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -185,24 +185,25 @@ def _config_reduction(config, task):
185185
onnx_plugs=PLUGS,
186186
)
187187

188-
with open(
189-
self.get_dump_file(
190-
f"sbs_qwen25_vli_visual.{device}.{dtype}.{attention}.{exporter}.sh"
191-
),
192-
"w",
193-
) as f:
194-
f.write(
195-
textwrap.dedent(
196-
f"""
197-
clear&&python -m onnx_diagnostic sbs \\
198-
-i qwen25_vli_visual.inputs.pt \\
199-
-e test_qwen25_vli_visual.{device}.{dtype}.{attention}.{exporter}.graph.ep.pt2 \\
200-
-m test_qwen25_vli_visual.{device}.{dtype}.{attention}.{exporter}.onnx \\
201-
-o test_qwen25_vli_visual.{device}.{dtype}.{attention}.{exporter}.xlsx \\
202-
-v 1 --atol 0.1 --rtol 1000
203-
"""
188+
if not self.unit_test_going():
189+
with open(
190+
self.get_dump_file(
191+
f"sbs_qwen25_vli_visual.{device}.{dtype}.{attention}.{exporter}.sh"
192+
),
193+
"w",
194+
) as f:
195+
f.write(
196+
textwrap.dedent(
197+
f"""
198+
clear&&python -m onnx_diagnostic sbs \\
199+
-i qwen25_vli_visual.inputs.pt \\
200+
-e test_qwen25_vli_visual.{device}.{dtype}.{attention}.{exporter}.graph.ep.pt2 \\
201+
-m test_qwen25_vli_visual.{device}.{dtype}.{attention}.{exporter}.onnx \\
202+
-o test_qwen25_vli_visual.{device}.{dtype}.{attention}.{exporter}.xlsx \\
203+
-v 1 --atol 0.1 --rtol 1000
204+
"""
205+
)
204206
)
205-
)
206207
print(f"-- MODEL CONVERTED IN {time.perf_counter() - begin}")
207208
model = onnx.load(filename, load_external_data=False)
208209
if attention == "PACKED":
@@ -226,11 +227,13 @@ def _config_reduction(config, task):
226227
assert (
227228
self.unit_test_going() or pt2_files
228229
), f"Unable to find an existing file among {pt2_files!r}"
229-
pt2_file = (
230-
(pt2_files[0] if pt2_files else None)
231-
if not self.unit_test_going()
232-
else None
233-
)
230+
231+
# pt2_file = (
232+
# (pt2_files[0] if pt2_files else None)
233+
# if not self.unit_test_going()
234+
# else None
235+
# )
236+
234237
# self.assertExists(pt2_file)
235238
# ep = torch.export.load(pt2_file)
236239
# diff = self.max_diff(ep.module()(**export_inputs), model.visual(**export_inputs))
@@ -250,8 +253,7 @@ def _config_reduction(config, task):
250253
use_ort=True,
251254
atol=0.02,
252255
rtol=10,
253-
ort_optimized_graph=False,
254-
ep=pt2_file,
256+
# ep=pt2_file,
255257
expected=expected,
256258
)
257259
print(f"-- MODEL VERIFIED IN {time.perf_counter() - begin}")

_unittests/ut_torch_export_patches/test_patch_transformers.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import onnx_diagnostic.torch_export_patches.patches.patch_transformers as patch_transformers
77
from onnx_diagnostic.ext_test_case import (
88
ExtTestCase,
9+
requires_cuda,
910
requires_transformers,
1011
requires_torch,
1112
ignore_warnings,
@@ -518,6 +519,71 @@ def test_qwen2_5_vl_vision_attention_iteration(self):
518519
)
519520
self.clean_dump()
520521

522+
@unittest.skipIf(not patch_qwen2_5, "Qwen25 not part of this transformers")
523+
@requires_cuda()
524+
def test_plug_packed_multi_head_attention_qwen25(self):
525+
from onnx_diagnostic.torch_export_patches.patches._patch_transformers_qwen2_5 import (
526+
qwen_sdpa_attention_versatile,
527+
)
528+
529+
inputs = (
530+
torch.rand((1, 16, 1292, 80), dtype=torch.float16).to("cuda"),
531+
torch.rand((1, 16, 1292, 80), dtype=torch.float16).to("cuda"),
532+
torch.rand((1, 16, 1292, 80), dtype=torch.float16).to("cuda"),
533+
torch.tensor(
534+
[
535+
0,
536+
64,
537+
128,
538+
192,
539+
256,
540+
304,
541+
368,
542+
432,
543+
496,
544+
560,
545+
608,
546+
672,
547+
736,
548+
800,
549+
864,
550+
912,
551+
976,
552+
1040,
553+
1104,
554+
1168,
555+
1216,
556+
1232,
557+
1248,
558+
1264,
559+
1280,
560+
1292,
561+
],
562+
dtype=torch.int64,
563+
).to("cuda"),
564+
)
565+
566+
results = qwen_sdpa_attention_versatile.verify(
567+
*inputs,
568+
scaling=0.5,
569+
num_heads=16,
570+
dump_onnx_model=self.get_dump_file(
571+
"test_plug_packed_multi_head_attention_qwen25.onnx"
572+
),
573+
)
574+
self.assertEqual(len(results.eager_outputs), len(results.onnx_outputs))
575+
self.assertEqual(len(results.eager_outputs), len(results.diffs))
576+
self.assertEqualArray(results.eager_outputs[0], results.onnx_outputs[0], atol=0.01)
577+
self.assertLess(results.diffs[0]["abs"], 0.01)
578+
579+
results = qwen_sdpa_attention_versatile.verify(
580+
*inputs, scaling=0.11180339887498948, num_heads=16
581+
)
582+
self.assertEqual(len(results.eager_outputs), len(results.onnx_outputs))
583+
self.assertEqual(len(results.eager_outputs), len(results.diffs))
584+
self.assertEqualArray(results.eager_outputs[0], results.onnx_outputs[0], atol=0.01)
585+
self.assertLess(results.diffs[0]["abs"], 0.01)
586+
521587

522588
if __name__ == "__main__":
523589
unittest.main(verbosity=2)

onnx_diagnostic/export/onnx_plug.py

Lines changed: 34 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ class VerifyResult:
2727
"""
2828

2929
eager_outputs: TUPLE_TENSORS
30-
onnx_output: TUPLE_TENSORS
30+
onnx_outputs: TUPLE_TENSORS
3131
diffs: Tuple[Dict[str, float], ...]
3232

3333

@@ -238,20 +238,30 @@ def _register(self):
238238
custom_def.register_kernel(None)(self.eager_fn)
239239
custom_def._abstract_fn = self.shape_fn
240240

241-
def verify(self, *args, engine: Optional[Callable] = None) -> VerifyResult:
241+
def verify(
242+
self,
243+
*args,
244+
engine: Optional[Callable] = None,
245+
dump_onnx_model: Optional[str] = None,
246+
**kwargs,
247+
) -> VerifyResult:
242248
"""
243249
Verifies that the eager mode is equivalent to the onnx function given
244250
as a replacements. This function evaluates `eager_fn`, checks that the shapes
245251
are equivalent to the ones given by `shape_fn`, and finally evaluates the
246252
onnx translation if the previous did not fail.
247253
248254
:param args: function inputs
255+
:param kwargs: arguments for eager_fn
249256
:param engine: by default an instance of
250257
:class:`onnx_diagnostic.reference.OnnxruntimeEvaluator`.
258+
:param dump_onnx_model: to dump the onnx model used to verify
259+
eager and onnx produce the same results
260+
:param kwargs: additional arguments to the function
251261
:return: outputs of :func:`onnx_diagnostic.helpers.max_diff`
252262
"""
253-
expected = self.eager_fn(*args)
254-
shapes = self.shape_fn(*args)
263+
expected = self.eager_fn(*args, **kwargs)
264+
shapes = self.shape_fn(*args, **kwargs)
255265
if isinstance(expected, torch.Tensor):
256266
expected = (expected,)
257267
assert isinstance(shapes, torch.Tensor), (
@@ -279,11 +289,23 @@ def verify(self, *args, engine: Optional[Callable] = None) -> VerifyResult:
279289

280290
# Now the ONNX execution.
281291
assert engine is None, f"Not implemented yet with engine={engine!r}"
282-
sess = OnnxruntimeEvaluator(self.function_proto)
283-
feeds = dict(zip(sess.input_names, args))
292+
ags, kws = self._make_args_kwargs(*args, **kwargs)
293+
sess = OnnxruntimeEvaluator(
294+
self.function_proto,
295+
whole=True,
296+
dump_onnx_model=dump_onnx_model,
297+
function_kwargs=kws,
298+
)
299+
feeds = dict(zip(sess.input_names, ags))
284300
got = sess.run(None, feeds)
285-
diffs = tuple(max_diff(e, g) for e, g in zip(expected, got))
286-
return VerifyResult(eager_outputs=expected, onnx_output=tuple(got), diffs=diffs) # type: ignore[arg-type]
301+
diffs = tuple(max_diff(e, g, hist=[0.1, 0.01]) for e, g in zip(expected, got))
302+
return VerifyResult(eager_outputs=expected, onnx_outputs=tuple(got), diffs=diffs) # type: ignore[arg-type]
303+
304+
def _make_args_kwargs(self, *args, **kwargs):
305+
ags = args[: len(self.args_name)]
306+
kws = dict(zip(self.kwargs_name, args[len(self.args_name) :]))
307+
kws.update(kwargs)
308+
return ags, kws
287309

288310
def custom_converter(
289311
self,
@@ -306,9 +328,7 @@ def converter(
306328
self.function_proto.name, domain=self.function_proto.domain
307329
):
308330
g.add_function(self.function_proto)
309-
ags = args[: len(self.args_name)]
310-
kws = dict(zip(self.kwargs_name, args[len(self.args_name) :]))
311-
kws.update(kwargs)
331+
ags, kws = self._make_args_kwargs(*args, **kwargs)
312332
res = g.make_node(
313333
self.function_proto.name,
314334
ags,
@@ -369,7 +389,8 @@ def onnx_dynamo_converter(self) -> Callable:
369389
onnx.defs.register_schema(schema)
370390
op = onnxscript.values.Op(onnx_plug_op, self.function_proto.name, schema)
371391

372-
def converter(*cargs):
373-
return op(*cargs, n_outputs=self.n_outputs)
392+
def converter(*cargs, **ckwargs):
393+
ags, kws = self._make_args_kwargs(*cargs, **ckwargs)
394+
return op(*ags, n_outputs=self.n_outputs, **kws)
374395

375396
return onnxscript.values.TracedOnnxFunction(onnx_plug_op, converter)

onnx_diagnostic/ext_test_case.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1323,7 +1323,7 @@ def assert_onnx_disc(
13231323
and not numpy.isnan(ep_diff["rel"])
13241324
and ep_diff["rel"] <= rtol
13251325
), (
1326-
f"discrepancies in {test_name!r} between the model "
1326+
f"discrepancies in {test_name!r} between the exported program "
13271327
f"and the exported model diff={string_diff(ep_diff)}"
13281328
)
13291329
ep_nx_diff = max_diff(ep_expected, got, flatten=True, hist=[0.1, 0.01])

0 commit comments

Comments
 (0)