Skip to content

Commit 77b702d

Browse files
authored
Merge pull request #61657 from slavapestov/variadic-generic-function-check-requirements
Check generic requirements when type checking calls to variadic generic functions
2 parents e15b409 + c34f8d3 commit 77b702d

13 files changed

+490
-108
lines changed

include/swift/AST/PackExpansionMatcher.h

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,28 @@ class ParamPackMatcher {
8585
bool match();
8686
};
8787

88+
/// Performs a structural match of two lists of types.
89+
///
90+
/// The invariant is that each list must only contain at most one pack
91+
/// expansion type. After collecting a common prefix and suffix, the
92+
/// pack expansion on either side asborbs the remaining elements on the
93+
/// other side.
94+
class PackMatcher {
95+
ArrayRef<Type> lhsTypes;
96+
ArrayRef<Type> rhsTypes;
97+
98+
ASTContext &ctx;
99+
100+
public:
101+
SmallVector<MatchedPair, 4> pairs;
102+
103+
PackMatcher(ArrayRef<Type> lhsTypes,
104+
ArrayRef<Type> rhsTypes,
105+
ASTContext &ctx);
106+
107+
bool match();
108+
};
109+
88110
} // end namespace swift
89111

90-
#endif // SWIFT_AST_TYPE_MATCHER_H
112+
#endif // SWIFT_AST_PACK_EXPANSION_MATCHER_H

include/swift/Sema/Constraint.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,9 @@ enum class ConstraintKind : char {
9797
ArgumentConversion,
9898
/// The first type is convertible to the second type, including inout.
9999
OperatorArgumentConversion,
100+
/// The first type must be a subclass of the second type (which is a
101+
/// class type).
102+
SubclassOf,
100103
/// The first type must conform to the second type (which is a
101104
/// protocol type).
102105
ConformsTo,
@@ -669,6 +672,7 @@ class Constraint final : public llvm::ilist_node<Constraint>,
669672
case ConstraintKind::BridgingConversion:
670673
case ConstraintKind::ArgumentConversion:
671674
case ConstraintKind::OperatorArgumentConversion:
675+
case ConstraintKind::SubclassOf:
672676
case ConstraintKind::ConformsTo:
673677
case ConstraintKind::LiteralConformsTo:
674678
case ConstraintKind::TransitivelyConformsTo:

include/swift/Sema/ConstraintSystem.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5423,6 +5423,15 @@ class ConstraintSystem {
54235423
FunctionRefKind functionRefKind,
54245424
ConstraintLocator *locator);
54255425

5426+
/// Attempt to simplify the given superclass constraint.
5427+
///
5428+
/// \param type The type being tested.
5429+
/// \param classType The class type which the type should be a subclass of.
5430+
/// \param locator Locator describing where this constraint occurred.
5431+
SolutionKind simplifySubclassOfConstraint(Type type, Type classType,
5432+
ConstraintLocatorBuilder locator,
5433+
TypeMatchOptions flags);
5434+
54265435
/// Attempt to simplify the given conformance constraint.
54275436
///
54285437
/// \param type The type being tested.

lib/AST/PackExpansionMatcher.cpp

Lines changed: 90 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,10 @@ bool ParamPackMatcher::match() {
167167
lhsParams = lhsParams.drop_front(prefixLength).drop_back(suffixLength);
168168
rhsParams = rhsParams.drop_front(prefixLength).drop_back(suffixLength);
169169

170+
// If nothing remains, we're done.
171+
if (lhsParams.empty() && rhsParams.empty())
172+
return false;
173+
170174
// If the left hand side is a single pack expansion type, bind it
171175
// to what remains of the right hand side.
172176
if (lhsParams.size() == 1) {
@@ -198,7 +202,92 @@ bool ParamPackMatcher::match() {
198202
auto lhs = PackType::get(ctx, lhsTypes);
199203

200204
// FIXME: Check rhs flags
201-
pairs.emplace_back(lhs, rhsParams[0].getPlainType(), prefixLength);
205+
pairs.emplace_back(lhs, rhsType, prefixLength);
206+
return false;
207+
}
208+
}
209+
210+
// Otherwise, all remaining possibilities are invalid:
211+
// - Neither side has any pack expansions, and they have different lengths.
212+
// - One side has a pack expansion but the other side is too short, eg
213+
// {Int, T..., Float} vs {Int}.
214+
// - The prefix and suffix are mismatched, so we're left with something
215+
// like {T..., Int} vs {Float, U...}.
216+
return true;
217+
}
218+
219+
PackMatcher::PackMatcher(
220+
ArrayRef<Type> lhsTypes,
221+
ArrayRef<Type> rhsTypes,
222+
ASTContext &ctx)
223+
: lhsTypes(lhsTypes), rhsTypes(rhsTypes), ctx(ctx) {}
224+
225+
bool PackMatcher::match() {
226+
unsigned minLength = std::min(lhsTypes.size(), rhsTypes.size());
227+
228+
// Consume the longest possible prefix where neither type in
229+
// the pair is a pack expansion type.
230+
unsigned prefixLength = 0;
231+
for (unsigned i = 0; i < minLength; ++i) {
232+
auto lhsType = lhsTypes[i];
233+
auto rhsType = rhsTypes[i];
234+
235+
if (lhsType->is<PackExpansionType>() ||
236+
rhsType->is<PackExpansionType>()) {
237+
break;
238+
}
239+
240+
pairs.emplace_back(lhsType, rhsType, i);
241+
++prefixLength;
242+
}
243+
244+
// Consume the longest possible suffix where neither type in
245+
// the pair is a pack expansion type.
246+
unsigned suffixLength = 0;
247+
for (unsigned i = 0; i < minLength - prefixLength; ++i) {
248+
auto lhsType = lhsTypes[lhsTypes.size() - i - 1];
249+
auto rhsType = rhsTypes[rhsTypes.size() - i - 1];
250+
251+
if (lhsType->is<PackExpansionType>() ||
252+
rhsType->is<PackExpansionType>()) {
253+
break;
254+
}
255+
256+
pairs.emplace_back(lhsType, rhsType, i);
257+
++suffixLength;
258+
}
259+
260+
assert(prefixLength + suffixLength <= lhsTypes.size());
261+
assert(prefixLength + suffixLength <= rhsTypes.size());
262+
263+
// Drop the consumed prefix and suffix from each list of types.
264+
lhsTypes = lhsTypes.drop_front(prefixLength).drop_back(suffixLength);
265+
rhsTypes = rhsTypes.drop_front(prefixLength).drop_back(suffixLength);
266+
267+
// If nothing remains, we're done.
268+
if (lhsTypes.empty() && rhsTypes.empty())
269+
return false;
270+
271+
// If the left hand side is a single pack expansion type, bind it
272+
// to what remains of the right hand side.
273+
if (lhsTypes.size() == 1) {
274+
auto lhsType = lhsTypes[0];
275+
if (auto *lhsExpansionType = lhsType->getAs<PackExpansionType>()) {
276+
auto rhs = PackType::get(ctx, rhsTypes);
277+
278+
pairs.emplace_back(lhsExpansionType->getPatternType(), rhs, prefixLength);
279+
return false;
280+
}
281+
}
282+
283+
// If the right hand side is a single pack expansion type, bind it
284+
// to what remains of the left hand side.
285+
if (rhsTypes.size() == 1) {
286+
auto rhsType = rhsTypes[0];
287+
if (auto *rhsExpansionType = rhsType->getAs<PackExpansionType>()) {
288+
auto lhs = PackType::get(ctx, lhsTypes);
289+
290+
pairs.emplace_back(lhs, rhsType, prefixLength);
202291
return false;
203292
}
204293
}

lib/AST/Type.cpp

Lines changed: 28 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
#include "swift/AST/GenericEnvironment.h"
3030
#include "swift/AST/LazyResolver.h"
3131
#include "swift/AST/Module.h"
32+
#include "swift/AST/PackConformance.h"
3233
#include "swift/AST/ParameterList.h"
3334
#include "swift/AST/ProtocolConformance.h"
3435
#include "swift/AST/SILLayout.h"
@@ -4365,10 +4366,6 @@ static Type getMemberForBaseType(LookupConformanceFn lookupConformances,
43654366

43664367
// Retrieve the member type with the given name.
43674368

4368-
// Tuples don't have member types.
4369-
if (substBase->is<TupleType>())
4370-
return failed();
4371-
43724369
// If we know the associated type, look in the witness table.
43734370
if (assocType) {
43744371
auto proto = assocType->getProtocol();
@@ -4377,30 +4374,37 @@ static Type getMemberForBaseType(LookupConformanceFn lookupConformances,
43774374

43784375
if (conformance.isInvalid())
43794376
return failed();
4380-
if (!conformance.isConcrete())
4381-
return failed();
4382-
4383-
// Retrieve the type witness.
4384-
auto witness =
4385-
conformance.getConcrete()->getTypeWitnessAndDecl(assocType, options);
4386-
4387-
auto witnessTy = witness.getWitnessType();
4388-
if (!witnessTy || witnessTy->hasError())
4389-
return failed();
43904377

4391-
// This is a hacky feature allowing code completion to migrate to
4392-
// using Type::subst() without changing output.
4393-
if (options & SubstFlags::DesugarMemberTypes) {
4394-
if (auto *aliasType = dyn_cast<TypeAliasType>(witnessTy.getPointer()))
4395-
witnessTy = aliasType->getSinglyDesugaredType();
4378+
Type witnessTy;
43964379

4397-
// Another hack. If the type witness is a opaque result type. They can
4398-
// only be referred using the name of the associated type.
4399-
if (witnessTy->is<OpaqueTypeArchetypeType>())
4400-
witnessTy = witness.getWitnessDecl()->getDeclaredInterfaceType();
4380+
// Retrieve the type witness.
4381+
if (conformance.isPack()) {
4382+
auto *packConformance = conformance.getPack();
4383+
4384+
witnessTy = packConformance->getAssociatedType(
4385+
assocType->getDeclaredInterfaceType());
4386+
} else if (conformance.isConcrete()) {
4387+
auto witness =
4388+
conformance.getConcrete()->getTypeWitnessAndDecl(assocType, options);
4389+
4390+
witnessTy = witness.getWitnessType();
4391+
if (!witnessTy || witnessTy->hasError())
4392+
return failed();
4393+
4394+
// This is a hacky feature allowing code completion to migrate to
4395+
// using Type::subst() without changing output.
4396+
if (options & SubstFlags::DesugarMemberTypes) {
4397+
if (auto *aliasType = dyn_cast<TypeAliasType>(witnessTy.getPointer()))
4398+
witnessTy = aliasType->getSinglyDesugaredType();
4399+
4400+
// Another hack. If the type witness is a opaque result type. They can
4401+
// only be referred using the name of the associated type.
4402+
if (witnessTy->is<OpaqueTypeArchetypeType>())
4403+
witnessTy = witness.getWitnessDecl()->getDeclaredInterfaceType();
4404+
}
44014405
}
44024406

4403-
if (witnessTy->is<ErrorType>())
4407+
if (!witnessTy || witnessTy->is<ErrorType>())
44044408
return failed();
44054409

44064410
return witnessTy;

lib/AST/TypeWalker.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,9 @@ class Traversal : public TypeVisitor<Traversal, bool>
5757
}
5858

5959
bool visitPackExpansionType(PackExpansionType *ty) {
60+
if (doIt(ty->getCountType()))
61+
return true;
62+
6063
return doIt(ty->getPatternType());
6164
}
6265

lib/Sema/CSBindings.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1281,6 +1281,7 @@ PotentialBindings::inferFromRelational(Constraint *constraint) {
12811281

12821282
switch (constraint->getKind()) {
12831283
case ConstraintKind::Subtype:
1284+
case ConstraintKind::SubclassOf:
12841285
case ConstraintKind::Conversion:
12851286
case ConstraintKind::ArgumentConversion:
12861287
case ConstraintKind::OperatorArgumentConversion: {
@@ -1358,6 +1359,7 @@ void PotentialBindings::infer(Constraint *constraint) {
13581359
case ConstraintKind::BindParam:
13591360
case ConstraintKind::BindToPointerType:
13601361
case ConstraintKind::Subtype:
1362+
case ConstraintKind::SubclassOf:
13611363
case ConstraintKind::Conversion:
13621364
case ConstraintKind::ArgumentConversion:
13631365
case ConstraintKind::OperatorArgumentConversion:

0 commit comments

Comments
 (0)