Skip to content

Commit 4b182f6

Browse files
committed
Move findAsyncNode to BraceStmt
I need to determine when top-level code contains an `await` to determine whether to make the source file an async context. This logic is perfectly encapsulated in the `FindInnerAsync` AST walker. Unfortunately, that is pushed down in Sema/ConstraintSystem and isn't available at the AST level. I've pulled it up into the brace statement so that I can use that as part of determining whether the source file is async in `DeclContext::isAsyncContext`. Unfortunately, statements don't have an AST context or evaluator or I would make this a request.
1 parent 7188f40 commit 4b182f6

File tree

3 files changed

+65
-53
lines changed

3 files changed

+65
-53
lines changed

include/swift/AST/Stmt.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ class BraceStmt final : public Stmt,
161161

162162
SourceLoc getLBraceLoc() const { return LBLoc; }
163163
SourceLoc getRBraceLoc() const { return RBLoc; }
164-
164+
165165
SourceRange getSourceRange() const { return SourceRange(LBLoc, RBLoc); }
166166

167167
bool empty() const { return getNumElements() == 0; }
@@ -182,7 +182,9 @@ class BraceStmt final : public Stmt,
182182
ArrayRef<ASTNode> getElements() const {
183183
return {getTrailingObjects<ASTNode>(), Bits.BraceStmt.NumElements};
184184
}
185-
185+
186+
ASTNode findAsyncNode();
187+
186188
static bool classof(const Stmt *S) { return S->getKind() == StmtKind::Brace; }
187189
};
188190

lib/AST/Stmt.cpp

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
#include "swift/AST/Stmt.h"
1818
#include "swift/AST/ASTContext.h"
19+
#include "swift/AST/ASTWalker.h"
1920
#include "swift/AST/Decl.h"
2021
#include "swift/AST/Expr.h"
2122
#include "swift/AST/Pattern.h"
@@ -155,6 +156,65 @@ BraceStmt *BraceStmt::create(ASTContext &ctx, SourceLoc lbloc,
155156
return ::new(Buffer) BraceStmt(lbloc, elts, rbloc, implicit);
156157
}
157158

159+
ASTNode BraceStmt::findAsyncNode() {
160+
// TODO: Statements don't track their ASTContext/evaluator, so I am not making
161+
// this a request. It probably should be a request at some point.
162+
//
163+
// While we're at it, it would be very nice if this could be a const
164+
// operation, but the AST-walking is not a const operation.
165+
166+
// A walker that looks for 'async' and 'await' expressions
167+
// that aren't nested within closures or nested declarations.
168+
class FindInnerAsync : public ASTWalker {
169+
ASTNode AsyncNode;
170+
171+
std::pair<bool, Expr *> walkToExprPre(Expr *expr) override {
172+
// If we've found an 'await', record it and terminate the traversal.
173+
if (isa<AwaitExpr>(expr)) {
174+
AsyncNode = expr;
175+
return {false, nullptr};
176+
}
177+
178+
// Do not recurse into other closures.
179+
if (isa<ClosureExpr>(expr))
180+
return {false, expr};
181+
182+
return {true, expr};
183+
}
184+
185+
bool walkToDeclPre(Decl *decl) override {
186+
// Do not walk into function or type declarations.
187+
if (auto *patternBinding = dyn_cast<PatternBindingDecl>(decl)) {
188+
if (patternBinding->isAsyncLet())
189+
AsyncNode = patternBinding;
190+
191+
return true;
192+
}
193+
194+
return false;
195+
}
196+
197+
std::pair<bool, Stmt *> walkToStmtPre(Stmt *stmt) override {
198+
if (auto forEach = dyn_cast<ForEachStmt>(stmt)) {
199+
if (forEach->getAwaitLoc().isValid()) {
200+
AsyncNode = forEach;
201+
return {false, nullptr};
202+
}
203+
}
204+
205+
return {true, stmt};
206+
}
207+
208+
public:
209+
ASTNode getAsyncNode() { return AsyncNode; }
210+
};
211+
212+
FindInnerAsync asyncFinder;
213+
walk(asyncFinder);
214+
215+
return asyncFinder.getAsyncNode();
216+
}
217+
158218
SourceLoc ReturnStmt::getStartLoc() const {
159219
if (ReturnLoc.isInvalid() && Result)
160220
return Result->getStartLoc();

lib/Sema/ConstraintSystem.cpp

Lines changed: 1 addition & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -6111,55 +6111,5 @@ ASTNode constraints::findAsyncNode(ClosureExpr *closure) {
61116111
auto *body = closure->getBody();
61126112
if (!body)
61136113
return ASTNode();
6114-
6115-
// A walker that looks for 'async' and 'await' expressions
6116-
// that aren't nested within closures or nested declarations.
6117-
class FindInnerAsync : public ASTWalker {
6118-
ASTNode AsyncNode;
6119-
6120-
std::pair<bool, Expr *> walkToExprPre(Expr *expr) override {
6121-
// If we've found an 'await', record it and terminate the traversal.
6122-
if (isa<AwaitExpr>(expr)) {
6123-
AsyncNode = expr;
6124-
return {false, nullptr};
6125-
}
6126-
6127-
// Do not recurse into other closures.
6128-
if (isa<ClosureExpr>(expr))
6129-
return {false, expr};
6130-
6131-
return {true, expr};
6132-
}
6133-
6134-
bool walkToDeclPre(Decl *decl) override {
6135-
// Do not walk into function or type declarations.
6136-
if (auto *patternBinding = dyn_cast<PatternBindingDecl>(decl)) {
6137-
if (patternBinding->isAsyncLet())
6138-
AsyncNode = patternBinding;
6139-
6140-
return true;
6141-
}
6142-
6143-
return false;
6144-
}
6145-
6146-
std::pair<bool, Stmt *> walkToStmtPre(Stmt *stmt) override {
6147-
if (auto forEach = dyn_cast<ForEachStmt>(stmt)) {
6148-
if (forEach->getAwaitLoc().isValid()) {
6149-
AsyncNode = forEach;
6150-
return {false, nullptr};
6151-
}
6152-
}
6153-
6154-
return {true, stmt};
6155-
}
6156-
6157-
public:
6158-
ASTNode getAsyncNode() { return AsyncNode; }
6159-
};
6160-
6161-
FindInnerAsync asyncFinder;
6162-
body->walk(asyncFinder);
6163-
6164-
return asyncFinder.getAsyncNode();
6114+
return body->findAsyncNode();
61656115
}

0 commit comments

Comments
 (0)