77from onnx_diagnostic .ext_test_case import (
88 ExtTestCase ,
99 requires_cuda ,
10+ requires_onnxruntime ,
1011 requires_transformers ,
1112 requires_torch ,
1213 ignore_warnings ,
@@ -555,7 +556,7 @@ def _get_seqlen(cls) -> torch.Tensor:
555556
556557 @unittest .skipIf (not patch_qwen2_5 , "Qwen25 not part of this transformers" )
557558 @requires_cuda ()
558- def test_plug_packed_multi_head_attention_qwen25_packed_float16 (self ):
559+ def test_plug_multi_head_attention_qwen25_packed_float16 (self ):
559560 from onnx_diagnostic .torch_export_patches .patches ._patch_transformers_qwen2_5 import (
560561 qwen_sdpa_attention_packed_versatile ,
561562 )
@@ -583,8 +584,9 @@ def test_plug_packed_multi_head_attention_qwen25_packed_float16(self):
583584 self .assertEqualArray (results .eager_outputs [0 ], results .onnx_outputs [0 ], atol = 0.01 )
584585 self .assertLess (results .diffs [0 ]["abs" ], 0.01 )
585586
587+ @requires_onnxruntime ("1.24" )
586588 @unittest .skipIf (not patch_qwen2_5 , "Qwen25 not part of this transformers" )
587- def test_plug_packed_multi_head_attention_qwen25_loopmha_float16 (self ):
589+ def test_plug_multi_head_attention_qwen25_loopmha_float16 (self ):
588590 from onnx_diagnostic .torch_export_patches .patches ._patch_transformers_qwen2_5 import (
589591 qwen_sdpa_attention_loopmha_versatile ,
590592 )
@@ -617,8 +619,9 @@ def test_plug_packed_multi_head_attention_qwen25_loopmha_float16(self):
617619 self .assertEqualArray (results .eager_outputs [0 ], results .onnx_outputs [0 ], atol = 0.01 )
618620 self .assertLess (results .diffs [0 ]["abs" ], 0.01 )
619621
622+ @requires_onnxruntime ("1.24" )
620623 @unittest .skipIf (not patch_qwen2_5 , "Qwen25 not part of this transformers" )
621- def test_plug_packed_multi_head_attention_qwen25_loopmha_float32 (self ):
624+ def test_plug_multi_head_attention_qwen25_loopmha_float32 (self ):
622625 from onnx_diagnostic .torch_export_patches .patches ._patch_transformers_qwen2_5 import (
623626 qwen_sdpa_attention_loopmha_versatile ,
624627 )
@@ -651,8 +654,9 @@ def test_plug_packed_multi_head_attention_qwen25_loopmha_float32(self):
651654 self .assertEqualArray (results .eager_outputs [0 ], results .onnx_outputs [0 ], atol = 1e-5 )
652655 self .assertLess (results .diffs [0 ]["abs" ], 1e-5 )
653656
657+ @requires_onnxruntime ("1.24" )
654658 @unittest .skipIf (not patch_qwen2_5 , "Qwen25 not part of this transformers" )
655- def test_plug_packed_multi_head_attention_qwen25_loopa24_float16 (self ):
659+ def test_plug_multi_head_attention_qwen25_loopa24_float16 (self ):
656660 from onnx_diagnostic .torch_export_patches .patches ._patch_transformers_qwen2_5 import (
657661 qwen_sdpa_attention_loopa24_versatile ,
658662 )
@@ -678,8 +682,9 @@ def test_plug_packed_multi_head_attention_qwen25_loopa24_float16(self):
678682 self .assertEqualArray (results .eager_outputs [0 ], results .onnx_outputs [0 ], atol = 0.005 )
679683 self .assertLess (results .diffs [0 ]["abs" ], 0.005 )
680684
685+ @requires_onnxruntime ("1.24" )
681686 @unittest .skipIf (not patch_qwen2_5 , "Qwen25 not part of this transformers" )
682- def test_plug_packed_multi_head_attention_qwen25_loopa24_float32 (self ):
687+ def test_plug_multi_head_attention_qwen25_loopa24_float32 (self ):
683688 from onnx_diagnostic .torch_export_patches .patches ._patch_transformers_qwen2_5 import (
684689 qwen_sdpa_attention_loopa24_versatile ,
685690 )
0 commit comments