Commit 7af8cad
authored
[TritonGPU] Support persistent matmul in warp specialization (#6239)
This PR extends the "pattern" for load-MMA warp specialization to
support persistent kernels with MMAv5. This leverages more of the
existing MMAv5 pipelining code in `TC05MMAPipeline.cpp`, primarily the
analysis part which determines if the op can be pipelined and determines
the accumulator override point. Thus, it is performed over the flattened
loop.
However, because warp specialization is async and cannot rely on
execution order, there are a few cases supported by the analysis step
that cannot be codegen'd at the moment. (There are likewise cases that
could be codegened that aren't supported by the analysis, but these
cases can be ironed out on an as-needed basis).
At a high level the extended "pattern" now looks for users of the
accumulator other than MMA op itself in the next iteration, and if it
finds any, places the users in a new partition and adds additional
synchronization, multi-buffering the accumulator if needed. This allows
the epilogue, which is a conditional user of the accumulator, to be
placed in its own partition, overlapping the epilogue with the
load<->MMA loop. The accumulator can also be multi-buffered, enabling
the next MMA to start running before the TMEM load completes in the user
partition.
This PR has lots of code motion due to refactoring utilities to be more
widely available:
* Move MMAInfo and the analysis to determine it into
`MMAv5PipelineUtility.h`
* Move some utilities from `PipeliningUtility.h` to `Utility.h`
* Some misc code cleanup and bugfixes along the way
* Fix lowering of tensordesc ops to insert addrspacecast when the
pointer types are actually different. Grid constant tensordescs have to
be generic address space due to NVPTX backend restriction/bug, but we
treat them as addrspace=1 pointers internally.
Performance results for `matmul_kernel_tma_persistent` on `M, N, K =
8192, 8192, 512 in `09-persistent-matmul.py`:
* With SWP, the best config is `BLOCK_{M,N,K} = (128, 256, 64)`, 4
stages and 4 warps at 1088 TFLOPS
* With WS, the best config is `BLOCK_{M, N, K} = (128, 256, 128)`, 4
stages and 4 warps at 1140 TFLOPS
That's about a ~5% increase in performance.1 parent fcf33a3 commit 7af8cad
File tree
29 files changed
+1700
-695
lines changed- include/triton/Dialect/TritonGPU/Transforms
- lib/Dialect/TritonGPU
- IR
- Transforms
- Pipeliner
- WarpSpecialization
- python
- src
- test/unit/language
- tutorials
- test
- Conversion
- TritonGPU
- third_party
- amd/lib/TritonAMDGPUTransforms
- nvidia/lib/TritonNVIDIAGPUToLLVM
29 files changed
+1700
-695
lines changedLines changed: 74 additions & 9 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
1 | 1 | | |
2 | 2 | | |
3 | 3 | | |
4 | | - | |
5 | | - | |
6 | | - | |
| 4 | + | |
7 | 5 | | |
8 | 6 | | |
| 7 | + | |
9 | 8 | | |
10 | | - | |
| 9 | + | |
11 | 10 | | |
12 | 11 | | |
13 | 12 | | |
14 | | - | |
| 13 | + | |
15 | 14 | | |
16 | | - | |
17 | | - | |
18 | | - | |
| 15 | + | |
| 16 | + | |
| 17 | + | |
| 18 | + | |
| 19 | + | |
| 20 | + | |
| 21 | + | |
| 22 | + | |
| 23 | + | |
| 24 | + | |
| 25 | + | |
| 26 | + | |
| 27 | + | |
| 28 | + | |
| 29 | + | |
| 30 | + | |
| 31 | + | |
| 32 | + | |
| 33 | + | |
| 34 | + | |
| 35 | + | |
| 36 | + | |
| 37 | + | |
| 38 | + | |
| 39 | + | |
| 40 | + | |
| 41 | + | |
| 42 | + | |
| 43 | + | |
| 44 | + | |
| 45 | + | |
| 46 | + | |
| 47 | + | |
| 48 | + | |
| 49 | + | |
| 50 | + | |
| 51 | + | |
| 52 | + | |
| 53 | + | |
| 54 | + | |
| 55 | + | |
| 56 | + | |
| 57 | + | |
| 58 | + | |
| 59 | + | |
| 60 | + | |
| 61 | + | |
| 62 | + | |
19 | 63 | | |
20 | 64 | | |
21 | 65 | | |
22 | 66 | | |
23 | 67 | | |
24 | 68 | | |
| 69 | + | |
| 70 | + | |
| 71 | + | |
| 72 | + | |
| 73 | + | |
| 74 | + | |
| 75 | + | |
| 76 | + | |
| 77 | + | |
| 78 | + | |
| 79 | + | |
25 | 80 | | |
26 | 81 | | |
27 | 82 | | |
28 | 83 | | |
29 | 84 | | |
| 85 | + | |
| 86 | + | |
| 87 | + | |
| 88 | + | |
| 89 | + | |
30 | 90 | | |
31 | 91 | | |
32 | 92 | | |
33 | 93 | | |
34 | 94 | | |
35 | 95 | | |
36 | | - | |
| 96 | + | |
| 97 | + | |
| 98 | + | |
| 99 | + | |
| 100 | + | |
37 | 101 | | |
| 102 | + | |
38 | 103 | | |
39 | 104 | | |
40 | 105 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
47 | 47 | | |
48 | 48 | | |
49 | 49 | | |
| 50 | + | |
| 51 | + | |
| 52 | + | |
50 | 53 | | |
51 | 54 | | |
52 | 55 | | |
| |||
57 | 60 | | |
58 | 61 | | |
59 | 62 | | |
| 63 | + | |
| 64 | + | |
60 | 65 | | |
61 | 66 | | |
62 | 67 | | |
| |||
Lines changed: 1 addition & 15 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
14 | 14 | | |
15 | 15 | | |
16 | 16 | | |
| 17 | + | |
17 | 18 | | |
18 | 19 | | |
19 | 20 | | |
| |||
38 | 39 | | |
39 | 40 | | |
40 | 41 | | |
41 | | - | |
42 | | - | |
43 | | - | |
44 | | - | |
45 | | - | |
46 | | - | |
47 | | - | |
48 | | - | |
49 | | - | |
50 | | - | |
51 | | - | |
52 | 42 | | |
53 | 43 | | |
54 | 44 | | |
| |||
90 | 80 | | |
91 | 81 | | |
92 | 82 | | |
93 | | - | |
94 | | - | |
95 | | - | |
96 | | - | |
97 | 83 | | |
98 | 84 | | |
99 | 85 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
10 | 10 | | |
11 | 11 | | |
12 | 12 | | |
| 13 | + | |
13 | 14 | | |
14 | 15 | | |
15 | 16 | | |
| |||
135 | 136 | | |
136 | 137 | | |
137 | 138 | | |
| 139 | + | |
| 140 | + | |
138 | 141 | | |
139 | 142 | | |
140 | 143 | | |
| |||
213 | 216 | | |
214 | 217 | | |
215 | 218 | | |
| 219 | + | |
| 220 | + | |
| 221 | + | |
| 222 | + | |
| 223 | + | |
| 224 | + | |
| 225 | + | |
| 226 | + | |
| 227 | + | |
| 228 | + | |
| 229 | + | |
| 230 | + | |
| 231 | + | |
| 232 | + | |
| 233 | + | |
| 234 | + | |
| 235 | + | |
| 236 | + | |
| 237 | + | |
| 238 | + | |
| 239 | + | |
216 | 240 | | |
217 | 241 | | |
218 | 242 | | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
848 | 848 | | |
849 | 849 | | |
850 | 850 | | |
851 | | - | |
| 851 | + | |
852 | 852 | | |
853 | 853 | | |
854 | 854 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
12 | 12 | | |
13 | 13 | | |
14 | 14 | | |
| 15 | + | |
15 | 16 | | |
16 | 17 | | |
17 | 18 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
900 | 900 | | |
901 | 901 | | |
902 | 902 | | |
| 903 | + | |
| 904 | + | |
| 905 | + | |
| 906 | + | |
| 907 | + | |
| 908 | + | |
| 909 | + | |
903 | 910 | | |
904 | 911 | | |
905 | 912 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
46 | 46 | | |
47 | 47 | | |
48 | 48 | | |
49 | | - | |
50 | | - | |
51 | | - | |
52 | | - | |
53 | | - | |
54 | | - | |
55 | | - | |
56 | | - | |
57 | | - | |
58 | | - | |
59 | | - | |
60 | | - | |
61 | | - | |
62 | | - | |
63 | | - | |
64 | | - | |
65 | | - | |
66 | | - | |
67 | | - | |
68 | | - | |
69 | | - | |
70 | | - | |
71 | | - | |
72 | | - | |
73 | | - | |
74 | | - | |
75 | | - | |
76 | | - | |
77 | | - | |
78 | 49 | | |
79 | 50 | | |
80 | 51 | | |
| |||
Lines changed: 1 addition & 4 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
211 | 211 | | |
212 | 212 | | |
213 | 213 | | |
214 | | - | |
215 | | - | |
216 | | - | |
217 | | - | |
| 214 | + | |
218 | 215 | | |
219 | 216 | | |
220 | 217 | | |
| |||
Lines changed: 10 additions & 16 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
590 | 590 | | |
591 | 591 | | |
592 | 592 | | |
593 | | - | |
594 | | - | |
595 | | - | |
596 | | - | |
| 593 | + | |
597 | 594 | | |
598 | 595 | | |
599 | | - | |
| 596 | + | |
600 | 597 | | |
601 | 598 | | |
602 | 599 | | |
| |||
605 | 602 | | |
606 | 603 | | |
607 | 604 | | |
608 | | - | |
| 605 | + | |
609 | 606 | | |
610 | | - | |
| 607 | + | |
611 | 608 | | |
612 | 609 | | |
613 | 610 | | |
614 | | - | |
| 611 | + | |
615 | 612 | | |
616 | 613 | | |
617 | 614 | | |
| |||
821 | 818 | | |
822 | 819 | | |
823 | 820 | | |
824 | | - | |
825 | | - | |
826 | | - | |
827 | | - | |
| 821 | + | |
828 | 822 | | |
829 | | - | |
| 823 | + | |
830 | 824 | | |
831 | 825 | | |
832 | 826 | | |
833 | | - | |
| 827 | + | |
834 | 828 | | |
835 | 829 | | |
836 | 830 | | |
837 | 831 | | |
838 | | - | |
| 832 | + | |
839 | 833 | | |
840 | 834 | | |
841 | 835 | | |
842 | | - | |
| 836 | + | |
843 | 837 | | |
844 | 838 | | |
845 | 839 | | |
| |||
0 commit comments