Skip to content

Commit e968ea1

Browse files
authored
Merge pull request #64591 from xedin/rework-tuple-with-pack-matching
[AST] PackExpansionMatcher: use common prefix/suffix algorithm for tuple matching
2 parents 5bb6f79 + b027691 commit e968ea1

File tree

4 files changed

+299
-424
lines changed

4 files changed

+299
-424
lines changed

include/swift/AST/PackExpansionMatcher.h

Lines changed: 182 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -41,26 +41,189 @@ struct MatchedPair {
4141
: lhs(lhs), rhs(rhs), lhsIdx(lhsIdx), rhsIdx(rhsIdx) {}
4242
};
4343

44-
/// Performs a structural match of two lists of tuple elements. The invariant
45-
/// is that a pack expansion type must not be followed by an unlabeled
46-
/// element, that is, it is either the last element or the next element has
47-
/// a label.
44+
/// Performs a structural match of two lists of types.
4845
///
49-
/// In this manner, an element with a pack expansion type "absorbs" all
50-
/// unlabeled elements up to the next label. An element with any other type
51-
/// matches exactly one element on the other side.
52-
class TuplePackMatcher {
53-
ArrayRef<TupleTypeElt> lhsElts;
54-
ArrayRef<TupleTypeElt> rhsElts;
55-
46+
/// The invariant is that each list must only contain at most one pack
47+
/// expansion type. After collecting a common prefix and suffix, the
48+
/// pack expansion on either side asborbs the remaining elements on the
49+
/// other side.
50+
template <typename Element>
51+
class TypeListPackMatcher {
5652
ASTContext &ctx;
5753

54+
ArrayRef<Element> lhsElements;
55+
ArrayRef<Element> rhsElements;
56+
57+
protected:
58+
TypeListPackMatcher(ASTContext &ctx, ArrayRef<Element> lhs,
59+
ArrayRef<Element> rhs)
60+
: ctx(ctx), lhsElements(lhs), rhsElements(rhs) {}
61+
5862
public:
5963
SmallVector<MatchedPair, 4> pairs;
6064

61-
TuplePackMatcher(TupleType *lhsTuple, TupleType *rhsTuple);
65+
[[nodiscard]] bool match() {
66+
ArrayRef<Element> lhsElts(lhsElements);
67+
ArrayRef<Element> rhsElts(rhsElements);
68+
69+
unsigned minLength = std::min(lhsElts.size(), rhsElts.size());
70+
71+
// Consume the longest possible prefix where neither type in
72+
// the pair is a pack expansion type.
73+
unsigned prefixLength = 0;
74+
for (unsigned i = 0; i < minLength; ++i) {
75+
unsigned lhsIdx = i;
76+
unsigned rhsIdx = i;
77+
78+
auto lhsElt = lhsElts[lhsIdx];
79+
auto rhsElt = rhsElts[rhsIdx];
80+
81+
if (getElementLabel(lhsElt) != getElementLabel(rhsElt))
82+
break;
83+
84+
// FIXME: Check flags
85+
86+
auto lhsType = getElementType(lhsElt);
87+
auto rhsType = getElementType(rhsElt);
88+
89+
if (lhsType->template is<PackExpansionType>() ||
90+
rhsType->template is<PackExpansionType>()) {
91+
break;
92+
}
93+
94+
// FIXME: Check flags
95+
96+
pairs.emplace_back(lhsType, rhsType, lhsIdx, rhsIdx);
97+
++prefixLength;
98+
}
99+
100+
// Consume the longest possible suffix where neither type in
101+
// the pair is a pack expansion type.
102+
unsigned suffixLength = 0;
103+
for (unsigned i = 0; i < minLength - prefixLength; ++i) {
104+
unsigned lhsIdx = lhsElts.size() - i - 1;
105+
unsigned rhsIdx = rhsElts.size() - i - 1;
106+
107+
auto lhsElt = lhsElts[lhsIdx];
108+
auto rhsElt = rhsElts[rhsIdx];
109+
110+
// FIXME: Check flags
111+
112+
if (getElementLabel(lhsElt) != getElementLabel(rhsElt))
113+
break;
114+
115+
auto lhsType = getElementType(lhsElt);
116+
auto rhsType = getElementType(rhsElt);
117+
118+
if (lhsType->template is<PackExpansionType>() ||
119+
rhsType->template is<PackExpansionType>()) {
120+
break;
121+
}
122+
123+
pairs.emplace_back(lhsType, rhsType, lhsIdx, rhsIdx);
124+
++suffixLength;
125+
}
126+
127+
assert(prefixLength + suffixLength <= lhsElts.size());
128+
assert(prefixLength + suffixLength <= rhsElts.size());
129+
130+
// Drop the consumed prefix and suffix from each list of types.
131+
lhsElts = lhsElts.drop_front(prefixLength).drop_back(suffixLength);
132+
rhsElts = rhsElts.drop_front(prefixLength).drop_back(suffixLength);
133+
134+
// If nothing remains, we're done.
135+
if (lhsElts.empty() && rhsElts.empty())
136+
return false;
137+
138+
// If the left hand side is a single pack expansion type, bind it
139+
// to what remains of the right hand side.
140+
if (lhsElts.size() == 1) {
141+
auto lhsType = getElementType(lhsElts[0]);
142+
if (auto *lhsExpansion = lhsType->template getAs<PackExpansionType>()) {
143+
unsigned lhsIdx = prefixLength;
144+
unsigned rhsIdx = prefixLength;
145+
146+
SmallVector<Type, 2> rhsTypes;
147+
for (auto rhsElt : rhsElts) {
148+
if (!getElementLabel(rhsElt).empty())
149+
return true;
62150

63-
bool match();
151+
// FIXME: Check rhs flags
152+
rhsTypes.push_back(getElementType(rhsElt));
153+
}
154+
auto rhs = createPackBinding(rhsTypes);
155+
156+
// FIXME: Check lhs flags
157+
pairs.emplace_back(lhsExpansion, rhs, lhsIdx, rhsIdx);
158+
return false;
159+
}
160+
}
161+
162+
// If the right hand side is a single pack expansion type, bind it
163+
// to what remains of the left hand side.
164+
if (rhsElts.size() == 1) {
165+
auto rhsType = getElementType(rhsElts[0]);
166+
if (auto *rhsExpansion = rhsType->template getAs<PackExpansionType>()) {
167+
unsigned lhsIdx = prefixLength;
168+
unsigned rhsIdx = prefixLength;
169+
170+
SmallVector<Type, 2> lhsTypes;
171+
for (auto lhsElt : lhsElts) {
172+
if (!getElementLabel(lhsElt).empty())
173+
return true;
174+
175+
// FIXME: Check lhs flags
176+
lhsTypes.push_back(getElementType(lhsElt));
177+
}
178+
auto lhs = createPackBinding(lhsTypes);
179+
180+
// FIXME: Check rhs flags
181+
pairs.emplace_back(lhs, rhsExpansion, lhsIdx, rhsIdx);
182+
return false;
183+
}
184+
}
185+
186+
// Otherwise, all remaining possibilities are invalid:
187+
// - Neither side has any pack expansions, and they have different lengths.
188+
// - One side has a pack expansion but the other side is too short, eg
189+
// {Int, T..., Float} vs {Int}.
190+
// - The prefix and suffix are mismatched, so we're left with something
191+
// like {T..., Int} vs {Float, U...}.
192+
return true;
193+
}
194+
195+
private:
196+
Identifier getElementLabel(const Element &) const;
197+
Type getElementType(const Element &) const;
198+
ParameterTypeFlags getElementFlags(const Element &) const;
199+
200+
PackExpansionType *createPackBinding(ArrayRef<Type> types) const {
201+
// If there is only one element and it's a PackExpansionType,
202+
// return it directly.
203+
if (types.size() == 1) {
204+
if (auto *expansionType = types.front()->getAs<PackExpansionType>()) {
205+
return expansionType;
206+
}
207+
}
208+
209+
// Otherwise, wrap the elements in PackExpansionType(PackType(...)).
210+
auto *packType = PackType::get(ctx, types);
211+
return PackExpansionType::get(packType, packType);
212+
}
213+
};
214+
215+
/// Performs a structural match of two lists of tuple elements.
216+
///
217+
/// The invariant is that each list must only contain at most one pack
218+
/// expansion type. After collecting a common prefix and suffix, the
219+
/// pack expansion on either side asborbs the remaining elements on the
220+
/// other side.
221+
class TuplePackMatcher : public TypeListPackMatcher<TupleTypeElt> {
222+
public:
223+
TuplePackMatcher(TupleType *lhsTuple, TupleType *rhsTuple)
224+
: TypeListPackMatcher(lhsTuple->getASTContext(),
225+
lhsTuple->getElements(),
226+
rhsTuple->getElements()) {}
64227
};
65228

66229
/// Performs a structural match of two lists of (unlabeled) function
@@ -70,20 +233,11 @@ class TuplePackMatcher {
70233
/// expansion type. After collecting a common prefix and suffix, the
71234
/// pack expansion on either side asborbs the remaining elements on the
72235
/// other side.
73-
class ParamPackMatcher {
74-
ArrayRef<AnyFunctionType::Param> lhsParams;
75-
ArrayRef<AnyFunctionType::Param> rhsParams;
76-
77-
ASTContext &ctx;
78-
236+
class ParamPackMatcher : public TypeListPackMatcher<AnyFunctionType::Param> {
79237
public:
80-
SmallVector<MatchedPair, 4> pairs;
81-
82238
ParamPackMatcher(ArrayRef<AnyFunctionType::Param> lhsParams,
83-
ArrayRef<AnyFunctionType::Param> rhsParams,
84-
ASTContext &ctx);
85-
86-
bool match();
239+
ArrayRef<AnyFunctionType::Param> rhsParams, ASTContext &ctx)
240+
: TypeListPackMatcher(ctx, lhsParams, rhsParams) {}
87241
};
88242

89243
/// Performs a structural match of two lists of types.
@@ -92,20 +246,10 @@ class ParamPackMatcher {
92246
/// expansion type. After collecting a common prefix and suffix, the
93247
/// pack expansion on either side asborbs the remaining elements on the
94248
/// other side.
95-
class PackMatcher {
96-
ArrayRef<Type> lhsTypes;
97-
ArrayRef<Type> rhsTypes;
98-
99-
ASTContext &ctx;
100-
249+
class PackMatcher : public TypeListPackMatcher<Type> {
101250
public:
102-
SmallVector<MatchedPair, 4> pairs;
103-
104-
PackMatcher(ArrayRef<Type> lhsTypes,
105-
ArrayRef<Type> rhsTypes,
106-
ASTContext &ctx);
107-
108-
bool match();
251+
PackMatcher(ArrayRef<Type> lhsTypes, ArrayRef<Type> rhsTypes, ASTContext &ctx)
252+
: TypeListPackMatcher(ctx, lhsTypes, rhsTypes) {}
109253
};
110254

111255
} // end namespace swift

0 commit comments

Comments
 (0)