Skip to content

Commit c4b7d6d

Browse files
authored
Merge pull request #61091 from xedin/issue-60958-alt
[ConstraintSystem] Use witnesses for `makeIterator` and `next` refs in `for-in` context
2 parents 1eae14c + 1a47a95 commit c4b7d6d

File tree

10 files changed

+447
-29
lines changed

10 files changed

+447
-29
lines changed

include/swift/AST/ASTContext.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -602,6 +602,12 @@ class ASTContext final {
602602
/// Get AsyncSequence.makeAsyncIterator().
603603
FuncDecl *getAsyncSequenceMakeAsyncIterator() const;
604604

605+
/// Get IteratorProtocol.next().
606+
FuncDecl *getIteratorNext() const;
607+
608+
/// Get AsyncIteratorProtocol.next().
609+
FuncDecl *getAsyncIteratorNext() const;
610+
605611
/// Check whether the standard library provides all the correct
606612
/// intrinsic support for Optional<T>.
607613
///

include/swift/Sema/Constraint.h

Lines changed: 37 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,11 @@ enum class ConstraintKind : char {
137137
/// name, and the type of that member, when referenced as a value, is the
138138
/// second type.
139139
UnresolvedValueMember,
140+
/// The first type conforms to the protocol in which the member requirement
141+
/// resides. Once the conformance is resolved, the value witness will be
142+
/// determined, and the type of that witness, when referenced as a value,
143+
/// will be bound to the second type.
144+
ValueWitness,
140145
/// The first type can be defaulted to the second (which currently
141146
/// cannot be dependent). This is more like a type property than a
142147
/// relational constraint.
@@ -406,11 +411,18 @@ class Constraint final : public llvm::ilist_node<Constraint>,
406411
/// The type of the member.
407412
Type Second;
408413

409-
/// If non-null, the name of a member of the first type is that
410-
/// being related to the second type.
411-
///
412-
/// Used for ValueMember an UnresolvedValueMember constraints.
413-
DeclNameRef Name;
414+
union {
415+
/// If non-null, the name of a member of the first type is that
416+
/// being related to the second type.
417+
///
418+
/// Used for ValueMember an UnresolvedValueMember constraints.
419+
DeclNameRef Name;
420+
421+
/// If non-null, the member being referenced.
422+
///
423+
/// Used for ValueWitness constraints.
424+
ValueDecl *Ref;
425+
} Member;
414426

415427
/// The DC in which the use appears.
416428
DeclContext *UseDC;
@@ -525,6 +537,12 @@ class Constraint final : public llvm::ilist_node<Constraint>,
525537
FunctionRefKind functionRefKind,
526538
ConstraintLocator *locator);
527539

540+
/// Create a new value witness constraint.
541+
static Constraint *createValueWitness(
542+
ConstraintSystem &cs, ConstraintKind kind, Type first, Type second,
543+
ValueDecl *requirement, DeclContext *useDC,
544+
FunctionRefKind functionRefKind, ConstraintLocator *locator);
545+
528546
/// Create an overload-binding constraint.
529547
static Constraint *createBindOverload(ConstraintSystem &cs, Type type,
530548
OverloadChoice choice,
@@ -672,6 +690,7 @@ class Constraint final : public llvm::ilist_node<Constraint>,
672690

673691
case ConstraintKind::ValueMember:
674692
case ConstraintKind::UnresolvedValueMember:
693+
case ConstraintKind::ValueWitness:
675694
case ConstraintKind::PropertyWrapper:
676695
return ConstraintClassification::Member;
677696

@@ -711,6 +730,7 @@ class Constraint final : public llvm::ilist_node<Constraint>,
711730

712731
case ConstraintKind::ValueMember:
713732
case ConstraintKind::UnresolvedValueMember:
733+
case ConstraintKind::ValueWitness:
714734
return Member.First;
715735

716736
case ConstraintKind::SyntacticElement:
@@ -732,6 +752,7 @@ class Constraint final : public llvm::ilist_node<Constraint>,
732752

733753
case ConstraintKind::ValueMember:
734754
case ConstraintKind::UnresolvedValueMember:
755+
case ConstraintKind::ValueWitness:
735756
return Member.Second;
736757

737758
default:
@@ -757,13 +778,20 @@ class Constraint final : public llvm::ilist_node<Constraint>,
757778
DeclNameRef getMember() const {
758779
assert(Kind == ConstraintKind::ValueMember ||
759780
Kind == ConstraintKind::UnresolvedValueMember);
760-
return Member.Name;
781+
return Member.Member.Name;
782+
}
783+
784+
/// Retrieve the requirement being referenced by a value witness constraint.
785+
ValueDecl *getRequirement() const {
786+
assert(Kind == ConstraintKind::ValueWitness);
787+
return Member.Member.Ref;
761788
}
762789

763790
/// Determine the kind of function reference we have for a member reference.
764791
FunctionRefKind getFunctionRefKind() const {
765792
if (Kind == ConstraintKind::ValueMember ||
766-
Kind == ConstraintKind::UnresolvedValueMember)
793+
Kind == ConstraintKind::UnresolvedValueMember ||
794+
Kind == ConstraintKind::ValueWitness)
767795
return static_cast<FunctionRefKind>(TheFunctionRefKind);
768796

769797
// Conservative answer: drop all of the labels.
@@ -823,7 +851,8 @@ class Constraint final : public llvm::ilist_node<Constraint>,
823851
/// Retrieve the DC in which the member was used.
824852
DeclContext *getMemberUseDC() const {
825853
assert(Kind == ConstraintKind::ValueMember ||
826-
Kind == ConstraintKind::UnresolvedValueMember);
854+
Kind == ConstraintKind::UnresolvedValueMember ||
855+
Kind == ConstraintKind::ValueWitness);
827856
return Member.UseDC;
828857
}
829858

include/swift/Sema/ConstraintSystem.h

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4354,6 +4354,26 @@ class ConstraintSystem {
43544354
}
43554355
}
43564356

4357+
/// Add a value witness constraint to the constraint system.
4358+
void addValueWitnessConstraint(
4359+
Type baseTy, ValueDecl *requirement, Type memberTy, DeclContext *useDC,
4360+
FunctionRefKind functionRefKind, ConstraintLocatorBuilder locator) {
4361+
assert(baseTy);
4362+
assert(memberTy);
4363+
assert(requirement);
4364+
assert(useDC);
4365+
switch (simplifyValueWitnessConstraint(
4366+
ConstraintKind::ValueWitness, baseTy, requirement, memberTy, useDC,
4367+
functionRefKind, TMF_GenerateConstraints, locator)) {
4368+
case SolutionKind::Unsolved:
4369+
llvm_unreachable("Unsolved result when generating constraints!");
4370+
4371+
case SolutionKind::Solved:
4372+
case SolutionKind::Error:
4373+
break;
4374+
}
4375+
}
4376+
43574377
/// Add an explicit conversion constraint (e.g., \c 'x as T').
43584378
///
43594379
/// \param fromType The type of the expression being converted.

lib/AST/ASTContext.cpp

Lines changed: 61 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,12 @@ struct ASTContext::Implementation {
217217
/// The declaration of 'AsyncSequence.makeAsyncIterator()'.
218218
FuncDecl *MakeAsyncIterator = nullptr;
219219

220+
/// The declaration of 'IteratorProtocol.next()'.
221+
FuncDecl *IteratorNext = nullptr;
222+
223+
/// The declaration of 'AsyncIteratorProtocol.next()'.
224+
FuncDecl *AsyncIteratorNext = nullptr;
225+
220226
/// The declaration of Swift.Optional<T>.Some.
221227
EnumElementDecl *OptionalSomeDecl = nullptr;
222228

@@ -806,31 +812,40 @@ FuncDecl *ASTContext::getPlusFunctionOnString() const {
806812
return getImpl().PlusFunctionOnString;
807813
}
808814

809-
FuncDecl *ASTContext::getSequenceMakeIterator() const {
810-
if (getImpl().MakeIterator) {
811-
return getImpl().MakeIterator;
812-
}
813-
814-
auto proto = getProtocol(KnownProtocolKind::Sequence);
815-
if (!proto)
816-
return nullptr;
817-
818-
for (auto result : proto->lookupDirect(Id_makeIterator)) {
815+
static FuncDecl *lookupRequirement(ProtocolDecl *proto,
816+
Identifier requirement) {
817+
for (auto result : proto->lookupDirect(requirement)) {
819818
if (result->getDeclContext() != proto)
820819
continue;
821820

822821
if (auto func = dyn_cast<FuncDecl>(result)) {
823822
if (func->getParameters()->size() != 0)
824823
continue;
825824

826-
getImpl().MakeIterator = func;
827825
return func;
828826
}
829827
}
830828

831829
return nullptr;
832830
}
833831

832+
FuncDecl *ASTContext::getSequenceMakeIterator() const {
833+
if (getImpl().MakeIterator) {
834+
return getImpl().MakeIterator;
835+
}
836+
837+
auto proto = getProtocol(KnownProtocolKind::Sequence);
838+
if (!proto)
839+
return nullptr;
840+
841+
if (auto *func = lookupRequirement(proto, Id_makeIterator)) {
842+
getImpl().MakeIterator = func;
843+
return func;
844+
}
845+
846+
return nullptr;
847+
}
848+
834849
FuncDecl *ASTContext::getAsyncSequenceMakeAsyncIterator() const {
835850
if (getImpl().MakeAsyncIterator) {
836851
return getImpl().MakeAsyncIterator;
@@ -840,17 +855,43 @@ FuncDecl *ASTContext::getAsyncSequenceMakeAsyncIterator() const {
840855
if (!proto)
841856
return nullptr;
842857

843-
for (auto result : proto->lookupDirect(Id_makeAsyncIterator)) {
844-
if (result->getDeclContext() != proto)
845-
continue;
858+
if (auto *func = lookupRequirement(proto, Id_makeAsyncIterator)) {
859+
getImpl().MakeAsyncIterator = func;
860+
return func;
861+
}
846862

847-
if (auto func = dyn_cast<FuncDecl>(result)) {
848-
if (func->getParameters()->size() != 0)
849-
continue;
863+
return nullptr;
864+
}
850865

851-
getImpl().MakeAsyncIterator = func;
852-
return func;
853-
}
866+
FuncDecl *ASTContext::getIteratorNext() const {
867+
if (getImpl().IteratorNext) {
868+
return getImpl().IteratorNext;
869+
}
870+
871+
auto proto = getProtocol(KnownProtocolKind::IteratorProtocol);
872+
if (!proto)
873+
return nullptr;
874+
875+
if (auto *func = lookupRequirement(proto, Id_next)) {
876+
getImpl().IteratorNext = func;
877+
return func;
878+
}
879+
880+
return nullptr;
881+
}
882+
883+
FuncDecl *ASTContext::getAsyncIteratorNext() const {
884+
if (getImpl().AsyncIteratorNext) {
885+
return getImpl().AsyncIteratorNext;
886+
}
887+
888+
auto proto = getProtocol(KnownProtocolKind::AsyncIteratorProtocol);
889+
if (!proto)
890+
return nullptr;
891+
892+
if (auto *func = lookupRequirement(proto, Id_next)) {
893+
getImpl().AsyncIteratorNext = func;
894+
return func;
854895
}
855896

856897
return nullptr;

lib/Sema/CSBindings.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1473,6 +1473,7 @@ void PotentialBindings::infer(Constraint *constraint) {
14731473

14741474
case ConstraintKind::ValueMember:
14751475
case ConstraintKind::UnresolvedValueMember:
1476+
case ConstraintKind::ValueWitness:
14761477
case ConstraintKind::PropertyWrapper: {
14771478
// If current type variable represents a member type of some reference,
14781479
// it would be bound once member is resolved either to a actual member

lib/Sema/CSGen.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4057,6 +4057,21 @@ generateForEachStmtConstraints(
40574057
AwaitExpr::createImplicit(ctx, /*awaitLoc=*/SourceLoc(), nextCall);
40584058
}
40594059

4060+
// The iterator type must conform to IteratorProtocol.
4061+
{
4062+
ProtocolDecl *iteratorProto = TypeChecker::getProtocol(
4063+
cs.getASTContext(), stmt->getForLoc(),
4064+
isAsync ? KnownProtocolKind::AsyncIteratorProtocol
4065+
: KnownProtocolKind::IteratorProtocol);
4066+
if (!iteratorProto)
4067+
return None;
4068+
4069+
cs.setContextualType(
4070+
nextRef->getBase(),
4071+
TypeLoc::withoutLoc(iteratorProto->getDeclaredInterfaceType()),
4072+
CTP_ForEachSequence);
4073+
}
4074+
40604075
SolutionApplicationTarget nextTarget(nextCall, dc, CTP_Unused,
40614076
/*contextualType=*/Type(),
40624077
/*isDiscarded=*/false);

0 commit comments

Comments
 (0)