Skip to content

Commit 5c2f77c

Browse files
committed
[AST] PackMatching: Make pack expansion type check pluggable
The constraint solver would supply its own version of the function that can check for pack expansion type variables as well.
1 parent 6769e39 commit 5c2f77c

File tree

2 files changed

+52
-29
lines changed

2 files changed

+52
-29
lines changed

include/swift/AST/PackExpansionMatcher.h

Lines changed: 33 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -54,10 +54,13 @@ class TypeListPackMatcher {
5454
ArrayRef<Element> lhsElements;
5555
ArrayRef<Element> rhsElements;
5656

57+
std::function<bool(Type)> IsPackExpansionType;
5758
protected:
5859
TypeListPackMatcher(ASTContext &ctx, ArrayRef<Element> lhs,
59-
ArrayRef<Element> rhs)
60-
: ctx(ctx), lhsElements(lhs), rhsElements(rhs) {}
60+
ArrayRef<Element> rhs,
61+
std::function<bool(Type)> isPackExpansionType)
62+
: ctx(ctx), lhsElements(lhs), rhsElements(rhs),
63+
IsPackExpansionType(isPackExpansionType) {}
6164

6265
public:
6366
SmallVector<MatchedPair, 4> pairs;
@@ -86,8 +89,8 @@ class TypeListPackMatcher {
8689
auto lhsType = getElementType(lhsElt);
8790
auto rhsType = getElementType(rhsElt);
8891

89-
if (lhsType->template is<PackExpansionType>() ||
90-
rhsType->template is<PackExpansionType>()) {
92+
if (IsPackExpansionType(lhsType) ||
93+
IsPackExpansionType(rhsType)) {
9194
break;
9295
}
9396

@@ -115,8 +118,8 @@ class TypeListPackMatcher {
115118
auto lhsType = getElementType(lhsElt);
116119
auto rhsType = getElementType(rhsElt);
117120

118-
if (lhsType->template is<PackExpansionType>() ||
119-
rhsType->template is<PackExpansionType>()) {
121+
if (IsPackExpansionType(lhsType) ||
122+
IsPackExpansionType(rhsType)) {
120123
break;
121124
}
122125

@@ -139,7 +142,7 @@ class TypeListPackMatcher {
139142
// to what remains of the right hand side.
140143
if (lhsElts.size() == 1) {
141144
auto lhsType = getElementType(lhsElts[0]);
142-
if (auto *lhsExpansion = lhsType->template getAs<PackExpansionType>()) {
145+
if (IsPackExpansionType(lhsType)) {
143146
unsigned lhsIdx = prefixLength;
144147
unsigned rhsIdx = prefixLength;
145148

@@ -154,7 +157,7 @@ class TypeListPackMatcher {
154157
auto rhs = createPackBinding(rhsTypes);
155158

156159
// FIXME: Check lhs flags
157-
pairs.emplace_back(lhsExpansion, rhs, lhsIdx, rhsIdx);
160+
pairs.emplace_back(lhsType, rhs, lhsIdx, rhsIdx);
158161
return false;
159162
}
160163
}
@@ -163,7 +166,7 @@ class TypeListPackMatcher {
163166
// to what remains of the left hand side.
164167
if (rhsElts.size() == 1) {
165168
auto rhsType = getElementType(rhsElts[0]);
166-
if (auto *rhsExpansion = rhsType->template getAs<PackExpansionType>()) {
169+
if (IsPackExpansionType(rhsType)) {
167170
unsigned lhsIdx = prefixLength;
168171
unsigned rhsIdx = prefixLength;
169172

@@ -178,7 +181,7 @@ class TypeListPackMatcher {
178181
auto lhs = createPackBinding(lhsTypes);
179182

180183
// FIXME: Check rhs flags
181-
pairs.emplace_back(lhs, rhsExpansion, lhsIdx, rhsIdx);
184+
pairs.emplace_back(lhs, rhsType, lhsIdx, rhsIdx);
182185
return false;
183186
}
184187
}
@@ -197,14 +200,11 @@ class TypeListPackMatcher {
197200
Type getElementType(const Element &) const;
198201
ParameterTypeFlags getElementFlags(const Element &) const;
199202

200-
PackExpansionType *createPackBinding(ArrayRef<Type> types) const {
203+
Type createPackBinding(ArrayRef<Type> types) const {
201204
// If there is only one element and it's a PackExpansionType,
202205
// return it directly.
203-
if (types.size() == 1) {
204-
if (auto *expansionType = types.front()->getAs<PackExpansionType>()) {
205-
return expansionType;
206-
}
207-
}
206+
if (types.size() == 1 && IsPackExpansionType(types.front()))
207+
return types.front();
208208

209209
// Otherwise, wrap the elements in PackExpansionType(PackType(...)).
210210
auto *packType = PackType::get(ctx, types);
@@ -220,10 +220,12 @@ class TypeListPackMatcher {
220220
/// other side.
221221
class TuplePackMatcher : public TypeListPackMatcher<TupleTypeElt> {
222222
public:
223-
TuplePackMatcher(TupleType *lhsTuple, TupleType *rhsTuple)
224-
: TypeListPackMatcher(lhsTuple->getASTContext(),
225-
lhsTuple->getElements(),
226-
rhsTuple->getElements()) {}
223+
TuplePackMatcher(
224+
TupleType *lhsTuple, TupleType *rhsTuple,
225+
std::function<bool(Type)> isPackExpansionType =
226+
[](Type T) { return T->is<PackExpansionType>(); })
227+
: TypeListPackMatcher(lhsTuple->getASTContext(), lhsTuple->getElements(),
228+
rhsTuple->getElements(), isPackExpansionType) {}
227229
};
228230

229231
/// Performs a structural match of two lists of (unlabeled) function
@@ -235,9 +237,12 @@ class TuplePackMatcher : public TypeListPackMatcher<TupleTypeElt> {
235237
/// other side.
236238
class ParamPackMatcher : public TypeListPackMatcher<AnyFunctionType::Param> {
237239
public:
238-
ParamPackMatcher(ArrayRef<AnyFunctionType::Param> lhsParams,
239-
ArrayRef<AnyFunctionType::Param> rhsParams, ASTContext &ctx)
240-
: TypeListPackMatcher(ctx, lhsParams, rhsParams) {}
240+
ParamPackMatcher(
241+
ArrayRef<AnyFunctionType::Param> lhsParams,
242+
ArrayRef<AnyFunctionType::Param> rhsParams, ASTContext &ctx,
243+
std::function<bool(Type)> isPackExpansionType =
244+
[](Type T) { return T->is<PackExpansionType>(); })
245+
: TypeListPackMatcher(ctx, lhsParams, rhsParams, isPackExpansionType) {}
241246
};
242247

243248
/// Performs a structural match of two lists of types.
@@ -248,8 +253,11 @@ class ParamPackMatcher : public TypeListPackMatcher<AnyFunctionType::Param> {
248253
/// other side.
249254
class PackMatcher : public TypeListPackMatcher<Type> {
250255
public:
251-
PackMatcher(ArrayRef<Type> lhsTypes, ArrayRef<Type> rhsTypes, ASTContext &ctx)
252-
: TypeListPackMatcher(ctx, lhsTypes, rhsTypes) {}
256+
PackMatcher(
257+
ArrayRef<Type> lhsTypes, ArrayRef<Type> rhsTypes, ASTContext &ctx,
258+
std::function<bool(Type)> isPackExpansionType =
259+
[](Type T) { return T->is<PackExpansionType>(); })
260+
: TypeListPackMatcher(ctx, lhsTypes, rhsTypes, isPackExpansionType) {}
253261
};
254262

255263
} // end namespace swift

lib/Sema/CSSimplify.cpp

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2081,6 +2081,18 @@ static bool isInPatternMatchingContext(ConstraintLocatorBuilder locator) {
20812081

20822082
namespace {
20832083

2084+
static bool isPackExpansionType(Type type) {
2085+
if (type->is<PackExpansionType>())
2086+
return true;
2087+
2088+
if (auto *typeVar = type->getAs<TypeVariableType>())
2089+
return typeVar->getImpl()
2090+
.getLocator()
2091+
->isLastElement<LocatorPathElt::PackExpansionType>();
2092+
2093+
return false;
2094+
}
2095+
20842096
class TupleMatcher {
20852097
TupleType *tuple1;
20862098
TupleType *tuple2;
@@ -2103,7 +2115,7 @@ class TupleMatcher {
21032115
// case too eventually.
21042116
if (tuple1->containsPackExpansionType() ||
21052117
tuple2->containsPackExpansionType()) {
2106-
TuplePackMatcher matcher(tuple1, tuple2);
2118+
TuplePackMatcher matcher(tuple1, tuple2, isPackExpansionType);
21072119
if (matcher.match())
21082120
return true;
21092121

@@ -2294,7 +2306,8 @@ ConstraintSystem::matchPackTypes(PackType *pack1, PackType *pack2,
22942306
TypeMatchOptions subflags = getDefaultDecompositionOptions(flags);
22952307

22962308
PackMatcher matcher(pack1->getElementTypes(), pack2->getElementTypes(),
2297-
getASTContext());
2309+
getASTContext(), isPackExpansionType);
2310+
22982311
if (matcher.match())
22992312
return getTypeMatchFailure(locator);
23002313

@@ -3347,7 +3360,8 @@ ConstraintSystem::matchFunctionTypes(FunctionType *func1, FunctionType *func2,
33473360
// case too eventually.
33483361
if (AnyFunctionType::containsPackExpansionType(func1Params) ||
33493362
AnyFunctionType::containsPackExpansionType(func2Params)) {
3350-
ParamPackMatcher matcher(func1Params, func2Params, getASTContext());
3363+
ParamPackMatcher matcher(func1Params, func2Params, getASTContext(),
3364+
isPackExpansionType);
33513365
if (matcher.match())
33523366
return getTypeMatchFailure(locator);
33533367

@@ -13121,7 +13135,8 @@ ConstraintSystem::simplifyExplicitGenericArgumentsConstraint(
1312113135

1312213136
// Match the opened generic parameters to the specialized arguments.
1312313137
auto specializedArgs = type2->castTo<PackType>()->getElementTypes();
13124-
PackMatcher matcher(openedGenericParams, specializedArgs, getASTContext());
13138+
PackMatcher matcher(openedGenericParams, specializedArgs, getASTContext(),
13139+
isPackExpansionType);
1312513140
if (matcher.match())
1312613141
return SolutionKind::Error;
1312713142

0 commit comments

Comments
 (0)