Skip to content

Commit bf7cf0a

Browse files
committed
xegpu: remove load_data argument from set_dpas_layout transform op
1 parent d73ef0d commit bf7cf0a

File tree

4 files changed

+8
-83
lines changed

4 files changed

+8
-83
lines changed

mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -51,23 +51,20 @@ def SetDPASLayoutOp : Op<Transform_Dialect, "xegpu.set_dpas_layout", [
5151
Given a `xegpu.dpas` operation, this transform adds `xegpu.layout`
5252
attribute to it's operand's tensor descriptor. The target operand is
5353
defined by the `tileIndex` argument. The layout is defined by the
54-
`sg_layout`, `sg_data` and `inst_data` attributes. The `load_data`
55-
attribute defines the tile size used for loading the data. It must be a
56-
multiple of the `inst_data` size.
54+
`sg_layout`, `sg_data` and `inst_data` attributes.
5755
}];
5856

5957
let arguments = (ins TransformHandleTypeInterface : $dpasOp,
6058
I64Attr : $tileIndex,
6159
DenseI32ArrayAttr : $sgLayout,
6260
DenseI32ArrayAttr : $sgData,
63-
OptionalAttr<DenseI32ArrayAttr> : $loadData,
6461
DenseI32ArrayAttr : $instData);
6562

6663
let results = (outs);
6764

6865
let assemblyFormat =
6966
"$dpasOp `index` `=` $tileIndex `sg_layout` `=` $sgLayout `sg_data` `=` "
70-
"$sgData (`load_data` `=` $loadData^)? `inst_data` `=` $instData attr-dict `:` type($dpasOp)";
67+
"$sgData `inst_data` `=` $instData attr-dict `:` type($dpasOp)";
7168

7269
let extraClassDeclaration = [{
7370
::mlir::DiagnosedSilenceableFailure applyToOne(

mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp

Lines changed: 5 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -610,59 +610,19 @@ DiagnosedSilenceableFailure transform::SetDPASLayoutOp::applyToOne(
610610
<< "Expected inst_data to be a 2D vector";
611611
}
612612

613-
llvm::ArrayRef<int> loadData = instData;
614-
if (getLoadData().has_value()) {
615-
loadData = getLoadData().value();
616-
if (loadData.size() != 2) {
617-
return emitSilenceableFailure(getLoc())
618-
<< "Expected load_data to be a 2D vector";
619-
}
620-
if (loadData[0] < instData[0] || loadData[1] < instData[1]) {
621-
return emitSilenceableFailure(getLoc())
622-
<< "load_data size must be larger or equal to inst_data size";
623-
}
624-
if (loadData[0] % instData[0] != 0 || loadData[1] % instData[1] != 0) {
625-
return emitSilenceableFailure(getLoc())
626-
<< "load_data must be evenly divisible by inst_data";
627-
}
628-
}
629-
630613
// Replace descriptor op using layout attribute.
631614
Value opVec = dpasOp.getOperation()->getOperand(tileIndex);
632615
auto maybeDescOp = findDescriptorOp(opVec, dpasOp.getOperation());
633616
if (!maybeDescOp) {
634617
return emitSilenceableFailure(getLoc()) << "Could not find descriptor op.";
635618
}
636619
auto descOp = *maybeDescOp;
637-
// Layout for the load op.
638-
auto loadLayoutAttr = createLayoutAttr(rewriter.getContext(), sgLayout, sgData, loadData);
639-
descOp = setDescLayout(rewriter, descOp, loadLayoutAttr);
640-
// Layout for the instruction.
641-
auto instLayoutAttr = createLayoutAttr(rewriter.getContext(), sgLayout, sgData, instData);
620+
// Set layout attribute.
621+
auto layoutAttr = createLayoutAttr(rewriter.getContext(), sgLayout, sgData, instData);
622+
descOp = setDescLayout(rewriter, descOp, layoutAttr);
642623
if (tileIndex == 2) {
643-
// C operand: set layout attribute for the dpas op result
644-
xegpu::setLayoutAttr(dpasOp.getOperation()->getResults()[0], instLayoutAttr);
645-
}
646-
647-
if (loadLayoutAttr != instLayoutAttr) {
648-
// Insert convert layout op after load op.
649-
auto maybeLoadOp = getUserOfType<xegpu::LoadNdOp>(descOp.getResult());
650-
if (!maybeLoadOp) {
651-
return emitSilenceableFailure(getLoc())
652-
<< "Expected a xegpu.load_nd op as a user of the descriptor op.";
653-
}
654-
auto loadOp = *maybeLoadOp;
655-
rewriter.setInsertionPointAfter(loadOp.getOperation());
656-
auto source = loadOp.getResult();
657-
auto convLayoutOp = rewriter.create<xegpu::ConvertLayoutOp>(
658-
loadOp.getLoc(), source.getType(), source,
659-
loadLayoutAttr, instLayoutAttr);
660-
// Replace load op result with the converted layout.
661-
rewriter.replaceUsesWithIf(
662-
source, convLayoutOp.getResult(),
663-
[&](OpOperand &use) {
664-
return use.getOwner() != convLayoutOp.getOperation();
665-
});
624+
// C operand: set layout attribute for the dpas op result.
625+
xegpu::setLayoutAttr(dpasOp.getOperation()->getResults()[0], layoutAttr);
666626
}
667627

668628
return DiagnosedSilenceableFailure::success();

mlir/python/mlir/dialects/transform/xegpu.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@ def __init__(
2828
sg_data: Union[Sequence[int], Attribute],
2929
inst_data: Union[Sequence[int], Attribute],
3030
*,
31-
load_data: Optional[Union[Sequence[int], Attribute]] = None,
3231
loc=None,
3332
ip=None,
3433
):
@@ -38,7 +37,6 @@ def __init__(
3837
sg_layout,
3938
sg_data,
4039
inst_data,
41-
loadData=load_data,
4240
loc=loc,
4341
ip=ip
4442
)

mlir/test/Dialect/XeGPU/transform-ops.mlir

Lines changed: 1 addition & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -107,37 +107,7 @@ module attributes {transform.with_named_sequence} {
107107
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
108108
%0 = transform.structured.match ops{["xegpu.dpas"]} in %arg1 : (!transform.any_op) -> !transform.any_op
109109
// CHECK: transform.xegpu.set_dpas_layout %{{.*}}
110-
transform.xegpu.set_dpas_layout %0 index = 2 sg_layout = [8, 4] sg_data = [32, 64] load_data = [8, 16] inst_data = [8, 16] : !transform.any_op
111-
transform.yield
112-
}
113-
}
114-
115-
// -----
116-
117-
// CHECK-LABEL: @set_dpas_layout_load_a
118-
func.func @set_dpas_layout_load_a(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x4096xf16>, %arg2: memref<4096x4096xf16>) {
119-
// CHECK: %[[V0:.+]] = xegpu.create_nd_tdesc %arg0
120-
// CHECK-SAME: #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], inst_data = [32, 16]>>
121-
%0 = xegpu.create_nd_tdesc %arg0[0, 0] : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
122-
// CHECK: %[[V1:.+]] = xegpu.load_nd %[[V0]]
123-
// CHECK: %[[V2:.+]] = xegpu.convert_layout %[[V1]]
124-
// CHECK-SAME: resMap = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], inst_data = [8, 16]>
125-
// CHECK-SAME: srcMap = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], inst_data = [32, 16]>
126-
%1 = xegpu.load_nd %0 : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
127-
%2 = xegpu.create_nd_tdesc %arg1[0, 0] : memref<4096x4096xf16> -> !xegpu.tensor_desc<32x256xf16>
128-
%3 = xegpu.load_nd %2 : !xegpu.tensor_desc<32x256xf16> -> vector<32x256xf16>
129-
%4 = xegpu.create_nd_tdesc %arg2[0, 0] : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x256xf16>
130-
%5 = xegpu.load_nd %4 : !xegpu.tensor_desc<256x256xf16> -> vector<256x256xf16>
131-
// CHECK: = xegpu.dpas %[[V2]]
132-
%6 = xegpu.dpas %1, %3, %5 : vector<256x32xf16>, vector<32x256xf16>, vector<256x256xf16> -> vector<256x256xf16>
133-
return
134-
}
135-
136-
module attributes {transform.with_named_sequence} {
137-
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
138-
%0 = transform.structured.match ops{["xegpu.dpas"]} in %arg1 : (!transform.any_op) -> !transform.any_op
139-
// CHECK: transform.xegpu.set_dpas_layout %{{.*}}
140-
transform.xegpu.set_dpas_layout %0 index = 0 sg_layout = [8, 4] sg_data = [32, 32] load_data = [32, 16] inst_data = [8, 16] : !transform.any_op
110+
transform.xegpu.set_dpas_layout %0 index = 2 sg_layout = [8, 4] sg_data = [32, 64] inst_data = [8, 16] : !transform.any_op
141111
transform.yield
142112
}
143113
}

0 commit comments

Comments
 (0)