-
Notifications
You must be signed in to change notification settings - Fork 0
Rel. pos. MHSA: use fused op for attention computation #88
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
324fc4a to
358de26
Compare
eac4313 to
58699f6
Compare
|
Did you check what SDPA backend it would actually use? And what it does use? I was checking a bit the logic. I think you can see that here: https://github.com/pytorch/pytorch/blob/7ba4680f3755a560af81aa0f688791e367aa3609/aten/src/ATen/native/transformers/attention.cpp#L718 I think for example Flash Attention will not be used, because Flash Attention does not support |
|
Yes, flash attention is not used due to the mask. I found in profiling it will use "memory efficient" attention. I've now made some changes that should make the code compatible with cuDNN attention on the relevant GPUs (mainly around proper bfloat16/float16 support), but it seems like the AppTek runtime image is built in such a way that cuDNN attention is disabled (or it's disabled for H200s?), so I wasn't able to test properly. I think I can merge once tests pass, but I'll leave some time in case you want to re-review the dtype changes. |
Are you sure? That would be very suboptimal for everything you compute there. You should use CuDNN for a lot of other things as well. |
|
Ah, sorry, I was unclear. I was referring just to cuDNN fused attention. For cuDNN in general I don't know, probably is alright. As you can see in the test output, it unfortunately does not print the reason why cuDNN attention is not being used. Python 3.13.7 (main, Sep 18 2025, 16:28:29) [GCC 13.3.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch
>>> torch.__version__
'2.8.0' |
You can go through all of the conditions and just check (print) them yourself. |
|
I searched for SDPA input validation functions that don't print a debug message if the check goes wrong. I found the one that checks the attention mask. It doesn't print anything when it fails in some cases. It seems like if the attention mask requires a gradient (which here it does since we're adding the B and D matrices to it), it cannot be used with cuDNN attention. So I guess we can only use memory efficient attention here. |
For efficiency in training, and also because torch's
scaled_dot_product_attentionis automatically exported into an optimized ONNX attention op.Open Qs:
Tests pass, so the output continues to be
torch.allclose(...)to the ESPNet output even when the fused op is used.