@@ -8929,21 +8929,46 @@ static Optional<SolutionApplicationTarget> applySolutionToForEachStmt(
89298929
89308930 Expr *nextCall = rewrittenTarget->getAsExpr ();
89318931 // Wrap a call to `next()` into `try await` since `AsyncIteratorProtocol`
8932- // requirement is `async throws`
8932+ // witness could be `async throws`.
89338933 if (isAsync) {
8934- auto &ctx = cs.getASTContext ();
8935- auto nextRefType =
8936- solution
8937- .getResolvedType (
8938- cast<ApplyExpr>(cast<AwaitExpr>(nextCall)->getSubExpr ())
8939- ->getFn ())
8940- ->castTo <FunctionType>();
8941-
8942- // If the inferred witness is throwing, we need to wrap the call
8943- // into `try` expression.
8944- if (nextRefType->isThrowing ())
8945- nextCall = TryExpr::createImplicit (ctx, /* tryLoc=*/ SourceLoc (),
8946- nextCall, nextCall->getType ());
8934+ // Cannot use `forEachChildExpr` here because we need to
8935+ // to wrap a call in `try` and then stop immediately after.
8936+ struct TryInjector : ASTWalker {
8937+ ASTContext &C;
8938+ const Solution &S;
8939+
8940+ bool ShouldStop = false ;
8941+
8942+ TryInjector (ASTContext &ctx, const Solution &solution)
8943+ : C(ctx), S(solution) {}
8944+
8945+ PreWalkResult<Expr *> walkToExprPre (Expr *E) override {
8946+ if (ShouldStop)
8947+ return Action::Stop ();
8948+
8949+ if (auto *call = dyn_cast<CallExpr>(E)) {
8950+ // There is a single call expression in `nextCall`.
8951+ ShouldStop = true ;
8952+
8953+ auto nextRefType =
8954+ S.getResolvedType (call->getFn ())->castTo <FunctionType>();
8955+
8956+ // If the inferred witness is throwing, we need to wrap the call
8957+ // into `try` expression.
8958+ if (nextRefType->isThrowing ()) {
8959+ auto *tryExpr = TryExpr::createImplicit (
8960+ C, /* tryLoc=*/ call->getStartLoc (), call, call->getType ());
8961+ // Cannot stop here because we need to make sure that
8962+ // the new expression gets injected into AST.
8963+ return Action::SkipChildren (tryExpr);
8964+ }
8965+ }
8966+
8967+ return Action::Continue (E);
8968+ }
8969+ };
8970+
8971+ nextCall->walk (TryInjector (cs.getASTContext (), solution));
89478972 }
89488973
89498974 stmt->setNextCall (nextCall);
0 commit comments