[WIP] Add FP32 softmax support in unified attention#1040
[WIP] Add FP32 softmax support in unified attention#1040afierka-intel wants to merge 2 commits intomainfrom
Conversation
Signed-off-by: Artur Fierka <artur.fierka@intel.com>
🚧 CI BlockedThe main CI workflow was not started for the following reason:
|
🚧 CI BlockedThe main CI workflow was not started for the following reason:
|
There was a problem hiding this comment.
Pull request overview
Adds initial FP32-softmax enablement for the HPU unified attention path by promoting the QK logits computation to float32 under a feature flag and relaxing an existing backend restriction.
Changes:
- Add
fp32_softmaxhandling to unified attention partial paths (causal/shared/unique), including optionalout=buffers for QK matmul. - Insert graph breaks in fp32 paths to control compilation boundaries.
- Remove the
fp32 softmax“unsupported feature” gate forHPUUnifiedAttentionImpl.
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 7 comments.
| File | Description |
|---|---|
vllm_gaudi/extension/unified.py |
Adds FP32-softmax branches for QK logits computation across partial attention routines (causal/shared/unique). |
vllm_gaudi/attention/backends/hpu_attn.py |
Removes the fp32 softmax unsupported-feature check for the unified attention backend implementation. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| torch._dynamo.graph_break() | ||
| else: | ||
| attn = torch.matmul(query, key.transpose(-1, -2)) | ||
| attn = attn.flatten(0, 1) |
There was a problem hiding this comment.
When fp32_softmax is enabled and the matmul output is float32, bias is still passed in the original dtype and then added to attn. Aligning bias to attn.dtype before the add avoids mixed-dtype adds and matches the established handling in extension/ops.py (which casts block/position bias when attn.dtype != bias.dtype).
| attn = attn.flatten(0, 1) | |
| attn = attn.flatten(0, 1) | |
| if attn.dtype != bias.dtype: | |
| bias = bias.to(attn.dtype) |
| # TODO: add downcasting attn to original dtype | ||
| attn = torch.matmul(attn.unflatten(0, (kv_heads if not is_mla else num_heads, -1)), value).flatten(0, 1) | ||
|
|
There was a problem hiding this comment.
Same as the causal path: with fp32_softmax enabled, the attention weights/value product is currently performed in float32 (TODO mentions missing downcast). This will tend to make the merged output float32 as well. Consider downcasting the exp weights and/or the attn output back to the original dtype (typically value.dtype / model dtype) after computing local_sum/local_max in fp32.
| torch._dynamo.graph_break() | ||
| else: | ||
| attn = torch.matmul(query, key.transpose(-1, -2)) | ||
| attn = attn + bias.unsqueeze(1).unsqueeze(1).unsqueeze(1) |
There was a problem hiding this comment.
If fp32_softmax is enabled, attn becomes float32 but bias is still in the original dtype when added here. To avoid mixed-dtype adds (and match the pattern used elsewhere in the codebase), cast this bias term to attn.dtype before adding.
| attn = attn + bias.unsqueeze(1).unsqueeze(1).unsqueeze(1) | |
| attn = attn + bias.to(attn.dtype).unsqueeze(1).unsqueeze(1).unsqueeze(1) |
| attn = torch.exp(attn - block_max.unsqueeze(-1)) | ||
| # TODO: (afierka) add downcasting attn to original dtype | ||
| block_sum = attn.sum(-1) | ||
| attn = torch.matmul(attn, value) | ||
|
|
There was a problem hiding this comment.
The fp32_softmax path currently leaves the exp(attn) weights and the subsequent attn @ value in float32 (TODO). Besides the output dtype change, this also interacts with block2batch() later since block_mapping_2d is built in query.dtype; mixed-dtype matmul can be problematic on HPU. Consider downcasting back to the original dtype before the matmul with value / block2batch, and/or ensure block_mapping_2d is created in the same dtype as the tensor it multiplies.
| @@ -929,7 +929,6 @@ def __init__( | |||
| 'Sliding window': sliding_window is not None, | |||
| 'non-GQA attention': num_kv_heads is None, | |||
| 'Encoder attn': attn_type != AttentionType.DECODER, | |||
There was a problem hiding this comment.
This PR removes the fp32 softmax unsupported-feature gate for HPUUnifiedAttentionImpl, but HPUUnifiedMLAImpl still rejects get_config().fp32_softmax later in this same file. If the intention is to support fp32 softmax across unified attention (including MLA), that remaining gate will still raise NotImplementedError for MLA configurations; consider removing or appropriately gating it as well.
| 'Encoder attn': attn_type != AttentionType.DECODER, | |
| 'Encoder attn': attn_type != AttentionType.DECODER, | |
| 'fp32 softmax': get_config().fp32_softmax, |
| if get_config().fp32_softmax: | ||
| s_attn = torch.empty(hpu_ops.matmul_shape(q, k.transpose(-1, -2)), dtype=torch.float32, device=q.device) | ||
| s_attn = torch.matmul(q, k.transpose(-1, -2), out=s_attn) | ||
| s_attn = s_attn + b.unsqueeze(0).unsqueeze(0) |
There was a problem hiding this comment.
In the use_output_tensor_in_matmulqk path, when fp32_softmax is enabled s_attn becomes float32, but the bias slice b is still in the original dtype. On HPU we already explicitly cast biases in the analogous path in extension/ops.py; doing the same here avoids mixed-dtype adds (and potential kernel/type issues) and makes the intent explicit. Consider casting b.unsqueeze(0).unsqueeze(0) to s_attn.dtype (or float32 when fp32_softmax) before adding.
| s_attn = s_attn + b.unsqueeze(0).unsqueeze(0) | |
| bias_term = b.unsqueeze(0).unsqueeze(0).to(s_attn.dtype) | |
| s_attn = s_attn + bias_term |
|
|
||
| # TODO: add downcasting attn to original dtype |
There was a problem hiding this comment.
With fp32_softmax enabled, the code leaves the exp(attn) weights / unnormalized weighted-V in float32 (see TODO). This changes the dtype of the partial outputs and will propagate to unified_attn/unified_mla outputs (e.g., division by a float32 sum yields float32), which can break downstream layers that expect the model dtype and also increases memory/compute. Add an explicit downcast back to the original dtype at an appropriate point (commonly: keep max/sum in fp32 for stability, but cast the exp weights and/or final attention output back to query/value.dtype).
| # TODO: add downcasting attn to original dtype | |
| # Keep max/sum in fp32 for stability when fp32_softmax is enabled, | |
| # but cast the exponentiated attention weights back to the original | |
| # value dtype before the attention matmul to avoid dtype propagation. | |
| if get_config().fp32_softmax and s_attn.dtype == torch.float32: | |
| s_attn = s_attn.to(v.dtype) |
No description provided.