Skip to content

[WIP] Add FP32 softmax support in unified attention#1040

Draft
afierka-intel wants to merge 2 commits intomainfrom
dev/afierka/fp32-unified
Draft

[WIP] Add FP32 softmax support in unified attention#1040
afierka-intel wants to merge 2 commits intomainfrom
dev/afierka/fp32-unified

Conversation

@afierka-intel
Copy link
Collaborator

No description provided.

Signed-off-by: Artur Fierka <artur.fierka@intel.com>
Copilot AI review requested due to automatic review settings February 25, 2026 11:38
@github-actions
Copy link

🚧 CI Blocked

The main CI workflow was not started for the following reason:

This is a Draft PR. Please mark it as 'Ready for Review' to trigger the CI.

@github-actions
Copy link

🚧 CI Blocked

The main CI workflow was not started for the following reason:

This is a Draft PR. Please mark it as 'Ready for Review' to trigger the CI.

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds initial FP32-softmax enablement for the HPU unified attention path by promoting the QK logits computation to float32 under a feature flag and relaxing an existing backend restriction.

Changes:

  • Add fp32_softmax handling to unified attention partial paths (causal/shared/unique), including optional out= buffers for QK matmul.
  • Insert graph breaks in fp32 paths to control compilation boundaries.
  • Remove the fp32 softmax “unsupported feature” gate for HPUUnifiedAttentionImpl.

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 7 comments.

File Description
vllm_gaudi/extension/unified.py Adds FP32-softmax branches for QK logits computation across partial attention routines (causal/shared/unique).
vllm_gaudi/attention/backends/hpu_attn.py Removes the fp32 softmax unsupported-feature check for the unified attention backend implementation.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

torch._dynamo.graph_break()
else:
attn = torch.matmul(query, key.transpose(-1, -2))
attn = attn.flatten(0, 1)
Copy link

Copilot AI Feb 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When fp32_softmax is enabled and the matmul output is float32, bias is still passed in the original dtype and then added to attn. Aligning bias to attn.dtype before the add avoids mixed-dtype adds and matches the established handling in extension/ops.py (which casts block/position bias when attn.dtype != bias.dtype).

Suggested change
attn = attn.flatten(0, 1)
attn = attn.flatten(0, 1)
if attn.dtype != bias.dtype:
bias = bias.to(attn.dtype)

Copilot uses AI. Check for mistakes.
Comment on lines +428 to 430
# TODO: add downcasting attn to original dtype
attn = torch.matmul(attn.unflatten(0, (kv_heads if not is_mla else num_heads, -1)), value).flatten(0, 1)

Copy link

Copilot AI Feb 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as the causal path: with fp32_softmax enabled, the attention weights/value product is currently performed in float32 (TODO mentions missing downcast). This will tend to make the merged output float32 as well. Consider downcasting the exp weights and/or the attn output back to the original dtype (typically value.dtype / model dtype) after computing local_sum/local_max in fp32.

Copilot uses AI. Check for mistakes.
torch._dynamo.graph_break()
else:
attn = torch.matmul(query, key.transpose(-1, -2))
attn = attn + bias.unsqueeze(1).unsqueeze(1).unsqueeze(1)
Copy link

Copilot AI Feb 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If fp32_softmax is enabled, attn becomes float32 but bias is still in the original dtype when added here. To avoid mixed-dtype adds (and match the pattern used elsewhere in the codebase), cast this bias term to attn.dtype before adding.

Suggested change
attn = attn + bias.unsqueeze(1).unsqueeze(1).unsqueeze(1)
attn = attn + bias.to(attn.dtype).unsqueeze(1).unsqueeze(1).unsqueeze(1)

Copilot uses AI. Check for mistakes.
Comment on lines 703 to 707
attn = torch.exp(attn - block_max.unsqueeze(-1))
# TODO: (afierka) add downcasting attn to original dtype
block_sum = attn.sum(-1)
attn = torch.matmul(attn, value)

Copy link

Copilot AI Feb 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The fp32_softmax path currently leaves the exp(attn) weights and the subsequent attn @ value in float32 (TODO). Besides the output dtype change, this also interacts with block2batch() later since block_mapping_2d is built in query.dtype; mixed-dtype matmul can be problematic on HPU. Consider downcasting back to the original dtype before the matmul with value / block2batch, and/or ensure block_mapping_2d is created in the same dtype as the tensor it multiplies.

Copilot uses AI. Check for mistakes.
@@ -929,7 +929,6 @@ def __init__(
'Sliding window': sliding_window is not None,
'non-GQA attention': num_kv_heads is None,
'Encoder attn': attn_type != AttentionType.DECODER,
Copy link

Copilot AI Feb 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR removes the fp32 softmax unsupported-feature gate for HPUUnifiedAttentionImpl, but HPUUnifiedMLAImpl still rejects get_config().fp32_softmax later in this same file. If the intention is to support fp32 softmax across unified attention (including MLA), that remaining gate will still raise NotImplementedError for MLA configurations; consider removing or appropriately gating it as well.

Suggested change
'Encoder attn': attn_type != AttentionType.DECODER,
'Encoder attn': attn_type != AttentionType.DECODER,
'fp32 softmax': get_config().fp32_softmax,

Copilot uses AI. Check for mistakes.
if get_config().fp32_softmax:
s_attn = torch.empty(hpu_ops.matmul_shape(q, k.transpose(-1, -2)), dtype=torch.float32, device=q.device)
s_attn = torch.matmul(q, k.transpose(-1, -2), out=s_attn)
s_attn = s_attn + b.unsqueeze(0).unsqueeze(0)
Copy link

Copilot AI Feb 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the use_output_tensor_in_matmulqk path, when fp32_softmax is enabled s_attn becomes float32, but the bias slice b is still in the original dtype. On HPU we already explicitly cast biases in the analogous path in extension/ops.py; doing the same here avoids mixed-dtype adds (and potential kernel/type issues) and makes the intent explicit. Consider casting b.unsqueeze(0).unsqueeze(0) to s_attn.dtype (or float32 when fp32_softmax) before adding.

Suggested change
s_attn = s_attn + b.unsqueeze(0).unsqueeze(0)
bias_term = b.unsqueeze(0).unsqueeze(0).to(s_attn.dtype)
s_attn = s_attn + bias_term

Copilot uses AI. Check for mistakes.
Comment on lines 329 to +330

# TODO: add downcasting attn to original dtype
Copy link

Copilot AI Feb 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With fp32_softmax enabled, the code leaves the exp(attn) weights / unnormalized weighted-V in float32 (see TODO). This changes the dtype of the partial outputs and will propagate to unified_attn/unified_mla outputs (e.g., division by a float32 sum yields float32), which can break downstream layers that expect the model dtype and also increases memory/compute. Add an explicit downcast back to the original dtype at an appropriate point (commonly: keep max/sum in fp32 for stability, but cast the exp weights and/or final attention output back to query/value.dtype).

Suggested change
# TODO: add downcasting attn to original dtype
# Keep max/sum in fp32 for stability when fp32_softmax is enabled,
# but cast the exponentiated attention weights back to the original
# value dtype before the attention matmul to avoid dtype propagation.
if get_config().fp32_softmax and s_attn.dtype == torch.float32:
s_attn = s_attn.to(v.dtype)

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants