Skip to content

Commit cbce4a5

Browse files
committed
Sema: Support PackExpansionTypes in matchPackTypes()
1 parent 7d0de80 commit cbce4a5

File tree

3 files changed

+121
-11
lines changed

3 files changed

+121
-11
lines changed

include/swift/AST/PackExpansionMatcher.h

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,28 @@ class ParamPackMatcher {
8585
bool match();
8686
};
8787

88+
/// Performs a structural match of two lists of types.
89+
///
90+
/// The invariant is that each list must only contain at most one pack
91+
/// expansion type. After collecting a common prefix and suffix, the
92+
/// pack expansion on either side asborbs the remaining elements on the
93+
/// other side.
94+
class PackMatcher {
95+
ArrayRef<Type> lhsTypes;
96+
ArrayRef<Type> rhsTypes;
97+
98+
ASTContext &ctx;
99+
100+
public:
101+
SmallVector<MatchedPair, 4> pairs;
102+
103+
PackMatcher(ArrayRef<Type> lhsTypes,
104+
ArrayRef<Type> rhsTypes,
105+
ASTContext &ctx);
106+
107+
bool match();
108+
};
109+
88110
} // end namespace swift
89111

90-
#endif // SWIFT_AST_TYPE_MATCHER_H
112+
#endif // SWIFT_AST_PACK_EXPANSION_MATCHER_H

lib/AST/PackExpansionMatcher.cpp

Lines changed: 90 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,10 @@ bool ParamPackMatcher::match() {
167167
lhsParams = lhsParams.drop_front(prefixLength).drop_back(suffixLength);
168168
rhsParams = rhsParams.drop_front(prefixLength).drop_back(suffixLength);
169169

170+
// If nothing remains, we're done.
171+
if (lhsParams.empty() && rhsParams.empty())
172+
return false;
173+
170174
// If the left hand side is a single pack expansion type, bind it
171175
// to what remains of the right hand side.
172176
if (lhsParams.size() == 1) {
@@ -198,7 +202,92 @@ bool ParamPackMatcher::match() {
198202
auto lhs = PackType::get(ctx, lhsTypes);
199203

200204
// FIXME: Check rhs flags
201-
pairs.emplace_back(lhs, rhsParams[0].getPlainType(), prefixLength);
205+
pairs.emplace_back(lhs, rhsType, prefixLength);
206+
return false;
207+
}
208+
}
209+
210+
// Otherwise, all remaining possibilities are invalid:
211+
// - Neither side has any pack expansions, and they have different lengths.
212+
// - One side has a pack expansion but the other side is too short, eg
213+
// {Int, T..., Float} vs {Int}.
214+
// - The prefix and suffix are mismatched, so we're left with something
215+
// like {T..., Int} vs {Float, U...}.
216+
return true;
217+
}
218+
219+
PackMatcher::PackMatcher(
220+
ArrayRef<Type> lhsTypes,
221+
ArrayRef<Type> rhsTypes,
222+
ASTContext &ctx)
223+
: lhsTypes(lhsTypes), rhsTypes(rhsTypes), ctx(ctx) {}
224+
225+
bool PackMatcher::match() {
226+
unsigned minLength = std::min(lhsTypes.size(), rhsTypes.size());
227+
228+
// Consume the longest possible prefix where neither type in
229+
// the pair is a pack expansion type.
230+
unsigned prefixLength = 0;
231+
for (unsigned i = 0; i < minLength; ++i) {
232+
auto lhsType = lhsTypes[i];
233+
auto rhsType = rhsTypes[i];
234+
235+
if (lhsType->is<PackExpansionType>() ||
236+
rhsType->is<PackExpansionType>()) {
237+
break;
238+
}
239+
240+
pairs.emplace_back(lhsType, rhsType, i);
241+
++prefixLength;
242+
}
243+
244+
// Consume the longest possible suffix where neither type in
245+
// the pair is a pack expansion type.
246+
unsigned suffixLength = 0;
247+
for (unsigned i = 0; i < minLength - prefixLength; ++i) {
248+
auto lhsType = lhsTypes[lhsTypes.size() - i - 1];
249+
auto rhsType = rhsTypes[rhsTypes.size() - i - 1];
250+
251+
if (lhsType->is<PackExpansionType>() ||
252+
rhsType->is<PackExpansionType>()) {
253+
break;
254+
}
255+
256+
pairs.emplace_back(lhsType, rhsType, i);
257+
++suffixLength;
258+
}
259+
260+
assert(prefixLength + suffixLength <= lhsTypes.size());
261+
assert(prefixLength + suffixLength <= rhsTypes.size());
262+
263+
// Drop the consumed prefix and suffix from each list of types.
264+
lhsTypes = lhsTypes.drop_front(prefixLength).drop_back(suffixLength);
265+
rhsTypes = rhsTypes.drop_front(prefixLength).drop_back(suffixLength);
266+
267+
// If nothing remains, we're done.
268+
if (lhsTypes.empty() && rhsTypes.empty())
269+
return false;
270+
271+
// If the left hand side is a single pack expansion type, bind it
272+
// to what remains of the right hand side.
273+
if (lhsTypes.size() == 1) {
274+
auto lhsType = lhsTypes[0];
275+
if (auto *lhsExpansionType = lhsType->getAs<PackExpansionType>()) {
276+
auto rhs = PackType::get(ctx, rhsTypes);
277+
278+
pairs.emplace_back(lhsExpansionType->getPatternType(), rhs, prefixLength);
279+
return false;
280+
}
281+
}
282+
283+
// If the right hand side is a single pack expansion type, bind it
284+
// to what remains of the left hand side.
285+
if (rhsTypes.size() == 1) {
286+
auto rhsType = rhsTypes[0];
287+
if (auto *rhsExpansionType = rhsType->getAs<PackExpansionType>()) {
288+
auto lhs = PackType::get(ctx, lhsTypes);
289+
290+
pairs.emplace_back(lhs, rhsType, prefixLength);
202291
return false;
203292
}
204293
}

lib/Sema/CSSimplify.cpp

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2412,17 +2412,16 @@ ConstraintSystem::matchPackTypes(PackType *pack1, PackType *pack2,
24122412
ConstraintKind kind, TypeMatchOptions flags,
24132413
ConstraintLocatorBuilder locator) {
24142414
TypeMatchOptions subflags = getDefaultDecompositionOptions(flags);
2415-
if (pack1->getNumElements() != pack2->getNumElements())
2416-
return getTypeMatchFailure(locator);
24172415

2418-
for (unsigned i = 0, n = pack1->getNumElements(); i != n; ++i) {
2419-
Type ty1 = pack1->getElementType(i);
2420-
Type ty2 = pack2->getElementType(i);
2416+
PackMatcher matcher(pack1->getElementTypes(), pack2->getElementTypes(),
2417+
getASTContext());
2418+
if (matcher.match())
2419+
return getTypeMatchFailure(locator);
24212420

2422-
// Compare the element types.
2423-
auto result =
2424-
matchTypes(ty1, ty2, kind, subflags,
2425-
locator.withPathElement(LocatorPathElt::PackElement(i)));
2421+
for (auto pair : matcher.pairs) {
2422+
auto result = matchTypes(pair.lhs, pair.rhs, kind, subflags,
2423+
locator.withPathElement(
2424+
LocatorPathElt::PackElement(pair.idx)));
24262425
if (result.isFailure())
24272426
return result;
24282427
}

0 commit comments

Comments
 (0)