Skip to content

Commit 1cf41ae

Browse files
committed
require
1 parent b34d94f commit 1cf41ae

File tree

1 file changed

+10
-5
lines changed

1 file changed

+10
-5
lines changed

_unittests/ut_torch_export_patches/test_patch_transformers.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from onnx_diagnostic.ext_test_case import (
88
ExtTestCase,
99
requires_cuda,
10+
requires_onnxruntime,
1011
requires_transformers,
1112
requires_torch,
1213
ignore_warnings,
@@ -555,7 +556,7 @@ def _get_seqlen(cls) -> torch.Tensor:
555556

556557
@unittest.skipIf(not patch_qwen2_5, "Qwen25 not part of this transformers")
557558
@requires_cuda()
558-
def test_plug_packed_multi_head_attention_qwen25_packed_float16(self):
559+
def test_plug_multi_head_attention_qwen25_packed_float16(self):
559560
from onnx_diagnostic.torch_export_patches.patches._patch_transformers_qwen2_5 import (
560561
qwen_sdpa_attention_packed_versatile,
561562
)
@@ -583,8 +584,9 @@ def test_plug_packed_multi_head_attention_qwen25_packed_float16(self):
583584
self.assertEqualArray(results.eager_outputs[0], results.onnx_outputs[0], atol=0.01)
584585
self.assertLess(results.diffs[0]["abs"], 0.01)
585586

587+
@requires_onnxruntime("1.24")
586588
@unittest.skipIf(not patch_qwen2_5, "Qwen25 not part of this transformers")
587-
def test_plug_packed_multi_head_attention_qwen25_loopmha_float16(self):
589+
def test_plug_multi_head_attention_qwen25_loopmha_float16(self):
588590
from onnx_diagnostic.torch_export_patches.patches._patch_transformers_qwen2_5 import (
589591
qwen_sdpa_attention_loopmha_versatile,
590592
)
@@ -617,8 +619,9 @@ def test_plug_packed_multi_head_attention_qwen25_loopmha_float16(self):
617619
self.assertEqualArray(results.eager_outputs[0], results.onnx_outputs[0], atol=0.01)
618620
self.assertLess(results.diffs[0]["abs"], 0.01)
619621

622+
@requires_onnxruntime("1.24")
620623
@unittest.skipIf(not patch_qwen2_5, "Qwen25 not part of this transformers")
621-
def test_plug_packed_multi_head_attention_qwen25_loopmha_float32(self):
624+
def test_plug_multi_head_attention_qwen25_loopmha_float32(self):
622625
from onnx_diagnostic.torch_export_patches.patches._patch_transformers_qwen2_5 import (
623626
qwen_sdpa_attention_loopmha_versatile,
624627
)
@@ -651,8 +654,9 @@ def test_plug_packed_multi_head_attention_qwen25_loopmha_float32(self):
651654
self.assertEqualArray(results.eager_outputs[0], results.onnx_outputs[0], atol=1e-5)
652655
self.assertLess(results.diffs[0]["abs"], 1e-5)
653656

657+
@requires_onnxruntime("1.24")
654658
@unittest.skipIf(not patch_qwen2_5, "Qwen25 not part of this transformers")
655-
def test_plug_packed_multi_head_attention_qwen25_loopa24_float16(self):
659+
def test_plug_multi_head_attention_qwen25_loopa24_float16(self):
656660
from onnx_diagnostic.torch_export_patches.patches._patch_transformers_qwen2_5 import (
657661
qwen_sdpa_attention_loopa24_versatile,
658662
)
@@ -678,8 +682,9 @@ def test_plug_packed_multi_head_attention_qwen25_loopa24_float16(self):
678682
self.assertEqualArray(results.eager_outputs[0], results.onnx_outputs[0], atol=0.005)
679683
self.assertLess(results.diffs[0]["abs"], 0.005)
680684

685+
@requires_onnxruntime("1.24")
681686
@unittest.skipIf(not patch_qwen2_5, "Qwen25 not part of this transformers")
682-
def test_plug_packed_multi_head_attention_qwen25_loopa24_float32(self):
687+
def test_plug_multi_head_attention_qwen25_loopa24_float32(self):
683688
from onnx_diagnostic.torch_export_patches.patches._patch_transformers_qwen2_5 import (
684689
qwen_sdpa_attention_loopa24_versatile,
685690
)

0 commit comments

Comments
 (0)