Commit da3ab2a
authored
[codegen] Use Python identifier as prefix for IR SSA names (#7521)
Quite some time ago I implemented [`-mlir-use-nameloc-as-prefix`
upstream](llvm/llvm-project#119996). I actually
implemented this because I was pulling my hair out trying to debug
Triton pipelines but I never got around to plumbing it all the way
through. So here's the plumbing.
The way this works is like this
```python
# demo.py
@triton.jit
def _kernel(src, N, BLOCK_SIZE: tl.constexpr):
pid = tl.program_id(0)
offset = pid * BLOCK_SIZE
offsets = offset + tl.arange(0, BLOCK_SIZE)
load_src_store_dst = src + offsets
mask = offsets < N
x_plus_1 = tl.load(load_src_store_dst, mask=mask) + 1
tl.store(load_src_store_dst, x_plus_1, mask=mask)
# shell
MLIR_ENABLE_DUMP=1 python demo.py
```
will give you dumps like this:
```mlir
module {
tt.func public @_kernel(%src: !tt.ptr<f32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32} , %N: i32 {tt.divisibility = 16 : i32} ) attributes {noinline = false} {
%cst = arith.constant dense<1.000000e+00> : tensor<16xf32>
%c16_i32 = arith.constant 16 : i32
%pid = tt.get_program_id x : i32
%offset = arith.muli %pid, %c16_i32 : i32
%offsets = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32>
%offsets_0 = tt.splat %offset : i32 -> tensor<16xi32>
%offsets_1 = arith.addi %offsets_0, %offsets : tensor<16xi32>
%load_src_store_dst = tt.splat %src : !tt.ptr<f32> -> tensor<16x!tt.ptr<f32>>
%load_src_store_dst_2 = tt.addptr %load_src_store_dst, %offsets_1 : tensor<16x!tt.ptr<f32>>, tensor<16xi32>
%mask = tt.splat %N : i32 -> tensor<16xi32>
%mask_3 = arith.cmpi slt, %offsets_1, %mask : tensor<16xi32>
%x_plus_1 = tt.load %load_src_store_dst_2, %mask_3 : tensor<16x!tt.ptr<f32>>
%x_plus_1_4 = arith.addf %x_plus_1, %cst : tensor<16xf32>
tt.store %load_src_store_dst_2, %x_plus_1_4, %mask_3 : tensor<16x!tt.ptr<f32>>
tt.return
}
}
```
Notice, the SSA name (roughly) correspond to the Python identifiers
(including func args `%src, %N`). Note, the reason we have
```mlir
%offsets = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32>
%offsets_0 = tt.splat %offset : i32 -> tensor<16xi32>
%offsets_1 = arith.addi %offsets_0, %offsets : tensor<16xi32>
```
is because the way it's plumbed is the "target" of the assignment
determines (contextually) the SSA names of all of the intermediate
values of the rhs. While this is subject to bike-shedding tbh there's
not really another way to do it (I tried...).
**Furthermore**, because `NameLoc` attributes are just `Location`
attributes these names will persist/be propagated through passes
(assuming the passes correctly/adequately propagate):
```mlir
// -----// IR Dump After TritonAMDGPUConvertToBufferOps (tritonamdgpu-convert-buffer-ops) ('builtin.module' operation) //----- //
#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx1100", "ttg.threads-per-warp" = 32 : i32} {
tt.func public @_kernel(%src: !tt.ptr<f32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32} , %N: i32 {tt.divisibility = 16 : i32} ) attributes {noinline = false} {
%cst = arith.constant dense<1.000000e+00> : tensor<16xf32, #blocked>
%c16_i32 = arith.constant 16 : i32
%pid = tt.get_program_id x : i32
%offset = arith.muli %pid, %c16_i32 : i32
%offsets = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #blocked>
%offsets_0 = tt.splat %offset : i32 -> tensor<16xi32, #blocked>
%offsets_1 = arith.addi %offsets_0, %offsets : tensor<16xi32, #blocked>
%load_src_store_dst = tt.addptr %src, %offset : !tt.ptr<f32>, i32
%mask = tt.splat %N : i32 -> tensor<16xi32, #blocked>
%mask_2 = arith.cmpi slt, %offsets_1, %mask : tensor<16xi32, #blocked>
%x_plus_1 = amdgpu.buffer_load %load_src_store_dst[%offsets], %mask_2 : tensor<16xf32, #blocked>
%x_plus_1_3 = arith.addf %x_plus_1, %cst : tensor<16xf32, #blocked>
amdgpu.buffer_store %x_plus_1_3, %load_src_store_dst[%offsets], %mask_2 : tensor<16xf32, #blocked>
tt.return
}
}
```
Notice `%x_plus_1 = amdgpu.buffer_load` keeps the SSA name from
`%x_plus_1 = tt.load`. This is the *real* value prop (at least for me).
And on down:
```mlir
// -----// IR Dump After ConvertBuiltinFuncToLLVM (convert-builtin-func-to-llvm) ('builtin.module' operation) //----- //
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 0 : i32, ttg.target = "hip:gfx1100", "ttg.threads-per-warp" = 32 : i32} {
llvm.mlir.global external @global_smem() {addr_space = 3 : i32, alignment = 16 : i64} : !llvm.array<0 x i8>
llvm.func @_kernel(%src: !llvm.ptr<1> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32} , %N: i32 {tt.divisibility = 16 : i32} , %arg2: !llvm.ptr<1> ) attributes {noinline = false, nvvm.kernel = 1 : ui1, nvvm.reqntid = array<i32: 128>} {
%0 = llvm.mlir.undef : vector<1xf32>
%1 = llvm.mlir.constant(3 : i32) : i32
%2 = llvm.mlir.constant(true) : i1
%3 = llvm.mlir.constant(4 : i32) : i32
%4 = llvm.mlir.constant(-2147483648 : i32) : i32
%5 = llvm.mlir.constant(2147483646 : i32) : i32
%6 = llvm.mlir.constant(822243328 : i32) : i32
%7 = llvm.mlir.constant(0 : i16) : i16
%8 = llvm.mlir.constant(15 : i32) : i32
%9 = llvm.mlir.constant(5 : i32) : i32
%10 = llvm.mlir.constant(0 : i32) : i32
%11 = llvm.mlir.constant(32 : i32) : i32
%12 = llvm.mlir.constant(127 : i32) : i32
%13 = llvm.mlir.constant(0 : index) : i32
%14 = llvm.mlir.constant(16 : i32) : i32
%15 = llvm.mlir.constant(1.000000e+00 : f32) : f32
%pid = rocdl.workgroup.id.x : i32
%offset = llvm.mul %pid, %14 : i32
%offsets = rocdl.workitem.id.x : i32
%offsets_0 = llvm.and %offsets, %12 : i32
%offsets_1 = llvm.urem %offsets_0, %11 : i32
%offsets_2 = llvm.udiv %offsets_0, %11 : i32
%offsets_3 = llvm.shl %offsets_1, %10 : i32
%offsets_4 = llvm.or %10, %offsets_3 : i32
%offsets_5 = llvm.shl %offsets_2, %9 : i32
%offsets_6 = llvm.or %offsets_4, %offsets_5 : i32
%offsets_7 = llvm.and %offsets_6, %8 : i32
%offsets_8 = llvm.lshr %offsets_7, %10 : i32
%offsets_9 = llvm.xor %10, %offsets_8 : i32
%offsets_10 = llvm.xor %10, %offsets_9 : i32
%offsets_11 = llvm.xor %offsets_10, %10 : i32
%offsets_12 = llvm.add %offsets_11, %13 : i32
%offsets_13 = llvm.add %offset, %offsets_12 : i32
%load_src_store_dst = llvm.getelementptr %src[%offset] : (!llvm.ptr<1>, i32) -> !llvm.ptr<1>, f32
%mask = llvm.icmp "slt" %offsets_13, %N : i32
%x_plus_1 = rocdl.make.buffer.rsrc %load_src_store_dst, %7, %5, %6 : <1> to <8>
%x_plus_1_14 = llvm.mul %offsets_12, %3 : i32
%x_plus_1_15 = llvm.select %mask, %x_plus_1_14, %4 : i1, i32
%x_plus_1_16 = rocdl.raw.ptr.buffer.load %x_plus_1, %x_plus_1_15, %10, %10 : f32
%x_plus_1_17 = llvm.bitcast %x_plus_1_16 : f32 to vector<1xf32>
%x_plus_1_18 = llvm.extractelement %x_plus_1_17[%13 : i32] : vector<1xf32>
%x_plus_1_19 = llvm.fadd %x_plus_1_18, %15 : f32
%16 = llvm.and %offsets_1, %14 : i32
%17 = llvm.icmp "eq" %16, %10 : i32
%18 = llvm.and %2, %17 : i1
%19 = llvm.and %offsets_2, %1 : i32
%20 = llvm.icmp "eq" %19, %10 : i32
%21 = llvm.and %18, %20 : i1
%22 = llvm.and %21, %mask : i1
%23 = llvm.insertelement %x_plus_1_19, %0[%10 : i32] : vector<1xf32>
%24 = llvm.bitcast %23 : vector<1xf32> to f32
%25 = llvm.select %22, %x_plus_1_14, %4 : i1, i32
rocdl.raw.ptr.buffer.store %24, %x_plus_1, %25, %10, %10 : f32
llvm.return
}
}
```
Note, the "explosion" in e.g. `%offsets_*` is due to the choices made in
the passes themselves, not the flag/plumbing (i.e., location
propagation).1 parent 388149b commit da3ab2a
File tree
6 files changed
+297
-54
lines changed- python
- src
- test
- gluon
- unit
- cuda
- language
- triton/compiler
6 files changed
+297
-54
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
152 | 152 | | |
153 | 153 | | |
154 | 154 | | |
| 155 | + | |
155 | 156 | | |
156 | 157 | | |
157 | 158 | | |
| |||
372 | 373 | | |
373 | 374 | | |
374 | 375 | | |
375 | | - | |
376 | | - | |
377 | | - | |
378 | | - | |
379 | | - | |
| 376 | + | |
| 377 | + | |
| 378 | + | |
| 379 | + | |
| 380 | + | |
| 381 | + | |
| 382 | + | |
| 383 | + | |
| 384 | + | |
380 | 385 | | |
381 | 386 | | |
382 | 387 | | |
| |||
929 | 934 | | |
930 | 935 | | |
931 | 936 | | |
| 937 | + | |
| 938 | + | |
| 939 | + | |
| 940 | + | |
| 941 | + | |
| 942 | + | |
| 943 | + | |
| 944 | + | |
| 945 | + | |
| 946 | + | |
| 947 | + | |
| 948 | + | |
| 949 | + | |
| 950 | + | |
| 951 | + | |
| 952 | + | |
| 953 | + | |
| 954 | + | |
| 955 | + | |
| 956 | + | |
| 957 | + | |
| 958 | + | |
932 | 959 | | |
933 | 960 | | |
934 | 961 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
165 | 165 | | |
166 | 166 | | |
167 | 167 | | |
168 | | - | |
169 | | - | |
| 168 | + | |
| 169 | + | |
170 | 170 | | |
171 | 171 | | |
172 | 172 | | |
| |||
200 | 200 | | |
201 | 201 | | |
202 | 202 | | |
203 | | - | |
| 203 | + | |
204 | 204 | | |
205 | | - | |
| 205 | + | |
206 | 206 | | |
207 | 207 | | |
208 | 208 | | |
| |||
257 | 257 | | |
258 | 258 | | |
259 | 259 | | |
260 | | - | |
261 | | - | |
| 260 | + | |
| 261 | + | |
262 | 262 | | |
263 | 263 | | |
264 | 264 | | |
| |||
283 | 283 | | |
284 | 284 | | |
285 | 285 | | |
286 | | - | |
| 286 | + | |
287 | 287 | | |
288 | | - | |
| 288 | + | |
289 | 289 | | |
290 | 290 | | |
291 | 291 | | |
| |||
676 | 676 | | |
677 | 677 | | |
678 | 678 | | |
679 | | - | |
| 679 | + | |
680 | 680 | | |
681 | 681 | | |
682 | 682 | | |
683 | 683 | | |
684 | | - | |
| 684 | + | |
685 | 685 | | |
686 | 686 | | |
687 | 687 | | |
688 | 688 | | |
689 | | - | |
| 689 | + | |
690 | 690 | | |
691 | | - | |
692 | | - | |
693 | | - | |
694 | | - | |
695 | | - | |
696 | | - | |
| 691 | + | |
| 692 | + | |
| 693 | + | |
| 694 | + | |
| 695 | + | |
| 696 | + | |
697 | 697 | | |
698 | | - | |
699 | | - | |
700 | | - | |
| 698 | + | |
| 699 | + | |
| 700 | + | |
701 | 701 | | |
702 | 702 | | |
703 | 703 | | |
704 | 704 | | |
| 705 | + | |
705 | 706 | | |
706 | 707 | | |
707 | 708 | | |
| |||
736 | 737 | | |
737 | 738 | | |
738 | 739 | | |
739 | | - | |
| 740 | + | |
740 | 741 | | |
741 | 742 | | |
742 | 743 | | |
743 | 744 | | |
744 | | - | |
| 745 | + | |
745 | 746 | | |
746 | 747 | | |
747 | 748 | | |
748 | 749 | | |
749 | 750 | | |
750 | 751 | | |
751 | | - | |
752 | | - | |
753 | | - | |
754 | | - | |
755 | | - | |
756 | | - | |
| 752 | + | |
| 753 | + | |
| 754 | + | |
| 755 | + | |
| 756 | + | |
| 757 | + | |
757 | 758 | | |
758 | | - | |
759 | | - | |
| 759 | + | |
| 760 | + | |
760 | 761 | | |
761 | 762 | | |
762 | 763 | | |
763 | 764 | | |
| 765 | + | |
764 | 766 | | |
765 | 767 | | |
766 | 768 | | |
| |||
972 | 974 | | |
973 | 975 | | |
974 | 976 | | |
| 977 | + | |
975 | 978 | | |
976 | | - | |
| 979 | + | |
977 | 980 | | |
978 | 981 | | |
979 | 982 | | |
| |||
1003 | 1006 | | |
1004 | 1007 | | |
1005 | 1008 | | |
1006 | | - | |
| 1009 | + | |
1007 | 1010 | | |
1008 | 1011 | | |
1009 | 1012 | | |
| |||
1202 | 1205 | | |
1203 | 1206 | | |
1204 | 1207 | | |
1205 | | - | |
| 1208 | + | |
| 1209 | + | |
1206 | 1210 | | |
1207 | 1211 | | |
1208 | 1212 | | |
1209 | | - | |
| 1213 | + | |
1210 | 1214 | | |
1211 | 1215 | | |
1212 | | - | |
| 1216 | + | |
1213 | 1217 | | |
1214 | | - | |
| 1218 | + | |
1215 | 1219 | | |
1216 | 1220 | | |
1217 | 1221 | | |
| |||
1223 | 1227 | | |
1224 | 1228 | | |
1225 | 1229 | | |
| 1230 | + | |
1226 | 1231 | | |
1227 | 1232 | | |
1228 | 1233 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
16 | 16 | | |
17 | 17 | | |
18 | 18 | | |
19 | | - | |
| 19 | + | |
| 20 | + | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
32 | 32 | | |
33 | 33 | | |
34 | 34 | | |
35 | | - | |
36 | | - | |
| 35 | + | |
| 36 | + | |
37 | 37 | | |
38 | 38 | | |
39 | 39 | | |
| |||
73 | 73 | | |
74 | 74 | | |
75 | 75 | | |
76 | | - | |
77 | | - | |
| 76 | + | |
| 77 | + | |
78 | 78 | | |
79 | | - | |
80 | | - | |
| 79 | + | |
| 80 | + | |
81 | 81 | | |
82 | | - | |
| 82 | + | |
83 | 83 | | |
84 | | - | |
85 | | - | |
| 84 | + | |
| 85 | + | |
0 commit comments