@@ -14,52 +14,48 @@ using namespace mlir;
1414using namespace mlir ::func;
1515
1616// ===----------------------------------------------------------------------===//
17- // ValueDecomposer
17+ // Helper functions
1818// ===----------------------------------------------------------------------===//
1919
20- void ValueDecomposer::decomposeValue (OpBuilder &builder, Location loc,
21- Type type, Value value,
22- SmallVectorImpl<Value> &results) {
23- for (auto &conversion : decomposeValueConversions)
24- if (conversion (builder, loc, type, value, results))
25- return ;
26- results.push_back (value);
20+ // / If the given value can be decomposed with the type converter, decompose it.
21+ // / Otherwise, return the given value.
22+ // TODO: Value decomposition should happen automatically through a 1:N adaptor.
23+ // This function will disappear when the 1:1 and 1:N drivers are merged.
24+ static SmallVector<Value> decomposeValue (OpBuilder &builder, Location loc,
25+ Value value,
26+ const TypeConverter *converter) {
27+ // Try to convert the given value's type. If that fails, just return the
28+ // given value.
29+ SmallVector<Type> convertedTypes;
30+ if (failed (converter->convertType (value.getType (), convertedTypes)))
31+ return {value};
32+ if (convertedTypes.empty ())
33+ return {};
34+
35+ // If the given value's type is already legal, just return the given value.
36+ TypeRange convertedTypeRange (convertedTypes);
37+ if (convertedTypeRange == TypeRange (value.getType ()))
38+ return {value};
39+
40+ // Try to materialize a target conversion. If the materialization did not
41+ // produce values of the requested type, the materialization failed. Just
42+ // return the given value in that case.
43+ SmallVector<Value> result = converter->materializeTargetConversion (
44+ builder, loc, convertedTypeRange, value);
45+ if (result.empty ())
46+ return {value};
47+ return result;
2748}
2849
29- // ===----------------------------------------------------------------------===//
30- // DecomposeCallGraphTypesOpConversionPattern
31- // ===----------------------------------------------------------------------===//
32-
33- namespace {
34- // / Base OpConversionPattern class to make a ValueDecomposer available to
35- // / inherited patterns.
36- template <typename SourceOp>
37- class DecomposeCallGraphTypesOpConversionPattern
38- : public OpConversionPattern<SourceOp> {
39- public:
40- DecomposeCallGraphTypesOpConversionPattern (const TypeConverter &typeConverter,
41- MLIRContext *context,
42- ValueDecomposer &decomposer,
43- PatternBenefit benefit = 1 )
44- : OpConversionPattern<SourceOp>(typeConverter, context, benefit),
45- decomposer (decomposer) {}
46-
47- protected:
48- ValueDecomposer &decomposer;
49- };
50- } // namespace
51-
5250// ===----------------------------------------------------------------------===//
5351// DecomposeCallGraphTypesForFuncArgs
5452// ===----------------------------------------------------------------------===//
5553
5654namespace {
57- // / Expand function arguments according to the provided TypeConverter and
58- // / ValueDecomposer.
55+ // / Expand function arguments according to the provided TypeConverter.
5956struct DecomposeCallGraphTypesForFuncArgs
60- : public DecomposeCallGraphTypesOpConversionPattern<func::FuncOp> {
61- using DecomposeCallGraphTypesOpConversionPattern::
62- DecomposeCallGraphTypesOpConversionPattern;
57+ : public OpConversionPattern<func::FuncOp> {
58+ using OpConversionPattern::OpConversionPattern;
6359
6460 LogicalResult
6561 matchAndRewrite (func::FuncOp op, OpAdaptor adaptor,
@@ -100,19 +96,22 @@ struct DecomposeCallGraphTypesForFuncArgs
10096// ===----------------------------------------------------------------------===//
10197
10298namespace {
103- // / Expand return operands according to the provided TypeConverter and
104- // / ValueDecomposer.
99+ // / Expand return operands according to the provided TypeConverter.
105100struct DecomposeCallGraphTypesForReturnOp
106- : public DecomposeCallGraphTypesOpConversionPattern <ReturnOp> {
107- using DecomposeCallGraphTypesOpConversionPattern::
108- DecomposeCallGraphTypesOpConversionPattern;
101+ : public OpConversionPattern <ReturnOp> {
102+ using OpConversionPattern::OpConversionPattern;
103+
109104 LogicalResult
110105 matchAndRewrite (ReturnOp op, OpAdaptor adaptor,
111106 ConversionPatternRewriter &rewriter) const final {
112107 SmallVector<Value, 2 > newOperands;
113- for (Value operand : adaptor.getOperands ())
114- decomposer.decomposeValue (rewriter, op.getLoc (), operand.getType (),
115- operand, newOperands);
108+ for (Value operand : adaptor.getOperands ()) {
109+ // TODO: We can directly take the values from the adaptor once this is a
110+ // 1:N conversion pattern.
111+ llvm::append_range (newOperands,
112+ decomposeValue (rewriter, operand.getLoc (), operand,
113+ getTypeConverter ()));
114+ }
116115 rewriter.replaceOpWithNewOp <ReturnOp>(op, newOperands);
117116 return success ();
118117 }
@@ -124,22 +123,23 @@ struct DecomposeCallGraphTypesForReturnOp
124123// ===----------------------------------------------------------------------===//
125124
126125namespace {
127- // / Expand call op operands and results according to the provided TypeConverter
128- // / and ValueDecomposer.
129- struct DecomposeCallGraphTypesForCallOp
130- : public DecomposeCallGraphTypesOpConversionPattern<CallOp> {
131- using DecomposeCallGraphTypesOpConversionPattern::
132- DecomposeCallGraphTypesOpConversionPattern;
126+ // / Expand call op operands and results according to the provided TypeConverter.
127+ struct DecomposeCallGraphTypesForCallOp : public OpConversionPattern <CallOp> {
128+ using OpConversionPattern::OpConversionPattern;
133129
134130 LogicalResult
135131 matchAndRewrite (CallOp op, OpAdaptor adaptor,
136132 ConversionPatternRewriter &rewriter) const final {
137133
138134 // Create the operands list of the new `CallOp`.
139135 SmallVector<Value, 2 > newOperands;
140- for (Value operand : adaptor.getOperands ())
141- decomposer.decomposeValue (rewriter, op.getLoc (), operand.getType (),
142- operand, newOperands);
136+ for (Value operand : adaptor.getOperands ()) {
137+ // TODO: We can directly take the values from the adaptor once this is a
138+ // 1:N conversion pattern.
139+ llvm::append_range (newOperands,
140+ decomposeValue (rewriter, operand.getLoc (), operand,
141+ getTypeConverter ()));
142+ }
143143
144144 // Create the new result types for the new `CallOp` and track the indices in
145145 // the new call op's results that correspond to the old call op's results.
@@ -189,9 +189,8 @@ struct DecomposeCallGraphTypesForCallOp
189189
190190void mlir::populateDecomposeCallGraphTypesPatterns (
191191 MLIRContext *context, const TypeConverter &typeConverter,
192- ValueDecomposer &decomposer, RewritePatternSet &patterns) {
192+ RewritePatternSet &patterns) {
193193 patterns
194194 .add <DecomposeCallGraphTypesForCallOp, DecomposeCallGraphTypesForFuncArgs,
195- DecomposeCallGraphTypesForReturnOp>(typeConverter, context,
196- decomposer);
195+ DecomposeCallGraphTypesForReturnOp>(typeConverter, context);
197196}
0 commit comments