Skip to content

Commit e087759

Browse files
authored
[Concurrency] Custom executors with move-only Job (#63569)
1 parent 8b2ecdb commit e087759

26 files changed

+556
-73
lines changed

include/swift/ABI/Task.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,10 @@ class alignas(2 * alignof(void*)) Job :
136136
return Flags.getPriority();
137137
}
138138

139+
uint32_t getJobId() const {
140+
return Id;
141+
}
142+
139143
/// Given that we've fully established the job context in the current
140144
/// thread, actually start running this job. To establish the context
141145
/// correctly, call swift_job_run or runJobInExecutorContext.

include/swift/AST/Decl.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3860,6 +3860,9 @@ class NominalTypeDecl : public GenericTypeDecl, public IterableDeclContext {
38603860
/// Find the 'RemoteCallArgument(label:name:value:)' initializer function.
38613861
ConstructorDecl *getDistributedRemoteCallArgumentInitFunction() const;
38623862

3863+
/// Get the move-only `enqueue(Job)` protocol requirement function on the `Executor` protocol.
3864+
AbstractFunctionDecl *getExecutorOwnedEnqueueFunction() const;
3865+
38633866
/// Collect the set of protocols to which this type should implicitly
38643867
/// conform, such as AnyObject (for classes).
38653868
void getImplicitProtocols(SmallVectorImpl<ProtocolDecl *> &protocols);

include/swift/AST/DiagnosticsSema.def

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6408,6 +6408,11 @@ WARNING(hashvalue_implementation,Deprecation,
64086408
"conform type %0 to 'Hashable' by implementing 'hash(into:)' instead",
64096409
(Type))
64106410

6411+
WARNING(executor_enqueue_unowned_implementation,Deprecation,
6412+
"'Executor.enqueue(UnownedJob)' is deprecated as a protocol requirement; "
6413+
"conform type %0 to 'Executor' by implementing 'func enqueue(Job)' instead",
6414+
(Type))
6415+
64116416
//------------------------------------------------------------------------------
64126417
// MARK: property wrapper diagnostics
64136418
//------------------------------------------------------------------------------

include/swift/AST/KnownProtocols.def

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ PROTOCOL(SIMDScalar)
7676
PROTOCOL(BinaryInteger)
7777
PROTOCOL(FixedWidthInteger)
7878
PROTOCOL(RangeReplaceableCollection)
79+
PROTOCOL(Executor)
7980
PROTOCOL(SerialExecutor)
8081
PROTOCOL(GlobalActor)
8182

include/swift/AST/KnownSDKTypes.def

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,10 @@ KNOWN_SDK_TYPE_DECL(ObjectiveC, ObjCBool, StructDecl, 0)
3939
// standardized
4040
KNOWN_SDK_TYPE_DECL(Concurrency, UnsafeContinuation, NominalTypeDecl, 2)
4141
KNOWN_SDK_TYPE_DECL(Concurrency, MainActor, NominalTypeDecl, 0)
42+
KNOWN_SDK_TYPE_DECL(Concurrency, Job, StructDecl, 0)
43+
KNOWN_SDK_TYPE_DECL(Concurrency, UnownedJob, StructDecl, 0)
44+
KNOWN_SDK_TYPE_DECL(Concurrency, Executor, NominalTypeDecl, 0)
45+
KNOWN_SDK_TYPE_DECL(Concurrency, SerialExecutor, NominalTypeDecl, 0)
4246
KNOWN_SDK_TYPE_DECL(Concurrency, UnownedSerialExecutor, NominalTypeDecl, 0)
4347

4448
KNOWN_SDK_TYPE_DECL(Concurrency, TaskLocal, ClassDecl, 1)

include/swift/Runtime/Concurrency.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -715,6 +715,11 @@ bool swift_task_isOnExecutor(
715715
const Metadata *selfType,
716716
const SerialExecutorWitnessTable *wtable);
717717

718+
/// Return the 64bit TaskID (if the job is an AsyncTask),
719+
/// or the 32bits of the job Id otherwise.
720+
SWIFT_EXPORT_FROM(swift_Concurrency) SWIFT_CC(swift)
721+
uint64_t swift_task_getJobTaskId(Job *job);
722+
718723
#if SWIFT_CONCURRENCY_ENABLE_DISPATCH
719724

720725
/// Enqueue the given job on the main executor.

lib/AST/ASTContext.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1120,6 +1120,7 @@ ProtocolDecl *ASTContext::getProtocol(KnownProtocolKind kind) const {
11201120
case KnownProtocolKind::GlobalActor:
11211121
case KnownProtocolKind::AsyncSequence:
11221122
case KnownProtocolKind::AsyncIteratorProtocol:
1123+
case KnownProtocolKind::Executor:
11231124
case KnownProtocolKind::SerialExecutor:
11241125
M = getLoadedModule(Id_Concurrency);
11251126
break;

lib/AST/Decl.cpp

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5248,6 +5248,39 @@ VarDecl *NominalTypeDecl::getGlobalActorInstance() const {
52485248
nullptr);
52495249
}
52505250

5251+
AbstractFunctionDecl *
5252+
NominalTypeDecl::getExecutorOwnedEnqueueFunction() const {
5253+
auto &C = getASTContext();
5254+
5255+
auto proto = dyn_cast<ProtocolDecl>(this);
5256+
if (!proto)
5257+
return nullptr;
5258+
5259+
llvm::SmallVector<ValueDecl *, 2> results;
5260+
lookupQualified(getSelfNominalTypeDecl(),
5261+
DeclNameRef(C.Id_enqueue),
5262+
NL_ProtocolMembers,
5263+
results);
5264+
5265+
for (auto candidate: results) {
5266+
// we're specifically looking for the Executor protocol requirement
5267+
if (!isa<ProtocolDecl>(candidate->getDeclContext()))
5268+
continue;
5269+
5270+
if (auto *funcDecl = dyn_cast<AbstractFunctionDecl>(candidate)) {
5271+
if (funcDecl->getParameters()->size() != 1)
5272+
continue;
5273+
5274+
auto params = funcDecl->getParameters();
5275+
if (params->get(0)->getSpecifier() == ParamSpecifier::LegacyOwned) { // TODO: make this Consuming
5276+
return funcDecl;
5277+
}
5278+
}
5279+
}
5280+
5281+
return nullptr;
5282+
}
5283+
52515284
ClassDecl::ClassDecl(SourceLoc ClassLoc, Identifier Name, SourceLoc NameLoc,
52525285
ArrayRef<InheritedEntry> Inherited,
52535286
GenericParamList *GenericParams, DeclContext *Parent,

lib/IRGen/GenMeta.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6231,6 +6231,7 @@ SpecialProtocol irgen::getSpecialProtocolID(ProtocolDecl *P) {
62316231
case KnownProtocolKind::CxxSequence:
62326232
case KnownProtocolKind::UnsafeCxxInputIterator:
62336233
case KnownProtocolKind::UnsafeCxxRandomAccessIterator:
6234+
case KnownProtocolKind::Executor:
62346235
case KnownProtocolKind::SerialExecutor:
62356236
case KnownProtocolKind::Sendable:
62366237
case KnownProtocolKind::UnsafeSendable:

lib/Sema/TypeCheckConcurrency.cpp

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1225,6 +1225,79 @@ void swift::diagnoseMissingExplicitSendable(NominalTypeDecl *nominal) {
12251225
}
12261226
}
12271227

1228+
void swift::tryDiagnoseExecutorConformance(ASTContext &C,
1229+
const NominalTypeDecl *nominal,
1230+
ProtocolDecl *proto) {
1231+
assert(proto->isSpecificProtocol(KnownProtocolKind::Executor) ||
1232+
proto->isSpecificProtocol(KnownProtocolKind::SerialExecutor));
1233+
1234+
auto &diags = C.Diags;
1235+
auto module = nominal->getParentModule();
1236+
Type nominalTy = nominal->getDeclaredInterfaceType();
1237+
1238+
// enqueue(_: UnownedJob)
1239+
auto enqueueDeclName = DeclName(C, DeclBaseName(C.Id_enqueue), { Identifier() });
1240+
1241+
FuncDecl *unownedEnqueueRequirement = nullptr;
1242+
FuncDecl *moveOnlyEnqueueRequirement = nullptr;
1243+
for (auto req: proto->getProtocolRequirements()) {
1244+
auto *funcDecl = dyn_cast<FuncDecl>(req);
1245+
if (!funcDecl)
1246+
continue;
1247+
1248+
if (funcDecl->getName() != enqueueDeclName)
1249+
continue;
1250+
1251+
1252+
// look for the first parameter being a Job or UnownedJob
1253+
if (funcDecl->getParameters()->size() != 1)
1254+
continue;
1255+
if (auto param = funcDecl->getParameters()->front()) {
1256+
if (param->getType()->isEqual(C.getJobDecl()->getDeclaredInterfaceType())) {
1257+
assert(moveOnlyEnqueueRequirement == nullptr);
1258+
moveOnlyEnqueueRequirement = funcDecl;
1259+
} else if (param->getType()->isEqual(C.getUnownedJobDecl()->getDeclaredInterfaceType())) {
1260+
assert(unownedEnqueueRequirement == nullptr);
1261+
unownedEnqueueRequirement = funcDecl;
1262+
}
1263+
}
1264+
1265+
// if we found both, we're done here and break out of the loop
1266+
if (unownedEnqueueRequirement && moveOnlyEnqueueRequirement)
1267+
break; // we're done looking for the requirements
1268+
}
1269+
1270+
1271+
auto conformance = module->lookupConformance(nominalTy, proto);
1272+
auto concreteConformance = conformance.getConcrete();
1273+
auto unownedEnqueueWitness = concreteConformance->getWitnessDeclRef(unownedEnqueueRequirement);
1274+
auto moveOnlyEnqueueWitness = concreteConformance->getWitnessDeclRef(moveOnlyEnqueueRequirement);
1275+
1276+
if (auto enqueueUnownedDecl = unownedEnqueueWitness.getDecl()) {
1277+
// Old UnownedJob based impl is present, warn about it suggesting the new protocol requirement.
1278+
if (enqueueUnownedDecl->getLoc().isValid()) {
1279+
diags.diagnose(enqueueUnownedDecl->getLoc(), diag::executor_enqueue_unowned_implementation, nominalTy);
1280+
}
1281+
}
1282+
1283+
if (auto unownedEnqueueDecl = unownedEnqueueWitness.getDecl()) {
1284+
if (auto moveOnlyEnqueueDecl = moveOnlyEnqueueWitness.getDecl()) {
1285+
if (unownedEnqueueDecl && unownedEnqueueDecl->getLoc().isInvalid() &&
1286+
moveOnlyEnqueueDecl && moveOnlyEnqueueDecl->getLoc().isInvalid()) {
1287+
// Neither old nor new implementation have been found, but we provide default impls for them
1288+
// that are mutually recursive, so we must error and suggest implementing the right requirement.
1289+
auto ownedRequirement = C.getExecutorDecl()->getExecutorOwnedEnqueueFunction();
1290+
nominal->diagnose(diag::type_does_not_conform, nominalTy, proto->getDeclaredInterfaceType());
1291+
ownedRequirement->diagnose(diag::no_witnesses,
1292+
getProtocolRequirementKind(ownedRequirement),
1293+
ownedRequirement->getName(),
1294+
proto->getDeclaredInterfaceType(),
1295+
/*AddFixIt=*/true);
1296+
}
1297+
}
1298+
}
1299+
}
1300+
12281301
/// Determine whether this is the main actor type.
12291302
static bool isMainActor(Type type) {
12301303
if (auto nominal = type->getAnyNominal())

0 commit comments

Comments
 (0)