Skip to content

Commit 2d03839

Browse files
committed
xegpu: add set_result_layout op
1 parent 4c92e6b commit 2d03839

File tree

5 files changed

+239
-2
lines changed

5 files changed

+239
-2
lines changed

mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,41 @@ def GetDescOp : Op<Transform_Dialect, "xegpu.get_desc_op", [
6767
}];
6868
}
6969

70+
def SetResultLayoutOp : Op<Transform_Dialect, "xegpu.set_result_layout", [
71+
TransformOpInterface, TransformEachOpTrait,
72+
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>
73+
]> {
74+
75+
let summary = "Set xegpu.layout attribute to an xegpu op result.";
76+
let description = [{
77+
Given an xegpu operation, this transform adds `xegpu.layout`
78+
attribute to it's result's tensor descriptor. The target result is
79+
defined by the `index` argument. The layout is defined by the
80+
`sg_layout`, `sg_data` and `inst_data` attributes. If `index` is not
81+
defined, `index=0` is used. Returns a handle to a transformed op.
82+
}];
83+
84+
let arguments = (ins TransformHandleTypeInterface : $target,
85+
OptionalAttr<I64Attr> : $resultIndex,
86+
DenseI32ArrayAttr : $sgLayout,
87+
DenseI32ArrayAttr : $sgData,
88+
DenseI32ArrayAttr : $instData);
89+
90+
let results = (outs TransformHandleTypeInterface : $transformed);
91+
92+
let assemblyFormat =
93+
"$target (`index` `=` $resultIndex^)? `sg_layout` `=` $sgLayout `sg_data` `=` "
94+
"$sgData `inst_data` `=` $instData attr-dict `:` functional-type(operands, results)";
95+
96+
let extraClassDeclaration = [{
97+
::mlir::DiagnosedSilenceableFailure applyToOne(
98+
::mlir::transform::TransformRewriter & rewriter,
99+
::mlir::Operation * target,
100+
::mlir::transform::ApplyToEachResultList & results,
101+
::mlir::transform::TransformState & state);
102+
}];
103+
}
104+
70105
def SetOperandLayoutOp : Op<Transform_Dialect, "xegpu.set_operand_layout", [
71106
TransformOpInterface, TransformEachOpTrait,
72107
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>

mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -612,6 +612,70 @@ void transform::GetDescOp::getEffects(
612612
modifiesPayload(effects);
613613
}
614614

615+
616+
DiagnosedSilenceableFailure transform::SetResultLayoutOp::applyToOne(
617+
transform::TransformRewriter &rewriter, Operation *target,
618+
transform::ApplyToEachResultList &results,
619+
transform::TransformState &state) {
620+
621+
int64_t resultIndex = getResultIndex() ? getResultIndex().value() : 0;
622+
if (resultIndex >= target->getNumResults()) {
623+
return emitSilenceableFailure(getLoc())
624+
<< "resultIndex exceeds the number of op results.";
625+
}
626+
627+
auto sgLayout = getSgLayout();
628+
if (sgLayout.size() != 2) {
629+
return emitSilenceableFailure(getLoc())
630+
<< "Expected sg_layout to be a 2D vector";
631+
}
632+
633+
auto sgData = getSgData();
634+
if (sgData.size() != 2) {
635+
return emitSilenceableFailure(getLoc())
636+
<< "Expected sg_data to be a 2D vector";
637+
}
638+
639+
auto instData = getInstData();
640+
if (instData.size() != 2) {
641+
return emitSilenceableFailure(getLoc())
642+
<< "Expected inst_data to be a 2D vector";
643+
}
644+
645+
// For now only desc op or dpas op are supported.
646+
auto descOp = dyn_cast<xegpu::CreateNdDescOp>(target);
647+
auto dpasOp = dyn_cast<xegpu::DpasOp>(target);
648+
if (!descOp && !dpasOp) {
649+
auto diag = emitSilenceableFailure(getLoc())
650+
<< "Expected a xegpu.create_nd_desc or xegpu.dpas op, but got: " << target->getName();
651+
diag.attachNote(target->getLoc()) << "target op";
652+
return diag;
653+
}
654+
655+
auto layoutAttr =
656+
createLayoutAttr(rewriter.getContext(), sgLayout, sgData, instData);
657+
if (descOp) {
658+
// Replace desc op with a new op that has the layout attr in return type.
659+
auto newdescOp = setDescLayout(rewriter, descOp, layoutAttr);
660+
results.push_back(newdescOp.getOperation());
661+
}
662+
if (dpasOp) {
663+
// Set layout attribute for the dpas op result.
664+
// NOTE this actually does not create a new handle ...
665+
// NOTE should not invalidate the handle ... should be a separate op?
666+
xegpu::setLayoutAttr(dpasOp.getOperation()->getResults()[0], layoutAttr);
667+
results.push_back(dpasOp.getOperation());
668+
}
669+
return DiagnosedSilenceableFailure::success();
670+
}
671+
672+
void transform::SetResultLayoutOp::getEffects(
673+
::llvm::SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
674+
consumesHandle(getTargetMutable(), effects);
675+
producesHandle(getOperation()->getOpResults(), effects);
676+
modifiesPayload(effects);
677+
}
678+
615679
DiagnosedSilenceableFailure transform::SetOperandLayoutOp::applyToOne(
616680
transform::TransformRewriter &rewriter, Operation *target,
617681
transform::ApplyToEachResultList &results,

mlir/python/mlir/dialects/transform/xegpu.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,34 @@ def __init__(
3838
)
3939

4040

41+
@_ods_cext.register_operation(_Dialect, replace=True)
42+
class SetResultLayoutOp(SetResultLayoutOp):
43+
"""Specialization for SetResultLayoutOp class."""
44+
45+
def __init__(
46+
self,
47+
target: Union[Operation, Value],
48+
sg_layout: Union[Sequence[int], Attribute],
49+
sg_data: Union[Sequence[int], Attribute],
50+
inst_data: Union[Sequence[int], Attribute],
51+
*,
52+
index: Optional[Union[int, Attribute]] = None,
53+
loc=None,
54+
ip=None,
55+
):
56+
transformed_type = transform.AnyOpType.get()
57+
super().__init__(
58+
transformed_type,
59+
target,
60+
sg_layout,
61+
sg_data,
62+
inst_data,
63+
resultIndex=index,
64+
loc=loc,
65+
ip=ip
66+
)
67+
68+
4169
@_ods_cext.register_operation(_Dialect, replace=True)
4270
class SetOperandLayoutOp(SetOperandLayoutOp):
4371
"""Specialization for SetOperandLayoutOp class."""

mlir/test/Dialect/XeGPU/transform-ops.mlir

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,70 @@ module attributes {transform.with_named_sequence} {
5353

5454
// -----
5555

56+
// CHECK-LABEL: @get_desc_op_default_index
57+
func.func @get_desc_op_default_index(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x4096xf16>, %arg2: memref<4096x4096xf16>) {
58+
// CHECK: %[[V0:.+]] = xegpu.create_nd_tdesc %arg0
59+
%0 = xegpu.create_nd_tdesc %arg0[0, 0] : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
60+
// CHECK: %[[V1:.+]] = xegpu.load_nd %[[V0]]
61+
%1 = xegpu.load_nd %0 : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
62+
%2 = xegpu.create_nd_tdesc %arg1[0, 0] : memref<4096x4096xf16> -> !xegpu.tensor_desc<32x256xf16>
63+
%3 = xegpu.load_nd %2 : !xegpu.tensor_desc<32x256xf16> -> vector<32x256xf16>
64+
%4 = xegpu.create_nd_tdesc %arg2[0, 0] : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x256xf16>
65+
%5 = xegpu.load_nd %4 : !xegpu.tensor_desc<256x256xf16> -> vector<256x256xf16>
66+
// CHECK: = xegpu.dpas %[[V1]]
67+
%6 = xegpu.dpas %1, %3, %5 : vector<256x32xf16>, vector<32x256xf16>, vector<256x256xf16> -> vector<256x256xf16>
68+
return
69+
}
70+
71+
module attributes {transform.with_named_sequence} {
72+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
73+
%0 = transform.structured.match ops{["xegpu.dpas"]} in %arg1 : (!transform.any_op) -> !transform.any_op
74+
// CHECK: transform.xegpu.get_desc_op %{{.*}}
75+
%1 = transform.xegpu.get_desc_op %0 : (!transform.any_op) -> !transform.any_op
76+
transform.yield
77+
}
78+
}
79+
80+
// -----
81+
82+
// CHECK-LABEL: @set_result_layout_create_op
83+
func.func @set_result_layout_create_op(%arg0: memref<4096x4096xf16>) {
84+
// CHECK: %[[V0:.+]] = xegpu.create_nd_tdesc %arg0
85+
// CHECK-SAME: #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], inst_data = [8, 16]>>
86+
%0 = xegpu.create_nd_tdesc %arg0[0, 0] : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
87+
return
88+
}
89+
90+
module attributes {transform.with_named_sequence} {
91+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
92+
%0 = transform.structured.match ops{["xegpu.create_nd_tdesc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
93+
// CHECK: transform.xegpu.set_result_layout %{{.*}}
94+
%1 = transform.xegpu.set_result_layout %0 index = 0 sg_layout = [8, 4] sg_data = [32, 32] inst_data = [8, 16] : (!transform.any_op) -> !transform.any_op
95+
transform.yield
96+
}
97+
}
98+
99+
// -----
100+
101+
// CHECK-LABEL: @set_result_layout_create_op_default_index
102+
func.func @set_result_layout_create_op_default_index(%arg0: memref<4096x4096xf16>) {
103+
// CHECK: %[[V0:.+]] = xegpu.create_nd_tdesc %arg0
104+
// CHECK-SAME: #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], inst_data = [8, 16]>>
105+
%0 = xegpu.create_nd_tdesc %arg0[0, 0] : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
106+
return
107+
}
108+
109+
module attributes {transform.with_named_sequence} {
110+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
111+
%0 = transform.structured.match ops{["xegpu.create_nd_tdesc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
112+
// CHECK: transform.xegpu.set_result_layout %{{.*}}
113+
%1 = transform.xegpu.set_result_layout %0 sg_layout = [8, 4] sg_data = [32, 32] inst_data = [8, 16] : (!transform.any_op) -> !transform.any_op
114+
transform.yield
115+
}
116+
}
117+
118+
// -----
119+
56120
// CHECK-LABEL: @set_operand_layout_a
57121
func.func @set_operand_layout_a(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x4096xf16>, %arg2: memref<4096x4096xf16>) {
58122
// CHECK: %[[V0:.+]] = xegpu.create_nd_tdesc %arg0

mlir/test/python/dialects/transform_xegpu_ext.py

Lines changed: 48 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def getDescOp():
3535

3636

3737
@run
38-
def getDescOpDefault():
38+
def getDescOpDefaultIndex():
3939
sequence = transform.SequenceOp(
4040
transform.FailurePropagationMode.Propagate,
4141
[],
@@ -46,10 +46,56 @@ def getDescOpDefault():
4646
sequence.bodyTarget,
4747
)
4848
transform.YieldOp()
49-
# CHECK-LABEL: TEST: getDescOp
49+
# CHECK-LABEL: TEST: getDescOpDefaultIndex
5050
# CHECK: transform.xegpu.get_desc_op %
5151

5252

53+
@run
54+
def setResultLayout():
55+
sequence = transform.SequenceOp(
56+
transform.FailurePropagationMode.Propagate,
57+
[],
58+
transform.OperationType.get("xegpu.create_nd_tdesc"),
59+
)
60+
with InsertionPoint(sequence.body):
61+
xegpu.SetResultLayoutOp(
62+
sequence.bodyTarget,
63+
index=0,
64+
sg_layout=[6, 4],
65+
sg_data=[32, 16],
66+
inst_data=[8, 16]
67+
)
68+
transform.YieldOp()
69+
# CHECK-LABEL: TEST: setResultLayout
70+
# CHECK: %0 = transform.xegpu.set_result_layout %
71+
# CHECK: index = 0
72+
# CHECK: sg_layout = [6, 4]
73+
# CHECK: sg_data = [32, 16]
74+
# CHECK: inst_data = [8, 16]
75+
76+
77+
@run
78+
def setResultLayoutDefaultIndex():
79+
sequence = transform.SequenceOp(
80+
transform.FailurePropagationMode.Propagate,
81+
[],
82+
transform.OperationType.get("xegpu.create_nd_tdesc"),
83+
)
84+
with InsertionPoint(sequence.body):
85+
xegpu.SetResultLayoutOp(
86+
sequence.bodyTarget,
87+
sg_layout=[6, 4],
88+
sg_data=[32, 16],
89+
inst_data=[8, 16]
90+
)
91+
transform.YieldOp()
92+
# CHECK-LABEL: TEST: setResultLayoutDefaultIndex
93+
# CHECK: %0 = transform.xegpu.set_result_layout %
94+
# CHECK: sg_layout = [6, 4]
95+
# CHECK: sg_data = [32, 16]
96+
# CHECK: inst_data = [8, 16]
97+
98+
5399
@run
54100
def setOperandLayout():
55101
sequence = transform.SequenceOp(

0 commit comments

Comments
 (0)