Skip to content

Conversation

@Byeong-Chan
Copy link

@Byeong-Chan Byeong-Chan commented Apr 5, 2024

Description

This PR implements a matrix multiplication optimization forward pass for flash attention. (~300 line)

I got these results on my RTX 3060 (sm_80 same or up)

in float minimal

=== profiling manual attention ===
...
Self CPU time total: 834.368ms
Self CUDA time total: 835.075ms

=== profiling minimal flash attention === 
...
Self CPU time total: 668.000us
Self CUDA time total: 687.000us

attn values sanity check: True

in half matmul opt

=== profiling manual attention ===
...
Self CPU time total: 849.544ms
Self CUDA time total: 849.698ms

=== profiling minimal flash attention ===
...
Self CPU time total: 89.000us
Self CUDA time total: 93.000us

attn values sanity check: True

Reference

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant