-
Notifications
You must be signed in to change notification settings - Fork 109
[WIP] Add FP32 softmax support in unified attention #1040
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -299,7 +299,20 @@ def partial_attn_causal(query: torch.tensor, | |||||||||||||||
| v = value[:, :, 0:q_max, :] | ||||||||||||||||
| b = bias[q_min:q_max, 0:q_max] | ||||||||||||||||
|
|
||||||||||||||||
| s_attn = torch.matmul(q, k.transpose(-1, -2)) + b.unsqueeze(0).unsqueeze(0) | ||||||||||||||||
| if get_config().use_output_tensor_in_matmulqk: | ||||||||||||||||
| s_attn = None | ||||||||||||||||
| if get_config().fp32_softmax: | ||||||||||||||||
| s_attn = torch.empty(hpu_ops.matmul_shape(q, k.transpose(-1, -2)), dtype=torch.float32, device=q.device) | ||||||||||||||||
| s_attn = torch.matmul(q, k.transpose(-1, -2), out=s_attn) | ||||||||||||||||
| s_attn = s_attn + b.unsqueeze(0).unsqueeze(0) | ||||||||||||||||
|
||||||||||||||||
| s_attn = s_attn + b.unsqueeze(0).unsqueeze(0) | |
| bias_term = b.unsqueeze(0).unsqueeze(0).to(s_attn.dtype) | |
| s_attn = s_attn + bias_term |
Copilot
AI
Feb 25, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
With fp32_softmax enabled, the code leaves the exp(attn) weights / unnormalized weighted-V in float32 (see TODO). This changes the dtype of the partial outputs and will propagate to unified_attn/unified_mla outputs (e.g., division by a float32 sum yields float32), which can break downstream layers that expect the model dtype and also increases memory/compute. Add an explicit downcast back to the original dtype at an appropriate point (commonly: keep max/sum in fp32 for stability, but cast the exp weights and/or final attention output back to query/value.dtype).
| # TODO: add downcasting attn to original dtype | |
| # Keep max/sum in fp32 for stability when fp32_softmax is enabled, | |
| # but cast the exponentiated attention weights back to the original | |
| # value dtype before the attention matmul to avoid dtype propagation. | |
| if get_config().fp32_softmax and s_attn.dtype == torch.float32: | |
| s_attn = s_attn.to(v.dtype) |
Copilot
AI
Feb 25, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When fp32_softmax is enabled and the matmul output is float32, bias is still passed in the original dtype and then added to attn. Aligning bias to attn.dtype before the add avoids mixed-dtype adds and matches the established handling in extension/ops.py (which casts block/position bias when attn.dtype != bias.dtype).
| attn = attn.flatten(0, 1) | |
| attn = attn.flatten(0, 1) | |
| if attn.dtype != bias.dtype: | |
| bias = bias.to(attn.dtype) |
Copilot
AI
Feb 25, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as the causal path: with fp32_softmax enabled, the attention weights/value product is currently performed in float32 (TODO mentions missing downcast). This will tend to make the merged output float32 as well. Consider downcasting the exp weights and/or the attn output back to the original dtype (typically value.dtype / model dtype) after computing local_sum/local_max in fp32.
Copilot
AI
Feb 25, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If fp32_softmax is enabled, attn becomes float32 but bias is still in the original dtype when added here. To avoid mixed-dtype adds (and match the pattern used elsewhere in the codebase), cast this bias term to attn.dtype before adding.
| attn = attn + bias.unsqueeze(1).unsqueeze(1).unsqueeze(1) | |
| attn = attn + bias.to(attn.dtype).unsqueeze(1).unsqueeze(1).unsqueeze(1) |
Copilot
AI
Feb 25, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The fp32_softmax path currently leaves the exp(attn) weights and the subsequent attn @ value in float32 (TODO). Besides the output dtype change, this also interacts with block2batch() later since block_mapping_2d is built in query.dtype; mixed-dtype matmul can be problematic on HPU. Consider downcasting back to the original dtype before the matmul with value / block2batch, and/or ensure block_mapping_2d is created in the same dtype as the tensor it multiplies.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This PR removes the
fp32 softmaxunsupported-feature gate forHPUUnifiedAttentionImpl, butHPUUnifiedMLAImplstill rejectsget_config().fp32_softmaxlater in this same file. If the intention is to support fp32 softmax across unified attention (including MLA), that remaining gate will still raiseNotImplementedErrorfor MLA configurations; consider removing or appropriately gating it as well.