99#ifndef XEGPU_EXTENSION
1010#define XEGPU_EXTENSION
1111
12+ include "mlir/Dialect/Transform/IR/TransformAttrs.td"
1213include "mlir/Dialect/Transform/IR/TransformDialect.td"
1314include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td"
1415include "mlir/Dialect/Transform/IR/TransformTypes.td"
15- include "mlir/IR/OpBase.td"
1616include "mlir/Interfaces/SideEffectInterfaces.td"
17+ include "mlir/IR/OpBase.td"
18+
19+ // This is roughly similar to OpFoldResult assuming the handle produces a single
20+ // value in the payload IR.
21+ def TransformAnyParamTypeOrAnyHandle : Type<
22+ Or<[TransformHandleTypeInterface.predicate,
23+ TransformParamTypeInterface.predicate]>,
24+ "transform any param type or any handle type">;
25+
1726
1827def HoistDescOp : Op<Transform_Dialect, "xegpu.hoist_desc_ops", [
1928 TransformOpInterface, TransformEachOpTrait,
@@ -68,8 +77,9 @@ def GetDescOp : Op<Transform_Dialect, "xegpu.get_desc_op", [
6877}
6978
7079def SetResultLayoutOp : Op<Transform_Dialect, "xegpu.set_result_layout", [
71- TransformOpInterface, TransformEachOpTrait,
72- DeclareOpInterfaceMethods<MemoryEffectsOpInterface>
80+ AttrSizedOperandSegments,
81+ DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
82+ TransformOpInterface
7383]> {
7484
7585 let summary = "Set xegpu.layout attribute to an xegpu op result.";
@@ -81,30 +91,60 @@ def SetResultLayoutOp : Op<Transform_Dialect, "xegpu.set_result_layout", [
8191 defined, `index=0` is used. Returns a handle to a transformed op.
8292 }];
8393
84- let arguments = (ins TransformHandleTypeInterface : $target,
85- DefaultValuedOptionalAttr<I64Attr, "0"> : $resultIndex,
86- DenseI32ArrayAttr : $sgLayout,
87- DenseI32ArrayAttr : $sgData,
88- DenseI32ArrayAttr : $instData);
94+ let arguments = (ins
95+ TransformHandleTypeInterface : $target,
96+ I64Attr : $resultIndex,
97+ Variadic<TransformAnyParamTypeOrAnyHandle> : $sg_layout,
98+ Variadic<TransformAnyParamTypeOrAnyHandle> : $sg_data,
99+ Variadic<TransformAnyParamTypeOrAnyHandle> : $inst_data,
100+ DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_sg_layout,
101+ DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_sg_data,
102+ DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_inst_data
103+ );
89104
90105 let results = (outs TransformHandleTypeInterface : $transformed);
91-
92- let assemblyFormat =
93- "$target (`index` `=` $resultIndex^)? `sg_layout` `=` $sgLayout `sg_data` `=` "
94- "$sgData `inst_data` `=` $instData attr-dict `:` functional-type(operands, results)";
106+ let builders = [
107+ OpBuilder<(ins "Value":$target,
108+ "int64_t":$resultIndex,
109+ "ArrayRef<OpFoldResult>":$mixedSgLayout,
110+ "ArrayRef<OpFoldResult>":$mixedSgData,
111+ "ArrayRef<OpFoldResult>":$mixedInstData
112+ )>,
113+ ];
114+
115+ let assemblyFormat = [{
116+ $target `index` `=` $resultIndex
117+ `sg_layout` `=` custom<DynamicIndexList>($sg_layout, $static_sg_layout)
118+ `sg_data` `=` custom<DynamicIndexList>($sg_data, $static_sg_data)
119+ `inst_data` `=` custom<DynamicIndexList>($inst_data, $static_inst_data)
120+ attr-dict `:` functional-type(operands, results)
121+ }];
95122
96123 let extraClassDeclaration = [{
97- ::mlir::DiagnosedSilenceableFailure applyToOne(
98- ::mlir::transform::TransformRewriter & rewriter,
99- ::mlir::Operation * target,
100- ::mlir::transform::ApplyToEachResultList & results,
101- ::mlir::transform::TransformState & state);
124+ ::mlir::DiagnosedSilenceableFailure apply(
125+ ::mlir::transform::TransformRewriter &rewriter,
126+ ::mlir::transform::TransformResults &transformResults,
127+ ::mlir::transform::TransformState &state);
128+
129+ ::llvm::SmallVector<::mlir::OpFoldResult> getMixedSgLayout() {
130+ Builder b(getContext());
131+ return getMixedValues(getStaticSgLayout(), getSgLayout(), b);
132+ }
133+ ::llvm::SmallVector<::mlir::OpFoldResult> getMixedSgData() {
134+ Builder b(getContext());
135+ return getMixedValues(getStaticSgData(), getSgData(), b);
136+ }
137+ ::llvm::SmallVector<::mlir::OpFoldResult> getMixedInstData() {
138+ Builder b(getContext());
139+ return getMixedValues(getStaticInstData(), getInstData(), b);
140+ }
102141 }];
103142}
104143
105144def SetOpLayoutAttrOp : Op<Transform_Dialect, "xegpu.set_op_layout_attr", [
106- TransformOpInterface, TransformEachOpTrait,
107- DeclareOpInterfaceMethods<MemoryEffectsOpInterface>
145+ AttrSizedOperandSegments,
146+ DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
147+ TransformOpInterface
108148]> {
109149
110150 let summary = "Set xegpu.layout attribute of an op.";
@@ -117,31 +157,61 @@ def SetOpLayoutAttrOp : Op<Transform_Dialect, "xegpu.set_op_layout_attr", [
117157
118158 let arguments = (ins TransformHandleTypeInterface : $target,
119159 DefaultValuedOptionalAttr<I64Attr, "0"> : $index,
120- DenseI32ArrayAttr : $sgLayout,
121- DenseI32ArrayAttr : $sgData,
122- DenseI32ArrayAttr : $instData,
160+ Variadic<TransformAnyParamTypeOrAnyHandle> : $sg_layout,
161+ Variadic<TransformAnyParamTypeOrAnyHandle> : $sg_data,
162+ Variadic<TransformAnyParamTypeOrAnyHandle> : $inst_data,
163+ DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_sg_layout,
164+ DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_sg_data,
165+ DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_inst_data,
123166 DefaultValuedAttr<UnitAttr, "false">:$result,
124167 DefaultValuedAttr<UnitAttr, "false">:$operand
125168 );
126169
127170 let results = (outs);
128-
129- let assemblyFormat =
130- "$target (`result` $result^)? (`operand` $operand^)? (`index` `=` $index^)? `sg_layout` `=` $sgLayout `sg_data` `=` "
131- "$sgData `inst_data` `=` $instData attr-dict `:` type($target)";
171+ let builders = [
172+ OpBuilder<(ins "Value":$target,
173+ "int64_t":$index,
174+ "ArrayRef<OpFoldResult>":$mixedSgLayout,
175+ "ArrayRef<OpFoldResult>":$mixedSgData,
176+ "ArrayRef<OpFoldResult>":$mixedInstData,
177+ CArg<"bool", "false">:$result,
178+ CArg<"bool", "false">:$operand
179+ )>,
180+ ];
181+
182+ let assemblyFormat = [{
183+ $target (`result` $result^)? (`operand` $operand^)? (`index` `=` $index^)?
184+ `sg_layout` `=` custom<DynamicIndexList>($sg_layout, $static_sg_layout)
185+ `sg_data` `=` custom<DynamicIndexList>($sg_data, $static_sg_data)
186+ `inst_data` `=` custom<DynamicIndexList>($inst_data, $static_inst_data)
187+ attr-dict `:` type($target) (`,` type($sg_layout)^)? (`,` type($sg_data)^)? (`,` type($inst_data)^)?
188+ }];
132189
133190 let extraClassDeclaration = [{
134- ::mlir::DiagnosedSilenceableFailure applyToOne(
135- ::mlir::transform::TransformRewriter & rewriter,
136- ::mlir::Operation * target,
137- ::mlir::transform::ApplyToEachResultList & results,
138- ::mlir::transform::TransformState & state);
191+ ::mlir::DiagnosedSilenceableFailure apply(
192+ ::mlir::transform::TransformRewriter &rewriter,
193+ ::mlir::transform::TransformResults &transformResults,
194+ ::mlir::transform::TransformState &state);
195+
196+ ::llvm::SmallVector<::mlir::OpFoldResult> getMixedSgLayout() {
197+ Builder b(getContext());
198+ return getMixedValues(getStaticSgLayout(), getSgLayout(), b);
199+ }
200+ ::llvm::SmallVector<::mlir::OpFoldResult> getMixedSgData() {
201+ Builder b(getContext());
202+ return getMixedValues(getStaticSgData(), getSgData(), b);
203+ }
204+ ::llvm::SmallVector<::mlir::OpFoldResult> getMixedInstData() {
205+ Builder b(getContext());
206+ return getMixedValues(getStaticInstData(), getInstData(), b);
207+ }
139208 }];
140209}
141210
142211def ConvertOperandLayoutOp : Op<Transform_Dialect, "xegpu.convert_operand_layout", [
143- TransformOpInterface, TransformEachOpTrait,
144- DeclareOpInterfaceMethods<MemoryEffectsOpInterface>
212+ AttrSizedOperandSegments,
213+ DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
214+ TransformOpInterface
145215]> {
146216
147217 let summary = "Convert xegpu.layout attribute for an xegpu op operand.";
@@ -154,46 +224,106 @@ def ConvertOperandLayoutOp : Op<Transform_Dialect, "xegpu.convert_operand_layout
154224
155225 let arguments = (ins TransformHandleTypeInterface : $target,
156226 I64Attr : $operandIndex,
157- DenseI32ArrayAttr : $sgLayout,
158- DenseI32ArrayAttr : $sgData,
159- DenseI32ArrayAttr : $instData);
227+ Variadic<TransformAnyParamTypeOrAnyHandle> : $sg_layout,
228+ Variadic<TransformAnyParamTypeOrAnyHandle> : $sg_data,
229+ Variadic<TransformAnyParamTypeOrAnyHandle> : $inst_data,
230+ DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_sg_layout,
231+ DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_sg_data,
232+ DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_inst_data
233+ );
160234
161235 let results = (outs);
162-
163- let assemblyFormat =
164- "$target `index` `=` $operandIndex `sg_layout` `=` $sgLayout `sg_data` `=` "
165- "$sgData `inst_data` `=` $instData attr-dict `:` type($target)";
236+ let builders = [
237+ OpBuilder<(ins "Value":$target,
238+ "int64_t":$index,
239+ "ArrayRef<OpFoldResult>":$mixedSgLayout,
240+ "ArrayRef<OpFoldResult>":$mixedSgData,
241+ "ArrayRef<OpFoldResult>":$mixedInstData
242+ )>,
243+ ];
244+
245+ let assemblyFormat = [{
246+ $target `index` `=` $operandIndex
247+ `sg_layout` `=` custom<DynamicIndexList>($sg_layout, $static_sg_layout)
248+ `sg_data` `=` custom<DynamicIndexList>($sg_data, $static_sg_data)
249+ `inst_data` `=` custom<DynamicIndexList>($inst_data, $static_inst_data)
250+ attr-dict `:` type($target) (`,` type($sg_layout)^)? (`,` type($sg_data)^)? (`,` type($inst_data)^)?
251+ }];
166252
167253 let extraClassDeclaration = [{
168- ::mlir::DiagnosedSilenceableFailure applyToOne(
169- ::mlir::transform::TransformRewriter & rewriter,
170- ::mlir::Operation * target,
171- ::mlir::transform::ApplyToEachResultList & results,
172- ::mlir::transform::TransformState & state);
254+ ::mlir::DiagnosedSilenceableFailure apply(
255+ ::mlir::transform::TransformRewriter &rewriter,
256+ ::mlir::transform::TransformResults &transformResults,
257+ ::mlir::transform::TransformState &state);
258+
259+ ::llvm::SmallVector<::mlir::OpFoldResult> getMixedSgLayout() {
260+ Builder b(getContext());
261+ return getMixedValues(getStaticSgLayout(), getSgLayout(), b);
262+ }
263+ ::llvm::SmallVector<::mlir::OpFoldResult> getMixedSgData() {
264+ Builder b(getContext());
265+ return getMixedValues(getStaticSgData(), getSgData(), b);
266+ }
267+ ::llvm::SmallVector<::mlir::OpFoldResult> getMixedInstData() {
268+ Builder b(getContext());
269+ return getMixedValues(getStaticInstData(), getInstData(), b);
270+ }
173271 }];
174272}
175273
176- def InsertPrefetchOp : Op<Transform_Dialect, "xegpu.insert_prefetch",
177- [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
178- DeclareOpInterfaceMethods<TransformOpInterface>]> {
274+ def InsertPrefetchOp : Op<Transform_Dialect, "xegpu.insert_prefetch", [
275+ AttrSizedOperandSegments,
276+ DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
277+ TransformOpInterface
278+ ]> {
179279
180280 let summary = "Adds xegpu prefetch ops to matmul operand tiles.";
181281 let description = [{
182282 Given an xegpu operation residing in a `scf.for` loop, this transform inserts cooperative `xegpu.prefetch` operations for the A (index = 0) or B (index = 1) operand. The prefetch tile size is determined by the `sg_layout` and `sg_data` attributes.
183283 }];
184284
185285 let arguments = (ins TransformHandleTypeInterface : $target,
186- TransformHandleTypeInterface : $loopOp ,
286+ TransformHandleTypeInterface : $loop ,
187287 I64Attr : $operandIndex,
188- DenseI32ArrayAttr : $sgLayout,
189- DenseI32ArrayAttr : $sgData);
288+ Variadic<TransformAnyParamTypeOrAnyHandle> : $sg_layout,
289+ Variadic<TransformAnyParamTypeOrAnyHandle> : $sg_data,
290+ DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_sg_layout,
291+ DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_sg_data
292+ );
190293
191- let results = (outs TransformHandleTypeInterface : $transformedTargetOp,
192- TransformHandleTypeInterface : $transformedLoopOp);
294+ let results = (outs TransformHandleTypeInterface : $transformedTarget,
295+ TransformHandleTypeInterface : $transformedLoop);
296+ let builders = [
297+ OpBuilder<(ins "Value":$target,
298+ "Value":$loop,
299+ "int64_t":$operandIndex,
300+ "ArrayRef<OpFoldResult>":$mixedSgLayout,
301+ "ArrayRef<OpFoldResult>":$mixedSgData
302+ )>,
303+ ];
304+
305+ let assemblyFormat = [{
306+ $target $loop `index` `=` $operandIndex
307+ `sg_layout` `=` custom<DynamicIndexList>($sg_layout, $static_sg_layout)
308+ `sg_data` `=` custom<DynamicIndexList>($sg_data, $static_sg_data)
309+ attr-dict `:` functional-type(operands, results)
310+ }];
193311
194- let assemblyFormat =
195- "$target $loopOp `index` `=` $operandIndex `sg_layout` `=` $sgLayout `sg_data` `=` "
196- "$sgData attr-dict `:` functional-type(operands, results)";
312+ let extraClassDeclaration = [{
313+ ::mlir::DiagnosedSilenceableFailure apply(
314+ ::mlir::transform::TransformRewriter &rewriter,
315+ ::mlir::transform::TransformResults &transformResults,
316+ ::mlir::transform::TransformState &state);
317+
318+ ::llvm::SmallVector<::mlir::OpFoldResult> getMixedSgLayout() {
319+ Builder b(getContext());
320+ return getMixedValues(getStaticSgLayout(), getSgLayout(), b);
321+ }
322+ ::llvm::SmallVector<::mlir::OpFoldResult> getMixedSgData() {
323+ Builder b(getContext());
324+ return getMixedValues(getStaticSgData(), getSgData(), b);
325+ }
326+ }];
197327}
198328
199329// TODO this should be handled with gpu transform ops.
0 commit comments