Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion vllm_gaudi/attention/backends/hpu_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -929,7 +929,6 @@ def __init__(
'Sliding window': sliding_window is not None,
'non-GQA attention': num_kv_heads is None,
'Encoder attn': attn_type != AttentionType.DECODER,
Copy link

Copilot AI Feb 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR removes the fp32 softmax unsupported-feature gate for HPUUnifiedAttentionImpl, but HPUUnifiedMLAImpl still rejects get_config().fp32_softmax later in this same file. If the intention is to support fp32 softmax across unified attention (including MLA), that remaining gate will still raise NotImplementedError for MLA configurations; consider removing or appropriately gating it as well.

Suggested change
'Encoder attn': attn_type != AttentionType.DECODER,
'Encoder attn': attn_type != AttentionType.DECODER,
'fp32 softmax': get_config().fp32_softmax,

Copilot uses AI. Check for mistakes.
'fp32 softmax': get_config().fp32_softmax,
}
for feature, check in unsupported_features.items():
if check:
Expand Down
47 changes: 43 additions & 4 deletions vllm_gaudi/extension/unified.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,20 @@ def partial_attn_causal(query: torch.tensor,
v = value[:, :, 0:q_max, :]
b = bias[q_min:q_max, 0:q_max]

s_attn = torch.matmul(q, k.transpose(-1, -2)) + b.unsqueeze(0).unsqueeze(0)
if get_config().use_output_tensor_in_matmulqk:
s_attn = None
if get_config().fp32_softmax:
s_attn = torch.empty(hpu_ops.matmul_shape(q, k.transpose(-1, -2)), dtype=torch.float32, device=q.device)
s_attn = torch.matmul(q, k.transpose(-1, -2), out=s_attn)
s_attn = s_attn + b.unsqueeze(0).unsqueeze(0)
Copy link

Copilot AI Feb 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the use_output_tensor_in_matmulqk path, when fp32_softmax is enabled s_attn becomes float32, but the bias slice b is still in the original dtype. On HPU we already explicitly cast biases in the analogous path in extension/ops.py; doing the same here avoids mixed-dtype adds (and potential kernel/type issues) and makes the intent explicit. Consider casting b.unsqueeze(0).unsqueeze(0) to s_attn.dtype (or float32 when fp32_softmax) before adding.

Suggested change
s_attn = s_attn + b.unsqueeze(0).unsqueeze(0)
bias_term = b.unsqueeze(0).unsqueeze(0).to(s_attn.dtype)
s_attn = s_attn + bias_term

Copilot uses AI. Check for mistakes.
elif get_config().fp32_softmax:
s_attn = torch.matmul(q, k.transpose(-1, -2))
s_attn = s_attn.float()
s_attn = s_attn + b.unsqueeze(0).unsqueeze(0).float()
torch._dynamo.graph_break()
else:
s_attn = torch.matmul(q, k.transpose(-1, -2)) + b.unsqueeze(0).unsqueeze(0)

# TODO: remove dtype check once full support is added for fp8 in unified attention
if get_config().unified_attn_softmax_fa2 and s_attn.dtype == torch.bfloat16:
inputM_hpu, inputL_hpu = create_softmax_fa2_input_tensors(s_attn, fmin, inputL_hpu_tensors,
Expand All @@ -314,6 +327,7 @@ def partial_attn_causal(query: torch.tensor,
s_attn = torch.exp(s_attn - s_max.unsqueeze(-1))
s_sum = torch.sum(s_attn, -1)

# TODO: add downcasting attn to original dtype
Comment on lines 329 to +330
Copy link

Copilot AI Feb 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With fp32_softmax enabled, the code leaves the exp(attn) weights / unnormalized weighted-V in float32 (see TODO). This changes the dtype of the partial outputs and will propagate to unified_attn/unified_mla outputs (e.g., division by a float32 sum yields float32), which can break downstream layers that expect the model dtype and also increases memory/compute. Add an explicit downcast back to the original dtype at an appropriate point (commonly: keep max/sum in fp32 for stability, but cast the exp weights and/or final attention output back to query/value.dtype).

Suggested change
# TODO: add downcasting attn to original dtype
# Keep max/sum in fp32 for stability when fp32_softmax is enabled,
# but cast the exponentiated attention weights back to the original
# value dtype before the attention matmul to avoid dtype propagation.
if get_config().fp32_softmax and s_attn.dtype == torch.float32:
s_attn = s_attn.to(v.dtype)

Copilot uses AI. Check for mistakes.
# Attention: s_attn @ v
s_attn = torch.matmul(s_attn, v)

Expand Down Expand Up @@ -383,7 +397,19 @@ def _partial_attn_shared_core(query: torch.tensor,
"""
num_heads = query.size(0) * query.size(1) if not is_mla else query.size(0)

attn = torch.matmul(query, key.transpose(-1, -2))
if get_config().use_output_tensor_in_matmulqk:
attn = None
if get_config().fp32_softmax:
attn = torch.empty(hpu_ops.matmul_shape(query, key.transpose(-1, -2)),
dtype=torch.float32,
device=query.device)
attn = torch.matmul(query, key.transpose(-1, -2), out=attn)
elif get_config().fp32_softmax:
attn = torch.matmul(query, key.transpose(-1, -2))
attn = attn.float()
torch._dynamo.graph_break()
else:
attn = torch.matmul(query, key.transpose(-1, -2))
attn = attn.flatten(0, 1)
Copy link

Copilot AI Feb 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When fp32_softmax is enabled and the matmul output is float32, bias is still passed in the original dtype and then added to attn. Aligning bias to attn.dtype before the add avoids mixed-dtype adds and matches the established handling in extension/ops.py (which casts block/position bias when attn.dtype != bias.dtype).

Suggested change
attn = attn.flatten(0, 1)
attn = attn.flatten(0, 1)
if attn.dtype != bias.dtype:
bias = bias.to(attn.dtype)

Copilot uses AI. Check for mistakes.
attn = attn + bias

Expand All @@ -399,7 +425,7 @@ def _partial_attn_shared_core(query: torch.tensor,
local_max = torch.maximum(attn.amax(-1), fmin)
attn = torch.exp(attn - local_max.unsqueeze(-1))
local_sum = attn.sum(-1)

# TODO: add downcasting attn to original dtype
attn = torch.matmul(attn.unflatten(0, (kv_heads if not is_mla else num_heads, -1)), value).flatten(0, 1)

Comment on lines +428 to 430
Copy link

Copilot AI Feb 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as the causal path: with fp32_softmax enabled, the attention weights/value product is currently performed in float32 (TODO mentions missing downcast). This will tend to make the merged output float32 as well. Consider downcasting the exp weights and/or the attn output back to the original dtype (typically value.dtype / model dtype) after computing local_sum/local_max in fp32.

Copilot uses AI. Check for mistakes.
# MLA: Extract latent part and project to full V
Expand Down Expand Up @@ -659,10 +685,23 @@ def partial_attn_unique(query: torch.tensor,

block_mapping_2d = torch.nn.functional.one_hot(block_mapping, num_classes=batch_size).to(query.dtype)

attn = torch.matmul(query, key.transpose(-1, -2))
if get_config().use_output_tensor_in_matmulqk:
attn = None
if get_config().fp32_softmax:
attn = torch.empty(hpu_ops.matmul_shape(query, key.transpose(-1, -2)),
dtype=torch.float32,
device=query.device)
attn = torch.matmul(query, key.transpose(-1, -2), out=attn)
elif get_config().fp32_softmax:
attn = torch.matmul(query, key.transpose(-1, -2))
attn = attn.float()
torch._dynamo.graph_break()
else:
attn = torch.matmul(query, key.transpose(-1, -2))
attn = attn + bias.unsqueeze(1).unsqueeze(1).unsqueeze(1)
Copy link

Copilot AI Feb 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If fp32_softmax is enabled, attn becomes float32 but bias is still in the original dtype when added here. To avoid mixed-dtype adds (and match the pattern used elsewhere in the codebase), cast this bias term to attn.dtype before adding.

Suggested change
attn = attn + bias.unsqueeze(1).unsqueeze(1).unsqueeze(1)
attn = attn + bias.to(attn.dtype).unsqueeze(1).unsqueeze(1).unsqueeze(1)

Copilot uses AI. Check for mistakes.
block_max = torch.maximum(attn.amax(-1), fmin)
attn = torch.exp(attn - block_max.unsqueeze(-1))
# TODO: (afierka) add downcasting attn to original dtype
block_sum = attn.sum(-1)
attn = torch.matmul(attn, value)

Comment on lines 703 to 707
Copy link

Copilot AI Feb 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The fp32_softmax path currently leaves the exp(attn) weights and the subsequent attn @ value in float32 (TODO). Besides the output dtype change, this also interacts with block2batch() later since block_mapping_2d is built in query.dtype; mixed-dtype matmul can be problematic on HPU. Consider downcasting back to the original dtype before the matmul with value / block2batch, and/or ensure block_mapping_2d is created in the same dtype as the tensor it multiplies.

Copilot uses AI. Check for mistakes.
Expand Down
Loading