@@ -682,13 +682,90 @@ struct UnrollUpdateOffsetOp : public UnrollPattern<xegpu::UpdateOffsetOp> {
682682 }
683683};
684684
685+ struct UnrollLoadMatrixOp : public UnrollPattern <xegpu::LoadMatrixOp> {
686+ using UnrollPattern<xegpu::LoadMatrixOp>::UnrollPattern;
687+ LogicalResult matchAndRewrite (xegpu::LoadMatrixOp op,
688+ PatternRewriter &rewriter) const override {
689+ Location loc = op.getLoc ();
690+ VectorType valueTy = op.getType ();
691+ std::optional<SmallVector<int64_t >> targetShape = getTargetShape (op);
692+ if (!targetShape || targetShape->size () != (size_t )valueTy.getRank ())
693+ return failure ();
694+
695+ Type elemTy = valueTy.getElementType ();
696+ ArrayRef<int64_t > shape = valueTy.getShape ();
697+ auto layout = dyn_cast<xegpu::LayoutAttr>(op.getLayoutAttr ());
698+
699+ VectorType newValueTy = valueTy.cloneWith (*targetShape, elemTy);
700+
701+ SmallVector<OpFoldResult> mixedOffsets = op.getMixedOffsets ();
702+ SmallVector<SmallVector<OpFoldResult>> offsetsList;
703+ for (SmallVector<int64_t > offsets :
704+ StaticTileOffsetRange (shape, *targetShape)) {
705+ auto adds = xegpu::addElementwise (
706+ rewriter, loc, mixedOffsets,
707+ getAsIndexOpFoldResult (op.getContext (), offsets));
708+ offsetsList.push_back (adds);
709+ }
710+
711+ SmallVector<Value> newOps;
712+ layout = layout.dropInstData ();
713+ for (SmallVector<OpFoldResult> offsets : offsetsList) {
714+ auto newOp = rewriter.create <xegpu::LoadMatrixOp>(
715+ op.getLoc (), newValueTy, op.getMemDesc (), offsets, layout);
716+ newOps.push_back (newOp);
717+ }
718+ Value castOp = unpack (newOps, op.getType (), *targetShape, loc, rewriter);
719+ rewriter.replaceOp (op, castOp);
720+ return success ();
721+ }
722+ };
723+
724+ struct UnrollStoreMatrixOp : public UnrollPattern <xegpu::StoreMatrixOp> {
725+ using UnrollPattern<xegpu::StoreMatrixOp>::UnrollPattern;
726+ LogicalResult matchAndRewrite (xegpu::StoreMatrixOp op,
727+ PatternRewriter &rewriter) const override {
728+ std::optional<SmallVector<int64_t >> targetShape = getTargetShape (op);
729+ if (!targetShape)
730+ return failure ();
731+
732+ Location loc = op.getLoc ();
733+ VectorType valueTy = op.getData ().getType ();
734+ ArrayRef<int64_t > shape = valueTy.getShape ();
735+ auto layout = dyn_cast<xegpu::LayoutAttr>(op.getLayoutAttr ());
736+
737+ SmallVector<Type> convertedValTypes =
738+ getUnrolledTypes (valueTy, *targetShape);
739+ SmallVector<Value> convertedValues =
740+ pack (op.getData (), convertedValTypes, *targetShape, loc, rewriter);
741+
742+ SmallVector<OpFoldResult> mixedOffsets = op.getMixedOffsets ();
743+ SmallVector<SmallVector<OpFoldResult>> offsetsList;
744+ for (SmallVector<int64_t > offsets :
745+ StaticTileOffsetRange (shape, *targetShape)) {
746+ auto adds = xegpu::addElementwise (
747+ rewriter, loc, mixedOffsets,
748+ getAsIndexOpFoldResult (op.getContext (), offsets));
749+ offsetsList.push_back (adds);
750+ }
751+
752+ for (auto [v, offsets] : llvm::zip_equal (convertedValues, offsetsList))
753+ rewriter.create <xegpu::StoreMatrixOp>(loc, v, op.getMemDesc (), offsets,
754+ layout.dropInstData ());
755+
756+ rewriter.eraseOp (op);
757+ return success ();
758+ }
759+ };
760+
685761} // namespace
686762
687763void mlir::xegpu::populateXeGPUUnrollPatterns (
688764 RewritePatternSet &patterns, const xegpu::UnrollOptions &options) {
689- patterns.add <UnrollCreateNdOp, UnrollUpdateNdOffsetOp, UnrollPrefetchNdOp,
690- UnrollLoadNdOp, UnrollStoreNdOp, UnrollDpasOp,
691- UnrollCreateDescOp, UnrollLoadGatherOp, UnrollStoreScatterOp,
692- UnrollPrefetchOp, UnrollUpdateOffsetOp>(patterns.getContext (),
693- options);
765+ patterns
766+ .add <UnrollCreateNdOp, UnrollUpdateNdOffsetOp, UnrollPrefetchNdOp,
767+ UnrollLoadNdOp, UnrollStoreNdOp, UnrollDpasOp, UnrollCreateDescOp,
768+ UnrollLoadGatherOp, UnrollStoreScatterOp, UnrollPrefetchOp,
769+ UnrollUpdateOffsetOp, UnrollLoadMatrixOp, UnrollStoreMatrixOp>(
770+ patterns.getContext (), options);
694771}
0 commit comments