Skip to content

Commit d1ff79e

Browse files
committed
[ASTContext] Add {Async}IteratorProtocol::next to list of known decls
1 parent c798a7f commit d1ff79e

File tree

2 files changed

+67
-20
lines changed

2 files changed

+67
-20
lines changed

include/swift/AST/ASTContext.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -602,6 +602,12 @@ class ASTContext final {
602602
/// Get AsyncSequence.makeAsyncIterator().
603603
FuncDecl *getAsyncSequenceMakeAsyncIterator() const;
604604

605+
/// Get IteratorProtocol.next().
606+
FuncDecl *getIteratorNext() const;
607+
608+
/// Get AsyncIteratorProtocol.next().
609+
FuncDecl *getAsyncIteratorNext() const;
610+
605611
/// Check whether the standard library provides all the correct
606612
/// intrinsic support for Optional<T>.
607613
///

lib/AST/ASTContext.cpp

Lines changed: 61 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,12 @@ struct ASTContext::Implementation {
217217
/// The declaration of 'AsyncSequence.makeAsyncIterator()'.
218218
FuncDecl *MakeAsyncIterator = nullptr;
219219

220+
/// The declaration of 'IteratorProtocol.next()'.
221+
FuncDecl *IteratorNext = nullptr;
222+
223+
/// The declaration of 'AsyncIteratorProtocol.next()'.
224+
FuncDecl *AsyncIteratorNext = nullptr;
225+
220226
/// The declaration of Swift.Optional<T>.Some.
221227
EnumElementDecl *OptionalSomeDecl = nullptr;
222228

@@ -806,31 +812,40 @@ FuncDecl *ASTContext::getPlusFunctionOnString() const {
806812
return getImpl().PlusFunctionOnString;
807813
}
808814

809-
FuncDecl *ASTContext::getSequenceMakeIterator() const {
810-
if (getImpl().MakeIterator) {
811-
return getImpl().MakeIterator;
812-
}
813-
814-
auto proto = getProtocol(KnownProtocolKind::Sequence);
815-
if (!proto)
816-
return nullptr;
817-
818-
for (auto result : proto->lookupDirect(Id_makeIterator)) {
815+
static FuncDecl *lookupRequirement(ProtocolDecl *proto,
816+
Identifier requirement) {
817+
for (auto result : proto->lookupDirect(requirement)) {
819818
if (result->getDeclContext() != proto)
820819
continue;
821820

822821
if (auto func = dyn_cast<FuncDecl>(result)) {
823822
if (func->getParameters()->size() != 0)
824823
continue;
825824

826-
getImpl().MakeIterator = func;
827825
return func;
828826
}
829827
}
830828

831829
return nullptr;
832830
}
833831

832+
FuncDecl *ASTContext::getSequenceMakeIterator() const {
833+
if (getImpl().MakeIterator) {
834+
return getImpl().MakeIterator;
835+
}
836+
837+
auto proto = getProtocol(KnownProtocolKind::Sequence);
838+
if (!proto)
839+
return nullptr;
840+
841+
if (auto *func = lookupRequirement(proto, Id_makeIterator)) {
842+
getImpl().MakeIterator = func;
843+
return func;
844+
}
845+
846+
return nullptr;
847+
}
848+
834849
FuncDecl *ASTContext::getAsyncSequenceMakeAsyncIterator() const {
835850
if (getImpl().MakeAsyncIterator) {
836851
return getImpl().MakeAsyncIterator;
@@ -840,17 +855,43 @@ FuncDecl *ASTContext::getAsyncSequenceMakeAsyncIterator() const {
840855
if (!proto)
841856
return nullptr;
842857

843-
for (auto result : proto->lookupDirect(Id_makeAsyncIterator)) {
844-
if (result->getDeclContext() != proto)
845-
continue;
858+
if (auto *func = lookupRequirement(proto, Id_makeAsyncIterator)) {
859+
getImpl().MakeAsyncIterator = func;
860+
return func;
861+
}
846862

847-
if (auto func = dyn_cast<FuncDecl>(result)) {
848-
if (func->getParameters()->size() != 0)
849-
continue;
863+
return nullptr;
864+
}
850865

851-
getImpl().MakeAsyncIterator = func;
852-
return func;
853-
}
866+
FuncDecl *ASTContext::getIteratorNext() const {
867+
if (getImpl().IteratorNext) {
868+
return getImpl().IteratorNext;
869+
}
870+
871+
auto proto = getProtocol(KnownProtocolKind::IteratorProtocol);
872+
if (!proto)
873+
return nullptr;
874+
875+
if (auto *func = lookupRequirement(proto, Id_next)) {
876+
getImpl().IteratorNext = func;
877+
return func;
878+
}
879+
880+
return nullptr;
881+
}
882+
883+
FuncDecl *ASTContext::getAsyncIteratorNext() const {
884+
if (getImpl().AsyncIteratorNext) {
885+
return getImpl().AsyncIteratorNext;
886+
}
887+
888+
auto proto = getProtocol(KnownProtocolKind::AsyncIteratorProtocol);
889+
if (!proto)
890+
return nullptr;
891+
892+
if (auto *func = lookupRequirement(proto, Id_next)) {
893+
getImpl().AsyncIteratorNext = func;
894+
return func;
854895
}
855896

856897
return nullptr;

0 commit comments

Comments
 (0)