Skip to content

Commit 2915a04

Browse files
committed
xegpu: add set_op_layout_attr and convert_operand_layout ops
1 parent 2d03839 commit 2915a04

File tree

5 files changed

+288
-80
lines changed

5 files changed

+288
-80
lines changed

mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td

Lines changed: 43 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def GetDescOp : Op<Transform_Dialect, "xegpu.get_desc_op", [
5252
}];
5353

5454
let arguments = (ins TransformHandleTypeInterface : $target,
55-
OptionalAttr<I64Attr> : $operandIndex);
55+
DefaultValuedOptionalAttr<I64Attr, "0"> : $operandIndex);
5656

5757
let results = (outs TransformHandleTypeInterface : $descHandle);
5858
let assemblyFormat =
@@ -82,7 +82,7 @@ def SetResultLayoutOp : Op<Transform_Dialect, "xegpu.set_result_layout", [
8282
}];
8383

8484
let arguments = (ins TransformHandleTypeInterface : $target,
85-
OptionalAttr<I64Attr> : $resultIndex,
85+
DefaultValuedOptionalAttr<I64Attr, "0"> : $resultIndex,
8686
DenseI32ArrayAttr : $sgLayout,
8787
DenseI32ArrayAttr : $sgData,
8888
DenseI32ArrayAttr : $instData);
@@ -102,15 +102,52 @@ def SetResultLayoutOp : Op<Transform_Dialect, "xegpu.set_result_layout", [
102102
}];
103103
}
104104

105-
def SetOperandLayoutOp : Op<Transform_Dialect, "xegpu.set_operand_layout", [
105+
def SetOpLayoutAttrOp : Op<Transform_Dialect, "xegpu.set_op_layout_attr", [
106106
TransformOpInterface, TransformEachOpTrait,
107107
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>
108108
]> {
109109

110-
let summary = "Set xegpu.layout attribute to an xegpu op operand.";
110+
let summary = "Set xegpu.layout attribute of an op.";
111111
let description = [{
112-
Given an xegpu operation, this transform adds `xegpu.layout`
113-
attribute to it's operand's tensor descriptor. The target operand is
112+
Sets the `xegpu.layout` attribute of an op. Sets either the
113+
`layout_result_{index}` or `layout_operand_{index}` attribute. The target
114+
operand/result value is defined by the `index` argument. The layout is
115+
defined by the `sg_layout`, `sg_data` and `inst_data` attributes.
116+
}];
117+
118+
let arguments = (ins TransformHandleTypeInterface : $target,
119+
DefaultValuedOptionalAttr<I64Attr, "0"> : $index,
120+
DenseI32ArrayAttr : $sgLayout,
121+
DenseI32ArrayAttr : $sgData,
122+
DenseI32ArrayAttr : $instData,
123+
DefaultValuedAttr<UnitAttr, "false">:$result,
124+
DefaultValuedAttr<UnitAttr, "false">:$operand
125+
);
126+
127+
let results = (outs);
128+
129+
let assemblyFormat =
130+
"$target (`result` $result^)? (`operand` $operand^)? (`index` `=` $index^)? `sg_layout` `=` $sgLayout `sg_data` `=` "
131+
"$sgData `inst_data` `=` $instData attr-dict `:` type($target)";
132+
133+
let extraClassDeclaration = [{
134+
::mlir::DiagnosedSilenceableFailure applyToOne(
135+
::mlir::transform::TransformRewriter & rewriter,
136+
::mlir::Operation * target,
137+
::mlir::transform::ApplyToEachResultList & results,
138+
::mlir::transform::TransformState & state);
139+
}];
140+
}
141+
142+
def ConvertOperandLayoutOp : Op<Transform_Dialect, "xegpu.convert_operand_layout", [
143+
TransformOpInterface, TransformEachOpTrait,
144+
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>
145+
]> {
146+
147+
let summary = "Convert xegpu.layout attribute for an xegpu op operand.";
148+
let description = [{
149+
Adds an `xegpu.convert_layout` op
150+
to convert the `xegpu.layout` attribute of an operand. The target operand is
114151
defined by the `operandIndex` argument. The layout is defined by the
115152
`sg_layout`, `sg_data` and `inst_data` attributes.
116153
}];

mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp

Lines changed: 116 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -325,14 +325,15 @@ xegpu::LayoutAttr createLayoutAttr(MLIRContext *ctx, ArrayRef<int32_t> sgLayout,
325325
xegpu::CreateNdDescOp setDescLayout(transform::TransformRewriter &rewriter,
326326
xegpu::CreateNdDescOp descOp,
327327
xegpu::LayoutAttr layout) {
328-
auto ctx = rewriter.getContext();
329328
auto oldTensorDesc = descOp.getResult();
330329
auto descShapedType = cast<ShapedType>(oldTensorDesc.getType());
331-
// This discards any block_tdesc_attr attributes.
332-
auto descType = xegpu::TensorDescType::get(ctx, descShapedType.getShape(),
333-
descShapedType.getElementType(),
334-
/*encoding=*/nullptr,
335-
/*layout=*/layout);
330+
// TODO inherit desc attributes from old op (if any)
331+
auto descType = xegpu::TensorDescType::get(
332+
descShapedType.getShape(), descShapedType.getElementType(),
333+
/*array_length=*/1,
334+
/*boundary_check=*/true,
335+
/*memory_space=*/xegpu::MemorySpace::Global,
336+
/*layout=*/layout);
336337

337338
rewriter.setInsertionPointAfter(descOp);
338339
auto newDescOp = rewriter.replaceOpWithNewOp<xegpu::CreateNdDescOp>(
@@ -589,7 +590,7 @@ transform::GetDescOp::applyToOne(transform::TransformRewriter &rewriter,
589590
return diag;
590591
}
591592

592-
int64_t operandIndex = getOperandIndex() ? getOperandIndex().value() : 0;
593+
int64_t operandIndex = getOperandIndex();
593594
if (operandIndex >= targetOp.getNumOperands()) {
594595
return emitSilenceableFailure(getLoc())
595596
<< "operandIndex exceeds the number of op operands.";
@@ -618,7 +619,7 @@ DiagnosedSilenceableFailure transform::SetResultLayoutOp::applyToOne(
618619
transform::ApplyToEachResultList &results,
619620
transform::TransformState &state) {
620621

621-
int64_t resultIndex = getResultIndex() ? getResultIndex().value() : 0;
622+
int64_t resultIndex = getResultIndex();
622623
if (resultIndex >= target->getNumResults()) {
623624
return emitSilenceableFailure(getLoc())
624625
<< "resultIndex exceeds the number of op results.";
@@ -642,30 +643,21 @@ DiagnosedSilenceableFailure transform::SetResultLayoutOp::applyToOne(
642643
<< "Expected inst_data to be a 2D vector";
643644
}
644645

645-
// For now only desc op or dpas op are supported.
646+
// For now only create_nd_desc op is supported.
646647
auto descOp = dyn_cast<xegpu::CreateNdDescOp>(target);
647-
auto dpasOp = dyn_cast<xegpu::DpasOp>(target);
648-
if (!descOp && !dpasOp) {
648+
if (!descOp) {
649649
auto diag = emitSilenceableFailure(getLoc())
650-
<< "Expected a xegpu.create_nd_desc or xegpu.dpas op, but got: " << target->getName();
650+
<< "Expected a xegpu.create_nd_desc op, but got: "
651+
<< target->getName();
651652
diag.attachNote(target->getLoc()) << "target op";
652653
return diag;
653654
}
654655

656+
// Set layout attr in desc op's return type. Replaces old desc op.
655657
auto layoutAttr =
656658
createLayoutAttr(rewriter.getContext(), sgLayout, sgData, instData);
657-
if (descOp) {
658-
// Replace desc op with a new op that has the layout attr in return type.
659-
auto newdescOp = setDescLayout(rewriter, descOp, layoutAttr);
660-
results.push_back(newdescOp.getOperation());
661-
}
662-
if (dpasOp) {
663-
// Set layout attribute for the dpas op result.
664-
// NOTE this actually does not create a new handle ...
665-
// NOTE should not invalidate the handle ... should be a separate op?
666-
xegpu::setLayoutAttr(dpasOp.getOperation()->getResults()[0], layoutAttr);
667-
results.push_back(dpasOp.getOperation());
668-
}
659+
auto newdescOp = setDescLayout(rewriter, descOp, layoutAttr);
660+
results.push_back(newdescOp.getOperation());
669661
return DiagnosedSilenceableFailure::success();
670662
}
671663

@@ -676,7 +668,76 @@ void transform::SetResultLayoutOp::getEffects(
676668
modifiesPayload(effects);
677669
}
678670

679-
DiagnosedSilenceableFailure transform::SetOperandLayoutOp::applyToOne(
671+
DiagnosedSilenceableFailure transform::SetOpLayoutAttrOp::applyToOne(
672+
transform::TransformRewriter &rewriter, Operation *target,
673+
transform::ApplyToEachResultList &results,
674+
transform::TransformState &state) {
675+
676+
bool resultTarget = getResult();
677+
bool operandTarget = getOperand();
678+
679+
if (resultTarget && operandTarget) {
680+
return emitSilenceableFailure(getLoc())
681+
<< "`result` and `operand` cannot be both set.";
682+
}
683+
if (!resultTarget && !operandTarget) {
684+
return emitSilenceableFailure(getLoc())
685+
<< "Either `result` or `operand` must be set.";
686+
}
687+
688+
int64_t index = getIndex();
689+
if (resultTarget && index >= target->getNumResults()) {
690+
return emitSilenceableFailure(getLoc())
691+
<< "index exceeds the number of op results.";
692+
}
693+
if (!resultTarget && index >= target->getNumOperands()) {
694+
return emitSilenceableFailure(getLoc())
695+
<< "index exceeds the number of op operands.";
696+
}
697+
698+
auto sgLayout = getSgLayout();
699+
if (sgLayout.size() != 2) {
700+
return emitSilenceableFailure(getLoc())
701+
<< "Expected sg_layout to be a 2D vector";
702+
}
703+
704+
auto sgData = getSgData();
705+
if (sgData.size() != 2) {
706+
return emitSilenceableFailure(getLoc())
707+
<< "Expected sg_data to be a 2D vector";
708+
}
709+
710+
auto instData = getInstData();
711+
if (instData.size() != 2) {
712+
return emitSilenceableFailure(getLoc())
713+
<< "Expected inst_data to be a 2D vector";
714+
}
715+
716+
// For now only dpas op is supported.
717+
if (!isa<xegpu::DpasOp>(target)) {
718+
auto diag = emitSilenceableFailure(getLoc())
719+
<< "Expected a xegpu.dpas op, but got: " << target->getName();
720+
diag.attachNote(target->getLoc()) << "target op";
721+
return diag;
722+
}
723+
auto layoutAttr =
724+
createLayoutAttr(rewriter.getContext(), sgLayout, sgData, instData);
725+
// Set layout attribute for the op result or operand
726+
if (resultTarget) {
727+
xegpu::setLayoutAttr(target->getResult(index), layoutAttr);
728+
} else {
729+
xegpu::setLayoutAttr(target->getOpOperand(index), layoutAttr);
730+
}
731+
return DiagnosedSilenceableFailure::success();
732+
}
733+
734+
void transform::SetOpLayoutAttrOp::getEffects(
735+
::llvm::SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
736+
onlyReadsHandle(getTargetMutable(), effects);
737+
modifiesPayload(effects);
738+
}
739+
740+
DiagnosedSilenceableFailure transform::ConvertOperandLayoutOp::applyToOne(
680741
transform::TransformRewriter &rewriter, Operation *target,
681742
transform::ApplyToEachResultList &results,
682743
transform::TransformState &state) {
@@ -714,26 +775,48 @@ DiagnosedSilenceableFailure transform::SetOperandLayoutOp::applyToOne(
714775
<< "Expected inst_data to be a 2D vector";
715776
}
716777

717-
// Replace descriptor op using layout attribute.
718-
Value opVec = targetOp.getOperation()->getOperand(operandIndex);
778+
// Find desc op.
779+
Value opVec = target->getOperand(operandIndex);
719780
auto maybeDescOp = findDescriptorOp(opVec, targetOp.getOperation());
720781
if (!maybeDescOp) {
721782
return emitSilenceableFailure(getLoc()) << "Could not find descriptor op.";
722783
}
723784
auto descOp = *maybeDescOp;
724-
// Set layout attribute.
785+
// Get load op.
786+
auto maybeLoadOp = getUserOfType<xegpu::LoadNdOp>(descOp.getResult());
787+
if (!maybeLoadOp) {
788+
return emitSilenceableFailure(getLoc())
789+
<< "Expected a xegpu.load_nd op as a user of the descriptor op.";
790+
}
791+
auto loadOp = *maybeLoadOp;
792+
// Get load op operand value layout
793+
auto producerLayoutAttr = xegpu::getLayoutAttr(loadOp.getOperand(0));
794+
if (!producerLayoutAttr) {
795+
return emitSilenceableFailure(getLoc())
796+
<< "Operand producer op does not have a layout attr.";
797+
}
798+
799+
// New layout attr
725800
auto layoutAttr =
726801
createLayoutAttr(rewriter.getContext(), sgLayout, sgData, instData);
727-
descOp = setDescLayout(rewriter, descOp, layoutAttr);
728-
if (operandIndex == 2) {
729-
// C operand: set layout attribute for the dpas op result.
730-
xegpu::setLayoutAttr(targetOp.getOperation()->getResults()[0], layoutAttr);
802+
803+
if (producerLayoutAttr != layoutAttr) {
804+
rewriter.setInsertionPointAfter(loadOp.getOperation());
805+
auto source = loadOp.getResult();
806+
auto convLayoutOp = rewriter.create<xegpu::ConvertLayoutOp>(
807+
loadOp.getLoc(), source.getType(), source, producerLayoutAttr,
808+
layoutAttr);
809+
// Replace load op result with the converted layout.
810+
rewriter.replaceUsesWithIf(
811+
source, convLayoutOp.getResult(), [&](OpOperand &use) {
812+
return use.getOwner() != convLayoutOp.getOperation();
813+
});
731814
}
732815

733816
return DiagnosedSilenceableFailure::success();
734817
}
735818

736-
void transform::SetOperandLayoutOp::getEffects(
819+
void transform::ConvertOperandLayoutOp::getEffects(
737820
::llvm::SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
738821
onlyReadsHandle(getTargetMutable(), effects);
739822
modifiesPayload(effects);

mlir/python/mlir/dialects/transform/xegpu.py

Lines changed: 47 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,53 @@ def __init__(
6767

6868

6969
@_ods_cext.register_operation(_Dialect, replace=True)
70-
class SetOperandLayoutOp(SetOperandLayoutOp):
71-
"""Specialization for SetOperandLayoutOp class."""
70+
class SetOpLayoutAttrOp(SetOpLayoutAttrOp):
71+
"""Specialization for SetOpLayoutAttrOp class."""
72+
73+
def __init__(
74+
self,
75+
target: Union[Operation, Value],
76+
sg_layout: Union[Sequence[int], Attribute],
77+
sg_data: Union[Sequence[int], Attribute],
78+
inst_data: Union[Sequence[int], Attribute],
79+
*,
80+
index: Union[int, Attribute] = None,
81+
result: Union[bool, Attribute] = None,
82+
operand: Union[bool, Attribute] = None,
83+
loc=None,
84+
ip=None,
85+
):
86+
if result is None and operand is None:
87+
result = True
88+
super().__init__(
89+
target,
90+
sg_layout,
91+
sg_data,
92+
inst_data,
93+
index=index,
94+
result=result,
95+
operand=operand,
96+
loc=loc,
97+
ip=ip
98+
)
99+
# __init__(
100+
# target: Union[mlir._mlir_libs._mlir.ir.Operation, mlir._mlir_libs._mlir.ir.Value],
101+
# sg_layout: Union[Sequence[int], mlir._mlir_libs._mlir.ir.Attribute],
102+
# sg_data: Union[Sequence[int], mlir._mlir_libs._mlir.ir.Attribute],
103+
# inst_data: Union[Sequence[int], mlir._mlir_libs._mlir.ir.Attribute],
104+
# *,
105+
# index: Union[int, mlir._mlir_libs._mlir.ir.Attribute] = None,
106+
# result: Union[bool, mlir._mlir_libs._mlir.ir.Attribute] = None,
107+
# operand: Union[bool, mlir._mlir_libs._mlir.ir.Attribute] = None,
108+
# loc=None,
109+
# ip=None
110+
# )
111+
112+
113+
114+
@_ods_cext.register_operation(_Dialect, replace=True)
115+
class ConvertOperandLayoutOp(ConvertOperandLayoutOp):
116+
"""Specialization for ConvertOperandLayoutOp class."""
72117

73118
def __init__(
74119
self,

0 commit comments

Comments
 (0)