Skip to content

Commit b77ba9a

Browse files
committed
[Concurrency] Infer @Concurrent on closures from contextual type
1 parent ba8819e commit b77ba9a

File tree

3 files changed

+49
-13
lines changed

3 files changed

+49
-13
lines changed

lib/Sema/CSApply.cpp

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6661,9 +6661,24 @@ Expr *ExprRewriter::coerceToType(Expr *expr, Type toType,
66616661
}
66626662
}
66636663

6664-
// If we have a ClosureExpr, then we can safely propagate the 'no escape'
6664+
// If we have a ClosureExpr, then we can safely propagate the 'concurrent'
66656665
// bit to the closure without invalidating prior analysis.
66666666
auto fromEI = fromFunc->getExtInfo();
6667+
if (toEI.isConcurrent() && !fromEI.isConcurrent()) {
6668+
auto newFromFuncType = fromFunc->withExtInfo(fromEI.withConcurrent());
6669+
if (applyTypeToClosureExpr(cs, expr, newFromFuncType)) {
6670+
fromFunc = newFromFuncType->castTo<FunctionType>();
6671+
6672+
// Propagating the 'concurrent' bit might have satisfied the entire
6673+
// conversion. If so, we're done, otherwise keep converting.
6674+
if (fromFunc->isEqual(toType))
6675+
return expr;
6676+
}
6677+
}
6678+
6679+
// If we have a ClosureExpr, then we can safely propagate the 'no escape'
6680+
// bit to the closure without invalidating prior analysis.
6681+
fromEI = fromFunc->getExtInfo();
66676682
if (toEI.isNoEscape() && !fromEI.isNoEscape()) {
66686683
auto newFromFuncType = fromFunc->withExtInfo(fromEI.withNoEscape());
66696684
if (!isInDefaultArgumentContext &&

lib/Sema/CSSimplify.cpp

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1894,19 +1894,19 @@ ConstraintSystem::matchFunctionTypes(FunctionType *func1, FunctionType *func2,
18941894
}
18951895
}
18961896

1897-
// A @concurrent function can be a subtype of a non-@concurrent function.
1898-
if (func1->isConcurrent() != func2->isConcurrent()) {
1899-
// Cannot add '@concurrent'.
1900-
if (func2->isConcurrent() || kind < ConstraintKind::Subtype) {
1901-
if (!shouldAttemptFixes())
1902-
return getTypeMatchFailure(locator);
1897+
// A @concurrent function can be a subtype of a non-@concurrent function.
1898+
if (func1->isConcurrent() != func2->isConcurrent()) {
1899+
// Cannot add '@concurrent'.
1900+
if (func2->isConcurrent() || kind < ConstraintKind::Subtype) {
1901+
if (!shouldAttemptFixes())
1902+
return getTypeMatchFailure(locator);
19031903

1904-
auto *fix = AddConcurrentAttribute::create(
1905-
*this, func1, func2, getConstraintLocator(locator));
1906-
if (recordFix(fix))
1907-
return getTypeMatchFailure(locator);
1908-
}
1904+
auto *fix = AddConcurrentAttribute::create(
1905+
*this, func1, func2, getConstraintLocator(locator));
1906+
if (recordFix(fix))
1907+
return getTypeMatchFailure(locator);
19091908
}
1909+
}
19101910

19111911
// A non-@noescape function type can be a subtype of a @noescape function
19121912
// type.
@@ -7864,9 +7864,16 @@ bool ConstraintSystem::resolveClosure(TypeVariableType *typeVar,
78647864
parameters.push_back(param);
78657865
}
78667866

7867+
// Propagate @concurrent from the contextual type to the closure.
7868+
auto closureExtInfo = inferredClosureType->getExtInfo();
7869+
if (auto contextualFnType = contextualType->getAs<FunctionType>()) {
7870+
if (contextualFnType->isConcurrent())
7871+
closureExtInfo = closureExtInfo.withConcurrent();
7872+
}
7873+
78677874
auto closureType =
78687875
FunctionType::get(parameters, inferredClosureType->getResult(),
7869-
inferredClosureType->getExtInfo());
7876+
closureExtInfo);
78707877
assignFixedType(typeVar, closureType, closureLocator);
78717878

78727879
// If there is a result builder to apply, do so now.

test/attr/attr_concurrent.swift

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,3 +24,17 @@ func passingConcurrentOrNot(
2424
acceptsNonConcurrent(cfn) // okay
2525
acceptsNonConcurrent(ncfn) // okay
2626
}
27+
28+
func closures() {
29+
// Okay, inferring @concurrent
30+
acceptsConcurrent { $0 }
31+
acceptsConcurrent({ $0 })
32+
acceptsConcurrent({ i in i })
33+
acceptsConcurrent({ (i: Int) -> Int in
34+
print(i)
35+
return i
36+
})
37+
38+
let closure1 = { $0 + 1 } // inferred to be non-concurrent
39+
acceptsConcurrent(closure1) // expected-error{{converting non-concurrent function value to '@concurrent (Int) -> Int' may introduce data races}}
40+
}

0 commit comments

Comments
 (0)