Skip to content

Commit 716c72c

Browse files
authored
Merge pull request swiftlang#28380 from DougGregor/for-each-in-constraint-solver
[Type checker] Fold more for-each type checking into the constraint solver
2 parents 952dd11 + bbcaf8c commit 716c72c

16 files changed

+395
-183
lines changed

include/swift/AST/ASTContext.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -489,6 +489,9 @@ class ASTContext final {
489489
/// Get the '+' function on two String.
490490
FuncDecl *getPlusFunctionOnString() const;
491491

492+
/// Get Sequence.makeIterator().
493+
FuncDecl *getSequenceMakeIterator() const;
494+
492495
/// Check whether the standard library provides all the correct
493496
/// intrinsic support for Optional<T>.
494497
///

include/swift/AST/Expr.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3924,6 +3924,10 @@ class OpaqueValueExpr : public Expr {
39243924
/// value to be specified later.
39253925
bool isPlaceholder() const { return Bits.OpaqueValueExpr.IsPlaceholder; }
39263926

3927+
void setIsPlaceholder(bool value) {
3928+
Bits.OpaqueValueExpr.IsPlaceholder = value;
3929+
}
3930+
39273931
SourceRange getSourceRange() const { return Range; }
39283932

39293933
static bool classof(const Expr *E) {

lib/AST/ASTContext.cpp

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,9 @@ struct ASTContext::Implementation {
196196
/// The declaration of '+' function for two String.
197197
FuncDecl *PlusFunctionOnString = nullptr;
198198

199+
/// The declaration of 'Sequence.makeIterator()'.
200+
FuncDecl *MakeIterator = nullptr;
201+
199202
/// The declaration of Swift.Optional<T>.Some.
200203
EnumElementDecl *OptionalSomeDecl = nullptr;
201204

@@ -710,6 +713,31 @@ FuncDecl *ASTContext::getPlusFunctionOnString() const {
710713
return getImpl().PlusFunctionOnString;
711714
}
712715

716+
FuncDecl *ASTContext::getSequenceMakeIterator() const {
717+
if (getImpl().MakeIterator) {
718+
return getImpl().MakeIterator;
719+
}
720+
721+
auto proto = getProtocol(KnownProtocolKind::Sequence);
722+
if (!proto)
723+
return nullptr;
724+
725+
for (auto result : proto->lookupDirect(Id_makeIterator)) {
726+
if (result->getDeclContext() != proto)
727+
continue;
728+
729+
if (auto func = dyn_cast<FuncDecl>(result)) {
730+
if (func->getParameters()->size() != 0)
731+
continue;
732+
733+
getImpl().MakeIterator = func;
734+
return func;
735+
}
736+
}
737+
738+
return nullptr;
739+
}
740+
713741
#define KNOWN_STDLIB_TYPE_DECL(NAME, DECL_CLASS, NUM_GENERIC_PARAMS) \
714742
DECL_CLASS *ASTContext::get##NAME##Decl() const { \
715743
if (getImpl().NAME##Decl) \

lib/Sema/CSBindings.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -610,6 +610,7 @@ ConstraintSystem::getPotentialBindings(TypeVariableType *typeVar) const {
610610

611611
case ConstraintKind::ValueMember:
612612
case ConstraintKind::UnresolvedValueMember:
613+
case ConstraintKind::ValueWitness:
613614
// If our type variable shows up in the base type, there's
614615
// nothing to do.
615616
// FIXME: Can we avoid simplification here?

lib/Sema/CSSimplify.cpp

Lines changed: 84 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1252,6 +1252,7 @@ ConstraintSystem::matchTupleTypes(TupleType *tuple1, TupleType *tuple2,
12521252
case ConstraintKind::SelfObjectOfProtocol:
12531253
case ConstraintKind::UnresolvedValueMember:
12541254
case ConstraintKind::ValueMember:
1255+
case ConstraintKind::ValueWitness:
12551256
case ConstraintKind::BridgingConversion:
12561257
case ConstraintKind::FunctionInput:
12571258
case ConstraintKind::FunctionResult:
@@ -1316,6 +1317,7 @@ static bool matchFunctionRepresentations(FunctionTypeRepresentation rep1,
13161317
case ConstraintKind::SelfObjectOfProtocol:
13171318
case ConstraintKind::UnresolvedValueMember:
13181319
case ConstraintKind::ValueMember:
1320+
case ConstraintKind::ValueWitness:
13191321
case ConstraintKind::FunctionInput:
13201322
case ConstraintKind::FunctionResult:
13211323
case ConstraintKind::OneWayEqual:
@@ -1594,6 +1596,7 @@ ConstraintSystem::matchFunctionTypes(FunctionType *func1, FunctionType *func2,
15941596
case ConstraintKind::SelfObjectOfProtocol:
15951597
case ConstraintKind::UnresolvedValueMember:
15961598
case ConstraintKind::ValueMember:
1599+
case ConstraintKind::ValueWitness:
15971600
case ConstraintKind::BridgingConversion:
15981601
case ConstraintKind::FunctionInput:
15991602
case ConstraintKind::FunctionResult:
@@ -3863,6 +3866,7 @@ ConstraintSystem::matchTypes(Type type1, Type type2, ConstraintKind kind,
38633866
case ConstraintKind::SelfObjectOfProtocol:
38643867
case ConstraintKind::UnresolvedValueMember:
38653868
case ConstraintKind::ValueMember:
3869+
case ConstraintKind::ValueWitness:
38663870
case ConstraintKind::FunctionInput:
38673871
case ConstraintKind::FunctionResult:
38683872
case ConstraintKind::OneWayEqual:
@@ -6079,7 +6083,9 @@ ConstraintSystem::SolutionKind ConstraintSystem::simplifyMemberConstraint(
60796083

60806084
markMemberTypeAsPotentialHole(memberTy);
60816085
return SolutionKind::Solved;
6082-
} else if (kind == ConstraintKind::ValueMember && baseObjTy->isHole()) {
6086+
} else if ((kind == ConstraintKind::ValueMember ||
6087+
kind == ConstraintKind::ValueWitness) &&
6088+
baseObjTy->isHole()) {
60836089
// If base type is a "hole" there is no reason to record any
60846090
// more "member not found" fixes for chained member references.
60856091
increaseScore(SK_Fix);
@@ -6354,6 +6360,72 @@ ConstraintSystem::SolutionKind ConstraintSystem::simplifyMemberConstraint(
63546360
return SolutionKind::Error;
63556361
}
63566362

6363+
ConstraintSystem::SolutionKind
6364+
ConstraintSystem::simplifyValueWitnessConstraint(
6365+
ConstraintKind kind, Type baseType, ValueDecl *requirement, Type memberType,
6366+
DeclContext *useDC, FunctionRefKind functionRefKind,
6367+
TypeMatchOptions flags, ConstraintLocatorBuilder locator) {
6368+
// We'd need to record original base type because it might be a type
6369+
// variable representing another missing member.
6370+
auto origBaseType = baseType;
6371+
6372+
auto formUnsolved = [&] {
6373+
// If requested, generate a constraint.
6374+
if (flags.contains(TMF_GenerateConstraints)) {
6375+
auto *witnessConstraint = Constraint::createValueWitness(
6376+
*this, kind, origBaseType, memberType, requirement, useDC,
6377+
functionRefKind, getConstraintLocator(locator));
6378+
6379+
addUnsolvedConstraint(witnessConstraint);
6380+
return SolutionKind::Solved;
6381+
}
6382+
6383+
return SolutionKind::Unsolved;
6384+
};
6385+
6386+
// Resolve the base type, if we can. If we can't resolve the base type,
6387+
// then we can't solve this constraint.
6388+
Type baseObjectType = getFixedTypeRecursive(
6389+
baseType, flags, /*wantRValue=*/true);
6390+
if (baseObjectType->isTypeVariableOrMember()) {
6391+
return formUnsolved();
6392+
}
6393+
6394+
// Check conformance to the protocol. If it doesn't conform, this constraint
6395+
// fails. Don't attempt to fix it.
6396+
// FIXME: Look in the constraint system to see if we've resolved the
6397+
// conformance already?
6398+
auto proto = requirement->getDeclContext()->getSelfProtocolDecl();
6399+
assert(proto && "Value witness constraint for a non-requirement");
6400+
auto conformance = TypeChecker::conformsToProtocol(
6401+
baseObjectType, proto, useDC,
6402+
(ConformanceCheckFlags::InExpression |
6403+
ConformanceCheckFlags::SkipConditionalRequirements));
6404+
if (!conformance) {
6405+
// The conformance failed, so mark the member type as a "hole". We cannot
6406+
// do anything further here.
6407+
if (!shouldAttemptFixes())
6408+
return SolutionKind::Error;
6409+
6410+
memberType.visit([&](Type type) {
6411+
if (auto *typeVar = type->getAs<TypeVariableType>())
6412+
recordPotentialHole(typeVar);
6413+
});
6414+
6415+
return SolutionKind::Solved;
6416+
}
6417+
6418+
// Reference the requirement.
6419+
Type resolvedBaseType = simplifyType(baseType, flags);
6420+
if (resolvedBaseType->isTypeVariableOrMember())
6421+
return formUnsolved();
6422+
6423+
auto choice = OverloadChoice(resolvedBaseType, requirement, functionRefKind);
6424+
resolveOverload(getConstraintLocator(locator), memberType, choice,
6425+
useDC);
6426+
return SolutionKind::Solved;
6427+
}
6428+
63576429
ConstraintSystem::SolutionKind
63586430
ConstraintSystem::simplifyDefaultableConstraint(
63596431
Type first, Type second,
@@ -8671,6 +8743,7 @@ ConstraintSystem::addConstraintImpl(ConstraintKind kind, Type first,
86718743

86728744
case ConstraintKind::ValueMember:
86738745
case ConstraintKind::UnresolvedValueMember:
8746+
case ConstraintKind::ValueWitness:
86748747
case ConstraintKind::BindOverload:
86758748
case ConstraintKind::Disjunction:
86768749
case ConstraintKind::KeyPath:
@@ -9058,6 +9131,16 @@ ConstraintSystem::simplifyConstraint(const Constraint &constraint) {
90589131
TMF_GenerateConstraints,
90599132
constraint.getLocator());
90609133

9134+
case ConstraintKind::ValueWitness:
9135+
return simplifyValueWitnessConstraint(constraint.getKind(),
9136+
constraint.getFirstType(),
9137+
constraint.getRequirement(),
9138+
constraint.getSecondType(),
9139+
constraint.getMemberUseDC(),
9140+
constraint.getFunctionRefKind(),
9141+
TMF_GenerateConstraints,
9142+
constraint.getLocator());
9143+
90619144
case ConstraintKind::Defaultable:
90629145
return simplifyDefaultableConstraint(constraint.getFirstType(),
90639146
constraint.getSecondType(),

lib/Sema/CSSolver.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1670,6 +1670,7 @@ void ConstraintSystem::ArgumentInfoCollector::walk(Type argType) {
16701670

16711671
case ConstraintKind::BindToPointerType:
16721672
case ConstraintKind::ValueMember:
1673+
case ConstraintKind::ValueWitness:
16731674
case ConstraintKind::UnresolvedValueMember:
16741675
case ConstraintKind::Disjunction:
16751676
case ConstraintKind::CheckedCast:

lib/Sema/Constraint.cpp

Lines changed: 63 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ Constraint::Constraint(ConstraintKind Kind, Type First, Type Second,
7777

7878
case ConstraintKind::ValueMember:
7979
case ConstraintKind::UnresolvedValueMember:
80+
case ConstraintKind::ValueWitness:
8081
llvm_unreachable("Wrong constructor for member constraint");
8182

8283
case ConstraintKind::Defaultable:
@@ -127,6 +128,7 @@ Constraint::Constraint(ConstraintKind Kind, Type First, Type Second, Type Third,
127128
case ConstraintKind::ApplicableFunction:
128129
case ConstraintKind::DynamicCallableApplicableFunction:
129130
case ConstraintKind::ValueMember:
131+
case ConstraintKind::ValueWitness:
130132
case ConstraintKind::UnresolvedValueMember:
131133
case ConstraintKind::Defaultable:
132134
case ConstraintKind::BindOverload:
@@ -156,7 +158,7 @@ Constraint::Constraint(ConstraintKind kind, Type first, Type second,
156158
ArrayRef<TypeVariableType *> typeVars)
157159
: Kind(kind), HasRestriction(false), IsActive(false), IsDisabled(false),
158160
RememberChoice(false), IsFavored(false),
159-
NumTypeVariables(typeVars.size()), Member{first, second, member, useDC},
161+
NumTypeVariables(typeVars.size()), Member{first, second, {member}, useDC},
160162
Locator(locator) {
161163
assert(kind == ConstraintKind::ValueMember ||
162164
kind == ConstraintKind::UnresolvedValueMember);
@@ -168,6 +170,28 @@ Constraint::Constraint(ConstraintKind kind, Type first, Type second,
168170
std::copy(typeVars.begin(), typeVars.end(), getTypeVariablesBuffer().begin());
169171
}
170172

173+
Constraint::Constraint(ConstraintKind kind, Type first, Type second,
174+
ValueDecl *requirement, DeclContext *useDC,
175+
FunctionRefKind functionRefKind,
176+
ConstraintLocator *locator,
177+
ArrayRef<TypeVariableType *> typeVars)
178+
: Kind(kind), HasRestriction(false), IsActive(false), IsDisabled(false),
179+
RememberChoice(false), IsFavored(false),
180+
NumTypeVariables(typeVars.size()), Locator(locator) {
181+
Member.First = first;
182+
Member.Second = second;
183+
Member.Member.Ref = requirement;
184+
Member.UseDC = useDC;
185+
TheFunctionRefKind = static_cast<unsigned>(functionRefKind);
186+
187+
assert(kind == ConstraintKind::ValueWitness);
188+
assert(getFunctionRefKind() == functionRefKind);
189+
assert(requirement && "Value witness constraint has no requirement");
190+
assert(useDC && "Member constraint has no use DC");
191+
192+
std::copy(typeVars.begin(), typeVars.end(), getTypeVariablesBuffer().begin());
193+
}
194+
171195
Constraint::Constraint(Type type, OverloadChoice choice, DeclContext *useDC,
172196
ConstraintFix *fix, ConstraintLocator *locator,
173197
ArrayRef<TypeVariableType *> typeVars)
@@ -251,6 +275,11 @@ Constraint *Constraint::clone(ConstraintSystem &cs) const {
251275
getMember(), getMemberUseDC(), getFunctionRefKind(),
252276
getLocator());
253277

278+
case ConstraintKind::ValueWitness:
279+
return createValueWitness(
280+
cs, getKind(), getFirstType(), getSecondType(), getRequirement(),
281+
getMemberUseDC(), getFunctionRefKind(), getLocator());
282+
254283
case ConstraintKind::Disjunction:
255284
return createDisjunction(cs, getNestedConstraints(), getLocator());
256285

@@ -386,11 +415,20 @@ void Constraint::print(llvm::raw_ostream &Out, SourceManager *sm) const {
386415
}
387416

388417
case ConstraintKind::ValueMember:
389-
Out << "[." << Member.Member << ": value] == ";
418+
Out << "[." << getMember() << ": value] == ";
390419
break;
391420
case ConstraintKind::UnresolvedValueMember:
392-
Out << "[(implicit) ." << Member.Member << ": value] == ";
421+
Out << "[(implicit) ." << getMember() << ": value] == ";
422+
break;
423+
424+
case ConstraintKind::ValueWitness: {
425+
auto requirement = getRequirement();
426+
auto selfNominal = requirement->getDeclContext()->getSelfNominalTypeDecl();
427+
Out << "[." << selfNominal->getName() << "::" << requirement->getFullName()
428+
<< ": witness] == ";
393429
break;
430+
}
431+
394432
case ConstraintKind::Defaultable:
395433
Out << " can default to ";
396434
break;
@@ -510,6 +548,7 @@ gatherReferencedTypeVars(Constraint *constraint,
510548
case ConstraintKind::Subtype:
511549
case ConstraintKind::UnresolvedValueMember:
512550
case ConstraintKind::ValueMember:
551+
case ConstraintKind::ValueWitness:
513552
case ConstraintKind::DynamicTypeOf:
514553
case ConstraintKind::EscapableFunctionOf:
515554
case ConstraintKind::OpenedExistentialOf:
@@ -648,6 +687,27 @@ Constraint *Constraint::createMember(ConstraintSystem &cs, ConstraintKind kind,
648687
functionRefKind, locator, typeVars);
649688
}
650689

690+
Constraint *Constraint::createValueWitness(
691+
ConstraintSystem &cs, ConstraintKind kind, Type first, Type second,
692+
ValueDecl *requirement, DeclContext *useDC,
693+
FunctionRefKind functionRefKind, ConstraintLocator *locator) {
694+
assert(kind == ConstraintKind::ValueWitness);
695+
696+
// Collect type variables.
697+
SmallVector<TypeVariableType *, 4> typeVars;
698+
if (first->hasTypeVariable())
699+
first->getTypeVariables(typeVars);
700+
if (second->hasTypeVariable())
701+
second->getTypeVariables(typeVars);
702+
uniqueTypeVariables(typeVars);
703+
704+
// Create the constraint.
705+
unsigned size = totalSizeToAlloc<TypeVariableType*>(typeVars.size());
706+
void *mem = cs.getAllocator().Allocate(size, alignof(Constraint));
707+
return new (mem) Constraint(kind, first, second, requirement, useDC,
708+
functionRefKind, locator, typeVars);
709+
}
710+
651711
Constraint *Constraint::createBindOverload(ConstraintSystem &cs, Type type,
652712
OverloadChoice choice,
653713
DeclContext *useDC,

0 commit comments

Comments
 (0)