@@ -54,10 +54,13 @@ class TypeListPackMatcher {
54
54
ArrayRef<Element> lhsElements;
55
55
ArrayRef<Element> rhsElements;
56
56
57
+ std::function<bool (Type)> IsPackExpansionType;
57
58
protected:
58
59
TypeListPackMatcher (ASTContext &ctx, ArrayRef<Element> lhs,
59
- ArrayRef<Element> rhs)
60
- : ctx(ctx), lhsElements(lhs), rhsElements(rhs) {}
60
+ ArrayRef<Element> rhs,
61
+ std::function<bool (Type)> isPackExpansionType)
62
+ : ctx(ctx), lhsElements(lhs), rhsElements(rhs),
63
+ IsPackExpansionType (isPackExpansionType) {}
61
64
62
65
public:
63
66
SmallVector<MatchedPair, 4 > pairs;
@@ -86,8 +89,8 @@ class TypeListPackMatcher {
86
89
auto lhsType = getElementType (lhsElt);
87
90
auto rhsType = getElementType (rhsElt);
88
91
89
- if (lhsType-> template is <PackExpansionType>( ) ||
90
- rhsType-> template is <PackExpansionType>( )) {
92
+ if (IsPackExpansionType (lhsType ) ||
93
+ IsPackExpansionType (rhsType )) {
91
94
break ;
92
95
}
93
96
@@ -115,8 +118,8 @@ class TypeListPackMatcher {
115
118
auto lhsType = getElementType (lhsElt);
116
119
auto rhsType = getElementType (rhsElt);
117
120
118
- if (lhsType-> template is <PackExpansionType>( ) ||
119
- rhsType-> template is <PackExpansionType>( )) {
121
+ if (IsPackExpansionType (lhsType ) ||
122
+ IsPackExpansionType (rhsType )) {
120
123
break ;
121
124
}
122
125
@@ -139,7 +142,7 @@ class TypeListPackMatcher {
139
142
// to what remains of the right hand side.
140
143
if (lhsElts.size () == 1 ) {
141
144
auto lhsType = getElementType (lhsElts[0 ]);
142
- if (auto *lhsExpansion = lhsType-> template getAs <PackExpansionType>( )) {
145
+ if (IsPackExpansionType (lhsType )) {
143
146
unsigned lhsIdx = prefixLength;
144
147
unsigned rhsIdx = prefixLength;
145
148
@@ -154,7 +157,7 @@ class TypeListPackMatcher {
154
157
auto rhs = createPackBinding (rhsTypes);
155
158
156
159
// FIXME: Check lhs flags
157
- pairs.emplace_back (lhsExpansion , rhs, lhsIdx, rhsIdx);
160
+ pairs.emplace_back (lhsType , rhs, lhsIdx, rhsIdx);
158
161
return false ;
159
162
}
160
163
}
@@ -163,7 +166,7 @@ class TypeListPackMatcher {
163
166
// to what remains of the left hand side.
164
167
if (rhsElts.size () == 1 ) {
165
168
auto rhsType = getElementType (rhsElts[0 ]);
166
- if (auto *rhsExpansion = rhsType-> template getAs <PackExpansionType>( )) {
169
+ if (IsPackExpansionType (rhsType )) {
167
170
unsigned lhsIdx = prefixLength;
168
171
unsigned rhsIdx = prefixLength;
169
172
@@ -178,7 +181,7 @@ class TypeListPackMatcher {
178
181
auto lhs = createPackBinding (lhsTypes);
179
182
180
183
// FIXME: Check rhs flags
181
- pairs.emplace_back (lhs, rhsExpansion , lhsIdx, rhsIdx);
184
+ pairs.emplace_back (lhs, rhsType , lhsIdx, rhsIdx);
182
185
return false ;
183
186
}
184
187
}
@@ -197,14 +200,11 @@ class TypeListPackMatcher {
197
200
Type getElementType (const Element &) const ;
198
201
ParameterTypeFlags getElementFlags (const Element &) const ;
199
202
200
- PackExpansionType * createPackBinding (ArrayRef<Type> types) const {
203
+ Type createPackBinding (ArrayRef<Type> types) const {
201
204
// If there is only one element and it's a PackExpansionType,
202
205
// return it directly.
203
- if (types.size () == 1 ) {
204
- if (auto *expansionType = types.front ()->getAs <PackExpansionType>()) {
205
- return expansionType;
206
- }
207
- }
206
+ if (types.size () == 1 && IsPackExpansionType (types.front ()))
207
+ return types.front ();
208
208
209
209
// Otherwise, wrap the elements in PackExpansionType(PackType(...)).
210
210
auto *packType = PackType::get (ctx, types);
@@ -220,10 +220,12 @@ class TypeListPackMatcher {
220
220
// / other side.
221
221
class TuplePackMatcher : public TypeListPackMatcher <TupleTypeElt> {
222
222
public:
223
- TuplePackMatcher (TupleType *lhsTuple, TupleType *rhsTuple)
224
- : TypeListPackMatcher(lhsTuple->getASTContext (),
225
- lhsTuple->getElements(),
226
- rhsTuple->getElements()) {}
223
+ TuplePackMatcher (
224
+ TupleType *lhsTuple, TupleType *rhsTuple,
225
+ std::function<bool (Type)> isPackExpansionType =
226
+ [](Type T) { return T->is <PackExpansionType>(); })
227
+ : TypeListPackMatcher(lhsTuple->getASTContext (), lhsTuple->getElements(),
228
+ rhsTuple->getElements(), isPackExpansionType) {}
227
229
};
228
230
229
231
// / Performs a structural match of two lists of (unlabeled) function
@@ -235,9 +237,12 @@ class TuplePackMatcher : public TypeListPackMatcher<TupleTypeElt> {
235
237
// / other side.
236
238
class ParamPackMatcher : public TypeListPackMatcher <AnyFunctionType::Param> {
237
239
public:
238
- ParamPackMatcher (ArrayRef<AnyFunctionType::Param> lhsParams,
239
- ArrayRef<AnyFunctionType::Param> rhsParams, ASTContext &ctx)
240
- : TypeListPackMatcher(ctx, lhsParams, rhsParams) {}
240
+ ParamPackMatcher (
241
+ ArrayRef<AnyFunctionType::Param> lhsParams,
242
+ ArrayRef<AnyFunctionType::Param> rhsParams, ASTContext &ctx,
243
+ std::function<bool (Type)> isPackExpansionType =
244
+ [](Type T) { return T->is <PackExpansionType>(); })
245
+ : TypeListPackMatcher(ctx, lhsParams, rhsParams, isPackExpansionType) {}
241
246
};
242
247
243
248
// / Performs a structural match of two lists of types.
@@ -248,8 +253,11 @@ class ParamPackMatcher : public TypeListPackMatcher<AnyFunctionType::Param> {
248
253
// / other side.
249
254
class PackMatcher : public TypeListPackMatcher <Type> {
250
255
public:
251
- PackMatcher (ArrayRef<Type> lhsTypes, ArrayRef<Type> rhsTypes, ASTContext &ctx)
252
- : TypeListPackMatcher(ctx, lhsTypes, rhsTypes) {}
256
+ PackMatcher (
257
+ ArrayRef<Type> lhsTypes, ArrayRef<Type> rhsTypes, ASTContext &ctx,
258
+ std::function<bool (Type)> isPackExpansionType =
259
+ [](Type T) { return T->is <PackExpansionType>(); })
260
+ : TypeListPackMatcher(ctx, lhsTypes, rhsTypes, isPackExpansionType) {}
253
261
};
254
262
255
263
} // end namespace swift
0 commit comments