@@ -649,6 +649,44 @@ struct UnrealizedConversionCastOpPattern
649649 }
650650};
651651
652+ // This pattern distributes arith.constant op into subgroup-level constants
653+ struct WgToSgArithConstantOp : public OpConversionPattern <arith::ConstantOp> {
654+ using OpConversionPattern<arith::ConstantOp>::OpConversionPattern;
655+
656+ LogicalResult
657+ matchAndRewrite (arith::ConstantOp op, OneToNOpAdaptor adaptor,
658+ ConversionPatternRewriter &rewriter) const override {
659+ auto vecAttr = dyn_cast<DenseElementsAttr>(op.getValue ());
660+ auto vecType = dyn_cast<VectorType>(op.getType ());
661+ if (!vecAttr || !vecAttr.isSplat () || !vecType)
662+ return failure ();
663+
664+ xegpu::LayoutAttr layout = xegpu::getLayoutAttr (op.getResult ());
665+ if (!layout || !layout.getSgLayout ())
666+ return failure ();
667+
668+ ArrayRef<int64_t > wgShape = vecType.getShape ();
669+ SmallVector<int64_t > sgShape;
670+ int count;
671+ std::tie (sgShape, count) = getSgShapeAndCount (wgShape, layout);
672+
673+ // Current limitation: constant of vector with single value.
674+ // TODO: support more complex cases, e.g., vector with multiple values.
675+ Attribute singleVal = vecAttr.getSplatValue <Attribute>();
676+
677+ auto newType = VectorType::get (sgShape, vecType.getElementType ());
678+ auto sgAttr = DenseElementsAttr::get (newType, singleVal);
679+ auto cstOp =
680+ rewriter.create <arith::ConstantOp>(op.getLoc (), newType, sgAttr);
681+ if (auto newLayout = layout.dropSgLayoutAndData ())
682+ xegpu::setLayoutAttr (cstOp->getResult (0 ), newLayout);
683+ SmallVector<Value> newConsts (count, cstOp);
684+
685+ rewriter.replaceOpWithMultiple (op, {newConsts});
686+ return success ();
687+ }
688+ };
689+
652690} // namespace
653691
654692namespace mlir {
@@ -657,7 +695,8 @@ void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) {
657695 patterns.add <WgToSgCreateNdOp, WgToSgLoadNdOp, WgToSgStoreNdOp,
658696 WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp,
659697 UnrealizedConversionCastOpPattern, WgToSgElementwiseOp,
660- WgToSgVectorBroadcastOp, WgToSgConvertLayoutOp>(
698+ WgToSgVectorBroadcastOp, WgToSgConvertLayoutOp,
699+ WgToSgArithConstantOp>(
661700 patterns.getContext ());
662701}
663702} // namespace xegpu
@@ -770,6 +809,14 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
770809 return isLegal (xegpu::getLayoutAttr (op.getResult ()));
771810 });
772811
812+ target.addDynamicallyLegalOp <arith::ConstantOp>(
813+ [=](arith::ConstantOp op) -> bool {
814+ auto vecType = dyn_cast<VectorType>(op.getType ());
815+ if (!vecType)
816+ return true ;
817+ return isLegal (xegpu::getLayoutAttr (op.getResult ()));
818+ });
819+
773820 target.addDynamicallyLegalOp <xegpu::ConvertLayoutOp>(
774821 [=](xegpu::ConvertLayoutOp op) -> bool {
775822 return isLegal (op.getInputLayout ()) && isLegal (op.getTargetLayout ());
0 commit comments