Skip to content

Commit 447b3cc

Browse files
GleasonKTensorFlow MLIR Team
authored andcommitted
[StableHLO] Add shape refinement callback to specify additional patterns.
PiperOrigin-RevId: 705955699
1 parent 9d2924a commit 447b3cc

File tree

4 files changed

+74
-31
lines changed

4 files changed

+74
-31
lines changed

stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -369,6 +369,10 @@ class RefinementKey {
369369
// Which correlates to <func, sym_int_values, arg_types>
370370
class RefineShapeState {
371371
public:
372+
RefineShapeState(
373+
std::optional<AdditionalShapeRefinementPatternsFn> additionalPatternsFn)
374+
: additionalPatternsFn(additionalPatternsFn) {}
375+
372376
enum class RefinementState {
373377
NOT_ALREADY_REFINED,
374378
ALREADY_REFINED,
@@ -431,7 +435,14 @@ class RefineShapeState {
431435
});
432436
}
433437

438+
void addAdditionalPatterns(RewritePatternSet& patterns) {
439+
if (additionalPatternsFn.has_value())
440+
additionalPatternsFn.value()(&patterns);
441+
}
442+
434443
private:
444+
std::optional<AdditionalShapeRefinementPatternsFn> additionalPatternsFn;
445+
435446
// Maps refined functions to the refinement context: the values of dimension
436447
// arguments and the types of non-global-constant arguments. A function is
437448
// added here when we start refining it.
@@ -1001,7 +1012,7 @@ struct UpdateRegionTypePattern : public OpRewritePattern<ReturnOp> {
10011012
LogicalResult applyShapeRefinementPatterns(func::FuncOp func,
10021013
RefineShapeState& state) {
10031014
MLIRContext* context = func.getContext();
1004-
RewritePatternSet patterns(context);
1015+
RewritePatternSet patterns(func->getContext());
10051016
GreedyRewriteConfig config;
10061017

10071018
// The algorithm behind this pass consists of a single traversal of the
@@ -1019,6 +1030,9 @@ LogicalResult applyShapeRefinementPatterns(func::FuncOp func,
10191030
populateStablehloRefineShapesPatterns(&patterns, context);
10201031
patterns.add<RefineCallOpPattern>(context, state);
10211032

1033+
// Populate additional patterns for StableHLO extensions.
1034+
state.addAdditionalPatterns(patterns);
1035+
10221036
// The folding patterns implement partial evaluation of shape computations
10231037
// which is a critical part of implementing type refinement for ops like
10241038
// dynamic_broadcast_in_dim, dynamic_iota and dynamic_reshape whose shape
@@ -1103,15 +1117,23 @@ struct StablehloRefineShapesPass
11031117

11041118
// Start with empty state, and no dim args / token args.
11051119
MLIRContext* context = func.getContext();
1106-
RefineShapeState state;
1107-
RefinementKey key(func, 0, {}, llvm::to_vector(func.getArgumentTypes()));
1108-
if (failed(refineFunction(*context, state, key)))
1109-
return signalPassFailure();
1120+
if (failed(refineEntryFunction(*context, func))) return signalPassFailure();
11101121
}
11111122
};
11121123

11131124
} // namespace
11141125

1126+
LogicalResult refineEntryFunction(
1127+
MLIRContext& context, func::FuncOp func,
1128+
std::optional<AdditionalShapeRefinementPatternsFn> additionalPatternsFn) {
1129+
// Start with empty state, and no dim args / token args.
1130+
RefineShapeState state(additionalPatternsFn);
1131+
RefinementKey key(func, 0, {}, llvm::to_vector(func.getArgumentTypes()));
1132+
if (failed(refineFunction(context, state, key)))
1133+
return func.emitError("Failed to refine entry function");
1134+
return success();
1135+
}
1136+
11151137
func::FuncOp getStablehloRefineShapesTarget(ModuleOp module) {
11161138
// Only one function per module is supported at the moment to avoid the need
11171139
// to think about iterative type inference algorithms.

stablehlo/stablehlo/transforms/StablehloRefineShapes.h

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@ limitations under the License.
1616
#ifndef STABLEHLO_TRANSFORMS_STABLEHLO_REFINE_SHAPES_H
1717
#define STABLEHLO_TRANSFORMS_STABLEHLO_REFINE_SHAPES_H
1818

19-
#include "llvm/ADT/SmallVector.h"
2019
#include "mlir/Dialect/Func/IR/FuncOps.h"
2120
#include "mlir/IR/BuiltinOps.h"
2221
#include "mlir/IR/Operation.h"
@@ -101,6 +100,18 @@ LogicalResult refineReturnShape(PatternRewriter& rewriter, OpType op,
101100
return refineReturnShape(rewriter, op, shape);
102101
}
103102

103+
// Entrypoint for any pass adding extensibility to the StableHLO shape
104+
// refinement pass. If program is inlined before shape refinement,
105+
// populateShapeRefinementPatterns can be safely used, but if shape refinement
106+
// needs to operate on programs with functions and calls, then
107+
// additionalPatterns will need to be populated and passed in.
108+
using AdditionalShapeRefinementPatternsFn =
109+
std::function<void(RewritePatternSet*)>;
110+
LogicalResult refineEntryFunction(
111+
MLIRContext& context, func::FuncOp func,
112+
std::optional<AdditionalShapeRefinementPatternsFn> additionalPatternsFn =
113+
std::nullopt);
114+
104115
// Custom call used to buffer operands for shape refinement
105116
// This is a temporary artifact that is introduced by StablehloRefineArguments
106117
// and is washed away during StablehloRefineShapes.

stablehlo_ext/transforms/stablehlo_refine_shapes.cpp

Lines changed: 15 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,11 @@ limitations under the License.
1313
==============================================================================*/
1414

1515
#include <cstdint>
16+
#include <functional>
1617

1718
#include "llvm/ADT/SmallVector.h"
1819
#include "mlir/Dialect/Func/IR/FuncOps.h"
20+
#include "mlir/IR/MLIRContext.h"
1921
#include "mlir/IR/PatternMatch.h"
2022
#include "mlir/Interfaces/InferTypeOpInterface.h"
2123
#include "mlir/Support/LogicalResult.h"
@@ -138,32 +140,20 @@ struct StablehloRefineShapesPass
138140
auto func = stablehlo::getStablehloRefineShapesTarget(getOperation());
139141
if (!func) return signalPassFailure();
140142

141-
// The algorithm behind this pass consists of a single traversal of the
142-
// function. This is sufficient because we only support one function per
143-
// program at the moment.
144-
// TODO(#1048): Find out why .maxIterations = 1 no longer works.
145-
// There have been recent refactors to applyPatternsAndFoldGreedily
146-
// upstream, and that might be the reason.
147-
GreedyRewriteConfig config;
148-
config.useTopDownTraversal = true;
149-
config.enableRegionSimplification = GreedySimplifyRegionLevel::Aggressive;
150-
config.maxIterations = 3;
151-
config.maxNumRewrites = GreedyRewriteConfig::kNoLimit;
152-
config.strictMode = GreedyRewriteStrictness::AnyOp;
153-
154-
RewritePatternSet patterns(&getContext());
155-
stablehlo::populateStablehloRefineShapesPatterns(&patterns, &getContext());
156-
stablehlo::populateStablehloShapeFolderPatterns(&patterns, &getContext());
157-
patterns.add<RefineDynamicReduceWindowOpPattern>(&getContext());
158-
patterns.add<RefineDynamicRngBitGeneratorOpPattern>(&getContext());
159-
patterns.add<RefineDynamicTopKOpPattern>(&getContext());
160-
if (failed(
161-
applyPatternsAndFoldGreedily(func, std::move(patterns), config))) {
162-
func.emitError()
163-
<< "Greedy rewriter in StablehloRefineShapes does not converge after "
164-
<< config.maxIterations << " iterations.";
143+
// Start with empty state, and no dim args / token args.
144+
MLIRContext* context = func.getContext();
145+
146+
// Populate additional patterns for StableHLO extensions.
147+
std::function<void(RewritePatternSet*)> additionalPatternsFn =
148+
[&](RewritePatternSet* patterns) {
149+
patterns->add<RefineDynamicReduceWindowOpPattern>(context);
150+
patterns->add<RefineDynamicRngBitGeneratorOpPattern>(context);
151+
patterns->add<RefineDynamicTopKOpPattern>(context);
152+
};
153+
154+
if (failed(stablehlo::refineEntryFunction(*context, func,
155+
additionalPatternsFn)))
165156
return signalPassFailure();
166-
}
167157
}
168158
};
169159

tests/stablehlo_ext/stablehlo_refine_shapes.mlir

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,3 +40,23 @@ func.func @refine_dynamic_top_k(%arg0: tensor<16xf32>) -> (tensor<?xf32>, tensor
4040
%1:2 = stablehlo.custom_call @stablehlo.dynamic_top_k(%arg0, %k) : (tensor<16xf32>, tensor<ui64>) -> (tensor<?xf32>, tensor<?xi32>)
4141
return %1#0, %1#1 : tensor<?xf32>, tensor<?xi32>
4242
}
43+
44+
// -----
45+
46+
// CHECK-LABEL: module @refine_call
47+
module @refine_call {
48+
// CHECK: func.func @main{{.*}}-> (tensor<4xf32>, tensor<4xi32>)
49+
func.func @main(%arg1: tensor<16xf32>) -> (tensor<?xf32>, tensor<?xi32>) {
50+
%0 = stablehlo.bitcast_convert %arg1 : (tensor<16xf32>) -> tensor<?xf32>
51+
// CHECK: refine_call_callee{{.*}}-> (tensor<4xf32>, tensor<4xi32>)
52+
%2:2 = call @refine_call_callee(%0) : (tensor<?xf32>) -> (tensor<?xf32>, tensor<?xi32>)
53+
return %2#0, %2#1 : tensor<?xf32>, tensor<?xi32>
54+
}
55+
// CHECK: refine_call_callee(%arg0: tensor<16xf32>) -> (tensor<4xf32>, tensor<4xi32>)
56+
func.func @refine_call_callee(%arg0: tensor<?xf32>) -> (tensor<?xf32>, tensor<?xi32>) {
57+
// CHECK: stablehlo.dynamic_top_k{{.*}} -> (tensor<4xf32>, tensor<4xi32>)
58+
%k = stablehlo.constant dense<4> : tensor<ui64>
59+
%1:2 = stablehlo.custom_call @stablehlo.dynamic_top_k(%arg0, %k) : (tensor<?xf32>, tensor<ui64>) -> (tensor<?xf32>, tensor<?xi32>)
60+
return %1#0, %1#1 : tensor<?xf32>, tensor<?xi32>
61+
}
62+
}

0 commit comments

Comments
 (0)