@@ -378,7 +378,7 @@ bubbleUpPackOpThroughGenericOp(RewriterBase &rewriter, tensor::PackOp packOp,
378378 return failure ();
379379
380380 // User controlled propagation function.
381- if (!controlFn (genericOp ))
381+ if (!controlFn (&packOp. getSourceMutable () ))
382382 return failure ();
383383
384384 // TODO: Enable propagation in the presence of linalg.index and
@@ -488,7 +488,7 @@ class BubbleUpPackThroughPadOp final : public OpRewritePattern<tensor::PackOp> {
488488 return failure ();
489489
490490 // User controlled propagation function.
491- if (!controlFn (padOp ))
491+ if (!controlFn (&packOp. getSourceMutable () ))
492492 return failure ();
493493
494494 if (!padOp.getResult ().hasOneUse ())
@@ -844,7 +844,7 @@ class BubbleUpPackOpThroughReshapeOp final
844844 }
845845
846846 // User controlled propagation function.
847- if (!controlFn (srcOp ))
847+ if (!controlFn (&packOp. getSourceMutable () ))
848848 return failure ();
849849
850850 return TypeSwitch<Operation *, LogicalResult>(srcOp)
@@ -880,10 +880,13 @@ class BubbleUpPackOpThroughReshapeOp final
880880// / %unpack = tensor.unpack %expanded outer_dims_perm = [0, 1, 2]
881881// / inner_dims_pos = [1, 2] inner_tiles = [8, 8] into %empty
882882// / : tensor<?x32x32x8x8xf32> -> tensor<?x256x256xf32>
883- static LogicalResult
884- pushDownUnPackOpThroughExpandShape (tensor::UnPackOp unPackOp,
885- tensor::ExpandShapeOp expandOp,
886- PatternRewriter &rewriter) {
883+ static LogicalResult pushDownUnPackOpThroughExpandShape (
884+ tensor::UnPackOp unPackOp, tensor::ExpandShapeOp expandOp,
885+ PatternRewriter &rewriter, ControlPropagationFn controlFn) {
886+ // User controlled propagation function.
887+ if (!controlFn (&expandOp.getSrcMutable ()))
888+ return failure ();
889+
887890 SmallVector<int64_t > innerTileSizes = unPackOp.getStaticTiles ();
888891 ArrayRef<int64_t > innerDimsPos = unPackOp.getInnerDimsPos ();
889892 ArrayRef<int64_t > outerDimsPerm = unPackOp.getOuterDimsPerm ();
@@ -970,13 +973,10 @@ class PushDownUnPackOpThroughReshapeOp final
970973 }
971974
972975 Operation *consumerOp = *result.user_begin ();
973- // User controlled propagation function.
974- if (!controlFn (consumerOp))
975- return failure ();
976-
977976 return TypeSwitch<Operation *, LogicalResult>(consumerOp)
978977 .Case ([&](tensor::ExpandShapeOp op) {
979- return pushDownUnPackOpThroughExpandShape (unPackOp, op, rewriter);
978+ return pushDownUnPackOpThroughExpandShape (unPackOp, op, rewriter,
979+ controlFn);
980980 })
981981 .Default ([](Operation *) { return failure (); });
982982 }
@@ -1038,7 +1038,8 @@ static FailureOr<OpOperand *> getUnPackedOperand(GenericOp genericOp) {
10381038// / inner_dims_pos = [3] inner_tiles = [32] into %0
10391039// /
10401040static FailureOr<std::tuple<GenericOp, Value>>
1041- pushDownUnPackOpThroughGenericOp (RewriterBase &rewriter, GenericOp genericOp) {
1041+ pushDownUnPackOpThroughGenericOp (RewriterBase &rewriter, GenericOp genericOp,
1042+ ControlPropagationFn controlFn) {
10421043 if (genericOp.getNumResults () != 1 )
10431044 return failure ();
10441045
@@ -1055,6 +1056,10 @@ pushDownUnPackOpThroughGenericOp(RewriterBase &rewriter, GenericOp genericOp) {
10551056 tensor::UnPackOp producerUnPackOp =
10561057 unPackedOperand->get ().getDefiningOp <tensor::UnPackOp>();
10571058 assert (producerUnPackOp && " expect a valid UnPackOp" );
1059+
1060+ if (!controlFn (unPackedOperand))
1061+ return failure ();
1062+
10581063 auto packInfo =
10591064 getPackingInfoFromOperand (unPackedOperand, genericOp, producerUnPackOp);
10601065 if (failed (packInfo))
@@ -1122,10 +1127,8 @@ struct PushDownUnPackOpThroughGenericOp : public OpRewritePattern<GenericOp> {
11221127
11231128 LogicalResult matchAndRewrite (GenericOp genericOp,
11241129 PatternRewriter &rewriter) const override {
1125- if (!controlFn (genericOp))
1126- return failure ();
1127-
1128- auto genericAndRepl = pushDownUnPackOpThroughGenericOp (rewriter, genericOp);
1130+ auto genericAndRepl =
1131+ pushDownUnPackOpThroughGenericOp (rewriter, genericOp, controlFn);
11291132 if (failed (genericAndRepl))
11301133 return failure ();
11311134 rewriter.replaceOp (genericOp, std::get<1 >(*genericAndRepl));
@@ -1150,7 +1153,7 @@ struct PushDownUnPackThroughPadOp : public OpRewritePattern<tensor::PadOp> {
11501153 if (!unpackOp)
11511154 return failure ();
11521155
1153- if (!controlFn (padOp))
1156+ if (!controlFn (& padOp. getSourceMutable () ))
11541157 return failure ();
11551158
11561159 Location loc = padOp.getLoc ();
0 commit comments