@@ -198,14 +198,14 @@ static Value genVectorReducInit(PatternRewriter &rewriter, Location loc,
198198 case vector::CombiningKind::ADD:
199199 case vector::CombiningKind::XOR:
200200 // Initialize reduction vector to: | 0 | .. | 0 | r |
201- return rewriter.create <vector::InsertElementOp>(
202- loc, r, constantZero (rewriter, loc, vtp),
203- constantIndex (rewriter, loc, 0 ));
201+ return rewriter.create <vector::InsertOp>(loc, r,
202+ constantZero (rewriter, loc, vtp),
203+ constantIndex (rewriter, loc, 0 ));
204204 case vector::CombiningKind::MUL:
205205 // Initialize reduction vector to: | 1 | .. | 1 | r |
206- return rewriter.create <vector::InsertElementOp>(
207- loc, r, constantOne (rewriter, loc, vtp),
208- constantIndex (rewriter, loc, 0 ));
206+ return rewriter.create <vector::InsertOp>(loc, r,
207+ constantOne (rewriter, loc, vtp),
208+ constantIndex (rewriter, loc, 0 ));
209209 case vector::CombiningKind::AND:
210210 case vector::CombiningKind::OR:
211211 // Initialize reduction vector to: | r | .. | r | r |
@@ -628,31 +628,49 @@ struct ForOpRewriter : public OpRewritePattern<scf::ForOp> {
628628 const VL vl;
629629};
630630
631+ static LogicalResult cleanReducChain (PatternRewriter &rewriter, Operation *op,
632+ Value inp) {
633+ if (auto redOp = inp.getDefiningOp <vector::ReductionOp>()) {
634+ if (auto forOp = redOp.getVector ().getDefiningOp <scf::ForOp>()) {
635+ if (forOp->hasAttr (LoopEmitter::getLoopEmitterLoopAttrName ())) {
636+ rewriter.replaceOp (op, redOp.getVector ());
637+ return success ();
638+ }
639+ }
640+ }
641+ return failure ();
642+ }
643+
631644// / Reduction chain cleanup.
632645// / v = for { }
633- // / s = vsum(v) v = for { }
634- // / u = expand (s) -> for (v) { }
646+ // / s = vsum(v) v = for { }
647+ // / u = broadcast (s) -> for (v) { }
635648// / for (u) { }
636- template < typename VectorOp>
637- struct ReducChainRewriter : public OpRewritePattern <VectorOp > {
649+ struct ReducChainBroadcastRewriter
650+ : public OpRewritePattern<vector::BroadcastOp > {
638651public:
639- using OpRewritePattern<VectorOp >::OpRewritePattern;
652+ using OpRewritePattern<vector::BroadcastOp >::OpRewritePattern;
640653
641- LogicalResult matchAndRewrite (VectorOp op,
654+ LogicalResult matchAndRewrite (vector::BroadcastOp op,
642655 PatternRewriter &rewriter) const override {
643- Value inp = op.getSource ();
644- if (auto redOp = inp.getDefiningOp <vector::ReductionOp>()) {
645- if (auto forOp = redOp.getVector ().getDefiningOp <scf::ForOp>()) {
646- if (forOp->hasAttr (LoopEmitter::getLoopEmitterLoopAttrName ())) {
647- rewriter.replaceOp (op, redOp.getVector ());
648- return success ();
649- }
650- }
651- }
652- return failure ();
656+ return cleanReducChain (rewriter, op, op.getSource ());
653657 }
654658};
655659
660+ // / Reduction chain cleanup.
661+ // / v = for { }
662+ // / s = vsum(v) v = for { }
663+ // / u = insert(s) -> for (v) { }
664+ // / for (u) { }
665+ struct ReducChainInsertRewriter : public OpRewritePattern <vector::InsertOp> {
666+ public:
667+ using OpRewritePattern<vector::InsertOp>::OpRewritePattern;
668+
669+ LogicalResult matchAndRewrite (vector::InsertOp op,
670+ PatternRewriter &rewriter) const override {
671+ return cleanReducChain (rewriter, op, op.getValueToStore ());
672+ }
673+ };
656674} // namespace
657675
658676// ===----------------------------------------------------------------------===//
@@ -668,6 +686,6 @@ void mlir::populateSparseVectorizationPatterns(RewritePatternSet &patterns,
668686 vector::populateVectorStepLoweringPatterns (patterns);
669687 patterns.add <ForOpRewriter>(patterns.getContext (), vectorLength,
670688 enableVLAVectorization, enableSIMDIndex32);
671- patterns.add <ReducChainRewriter<vector::InsertElementOp>,
672- ReducChainRewriter<vector::BroadcastOp>>( patterns.getContext ());
689+ patterns.add <ReducChainInsertRewriter, ReducChainBroadcastRewriter>(
690+ patterns.getContext ());
673691}
0 commit comments