Skip to content

Commit 093c104

Browse files
authored
Support lists with OnnxruntimeEvaluator (#329)
* Support lists with OnnxruntimeEvaluator * keep type * more checks * fix clone * fix * catch more eception * mypy * add unit test * assert * fix sbs
1 parent 9dcb936 commit 093c104

File tree

10 files changed

+416
-25
lines changed

10 files changed

+416
-25
lines changed

.github/workflows/check-urls.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,6 @@ jobs:
4242
print_all: false
4343
timeout: 2
4444
retry_count# : 2
45-
exclude_urls: https://hal.archives-,ouvertes.fr/hal-00990252/document,http://badge.fury.io/py/onnx-diagnostic,https://azure.microsoft.com/en-us/products/devops/pipelines,https://github.com/pytorch/pytorch/pull/83137#issuecomment-1211320670,https://github.com/NVIDIA/TransformerEngine.git@6a9edc38bf9b941b7d369af5103fa8fe0b121d61,https://medium.com/@msouza.os/llm-from-scratch-with-pytorch-9f21808c6319,https://github.com/pytorch/pytorch/blob/main/torch/fx/experimental/symbolic_shapes.py#L5965,https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-24-04.html,https://badge.fury.io/py/onnx-diagnostic.svg,https://github.com/huggingface/transformers/pull/36311,https://www.linux.org/
46-
exclude_patterns: https://www.data.gouv.fr/fr/datasets/r/e3d83ab3-dc52-4c99-abaf-8a38050cc68c,https://dev.azure.com/,https://azure.microsoft.com/en-us/products/devops/pipelines,https://github.com/pytorch/pytorch/pull/83137#issuecomment-1211320670,https://github.com/NVIDIA/TransformerEngine.git@6a9edc38bf9b941b7d369af5103fa8fe0b121d61,https://github.com/pytorch/pytorch/blob/main/torch/,https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-24-04.html,https://badge.fury.io/py/onnx-diagnostic.svg,https://github.com/,https://codecov.io/,https://huggingface.co/,https://www.linux.org/
45+
exclude_urls: https://hal.archives-,ouvertes.fr/hal-00990252/document,http://badge.fury.io/py/onnx-diagnostic,https://azure.microsoft.com/en-us/products/devops/pipelines,https://github.com/pytorch/pytorch/pull/83137#issuecomment-1211320670,https://github.com/NVIDIA/TransformerEngine.git@6a9edc38bf9b941b7d369af5103fa8fe0b121d61,https://medium.com/@msouza.os/llm-from-scratch-with-pytorch-9f21808c6319,https://github.com/pytorch/pytorch/blob/main/torch/fx/experimental/symbolic_shapes.py#L5965,https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-24-04.html,https://badge.fury.io/py/onnx-diagnostic.svg,https://github.com/huggingface/transformers/pull/36311,https://www.linux.org/,https://docs.scipy.org/doc/scipy/
46+
exclude_patterns: https://www.data.gouv.fr/fr/datasets/r/e3d83ab3-dc52-4c99-abaf-8a38050cc68c,https://dev.azure.com/,https://azure.microsoft.com/en-us/products/devops/pipelines,https://github.com/pytorch/pytorch/pull/83137#issuecomment-1211320670,https://github.com/NVIDIA/TransformerEngine.git@6a9edc38bf9b941b7d369af5103fa8fe0b121d61,https://github.com/pytorch/pytorch/blob/main/torch/,https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-24-04.html,https://badge.fury.io/py/onnx-diagnostic.svg,https://github.com/,https://codecov.io/,https://huggingface.co/,https://www.linux.org/,https://docs.scipy.org/doc/scipy/
4747
# force_pass : true

CHANGELOGS.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ Change Logs
55
+++++
66

77
* :pr:`330`: fixes access rope_parameters for ``transformers>=5``
8+
* :pr:`329`: supports lists with OnnxruntimeEvaluator
89
* :pr:`326`: use ConcatFromSequence in LoopMHA with the loop
910
* :pr:`325`: adds plug for LoopMHA, extends the unit tests to measure the discrepancies
1011
* :pr:`324`: supports FunctionProto with arguments in OnnxruntimeEvaluator

_unittests/ut_helpers/test_onnx_helper.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -402,7 +402,9 @@ def test_enumerate_results_loop(self):
402402
new_axis=0,
403403
),
404404
],
405-
)
405+
),
406+
ir_version=10,
407+
opset_imports=[oh.make_opsetid("", 22)],
406408
)
407409
res = list(enumerate_results(model, "slice_start", verbose=2))
408410
self.assertEqual(len(res), 2)

_unittests/ut_reference/test_onnxruntime_evaluator.py

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121

2222
TFLOAT = onnx.TensorProto.FLOAT
23+
TINT64 = onnx.TensorProto.INT64
2324

2425

2526
class TestOnnxruntimeEvaluator(ExtTestCase):
@@ -319,6 +320,123 @@ def test_function_proto_with_kwargs(self):
319320
got = sess.run(None, feeds)
320321
self.assertEqualArray(expected, got[0], atol=1e-5)
321322

323+
@hide_stdout()
324+
def test_ort_eval_loop_seq(self):
325+
x = np.array([1, 2, 3, 4, 5]).astype(np.float32)
326+
_mkv_ = oh.make_tensor_value_info
327+
model = oh.make_model(
328+
graph=oh.make_graph(
329+
name="loop_test",
330+
inputs=[
331+
oh.make_tensor_value_info("trip_count", TINT64, ["a"]),
332+
oh.make_tensor_value_info("cond", onnx.TensorProto.BOOL, []),
333+
],
334+
outputs=[oh.make_tensor_value_info("res", TFLOAT, [])],
335+
nodes=[
336+
oh.make_node("SequenceEmpty", [], ["seq_empty"], dtype=TFLOAT),
337+
oh.make_node(
338+
"Loop",
339+
inputs=["trip_count", "cond", "seq_empty"],
340+
outputs=["seq_res"],
341+
body=oh.make_graph(
342+
[
343+
oh.make_node(
344+
"Identity", inputs=["cond_in"], outputs=["cond_out"]
345+
),
346+
oh.make_node(
347+
"Constant",
348+
inputs=[],
349+
outputs=["x"],
350+
value=oh.make_tensor(
351+
name="const_tensor_x",
352+
data_type=TFLOAT,
353+
dims=x.shape,
354+
vals=x.flatten().astype(float),
355+
),
356+
),
357+
oh.make_node(
358+
"Constant",
359+
inputs=[],
360+
outputs=["one"],
361+
value=oh.make_tensor(
362+
name="const_tensor_one",
363+
data_type=TINT64,
364+
dims=(),
365+
vals=[1],
366+
),
367+
),
368+
oh.make_node(
369+
"Constant",
370+
inputs=[],
371+
outputs=["slice_start"],
372+
value=oh.make_tensor(
373+
name="const_tensor_zero",
374+
data_type=TINT64,
375+
dims=(1,),
376+
vals=[0],
377+
),
378+
),
379+
oh.make_node(
380+
"Add", inputs=["iter_count", "one"], outputs=["end"]
381+
),
382+
oh.make_node(
383+
"Constant",
384+
inputs=[],
385+
outputs=["axes"],
386+
value=oh.make_tensor(
387+
name="const_tensor_axes",
388+
data_type=TINT64,
389+
dims=(1,),
390+
vals=[0],
391+
),
392+
),
393+
oh.make_node(
394+
"Unsqueeze", inputs=["end", "axes"], outputs=["slice_end"]
395+
),
396+
oh.make_node(
397+
"Slice",
398+
inputs=["x", "slice_start", "slice_end"],
399+
outputs=["slice_out"],
400+
),
401+
oh.make_node(
402+
"SequenceInsert",
403+
inputs=["seq_in", "slice_out"],
404+
outputs=["seq_out"],
405+
),
406+
],
407+
"loop_body",
408+
[
409+
_mkv_("iter_count", TINT64, []),
410+
_mkv_("cond_in", onnx.TensorProto.BOOL, []),
411+
oh.make_tensor_sequence_value_info("seq_in", TFLOAT, None),
412+
],
413+
[
414+
_mkv_("cond_out", onnx.TensorProto.BOOL, []),
415+
oh.make_tensor_sequence_value_info("seq_out", TFLOAT, None),
416+
],
417+
),
418+
),
419+
oh.make_node(
420+
"ConcatFromSequence",
421+
inputs=["seq_res"],
422+
outputs=["res"],
423+
axis=0,
424+
new_axis=0,
425+
),
426+
],
427+
),
428+
ir_version=10,
429+
opset_imports=[oh.make_opsetid("", 22)],
430+
)
431+
ev = OnnxruntimeEvaluator(model, verbose=10)
432+
feeds = dict(trip_count=torch.tensor([3], dtype=torch.int64), cond=torch.tensor(True))
433+
got = ev.run(None, feeds)
434+
self.assertEqual((6,), got[0].shape)
435+
self.assertEqualArray(
436+
torch.tensor([1.0, 1.0, 2.0, 1.0, 2.0, 3.0], dtype=torch.float32), got[0]
437+
)
438+
self.assertIsInstance(got[0], torch.Tensor)
439+
322440

323441
if __name__ == "__main__":
324442
unittest.main(verbosity=2)

_unittests/ut_reference/test_torch_onnx_evaluator.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1123,7 +1123,9 @@ def test_loop(self):
11231123
new_axis=0,
11241124
),
11251125
],
1126-
)
1126+
),
1127+
ir_version=10,
1128+
opset_imports=[oh.make_opsetid("", 22)],
11271129
)
11281130
self._finalize_test(
11291131
model, torch.tensor(5, dtype=torch.int64), torch.tensor(1, dtype=torch.bool)

_unittests/ut_torch_onnx/test_sbs.py

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,10 @@
88
ignore_errors,
99
requires_cuda,
1010
)
11+
from onnx_diagnostic.helpers.rt_helper import make_feeds
1112
from onnx_diagnostic.reference import ExtendedReferenceEvaluator, OnnxruntimeEvaluator
1213
from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str
14+
from onnx_diagnostic.torch_export_patches.patches.patch_transformers import patch_qwen2_5
1315
from onnx_diagnostic.torch_onnx.sbs import run_aligned
1416
from onnx_diagnostic.torch_onnx.sbs_dataclasses import RunAlignedRecord, ReplayConfiguration
1517
from onnx_diagnostic.export.api import to_onnx
@@ -671,6 +673,124 @@ def forward(self, x):
671673
)
672674
self.clean_dump()
673675

676+
@unittest.skipIf(not patch_qwen2_5, "Qwen25 not part of this transformers")
677+
@hide_stdout()
678+
def test_sbs_with_loops(self):
679+
import torch
680+
from onnx_diagnostic.torch_export_patches.patches.patch_transformers import (
681+
PLUGS_Qwen25,
682+
)
683+
from onnx_diagnostic.torch_export_patches.patches._patch_transformers_qwen2_5 import (
684+
qwen_sdpa_attention_loopmha_versatile,
685+
)
686+
687+
class Model(torch.nn.Module):
688+
def forward(self, query, key, value, seq_lens):
689+
rg1 = torch.arange(4, dtype=torch.int32).unsqueeze(0)
690+
rg0 = torch.arange(4, dtype=torch.int32).unsqueeze(1)
691+
mask = (rg0 <= rg1).flatten().reshape((1, -1, 1, 1)).to(query.dtype)
692+
qs = query * mask
693+
ks = key * mask
694+
vs = value * mask
695+
attn_output = qwen_sdpa_attention_loopmha_versatile(
696+
qs,
697+
ks,
698+
vs,
699+
seq_lens,
700+
0.11,
701+
16,
702+
(
703+
onnx.TensorProto.FLOAT
704+
if query.dtype == torch.float32
705+
else (
706+
onnx.TensorProto.FLOAT16
707+
if query.dtype == torch.float16
708+
else onnx.TensorProto.BFLOAT16
709+
)
710+
),
711+
)
712+
red = attn_output.mean(dim=-1, keepdim=True)
713+
return attn_output - red
714+
715+
model = Model()
716+
inputs = (
717+
torch.rand((1, 16, 1292, 80), dtype=torch.float16),
718+
torch.rand((1, 16, 1292, 80), dtype=torch.float16),
719+
torch.rand((1, 16, 1292, 80), dtype=torch.float16),
720+
torch.tensor(
721+
[
722+
0,
723+
64,
724+
128,
725+
192,
726+
256,
727+
304,
728+
368,
729+
432,
730+
496,
731+
560,
732+
608,
733+
672,
734+
736,
735+
800,
736+
864,
737+
912,
738+
976,
739+
1040,
740+
1104,
741+
1168,
742+
1216,
743+
1232,
744+
1248,
745+
1264,
746+
1280,
747+
1292,
748+
],
749+
dtype=torch.int64,
750+
),
751+
)
752+
expected = model(*inputs)
753+
ds = ({2: "seq_length"}, {2: "seq_length"}, {2: "seq_length"}, {0: "num_patches"})
754+
onnx_file = self.get_dump_file("test_sbs_with_loops.onnx")
755+
ep_file = self.get_dump_file("test_sbs_with_loops")
756+
to_onnx(
757+
model,
758+
inputs,
759+
dynamic_shapes=ds,
760+
filename=onnx_file,
761+
save_ep=(ep_file, 2**28),
762+
exporter="custom",
763+
onnx_plugs=PLUGS_Qwen25,
764+
target_opset=22,
765+
)
766+
input_file = ep_file + ".input.pt"
767+
ep_file = ep_file + ".ep.pt2"
768+
self.assertExists(onnx_file)
769+
self.assertExists(ep_file)
770+
self.assertExists(input_file)
771+
sess = self.check_ort(onnx_file)
772+
input_names = [i.name for i in sess.get_inputs()]
773+
feeds = make_feeds(input_names, inputs, use_numpy=True)
774+
got = sess.run(None, feeds)
775+
self.assertEqualArray(expected, got[0], atol=1e-3)
776+
# sbs
777+
ep = torch.export.load(ep_file)
778+
onx = onnx.load(onnx_file)
779+
kwargs = make_feeds(input_names, inputs, use_numpy=False)
780+
results = list(
781+
run_aligned(
782+
ep,
783+
onx,
784+
kwargs=kwargs,
785+
run_cls=OnnxruntimeEvaluator,
786+
verbose=11,
787+
use_tensor=True,
788+
),
789+
)
790+
df = pandas.DataFrame(list(results)).dropna(axis=1, how="all")
791+
df.to_excel(self.get_dump_file("test_sbs_with_loops.xlsx"))
792+
# self.clean_dump()
793+
674794

675795
if __name__ == "__main__":
676796
unittest.main(verbosity=2)

onnx_diagnostic/ext_test_case.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1111,7 +1111,10 @@ def check_ort(
11111111
) -> "onnxruntime.InferenceSession": # noqa: F821
11121112
from onnxruntime import InferenceSession
11131113

1114-
return InferenceSession(onx.SerializeToString(), providers=["CPUExecutionProvider"])
1114+
return InferenceSession(
1115+
onx if isinstance(onx, str) else onx.SerializeToString(),
1116+
providers=["CPUExecutionProvider"],
1117+
)
11151118

11161119
def assertRaise(self, fct: Callable, exc_type: type[Exception], msg: Optional[str] = None):
11171120
"""In the name"""

onnx_diagnostic/helpers/ort_session.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,10 @@ def __init__(
108108
session_options,
109109
providers=providers,
110110
)
111-
except onnxruntime.capi.onnxruntime_pybind11_state.Fail as e:
111+
except (
112+
onnxruntime.capi.onnxruntime_pybind11_state.Fail,
113+
onnxruntime.capi.onnxruntime_pybind11_state.InvalidGraph,
114+
) as e:
112115
if isinstance(sess, onnx.ModelProto):
113116
debug_path = "_debug_InferenceSession_last_failure.onnx"
114117
onnx.save(

0 commit comments

Comments
 (0)