Why doesn't the backward of the fused attention kernel account for the normalization constant in the softmax function? #4629
Unanswered
jeffwillette
asked this question in
Q&A
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
The fused softmax tutorial (https://triton-lang.org/main/getting-started/tutorials/06-fused-attention.html#sphx-glr-getting-started-tutorials-06-fused-attention-py) shows a backward implementation.
In this implementation, when calculating the derivative with respect to$V$ , we would expect to see the attention matrix $A$ , as we are calculating $\frac{\partial}{\partial V} AV = A$ . However, it looks as if the implementation ignores the normalization constant (sum over the rows) of the attention matrix, and just calculates $A$ as $QK^\top - max(QK^\top, dim=1)$ instead of $\frac{QK^\top - max(QK^\top, dim=1)}{sum(QK^\top - max(QK^\top, dim=1), dim=1)}$
triton/python/tutorials/06-fused-attention.py
Lines 235 to 245 in 0e3cadd
Tests pass, and this appears to be equivalent to the eager attention backward. My question is why? Is there some line I am missing which incorporates the normalization constant, or is it just safely ignored because it doesn't change the output that much?
Beta Was this translation helpful? Give feedback.
All reactions