77from onnx_diagnostic .ext_test_case import (
88 ExtTestCase ,
99 requires_cuda ,
10+ requires_onnxruntime ,
1011 requires_transformers ,
1112 requires_torch ,
1213 ignore_warnings ,
@@ -519,9 +520,43 @@ def test_qwen2_5_vl_vision_attention_iteration(self):
519520 )
520521 self .clean_dump ()
521522
523+ @classmethod
524+ def _get_seqlen (cls ) -> torch .Tensor :
525+ return torch .tensor (
526+ [
527+ 0 ,
528+ 64 ,
529+ 128 ,
530+ 192 ,
531+ 256 ,
532+ 304 ,
533+ 368 ,
534+ 432 ,
535+ 496 ,
536+ 560 ,
537+ 608 ,
538+ 672 ,
539+ 736 ,
540+ 800 ,
541+ 864 ,
542+ 912 ,
543+ 976 ,
544+ 1040 ,
545+ 1104 ,
546+ 1168 ,
547+ 1216 ,
548+ 1232 ,
549+ 1248 ,
550+ 1264 ,
551+ 1280 ,
552+ 1292 ,
553+ ],
554+ dtype = torch .int64 ,
555+ )
556+
522557 @unittest .skipIf (not patch_qwen2_5 , "Qwen25 not part of this transformers" )
523558 @requires_cuda ()
524- def test_plug_packed_multi_head_attention_qwen25_packed (self ):
559+ def test_plug_multi_head_attention_qwen25_packed_float16 (self ):
525560 from onnx_diagnostic .torch_export_patches .patches ._patch_transformers_qwen2_5 import (
526561 qwen_sdpa_attention_packed_versatile ,
527562 )
@@ -530,37 +565,7 @@ def test_plug_packed_multi_head_attention_qwen25_packed(self):
530565 torch .rand ((1 , 16 , 1292 , 80 ), dtype = torch .float16 ).to ("cuda" ),
531566 torch .rand ((1 , 16 , 1292 , 80 ), dtype = torch .float16 ).to ("cuda" ),
532567 torch .rand ((1 , 16 , 1292 , 80 ), dtype = torch .float16 ).to ("cuda" ),
533- torch .tensor (
534- [
535- 0 ,
536- 64 ,
537- 128 ,
538- 192 ,
539- 256 ,
540- 304 ,
541- 368 ,
542- 432 ,
543- 496 ,
544- 560 ,
545- 608 ,
546- 672 ,
547- 736 ,
548- 800 ,
549- 864 ,
550- 912 ,
551- 976 ,
552- 1040 ,
553- 1104 ,
554- 1168 ,
555- 1216 ,
556- 1232 ,
557- 1248 ,
558- 1264 ,
559- 1280 ,
560- 1292 ,
561- ],
562- dtype = torch .int64 ,
563- ).to ("cuda" ),
568+ self ._get_seqlen ().to ("cuda" ),
564569 )
565570
566571 results = qwen_sdpa_attention_packed_versatile .verify (
@@ -579,8 +584,9 @@ def test_plug_packed_multi_head_attention_qwen25_packed(self):
579584 self .assertEqualArray (results .eager_outputs [0 ], results .onnx_outputs [0 ], atol = 0.01 )
580585 self .assertLess (results .diffs [0 ]["abs" ], 0.01 )
581586
587+ @requires_onnxruntime ("1.24" )
582588 @unittest .skipIf (not patch_qwen2_5 , "Qwen25 not part of this transformers" )
583- def test_plug_packed_multi_head_attention_qwen25_loopmha (self ):
589+ def test_plug_multi_head_attention_qwen25_loopmha_float16 (self ):
584590 from onnx_diagnostic .torch_export_patches .patches ._patch_transformers_qwen2_5 import (
585591 qwen_sdpa_attention_loopmha_versatile ,
586592 )
@@ -589,46 +595,15 @@ def test_plug_packed_multi_head_attention_qwen25_loopmha(self):
589595 torch .rand ((1 , 16 , 1292 , 80 ), dtype = torch .float16 ),
590596 torch .rand ((1 , 16 , 1292 , 80 ), dtype = torch .float16 ),
591597 torch .rand ((1 , 16 , 1292 , 80 ), dtype = torch .float16 ),
592- torch .tensor (
593- [
594- 0 ,
595- 64 ,
596- 128 ,
597- 192 ,
598- 256 ,
599- 304 ,
600- 368 ,
601- 432 ,
602- 496 ,
603- 560 ,
604- 608 ,
605- 672 ,
606- 736 ,
607- 800 ,
608- 864 ,
609- 912 ,
610- 976 ,
611- 1040 ,
612- 1104 ,
613- 1168 ,
614- 1216 ,
615- 1232 ,
616- 1248 ,
617- 1264 ,
618- 1280 ,
619- 1292 ,
620- ],
621- dtype = torch .int64 ,
622- ),
598+ self ._get_seqlen (),
623599 )
624600
625601 results = qwen_sdpa_attention_loopmha_versatile .verify (
626602 * inputs ,
627603 scaling = 0.5 ,
628604 num_heads = 16 ,
629- itype = onnx .TensorProto .FLOAT16 ,
630605 dump_onnx_model = self .get_dump_file (
631- "test_plug_packed_multi_head_attention_qwen25_loopmha .onnx"
606+ "test_plug_packed_multi_head_attention_qwen25_loopmha_float16 .onnx"
632607 ),
633608 )
634609 self .assertEqual (len (results .eager_outputs ), len (results .onnx_outputs ))
@@ -637,13 +612,104 @@ def test_plug_packed_multi_head_attention_qwen25_loopmha(self):
637612 self .assertLess (results .diffs [0 ]["abs" ], 0.01 )
638613
639614 results = qwen_sdpa_attention_loopmha_versatile .verify (
640- * inputs , scaling = 0.11180339887498948 , num_heads = 16 , itype = onnx . TensorProto . FLOAT16
615+ * inputs , scaling = 0.11180339887498948 , num_heads = 16
641616 )
642617 self .assertEqual (len (results .eager_outputs ), len (results .onnx_outputs ))
643618 self .assertEqual (len (results .eager_outputs ), len (results .diffs ))
644619 self .assertEqualArray (results .eager_outputs [0 ], results .onnx_outputs [0 ], atol = 0.01 )
645620 self .assertLess (results .diffs [0 ]["abs" ], 0.01 )
646621
622+ @requires_onnxruntime ("1.24" )
623+ @unittest .skipIf (not patch_qwen2_5 , "Qwen25 not part of this transformers" )
624+ def test_plug_multi_head_attention_qwen25_loopmha_float32 (self ):
625+ from onnx_diagnostic .torch_export_patches .patches ._patch_transformers_qwen2_5 import (
626+ qwen_sdpa_attention_loopmha_versatile ,
627+ )
628+
629+ inputs = (
630+ torch .rand ((1 , 16 , 1292 , 80 ), dtype = torch .float32 ),
631+ torch .rand ((1 , 16 , 1292 , 80 ), dtype = torch .float32 ),
632+ torch .rand ((1 , 16 , 1292 , 80 ), dtype = torch .float32 ),
633+ self ._get_seqlen (),
634+ )
635+
636+ results = qwen_sdpa_attention_loopmha_versatile .verify (
637+ * inputs ,
638+ scaling = 0.5 ,
639+ num_heads = 16 ,
640+ dump_onnx_model = self .get_dump_file (
641+ "test_plug_packed_multi_head_attention_qwen25_loopmha_float16.onnx"
642+ ),
643+ )
644+ self .assertEqual (len (results .eager_outputs ), len (results .onnx_outputs ))
645+ self .assertEqual (len (results .eager_outputs ), len (results .diffs ))
646+ self .assertEqualArray (results .eager_outputs [0 ], results .onnx_outputs [0 ], atol = 1e-5 )
647+ self .assertLess (results .diffs [0 ]["abs" ], 1e-5 )
648+
649+ results = qwen_sdpa_attention_loopmha_versatile .verify (
650+ * inputs , scaling = 0.11180339887498948 , num_heads = 16
651+ )
652+ self .assertEqual (len (results .eager_outputs ), len (results .onnx_outputs ))
653+ self .assertEqual (len (results .eager_outputs ), len (results .diffs ))
654+ self .assertEqualArray (results .eager_outputs [0 ], results .onnx_outputs [0 ], atol = 1e-5 )
655+ self .assertLess (results .diffs [0 ]["abs" ], 1e-5 )
656+
657+ @requires_onnxruntime ("1.24" )
658+ @unittest .skipIf (not patch_qwen2_5 , "Qwen25 not part of this transformers" )
659+ def test_plug_multi_head_attention_qwen25_loopa24_float16 (self ):
660+ from onnx_diagnostic .torch_export_patches .patches ._patch_transformers_qwen2_5 import (
661+ qwen_sdpa_attention_loopa24_versatile ,
662+ )
663+
664+ inputs = (
665+ torch .rand ((1 , 16 , 1292 , 80 ), dtype = torch .float16 ),
666+ torch .rand ((1 , 16 , 1292 , 80 ), dtype = torch .float16 ),
667+ torch .rand ((1 , 16 , 1292 , 80 ), dtype = torch .float16 ),
668+ self ._get_seqlen (),
669+ )
670+
671+ results = qwen_sdpa_attention_loopa24_versatile .verify (* inputs , scaling = 0.5 )
672+ self .assertEqual (len (results .eager_outputs ), len (results .onnx_outputs ))
673+ self .assertEqual (len (results .eager_outputs ), len (results .diffs ))
674+ self .assertEqualArray (results .eager_outputs [0 ], results .onnx_outputs [0 ], atol = 1e-2 )
675+ self .assertLess (results .diffs [0 ]["abs" ], 1e-2 )
676+
677+ results = qwen_sdpa_attention_loopa24_versatile .verify (
678+ * inputs , scaling = 0.11180339887498948
679+ )
680+ self .assertEqual (len (results .eager_outputs ), len (results .onnx_outputs ))
681+ self .assertEqual (len (results .eager_outputs ), len (results .diffs ))
682+ self .assertEqualArray (results .eager_outputs [0 ], results .onnx_outputs [0 ], atol = 0.005 )
683+ self .assertLess (results .diffs [0 ]["abs" ], 0.005 )
684+
685+ @requires_onnxruntime ("1.24" )
686+ @unittest .skipIf (not patch_qwen2_5 , "Qwen25 not part of this transformers" )
687+ def test_plug_multi_head_attention_qwen25_loopa24_float32 (self ):
688+ from onnx_diagnostic .torch_export_patches .patches ._patch_transformers_qwen2_5 import (
689+ qwen_sdpa_attention_loopa24_versatile ,
690+ )
691+
692+ inputs = (
693+ torch .rand ((1 , 16 , 1292 , 80 ), dtype = torch .float32 ),
694+ torch .rand ((1 , 16 , 1292 , 80 ), dtype = torch .float32 ),
695+ torch .rand ((1 , 16 , 1292 , 80 ), dtype = torch .float32 ),
696+ self ._get_seqlen (),
697+ )
698+
699+ results = qwen_sdpa_attention_loopa24_versatile .verify (* inputs , scaling = 0.5 )
700+ self .assertEqual (len (results .eager_outputs ), len (results .onnx_outputs ))
701+ self .assertEqual (len (results .eager_outputs ), len (results .diffs ))
702+ self .assertEqualArray (results .eager_outputs [0 ], results .onnx_outputs [0 ], atol = 1e-5 )
703+ self .assertLess (results .diffs [0 ]["abs" ], 1e-5 )
704+
705+ results = qwen_sdpa_attention_loopa24_versatile .verify (
706+ * inputs , scaling = 0.11180339887498948
707+ )
708+ self .assertEqual (len (results .eager_outputs ), len (results .onnx_outputs ))
709+ self .assertEqual (len (results .eager_outputs ), len (results .diffs ))
710+ self .assertEqualArray (results .eager_outputs [0 ], results .onnx_outputs [0 ], atol = 1e-5 )
711+ self .assertLess (results .diffs [0 ]["abs" ], 1e-5 )
712+
647713
648714if __name__ == "__main__" :
649715 unittest .main (verbosity = 2 )
0 commit comments