@@ -39,12 +39,13 @@ def LoopMHAAttention(
3939 query_3d = op .Reshape (query_transposed , to_3d_shape )
4040 value_3d = op .Reshape (op .Transpose (value_states , perm = [0 , 2 , 1 , 3 ]), to_3d_shape )
4141 key_3d = op .Reshape (op .Transpose (key_states , perm = [0 , 2 , 1 , 3 ]), to_3d_shape )
42+ cu_seqlens = op .Cast (cu_seqlens , to = onnx .TensorProto .INT32 )
4243 num_patches = op .Size (cu_seqlens ) - 1
4344 seq_axis = op .Constant (value_ints = [1 ])
4445 seq_axis_int32 = op .Cast (seq_axis , to = onnx .TensorProto .INT32 )
4546 attn_output = op .Slice (value_3d , [0 ], [0 ], seq_axis )
46- for i in range (num_patches ):
47- i_1d = op .Reshape (i , [1 ])
47+ for i_patch in range (num_patches ):
48+ i_1d = op .Reshape (i_patch , [1 ])
4849 i_plus_1_1d = i_1d + 1
4950 start = op .Gather (cu_seqlens , i_1d , axis = 0 )
5051 end = op .Gather (cu_seqlens , i_plus_1_1d , axis = 0 )
@@ -62,6 +63,14 @@ def LoopMHAAttention(
6263 attn_output_4d = op .Reshape (attn_output , output_shape )
6364 return attn_output_4d
6465
66+ def _add_com_microsoft_opset (function_proto ):
67+ opsets = {d .domain : d .version for d in function_proto .opset_import }
68+ if "com.microsoft" not in opsets :
69+ d = function_proto .opset_import .add ()
70+ d .domain = "com.microsoft"
71+ d .version = 1
72+ return function_proto
73+
6574 @onnxscript .script (opset = onnx_plugs_op )
6675 def PackedAttention (
6776 query ,
@@ -143,20 +152,35 @@ def qwen_sdpa_attention(
143152 return attn_output
144153
145154 # not ideal
146- qwen_sdpa_attention_versatile = EagerDirectReplacementWithOnnx (
155+ qwen_sdpa_attention_packed_versatile = EagerDirectReplacementWithOnnx (
156+ qwen_sdpa_attention ,
157+ lambda qs , * args , ** kwargs : torch .empty (
158+ (qs .shape [0 ], qs .shape [2 ], qs .shape [1 ], qs .shape [3 ]),
159+ dtype = qs .dtype ,
160+ device = qs .device ,
161+ ),
162+ _add_com_microsoft_opset (PackedAttention .to_function_proto ()),
163+ n_inputs = 4 ,
164+ n_outputs = 1 ,
165+ kwargs = dict (scaling = 0.11180339887498948 , num_heads = 16 ),
166+ name = "qwen_sdpa_attention_packed" ,
167+ )
168+ PLUGS .append (qwen_sdpa_attention_packed_versatile )
169+
170+ qwen_sdpa_attention_loopmha_versatile = EagerDirectReplacementWithOnnx (
147171 qwen_sdpa_attention ,
148172 lambda qs , * args , ** kwargs : torch .empty (
149173 (qs .shape [0 ], qs .shape [2 ], qs .shape [1 ], qs .shape [3 ]),
150174 dtype = qs .dtype ,
151175 device = qs .device ,
152176 ),
153- PackedAttention .to_function_proto (),
177+ _add_com_microsoft_opset ( LoopMHAAttention .to_function_proto () ),
154178 n_inputs = 4 ,
155179 n_outputs = 1 ,
156180 kwargs = dict (scaling = 0.11180339887498948 , num_heads = 16 ),
157- name = "qwen_sdpa_attention " ,
181+ name = "qwen_sdpa_attention_loopmha " ,
158182 )
159- PLUGS .append (qwen_sdpa_attention_versatile )
183+ PLUGS .append (qwen_sdpa_attention_loopmha_versatile )
160184
161185 class patched_Qwen2_5_VLForConditionalGeneration :
162186 _PATCHES_ = ["prepare_inputs_for_generation" ]
@@ -496,8 +520,8 @@ def forward(
496520 or attention_interface is patched_sdpa_attention_forward
497521 )
498522 attention_strategy = patched_Qwen2_5_VLVisionAttention .STRATEGY_FOR_ATTENTION ()
499- if is_sdpa and attention_strategy == "PACKED" :
500- attn_output = qwen_sdpa_attention_versatile (
523+ if is_sdpa and attention_strategy in "PACKED" :
524+ attn_output = qwen_sdpa_attention_packed_versatile (
501525 query_states ,
502526 key_states ,
503527 value_states ,
@@ -530,37 +554,37 @@ def forward(
530554 version = 1 ,
531555 )
532556 elif is_sdpa and attention_strategy == "LOOPMHA" :
557+ attn_output = qwen_sdpa_attention_loopmha_versatile (
558+ query_states ,
559+ key_states ,
560+ value_states ,
561+ cu_seqlens ,
562+ self .scaling ,
563+ self .num_heads ,
564+ )
533565
534- def _iteration (start_end , query_states , key_states , value_states ):
535- return patched_Qwen2_5_VLVisionAttentionOneIteration .forward (
536- self ,
537- start_end ,
538- query_states ,
539- key_states ,
540- value_states ,
541- scaling = self .scaling ,
542- dropout = 0.0 if not self .training else self .attention_dropout ,
543- )
544-
545- starts = cu_seqlens [:- 1 ]
546- ends = cu_seqlens [1 :]
547- # cu_seqlens = [0, 10, 14, 27]
548- # starts: [0, 10, 14]
549- # ends: [10, 14, 17]
550- # starts_ends: [[0, 10], [10, 14], [14, 27]]
551- starts_ends = torch .cat ([starts .unsqueeze (1 ), ends .unsqueeze (1 )], dim = 1 )
552- attn_outputs = [
553- _iteration (start_end , query_states , key_states , value_states )
554- for start_end in starts_ends
555- ]
556- # attn_outputs = torch._higher_order_ops.while_loop(
557- # attn_outputs = torch.ops.higher_order.while_loop(
558- # (lambda it, starts_ends, *_args: it < starts_ends.shape[0]),
559- # _iteration,
560- # (torch.tensor(0),
561- # starts_ends, query_states, key_states, value_states), tuple(),
562- # )
563- attn_output = torch .cat (attn_outputs , dim = 1 )
566+ # to rewrite later with a for loop
567+ # def _iteration(start_end, query_states, key_states, value_states):
568+ # return patched_Qwen2_5_VLVisionAttentionOneIteration.forward(
569+ # self,
570+ # start_end,
571+ # query_states,
572+ # key_states,
573+ # value_states,
574+ # scaling=self.scaling,
575+ # dropout=0.0 if not self.training else self.attention_dropout,
576+ # )
577+
578+ # starts = cu_seqlens[:-1]
579+ # ends = cu_seqlens[1:]
580+ # torch._check(starts.shape[0] > 0)
581+ # torch._check(ends.shape[0] > 0)
582+ # starts_ends = torch.cat([starts.unsqueeze(1), ends.unsqueeze(1)], dim=1)
583+ # attn_outputs = [
584+ # _iteration(start_end, query_states, key_states, value_states)
585+ # for start_end in starts_ends
586+ # ]
587+ # attn_output = torch.cat(attn_outputs, dim=1)
564588 elif is_sdpa and attention_strategy == "BIGMASK" :
565589 # make square mask
566590 indices = torch .arange (
0 commit comments