@@ -32,6 +32,7 @@ def LoopMHAAttention(
3232 cu_seqlens ,
3333 scaling : float = 0.11180339887498948 ,
3434 num_heads : int = 16 ,
35+ itype : int = onnx .TensorProto .FLOAT ,
3536 ):
3637 to_3d_shape = op .Constant (value_ints = [0 , 0 , - 1 ])
3738 query_transposed = op .Transpose (query_states , perm = [0 , 2 , 1 , 3 ])
@@ -43,7 +44,8 @@ def LoopMHAAttention(
4344 num_patches = op .Size (cu_seqlens ) - 1
4445 seq_axis = op .Constant (value_ints = [1 ])
4546 seq_axis_int32 = op .Cast (seq_axis , to = onnx .TensorProto .INT32 )
46- attn_output = op .Slice (value_3d , [0 ], [0 ], seq_axis )
47+ # attn_output = op.Slice(value_3d, [0], [0], seq_axis)
48+ seq_attn = op .SequenceEmpty (dtype = itype )
4749 for i_patch in range (num_patches ):
4850 i_1d = op .Reshape (i_patch , [1 ])
4951 i_plus_1_1d = i_1d + 1
@@ -59,7 +61,9 @@ def LoopMHAAttention(
5961 num_heads = num_heads ,
6062 scale = scaling ,
6163 )
62- attn_output = op .Concat (attn_output , mha_output , axis = 1 )
64+ # attn_output = op.Concat(attn_output, mha_output, axis=1)
65+ seq_attn = op .SequenceInsert (seq_attn , mha_output )
66+ attn_output = op .ConcatFromSequence (seq_attn , axis = 1 )
6367 attn_output_4d = op .Reshape (attn_output , output_shape )
6468 return attn_output_4d
6569
@@ -128,6 +132,7 @@ def qwen_sdpa_attention(
128132 cu_seqlens : torch .Tensor , # F7su19
129133 scaling : float = 0 ,
130134 num_heads : int = 16 ,
135+ itype : int = onnx .TensorProto .FLOAT ,
131136 ) -> torch .Tensor :
132137 lengths = cu_seqlens [1 :] - cu_seqlens [:- 1 ]
133138 splits = [
@@ -162,7 +167,7 @@ def qwen_sdpa_attention(
162167 _add_com_microsoft_opset (PackedAttention .to_function_proto ()),
163168 n_inputs = 4 ,
164169 n_outputs = 1 ,
165- kwargs = dict (scaling = 0.11180339887498948 , num_heads = 16 ),
170+ kwargs = dict (scaling = 0.11180339887498948 , num_heads = 16 , itype = onnx . TensorProto . FLOAT ),
166171 name = "qwen_sdpa_attention_packed" ,
167172 )
168173 PLUGS .append (qwen_sdpa_attention_packed_versatile )
@@ -177,7 +182,7 @@ def qwen_sdpa_attention(
177182 _add_com_microsoft_opset (LoopMHAAttention .to_function_proto ()),
178183 n_inputs = 4 ,
179184 n_outputs = 1 ,
180- kwargs = dict (scaling = 0.11180339887498948 , num_heads = 16 ),
185+ kwargs = dict (scaling = 0.11180339887498948 , num_heads = 16 , itype = onnx . TensorProto . FLOAT ),
181186 name = "qwen_sdpa_attention_loopmha" ,
182187 )
183188 PLUGS .append (qwen_sdpa_attention_loopmha_versatile )
@@ -561,6 +566,15 @@ def forward(
561566 cu_seqlens ,
562567 self .scaling ,
563568 self .num_heads ,
569+ (
570+ onnx .TensorProto .FLOAT
571+ if query_states .dtype == torch .float32
572+ else (
573+ onnx .TensorProto .FLOAT16
574+ if query_states .dtype == torch .float16
575+ else onnx .TensorProto .BFLOAT16
576+ )
577+ ),
564578 )
565579
566580 # to rewrite later with a for loop
0 commit comments