@@ -519,9 +519,43 @@ def test_qwen2_5_vl_vision_attention_iteration(self):
519519 )
520520 self .clean_dump ()
521521
522+ @classmethod
523+ def _get_seqlen (cls ) -> torch .Tensor :
524+ return torch .tensor (
525+ [
526+ 0 ,
527+ 64 ,
528+ 128 ,
529+ 192 ,
530+ 256 ,
531+ 304 ,
532+ 368 ,
533+ 432 ,
534+ 496 ,
535+ 560 ,
536+ 608 ,
537+ 672 ,
538+ 736 ,
539+ 800 ,
540+ 864 ,
541+ 912 ,
542+ 976 ,
543+ 1040 ,
544+ 1104 ,
545+ 1168 ,
546+ 1216 ,
547+ 1232 ,
548+ 1248 ,
549+ 1264 ,
550+ 1280 ,
551+ 1292 ,
552+ ],
553+ dtype = torch .int64 ,
554+ )
555+
522556 @unittest .skipIf (not patch_qwen2_5 , "Qwen25 not part of this transformers" )
523557 @requires_cuda ()
524- def test_plug_packed_multi_head_attention_qwen25_packed (self ):
558+ def test_plug_packed_multi_head_attention_qwen25_packed_float16 (self ):
525559 from onnx_diagnostic .torch_export_patches .patches ._patch_transformers_qwen2_5 import (
526560 qwen_sdpa_attention_packed_versatile ,
527561 )
@@ -530,37 +564,7 @@ def test_plug_packed_multi_head_attention_qwen25_packed(self):
530564 torch .rand ((1 , 16 , 1292 , 80 ), dtype = torch .float16 ).to ("cuda" ),
531565 torch .rand ((1 , 16 , 1292 , 80 ), dtype = torch .float16 ).to ("cuda" ),
532566 torch .rand ((1 , 16 , 1292 , 80 ), dtype = torch .float16 ).to ("cuda" ),
533- torch .tensor (
534- [
535- 0 ,
536- 64 ,
537- 128 ,
538- 192 ,
539- 256 ,
540- 304 ,
541- 368 ,
542- 432 ,
543- 496 ,
544- 560 ,
545- 608 ,
546- 672 ,
547- 736 ,
548- 800 ,
549- 864 ,
550- 912 ,
551- 976 ,
552- 1040 ,
553- 1104 ,
554- 1168 ,
555- 1216 ,
556- 1232 ,
557- 1248 ,
558- 1264 ,
559- 1280 ,
560- 1292 ,
561- ],
562- dtype = torch .int64 ,
563- ).to ("cuda" ),
567+ self ._get_seqlen ().to ("cuda" ),
564568 )
565569
566570 results = qwen_sdpa_attention_packed_versatile .verify (
@@ -580,7 +584,7 @@ def test_plug_packed_multi_head_attention_qwen25_packed(self):
580584 self .assertLess (results .diffs [0 ]["abs" ], 0.01 )
581585
582586 @unittest .skipIf (not patch_qwen2_5 , "Qwen25 not part of this transformers" )
583- def test_plug_packed_multi_head_attention_qwen25_loopmha (self ):
587+ def test_plug_packed_multi_head_attention_qwen25_loopmha_float16 (self ):
584588 from onnx_diagnostic .torch_export_patches .patches ._patch_transformers_qwen2_5 import (
585589 qwen_sdpa_attention_loopmha_versatile ,
586590 )
@@ -589,46 +593,15 @@ def test_plug_packed_multi_head_attention_qwen25_loopmha(self):
589593 torch .rand ((1 , 16 , 1292 , 80 ), dtype = torch .float16 ),
590594 torch .rand ((1 , 16 , 1292 , 80 ), dtype = torch .float16 ),
591595 torch .rand ((1 , 16 , 1292 , 80 ), dtype = torch .float16 ),
592- torch .tensor (
593- [
594- 0 ,
595- 64 ,
596- 128 ,
597- 192 ,
598- 256 ,
599- 304 ,
600- 368 ,
601- 432 ,
602- 496 ,
603- 560 ,
604- 608 ,
605- 672 ,
606- 736 ,
607- 800 ,
608- 864 ,
609- 912 ,
610- 976 ,
611- 1040 ,
612- 1104 ,
613- 1168 ,
614- 1216 ,
615- 1232 ,
616- 1248 ,
617- 1264 ,
618- 1280 ,
619- 1292 ,
620- ],
621- dtype = torch .int64 ,
622- ),
596+ self ._get_seqlen (),
623597 )
624598
625599 results = qwen_sdpa_attention_loopmha_versatile .verify (
626600 * inputs ,
627601 scaling = 0.5 ,
628602 num_heads = 16 ,
629- itype = onnx .TensorProto .FLOAT16 ,
630603 dump_onnx_model = self .get_dump_file (
631- "test_plug_packed_multi_head_attention_qwen25_loopmha .onnx"
604+ "test_plug_packed_multi_head_attention_qwen25_loopmha_float16 .onnx"
632605 ),
633606 )
634607 self .assertEqual (len (results .eager_outputs ), len (results .onnx_outputs ))
@@ -637,13 +610,101 @@ def test_plug_packed_multi_head_attention_qwen25_loopmha(self):
637610 self .assertLess (results .diffs [0 ]["abs" ], 0.01 )
638611
639612 results = qwen_sdpa_attention_loopmha_versatile .verify (
640- * inputs , scaling = 0.11180339887498948 , num_heads = 16 , itype = onnx . TensorProto . FLOAT16
613+ * inputs , scaling = 0.11180339887498948 , num_heads = 16
641614 )
642615 self .assertEqual (len (results .eager_outputs ), len (results .onnx_outputs ))
643616 self .assertEqual (len (results .eager_outputs ), len (results .diffs ))
644617 self .assertEqualArray (results .eager_outputs [0 ], results .onnx_outputs [0 ], atol = 0.01 )
645618 self .assertLess (results .diffs [0 ]["abs" ], 0.01 )
646619
620+ @unittest .skipIf (not patch_qwen2_5 , "Qwen25 not part of this transformers" )
621+ def test_plug_packed_multi_head_attention_qwen25_loopmha_float32 (self ):
622+ from onnx_diagnostic .torch_export_patches .patches ._patch_transformers_qwen2_5 import (
623+ qwen_sdpa_attention_loopmha_versatile ,
624+ )
625+
626+ inputs = (
627+ torch .rand ((1 , 16 , 1292 , 80 ), dtype = torch .float32 ),
628+ torch .rand ((1 , 16 , 1292 , 80 ), dtype = torch .float32 ),
629+ torch .rand ((1 , 16 , 1292 , 80 ), dtype = torch .float32 ),
630+ self ._get_seqlen (),
631+ )
632+
633+ results = qwen_sdpa_attention_loopmha_versatile .verify (
634+ * inputs ,
635+ scaling = 0.5 ,
636+ num_heads = 16 ,
637+ dump_onnx_model = self .get_dump_file (
638+ "test_plug_packed_multi_head_attention_qwen25_loopmha_float16.onnx"
639+ ),
640+ )
641+ self .assertEqual (len (results .eager_outputs ), len (results .onnx_outputs ))
642+ self .assertEqual (len (results .eager_outputs ), len (results .diffs ))
643+ self .assertEqualArray (results .eager_outputs [0 ], results .onnx_outputs [0 ], atol = 1e-5 )
644+ self .assertLess (results .diffs [0 ]["abs" ], 1e-5 )
645+
646+ results = qwen_sdpa_attention_loopmha_versatile .verify (
647+ * inputs , scaling = 0.11180339887498948 , num_heads = 16
648+ )
649+ self .assertEqual (len (results .eager_outputs ), len (results .onnx_outputs ))
650+ self .assertEqual (len (results .eager_outputs ), len (results .diffs ))
651+ self .assertEqualArray (results .eager_outputs [0 ], results .onnx_outputs [0 ], atol = 1e-5 )
652+ self .assertLess (results .diffs [0 ]["abs" ], 1e-5 )
653+
654+ @unittest .skipIf (not patch_qwen2_5 , "Qwen25 not part of this transformers" )
655+ def test_plug_packed_multi_head_attention_qwen25_loopa24_float16 (self ):
656+ from onnx_diagnostic .torch_export_patches .patches ._patch_transformers_qwen2_5 import (
657+ qwen_sdpa_attention_loopa24_versatile ,
658+ )
659+
660+ inputs = (
661+ torch .rand ((1 , 16 , 1292 , 80 ), dtype = torch .float16 ),
662+ torch .rand ((1 , 16 , 1292 , 80 ), dtype = torch .float16 ),
663+ torch .rand ((1 , 16 , 1292 , 80 ), dtype = torch .float16 ),
664+ self ._get_seqlen (),
665+ )
666+
667+ results = qwen_sdpa_attention_loopa24_versatile .verify (* inputs , scaling = 0.5 )
668+ self .assertEqual (len (results .eager_outputs ), len (results .onnx_outputs ))
669+ self .assertEqual (len (results .eager_outputs ), len (results .diffs ))
670+ self .assertEqualArray (results .eager_outputs [0 ], results .onnx_outputs [0 ], atol = 1e-5 )
671+ self .assertLess (results .diffs [0 ]["abs" ], 1e-5 )
672+
673+ results = qwen_sdpa_attention_loopa24_versatile .verify (
674+ * inputs , scaling = 0.11180339887498948
675+ )
676+ self .assertEqual (len (results .eager_outputs ), len (results .onnx_outputs ))
677+ self .assertEqual (len (results .eager_outputs ), len (results .diffs ))
678+ self .assertEqualArray (results .eager_outputs [0 ], results .onnx_outputs [0 ], atol = 1e-5 )
679+ self .assertLess (results .diffs [0 ]["abs" ], 1e-5 )
680+
681+ @unittest .skipIf (not patch_qwen2_5 , "Qwen25 not part of this transformers" )
682+ def test_plug_packed_multi_head_attention_qwen25_loopa24_float32 (self ):
683+ from onnx_diagnostic .torch_export_patches .patches ._patch_transformers_qwen2_5 import (
684+ qwen_sdpa_attention_loopa24_versatile ,
685+ )
686+
687+ inputs = (
688+ torch .rand ((1 , 16 , 1292 , 80 ), dtype = torch .float32 ),
689+ torch .rand ((1 , 16 , 1292 , 80 ), dtype = torch .float32 ),
690+ torch .rand ((1 , 16 , 1292 , 80 ), dtype = torch .float32 ),
691+ self ._get_seqlen (),
692+ )
693+
694+ results = qwen_sdpa_attention_loopa24_versatile .verify (* inputs , scaling = 0.5 )
695+ self .assertEqual (len (results .eager_outputs ), len (results .onnx_outputs ))
696+ self .assertEqual (len (results .eager_outputs ), len (results .diffs ))
697+ self .assertEqualArray (results .eager_outputs [0 ], results .onnx_outputs [0 ], atol = 1e-5 )
698+ self .assertLess (results .diffs [0 ]["abs" ], 1e-5 )
699+
700+ results = qwen_sdpa_attention_loopa24_versatile .verify (
701+ * inputs , scaling = 0.11180339887498948
702+ )
703+ self .assertEqual (len (results .eager_outputs ), len (results .onnx_outputs ))
704+ self .assertEqual (len (results .eager_outputs ), len (results .diffs ))
705+ self .assertEqualArray (results .eager_outputs [0 ], results .onnx_outputs [0 ], atol = 1e-5 )
706+ self .assertLess (results .diffs [0 ]["abs" ], 1e-5 )
707+
647708
648709if __name__ == "__main__" :
649710 unittest .main (verbosity = 2 )
0 commit comments