-
Notifications
You must be signed in to change notification settings - Fork 109
Description
Thanks for the great work! recently I have been working on a flash attention v2 kernel and your repo have been inspiring.
However, I noticed that when benchmarking with torch profiler, like your repo, the flash attention is faster than the manual attention with smaller N but much slower with larger N.
While when I use cuda.Event.record , the flash attention is consistently much slower than the manual attention (I think this would also happen if using your kernel), and I don't think it is a thread-per-row problem because my kernel is warp-based and block-per-row.
Therefore my conjecture is that torch profiler is accounting for the time that PyTorch takes to initialize the manual layer which makes it slower with small N, while it executes consistently faster than our flash attentions on a kernel-level.
I wonder if anyone knows an answer for sure or could come up with some way to experiment this conjecture?