Skip to content

Commit 2945e0c

Browse files
nbpateltkarna
authored andcommitted
[MLIR][XeGPU] Add pattern for arith.constant for wg to sg distribution (llvm#151977)
1 parent 9b8a495 commit 2945e0c

File tree

2 files changed

+55
-1
lines changed

2 files changed

+55
-1
lines changed

mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -649,6 +649,44 @@ struct UnrealizedConversionCastOpPattern
649649
}
650650
};
651651

652+
// This pattern distributes arith.constant op into subgroup-level constants
653+
struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
654+
using OpConversionPattern<arith::ConstantOp>::OpConversionPattern;
655+
656+
LogicalResult
657+
matchAndRewrite(arith::ConstantOp op, OneToNOpAdaptor adaptor,
658+
ConversionPatternRewriter &rewriter) const override {
659+
auto vecAttr = dyn_cast<DenseElementsAttr>(op.getValue());
660+
auto vecType = dyn_cast<VectorType>(op.getType());
661+
if (!vecAttr || !vecAttr.isSplat() || !vecType)
662+
return failure();
663+
664+
xegpu::LayoutAttr layout = xegpu::getLayoutAttr(op.getResult());
665+
if (!layout || !layout.getSgLayout())
666+
return failure();
667+
668+
ArrayRef<int64_t> wgShape = vecType.getShape();
669+
SmallVector<int64_t> sgShape;
670+
int count;
671+
std::tie(sgShape, count) = getSgShapeAndCount(wgShape, layout);
672+
673+
// Current limitation: constant of vector with single value.
674+
// TODO: support more complex cases, e.g., vector with multiple values.
675+
Attribute singleVal = vecAttr.getSplatValue<Attribute>();
676+
677+
auto newType = VectorType::get(sgShape, vecType.getElementType());
678+
auto sgAttr = DenseElementsAttr::get(newType, singleVal);
679+
auto cstOp =
680+
rewriter.create<arith::ConstantOp>(op.getLoc(), newType, sgAttr);
681+
if (auto newLayout = layout.dropSgLayoutAndData())
682+
xegpu::setLayoutAttr(cstOp->getResult(0), newLayout);
683+
SmallVector<Value> newConsts(count, cstOp);
684+
685+
rewriter.replaceOpWithMultiple(op, {newConsts});
686+
return success();
687+
}
688+
};
689+
652690
} // namespace
653691

654692
namespace mlir {
@@ -657,7 +695,8 @@ void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) {
657695
patterns.add<WgToSgCreateNdOp, WgToSgLoadNdOp, WgToSgStoreNdOp,
658696
WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp,
659697
UnrealizedConversionCastOpPattern, WgToSgElementwiseOp,
660-
WgToSgVectorBroadcastOp, WgToSgConvertLayoutOp>(
698+
WgToSgVectorBroadcastOp, WgToSgConvertLayoutOp,
699+
WgToSgArithConstantOp>(
661700
patterns.getContext());
662701
}
663702
} // namespace xegpu
@@ -770,6 +809,14 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
770809
return isLegal(xegpu::getLayoutAttr(op.getResult()));
771810
});
772811

812+
target.addDynamicallyLegalOp<arith::ConstantOp>(
813+
[=](arith::ConstantOp op) -> bool {
814+
auto vecType = dyn_cast<VectorType>(op.getType());
815+
if (!vecType)
816+
return true;
817+
return isLegal(xegpu::getLayoutAttr(op.getResult()));
818+
});
819+
773820
target.addDynamicallyLegalOp<xegpu::ConvertLayoutOp>(
774821
[=](xegpu::ConvertLayoutOp op) -> bool {
775822
return isLegal(op.getInputLayout()) && isLegal(op.getTargetLayout());

mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -373,4 +373,11 @@ gpu.func @dpas_no_sg_data(%a: memref<128x128xf16>, %b: memref<128x128xf16>) {
373373
} {sg_id_range = #xegpu.range<[3, 19]>}
374374
gpu.return
375375
}
376+
377+
// CHECK-LABEL: distribute_constant
378+
gpu.func @distribute_constant() {
379+
// CHECK: arith.constant dense<1.000000e+00> : vector<32x32xf32>
380+
%cst = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>} dense<1.0> : vector<256x128xf32>
381+
gpu.return
382+
}
376383
}

0 commit comments

Comments
 (0)