Skip to content

Commit e5be006

Browse files
[AMD] Use warp shuffle for fp8 MFMA to dot operand layout conversion (#5362)
This relands #5139: Adding a shortcut case for fp8 MFMA to dot operand layout conversion that avoids using shared memory, to speed up FP8 attention kernels. --------- Co-authored-by: ilia-cher <[email protected]>
1 parent 89c0b0a commit e5be006

File tree

5 files changed

+378
-2
lines changed

5 files changed

+378
-2
lines changed

include/triton/Analysis/Utility.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,11 @@ bool isBlockedToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy);
231231
bool matchMmaV3AndDotOperandLayout(RankedTensorType srcTy,
232232
RankedTensorType dstTy);
233233

234+
// Check if MFMA layout can be converted to the dot operand
235+
// layout using warp shuffle.
236+
bool matchMFMAAndDotOperandShuffleCase(RankedTensorType srcTy,
237+
RankedTensorType dstTy);
238+
234239
// TODO: Move utility functions that belong to ConvertLayoutOp to class
235240
// ConvertLayoutOpHelper in the future
236241
bool shouldUseDistSmem(Attribute srcLayout, Attribute dstLayout);

lib/Analysis/Utility.cpp

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include "mlir/IR/Dialect.h"
1111
#include "mlir/IR/Matchers.h"
1212
#include "mlir/Support/LLVM.h"
13+
#include "triton/Conversion/MLIRTypes.h"
1314
#include "triton/Dialect/Triton/IR/Dialect.h"
1415
#include "triton/Dialect/Triton/IR/Utility.h"
1516
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
@@ -644,6 +645,25 @@ bool matchMmaV3AndDotOperandLayout(RankedTensorType srcTy,
644645
return ans;
645646
}
646647

648+
bool matchMFMAAndDotOperandShuffleCase(RankedTensorType srcTy,
649+
RankedTensorType dstTy) {
650+
auto mfmaLayout = dyn_cast<AMDMfmaEncodingAttr>(srcTy.getEncoding());
651+
auto dotOperandLayout = dyn_cast<DotOperandEncodingAttr>(dstTy.getEncoding());
652+
if (!mfmaLayout || !dotOperandLayout)
653+
return false;
654+
655+
// Currently supporting 32x32 and 16x16 FP8 MFMA -> dot operand case
656+
return dotOperandLayout.getParent() == mfmaLayout &&
657+
dotOperandLayout.getOpIdx() == 0 && mfmaLayout.getIsTransposed() &&
658+
dotOperandLayout.getKWidth() == 8 &&
659+
getContigPerThread(mfmaLayout)[1] == 4 &&
660+
((mfmaLayout.getMDim() == 16 && mfmaLayout.getNDim() == 16) ||
661+
(mfmaLayout.getMDim() == 32 && mfmaLayout.getNDim() == 32)) &&
662+
triton::type::isFloat8(srcTy.getElementType()) &&
663+
triton::type::isFloat8(dstTy.getElementType()) &&
664+
mfmaLayout.getWarpsPerCTA()[1] == 1;
665+
}
666+
647667
// We get the smallest submap of srcTy^{-1} * dstTy that is not the identity
648668
// under kBlock, kWarp or kLane (in that order). The idea here is that if we
649669
// have a transformation that's the identity on kBlock, we don't need to use
@@ -708,7 +728,10 @@ bool cvtNeedsSharedMemory(RankedTensorType srcTy, RankedTensorType dstTy) {
708728
// supported yet in Triton's backend.
709729
return !cvtReordersRegisters(srcTy, dstTy) &&
710730
!isBlockedToDotShortcut(srcTy, dstTy) &&
711-
!matchMmaV3AndDotOperandLayout(srcTy, dstTy);
731+
!matchMmaV3AndDotOperandLayout(srcTy, dstTy) &&
732+
// to be removed when generalized warp shuffle conversions
733+
// are ready:
734+
!matchMFMAAndDotOperandShuffleCase(srcTy, dstTy);
712735
}
713736

714737
bool atomicNeedsSharedMemory(Value value) {

lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -391,6 +391,12 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
391391
return failure();
392392
}
393393

394+
// The following check can be removed when generalized warp shuffle
395+
// conversions are ready:
396+
if (matchMFMAAndDotOperandShuffleCase(srcTy, dstTy)) {
397+
return failure();
398+
}
399+
394400
assert(cvtNeedsSharedMemory(srcTy, dstTy));
395401

396402
SmallVector<Value> inVals =

test/Conversion/amd/mfma-shortcut.mlir

Lines changed: 189 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: triton-opt %s --decompose-unsupported-amd-conversions --allocate-shared-memory --convert-triton-amdgpu-to-llvm=arch="gfx90a" -split-input-file | FileCheck %s
1+
// RUN: triton-opt %s --decompose-unsupported-amd-conversions --allocate-shared-memory --convert-triton-amdgpu-to-llvm=arch="gfx942" -split-input-file | FileCheck %s
22

33
#mfma = #ttg.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16], isTransposed = true}>
44
#dotop = #ttg.dot_op<{opIdx = 0, parent = #mfma, kWidth=4}>
@@ -27,3 +27,191 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr
2727
tt.return
2828
}
2929
}
30+
31+
// -----
32+
33+
#mfma = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [32, 32], isTransposed = true}>
34+
#dotop0 = #ttg.dot_op<{opIdx = 0, parent = #mfma, kWidth=8}>
35+
36+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} {
37+
// CHECK-LABEL: mfma_dot_cvt_f8_mfma32
38+
tt.func public @mfma_dot_cvt_f8_mfma32(%arg0: tensor<128x32xf8E4M3FNUZ, #mfma>) {
39+
// CHECK-NOT: store
40+
// CHECK-NOT: load
41+
42+
// CHECK: [[val3:%.*]] = llvm.extractvalue %arg0[3]
43+
// CHECK: [[val7:%.*]] = llvm.extractvalue %arg0[7]
44+
45+
// CHECK-DAG: [[c32:%.*]] = llvm.mlir.constant(32 : i32)
46+
// CHECK-DAG: [[c64:%.*]] = llvm.mlir.constant(64 : i32)
47+
48+
// CHECK: [[threadId:%.*]] = rocdl.workitem.id.x
49+
// CHECK: [[laneId:%.*]] = llvm.urem [[threadId]], [[c64]]
50+
// CHECK: [[mask0:%.*]] = llvm.icmp "slt" [[laneId]], [[c32]]
51+
52+
// CHECK: [[shflLaneId:%.*]] = llvm.add [[laneId]], [[c32]]
53+
// CHECK: [[addr32:%.*]] = llvm.urem [[shflLaneId]], [[c64]]
54+
55+
// CHECK: [[vec0:%.*]] = llvm.insertelement [[val3]], {{.*}} : vector<4xi8>
56+
// CHECK: [[vec1:%.*]] = llvm.insertelement [[val7]], {{.*}} : vector<4xi8>
57+
58+
// CHECK: [[bvec0:%.*]] = llvm.bitcast [[vec0]]
59+
// CHECK: [[c2:%.*]] = llvm.mlir.constant(2 : i32)
60+
// CHECK: [[addr:%.*]] = llvm.shl [[addr32]], [[c2]]
61+
// CHECK: [[bShflVec0:%.*]] = rocdl.ds_bpermute [[addr]], [[bvec0]]
62+
// CHECK: [[shflVec0:%.*]] = llvm.bitcast [[bShflVec0]]
63+
64+
// CHECK: [[bvec1:%.*]] = llvm.bitcast [[vec1]]
65+
// CHECK: [[c2:%.*]] = llvm.mlir.constant(2 : i32)
66+
// CHECK: [[addr:%.*]] = llvm.shl [[addr32]], [[c2]]
67+
// CHECK: [[bShflVec1:%.*]] = rocdl.ds_bpermute [[addr]], [[bvec1]]
68+
// CHECK: [[shflVec1:%.*]] = llvm.bitcast [[bShflVec1]]
69+
70+
// Input (8 values): (vec0, vec1)
71+
// Output (8 values shuffled, '>> n' - take the value from (lane + n) % 64):
72+
// resVec0 resVec1
73+
// lanes 0-31: (vec0 , vec0 >> 32) (mask0=1)
74+
// lanes 32-63: (vec1 >> 32, vec1 ) (mask0=0)
75+
76+
// CHECK: [[resVec0:%.*]] = llvm.select [[mask0]], [[vec0]], [[shflVec1]]
77+
// CHECK: [[resVec1:%.*]] = llvm.select [[mask0]], [[shflVec0]], [[vec1]]
78+
79+
// CHECK: [[c3:%.*]] = llvm.mlir.constant(3 : i32)
80+
// CHECK: [[resVal3:%.*]] = llvm.extractelement [[resVec0]][[[c3]] : i32] : vector<4xi8>
81+
// CHECK: [[c3:%.*]] = llvm.mlir.constant(3 : i32) : i32
82+
// CHECK: [[resVal7:%.*]] = llvm.extractelement [[resVec1]][[[c3]] : i32] : vector<4xi8>
83+
84+
// CHECK: llvm.insertvalue [[resVal3]], {{.*}}[3]
85+
// CHECK: llvm.insertvalue [[resVal7]], {{.*}}[7]
86+
87+
// CHECK: llvm.return
88+
%0 = ttg.convert_layout %arg0 : tensor<128x32xf8E4M3FNUZ, #mfma> -> tensor<128x32xf8E4M3FNUZ, #dotop0>
89+
tt.return
90+
}
91+
}
92+
93+
// -----
94+
95+
#mfma = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [32, 32], isTransposed = true}>
96+
#dotop0 = #ttg.dot_op<{opIdx = 0, parent = #mfma, kWidth=8}>
97+
98+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} {
99+
// CHECK-LABEL: mfma_dot_cvt_bf8_mfma32
100+
tt.func public @mfma_dot_cvt_bf8_mfma32(%arg0: tensor<128x32xf8E5M2, #mfma>) {
101+
// CHECK-NOT: store
102+
// CHECK-NOT: load
103+
// CHECK: rocdl.ds_bpermute
104+
// CHECK: llvm.return
105+
%0 = ttg.convert_layout %arg0 : tensor<128x32xf8E5M2, #mfma> -> tensor<128x32xf8E5M2, #dotop0>
106+
tt.return
107+
}
108+
}
109+
110+
// -----
111+
112+
#mfma = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16], isTransposed = true}>
113+
#dotop0 = #ttg.dot_op<{opIdx = 0, parent = #mfma, kWidth=8}>
114+
115+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} {
116+
// CHECK-LABEL: mfma_dot_cvt_f8_mfma16
117+
tt.func public @mfma_dot_cvt_f8_mfma16(%arg0: tensor<128x32xf8E4M3FNUZ, #mfma>) {
118+
// CHECK-NOT: store
119+
// CHECK-NOT: load
120+
121+
// CHECK: [[val3:%.*]] = llvm.extractvalue %arg0[3]
122+
// CHECK: [[val7:%.*]] = llvm.extractvalue %arg0[7]
123+
124+
// CHECK-DAG: [[c16:%.*]] = llvm.mlir.constant(16 : i32)
125+
// CHECK-DAG: [[c32:%.*]] = llvm.mlir.constant(32 : i32)
126+
// CHECK-DAG: [[c48:%.*]] = llvm.mlir.constant(48 : i32)
127+
// CHECK-DAG: [[c64:%.*]] = llvm.mlir.constant(64 : i32)
128+
129+
// CHECK: [[threadId:%.*]] = rocdl.workitem.id.x
130+
// CHECK: [[laneId:%.*]] = llvm.urem [[threadId]], [[c64]]
131+
// CHECK: [[mask0:%.*]] = llvm.icmp "slt" [[laneId]], [[c32]]
132+
133+
// CHECK: [[laneIdRem:%.*]] = llvm.urem [[laneId]], [[c32]]
134+
// CHECK: [[mask1:%.*]] = llvm.icmp "slt" [[laneIdRem]], [[c16]]
135+
136+
// CHECK: [[shflLaneId:%.*]] = llvm.add [[laneId]], [[c16]]
137+
// CHECK: [[addr16:%.*]] = llvm.urem [[shflLaneId]], [[c64]]
138+
139+
// CHECK: [[shflLaneId:%.*]] = llvm.add [[laneId]], [[c32]]
140+
// CHECK: [[addr32:%.*]] = llvm.urem [[shflLaneId]], [[c64]]
141+
142+
// CHECK: [[shflLaneId:%.*]] = llvm.add [[laneId]], [[c48]]
143+
// CHECK: [[addr48:%.*]] = llvm.urem [[shflLaneId]], [[c64]]
144+
145+
// CHECK: [[vec0:%.*]] = llvm.insertelement [[val3]], {{.*}} : vector<4xi8>
146+
// CHECK: [[vec1:%.*]] = llvm.insertelement [[val7]], {{.*}} : vector<4xi8>
147+
148+
// CHECK: [[bvec0:%.*]] = llvm.bitcast [[vec0]]
149+
// CHECK: [[c2:%.*]] = llvm.mlir.constant(2 : i32)
150+
// CHECK: [[addr:%.*]] = llvm.shl [[addr16]], [[c2]]
151+
// CHECK: [[bShflVec0_16:%.*]] = rocdl.ds_bpermute [[addr]], [[bvec0]]
152+
// CHECK: [[shflVec0_16:%.*]] = llvm.bitcast [[bShflVec0_16]]
153+
154+
// CHECK: [[bvec0:%.*]] = llvm.bitcast [[vec0]]
155+
// CHECK: [[c2:%.*]] = llvm.mlir.constant(2 : i32)
156+
// CHECK: [[addr:%.*]] = llvm.shl [[addr32]], [[c2]]
157+
// CHECK: [[bShflVec0_32:%.*]] = rocdl.ds_bpermute [[addr]], [[bvec0]]
158+
// CHECK: [[shflVec0_32:%.*]] = llvm.bitcast [[bShflVec0_32]]
159+
160+
// CHECK: [[bvec1:%.*]] = llvm.bitcast [[vec1]]
161+
// CHECK: [[c2:%.*]] = llvm.mlir.constant(2 : i32)
162+
// CHECK: [[addr:%.*]] = llvm.shl [[addr32]], [[c2]]
163+
// CHECK: [[bShflVec1_32:%.*]] = rocdl.ds_bpermute [[addr]], [[bvec1]]
164+
// CHECK: [[shflVec1_32:%.*]] = llvm.bitcast [[bShflVec1_32]]
165+
166+
// CHECK: [[bvec1:%.*]] = llvm.bitcast [[vec1]]
167+
// CHECK: [[c2:%.*]] = llvm.mlir.constant(2 : i32)
168+
// CHECK: [[addr:%.*]] = llvm.shl [[addr48]], [[c2]]
169+
// CHECK: [[bShflVec1_48:%.*]] = rocdl.ds_bpermute [[addr]], [[bvec1]]
170+
// CHECK: [[shflVec1_48:%.*]] = llvm.bitcast [[bShflVec1_48]]
171+
172+
// Input (8 values): (vec0, vec1)
173+
// Output (8 values shuffled, '>> n' - take the value from (lane + n) % 64):
174+
// resVec0 resVec1
175+
// lanes 0-15: (vec0 , vec0 >> 16) (mask0=1, mask1=1)
176+
// lanes 16-31: (vec0 >> 16, vec0 >> 32) (mask0=1, mask1=0)
177+
// lanes 32-47: (vec1 >> 32, vec1 >> 48) (mask0=0, mask1=1)
178+
// lanes 48-63: (vec1 >> 48, vec1 ) (mask0=0, mask1=0)
179+
180+
// CHECK-DAG: [[mask0_true:%.*]] = llvm.select [[mask1]], [[vec0]], [[shflVec0_16]] : i1, vector<4xi8>
181+
// CHECK-DAG: [[mask0_false:%.*]] = llvm.select [[mask1]], [[shflVec1_32]], [[shflVec1_48]] : i1, vector<4xi8>
182+
// CHECK: [[resVec0:%.*]] = llvm.select [[mask0]], [[mask0_true]], [[mask0_false]] : i1, vector<4xi8>
183+
184+
// CHECK-DAG: [[mask0_true:%.*]] = llvm.select [[mask1]], [[shflVec0_16]], [[shflVec0_32]] : i1, vector<4xi8>
185+
// CHECK-DAG: [[mask0_false:%.*]] = llvm.select [[mask1]], [[shflVec1_48]], [[vec1]] : i1, vector<4xi8>
186+
// CHECK: [[resVec1:%.*]] = llvm.select [[mask0]], [[mask0_true]], [[mask0_false]] : i1, vector<4xi8>
187+
188+
// CHECK: [[c3:%.*]] = llvm.mlir.constant(3 : i32)
189+
// CHECK: [[resVal3:%.*]] = llvm.extractelement [[resVec0]][[[c3]] : i32] : vector<4xi8>
190+
// CHECK: [[c3:%.*]] = llvm.mlir.constant(3 : i32) : i32
191+
// CHECK: [[resVal7:%.*]] = llvm.extractelement [[resVec1]][[[c3]] : i32] : vector<4xi8>
192+
193+
// CHECK: llvm.insertvalue [[resVal3]], {{.*}}[3]
194+
// CHECK: llvm.insertvalue [[resVal7]], {{.*}}[7]
195+
196+
// CHECK: llvm.return
197+
%0 = ttg.convert_layout %arg0 : tensor<128x32xf8E4M3FNUZ, #mfma> -> tensor<128x32xf8E4M3FNUZ, #dotop0>
198+
tt.return
199+
}
200+
}
201+
202+
// -----
203+
204+
#mfma = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16], isTransposed = true}>
205+
#dotop0 = #ttg.dot_op<{opIdx = 0, parent = #mfma, kWidth=8}>
206+
207+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} {
208+
// CHECK-LABEL: mfma_dot_cvt_bf8_mfma16
209+
tt.func public @mfma_dot_cvt_bf8_mfma16(%arg0: tensor<128x32xf8E5M2, #mfma>) {
210+
// CHECK-NOT: store
211+
// CHECK-NOT: load
212+
// CHECK: rocdl.ds_bpermute
213+
// CHECK: llvm.return
214+
%0 = ttg.convert_layout %arg0 : tensor<128x32xf8E5M2, #mfma> -> tensor<128x32xf8E5M2, #dotop0>
215+
tt.return
216+
}
217+
}

0 commit comments

Comments
 (0)