@@ -369,6 +369,10 @@ class RefinementKey {
369369// Which correlates to <func, sym_int_values, arg_types>
370370class RefineShapeState {
371371 public:
372+ RefineShapeState (
373+ std::optional<AdditionalShapeRefinementPatternsFn> additionalPatternsFn)
374+ : additionalPatternsFn(additionalPatternsFn) {}
375+
372376 enum class RefinementState {
373377 NOT_ALREADY_REFINED,
374378 ALREADY_REFINED,
@@ -431,7 +435,14 @@ class RefineShapeState {
431435 });
432436 }
433437
438+ void addAdditionalPatterns (RewritePatternSet& patterns) {
439+ if (additionalPatternsFn.has_value ())
440+ additionalPatternsFn.value ()(&patterns);
441+ }
442+
434443 private:
444+ std::optional<AdditionalShapeRefinementPatternsFn> additionalPatternsFn;
445+
435446 // Maps refined functions to the refinement context: the values of dimension
436447 // arguments and the types of non-global-constant arguments. A function is
437448 // added here when we start refining it.
@@ -1001,7 +1012,7 @@ struct UpdateRegionTypePattern : public OpRewritePattern<ReturnOp> {
10011012LogicalResult applyShapeRefinementPatterns (func::FuncOp func,
10021013 RefineShapeState& state) {
10031014 MLIRContext* context = func.getContext ();
1004- RewritePatternSet patterns (context );
1015+ RewritePatternSet patterns (func-> getContext () );
10051016 GreedyRewriteConfig config;
10061017
10071018 // The algorithm behind this pass consists of a single traversal of the
@@ -1019,6 +1030,9 @@ LogicalResult applyShapeRefinementPatterns(func::FuncOp func,
10191030 populateStablehloRefineShapesPatterns (&patterns, context);
10201031 patterns.add <RefineCallOpPattern>(context, state);
10211032
1033+ // Populate additional patterns for StableHLO extensions.
1034+ state.addAdditionalPatterns (patterns);
1035+
10221036 // The folding patterns implement partial evaluation of shape computations
10231037 // which is a critical part of implementing type refinement for ops like
10241038 // dynamic_broadcast_in_dim, dynamic_iota and dynamic_reshape whose shape
@@ -1103,15 +1117,23 @@ struct StablehloRefineShapesPass
11031117
11041118 // Start with empty state, and no dim args / token args.
11051119 MLIRContext* context = func.getContext ();
1106- RefineShapeState state;
1107- RefinementKey key (func, 0 , {}, llvm::to_vector (func.getArgumentTypes ()));
1108- if (failed (refineFunction (*context, state, key)))
1109- return signalPassFailure ();
1120+ if (failed (refineEntryFunction (*context, func))) return signalPassFailure ();
11101121 }
11111122};
11121123
11131124} // namespace
11141125
1126+ LogicalResult refineEntryFunction (
1127+ MLIRContext& context, func::FuncOp func,
1128+ std::optional<AdditionalShapeRefinementPatternsFn> additionalPatternsFn) {
1129+ // Start with empty state, and no dim args / token args.
1130+ RefineShapeState state (additionalPatternsFn);
1131+ RefinementKey key (func, 0 , {}, llvm::to_vector (func.getArgumentTypes ()));
1132+ if (failed (refineFunction (context, state, key)))
1133+ return func.emitError (" Failed to refine entry function" );
1134+ return success ();
1135+ }
1136+
11151137func::FuncOp getStablehloRefineShapesTarget (ModuleOp module ) {
11161138 // Only one function per module is supported at the moment to avoid the need
11171139 // to think about iterative type inference algorithms.
0 commit comments