Skip to content

Commit a76c71b

Browse files
authored
[mlir][amdgpu] Add scaled_ext_packed{8,16} operations (llvm#159830)
1 parent d6191b8 commit a76c71b

File tree

5 files changed

+198
-1
lines changed

5 files changed

+198
-1
lines changed

mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td

Lines changed: 92 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,97 @@ def AMDGPU_ExtPackedFp8Op :
112112
}];
113113
}
114114

115+
def IsValidBlockSize: AttrConstraint<
116+
CPred<"::llvm::is_contained({16, 32}, ::llvm::cast<::mlir::IntegerAttr>($_self).getInt())">,
117+
"whose value is 16 or 32">;
118+
119+
def AMDGPU_ScaledExtPacked816Op
120+
: AMDGPU_Op<"scaled_ext_packed816", [Pure, AllShapesMatch<["source", "res"]>]>,
121+
Arguments<(
122+
ins AnyTypeOf<[FixedVectorOfShapeAndType<[8], F4E2M1FN>,
123+
FixedVectorOfShapeAndType<[8], F8E4M3FN>,
124+
FixedVectorOfShapeAndType<[8], F8E5M2>,
125+
FixedVectorOfShapeAndType<[16], F6E2M3FN>,
126+
FixedVectorOfShapeAndType<[16], F6E3M2FN>]>:$source,
127+
FixedVectorOfShapeAndType<[4], F8E8M0FNU>:$scale,
128+
ConfinedAttr<I32Attr, [IsValidBlockSize]>:$blockSize,
129+
ConfinedAttr<I32Attr, [IntMinValue<0>, IntMaxValue<1>]>:$firstScaleLane,
130+
ConfinedAttr<I32Attr, [IntMinValue<0>, IntMaxValue<2>]>:$firstScaleByte)>,
131+
Results<(
132+
outs AnyTypeOf<[FixedVectorOfShapeAndType<[8], F32>,
133+
FixedVectorOfShapeAndType<[8], F16>,
134+
FixedVectorOfShapeAndType<[8], BF16>,
135+
FixedVectorOfShapeAndType<[16], F32>,
136+
FixedVectorOfShapeAndType<[16], F16>,
137+
FixedVectorOfShapeAndType<[16], BF16>]>:$res)> {
138+
139+
let summary = "Extend a vector of packed floating point values";
140+
141+
let description = [{
142+
The scales applied to the input microfloats are stored in two bytes which
143+
come from the `scales` input provided in a *half* of the wave identified
144+
by `firstScaleLane`. The pair of bytes used is selected by
145+
`firstScaleByte`. The 16 vectors in consecutive lanes starting from
146+
`firstScaleLane` (which we'll call the scale vectors) will be used by both
147+
halves of the wave (with lane L reading from L % 16'th scale vector), but
148+
each half will use a different byte.
149+
150+
When the block size is 32, `firstScaleByte` can be either 0 or 2,
151+
selecting halves of the scale vectors. Lanes 0-15 will read from
152+
`firstScaleByte` and lanes 16-31 will read from `firstScaleByte` + 1.
153+
For example:
154+
```mlir
155+
// Input: 8-element vector of F8E4M3FN, converting to F32
156+
// Lanes 0-15 read from byte 0, lanes 16-31 read from byte 1
157+
%result = amdgpu.scaled_ext_packed816 %source scale(%scales)
158+
blockSize(32) firstScaleLane(0) firstScaleByte(0)
159+
: vector<8xf8E4M3FN>, vector<4xf8E8M0FNU> -> vector<8xf32>
160+
161+
// Input: 16-element vector of F6E2M3FN, converting to F16
162+
// Lanes 0-15 read from byte 2, lanes 16-31 read from byte 3
163+
%result = amdgpu.scaled_ext_packed816 %source scale(%scales)
164+
blockSize(32) firstScaleLane(1) firstScaleByte(2)
165+
: vector<16xf6E2M3FN>, vector<4xf8E8M0FNU> -> vector<16xf16>
166+
```
167+
168+
However, when the block size is 16, `firstScaleByte` can be 0 or 1.
169+
Lanes 0-15 read from the `firstScaleByte`th element of the scale vectors,
170+
while lanes 16-31 read from `firstScaleByte` + 2.
171+
For example:
172+
```mlir
173+
// Input: 8-element vector of F8E5M2, converting to BF16
174+
// Lanes 0-15 read from byte 0, lanes 16-31 read from byte 2 (0+2)
175+
%result = amdgpu.scaled_ext_packed816 %source scale(%scales)
176+
blockSize(16) firstScaleLane(0) firstScaleByte(0)
177+
: vector<8xf8E5M2>, vector<4xf8E8M0FNU> -> vector<8xbf16>
178+
179+
// Input: 16-element vector of F6E3M2FN, converting to F32
180+
// Lanes 0-15 read from byte 1, lanes 16-31 read from byte 3 (1+2)
181+
%result = amdgpu.scaled_ext_packed816 %source scale(%scales)
182+
blockSize(16) firstScaleLane(1) firstScaleByte(1)
183+
: vector<16xf6E3M2FN>, vector<4xf8E8M0FNU> -> vector<16xf32>
184+
```
185+
186+
Note: the layout for the scales generally mirrors how the WMMA
187+
instructions use for matix scales. These selection operands allows
188+
one to choose portions of the matrix to convert.
189+
190+
Available on gfx1250+.
191+
}];
192+
193+
let assemblyFormat = [{
194+
attr-dict $source
195+
`scale` `(` $scale `)`
196+
`blockSize` `(` $blockSize `)`
197+
`firstScaleLane` `(` $firstScaleLane`)`
198+
`firstScaleByte` `(` $firstScaleByte `)`
199+
`:` type($source) `,` type($scale) `->` type($res)
200+
}];
201+
202+
let hasVerifier = 1;
203+
204+
}
205+
115206
def AMDGPU_ScaledExtPackedOp
116207
: AMDGPU_Op<"scaled_ext_packed", [Pure]>,
117208
Arguments<(
@@ -860,7 +951,7 @@ def AMDGPU_MFMAOp :
860951
based on the provided `m`, `k`, `n`, and `nBlks` attributes, along with the
861952
types of the source and destination arguments.
862953

863-
For information on the layouts of the input and output matrces (which are stored
954+
For information on the layouts of the input and output matrices (which are stored
864955
in `sourceA`, `sourceB`, `destC`, and `destD`), see the CDNA ISA documentation.
865956

866957
The `cbsz`, `abid`, and `blgp` parameters control how the lanes of the wave

mlir/include/mlir/IR/CommonTypeConstraints.td

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -623,6 +623,14 @@ class VectorOfLengthAndType<list<int> allowedLengths,
623623
VectorOfNonZeroRankOf<allowedTypes>.summary # VectorOfLength<allowedLengths>.summary,
624624
"::mlir::VectorType">;
625625

626+
class FixedVectorOfShapeAndType<list<int> shape, Type elType>: ShapedContainerType<
627+
[elType],
628+
And<[IsVectorOfShape<shape>, IsFixedVectorOfAnyRankTypePred]>,
629+
"vector<" # !interleave(shape, "x") # "x" # elType # ">",
630+
"::mlir::VectorType">,
631+
BuildableType<"::mlir::VectorType::get({" # !interleave(shape, " ,") # "} , " # elType.builderCall # " );">;
632+
633+
626634
// Any fixed-length vector where the number of elements is from the given
627635
// `allowedLengths` list and the type is from the given `allowedTypes` list
628636
class FixedVectorOfLengthAndType<list<int> allowedLengths,

mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -338,6 +338,25 @@ void RawBufferAtomicCmpswapOp::getCanonicalizationPatterns(
338338
context);
339339
}
340340

341+
//===----------------------------------------------------------------------===//
342+
// ScaledExtPacked816Op
343+
//===----------------------------------------------------------------------===//
344+
LogicalResult ScaledExtPacked816Op::verify() {
345+
int blockSize = getBlockSize();
346+
assert((blockSize == 16 || blockSize == 32) && "invalid block size");
347+
int firstScaleByte = getFirstScaleByte();
348+
if (blockSize == 16 && !llvm::is_contained({0, 1}, firstScaleByte)) {
349+
return emitOpError(
350+
"blockSize of 16 can only have firstScaleByte be 0 or 1.");
351+
}
352+
if (blockSize == 32 && !llvm::is_contained({0, 2}, firstScaleByte)) {
353+
return emitOpError(
354+
"blockSize of 32 can only have firstScaleByte be 0 or 2.");
355+
}
356+
357+
return success();
358+
}
359+
341360
//===----------------------------------------------------------------------===//
342361
// WMMAOp
343362
//===----------------------------------------------------------------------===//

mlir/test/Dialect/AMDGPU/invalid.mlir

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,3 +238,27 @@ func.func @gather_to_lds_non_lds(%idx1 : index, %mem1 : memref<32xf16>, %mem2 :
238238
amdgpu.gather_to_lds %mem1[%idx1], %mem2[%idx1] : vector<2xf16>, memref<32xf16>, memref<32xf16, strided<[?]>, #gpu.address_space<workgroup>>
239239
func.return
240240
}
241+
242+
// -----
243+
244+
func.func @amdgpu.scaled_ext_packed816_invalid_block_size_and_first_scale_byte_16(%v: vector<8xf8E5M2>, %scale: vector<4xf8E8M0FNU>) {
245+
// expected-error@+1 {{'amdgpu.scaled_ext_packed816' op blockSize of 16 can only have firstScaleByte be 0 or 1.}}
246+
%ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(16) firstScaleLane(0) firstScaleByte(2) : vector<8xf8E5M2>, vector<4xf8E8M0FNU> -> vector<8xf16>
247+
func.return
248+
}
249+
250+
// -----
251+
252+
func.func @amdgpu.scaled_ext_packed816_invalid_block_size_and_first_scale_byte_32(%v: vector<8xf8E5M2>, %scale: vector<4xf8E8M0FNU>) {
253+
// expected-error@+1 {{'amdgpu.scaled_ext_packed816' op blockSize of 32 can only have firstScaleByte be 0 or 2.}}
254+
%ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(1) : vector<8xf8E5M2>, vector<4xf8E8M0FNU> -> vector<8xf16>
255+
func.return
256+
}
257+
258+
// -----
259+
260+
func.func @amdgpu.scaled_ext_packed816_invalid_input_output_sizes(%v: vector<8xf8E5M2>, %scale: vector<4xf8E8M0FNU>) {
261+
// expected-error@+1 {{'amdgpu.scaled_ext_packed816' op failed to verify that all of {source, res} have same shape}}
262+
%ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(16) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E5M2>, vector<4xf8E8M0FNU> -> vector<16xf16>
263+
func.return
264+
}

mlir/test/Dialect/AMDGPU/ops.mlir

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,61 @@ func.func @scaled_ext_scalar_f4e2m1_bf16(%v: vector<2xf4E2M1FN>, %scale: f32) ->
221221
func.return %ret : vector<2xbf16>
222222
}
223223

224+
// CHECK-LABEL: func.func @scaled_ext_packed816_fp4
225+
func.func @scaled_ext_packed816_fp4(%v: vector<8xf4E2M1FN>, %scale: vector<4xf8E8M0FNU>) -> (vector<8xf16>, vector<8xbf16>, vector<8xf32>) {
226+
// CHECK: amdgpu.scaled_ext_packed816
227+
%ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf4E2M1FN>, vector<4xf8E8M0FNU> -> vector<8xf16>
228+
// CHECK: amdgpu.scaled_ext_packed816
229+
%ret1 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf4E2M1FN>, vector<4xf8E8M0FNU> -> vector<8xbf16>
230+
// CHECK: amdgpu.scaled_ext_packed816
231+
%ret2 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf4E2M1FN>, vector<4xf8E8M0FNU> -> vector<8xf32>
232+
func.return %ret0, %ret1, %ret2 : vector<8xf16>, vector<8xbf16>, vector<8xf32>
233+
}
234+
235+
// CHECK-LABEL: func.func @scaled_ext_packed816_fp8
236+
func.func @scaled_ext_packed816_fp8(%v: vector<8xf8E4M3FN>, %scale: vector<4xf8E8M0FNU>) -> (vector<8xf16>, vector<8xbf16>, vector<8xf32>) {
237+
// CHECK: amdgpu.scaled_ext_packed816
238+
%ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E4M3FN>, vector<4xf8E8M0FNU> -> vector<8xf16>
239+
// CHECK: amdgpu.scaled_ext_packed816
240+
%ret1 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E4M3FN>, vector<4xf8E8M0FNU> -> vector<8xbf16>
241+
// CHECK: amdgpu.scaled_ext_packed816
242+
%ret2 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E4M3FN>, vector<4xf8E8M0FNU> -> vector<8xf32>
243+
func.return %ret0, %ret1, %ret2 : vector<8xf16>, vector<8xbf16>, vector<8xf32>
244+
}
245+
246+
// CHECK-LABEL: func.func @scaled_ext_packed816_bf8
247+
func.func @scaled_ext_packed816_bf8(%v: vector<8xf8E5M2>, %scale: vector<4xf8E8M0FNU>) -> (vector<8xf16>, vector<8xbf16>, vector<8xf32>) {
248+
// CHECK: amdgpu.scaled_ext_packed816
249+
%ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E5M2>, vector<4xf8E8M0FNU> -> vector<8xf16>
250+
// CHECK: amdgpu.scaled_ext_packed816
251+
%ret1 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E5M2>, vector<4xf8E8M0FNU> -> vector<8xbf16>
252+
// CHECK: amdgpu.scaled_ext_packed816
253+
%ret2 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E5M2>, vector<4xf8E8M0FNU> -> vector<8xf32>
254+
func.return %ret0, %ret1, %ret2 : vector<8xf16>, vector<8xbf16>, vector<8xf32>
255+
}
256+
257+
// CHECK-LABEL: func.func @scaled_ext_packed816_fp6
258+
func.func @scaled_ext_packed816_fp6(%v: vector<16xf6E2M3FN>, %scale: vector<4xf8E8M0FNU>) -> (vector<16xf16>, vector<16xbf16>, vector<16xf32>) {
259+
// CHECK: amdgpu.scaled_ext_packed816
260+
%ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<16xf6E2M3FN>, vector<4xf8E8M0FNU> -> vector<16xf16>
261+
// CHECK: amdgpu.scaled_ext_packed816
262+
%ret1 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<16xf6E2M3FN>, vector<4xf8E8M0FNU> -> vector<16xbf16>
263+
// CHECK: amdgpu.scaled_ext_packed816
264+
%ret2 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<16xf6E2M3FN>, vector<4xf8E8M0FNU> -> vector<16xf32>
265+
func.return %ret0, %ret1, %ret2 : vector<16xf16>, vector<16xbf16>, vector<16xf32>
266+
}
267+
268+
// CHECK-LABEL: func.func @scaled_ext_packed816_bf16
269+
func.func @scaled_ext_packed816_bf16(%v: vector<16xf6E3M2FN>, %scale: vector<4xf8E8M0FNU>) -> (vector<16xf16>, vector<16xbf16>, vector<16xf32>) {
270+
// CHECK: amdgpu.scaled_ext_packed816
271+
%ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<16xf6E3M2FN>, vector<4xf8E8M0FNU> -> vector<16xf16>
272+
// CHECK: amdgpu.scaled_ext_packed816
273+
%ret1 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<16xf6E3M2FN>, vector<4xf8E8M0FNU> -> vector<16xbf16>
274+
// CHECK: amdgpu.scaled_ext_packed816
275+
%ret2 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<16xf6E3M2FN>, vector<4xf8E8M0FNU> -> vector<16xf32>
276+
func.return %ret0, %ret1, %ret2 : vector<16xf16>, vector<16xbf16>, vector<16xf32>
277+
}
278+
224279
// CHECK-LABEL: func.func @packed_scaled_trunc_f8e4m3_f32
225280
// CHECK: amdgpu.packed_scaled_trunc
226281
func.func @packed_scaled_trunc_f8e4m3_f32(%v: vector<2xf32>, %scale: f32) -> vector<4xf8E4M3FN> {

0 commit comments

Comments
 (0)