Skip to content

Conversation

@NeoLegends
Copy link
Member

@NeoLegends NeoLegends commented Dec 3, 2025

For efficiency in training, and also because torch's scaled_dot_product_attention is automatically exported into an optimized ONNX attention op.

Open Qs:

  • Is this actually more efficient in reality?
  • I am still doing the integration of the relative positional encoding outside of the op, because I can only have one additional summand (via the mask parameter) and not an additional factor + sum. Can this be done better?
  • Does this need a flag to be turned on and off (i.e. switched back to a non-fused implementation)?

Tests pass, so the output continues to be torch.allclose(...) to the ESPNet output even when the fused op is used.

@NeoLegends NeoLegends self-assigned this Dec 3, 2025
@NeoLegends NeoLegends force-pushed the moritz-rel-pos-conf-sdpa branch from 324fc4a to 358de26 Compare December 3, 2025 13:35
@NeoLegends NeoLegends force-pushed the moritz-rel-pos-conf-sdpa branch from eac4313 to 58699f6 Compare December 3, 2025 13:47
@NeoLegends NeoLegends marked this pull request as ready for review December 3, 2025 13:57
@NeoLegends NeoLegends changed the title MHSA: use fused SDPA for attention computation Rel. pos. MHSA: use fused SDPA for attention computation Dec 3, 2025
@NeoLegends NeoLegends changed the title Rel. pos. MHSA: use fused SDPA for attention computation Rel. pos. MHSA: use fused op for attention computation Dec 3, 2025
@albertz
Copy link
Member

albertz commented Dec 3, 2025

Did you check what SDPA backend it would actually use? And what it does use?

I was checking a bit the logic. I think you can see that here:

https://github.com/pytorch/pytorch/blob/7ba4680f3755a560af81aa0f688791e367aa3609/aten/src/ATen/native/transformers/attention.cpp#L718
https://github.com/pytorch/pytorch/blob/e3f24fd73ad74c6e7176687986436956c7c18235/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp#L764

I think for example Flash Attention will not be used, because Flash Attention does not support attn_mask.

@NeoLegends
Copy link
Member Author

NeoLegends commented Dec 22, 2025

Yes, flash attention is not used due to the mask. I found in profiling it will use "memory efficient" attention.

I've now made some changes that should make the code compatible with cuDNN attention on the relevant GPUs (mainly around proper bfloat16/float16 support), but it seems like the AppTek runtime image is built in such a way that cuDNN attention is disabled (or it's disabled for H200s?), so I wasn't able to test properly.

I think I can merge once tests pass, but I'll leave some time in case you want to re-review the dtype changes.

test-output.txt

@albertz
Copy link
Member

albertz commented Dec 22, 2025

it seems like the AppTek runtime image is built in such a way that cuDNN is disabled

Are you sure? That would be very suboptimal for everything you compute there. You should use CuDNN for a lot of other things as well.

@NeoLegends
Copy link
Member Author

NeoLegends commented Dec 22, 2025

Ah, sorry, I was unclear. I was referring just to cuDNN fused attention. For cuDNN in general I don't know, probably is alright. As you can see in the test output, it unfortunately does not print the reason why cuDNN attention is not being used.

Python 3.13.7 (main, Sep 18 2025, 16:28:29) [GCC 13.3.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch
>>> torch.__version__
'2.8.0'

@albertz
Copy link
Member

albertz commented Dec 22, 2025

As you can see in the test output, it unfortunately does not print the reason why cuDNN attention is not being used.

You can go through all of the conditions and just check (print) them yourself.

@NeoLegends
Copy link
Member Author

NeoLegends commented Dec 22, 2025

I searched for SDPA input validation functions that don't print a debug message if the check goes wrong. I found the one that checks the attention mask. It doesn't print anything when it fails in some cases.
https://github.com/pytorch/pytorch/blob/229d33f7f9b8abcdd9ba17777eff2f2dbbe4afc9/aten/src/ATen/native/transformers/sdp_utils_cpp.h#L269

It seems like if the attention mask requires a gradient (which here it does since we're adding the B and D matrices to it), it cannot be used with cuDNN attention. So I guess we can only use memory efficient attention here.

@NeoLegends NeoLegends merged commit 46fe27b into main Jan 20, 2026
2 checks passed
@NeoLegends NeoLegends deleted the moritz-rel-pos-conf-sdpa branch January 20, 2026 10:53
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants