@@ -32,6 +32,7 @@ def LoopMHAAttention(
3232 cu_seqlens ,
3333 scaling : float = 0.11180339887498948 ,
3434 num_heads : int = 16 ,
35+ itype : int = onnx .TensorProto .FLOAT ,
3536 ):
3637 to_3d_shape = op .Constant (value_ints = [0 , 0 , - 1 ])
3738 query_transposed = op .Transpose (query_states , perm = [0 , 2 , 1 , 3 ])
@@ -44,10 +45,7 @@ def LoopMHAAttention(
4445 seq_axis = op .Constant (value_ints = [1 ])
4546 seq_axis_int32 = op .Cast (seq_axis , to = onnx .TensorProto .INT32 )
4647 # attn_output = op.Slice(value_3d, [0], [0], seq_axis)
47- # SequenceEmpty needs dtype to be filled but it should be possible
48- # to leave it empty and just ensure that all tensors stored
49- # in the sequence share the same type.
50- seq_attn = op .SequenceEmpty () # dtype=onnx.TensorProto.FLOAT16)
48+ seq_attn = op .SequenceEmpty (dtype = itype )
5149 for i_patch in range (num_patches ):
5250 i_1d = op .Reshape (i_patch , [1 ])
5351 i_plus_1_1d = i_1d + 1
@@ -134,6 +132,7 @@ def qwen_sdpa_attention(
134132 cu_seqlens : torch .Tensor , # F7su19
135133 scaling : float = 0 ,
136134 num_heads : int = 16 ,
135+ itype : int = onnx .TensorProto .FLOAT ,
137136 ) -> torch .Tensor :
138137 lengths = cu_seqlens [1 :] - cu_seqlens [:- 1 ]
139138 splits = [
@@ -168,7 +167,7 @@ def qwen_sdpa_attention(
168167 _add_com_microsoft_opset (PackedAttention .to_function_proto ()),
169168 n_inputs = 4 ,
170169 n_outputs = 1 ,
171- kwargs = dict (scaling = 0.11180339887498948 , num_heads = 16 ),
170+ kwargs = dict (scaling = 0.11180339887498948 , num_heads = 16 , itype = onnx . TensorProto . FLOAT ),
172171 name = "qwen_sdpa_attention_packed" ,
173172 )
174173 PLUGS .append (qwen_sdpa_attention_packed_versatile )
@@ -183,7 +182,7 @@ def qwen_sdpa_attention(
183182 _add_com_microsoft_opset (LoopMHAAttention .to_function_proto ()),
184183 n_inputs = 4 ,
185184 n_outputs = 1 ,
186- kwargs = dict (scaling = 0.11180339887498948 , num_heads = 16 ),
185+ kwargs = dict (scaling = 0.11180339887498948 , num_heads = 16 , itype = onnx . TensorProto . FLOAT ),
187186 name = "qwen_sdpa_attention_loopmha" ,
188187 )
189188 PLUGS .append (qwen_sdpa_attention_loopmha_versatile )
@@ -567,6 +566,15 @@ def forward(
567566 cu_seqlens ,
568567 self .scaling ,
569568 self .num_heads ,
569+ (
570+ onnx .TensorProto .FLOAT
571+ if query_states .dtype == torch .float32
572+ else (
573+ onnx .TensorProto .FLOAT16
574+ if query_states .dtype == torch .float16
575+ else onnx .TensorProto .BFLOAT16
576+ )
577+ ),
570578 )
571579
572580 # to rewrite later with a for loop
0 commit comments