Skip to content

Benchmark Issue #13

@zack041

Description

@zack041

Thanks for the great work! recently I have been working on a flash attention v2 kernel and your repo have been inspiring.

However, I noticed that when benchmarking with torch profiler, like your repo, the flash attention is faster than the manual attention with smaller N but much slower with larger N.

While when I use cuda.Event.record , the flash attention is consistently much slower than the manual attention (I think this would also happen if using your kernel), and I don't think it is a thread-per-row problem because my kernel is warp-based and block-per-row.

Therefore my conjecture is that torch profiler is accounting for the time that PyTorch takes to initialize the manual layer which makes it slower with small N, while it executes consistently faster than our flash attentions on a kernel-level.

I wonder if anyone knows an answer for sure or could come up with some way to experiment this conjecture?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions