Nothing to attend to after combining the causal mask and the padding mask in static batching #808
Replies: 3 comments
-
(Update, 2025/09/06) Finally found this interesting thread, but I'm still curious about whether there's any recommended way to handle this problem in modern LLM architectures. Thanks! |
Beta Was this translation helpful? Give feedback.
-
Hi there, I haven't had a chance to read through the thread you linked, yet. But let me share a few thoughts. In addition to that, I recently implemented a batched version with left-padding for Qwen3 here: https://github.com/rasbt/reasoning-from-scratch/blob/main/reasoning_from_scratch/qwen3_batched.py (I recommend looking at a file diff between qwen3.py and qwen3_batched.py to see the relevant lines more easily) ![]() So, what I had to do here is to implement a more stable version of softmax that uses a large negative value instead of -inf to avoid the 0 issue: # More numerically stable attention
attn_scores = queries @ keys.transpose(2, 3)
# Use large negative sentinel instead of -inf for stable softmax when a row is fully masked
attn_scores = attn_scores.masked_fill(mask, -1e9)
attn_scores = attn_scores / (self.head_dim ** 0.5)
attn_weights = torch.softmax(attn_scores, dim=-1)
# Zero out masked positions post-softmax and renormalize to keep sums ~1 where possible
attn_weights = attn_weights.masked_fill(mask, 0.0)
denom = attn_weights.sum(dim=-1, keepdim=True).clamp(min=1e-9)
attn_weights = attn_weights / denom It seems to work relatively well, but I am not sure if that's the best solution. I'd be happy to hear any suggestions or feedback. |
Beta Was this translation helpful? Give feedback.
-
Thanks @rasbt that's good to know you are also using a small value, I was myself using |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
-
Hi all,
I’m implementing the Llama 3 model architecture and ran into a problem with the masking mechanism during the prefill phase.
Since I’m using static batching, all sequences in the batch are padded to the same length. For scaled dot-product attention, I apply both a padding mask and a causal mask.
Consider this example:
Sequence length = 2
Max sequence length in the current batch = 3
Padding mask
Using the attention bias trick, the combined bias looks like this:
Here, the sequence starts with a single pad token in the first position, so there’s nothing valid to attend to at that step.
Then I add the bias:
This masks out invalid positions. The issue is that the first row becomes all
-inf
, which means after softmax the result is allnan
. That propagates forward, making the hidden state at that positionnan
too. Passing this into the next decoder layer is clearly invalid.My question: Is it reasonable to replace these
nan
outputs with zeros, or is there a standard approach/reference for handling this situation?Thanks a lot.
Beta Was this translation helpful? Give feedback.
All reactions