Skip to content

Commit 12a463f

Browse files
committed
[ASTContext] Add {Async}IteratorProtocol::next to list of known decls
1 parent 88f24e5 commit 12a463f

File tree

2 files changed

+67
-20
lines changed

2 files changed

+67
-20
lines changed

include/swift/AST/ASTContext.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -600,6 +600,12 @@ class ASTContext final {
600600
/// Get AsyncSequence.makeAsyncIterator().
601601
FuncDecl *getAsyncSequenceMakeAsyncIterator() const;
602602

603+
/// Get IteratorProtocol.next().
604+
FuncDecl *getIteratorNext() const;
605+
606+
/// Get AsyncIteratorProtocol.next().
607+
FuncDecl *getAsyncIteratorNext() const;
608+
603609
/// Check whether the standard library provides all the correct
604610
/// intrinsic support for Optional<T>.
605611
///

lib/AST/ASTContext.cpp

Lines changed: 61 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,12 @@ struct ASTContext::Implementation {
210210
/// The declaration of 'AsyncSequence.makeAsyncIterator()'.
211211
FuncDecl *MakeAsyncIterator = nullptr;
212212

213+
/// The declaration of 'IteratorProtocol.next()'.
214+
FuncDecl *IteratorNext = nullptr;
215+
216+
/// The declaration of 'AsyncIteratorProtocol.next()'.
217+
FuncDecl *AsyncIteratorNext = nullptr;
218+
213219
/// The declaration of Swift.Optional<T>.Some.
214220
EnumElementDecl *OptionalSomeDecl = nullptr;
215221

@@ -779,31 +785,40 @@ FuncDecl *ASTContext::getPlusFunctionOnString() const {
779785
return getImpl().PlusFunctionOnString;
780786
}
781787

782-
FuncDecl *ASTContext::getSequenceMakeIterator() const {
783-
if (getImpl().MakeIterator) {
784-
return getImpl().MakeIterator;
785-
}
786-
787-
auto proto = getProtocol(KnownProtocolKind::Sequence);
788-
if (!proto)
789-
return nullptr;
790-
791-
for (auto result : proto->lookupDirect(Id_makeIterator)) {
788+
static FuncDecl *lookupRequirement(ProtocolDecl *proto,
789+
Identifier requirement) {
790+
for (auto result : proto->lookupDirect(requirement)) {
792791
if (result->getDeclContext() != proto)
793792
continue;
794793

795794
if (auto func = dyn_cast<FuncDecl>(result)) {
796795
if (func->getParameters()->size() != 0)
797796
continue;
798797

799-
getImpl().MakeIterator = func;
800798
return func;
801799
}
802800
}
803801

804802
return nullptr;
805803
}
806804

805+
FuncDecl *ASTContext::getSequenceMakeIterator() const {
806+
if (getImpl().MakeIterator) {
807+
return getImpl().MakeIterator;
808+
}
809+
810+
auto proto = getProtocol(KnownProtocolKind::Sequence);
811+
if (!proto)
812+
return nullptr;
813+
814+
if (auto *func = lookupRequirement(proto, Id_makeIterator)) {
815+
getImpl().MakeIterator = func;
816+
return func;
817+
}
818+
819+
return nullptr;
820+
}
821+
807822
FuncDecl *ASTContext::getAsyncSequenceMakeAsyncIterator() const {
808823
if (getImpl().MakeAsyncIterator) {
809824
return getImpl().MakeAsyncIterator;
@@ -813,17 +828,43 @@ FuncDecl *ASTContext::getAsyncSequenceMakeAsyncIterator() const {
813828
if (!proto)
814829
return nullptr;
815830

816-
for (auto result : proto->lookupDirect(Id_makeAsyncIterator)) {
817-
if (result->getDeclContext() != proto)
818-
continue;
831+
if (auto *func = lookupRequirement(proto, Id_makeAsyncIterator)) {
832+
getImpl().MakeAsyncIterator = func;
833+
return func;
834+
}
819835

820-
if (auto func = dyn_cast<FuncDecl>(result)) {
821-
if (func->getParameters()->size() != 0)
822-
continue;
836+
return nullptr;
837+
}
823838

824-
getImpl().MakeAsyncIterator = func;
825-
return func;
826-
}
839+
FuncDecl *ASTContext::getIteratorNext() const {
840+
if (getImpl().IteratorNext) {
841+
return getImpl().IteratorNext;
842+
}
843+
844+
auto proto = getProtocol(KnownProtocolKind::IteratorProtocol);
845+
if (!proto)
846+
return nullptr;
847+
848+
if (auto *func = lookupRequirement(proto, Id_next)) {
849+
getImpl().IteratorNext = func;
850+
return func;
851+
}
852+
853+
return nullptr;
854+
}
855+
856+
FuncDecl *ASTContext::getAsyncIteratorNext() const {
857+
if (getImpl().AsyncIteratorNext) {
858+
return getImpl().AsyncIteratorNext;
859+
}
860+
861+
auto proto = getProtocol(KnownProtocolKind::AsyncIteratorProtocol);
862+
if (!proto)
863+
return nullptr;
864+
865+
if (auto *func = lookupRequirement(proto, Id_next)) {
866+
getImpl().AsyncIteratorNext = func;
867+
return func;
827868
}
828869

829870
return nullptr;

0 commit comments

Comments
 (0)