Skip to content

Commit d5798cc

Browse files
committed
ut
1 parent 2a3a2d4 commit d5798cc

File tree

3 files changed

+54
-5
lines changed

3 files changed

+54
-5
lines changed

.github/workflows/models.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,5 +61,5 @@ jobs:
6161
run: python -m pip freeze
6262

6363
- name: qwen2.5_vl_instruct
64-
run: PYTHONPATH=. UNITTEST_GOING=1 NEVERTEST=1 QWEN25ATTENTION=BIGMASK TESTDTYPE=float16 TESTDEVICE=cpu python _unittests/ut_tasks/try_export.py -f -k test_qwen_2_5_vli_visual
64+
run: PYTHONPATH=. UNITTEST_GOING=1 NEVERTEST=1 QWEN25ATTENTION=BIGMASK TESTDTYPE=float16 TESTDEVICE=cpu python _unittests/ut_tasks/try_export.py -f -k test_qwen25_vli_visual
6565

_unittests/ut_torch_export_patches/test_patch_transformers.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import onnx_diagnostic.torch_export_patches.patches.patch_transformers as patch_transformers
77
from onnx_diagnostic.ext_test_case import (
88
ExtTestCase,
9+
requires_cuda,
910
requires_transformers,
1011
requires_torch,
1112
ignore_warnings,
@@ -518,6 +519,53 @@ def test_qwen2_5_vl_vision_attention_iteration(self):
518519
)
519520
self.clean_dump()
520521

522+
@unittest.skipIf(not patch_qwen2_5, "Qwen25 not part of this transformers")
523+
@requires_cuda()
524+
def test_plug_packed_multi_head_attention_qwen25(self):
525+
from onnx_diagnostic.torch_export_patches.patches._patch_transformers_qwen2_5 import (
526+
qwen_sdpa_attention_versatile,
527+
)
528+
529+
inputs = (
530+
torch.rand((1, 16, 1292, 80), dtype=torch.float16).to("cuda"),
531+
torch.rand((1, 16, 1292, 80), dtype=torch.float16).to("cuda"),
532+
torch.rand((1, 16, 1292, 80), dtype=torch.float16).to("cuda"),
533+
torch.tensor(
534+
[
535+
0,
536+
64,
537+
128,
538+
192,
539+
256,
540+
304,
541+
368,
542+
432,
543+
496,
544+
560,
545+
608,
546+
672,
547+
736,
548+
800,
549+
864,
550+
912,
551+
976,
552+
1040,
553+
1104,
554+
1168,
555+
1216,
556+
1232,
557+
1248,
558+
1264,
559+
1280,
560+
1292,
561+
],
562+
dtype=torch.int64,
563+
).to("cuda"),
564+
)
565+
qwen_sdpa_attention_versatile.verify(
566+
*inputs, scaling=0.11180339887498948, num_heads=16
567+
)
568+
521569

522570
if __name__ == "__main__":
523571
unittest.main(verbosity=2)

onnx_diagnostic/export/onnx_plug.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -238,20 +238,21 @@ def _register(self):
238238
custom_def.register_kernel(None)(self.eager_fn)
239239
custom_def._abstract_fn = self.shape_fn
240240

241-
def verify(self, *args, engine: Optional[Callable] = None) -> VerifyResult:
241+
def verify(self, *args, engine: Optional[Callable] = None, **kwargs) -> VerifyResult:
242242
"""
243243
Verifies that the eager mode is equivalent to the onnx function given
244244
as a replacements. This function evaluates `eager_fn`, checks that the shapes
245245
are equivalent to the ones given by `shape_fn`, and finally evaluates the
246246
onnx translation if the previous did not fail.
247247
248248
:param args: function inputs
249+
:param kwargs: arguments for eager_fn
249250
:param engine: by default an instance of
250251
:class:`onnx_diagnostic.reference.OnnxruntimeEvaluator`.
251252
:return: outputs of :func:`onnx_diagnostic.helpers.max_diff`
252253
"""
253-
expected = self.eager_fn(*args)
254-
shapes = self.shape_fn(*args)
254+
expected = self.eager_fn(*args, **kwargs)
255+
shapes = self.shape_fn(*args, **kwargs)
255256
if isinstance(expected, torch.Tensor):
256257
expected = (expected,)
257258
assert isinstance(shapes, torch.Tensor), (
@@ -279,7 +280,7 @@ def verify(self, *args, engine: Optional[Callable] = None) -> VerifyResult:
279280

280281
# Now the ONNX execution.
281282
assert engine is None, f"Not implemented yet with engine={engine!r}"
282-
sess = OnnxruntimeEvaluator(self.function_proto)
283+
sess = OnnxruntimeEvaluator(self.function_proto, whole=True)
283284
feeds = dict(zip(sess.input_names, args))
284285
got = sess.run(None, feeds)
285286
diffs = tuple(max_diff(e, g) for e, g in zip(expected, got))

0 commit comments

Comments
 (0)