Skip to content

Commit 7d0de80

Browse files
committed
Sema: Check requirements when calling a variadic generic function
1 parent 46d5fa6 commit 7d0de80

File tree

2 files changed

+74
-12
lines changed

2 files changed

+74
-12
lines changed

lib/Sema/CSSimplify.cpp

Lines changed: 45 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2607,10 +2607,9 @@ static ConstraintFix *fixRequirementFailure(ConstraintSystem &cs, Type type1,
26072607
Type type2,
26082608
ConstraintLocatorBuilder locator) {
26092609
SmallVector<LocatorPathElt, 4> path;
2610-
if (auto anchor = locator.getLocatorParts(path)) {
2611-
return fixRequirementFailure(cs, type1, type2, anchor, path);
2612-
}
2613-
return nullptr;
2610+
2611+
auto anchor = locator.getLocatorParts(path);
2612+
return fixRequirementFailure(cs, type1, type2, anchor, path);
26142613
}
26152614

26162615
static unsigned
@@ -3576,8 +3575,16 @@ ConstraintSystem::matchDeepEqualityTypes(Type type1, Type type2,
35763575
if (mismatches.empty())
35773576
return result;
35783577

3579-
if (auto last = locator.last()) {
3580-
if (last->is<LocatorPathElt::AnyRequirement>()) {
3578+
auto *loc = getConstraintLocator(locator);
3579+
3580+
auto path = loc->getPath();
3581+
if (!path.empty()) {
3582+
// If we have something like ... -> type req # -> pack element #, we're
3583+
// solving a requirement of the form T : P where T is a type parameter pack
3584+
if (path.back().is<LocatorPathElt::PackElement>())
3585+
path = path.drop_back();
3586+
3587+
if (path.back().is<LocatorPathElt::AnyRequirement>()) {
35813588
if (auto *fix = fixRequirementFailure(*this, type1, type2, locator)) {
35823589
if (recordFix(fix))
35833590
return getTypeMatchFailure(locator);
@@ -3602,7 +3609,7 @@ ConstraintSystem::matchDeepEqualityTypes(Type type1, Type type2,
36023609
}
36033610

36043611
auto *fix = GenericArgumentsMismatch::create(
3605-
*this, type1, type2, mismatches, getConstraintLocator(locator));
3612+
*this, type1, type2, mismatches, loc);
36063613

36073614
if (!recordFix(fix, impact))
36083615
return getTypeMatchSuccess();
@@ -4144,6 +4151,11 @@ static ConstraintFix *fixRequirementFailure(ConstraintSystem &cs, Type type1,
41444151
if (type1->isTypeVariableOrMember() || type2->isTypeVariableOrMember())
41454152
return nullptr;
41464153

4154+
// If we have something like ... -> type req # -> pack element #, we're
4155+
// solving a requirement of the form T : P where T is a type parameter pack
4156+
if (path.back().is<LocatorPathElt::PackElement>())
4157+
path = path.drop_back();
4158+
41474159
auto req = path.back().castTo<LocatorPathElt::AnyRequirement>();
41484160
if (req.isConditionalRequirement()) {
41494161
// path is - ... -> open generic -> type req # -> cond req #,
@@ -6051,11 +6063,17 @@ bool ConstraintSystem::repairFailures(
60516063
// record the requirement failure fix.
60526064
path.pop_back();
60536065

6054-
if (path.empty() || !path.back().is<LocatorPathElt::AnyRequirement>())
6055-
break;
6066+
// If we have something like ... -> type req # -> pack element #, we're
6067+
// solving a requirement of the form T : P where T is a type parameter pack
6068+
if (!path.empty() && path.back().is<LocatorPathElt::PackElement>())
6069+
path.pop_back();
60566070

6057-
return repairFailures(lhs, rhs, matchKind, conversionsOrFixes,
6058-
getConstraintLocator(anchor, path));
6071+
if (!path.empty() && path.back().is<LocatorPathElt::AnyRequirement>()) {
6072+
return repairFailures(lhs, rhs, matchKind, conversionsOrFixes,
6073+
getConstraintLocator(anchor, path));
6074+
}
6075+
6076+
break;
60596077
}
60606078

60616079
case ConstraintLocator::ResultBuilderBodyResult: {
@@ -7636,6 +7654,11 @@ ConstraintSystem::SolutionKind ConstraintSystem::simplifyConformsToConstraint(
76367654
return recordFix(fix) ? SolutionKind::Error : SolutionKind::Solved;
76377655
}
76387656

7657+
// If we have something like ... -> type req # -> pack element #, we're
7658+
// solving a requirement of the form T : P where T is a type parameter pack
7659+
if (path.back().is<LocatorPathElt::PackElement>())
7660+
path.pop_back();
7661+
76397662
if (auto req = path.back().getAs<LocatorPathElt::AnyRequirement>()) {
76407663
// If this is a requirement associated with `Self` which is bound
76417664
// to `Any`, let's consider this "too incorrect" to continue.
@@ -13875,10 +13898,20 @@ void ConstraintSystem::addConstraint(Requirement req,
1387513898
case RequirementKind::Conformance:
1387613899
kind = ConstraintKind::ConformsTo;
1387713900
break;
13878-
case RequirementKind::Superclass:
13901+
case RequirementKind::Superclass: {
13902+
// FIXME: Should always use ConstraintKind::SubclassOf, but that breaks
13903+
// a couple of diagnostics
13904+
if (auto *typeVar = req.getFirstType()->getAs<TypeVariableType>()) {
13905+
if (typeVar->getImpl().canBindToPack()) {
13906+
kind = ConstraintKind::SubclassOf;
13907+
break;
13908+
}
13909+
}
13910+
1387913911
conformsToAnyObject = true;
1388013912
kind = ConstraintKind::Subtype;
1388113913
break;
13914+
}
1388213915
case RequirementKind::SameType:
1388313916
kind = ConstraintKind::Bind;
1388413917
break;
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
// RUN: %target-typecheck-verify-swift -enable-experimental-variadic-generics
2+
3+
// Test instantiation of constraint solver constraints from generic requirements
4+
// involving type pack parameters
5+
6+
protocol P {}
7+
8+
func takesP<T...: P>(_: T...) {} // expected-note {{where 'T' = 'DoesNotConformToP'}}
9+
10+
struct ConformsToP: P {}
11+
struct DoesNotConformToP {}
12+
13+
takesP() // ok
14+
takesP(ConformsToP(), ConformsToP(), ConformsToP()) // ok
15+
16+
// FIXME: Bad diagnostic
17+
takesP(ConformsToP(), DoesNotConformToP(), ConformsToP()) // expected-error {{global function 'takesP' requires that 'DoesNotConformToP' conform to 'P'}}
18+
19+
class C {}
20+
21+
class SubclassOfC: C {}
22+
class NotSubclassOfC {}
23+
24+
func takesC<T...: C>(_: T...) {} // expected-note {{where 'T' = 'NotSubclassOfC'}}
25+
26+
takesC() // ok
27+
takesC(SubclassOfC(), SubclassOfC(), SubclassOfC()) // ok
28+
29+
takesC(SubclassOfC(), NotSubclassOfC(), SubclassOfC()) // expected-error {{global function 'takesC' requires that 'NotSubclassOfC' inherit from 'C'}}

0 commit comments

Comments
 (0)