Commit 60a5050
[AMD] SDPA internal changes (pytorch#144320)
Summary: All the internal changes needed to enable flash attention w/ SDPA in fbcode.
Test Plan:
```
TORCH_ROCM_FA_PREFER_CK=1 buck run -m rocm621 mode/opt-amd-gpu scripts/xdwang/example:sdpa
+--------------+-------------------+---------+------------+-------------------+------------------+----------------------+----------------+---------------+-------------------+------------------------+---------------------------+----------------------+-------------------+
| Batch Size | Sequence Length | Heads | Head Dim | Flash Time (µs) | Math Time (µs) | xformers Time (µs) | Flash TFlops | Math TFlops | xformers TFlops | Speedup (Flash/Math) | Speedup (xformers/Math) | xformers trace_url | Flash trace_url |
+==============+===================+=========+============+===================+==================+======================+================+===============+===================+========================+===========================+======================+===================+
| 1 | 4096 | 32 | 64 | 455.552 | 7748.76 | 513.449 | 301.698 | 17.7369 | 267.678 | 17.0096 | 15.0916 | | |
+--------------+-------------------+---------+------------+-------------------+------------------+----------------------+----------------+---------------+-------------------+------------------------+---------------------------+----------------------+-------------------+
| 1 | 4096 | 16 | 128 | 329.971 | 4741.11 | 386.049 | 416.519 | 28.9888 | 356.014 | 14.3683 | 12.2811 | | |
+--------------+-------------------+---------+------------+-------------------+------------------+----------------------+----------------+---------------+-------------------+------------------------+---------------------------+----------------------+-------------------+
| 1 | 8192 | 32 | 64 | 1455.76 | 31869.6 | 1665.49 | 377.642 | 17.2501 | 330.087 | 21.8921 | 19.1353 | | |
+--------------+-------------------+---------+------------+-------------------+------------------+----------------------+----------------+---------------+-------------------+------------------------+---------------------------+----------------------+-------------------+
| 1 | 8192 | 16 | 128 | 1265.77 | 18972.8 | 1479.48 | 434.325 | 28.976 | 371.588 | 14.9891 | 12.824 | | |
+--------------+-------------------+---------+------------+-------------------+------------------+----------------------+----------------+---------------+-------------------+------------------------+---------------------------+----------------------+-------------------+
| 1 | 16384 | 32 | 64 | 5732.99 | 121861 | 6816.77 | 383.573 | 18.0453 | 322.59 | 21.2562 | 17.8767 | | |
+--------------+-------------------+---------+------------+-------------------+------------------+----------------------+----------------+---------------+-------------------+------------------------+---------------------------+----------------------+-------------------+
| 1 | 16384 | 16 | 128 | 4749.69 | 73776.4 | 5404.03 | 462.982 | 29.8066 | 406.923 | 15.5329 | 13.6521 | | |
+--------------+-------------------+---------+------------+-------------------+------------------+----------------------+----------------+---------------+-------------------+------------------------+---------------------------+----------------------+-------------------+
+--------------+-------------------+---------+------------+-------------------+------------------+----------------------+----------------+---------------+-------------------+------------------------+---------------------------+----------------------+-------------------+
| Batch Size | Sequence Length | Heads | Head Dim | Flash Time (µs) | Math Time (µs) | xformers Time (µs) | Flash TFlops | Math TFlops | xformers TFlops | Speedup (Flash/Math) | Speedup (xformers/Math) | xformers trace_url | Flash trace_url |
+==============+===================+=========+============+===================+==================+======================+================+===============+===================+========================+===========================+======================+===================+
| 1 | 4096 | 32 | 64 | 1615.41 | 8342.67 | 1822.72 | 212.7 | 41.1855 | 188.508 | 5.16443 | 4.57705 | | |
+--------------+-------------------+---------+------------+-------------------+------------------+----------------------+----------------+---------------+-------------------+------------------------+---------------------------+----------------------+-------------------+
| 1 | 4096 | 16 | 128 | 1357.97 | 5943.53 | 1432.34 | 253.022 | 57.8104 | 239.886 | 4.37676 | 4.14953 | | |
+--------------+-------------------+---------+------------+-------------------+------------------+----------------------+----------------+---------------+-------------------+------------------------+---------------------------+----------------------+-------------------+
| 1 | 8192 | 32 | 64 | 5556.5 | 31726.7 | 6502.17 | 247.348 | 43.3197 | 211.374 | 5.70984 | 4.8794 | | |
+--------------+-------------------+---------+------------+-------------------+------------------+----------------------+----------------+---------------+-------------------+------------------------+---------------------------+----------------------+-------------------+
| 1 | 8192 | 16 | 128 | 5186 | 22529.4 | 5590.36 | 265.019 | 61.0044 | 245.85 | 4.34427 | 4.03004 | | |
+--------------+-------------------+---------+------------+-------------------+------------------+----------------------+----------------+---------------+-------------------+------------------------+---------------------------+----------------------+-------------------+
| 1 | 16384 | 32 | 64 | 22527.7 | 130413 | 26527.6 | 244.035 | 42.155 | 207.239 | 5.789 | 4.91613 | | |
+--------------+-------------------+---------+------------+-------------------+------------------+----------------------+----------------+---------------+-------------------+------------------------+---------------------------+----------------------+-------------------+
| 1 | 16384 | 16 | 128 | 18347.9 | 87553.2 | 20358 | 299.628 | 62.791 | 270.044 | 4.77184 | 4.30068 | | |
+--------------+-------------------+---------+------------+-------------------+------------------+----------------------+----------------+---------------+-------------------+------------------------+---------------------------+----------------------+-------------------+
```
Reviewed By: leitian, feikou, yoyoyocmu, sijiac
Differential Revision: D67262726
Pull Request resolved: pytorch#144320
Approved by: https://github.com/jianyuh, https://github.com/eqy, https://github.com/leitian1 parent 7d9f26d commit 60a5050
File tree
7 files changed
+640
-484
lines changed- aten/src/ATen/native/transformers/hip/flash_attn
- aot
- ck
7 files changed
+640
-484
lines changedLines changed: 6 additions & 6 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
126 | 126 | | |
127 | 127 | | |
128 | 128 | | |
129 | | - | |
| 129 | + | |
130 | 130 | | |
131 | 131 | | |
132 | 132 | | |
| |||
254 | 254 | | |
255 | 255 | | |
256 | 256 | | |
257 | | - | |
| 257 | + | |
258 | 258 | | |
259 | 259 | | |
260 | 260 | | |
| |||
418 | 418 | | |
419 | 419 | | |
420 | 420 | | |
421 | | - | |
422 | | - | |
| 421 | + | |
| 422 | + | |
423 | 423 | | |
424 | 424 | | |
425 | 425 | | |
| |||
574 | 574 | | |
575 | 575 | | |
576 | 576 | | |
577 | | - | |
578 | | - | |
| 577 | + | |
| 578 | + | |
579 | 579 | | |
580 | 580 | | |
581 | 581 | | |
| |||
Lines changed: 7 additions & 0 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
383 | 383 | | |
384 | 384 | | |
385 | 385 | | |
| 386 | + | |
| 387 | + | |
| 388 | + | |
| 389 | + | |
| 390 | + | |
| 391 | + | |
| 392 | + | |
386 | 393 | | |
387 | 394 | | |
388 | 395 | | |
| |||
Lines changed: 8 additions & 0 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
142 | 142 | | |
143 | 143 | | |
144 | 144 | | |
| 145 | + | |
145 | 146 | | |
146 | 147 | | |
147 | 148 | | |
| |||
342 | 343 | | |
343 | 344 | | |
344 | 345 | | |
| 346 | + | |
| 347 | + | |
| 348 | + | |
| 349 | + | |
| 350 | + | |
| 351 | + | |
| 352 | + | |
345 | 353 | | |
346 | 354 | | |
347 | 355 | | |
| |||
Lines changed: 7 additions & 0 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
412 | 412 | | |
413 | 413 | | |
414 | 414 | | |
| 415 | + | |
| 416 | + | |
| 417 | + | |
| 418 | + | |
| 419 | + | |
| 420 | + | |
| 421 | + | |
415 | 422 | | |
416 | 423 | | |
417 | 424 | | |
| |||
Lines changed: 7 additions & 0 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
341 | 341 | | |
342 | 342 | | |
343 | 343 | | |
| 344 | + | |
| 345 | + | |
| 346 | + | |
| 347 | + | |
| 348 | + | |
| 349 | + | |
| 350 | + | |
344 | 351 | | |
345 | 352 | | |
346 | 353 | | |
| |||
0 commit comments