Skip to content

Commit 3530aab

Browse files
authored
[AMD][gfx1250] Enable mixed precision (scaled) dot support in Triton (#8938)
This PR is to enable `test_scaled_dot@test_core.py`. The general lowering steps are the same as steps on gfx9: The operand with mxfp8 or mxfp4 type will be upcasted to the same [b]f16 type of another operand, using the new gfx1250 `cvt.scale.pk8.*` instructions. Once both operands type are aligned, the existing code generation about 16x16x32 [b]f16 wmma instruction can be re-used. Since the data layout of input and scale tensor are different on gfx1250 than on gfx9, most of changes in this PR are about: - prepare the layout for new cvt.scale.pk8 instructions; - the op_sel value selection corresponding to the layout created in the step above; More details are added as the inline comments in the code changes. As for test: Relevant lit-tests are added. Runtime test: pytest -s -v python/test/unit/language/test_core.py::test_scaled_dot.
1 parent 1c15a29 commit 3530aab

File tree

10 files changed

+452
-57
lines changed

10 files changed

+452
-57
lines changed

test/TritonGPU/amd/accelerate-amd-matmul-wmma-gfx1250.mlir

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,3 +194,136 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
194194
tt.return
195195
}
196196
}
197+
198+
// -----
199+
200+
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}>
201+
#blocked2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0]}>
202+
#blocked4 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}>
203+
// CHECK{LITERAL}: #linear = #ttg.linear<{register = [[0, 1], [0, 2], [0, 4], [0, 8], [0, 16]], lane = [[0, 32], [0, 64], [1, 0], [2, 0], [4, 0]], warp = [[8, 0], [16, 0]], block = []}>
204+
// CHECK-LABEL: wmma_dot_scaled_mxfp8_bf16
205+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx1250", "ttg.threads-per-warp" = 32 : i32} {
206+
tt.func public @wmma_dot_scaled_mxfp8_bf16(
207+
%arg0: tensor<32x128x!tt.ptr<f8E4M3FN>, #blocked4>,
208+
%arg1: tensor<32x4x!tt.ptr<i8>, #blocked2>,
209+
%arg2: tensor<128x32x!tt.ptr<bf16>, #blocked>,
210+
%output: tensor<32x32x!tt.ptr<f32>, #blocked>
211+
) {
212+
// CHECK: tt.load %arg1 {amdg.decomposed_dot_scaled_source = true} : tensor<32x4x!tt.ptr<i8>, #blocked1>
213+
// CHECK: %[[SCALE:.*]] = tt.reshape {{.*}} : tensor<32x4x32xi8, #blocked3> -> tensor<32x128xi8, #linear>
214+
// CHECK: %[[CVT0:.*]] = ttg.convert_layout %[[SCALE]] : tensor<32x128xi8, #linear> -> tensor<32x128xi8, #blocked>
215+
// CHECK: %[[UPCASTED:.*]] = amdg.scaled_upcast_fp8 {{.*}} scale %[[CVT0]] : tensor<32x128xf8E4M3FN, #blocked>, tensor<32x128xi8, #blocked> -> tensor<32x128xbf16, #blocked>
216+
// CHECK: %[[SEL:.*]] = arith.select {{.*}}, {{.*}}, %[[UPCASTED]]
217+
// CHECK: %[[CVT1:.*]] = ttg.convert_layout %[[SEL]] : tensor<32x128xbf16, #blocked> -> tensor<32x128xbf16, #ttg.dot_op<{opIdx = 0, parent = #blocked2}>>
218+
// CHECK: %[[OPND0:.*]] = ttg.convert_layout %[[CVT1]] : tensor<32x128xbf16, #ttg.dot_op<{opIdx = 0, parent = #blocked2}>> -> tensor<32x128xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>>
219+
// CHECK: tt.dot %[[OPND0]]
220+
%a = tt.load %arg0 : tensor<32x128x!tt.ptr<f8E4M3FN>, #blocked4>
221+
%scale = tt.load %arg1 : tensor<32x4x!tt.ptr<i8>, #blocked2>
222+
%b = tt.load %arg2 : tensor<128x32x!tt.ptr<bf16>, #blocked>
223+
%c = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #blocked>
224+
%res = tt.dot_scaled %a scale %scale, %b, %c lhs = e4m3 rhs = bf16 {fastMath = false} : tensor<32x128xf8E4M3FN, #blocked4>, tensor<32x4xi8, #blocked2> * tensor<128x32xbf16, #blocked> -> tensor<32x32xf32, #blocked>
225+
226+
tt.store %output, %res : tensor<32x32x!tt.ptr<f32>, #blocked>
227+
tt.return
228+
}
229+
}
230+
231+
// -----
232+
233+
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}>
234+
#blocked2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0]}>
235+
#blocked4 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}>
236+
// CHECK{LITERAL}: #linear = #ttg.linear<{register = [[1, 0], [2, 0], [4, 0], [8, 0], [16, 0]], lane = [[0, 1], [0, 2], [0, 4], [0, 8], [0, 16]], warp = [[32, 0], [64, 0]], block = []}>
237+
// CHECK-LABEL: wmma_dot_scaled_f16_mxfp8
238+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx1250", "ttg.threads-per-warp" = 32 : i32} {
239+
tt.func public @wmma_dot_scaled_f16_mxfp8(
240+
%arg0: tensor<32x128x!tt.ptr<f16>, #blocked4>,
241+
%arg1: tensor<32x4x!tt.ptr<i8>, #blocked2>,
242+
%arg2: tensor<128x32x!tt.ptr<f8E5M2>, #blocked>,
243+
%output: tensor<32x32x!tt.ptr<f32>, #blocked>
244+
) {
245+
// CHECK: %[[TRANS:.*]] = tt.trans {{.*}} {order = array<i32: 0, 2, 1>} : tensor<4x32x32xi8, #blocked4> -> tensor<4x32x32xi8, #blocked5>
246+
// CHECK: %[[SCALE:.*]] = tt.reshape %[[TRANS]] : tensor<4x32x32xi8, #blocked5> -> tensor<128x32xi8, #linear>
247+
// CHECK: %[[CVT0:.*]] = ttg.convert_layout %[[SCALE]] : tensor<128x32xi8, #linear> -> tensor<128x32xi8, #blocked2>
248+
// CHECK: %[[UPCASTED:.*]] = amdg.scaled_upcast_fp8 {{.*}} scale %[[CVT0]] : tensor<128x32xf8E5M2, #blocked2>, tensor<128x32xi8, #blocked2> -> tensor<128x32xf16, #blocked2>
249+
// CHECK: %[[SEL:.*]] = arith.select {{.*}}, %cst, %[[UPCASTED]] : tensor<128x32xi1, #blocked2>, tensor<128x32xf16, #blocked2>
250+
// CHECK: %[[CVT1:.*]] = ttg.convert_layout %[[SEL]] : tensor<128x32xf16, #blocked2> -> tensor<128x32xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked2}>>
251+
// CHECK: %[[OPND1:.*]] = ttg.convert_layout %[[CVT1]] : tensor<128x32xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked2}>> -> tensor<128x32xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}>>
252+
// CHECK: = tt.dot {{.*}}, %[[OPND1]]
253+
%a = tt.load %arg0 : tensor<32x128x!tt.ptr<f16>, #blocked4>
254+
%scale = tt.load %arg1 : tensor<32x4x!tt.ptr<i8>, #blocked2>
255+
%b = tt.load %arg2 : tensor<128x32x!tt.ptr<f8E5M2>, #blocked>
256+
%c = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #blocked>
257+
%res = tt.dot_scaled %a, %b scale %scale, %c lhs = fp16 rhs = e5m2 {fastMath = false} : tensor<32x128xf16, #blocked4> * tensor<128x32xf8E5M2, #blocked>, tensor<32x4xi8, #blocked2> -> tensor<32x32xf32, #blocked>
258+
259+
tt.store %output, %res : tensor<32x32x!tt.ptr<f32>, #blocked>
260+
tt.return
261+
}
262+
}
263+
264+
// -----
265+
266+
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}>
267+
#blocked1 = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
268+
#blocked2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [1, 0]}>
269+
#blocked5 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}>
270+
// CHECK{LITERAL}: #linear = #ttg.linear<{register = [[0, 1], [0, 2], [0, 4], [0, 8], [0, 16]], lane = [[0, 32], [1, 0], [2, 0], [4, 0], [8, 0]], warp = [[0, 0], [0, 0]], block = []}>
271+
// CHECK-LABEL: wmma_dot_scaled_mxfp4_bf16
272+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx1250", "ttg.threads-per-warp" = 32 : i32} {
273+
tt.func public @wmma_dot_scaled_mxfp4_bf16(
274+
%arg0: tensor<16x32x!tt.ptr<i8>, #blocked5>,
275+
%arg1: tensor<16x2x!tt.ptr<i8>, #blocked2>,
276+
%arg2: tensor<64x16x!tt.ptr<bf16>, #blocked>,
277+
%output: tensor<16x16x!tt.ptr<f32>, #blocked>
278+
) {
279+
// CHECK: tt.load %arg1 {amdg.decomposed_dot_scaled_source = true} : tensor<16x2x!tt.ptr<i8>, #blocked1>
280+
// CHECK: %[[SCALE:.*]] = tt.reshape {{.*}} : tensor<16x2x32xi8, #blocked3> -> tensor<16x64xi8, #linear>
281+
// CHECK: %[[CVT0:.*]] = ttg.convert_layout %[[SCALE]] : tensor<16x64xi8, #linear> -> tensor<16x64xi8, #blocked>
282+
// CHECK: %[[UPCASTED:.*]] = amdg.scaled_upcast_fp4 {{.*}} scale %[[CVT0]] {axis = 1 : i32} : tensor<16x32xi8, #blocked>, tensor<16x64xi8, #blocked> -> tensor<16x64xbf16, #blocked>
283+
// CHECK: %[[SEL:.*]] = arith.select {{.*}}, %{{.*}}, %[[UPCASTED]] : tensor<16x64xi1, #blocked>, tensor<16x64xbf16, #blocked>
284+
// CHECK: %[[CVT1:.*]] = ttg.convert_layout %[[SEL]] : tensor<16x64xbf16, #blocked> -> tensor<16x64xbf16, #ttg.dot_op<{opIdx = 0, parent = #blocked2}>>
285+
// CHECK: %[[OPND0:.*]] = ttg.convert_layout %[[CVT1]] : tensor<16x64xbf16, #ttg.dot_op<{opIdx = 0, parent = #blocked2}>> -> tensor<16x64xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>>
286+
// CHECK: tt.dot %[[OPND0]]
287+
%a = tt.load %arg0 : tensor<16x32x!tt.ptr<i8>, #blocked5>
288+
%scale = tt.load %arg1 : tensor<16x2x!tt.ptr<i8>, #blocked2>
289+
%b = tt.load %arg2 : tensor<64x16x!tt.ptr<bf16>, #blocked>
290+
%c = arith.constant dense<0.000000e+00> : tensor<16x16xf32, #blocked>
291+
%res = tt.dot_scaled %a scale %scale, %b, %c lhs = e2m1 rhs = bf16 {fastMath = false} : tensor<16x32xi8, #blocked5>, tensor<16x2xi8, #blocked2> * tensor<64x16xbf16, #blocked> -> tensor<16x16xf32, #blocked>
292+
293+
tt.store %output, %res : tensor<16x16x!tt.ptr<f32>, #blocked>
294+
tt.return
295+
}
296+
}
297+
298+
// -----
299+
300+
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}>
301+
#blocked2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [1, 0]}>
302+
#blocked5 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}>
303+
// CHECK{LITERAL}: #linear = #ttg.linear<{register = [[1, 0], [2, 0], [4, 0], [8, 0], [16, 0]], lane = [[0, 1], [0, 2], [0, 4], [0, 8], [32, 0]], warp = [[0, 0], [0, 0]], block = []}>
304+
// CHECK-LABEL: wmma_dot_scaled_fp16_mxfp4
305+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx1250", "ttg.threads-per-warp" = 32 : i32} {
306+
tt.func public @wmma_dot_scaled_fp16_mxfp4(
307+
%arg0: tensor<16x64x!tt.ptr<f16>, #blocked5>,
308+
%arg1: tensor<16x2x!tt.ptr<i8>, #blocked2>,
309+
%arg2: tensor<32x16x!tt.ptr<i8>, #blocked>,
310+
%output: tensor<16x16x!tt.ptr<f32>, #blocked>
311+
) {
312+
// CHECK: tt.load %arg1 {amdg.decomposed_dot_scaled_source = true} : tensor<16x2x!tt.ptr<i8>, #blocked1>
313+
// CHECK: %[[SCALE:.*]] = tt.reshape {{.*}} : tensor<2x32x16xi8, #blocked5> -> tensor<64x16xi8, #linear>
314+
// CHECK: %[[CVT0:.*]] = ttg.convert_layout %[[SCALE]] : tensor<64x16xi8, #linear> -> tensor<64x16xi8, #blocked2>
315+
// CHECK: %[[UPCASTED:.*]] = amdg.scaled_upcast_fp4 {{.*}} scale %[[CVT0]] {axis = 0 : i32} : tensor<32x16xi8, #blocked2>, tensor<64x16xi8, #blocked2> -> tensor<64x16xf16, #blocked2>
316+
// CHECK: %[[SEL:.*]] = arith.select {{.*}}, %cst, %[[UPCASTED]] : tensor<64x16xi1, #blocked2>, tensor<64x16xf16, #blocked2>
317+
// CHECK: %[[CVT1:.*]] = ttg.convert_layout %[[SEL]] : tensor<64x16xf16, #blocked2> -> tensor<64x16xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked2}>>
318+
// CHECK: %[[OPND1:.*]] = ttg.convert_layout %[[CVT1]] : tensor<64x16xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked2}>> -> tensor<64x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}>>
319+
// CHECK: tt.dot {{.*}}, %[[OPND1]]
320+
%a = tt.load %arg0 : tensor<16x64x!tt.ptr<f16>, #blocked5>
321+
%scale = tt.load %arg1 : tensor<16x2x!tt.ptr<i8>, #blocked2>
322+
%b = tt.load %arg2 : tensor<32x16x!tt.ptr<i8>, #blocked>
323+
%c = arith.constant dense<0.000000e+00> : tensor<16x16xf32, #blocked>
324+
%res = tt.dot_scaled %a, %b scale %scale, %c lhs = fp16 rhs = e2m1 {fastMath = false} : tensor<16x64xf16, #blocked5> * tensor<32x16xi8, #blocked>, tensor<16x2xi8, #blocked2> -> tensor<16x16xf32, #blocked>
325+
326+
tt.store %output, %res : tensor<16x16x!tt.ptr<f32>, #blocked>
327+
tt.return
328+
}
329+
}
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
// RUN: triton-opt %s -split-input-file --allocate-amdgpu-shared-memory --convert-triton-amdgpu-to-llvm="arch=gfx1250" --canonicalize --cse | FileCheck %s
2+
3+
// -----
4+
5+
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}>
6+
#mma = #ttg.amd_wmma<{version = 3, isTranspose = true, warpsPerCTA = [2, 2], instrShape = [16, 16, 32]}>
7+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx1250", "ttg.threads-per-warp" = 32 : i32} {
8+
tt.func public @wmma_dot_scaled_mxfp8_bf16(%arg0: tensor<32x128xf8E4M3FN, #blocked>, %arg1: tensor<32x128xi8, #blocked>, %arg2: tensor<32x128x!tt.ptr<bf16>, #blocked>) {
9+
// CHECK: %[[SCALE:.*]] = llvm.extractvalue %arg1[0] : !llvm.struct<(i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8)>
10+
// CHECK: %[[SCALE_1:.*]] = llvm.extractvalue %arg1[8] : !llvm.struct<(i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8)>
11+
// CHECK: %[[SCALE_2:.*]] = llvm.extractvalue %arg1[16] : !llvm.struct<(i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8)>
12+
// CHECK: %[[SCALE_3:.*]] = llvm.extractvalue %arg1[24] : !llvm.struct<(i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8)>
13+
14+
// CHECK: llvm.insertelement %[[SCALE]], {{.*}} : vector<4xi8>
15+
// CHECK: llvm.insertelement %[[SCALE]], {{.*}} : vector<4xi8>
16+
// CHECK: llvm.insertelement %[[SCALE]], {{.*}} : vector<4xi8>
17+
// CHECK: %[[V0:.*]] = llvm.insertelement %[[SCALE]], {{.*}} : vector<4xi8>
18+
// CHECK: %[[SCALE_INT32:.*]] = llvm.bitcast %[[V0]] : vector<4xi8> to i32
19+
// CHECK: rocdl.cvt.scale.pk8.bf16.fp8 {{.*}}, %[[SCALE_INT32]][0] : vector<8xbf16>
20+
21+
// CHECK: llvm.insertelement %[[SCALE_1]], {{.*}} : vector<4xi8>
22+
// CHECK: llvm.insertelement %[[SCALE_1]], {{.*}} : vector<4xi8>
23+
// CHECK: llvm.insertelement %[[SCALE_1]], {{.*}} : vector<4xi8>
24+
// CHECK: %[[V1:.*]] = llvm.insertelement %[[SCALE_1]], {{.*}} : vector<4xi8>
25+
// CHECK: %[[SCALE_INT32_1:.*]] = llvm.bitcast %[[V1]] : vector<4xi8> to i32
26+
// CHECK: rocdl.cvt.scale.pk8.bf16.fp8 {{.*}}, %[[SCALE_INT32_1]][0] : vector<8xbf16>
27+
28+
// CHECK: llvm.insertelement %[[SCALE_2]], {{.*}} : vector<4xi8>
29+
// CHECK: llvm.insertelement %[[SCALE_2]], {{.*}} : vector<4xi8>
30+
// CHECK: llvm.insertelement %[[SCALE_2]], {{.*}} : vector<4xi8>
31+
// CHECK: %[[V2:.*]] = llvm.insertelement %[[SCALE_2]], {{.*}} : vector<4xi8>
32+
// CHECK: %[[SCALE_INT32_2:.*]] = llvm.bitcast %[[V2]] : vector<4xi8> to i32
33+
// CHECK: rocdl.cvt.scale.pk8.bf16.fp8 {{.*}}, %[[SCALE_INT32_2]][0] : vector<8xbf16>
34+
35+
// CHECK: llvm.insertelement %[[SCALE_3]], {{.*}} : vector<4xi8>
36+
// CHECK: llvm.insertelement %[[SCALE_3]], {{.*}} : vector<4xi8>
37+
// CHECK: llvm.insertelement %[[SCALE_3]], {{.*}} : vector<4xi8>
38+
// CHECK: %[[V3:.*]] = llvm.insertelement %[[SCALE_3]], {{.*}} : vector<4xi8>
39+
// CHECK: %[[SCALE_INT32_3:.*]] = llvm.bitcast %[[V3]] : vector<4xi8> to i32
40+
// CHECK: rocdl.cvt.scale.pk8.bf16.fp8 {{.*}}, %[[SCALE_INT32_3]][0] : vector<8xbf16>
41+
%7 = amdg.scaled_upcast_fp8 %arg0 scale %arg1 : tensor<32x128xf8E4M3FN, #blocked>, tensor<32x128xi8, #blocked> -> tensor<32x128xbf16, #blocked>
42+
tt.store %arg2, %7 : tensor<32x128x!tt.ptr<bf16>, #blocked>
43+
tt.return
44+
}
45+
}
46+
47+
// -----
48+
49+
#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
50+
#mma = #ttg.amd_wmma<{version = 3, isTranspose = true, warpsPerCTA = [4, 1], instrShape = [16, 16, 32]}>
51+
#shared = #ttg.swizzled_shared<{vec = 4, perPhase = 4, maxPhase = 4, order = [1, 0]}>
52+
#smem = #ttg.shared_memory
53+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 2048 : i32, ttg.target = "hip:gfx1250", "ttg.threads-per-warp" = 32 : i32} {
54+
tt.func public @cvt_scale_pk8_bf16_fp4(%output: tensor<16x64x!tt.ptr<bf16>, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>>, %15: tensor<16x32xi8, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>, %27: tensor<16x64xi8, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>>) attributes {noinline = false} {
55+
// CHECK: %[[SCALE:.*]] = llvm.extractvalue %arg2[0] : !llvm.struct<(i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8)>
56+
// CHECK: %[[SCALE_1:.*]] = llvm.extractvalue %arg2[8] : !llvm.struct<(i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8)>
57+
// CHECK: %[[SCALE_2:.*]] = llvm.extractvalue %arg2[16] : !llvm.struct<(i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8)>
58+
// CHECK: %[[SCALE_3:.*]] = llvm.extractvalue %arg2[24] : !llvm.struct<(i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8)>
59+
60+
// CHECK: llvm.insertelement %[[SCALE]], {{.*}} : vector<4xi8>
61+
// CHECK: llvm.insertelement %[[SCALE]], {{.*}} : vector<4xi8>
62+
// CHECK: llvm.insertelement %[[SCALE]], {{.*}} : vector<4xi8>
63+
// CHECK: %[[V0:.*]] = llvm.insertelement %[[SCALE]], {{.*}} : vector<4xi8>
64+
// CHECK: %[[SCALE_INT32:.*]] = llvm.bitcast %[[V0]] : vector<4xi8> to i32
65+
// CHECK: rocdl.cvt.scale.pk8.bf16.fp4 {{.*}}, %[[SCALE_INT32]][0] : vector<8xbf16>
66+
67+
// CHECK: llvm.insertelement %[[SCALE_1]], {{.*}} : vector<4xi8>
68+
// CHECK: llvm.insertelement %[[SCALE_1]], {{.*}} : vector<4xi8>
69+
// CHECK: llvm.insertelement %[[SCALE_1]], {{.*}} : vector<4xi8>
70+
// CHECK: %[[V1:.*]] = llvm.insertelement %[[SCALE_1]], {{.*}} : vector<4xi8>
71+
// CHECK: %[[SCALE_INT32_1:.*]] = llvm.bitcast %[[V1]] : vector<4xi8> to i32
72+
// CHECK: rocdl.cvt.scale.pk8.bf16.fp4 {{.*}}, %[[SCALE_INT32_1]][0] : vector<8xbf16>
73+
74+
// CHECK: llvm.insertelement %[[SCALE_2]], {{.*}} : vector<4xi8>
75+
// CHECK: llvm.insertelement %[[SCALE_2]], {{.*}} : vector<4xi8>
76+
// CHECK: llvm.insertelement %[[SCALE_2]], {{.*}} : vector<4xi8>
77+
// CHECK: %[[V2:.*]] = llvm.insertelement %[[SCALE_2]], {{.*}} : vector<4xi8>
78+
// CHECK: %[[SCALE_INT32_2:.*]] = llvm.bitcast %[[V2]] : vector<4xi8> to i32
79+
// CHECK: rocdl.cvt.scale.pk8.bf16.fp4 {{.*}}, %[[SCALE_INT32_2]][0] : vector<8xbf16>
80+
81+
// CHECK: llvm.insertelement %[[SCALE_3]], {{.*}} : vector<4xi8>
82+
// CHECK: llvm.insertelement %[[SCALE_3]], {{.*}} : vector<4xi8>
83+
// CHECK: llvm.insertelement %[[SCALE_3]], {{.*}} : vector<4xi8>
84+
// CHECK: %[[V3:.*]] = llvm.insertelement %[[SCALE_3]], {{.*}} : vector<4xi8>
85+
// CHECK: %[[SCALE_INT32_3:.*]] = llvm.bitcast %[[V3]] : vector<4xi8> to i32
86+
// CHECK: rocdl.cvt.scale.pk8.bf16.fp4 {{.*}}, %[[SCALE_INT32_3]][0] : vector<8xbf16>
87+
88+
%28 = amdg.scaled_upcast_fp4 %15 scale %27 {axis = 1 : i32} : tensor<16x32xi8, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>, tensor<16x64xi8, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>> -> tensor<16x64xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>>
89+
tt.store %output, %28 : tensor<16x64x!tt.ptr<bf16>, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>>
90+
tt.return
91+
}
92+
}

third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -604,7 +604,7 @@ def ScaledUpcastFp4Op : TT_AMDGPU_Op<"scaled_upcast_fp4", [Pure, DeclareOpInterf
604604

605605
let arguments = (ins
606606
RankedTensorOf<[I8]>:$input,
607-
RankedTensorOf<[BF16]>:$scale,
607+
RankedTensorOf<[BF16, I8]>:$scale,
608608
I32Attr:$axis);
609609
let results = (outs RankedTensorOf<[AnyTypeOf<[F16, BF16, F32]>]>:$output);
610610

@@ -636,14 +636,15 @@ def ScaledUpcastFp8Op : TT_AMDGPU_Op<"scaled_upcast_fp8", [
636636

637637
let arguments = (ins
638638
RankedTensorOf<[AnyTypeOf<[F8E4M3FN, F8E5M2]>]>:$input,
639-
RankedTensorOf<[BF16]>:$scale);
639+
RankedTensorOf<[BF16, I8]>:$scale);
640640
let results = (outs RankedTensorOf<[AnyTypeOf<[F16, BF16, F32]>]>:$output);
641641

642642
let assemblyFormat = [{
643643
$input `scale` $scale attr-dict
644644
`:` type($input) `,` type($scale) `->` type($output)
645645
}];
646646
}
647+
647648
//===----------------------------------------------------------------------===//
648649
// InThreadTransposeOp
649650
//===----------------------------------------------------------------------===//

third_party/amd/include/TritonAMDGPUToLLVM/PatternTritonAMDGPUToLLVM.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#define TRITON_THIRD_PARTY_AMD_INCLUDE_TRITONAMDGPUTOLLVM_PATTERNTRITONAMDGPUTOLLVM_H_
33

44
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
5+
#include "third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.h"
56

67
namespace mlir::triton::AMD {
78

@@ -17,7 +18,7 @@ void populateConcatOpToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
1718

1819
void populateScaledUpcastOpToLLVMPatterns(
1920
mlir::LLVMTypeConverter &typeConverter, mlir::RewritePatternSet &patterns,
20-
mlir::PatternBenefit benefit);
21+
const AMD::TargetInfo &targetInfo, mlir::PatternBenefit benefit);
2122

2223
} // namespace mlir::triton::AMD
2324

0 commit comments

Comments
 (0)