Skip to content

Commit 7be05de

Browse files
committed
[Concurrency] Simplify checking for local functions.
Treat them as local captures and require them to be @Concurrent.
1 parent c26c502 commit 7be05de

File tree

4 files changed

+24
-168
lines changed

4 files changed

+24
-168
lines changed

include/swift/AST/DiagnosticsSema.def

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4277,8 +4277,6 @@ ERROR(local_function_executed_concurrently,none,
42774277
ERROR(concurrent_mutation_of_local_capture,none,
42784278
"mutation of captured %0 %1 in concurrently-executing code",
42794279
(DescriptiveDeclKind, DeclName))
4280-
NOTE(concurrent_access_here,none,
4281-
"access in concurrently-executed code here", ())
42824280
NOTE(actor_isolated_sync_func,none,
42834281
"calls to %0 %1 from outside of its actor context are "
42844282
"implicitly asynchronous",

lib/Sema/TypeCheckConcurrency.cpp

Lines changed: 19 additions & 161 deletions
Original file line numberDiff line numberDiff line change
@@ -520,14 +520,8 @@ ActorIsolationRestriction ActorIsolationRestriction::forDeclaration(
520520
case DeclKind::Constructor:
521521
case DeclKind::Func:
522522
case DeclKind::Subscript: {
523-
// Local captures can only be referenced in their local context or a
524-
// context that is guaranteed not to run concurrently with it.
523+
// Local captures are checked separately.
525524
if (cast<ValueDecl>(decl)->isLocalCapture()) {
526-
// Local functions are safe to capture; their bodies are checked based on
527-
// where that capture is used.
528-
if (isa<FuncDecl>(decl))
529-
return forUnrestricted();
530-
531525
return forLocalCapture(decl->getDeclContext());
532526
}
533527

@@ -968,39 +962,13 @@ bool swift::diagnoseNonConcurrentTypesInFunctionType(
968962
}
969963

970964
namespace {
971-
/// Check whether a particular context may execute concurrently within
972-
/// another context.
973-
class ConcurrentExecutionChecker {
974-
/// Keeps track of the first location at which a given local function is
975-
/// referenced from a context that may execute concurrently with the
976-
/// context in which it was introduced.
977-
llvm::SmallDenseMap<const FuncDecl *, SourceLoc, 4> concurrentRefs;
978-
979-
public:
980-
/// Determine whether (and where) a given local function is referenced
981-
/// from a context that may execute concurrently with the context in
982-
/// which it is declared.
983-
///
984-
/// \returns the source location of the first reference to the local
985-
/// function that may be concurrent. If the result is an invalid
986-
/// source location, there are no such references.
987-
SourceLoc getConcurrentReferenceLoc(const FuncDecl *localFunc);
988-
989-
/// Determine whether code in the given use context might execute
990-
/// concurrently with code in the definition context.
991-
bool mayExecuteConcurrentlyWith(
992-
const DeclContext *useContext, const DeclContext *defContext);
993-
};
994-
995965
/// Check for adherence to the actor isolation rules, emitting errors
996966
/// when actor-isolated declarations are used in an unsafe manner.
997967
class ActorIsolationChecker : public ASTWalker {
998968
ASTContext &ctx;
999969
SmallVector<const DeclContext *, 4> contextStack;
1000970
SmallVector<ApplyExpr*, 4> applyStack;
1001971

1002-
ConcurrentExecutionChecker concurrentExecutionChecker;
1003-
1004972
using MutableVarSource = llvm::PointerUnion<DeclRefExpr *, InOutExpr *>;
1005973
using MutableVarParent = llvm::PointerUnion<InOutExpr *, LoadExpr *>;
1006974

@@ -1016,10 +984,7 @@ namespace {
1016984
/// Determine whether code in the given use context might execute
1017985
/// concurrently with code in the definition context.
1018986
bool mayExecuteConcurrentlyWith(
1019-
const DeclContext *useContext, const DeclContext *defContext) {
1020-
return concurrentExecutionChecker.mayExecuteConcurrentlyWith(
1021-
useContext, defContext);
1022-
}
987+
const DeclContext *useContext, const DeclContext *defContext);
1023988

1024989
/// If the subexpression is a reference to a mutable local variable from a
1025990
/// different context, record its parent. We'll query this as part of
@@ -1689,6 +1654,22 @@ namespace {
16891654
return true;
16901655
}
16911656

1657+
if (auto func = dyn_cast<FuncDecl>(value)) {
1658+
if (func->isConcurrent())
1659+
return false;
1660+
1661+
func->diagnose(
1662+
diag::local_function_executed_concurrently,
1663+
func->getDescriptiveKind(), func->getName())
1664+
.fixItInsert(func->getAttributeInsertionLoc(false), "@concurrent ");
1665+
1666+
// Add the @concurrent attribute implicitly, so we don't diagnose
1667+
// again.
1668+
const_cast<FuncDecl *>(func)->getAttrs().add(
1669+
new (ctx) ConcurrentAttr(true));
1670+
return true;
1671+
}
1672+
16921673
// Concurrent access to some other local.
16931674
ctx.Diags.diagnose(
16941675
loc, diag::concurrent_access_local,
@@ -1910,111 +1891,7 @@ namespace {
19101891
};
19111892
}
19121893

1913-
SourceLoc ConcurrentExecutionChecker::getConcurrentReferenceLoc(
1914-
const FuncDecl *localFunc) {
1915-
1916-
// If we've already computed a result, we're done.
1917-
auto known = concurrentRefs.find(localFunc);
1918-
if (known != concurrentRefs.end())
1919-
return known->second;
1920-
1921-
// Record that there are no concurrent references to this local function. This
1922-
// prevents infinite recursion if two local functions call each other.
1923-
concurrentRefs[localFunc] = SourceLoc();
1924-
1925-
class ConcurrentLocalRefWalker : public ASTWalker {
1926-
ConcurrentExecutionChecker &checker;
1927-
const FuncDecl *targetFunc;
1928-
SmallVector<const DeclContext *, 4> contextStack;
1929-
1930-
const DeclContext *getDeclContext() const {
1931-
return contextStack.back();
1932-
}
1933-
1934-
public:
1935-
ConcurrentLocalRefWalker(
1936-
ConcurrentExecutionChecker &checker, const FuncDecl *targetFunc
1937-
) : checker(checker), targetFunc(targetFunc) {
1938-
contextStack.push_back(targetFunc->getDeclContext());
1939-
}
1940-
1941-
std::pair<bool, Expr *> walkToExprPre(Expr *expr) override {
1942-
if (auto *closure = dyn_cast<AbstractClosureExpr>(expr)) {
1943-
contextStack.push_back(closure);
1944-
return { true, expr };
1945-
}
1946-
1947-
if (auto *declRef = dyn_cast<DeclRefExpr>(expr)) {
1948-
// If this is a reference to the target function from a context
1949-
// that may execute concurrently with the context where the target
1950-
// function was declared, record the location.
1951-
if (declRef->getDecl() == targetFunc &&
1952-
checker.mayExecuteConcurrentlyWith(
1953-
getDeclContext(), contextStack.front())) {
1954-
SourceLoc &loc = checker.concurrentRefs[targetFunc];
1955-
if (loc.isInvalid())
1956-
loc = declRef->getLoc();
1957-
1958-
return { false, expr };
1959-
}
1960-
1961-
return { true, expr };
1962-
}
1963-
1964-
return { true, expr };
1965-
}
1966-
1967-
Expr *walkToExprPost(Expr *expr) override {
1968-
if (auto *closure = dyn_cast<AbstractClosureExpr>(expr)) {
1969-
assert(contextStack.back() == closure);
1970-
contextStack.pop_back();
1971-
}
1972-
1973-
return expr;
1974-
}
1975-
1976-
bool walkToDeclPre(Decl *decl) override {
1977-
if (isa<NominalTypeDecl>(decl) || isa<ExtensionDecl>(decl))
1978-
return false;
1979-
1980-
if (auto func = dyn_cast<AbstractFunctionDecl>(decl)) {
1981-
contextStack.push_back(func);
1982-
}
1983-
1984-
return true;
1985-
}
1986-
1987-
bool walkToDeclPost(Decl *decl) override {
1988-
if (auto func = dyn_cast<AbstractFunctionDecl>(decl)) {
1989-
assert(contextStack.back() == func);
1990-
contextStack.pop_back();
1991-
}
1992-
1993-
return true;
1994-
}
1995-
};
1996-
1997-
// Walk the body of the enclosing function, where all references to the
1998-
// given local function would occur.
1999-
Stmt *enclosingBody = nullptr;
2000-
DeclContext *enclosingDC = localFunc->getDeclContext();
2001-
if (auto enclosingFunc = dyn_cast<AbstractFunctionDecl>(enclosingDC))
2002-
enclosingBody = enclosingFunc->getBody();
2003-
else if (auto enclosingClosure = dyn_cast<ClosureExpr>(enclosingDC))
2004-
enclosingBody = enclosingClosure->getBody();
2005-
else if (auto enclosingTopLevelCode = dyn_cast<TopLevelCodeDecl>(enclosingDC))
2006-
enclosingBody = enclosingTopLevelCode->getBody();
2007-
else
2008-
return SourceLoc();
2009-
2010-
assert(enclosingBody && "Cannot have a local function here");
2011-
ConcurrentLocalRefWalker walker(*this, localFunc);
2012-
enclosingBody->walk(walker);
2013-
2014-
return concurrentRefs[localFunc];
2015-
}
2016-
2017-
bool ConcurrentExecutionChecker::mayExecuteConcurrentlyWith(
1894+
bool ActorIsolationChecker::mayExecuteConcurrentlyWith(
20181895
const DeclContext *useContext, const DeclContext *defContext) {
20191896
// Walk the context chain from the use to the definition.
20201897
while (useContext != defContext) {
@@ -2029,25 +1906,6 @@ bool ConcurrentExecutionChecker::mayExecuteConcurrentlyWith(
20291906
// If the function is @concurrent... it can be run concurrently.
20301907
if (func->isConcurrent())
20311908
return true;
2032-
2033-
// If we find a local function that was referenced in code that can be
2034-
// executed concurrently with where the local function was declared, the
2035-
// local function can be run concurrently.
2036-
SourceLoc concurrentLoc = getConcurrentReferenceLoc(func);
2037-
if (concurrentLoc.isValid()) {
2038-
ASTContext &ctx = func->getASTContext();
2039-
func->diagnose(
2040-
diag::local_function_executed_concurrently,
2041-
func->getDescriptiveKind(), func->getName())
2042-
.fixItInsert(func->getAttributeInsertionLoc(false), "@concurrent ");
2043-
ctx.Diags.diagnose(concurrentLoc, diag::concurrent_access_here);
2044-
2045-
// Add the @concurrent attribute implicitly, so we don't diagnose
2046-
// again.
2047-
const_cast<FuncDecl *>(func)->getAttrs().add(
2048-
new (ctx) ConcurrentAttr(true));
2049-
return true;
2050-
}
20511909
}
20521910
}
20531911

test/Concurrency/actor_isolation.swift

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -356,7 +356,7 @@ func checkLocalFunctions() async {
356356
}
357357

358358
func local2() { // expected-error{{concurrently-executed local function 'local2()' must be marked as '@concurrent'}}{{3-3=@concurrent }}
359-
j = 42 // expected-error{{mutation of captured var 'j' in concurrently-executing code}}
359+
j = 42
360360
}
361361

362362
// Okay to call locally.
@@ -371,7 +371,7 @@ func checkLocalFunctions() async {
371371

372372
// Escaping closures can make the local function execute concurrently.
373373
acceptConcurrentClosure {
374-
local2() // expected-note{{access in concurrently-executed code here}}
374+
local2()
375375
}
376376

377377
print(i)
@@ -380,7 +380,7 @@ func checkLocalFunctions() async {
380380
var k = 17
381381
func local4() {
382382
acceptConcurrentClosure {
383-
local3() // expected-note{{access in concurrently-executed code here}}
383+
local3()
384384
}
385385
}
386386

test/Concurrency/async_task_groups.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,8 @@ func first_allMustSucceed() async throws {
8787
}
8888

8989
func first_ignoreFailures() async throws {
90-
func work() async -> Int { 42 }
91-
func boom() async throws -> Int { throw Boom() }
90+
@concurrent func work() async -> Int { 42 }
91+
@concurrent func boom() async throws -> Int { throw Boom() }
9292

9393
let first: Int = try await Task.withGroup(resultType: Int.self) { group in
9494
await group.add { await work() }

0 commit comments

Comments
 (0)