Skip to content

Commit a45730c

Browse files
committed
xegpu: rename set_dpas_layout to set_operand_layout
rename tileIndex to operandIndex remove all references to dpas ops where possible
1 parent bf7cf0a commit a45730c

File tree

4 files changed

+73
-73
lines changed

4 files changed

+73
-73
lines changed

mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -41,30 +41,30 @@ def HoistDescOp : Op<Transform_Dialect, "xegpu.hoist_desc_ops", [
4141
}];
4242
}
4343

44-
def SetDPASLayoutOp : Op<Transform_Dialect, "xegpu.set_dpas_layout", [
44+
def SetOperandLayoutOp : Op<Transform_Dialect, "xegpu.set_operand_layout", [
4545
TransformOpInterface, TransformEachOpTrait,
4646
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>
4747
]> {
4848

49-
let summary = "Set xegpu.layout attribute to an DPAS op operand.";
49+
let summary = "Set xegpu.layout attribute to an xegpu op operand.";
5050
let description = [{
51-
Given a `xegpu.dpas` operation, this transform adds `xegpu.layout`
51+
Given an xegpu operation, this transform adds `xegpu.layout`
5252
attribute to it's operand's tensor descriptor. The target operand is
53-
defined by the `tileIndex` argument. The layout is defined by the
53+
defined by the `operandIndex` argument. The layout is defined by the
5454
`sg_layout`, `sg_data` and `inst_data` attributes.
5555
}];
5656

57-
let arguments = (ins TransformHandleTypeInterface : $dpasOp,
58-
I64Attr : $tileIndex,
57+
let arguments = (ins TransformHandleTypeInterface : $target,
58+
I64Attr : $operandIndex,
5959
DenseI32ArrayAttr : $sgLayout,
6060
DenseI32ArrayAttr : $sgData,
6161
DenseI32ArrayAttr : $instData);
6262

6363
let results = (outs);
6464

6565
let assemblyFormat =
66-
"$dpasOp `index` `=` $tileIndex `sg_layout` `=` $sgLayout `sg_data` `=` "
67-
"$sgData `inst_data` `=` $instData attr-dict `:` type($dpasOp)";
66+
"$target `index` `=` $operandIndex `sg_layout` `=` $sgLayout `sg_data` `=` "
67+
"$sgData `inst_data` `=` $instData attr-dict `:` type($target)";
6868

6969
let extraClassDeclaration = [{
7070
::mlir::DiagnosedSilenceableFailure applyToOne(
@@ -81,20 +81,20 @@ def InsertPrefetchOp : Op<Transform_Dialect, "xegpu.insert_prefetch",
8181

8282
let summary = "Adds xegpu prefetch ops to matmul operand tiles.";
8383
let description = [{
84-
Given a `xegpu.dpas` operation residing in a `scf.for` loop, this transform inserts cooperative `xegpu.prefetch` operations for the A (index = 0) or B (index = 1) operand. The prefetch tile size is determined by the `sg_layout` and `sg_data` attributes.
84+
Given an xegpu operation residing in a `scf.for` loop, this transform inserts cooperative `xegpu.prefetch` operations for the A (index = 0) or B (index = 1) operand. The prefetch tile size is determined by the `sg_layout` and `sg_data` attributes.
8585
}];
8686

87-
let arguments = (ins TransformHandleTypeInterface : $dpasOp,
87+
let arguments = (ins TransformHandleTypeInterface : $target,
8888
TransformHandleTypeInterface : $loopOp,
89-
I64Attr : $tileIndex,
89+
I64Attr : $operandIndex,
9090
DenseI32ArrayAttr : $sgLayout,
9191
DenseI32ArrayAttr : $sgData);
9292

93-
let results = (outs TransformHandleTypeInterface : $transformedDpasOp,
93+
let results = (outs TransformHandleTypeInterface : $transformedTargetOp,
9494
TransformHandleTypeInterface : $transformedLoopOp);
9595

9696
let assemblyFormat =
97-
"$dpasOp $loopOp `index` `=` $tileIndex `sg_layout` `=` $sgLayout `sg_data` `=` "
97+
"$target $loopOp `index` `=` $operandIndex `sg_layout` `=` $sgLayout `sg_data` `=` "
9898
"$sgData attr-dict `:` functional-type(operands, results)";
9999
}
100100

mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp

Lines changed: 36 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -452,23 +452,24 @@ transform::InsertPrefetchOp::apply(transform::TransformRewriter &rewriter,
452452
transform::TransformResults &results,
453453
transform::TransformState &state) {
454454

455-
auto dpasOps = state.getPayloadOps(getDpasOp());
455+
auto targetOps = state.getPayloadOps(getTarget());
456456
auto loopOps = state.getPayloadOps(getLoopOp());
457457

458-
if (!llvm::hasSingleElement(dpasOps)) {
459-
return emitDefiniteFailure() << "requires exactly one dpasOp handle (got "
460-
<< llvm::range_size(dpasOps) << ")";
458+
if (!llvm::hasSingleElement(targetOps)) {
459+
return emitDefiniteFailure() << "requires exactly one targetOp handle (got "
460+
<< llvm::range_size(targetOps) << ")";
461461
}
462462
if (!llvm::hasSingleElement(loopOps)) {
463463
return emitDefiniteFailure() << "requires exactly one loopOp handle (got "
464464
<< llvm::range_size(loopOps) << ")";
465465
}
466466

467-
Operation *dpasPtr = *dpasOps.begin();
468-
auto dpasOp = dyn_cast<xegpu::DpasOp>(dpasPtr);
469-
if (!dpasOp) {
467+
Operation *targetPtr = *targetOps.begin();
468+
// For now only DPAS op is supported.
469+
auto targetOp = dyn_cast<xegpu::DpasOp>(targetPtr);
470+
if (!targetOp) {
470471
return emitSilenceableFailure(getLoc())
471-
<< "Expected a xegpu.dpas op, but got: " << dpasPtr->getName();
472+
<< "Expected a xegpu.dpas op, but got: " << targetPtr->getName();
472473
}
473474

474475
Operation *loopPtr = *loopOps.begin();
@@ -478,16 +479,16 @@ transform::InsertPrefetchOp::apply(transform::TransformRewriter &rewriter,
478479
<< "Expected a scf.for op, but got: " << loopPtr->getName();
479480
}
480481

481-
auto parentLoop = dpasOp->getParentOfType<scf::ForOp>();
482+
auto parentLoop = targetOp->getParentOfType<scf::ForOp>();
482483
if (!parentLoop || parentLoop != forOp) {
483484
return emitSilenceableFailure(getLoc())
484-
<< "dpasOp is not contained in the given scf.for loop.";
485+
<< "target op is not contained in the given scf.for loop.";
485486
}
486487

487-
int64_t tileIndex = getTileIndex();
488-
if (tileIndex >= dpasOp.getNumOperands()) {
488+
int64_t operandIndex = getOperandIndex();
489+
if (operandIndex >= targetOp.getNumOperands()) {
489490
return emitSilenceableFailure(getLoc())
490-
<< "tileIndex exceeds the number of op operands.";
491+
<< "operandIndex exceeds the number of op operands.";
491492
}
492493

493494
auto sgLayout = getSgLayout();
@@ -503,8 +504,8 @@ transform::InsertPrefetchOp::apply(transform::TransformRewriter &rewriter,
503504
}
504505

505506
// Find descriptor op of the operand.
506-
Value opVec = dpasOp.getOperation()->getOperand(tileIndex);
507-
auto maybeDescOp = findDescriptorOp(opVec, dpasOp.getOperation());
507+
Value opVec = targetOp.getOperation()->getOperand(operandIndex);
508+
auto maybeDescOp = findDescriptorOp(opVec, targetOp.getOperation());
508509
if (!maybeDescOp) {
509510
return emitSilenceableFailure(getLoc()) << "Could not find descriptor op.";
510511
}
@@ -554,42 +555,42 @@ transform::InsertPrefetchOp::apply(transform::TransformRewriter &rewriter,
554555
newForOp.setLowerBound(forOp.getLowerBound());
555556

556557
// Fuse with the original loop, keep track of cloned ops.
557-
SmallVector<Operation *> sourceOps{dpasOp.getOperation()}, targetOps;
558-
auto fusedLoop =
559-
fuseForLoops(newForOp, forOp, rewriter, sourceOps, targetOps);
558+
SmallVector<Operation *> sourceOps{targetOp.getOperation()}, dstOps;
559+
auto fusedLoop = fuseForLoops(newForOp, forOp, rewriter, sourceOps, dstOps);
560560
assert(fusedLoop && "failed to fuse loops");
561561

562-
// Get the cloned dpas op.
563-
auto clonedDpasOp = targetOps[0];
564-
if (!clonedDpasOp) {
562+
// Get the cloned target op.
563+
auto clonedTargetOp = dstOps[0];
564+
if (!clonedTargetOp) {
565565
return emitSilenceableFailure(getLoc())
566-
<< "Failed to find cloned dpas op in the fused loop.";
566+
<< "Failed to find cloned target op in the fused loop.";
567567
}
568568

569569
// Map result handles.
570570
results.set(cast<OpResult>(getTransformedLoopOp()), {fusedLoop});
571-
results.set(cast<OpResult>(getTransformedDpasOp()), {clonedDpasOp});
571+
results.set(cast<OpResult>(getTransformedTargetOp()), {clonedTargetOp});
572572

573573
return DiagnosedSilenceableFailure::success();
574574
}
575575

576-
DiagnosedSilenceableFailure transform::SetDPASLayoutOp::applyToOne(
576+
DiagnosedSilenceableFailure transform::SetOperandLayoutOp::applyToOne(
577577
transform::TransformRewriter &rewriter, Operation *target,
578578
transform::ApplyToEachResultList &results,
579579
transform::TransformState &state) {
580580

581-
auto dpasOp = dyn_cast<xegpu::DpasOp>(target);
582-
if (!dpasOp) {
581+
// For now only DPAS op is supported.
582+
auto targetOp = dyn_cast<xegpu::DpasOp>(target);
583+
if (!targetOp) {
583584
auto diag = emitSilenceableFailure(getLoc())
584585
<< "Expected a xegpu.dpas op, but got: " << target->getName();
585586
diag.attachNote(target->getLoc()) << "target op";
586587
return diag;
587588
}
588589

589-
int64_t tileIndex = getTileIndex();
590-
if (tileIndex >= dpasOp.getNumOperands()) {
590+
int64_t operandIndex = getOperandIndex();
591+
if (operandIndex >= targetOp.getNumOperands()) {
591592
return emitSilenceableFailure(getLoc())
592-
<< "tileIndex exceeds the number of op operands.";
593+
<< "operandIndex exceeds the number of op operands.";
593594
}
594595

595596
auto sgLayout = getSgLayout();
@@ -611,26 +612,26 @@ DiagnosedSilenceableFailure transform::SetDPASLayoutOp::applyToOne(
611612
}
612613

613614
// Replace descriptor op using layout attribute.
614-
Value opVec = dpasOp.getOperation()->getOperand(tileIndex);
615-
auto maybeDescOp = findDescriptorOp(opVec, dpasOp.getOperation());
615+
Value opVec = targetOp.getOperation()->getOperand(operandIndex);
616+
auto maybeDescOp = findDescriptorOp(opVec, targetOp.getOperation());
616617
if (!maybeDescOp) {
617618
return emitSilenceableFailure(getLoc()) << "Could not find descriptor op.";
618619
}
619620
auto descOp = *maybeDescOp;
620621
// Set layout attribute.
621622
auto layoutAttr = createLayoutAttr(rewriter.getContext(), sgLayout, sgData, instData);
622623
descOp = setDescLayout(rewriter, descOp, layoutAttr);
623-
if (tileIndex == 2) {
624+
if (operandIndex == 2) {
624625
// C operand: set layout attribute for the dpas op result.
625-
xegpu::setLayoutAttr(dpasOp.getOperation()->getResults()[0], layoutAttr);
626+
xegpu::setLayoutAttr(targetOp.getOperation()->getResults()[0], layoutAttr);
626627
}
627628

628629
return DiagnosedSilenceableFailure::success();
629630
}
630631

631-
void transform::SetDPASLayoutOp::getEffects(
632+
void transform::SetOperandLayoutOp::getEffects(
632633
::llvm::SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
633-
onlyReadsHandle(getDpasOpMutable(), effects);
634+
onlyReadsHandle(getTargetMutable(), effects);
634635
modifiesPayload(effects);
635636
}
636637

mlir/python/mlir/dialects/transform/xegpu.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,13 @@
1717

1818

1919
@_ods_cext.register_operation(_Dialect, replace=True)
20-
class SetDPASLayoutOp(SetDPASLayoutOp):
21-
"""Specialization for SetDPASLayoutOp class."""
20+
class SetOperandLayoutOp(SetOperandLayoutOp):
21+
"""Specialization for SetOperandLayoutOp class."""
2222

2323
def __init__(
2424
self,
25-
dpas_op: Union[Operation, Value],
26-
tile_index: Union[int, Attribute],
25+
target: Union[Operation, Value],
26+
index: Union[int, Attribute],
2727
sg_layout: Union[Sequence[int], Attribute],
2828
sg_data: Union[Sequence[int], Attribute],
2929
inst_data: Union[Sequence[int], Attribute],
@@ -32,8 +32,8 @@ def __init__(
3232
ip=None,
3333
):
3434
super().__init__(
35-
dpas_op,
36-
tile_index,
35+
target,
36+
index,
3737
sg_layout,
3838
sg_data,
3939
inst_data,
@@ -48,23 +48,22 @@ class InsertPrefetchOp(InsertPrefetchOp):
4848

4949
def __init__(
5050
self,
51-
dpas_op: Union[Operation, Value],
51+
target: Union[Operation, Value],
5252
loop_op: Union[Operation, Value],
53-
tile_index: Union[int, Attribute],
53+
index: Union[int, Attribute],
5454
sg_layout: Union[Sequence[int], Attribute],
5555
sg_data: Union[Sequence[int], Attribute],
5656
loc=None,
5757
ip=None,
5858
):
59-
# results = get_op_result_or_op_results(dpas_op, loop_op)
60-
transformed_dpas_type = transform.AnyOpType.get()
59+
transformed_target_type = transform.AnyOpType.get()
6160
transformed_loop_type = transform.AnyOpType.get()
6261
super().__init__(
63-
transformed_dpas_type,
62+
transformed_target_type,
6463
transformed_loop_type,
65-
_get_op_result_or_value(dpas_op),
64+
_get_op_result_or_value(target),
6665
_get_op_result_or_value(loop_op),
67-
tile_index,
66+
index,
6867
sg_layout,
6968
sg_data,
7069
loc=loc,

mlir/test/Dialect/XeGPU/transform-ops.mlir

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@ module attributes {transform.with_named_sequence} {
2727

2828
// -----
2929

30-
// CHECK-LABEL: @set_dpas_layout_a
31-
func.func @set_dpas_layout_a(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x4096xf16>, %arg2: memref<4096x4096xf16>) {
30+
// CHECK-LABEL: @set_operand_layout_a
31+
func.func @set_operand_layout_a(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x4096xf16>, %arg2: memref<4096x4096xf16>) {
3232
// CHECK: %[[V0:.+]] = xegpu.create_nd_tdesc %arg0
3333
// CHECK-SAME: #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], inst_data = [8, 16]>>
3434
%0 = xegpu.create_nd_tdesc %arg0[0, 0] : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
@@ -46,16 +46,16 @@ func.func @set_dpas_layout_a(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x40
4646
module attributes {transform.with_named_sequence} {
4747
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
4848
%0 = transform.structured.match ops{["xegpu.dpas"]} in %arg1 : (!transform.any_op) -> !transform.any_op
49-
// CHECK: transform.xegpu.set_dpas_layout %{{.*}}
50-
transform.xegpu.set_dpas_layout %0 index = 0 sg_layout = [8, 4] sg_data = [32, 32] inst_data = [8, 16] : !transform.any_op
49+
// CHECK: transform.xegpu.set_operand_layout %{{.*}}
50+
transform.xegpu.set_operand_layout %0 index = 0 sg_layout = [8, 4] sg_data = [32, 32] inst_data = [8, 16] : !transform.any_op
5151
transform.yield
5252
}
5353
}
5454

5555
// -----
5656

57-
// CHECK-LABEL: @set_dpas_layout_b
58-
func.func @set_dpas_layout_b(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x4096xf16>, %arg2: memref<4096x4096xf16>) {
57+
// CHECK-LABEL: @set_operand_layout_b
58+
func.func @set_operand_layout_b(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x4096xf16>, %arg2: memref<4096x4096xf16>) {
5959
// CHECK: = xegpu.create_nd_tdesc
6060
%0 = xegpu.create_nd_tdesc %arg0[0, 0] : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
6161
// CHECK: = xegpu.load_nd
@@ -75,16 +75,16 @@ func.func @set_dpas_layout_b(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x40
7575
module attributes {transform.with_named_sequence} {
7676
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
7777
%0 = transform.structured.match ops{["xegpu.dpas"]} in %arg1 : (!transform.any_op) -> !transform.any_op
78-
// CHECK: transform.xegpu.set_dpas_layout %{{.*}}
79-
transform.xegpu.set_dpas_layout %0 index = 1 sg_layout = [8, 4] sg_data = [32, 64] inst_data = [16, 16] : !transform.any_op
78+
// CHECK: transform.xegpu.set_operand_layout %{{.*}}
79+
transform.xegpu.set_operand_layout %0 index = 1 sg_layout = [8, 4] sg_data = [32, 64] inst_data = [16, 16] : !transform.any_op
8080
transform.yield
8181
}
8282
}
8383

8484
// -----
8585

86-
// CHECK-LABEL: @set_dpas_layout_c
87-
func.func @set_dpas_layout_c(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x4096xf16>, %arg2: memref<4096x4096xf16>) {
86+
// CHECK-LABEL: @set_operand_layout_c
87+
func.func @set_operand_layout_c(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x4096xf16>, %arg2: memref<4096x4096xf16>) {
8888
// CHECK: = xegpu.create_nd_tdesc
8989
%0 = xegpu.create_nd_tdesc %arg0[0, 0] : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
9090
// CHECK: = xegpu.load_nd
@@ -106,8 +106,8 @@ func.func @set_dpas_layout_c(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x40
106106
module attributes {transform.with_named_sequence} {
107107
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
108108
%0 = transform.structured.match ops{["xegpu.dpas"]} in %arg1 : (!transform.any_op) -> !transform.any_op
109-
// CHECK: transform.xegpu.set_dpas_layout %{{.*}}
110-
transform.xegpu.set_dpas_layout %0 index = 2 sg_layout = [8, 4] sg_data = [32, 64] inst_data = [8, 16] : !transform.any_op
109+
// CHECK: transform.xegpu.set_operand_layout %{{.*}}
110+
transform.xegpu.set_operand_layout %0 index = 2 sg_layout = [8, 4] sg_data = [32, 64] inst_data = [8, 16] : !transform.any_op
111111
transform.yield
112112
}
113113
}

0 commit comments

Comments
 (0)