Why does num_warps affect numerical precision? #7751
Unanswered
yaozhenghangma
asked this question in
Q&A
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
I'm testing a Flash Attention kernel, and I've placed the test code at the end. In short, when performing an attention operation using q, k, and v matrices with
num_batchs = 8
,seq_len = 128
,num_heads = 128
, andhead_dim = 128
, I observed that settingnum_warps = 4
andnum_warps = 8
produce different output results. The maximum difference between the output matrices is0.00048828125
, and this difference remains consistent across repeated runs. Moreover, when keeping all other parameters the same and settingseq_len
to 32 or smaller, the discrepancy disappears.I'm using an Nvidia T4 GPU, with CUDA version 12.4 and Triton version 3.1.0.
Beta Was this translation helpful? Give feedback.
All reactions