Skip to content

Commit bb7a563

Browse files
committed
Switch async for-each loop over to _nextElement and drop @rethrows.
This couples together several changes to move entirely from `@rethrows` over to typed throws: * Use the `Failure` type to determine whether an async for-each loop will throw, rather than depending on rethrows checking * Introduce a special carve-out for `rethrows` functions that have a generic requirement on an `AsyncSequence` or `AsyncIteratorProtocol`, which uses that requirement's `Failure` type as potentially being part of the thrown error type. This allows existing generic functions like the following to continue to work: func f<S: AsyncSequence>(_: S) rethrows * Switch SIL generation for the async for-each loop from the prior `next()` over to the typed-throws version `_nextElement`. * Remove `@rethrows` from `AsyncSequence` and `AsyncIteratorProtocol` entirely. We are now fully dependent on typed throws.
1 parent a5bdb12 commit bb7a563

21 files changed

+219
-53
lines changed

include/swift/AST/Effects.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,11 @@ enum class PolymorphicEffectKind : uint8_t {
8787
/// This is the conformance-based 'rethrows' /'reasync' case.
8888
ByConformance,
8989

90+
/// The function is only permitted to be `rethrows` because it depends
91+
/// on a conformance to `AsyncSequence` or `AsyncIteratorProtocol`,
92+
/// which historically were "@rethrows" protocols.
93+
AsyncSequenceRethrows,
94+
9095
/// The function has this effect unconditionally.
9196
///
9297
/// This is a plain old 'throws' / 'async' function.

include/swift/AST/KnownIdentifiers.def

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ IDENTIFIER(enqueue)
9090
IDENTIFIER(erasing)
9191
IDENTIFIER(error)
9292
IDENTIFIER(errorDomain)
93+
IDENTIFIER(Failure)
9394
IDENTIFIER(first)
9495
IDENTIFIER(forKeyedSubscript)
9596
IDENTIFIER(Foundation)
@@ -122,6 +123,7 @@ IDENTIFIER(load)
122123
IDENTIFIER(main)
123124
IDENTIFIER_WITH_NAME(MainEntryPoint, "$main")
124125
IDENTIFIER(next)
126+
IDENTIFIER_(nextElement)
125127
IDENTIFIER_(nsErrorDomain)
126128
IDENTIFIER(objectAtIndexedSubscript)
127129
IDENTIFIER(objectForKeyedSubscript)

include/swift/AST/Stmt.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -971,6 +971,7 @@ class ForEachStmt : public LabeledStmt {
971971

972972
// Set by Sema:
973973
ProtocolConformanceRef sequenceConformance = ProtocolConformanceRef();
974+
Type sequenceType;
974975
PatternBindingDecl *iteratorVar = nullptr;
975976
Expr *nextCall = nullptr;
976977
OpaqueValueExpr *elementExpr = nullptr;
@@ -1001,9 +1002,12 @@ class ForEachStmt : public LabeledStmt {
10011002
void setConvertElementExpr(Expr *expr) { convertElementExpr = expr; }
10021003
Expr *getConvertElementExpr() const { return convertElementExpr; }
10031004

1004-
void setSequenceConformance(ProtocolConformanceRef conformance) {
1005+
void setSequenceConformance(Type type,
1006+
ProtocolConformanceRef conformance) {
1007+
sequenceType = type;
10051008
sequenceConformance = conformance;
10061009
}
1010+
Type getSequenceType() const { return sequenceType; }
10071011
ProtocolConformanceRef getSequenceConformance() const {
10081012
return sequenceConformance;
10091013
}

lib/AST/ASTContext.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -957,6 +957,12 @@ FuncDecl *ASTContext::getAsyncIteratorNext() const {
957957
if (!proto)
958958
return nullptr;
959959

960+
if (auto *func = lookupRequirement(
961+
proto, getIdentifier("_nextElement"))) {
962+
getImpl().AsyncIteratorNext = func;
963+
return func;
964+
}
965+
960966
if (auto *func = lookupRequirement(proto, Id_next)) {
961967
getImpl().AsyncIteratorNext = func;
962968
return func;

lib/AST/Effects.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,9 @@ void swift::simple_display(llvm::raw_ostream &out,
114114
case PolymorphicEffectKind::ByConformance:
115115
out << "by conformance";
116116
break;
117+
case PolymorphicEffectKind::AsyncSequenceRethrows:
118+
out << "by async sequence implicit @rethrows";
119+
break;
117120
case PolymorphicEffectKind::Always:
118121
out << "always";
119122
break;

lib/Sema/AssociatedTypeInference.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1914,7 +1914,7 @@ AssociatedTypeInference::computeFailureTypeWitness(
19141914
auto proto = req.getProtocolDecl();
19151915
if (proto->isSpecificProtocol(KnownProtocolKind::AsyncIteratorProtocol) ||
19161916
proto->isSpecificProtocol(KnownProtocolKind::AsyncSequence)) {
1917-
auto failureAssocType = proto->getAssociatedType(ctx.getIdentifier("Failure"));
1917+
auto failureAssocType = proto->getAssociatedType(ctx.Id_Failure);
19181918
auto failureType = DependentMemberType::get(req.getFirstType(), failureAssocType);
19191919
return AbstractTypeWitness(assocType, dc->mapTypeIntoContext(failureType));
19201920
}

lib/Sema/CSApply.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9194,7 +9194,7 @@ static llvm::Optional<SequenceIterationInfo> applySolutionToForEachStmt(
91949194
type, sequenceProto);
91959195
assert(!sequenceConformance.isInvalid() &&
91969196
"Couldn't find sequence conformance");
9197-
stmt->setSequenceConformance(sequenceConformance);
9197+
stmt->setSequenceConformance(type, sequenceConformance);
91989198

91999199
// Apply the solution to the filtering condition, if there is one.
92009200
if (auto *whereExpr = stmt->getWhere()) {

lib/Sema/CSGen.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4592,11 +4592,15 @@ generateForEachStmtConstraints(ConstraintSystem &cs, DeclContext *dc,
45924592
// Now, result type of `.makeIterator()` is used to form a call to
45934593
// `.next()`. `next()` is called on each iteration of the loop.
45944594
{
4595+
FuncDecl *nextFn = isAsync ? ctx.getAsyncIteratorNext()
4596+
: ctx.getIteratorNext();
4597+
Identifier nextId = nextFn ? nextFn->getName().getBaseIdentifier()
4598+
: ctx.Id_next;
45954599
auto *nextRef = UnresolvedDotExpr::createImplicit(
45964600
ctx,
45974601
new (ctx) DeclRefExpr(makeIteratorVar, DeclNameLoc(stmt->getForLoc()),
45984602
/*Implicit=*/true),
4599-
ctx.Id_next, /*labels=*/ArrayRef<Identifier>());
4603+
nextId, /*labels=*/ArrayRef<Identifier>());
46004604
nextRef->setFunctionRefKind(FunctionRefKind::SingleApply);
46014605

46024606
Expr *nextCall = CallExpr::createImplicitEmpty(ctx, nextRef);

lib/Sema/CSSimplify.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10704,7 +10704,8 @@ ConstraintSystem::SolutionKind ConstraintSystem::simplifyMemberConstraint(
1070410704

1070510705
// Handle `next` reference.
1070610706
if (getContextualTypePurpose(baseExpr) == CTP_ForEachSequence &&
10707-
isRefTo(memberRef, ctx.Id_next, /*labels=*/{})) {
10707+
(isRefTo(memberRef, ctx.Id_next, /*labels=*/{}) ||
10708+
isRefTo(memberRef, ctx.Id_nextElement, /*labels=*/{}))) {
1070810709
auto *iteratorProto = cast<ProtocolDecl>(
1070910710
getContextualType(baseExpr, /*forConstraint=*/false)
1071010711
->getAnyNominal());
@@ -10930,7 +10931,8 @@ ConstraintSystem::SolutionKind ConstraintSystem::simplifyMemberConstraint(
1093010931
if (auto *base = dyn_cast<DeclRefExpr>(UDE->getBase())) {
1093110932
if (auto var = dyn_cast_or_null<VarDecl>(base->getDecl())) {
1093210933
if (var->getNameStr().contains("$generator") &&
10933-
UDE->getName().getBaseIdentifier() == Context.Id_next)
10934+
(UDE->getName().getBaseIdentifier() == Context.Id_next ||
10935+
UDE->getName().getBaseIdentifier() == Context.Id_nextElement))
1093410936
return success();
1093510937
}
1093610938
}

lib/Sema/MiscDiagnostics.cpp

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6239,17 +6239,14 @@ llvm::Optional<Identifier> TypeChecker::omitNeedlessWords(VarDecl *var) {
62396239

62406240
bool swift::diagnoseUnhandledThrowsInAsyncContext(DeclContext *dc,
62416241
ForEachStmt *forEach) {
6242-
if (!forEach->getAwaitLoc().isValid())
6243-
return false;
6244-
6245-
auto conformanceRef = forEach->getSequenceConformance();
6246-
if (conformanceRef.hasEffect(EffectKind::Throws) &&
6247-
forEach->getTryLoc().isInvalid()) {
6248-
auto &ctx = dc->getASTContext();
6249-
ctx.Diags
6250-
.diagnose(forEach->getAwaitLoc(), diag::throwing_call_unhandled, "call")
6251-
.fixItInsert(forEach->getAwaitLoc(), "try");
6252-
return true;
6242+
auto &ctx = dc->getASTContext();
6243+
if (auto thrownError = TypeChecker::canThrow(ctx, forEach)) {
6244+
if (forEach->getTryLoc().isInvalid()) {
6245+
ctx.Diags
6246+
.diagnose(forEach->getAwaitLoc(), diag::throwing_call_unhandled, "call")
6247+
.fixItInsert(forEach->getAwaitLoc(), "try");
6248+
return true;
6249+
}
62536250
}
62546251

62556252
return false;

0 commit comments

Comments
 (0)