This PR implements the flash attention algorithm from the original paper #12
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
In the original algorithm, the intermediate scores
S_iare not stored in shared memory. Instead, the outputO_iis accumulated incrementally as each block is processed. This PR adopts that approach, removing the need to materializeS_iand aligning the implementation more directly with the paper.Additionally, the kernel launch configuration has been changed to use one thread per row. This removes the outer
T_cloop, making the control flow much closer to the pseudocode in the paper and easier to reason about and compare against the reference algorithm.For the particular tensor sizes used, this code uses more shared memory and runs slightly slower (by ~1ms) on a 3060, however it may be easier to understand and extend, especially for readers learning how the algorithm works.
Results
=== profiling manual attention ===
Self CPU time total: 97.501ms
Self CUDA time total: 97.638ms
=== profiling minimal flash attention ===
Self CPU time total: 15.558ms
Self CUDA time total: 6.453ms