Skip to content

Commit 6026ca3

Browse files
authored
[mlir][XeGPU] add unroll patterns for load_matrix and store_matrix (llvm#154637)
1 parent a26cd2d commit 6026ca3

File tree

7 files changed

+177
-16
lines changed

7 files changed

+177
-16
lines changed

mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,8 @@ def XeGPUBlocking: Pass<"xegpu-blocking"> {
6767
to a hardware instruction.
6868
}];
6969
let dependentDialects = [
70-
"memref::MemRefDialect", "xegpu::XeGPUDialect", "vector::VectorDialect"
71-
];
70+
"memref::MemRefDialect", "xegpu::XeGPUDialect", "vector::VectorDialect",
71+
"index::IndexDialect"];
7272
}
7373

7474
#endif // MLIR_DIALECT_XEGPU_TRANSFORMS_PASSES_TD

mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ class OpResult;
2020
class OpBuilder;
2121
class ValueRange;
2222
class TypeConverter;
23+
class OpFoldResult;
2324

2425
namespace xegpu {
2526
class DistributeLayoutAttr;
@@ -143,6 +144,11 @@ void doSCFStructuralTypeConversionWithTensorType(Operation *op,
143144
/// if no GPU module parent or XeVM target attribute exists.
144145
std::optional<std::string> getChipStr(Operation *op);
145146

147+
/// Generates element-wise addition ops of two arrays with same length.
148+
SmallVector<OpFoldResult> addElementwise(OpBuilder &builder, Location loc,
149+
ArrayRef<OpFoldResult> lhs,
150+
ArrayRef<OpFoldResult> rhs);
151+
146152
/// Generates element-wise addition ops of two arrays with automatic alignment.
147153
/// When the input arrays have different sizes, the shorter array is
148154
/// right-aligned with the longer array, and the unmatched leading elements from
@@ -156,7 +162,6 @@ std::optional<std::string> getChipStr(Operation *op);
156162
SmallVector<OpFoldResult> addWithRightAligned(OpBuilder &builder, Location loc,
157163
ArrayRef<OpFoldResult> lhs,
158164
ArrayRef<OpFoldResult> rhs);
159-
160165
} // namespace xegpu
161166

162167
} // namespace mlir

mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#include "mlir/Dialect/XeGPU/Transforms/Passes.h"
1010

11+
#include "mlir/Dialect/Index/IR/IndexDialect.h"
1112
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
1213
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
1314
#include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
@@ -157,10 +158,10 @@ XeGPUBlockingPass::getTileShape(const T &operandOrResult) const {
157158
std::optional<SmallVector<int64_t>>
158159
XeGPUBlockingPass::getTileShape(Operation *op) const {
159160
if (isa<xegpu::CreateNdDescOp, xegpu::UpdateNdOffsetOp, xegpu::CreateDescOp,
160-
xegpu::UpdateOffsetOp>(op))
161+
xegpu::UpdateOffsetOp, xegpu::LoadMatrixOp>(op))
161162
return getTileShape(op->getOpResult(0));
162163
if (isa<xegpu::PrefetchNdOp, xegpu::LoadNdOp, xegpu::PrefetchOp,
163-
xegpu::LoadGatherOp>(op))
164+
xegpu::LoadGatherOp, xegpu::StoreMatrixOp>(op))
164165
return getTileShape(op->getOpOperand(0));
165166
if (isa<xegpu::StoreNdOp, xegpu::StoreScatterOp>(op))
166167
return getTileShape(op->getOpOperand(1));

mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp

Lines changed: 82 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -682,13 +682,90 @@ struct UnrollUpdateOffsetOp : public UnrollPattern<xegpu::UpdateOffsetOp> {
682682
}
683683
};
684684

685+
struct UnrollLoadMatrixOp : public UnrollPattern<xegpu::LoadMatrixOp> {
686+
using UnrollPattern<xegpu::LoadMatrixOp>::UnrollPattern;
687+
LogicalResult matchAndRewrite(xegpu::LoadMatrixOp op,
688+
PatternRewriter &rewriter) const override {
689+
Location loc = op.getLoc();
690+
VectorType valueTy = op.getType();
691+
std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
692+
if (!targetShape || targetShape->size() != (size_t)valueTy.getRank())
693+
return failure();
694+
695+
Type elemTy = valueTy.getElementType();
696+
ArrayRef<int64_t> shape = valueTy.getShape();
697+
auto layout = dyn_cast<xegpu::LayoutAttr>(op.getLayoutAttr());
698+
699+
VectorType newValueTy = valueTy.cloneWith(*targetShape, elemTy);
700+
701+
SmallVector<OpFoldResult> mixedOffsets = op.getMixedOffsets();
702+
SmallVector<SmallVector<OpFoldResult>> offsetsList;
703+
for (SmallVector<int64_t> offsets :
704+
StaticTileOffsetRange(shape, *targetShape)) {
705+
auto adds = xegpu::addElementwise(
706+
rewriter, loc, mixedOffsets,
707+
getAsIndexOpFoldResult(op.getContext(), offsets));
708+
offsetsList.push_back(adds);
709+
}
710+
711+
SmallVector<Value> newOps;
712+
layout = layout.dropInstData();
713+
for (SmallVector<OpFoldResult> offsets : offsetsList) {
714+
auto newOp = rewriter.create<xegpu::LoadMatrixOp>(
715+
op.getLoc(), newValueTy, op.getMemDesc(), offsets, layout);
716+
newOps.push_back(newOp);
717+
}
718+
Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
719+
rewriter.replaceOp(op, castOp);
720+
return success();
721+
}
722+
};
723+
724+
struct UnrollStoreMatrixOp : public UnrollPattern<xegpu::StoreMatrixOp> {
725+
using UnrollPattern<xegpu::StoreMatrixOp>::UnrollPattern;
726+
LogicalResult matchAndRewrite(xegpu::StoreMatrixOp op,
727+
PatternRewriter &rewriter) const override {
728+
std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
729+
if (!targetShape)
730+
return failure();
731+
732+
Location loc = op.getLoc();
733+
VectorType valueTy = op.getData().getType();
734+
ArrayRef<int64_t> shape = valueTy.getShape();
735+
auto layout = dyn_cast<xegpu::LayoutAttr>(op.getLayoutAttr());
736+
737+
SmallVector<Type> convertedValTypes =
738+
getUnrolledTypes(valueTy, *targetShape);
739+
SmallVector<Value> convertedValues =
740+
pack(op.getData(), convertedValTypes, *targetShape, loc, rewriter);
741+
742+
SmallVector<OpFoldResult> mixedOffsets = op.getMixedOffsets();
743+
SmallVector<SmallVector<OpFoldResult>> offsetsList;
744+
for (SmallVector<int64_t> offsets :
745+
StaticTileOffsetRange(shape, *targetShape)) {
746+
auto adds = xegpu::addElementwise(
747+
rewriter, loc, mixedOffsets,
748+
getAsIndexOpFoldResult(op.getContext(), offsets));
749+
offsetsList.push_back(adds);
750+
}
751+
752+
for (auto [v, offsets] : llvm::zip_equal(convertedValues, offsetsList))
753+
rewriter.create<xegpu::StoreMatrixOp>(loc, v, op.getMemDesc(), offsets,
754+
layout.dropInstData());
755+
756+
rewriter.eraseOp(op);
757+
return success();
758+
}
759+
};
760+
685761
} // namespace
686762

687763
void mlir::xegpu::populateXeGPUUnrollPatterns(
688764
RewritePatternSet &patterns, const xegpu::UnrollOptions &options) {
689-
patterns.add<UnrollCreateNdOp, UnrollUpdateNdOffsetOp, UnrollPrefetchNdOp,
690-
UnrollLoadNdOp, UnrollStoreNdOp, UnrollDpasOp,
691-
UnrollCreateDescOp, UnrollLoadGatherOp, UnrollStoreScatterOp,
692-
UnrollPrefetchOp, UnrollUpdateOffsetOp>(patterns.getContext(),
693-
options);
765+
patterns
766+
.add<UnrollCreateNdOp, UnrollUpdateNdOffsetOp, UnrollPrefetchNdOp,
767+
UnrollLoadNdOp, UnrollStoreNdOp, UnrollDpasOp, UnrollCreateDescOp,
768+
UnrollLoadGatherOp, UnrollStoreScatterOp, UnrollPrefetchOp,
769+
UnrollUpdateOffsetOp, UnrollLoadMatrixOp, UnrollStoreMatrixOp>(
770+
patterns.getContext(), options);
694771
}

mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp

Lines changed: 34 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,14 @@ xegpu::DistributeLayoutAttr xegpu::getDistributeLayoutAttr(const Value value) {
134134
if (auto loadNd = dyn_cast<xegpu::LoadNdOp>(defOp))
135135
return getDistributeLayoutAttr(loadNd.getTensorDesc());
136136

137+
// for LoadMatrixOp, the layout is attached to the property of the op
138+
if (auto loadOp = dyn_cast<xegpu::LoadMatrixOp>(defOp))
139+
return loadOp.getLayoutAttr();
140+
141+
// for StoreMatrixOp, the layout is attached to the property of the op
142+
if (auto storeOp = dyn_cast<xegpu::StoreMatrixOp>(defOp))
143+
return storeOp.getLayoutAttr();
144+
137145
std::string layoutName = getLayoutName(result);
138146
if (defOp->hasAttr(layoutName))
139147
return defOp->getAttrOfType<xegpu::DistributeLayoutAttr>(layoutName);
@@ -154,6 +162,13 @@ xegpu::DistributeLayoutAttr xegpu::getDistributeLayoutAttr(const Value value) {
154162
xegpu::DistributeLayoutAttr
155163
xegpu::getDistributeLayoutAttr(const OpOperand &opr) {
156164
Operation *op = opr.getOwner();
165+
166+
if (auto loadOp = dyn_cast<xegpu::LoadMatrixOp>(op))
167+
return loadOp.getLayoutAttr();
168+
169+
if (auto storeOp = dyn_cast<xegpu::StoreMatrixOp>(op))
170+
return storeOp.getLayoutAttr();
171+
157172
std::string layoutName = xegpu::getLayoutName(opr);
158173
if (op->hasAttr(layoutName))
159174
return op->getAttrOfType<xegpu::DistributeLayoutAttr>(layoutName);
@@ -182,6 +197,9 @@ template void xegpu::setDistributeLayoutAttr<mlir::OpOperand>(
182197
void xegpu::setDistributeLayoutAttrs(
183198
Operation *op, function_ref<DistributeLayoutAttr(Value)> getLayoutImpl) {
184199
op->walk([&](Operation *nestOp) {
200+
if (isa<xegpu::LoadMatrixOp, xegpu::StoreMatrixOp>(nestOp))
201+
return;
202+
185203
for (OpOperand &opr : nestOp->getOpOperands()) {
186204
auto layout = getLayoutImpl(opr.get());
187205
setDistributeLayoutAttr(opr, layout);
@@ -429,6 +447,21 @@ std::optional<std::string> xegpu::getChipStr(Operation *op) {
429447
return std::nullopt;
430448
}
431449

450+
/// Generates element-wise addition ops of two arrays with same length.
451+
SmallVector<OpFoldResult> xegpu::addElementwise(OpBuilder &builder,
452+
Location loc,
453+
ArrayRef<OpFoldResult> lhs,
454+
ArrayRef<OpFoldResult> rhs) {
455+
assert(lhs.size() == rhs.size() && "lhs and rhs must have the same size");
456+
SmallVector<OpFoldResult> results;
457+
for (auto [l, r] : llvm::zip_equal(lhs, rhs)) {
458+
auto lval = getValueOrCreateConstantIndexOp(builder, loc, l);
459+
auto rval = getValueOrCreateConstantIndexOp(builder, loc, r);
460+
results.push_back(builder.createOrFold<index::AddOp>(loc, lval, rval));
461+
}
462+
return results;
463+
}
464+
432465
/// Generates element-wise addition ops of two arrays with automatic alignment.
433466
/// When the input arrays have different sizes, the shorter array is
434467
/// right-aligned with the longer array, and the unmatched leading elements from
@@ -448,11 +481,6 @@ xegpu::addWithRightAligned(OpBuilder &builder, Location loc,
448481
ArrayRef<OpFoldResult> b = lhs.size() >= rhs.size() ? rhs : lhs;
449482
SmallVector<OpFoldResult> results(a.take_front(a.size() - b.size()));
450483
a = a.slice(a.size() - b.size());
451-
for (auto [l, r] : llvm::zip(a, b)) {
452-
auto lval = getValueOrCreateConstantIndexOp(builder, loc, l);
453-
auto rval = getValueOrCreateConstantIndexOp(builder, loc, r);
454-
results.push_back(builder.createOrFold<index::AddOp>(loc, lval, rval));
455-
}
484+
results.append(addElementwise(builder, loc, a, b));
456485
return results;
457-
return {};
458486
}

mlir/test/Dialect/XeGPU/xegpu-blocking.mlir

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -561,3 +561,26 @@ gpu.module @test_kernel {
561561
gpu.return %e : vector<8x32x2xf16>
562562
}
563563
}
564+
565+
// -----
566+
gpu.module @test_kernel {
567+
//CHECK-LABEL: unroll_load_matrix
568+
gpu.func @unroll_load_matrix(%arg0: memref<4096xi8, 3>) -> vector<32x32xf32> {
569+
%0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x32xf32>
570+
//CHECK-COUNT-8: xegpu.load_matrix {{.*}} : !xegpu.mem_desc<32x32xf32>, index, index -> vector<8x16xf32>
571+
//CHECK-COUNT-8: vector.insert_strided_slice {{.*}} : vector<8x16xf32> into vector<32x32xf32>
572+
%1 = xegpu.load_matrix %0[0, 0] <{layout = #xegpu.layout<inst_data = [8, 16]>}>: !xegpu.mem_desc<32x32xf32> -> vector<32x32xf32>
573+
gpu.return %1: vector<32x32xf32>
574+
}
575+
}
576+
577+
// -----
578+
gpu.module @test_kernel {
579+
// CHECK-LABEL: unroll_store_matrix
580+
gpu.func @unroll_store_matrix(%value: vector<32x32xf32>, %arg0 : memref<32768xi8, 3>) {
581+
%mdesc = xegpu.create_mem_desc %arg0 : memref<32768xi8, 3> -> !xegpu.mem_desc<64x128xf32>
582+
// CHECK-COUNT-8: xegpu.store_matrix {{.*}} : vector<8x16xf32>, !xegpu.mem_desc<64x128xf32>, index, index
583+
xegpu.store_matrix %value, %mdesc[0, 0] {layout = #xegpu.layout<inst_data = [8, 16]>} : vector<32x32xf32>, !xegpu.mem_desc<64x128xf32>
584+
gpu.return
585+
}
586+
}

mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,33 @@ gpu.module @test_1_1_assignment {
2626
gpu.return
2727
}
2828

29+
// CHECK-LABEL: create_nd_tdesc_from_higher_rank_memref
30+
// CHECK-SAME: [[ARG_0:%.*]]: memref<3x256x128xf32>
31+
gpu.func @create_nd_tdesc_from_higher_rank_memref(%src: memref<3x256x128xf32>) {
32+
//CHECK: [[SGID:%.+]] = gpu.subgroup_id : index
33+
//CHECK: [[SGIDY:%.+]] = affine.apply #map()[[[SGID]]]
34+
//CHECK: [[SGIDX:%.+]] = affine.apply #map1()[[[SGID]]]
35+
//CHECK: [[C32:%.+]] = arith.constant 32 : index
36+
//CHECK: [[LY:%.+]] = index.mul [[SGIDY]], [[C32]]
37+
//CHECK: [[LX:%.+]] = index.mul [[SGIDX]], [[C32]]
38+
//CHECK: [[C0:%.+]] = arith.constant 0 : index
39+
//CHECK: [[C0_2:%.+]] = arith.constant 0 : index
40+
//CHECK: [[UY:%.+]] = arith.addi [[LY]], [[C0]] : index
41+
//CHECK: [[UX:%.+]] = arith.addi [[LX]], [[C0_2]] : index
42+
//CHECK: [[C256:%.+]] = arith.constant 256 : index
43+
//CHECK: [[MODY:%.+]] = index.remu [[UY]], [[C256]]
44+
//CHECK: [[C128:%.+]] = arith.constant 128 : index
45+
//CHECK: [[MODX:%.+]] = index.remu [[UX]], [[C128]]
46+
//CHECK: [[C0_3:%.+]] = arith.constant 0 : index
47+
//CHECK: [[Y:%.+]] = index.add [[MODY]], [[C0_3]]
48+
//CHECK: [[C0_4:%.+]] = arith.constant 0 : index
49+
//CHECK: [[X:%.+]] = index.add [[MODX]], [[C0_4]]
50+
//CHECK: [[TDESC:%.+]] = xegpu.create_nd_tdesc [[ARG_0]][1, [[Y]], [[X]]] : memref<3x256x128xf32> -> !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
51+
%tdesc = xegpu.create_nd_tdesc %src[1, 0, 0] : memref<3x256x128xf32>
52+
-> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
53+
gpu.return
54+
}
55+
2956
// CHECK-LABEL: load_nd_tdesc
3057
// CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
3158
gpu.func @load_nd_tdesc(%src: memref<256x128xf32>) {

0 commit comments

Comments
 (0)