Skip to content

Commit 997830e

Browse files
committed
xegpu: transform op layout attrs support parameters/values/ints
1 parent 2915a04 commit 997830e

File tree

5 files changed

+697
-141
lines changed

5 files changed

+697
-141
lines changed

mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,10 @@
99
#ifndef MLIR_DIALECT_XEGPU_TRANSFORMOPS_XEGPUTRANSFORMOPS_H
1010
#define MLIR_DIALECT_XEGPU_TRANSFORMOPS_XEGPUTRANSFORMOPS_H
1111

12-
#include "mlir/Bytecode/BytecodeOpInterface.h"
13-
#include "mlir/Dialect/SCF/IR/SCF.h"
1412
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
1513
#include "mlir/Dialect/Transform/IR/TransformTypes.h"
1614
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
15+
#include "mlir/Dialect/Utils/StaticValueUtils.h"
1716

1817
#define GET_OP_CLASSES
1918
#include <mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.h.inc>

mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td

Lines changed: 186 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,20 @@
99
#ifndef XEGPU_EXTENSION
1010
#define XEGPU_EXTENSION
1111

12+
include "mlir/Dialect/Transform/IR/TransformAttrs.td"
1213
include "mlir/Dialect/Transform/IR/TransformDialect.td"
1314
include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td"
1415
include "mlir/Dialect/Transform/IR/TransformTypes.td"
15-
include "mlir/IR/OpBase.td"
1616
include "mlir/Interfaces/SideEffectInterfaces.td"
17+
include "mlir/IR/OpBase.td"
18+
19+
// This is roughly similar to OpFoldResult assuming the handle produces a single
20+
// value in the payload IR.
21+
def TransformAnyParamTypeOrAnyHandle : Type<
22+
Or<[TransformHandleTypeInterface.predicate,
23+
TransformParamTypeInterface.predicate]>,
24+
"transform any param type or any handle type">;
25+
1726

1827
def HoistDescOp : Op<Transform_Dialect, "xegpu.hoist_desc_ops", [
1928
TransformOpInterface, TransformEachOpTrait,
@@ -68,8 +77,9 @@ def GetDescOp : Op<Transform_Dialect, "xegpu.get_desc_op", [
6877
}
6978

7079
def SetResultLayoutOp : Op<Transform_Dialect, "xegpu.set_result_layout", [
71-
TransformOpInterface, TransformEachOpTrait,
72-
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>
80+
AttrSizedOperandSegments,
81+
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
82+
TransformOpInterface
7383
]> {
7484

7585
let summary = "Set xegpu.layout attribute to an xegpu op result.";
@@ -81,30 +91,60 @@ def SetResultLayoutOp : Op<Transform_Dialect, "xegpu.set_result_layout", [
8191
defined, `index=0` is used. Returns a handle to a transformed op.
8292
}];
8393

84-
let arguments = (ins TransformHandleTypeInterface : $target,
85-
DefaultValuedOptionalAttr<I64Attr, "0"> : $resultIndex,
86-
DenseI32ArrayAttr : $sgLayout,
87-
DenseI32ArrayAttr : $sgData,
88-
DenseI32ArrayAttr : $instData);
94+
let arguments = (ins
95+
TransformHandleTypeInterface : $target,
96+
I64Attr : $resultIndex,
97+
Variadic<TransformAnyParamTypeOrAnyHandle> : $sg_layout,
98+
Variadic<TransformAnyParamTypeOrAnyHandle> : $sg_data,
99+
Variadic<TransformAnyParamTypeOrAnyHandle> : $inst_data,
100+
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_sg_layout,
101+
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_sg_data,
102+
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_inst_data
103+
);
89104

90105
let results = (outs TransformHandleTypeInterface : $transformed);
91-
92-
let assemblyFormat =
93-
"$target (`index` `=` $resultIndex^)? `sg_layout` `=` $sgLayout `sg_data` `=` "
94-
"$sgData `inst_data` `=` $instData attr-dict `:` functional-type(operands, results)";
106+
let builders = [
107+
OpBuilder<(ins "Value":$target,
108+
"int64_t":$resultIndex,
109+
"ArrayRef<OpFoldResult>":$mixedSgLayout,
110+
"ArrayRef<OpFoldResult>":$mixedSgData,
111+
"ArrayRef<OpFoldResult>":$mixedInstData
112+
)>,
113+
];
114+
115+
let assemblyFormat = [{
116+
$target `index` `=` $resultIndex
117+
`sg_layout` `=` custom<DynamicIndexList>($sg_layout, $static_sg_layout)
118+
`sg_data` `=` custom<DynamicIndexList>($sg_data, $static_sg_data)
119+
`inst_data` `=` custom<DynamicIndexList>($inst_data, $static_inst_data)
120+
attr-dict `:` functional-type(operands, results)
121+
}];
95122

96123
let extraClassDeclaration = [{
97-
::mlir::DiagnosedSilenceableFailure applyToOne(
98-
::mlir::transform::TransformRewriter & rewriter,
99-
::mlir::Operation * target,
100-
::mlir::transform::ApplyToEachResultList & results,
101-
::mlir::transform::TransformState & state);
124+
::mlir::DiagnosedSilenceableFailure apply(
125+
::mlir::transform::TransformRewriter &rewriter,
126+
::mlir::transform::TransformResults &transformResults,
127+
::mlir::transform::TransformState &state);
128+
129+
::llvm::SmallVector<::mlir::OpFoldResult> getMixedSgLayout() {
130+
Builder b(getContext());
131+
return getMixedValues(getStaticSgLayout(), getSgLayout(), b);
132+
}
133+
::llvm::SmallVector<::mlir::OpFoldResult> getMixedSgData() {
134+
Builder b(getContext());
135+
return getMixedValues(getStaticSgData(), getSgData(), b);
136+
}
137+
::llvm::SmallVector<::mlir::OpFoldResult> getMixedInstData() {
138+
Builder b(getContext());
139+
return getMixedValues(getStaticInstData(), getInstData(), b);
140+
}
102141
}];
103142
}
104143

105144
def SetOpLayoutAttrOp : Op<Transform_Dialect, "xegpu.set_op_layout_attr", [
106-
TransformOpInterface, TransformEachOpTrait,
107-
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>
145+
AttrSizedOperandSegments,
146+
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
147+
TransformOpInterface
108148
]> {
109149

110150
let summary = "Set xegpu.layout attribute of an op.";
@@ -117,31 +157,61 @@ def SetOpLayoutAttrOp : Op<Transform_Dialect, "xegpu.set_op_layout_attr", [
117157

118158
let arguments = (ins TransformHandleTypeInterface : $target,
119159
DefaultValuedOptionalAttr<I64Attr, "0"> : $index,
120-
DenseI32ArrayAttr : $sgLayout,
121-
DenseI32ArrayAttr : $sgData,
122-
DenseI32ArrayAttr : $instData,
160+
Variadic<TransformAnyParamTypeOrAnyHandle> : $sg_layout,
161+
Variadic<TransformAnyParamTypeOrAnyHandle> : $sg_data,
162+
Variadic<TransformAnyParamTypeOrAnyHandle> : $inst_data,
163+
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_sg_layout,
164+
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_sg_data,
165+
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_inst_data,
123166
DefaultValuedAttr<UnitAttr, "false">:$result,
124167
DefaultValuedAttr<UnitAttr, "false">:$operand
125168
);
126169

127170
let results = (outs);
128-
129-
let assemblyFormat =
130-
"$target (`result` $result^)? (`operand` $operand^)? (`index` `=` $index^)? `sg_layout` `=` $sgLayout `sg_data` `=` "
131-
"$sgData `inst_data` `=` $instData attr-dict `:` type($target)";
171+
let builders = [
172+
OpBuilder<(ins "Value":$target,
173+
"int64_t":$index,
174+
"ArrayRef<OpFoldResult>":$mixedSgLayout,
175+
"ArrayRef<OpFoldResult>":$mixedSgData,
176+
"ArrayRef<OpFoldResult>":$mixedInstData,
177+
CArg<"bool", "false">:$result,
178+
CArg<"bool", "false">:$operand
179+
)>,
180+
];
181+
182+
let assemblyFormat = [{
183+
$target (`result` $result^)? (`operand` $operand^)? (`index` `=` $index^)?
184+
`sg_layout` `=` custom<DynamicIndexList>($sg_layout, $static_sg_layout)
185+
`sg_data` `=` custom<DynamicIndexList>($sg_data, $static_sg_data)
186+
`inst_data` `=` custom<DynamicIndexList>($inst_data, $static_inst_data)
187+
attr-dict `:` type($target) (`,` type($sg_layout)^)? (`,` type($sg_data)^)? (`,` type($inst_data)^)?
188+
}];
132189

133190
let extraClassDeclaration = [{
134-
::mlir::DiagnosedSilenceableFailure applyToOne(
135-
::mlir::transform::TransformRewriter & rewriter,
136-
::mlir::Operation * target,
137-
::mlir::transform::ApplyToEachResultList & results,
138-
::mlir::transform::TransformState & state);
191+
::mlir::DiagnosedSilenceableFailure apply(
192+
::mlir::transform::TransformRewriter &rewriter,
193+
::mlir::transform::TransformResults &transformResults,
194+
::mlir::transform::TransformState &state);
195+
196+
::llvm::SmallVector<::mlir::OpFoldResult> getMixedSgLayout() {
197+
Builder b(getContext());
198+
return getMixedValues(getStaticSgLayout(), getSgLayout(), b);
199+
}
200+
::llvm::SmallVector<::mlir::OpFoldResult> getMixedSgData() {
201+
Builder b(getContext());
202+
return getMixedValues(getStaticSgData(), getSgData(), b);
203+
}
204+
::llvm::SmallVector<::mlir::OpFoldResult> getMixedInstData() {
205+
Builder b(getContext());
206+
return getMixedValues(getStaticInstData(), getInstData(), b);
207+
}
139208
}];
140209
}
141210

142211
def ConvertOperandLayoutOp : Op<Transform_Dialect, "xegpu.convert_operand_layout", [
143-
TransformOpInterface, TransformEachOpTrait,
144-
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>
212+
AttrSizedOperandSegments,
213+
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
214+
TransformOpInterface
145215
]> {
146216

147217
let summary = "Convert xegpu.layout attribute for an xegpu op operand.";
@@ -154,46 +224,106 @@ def ConvertOperandLayoutOp : Op<Transform_Dialect, "xegpu.convert_operand_layout
154224

155225
let arguments = (ins TransformHandleTypeInterface : $target,
156226
I64Attr : $operandIndex,
157-
DenseI32ArrayAttr : $sgLayout,
158-
DenseI32ArrayAttr : $sgData,
159-
DenseI32ArrayAttr : $instData);
227+
Variadic<TransformAnyParamTypeOrAnyHandle> : $sg_layout,
228+
Variadic<TransformAnyParamTypeOrAnyHandle> : $sg_data,
229+
Variadic<TransformAnyParamTypeOrAnyHandle> : $inst_data,
230+
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_sg_layout,
231+
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_sg_data,
232+
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_inst_data
233+
);
160234

161235
let results = (outs);
162-
163-
let assemblyFormat =
164-
"$target `index` `=` $operandIndex `sg_layout` `=` $sgLayout `sg_data` `=` "
165-
"$sgData `inst_data` `=` $instData attr-dict `:` type($target)";
236+
let builders = [
237+
OpBuilder<(ins "Value":$target,
238+
"int64_t":$index,
239+
"ArrayRef<OpFoldResult>":$mixedSgLayout,
240+
"ArrayRef<OpFoldResult>":$mixedSgData,
241+
"ArrayRef<OpFoldResult>":$mixedInstData
242+
)>,
243+
];
244+
245+
let assemblyFormat = [{
246+
$target `index` `=` $operandIndex
247+
`sg_layout` `=` custom<DynamicIndexList>($sg_layout, $static_sg_layout)
248+
`sg_data` `=` custom<DynamicIndexList>($sg_data, $static_sg_data)
249+
`inst_data` `=` custom<DynamicIndexList>($inst_data, $static_inst_data)
250+
attr-dict `:` type($target) (`,` type($sg_layout)^)? (`,` type($sg_data)^)? (`,` type($inst_data)^)?
251+
}];
166252

167253
let extraClassDeclaration = [{
168-
::mlir::DiagnosedSilenceableFailure applyToOne(
169-
::mlir::transform::TransformRewriter & rewriter,
170-
::mlir::Operation * target,
171-
::mlir::transform::ApplyToEachResultList & results,
172-
::mlir::transform::TransformState & state);
254+
::mlir::DiagnosedSilenceableFailure apply(
255+
::mlir::transform::TransformRewriter &rewriter,
256+
::mlir::transform::TransformResults &transformResults,
257+
::mlir::transform::TransformState &state);
258+
259+
::llvm::SmallVector<::mlir::OpFoldResult> getMixedSgLayout() {
260+
Builder b(getContext());
261+
return getMixedValues(getStaticSgLayout(), getSgLayout(), b);
262+
}
263+
::llvm::SmallVector<::mlir::OpFoldResult> getMixedSgData() {
264+
Builder b(getContext());
265+
return getMixedValues(getStaticSgData(), getSgData(), b);
266+
}
267+
::llvm::SmallVector<::mlir::OpFoldResult> getMixedInstData() {
268+
Builder b(getContext());
269+
return getMixedValues(getStaticInstData(), getInstData(), b);
270+
}
173271
}];
174272
}
175273

176-
def InsertPrefetchOp : Op<Transform_Dialect, "xegpu.insert_prefetch",
177-
[FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
178-
DeclareOpInterfaceMethods<TransformOpInterface>]> {
274+
def InsertPrefetchOp : Op<Transform_Dialect, "xegpu.insert_prefetch", [
275+
AttrSizedOperandSegments,
276+
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
277+
TransformOpInterface
278+
]> {
179279

180280
let summary = "Adds xegpu prefetch ops to matmul operand tiles.";
181281
let description = [{
182282
Given an xegpu operation residing in a `scf.for` loop, this transform inserts cooperative `xegpu.prefetch` operations for the A (index = 0) or B (index = 1) operand. The prefetch tile size is determined by the `sg_layout` and `sg_data` attributes.
183283
}];
184284

185285
let arguments = (ins TransformHandleTypeInterface : $target,
186-
TransformHandleTypeInterface : $loopOp,
286+
TransformHandleTypeInterface : $loop,
187287
I64Attr : $operandIndex,
188-
DenseI32ArrayAttr : $sgLayout,
189-
DenseI32ArrayAttr : $sgData);
288+
Variadic<TransformAnyParamTypeOrAnyHandle> : $sg_layout,
289+
Variadic<TransformAnyParamTypeOrAnyHandle> : $sg_data,
290+
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_sg_layout,
291+
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_sg_data
292+
);
190293

191-
let results = (outs TransformHandleTypeInterface : $transformedTargetOp,
192-
TransformHandleTypeInterface : $transformedLoopOp);
294+
let results = (outs TransformHandleTypeInterface : $transformedTarget,
295+
TransformHandleTypeInterface : $transformedLoop);
296+
let builders = [
297+
OpBuilder<(ins "Value":$target,
298+
"Value":$loop,
299+
"int64_t":$operandIndex,
300+
"ArrayRef<OpFoldResult>":$mixedSgLayout,
301+
"ArrayRef<OpFoldResult>":$mixedSgData
302+
)>,
303+
];
304+
305+
let assemblyFormat = [{
306+
$target $loop `index` `=` $operandIndex
307+
`sg_layout` `=` custom<DynamicIndexList>($sg_layout, $static_sg_layout)
308+
`sg_data` `=` custom<DynamicIndexList>($sg_data, $static_sg_data)
309+
attr-dict `:` functional-type(operands, results)
310+
}];
193311

194-
let assemblyFormat =
195-
"$target $loopOp `index` `=` $operandIndex `sg_layout` `=` $sgLayout `sg_data` `=` "
196-
"$sgData attr-dict `:` functional-type(operands, results)";
312+
let extraClassDeclaration = [{
313+
::mlir::DiagnosedSilenceableFailure apply(
314+
::mlir::transform::TransformRewriter &rewriter,
315+
::mlir::transform::TransformResults &transformResults,
316+
::mlir::transform::TransformState &state);
317+
318+
::llvm::SmallVector<::mlir::OpFoldResult> getMixedSgLayout() {
319+
Builder b(getContext());
320+
return getMixedValues(getStaticSgLayout(), getSgLayout(), b);
321+
}
322+
::llvm::SmallVector<::mlir::OpFoldResult> getMixedSgData() {
323+
Builder b(getContext());
324+
return getMixedValues(getStaticSgData(), getSgData(), b);
325+
}
326+
}];
197327
}
198328

199329
// TODO this should be handled with gpu transform ops.

0 commit comments

Comments
 (0)