Skip to content

Commit 29982fc

Browse files
committed
Add fixits for distributed system adhoc requirements; make easier to adopt
rdar://114185115
1 parent 5ba26d7 commit 29982fc

6 files changed

+264
-110
lines changed

include/swift/AST/DiagnosticsSema.def

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5348,9 +5348,9 @@ ERROR(broken_distributed_actor_requirement,none,
53485348
ERROR(distributed_actor_system_conformance_missing_adhoc_requirement,none,
53495349
"%kind0 is missing witness for protocol requirement %1",
53505350
(const ValueDecl *, DeclName))
5351-
NOTE(note_distributed_actor_system_conformance_missing_adhoc_requirement,none,
5352-
"protocol %0 requires function %1 with signature:\n%2",
5353-
(DeclName, DeclName, StringRef))
5351+
//NOTE(note_distributed_actor_system_conformance_missing_adhoc_requirement,none,
5352+
// "protocol %0 requires function %1 with signature:\n%2",
5353+
// (DeclName, DeclName, StringRef))
53545354

53555355
ERROR(override_implicit_unowned_executor,none,
53565356
"cannot override an actor's 'unownedExecutor' property that wasn't "

lib/Sema/TypeCheckDistributed.cpp

Lines changed: 82 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
#include "swift/AST/TypeVisitor.h"
2828
#include "swift/AST/ExistentialLayout.h"
2929
#include "swift/Basic/Defer.h"
30+
#include "swift/AST/ASTPrinter.h"
3031

3132
using namespace swift;
3233

@@ -222,6 +223,69 @@ static bool checkAdHocRequirementAccessControl(
222223
return true;
223224
}
224225

226+
static bool diagnoseMissingAdHocProtocolRequirement(ASTContext &C, Identifier identifier, NominalTypeDecl *decl) {
227+
assert(decl);
228+
auto FixitLocation = decl->getBraces().Start;
229+
230+
// Prepare the indent (same as `printRequirementStub`)
231+
StringRef ExtraIndent;
232+
StringRef CurrentIndent =
233+
Lexer::getIndentationForLine(C.SourceMgr, decl->getStartLoc(), &ExtraIndent);
234+
235+
llvm::SmallString<128> Text;
236+
llvm::raw_svector_ostream OS(Text);
237+
ExtraIndentStreamPrinter Printer(OS, CurrentIndent);
238+
239+
Printer << (decl->getFormalAccess() == AccessLevel::Public ? "public " : "");
240+
241+
if (identifier == C.Id_remoteCall) {
242+
Printer << "func remoteCall<Act, Err, Res>("
243+
"on actor: Act, "
244+
"target: RemoteCallTarget, "
245+
"invocation: inout InvocationEncoder, "
246+
"throwing: Err.Type, "
247+
"returning: Res.Type) "
248+
"async throws -> Res "
249+
"where Act: DistributedActor, "
250+
"Act.ID == ActorID, "
251+
"Err: Error, "
252+
"Res: SerializationRequirement";
253+
} else if (identifier == C.Id_remoteCallVoid) {
254+
Printer << "func remoteCallVoid<Act, Err>("
255+
"on actor: Act, "
256+
"target: RemoteCallTarget, "
257+
"invocation: inout InvocationEncoder, "
258+
"throwing: Err.Type"
259+
") async throws "
260+
"where Act: DistributedActor, "
261+
"Act.ID == ActorID, "
262+
"Err: Error";
263+
} else if (identifier == C.Id_recordArgument) {
264+
Printer << "mutating func recordArgument<Value: SerializationRequirement>(_ argument: RemoteCallArgument<Value>) throws";
265+
} else if (identifier == C.Id_recordReturnType) {
266+
Printer << "mutating func recordReturnType<Res: SerializationRequirement>(_ resultType: Res.Type) throws";
267+
} else if (identifier == C.Id_decodeNextArgument) {
268+
Printer << "mutating func decodeNextArgument<Argument: SerializationRequirement>() throws -> Argument";
269+
} else if (identifier == C.Id_onReturn) {
270+
Printer << "func onReturn<Success: SerializationRequirement>(value: Success) async throws";
271+
} else {
272+
llvm_unreachable("Unknown identifier for diagnosing ad-hoc missing requirement.");
273+
}
274+
275+
/// Print the "{ <#code#> }" placeholder body
276+
Printer << " {\n";
277+
Printer << ExtraIndent << getCodePlaceholder();
278+
Printer << "\n}\n";
279+
280+
decl->diagnose(
281+
diag::distributed_actor_system_conformance_missing_adhoc_requirement,
282+
decl, identifier);
283+
decl->diagnose(diag::missing_witnesses_general)
284+
.fixItInsertAfter(FixitLocation, Text.str());
285+
286+
return true;
287+
}
288+
225289
bool swift::checkDistributedActorSystemAdHocProtocolRequirements(
226290
ASTContext &C,
227291
ProtocolDecl *Proto,
@@ -238,53 +302,21 @@ bool swift::checkDistributedActorSystemAdHocProtocolRequirements(
238302
auto remoteCallDecl =
239303
C.getRemoteCallOnDistributedActorSystem(decl, /*isVoidReturn=*/false);
240304
if (!remoteCallDecl && diagnose) {
241-
auto identifier = C.Id_remoteCall;
242-
decl->diagnose(
243-
diag::distributed_actor_system_conformance_missing_adhoc_requirement,
244-
decl, identifier);
245-
decl->diagnose(
246-
diag::note_distributed_actor_system_conformance_missing_adhoc_requirement,
247-
Proto->getName(), identifier,
248-
"func remoteCall<Act, Err, Res>(\n"
249-
" on actor: Act,\n"
250-
" target: RemoteCallTarget,\n"
251-
" invocation: inout InvocationEncoder,\n"
252-
" throwing: Err.Type,\n"
253-
" returning: Res.Type\n"
254-
") async throws -> Res\n"
255-
" where Act: DistributedActor,\n"
256-
" Act.ID == ActorID,\n"
257-
" Err: Error,\n"
258-
" Res: SerializationRequirement\n");
259-
anyMissingAdHocRequirements = true;
305+
anyMissingAdHocRequirements = diagnoseMissingAdHocProtocolRequirement(C, C.Id_remoteCall, decl);
260306
}
261-
if (checkAdHocRequirementAccessControl(decl, Proto, remoteCallDecl))
307+
if (checkAdHocRequirementAccessControl(decl, Proto, remoteCallDecl)) {
262308
anyMissingAdHocRequirements = true;
309+
}
263310

264311
// - remoteCallVoid
265312
auto remoteCallVoidDecl =
266313
C.getRemoteCallOnDistributedActorSystem(decl, /*isVoidReturn=*/true);
267314
if (!remoteCallVoidDecl && diagnose) {
268-
auto identifier = C.Id_remoteCallVoid;
269-
decl->diagnose(
270-
diag::distributed_actor_system_conformance_missing_adhoc_requirement,
271-
decl, identifier);
272-
decl->diagnose(
273-
diag::note_distributed_actor_system_conformance_missing_adhoc_requirement,
274-
Proto->getName(), identifier,
275-
"func remoteCallVoid<Act, Err>(\n"
276-
" on actor: Act,\n"
277-
" target: RemoteCallTarget,\n"
278-
" invocation: inout InvocationEncoder,\n"
279-
" throwing: Err.Type\n"
280-
") async throws\n"
281-
" where Act: DistributedActor,\n"
282-
" Act.ID == ActorID,\n"
283-
" Err: Error\n");
284-
anyMissingAdHocRequirements = true;
315+
anyMissingAdHocRequirements = diagnoseMissingAdHocProtocolRequirement(C, C.Id_remoteCallVoid, decl);
285316
}
286-
if (checkAdHocRequirementAccessControl(decl, Proto, remoteCallVoidDecl))
317+
if (checkAdHocRequirementAccessControl(decl, Proto, remoteCallVoidDecl)) {
287318
anyMissingAdHocRequirements = true;
319+
}
288320

289321
return anyMissingAdHocRequirements;
290322
}
@@ -295,32 +327,20 @@ bool swift::checkDistributedActorSystemAdHocProtocolRequirements(
295327
// - recordArgument
296328
auto recordArgumentDecl = C.getRecordArgumentOnDistributedInvocationEncoder(decl);
297329
if (!recordArgumentDecl) {
298-
auto identifier = C.Id_recordArgument;
299-
decl->diagnose(
300-
diag::distributed_actor_system_conformance_missing_adhoc_requirement,
301-
decl, identifier);
302-
decl->diagnose(diag::note_distributed_actor_system_conformance_missing_adhoc_requirement,
303-
Proto->getName(), identifier,
304-
"mutating func recordArgument<Value: SerializationRequirement>(_ argument: RemoteCallArgument<Value>) throws\n");
305-
anyMissingAdHocRequirements = true;
330+
anyMissingAdHocRequirements = diagnoseMissingAdHocProtocolRequirement(C, C.Id_recordArgument, decl);
306331
}
307-
if (checkAdHocRequirementAccessControl(decl, Proto, recordArgumentDecl))
332+
if (checkAdHocRequirementAccessControl(decl, Proto, recordArgumentDecl)) {
308333
anyMissingAdHocRequirements = true;
334+
}
309335

310336
// - recordReturnType
311337
auto recordReturnTypeDecl = C.getRecordReturnTypeOnDistributedInvocationEncoder(decl);
312338
if (!recordReturnTypeDecl) {
313-
auto identifier = C.Id_recordReturnType;
314-
decl->diagnose(
315-
diag::distributed_actor_system_conformance_missing_adhoc_requirement,
316-
decl, identifier);
317-
decl->diagnose(diag::note_distributed_actor_system_conformance_missing_adhoc_requirement,
318-
Proto->getName(), identifier,
319-
"mutating func recordReturnType<Res: SerializationRequirement>(_ resultType: Res.Type) throws\n");
320-
anyMissingAdHocRequirements = true;
339+
anyMissingAdHocRequirements = diagnoseMissingAdHocProtocolRequirement(C, C.Id_recordReturnType, decl);
321340
}
322-
if (checkAdHocRequirementAccessControl(decl, Proto, recordReturnTypeDecl))
341+
if (checkAdHocRequirementAccessControl(decl, Proto, recordReturnTypeDecl)) {
323342
anyMissingAdHocRequirements = true;
343+
}
324344

325345
return anyMissingAdHocRequirements;
326346
}
@@ -331,17 +351,11 @@ bool swift::checkDistributedActorSystemAdHocProtocolRequirements(
331351
// - decodeNextArgument
332352
auto decodeNextArgumentDecl = C.getDecodeNextArgumentOnDistributedInvocationDecoder(decl);
333353
if (!decodeNextArgumentDecl) {
334-
auto identifier = C.Id_decodeNextArgument;
335-
decl->diagnose(
336-
diag::distributed_actor_system_conformance_missing_adhoc_requirement,
337-
decl, identifier);
338-
decl->diagnose(diag::note_distributed_actor_system_conformance_missing_adhoc_requirement,
339-
Proto->getName(), identifier,
340-
"mutating func decodeNextArgument<Argument: SerializationRequirement>() throws -> Argument\n");
341-
anyMissingAdHocRequirements = true;
354+
anyMissingAdHocRequirements = diagnoseMissingAdHocProtocolRequirement(C, C.Id_decodeNextArgument, decl);
342355
}
343-
if (checkAdHocRequirementAccessControl(decl, Proto, decodeNextArgumentDecl))
356+
if (checkAdHocRequirementAccessControl(decl, Proto, decodeNextArgumentDecl)) {
344357
anyMissingAdHocRequirements = true;
358+
}
345359

346360
return anyMissingAdHocRequirements;
347361
}
@@ -352,19 +366,11 @@ bool swift::checkDistributedActorSystemAdHocProtocolRequirements(
352366
// - onReturn
353367
auto onReturnDecl = C.getOnReturnOnDistributedTargetInvocationResultHandler(decl);
354368
if (!onReturnDecl) {
355-
auto identifier = C.Id_onReturn;
356-
decl->diagnose(
357-
diag::distributed_actor_system_conformance_missing_adhoc_requirement,
358-
decl, identifier);
359-
decl->diagnose(
360-
diag::note_distributed_actor_system_conformance_missing_adhoc_requirement,
361-
Proto->getName(), identifier,
362-
"func onReturn<Success: SerializationRequirement>(value: "
363-
"Success) async throws\n");
364-
anyMissingAdHocRequirements = true;
369+
anyMissingAdHocRequirements = diagnoseMissingAdHocProtocolRequirement(C, C.Id_onReturn, decl);
365370
}
366-
if (checkAdHocRequirementAccessControl(decl, Proto, onReturnDecl))
371+
if (checkAdHocRequirementAccessControl(decl, Proto, onReturnDecl)) {
367372
anyMissingAdHocRequirements = true;
373+
}
368374

369375
return anyMissingAdHocRequirements;
370376
}
Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
// RUN: %target-swift-frontend -typecheck -verify -disable-availability-checking -I %t 2>&1 %s
2+
3+
// UNSUPPORTED: back_deploy_concurrency
4+
// REQUIRES: concurrency
5+
// REQUIRES: distributed
6+
7+
import Distributed
8+
9+
struct MissingRemoteCall: DistributedActorSystem {
10+
// expected-error@-1{{struct 'MissingRemoteCall' is missing witness for protocol requirement 'remoteCall'}}
11+
// expected-note@-2{{add stubs for conformance}}{{51-51=func remoteCall<Act, Err, Res>(on actor: Act, target: RemoteCallTarget, invocation: inout InvocationEncoder, throwing: Err.Type, returning: Res.Type) async throws -> Res where Act: DistributedActor, Act.ID == ActorID, Err: Error, Res: SerializationRequirement {\n <#code#>\n}\n}}
12+
13+
// expected-error@-4{{struct 'MissingRemoteCall' is missing witness for protocol requirement 'remoteCallVoid'}}
14+
// expected-note@-5{{add stubs for conformance}}{{51-51=func remoteCallVoid<Act, Err>(on actor: Act, target: RemoteCallTarget, invocation: inout InvocationEncoder, throwing: Err.Type) async throws where Act: DistributedActor, Act.ID == ActorID, Err: Error {\n <#code#>\n}\n}}
15+
16+
typealias ActorID = ActorAddress
17+
typealias InvocationDecoder = FakeInvocationDecoder
18+
typealias InvocationEncoder = FakeInvocationEncoder
19+
typealias SerializationRequirement = Codable
20+
typealias ResultHandler = FakeResultHandler
21+
22+
func resolve<Act>(id: ActorID, as actorType: Act.Type)
23+
throws -> Act? where Act: DistributedActor {
24+
return nil
25+
}
26+
27+
func assignID<Act>(_ actorType: Act.Type) -> ActorID
28+
where Act: DistributedActor {
29+
ActorAddress(parse: "fake://123")
30+
}
31+
32+
func actorReady<Act>(_ actor: Act)
33+
where Act: DistributedActor,
34+
Act.ID == ActorID {
35+
}
36+
37+
func resignID(_ id: ActorID) {
38+
}
39+
40+
func makeInvocationEncoder() -> InvocationEncoder {
41+
}
42+
}
43+
44+
public struct PublicMissingRemoteCall: DistributedActorSystem {
45+
// expected-error@-1{{struct 'PublicMissingRemoteCall' is missing witness for protocol requirement 'remoteCall'}}
46+
// expected-note@-2{{add stubs for conformance}}{{64-64=public func remoteCall<Act, Err, Res>(on actor: Act, target: RemoteCallTarget, invocation: inout InvocationEncoder, throwing: Err.Type, returning: Res.Type) async throws -> Res where Act: DistributedActor, Act.ID == ActorID, Err: Error, Res: SerializationRequirement {\n <#code#>\n}\n}}
47+
48+
// expected-error@-4{{struct 'PublicMissingRemoteCall' is missing witness for protocol requirement 'remoteCallVoid'}}
49+
// expected-note@-5{{add stubs for conformance}}{{64-64=public func remoteCallVoid<Act, Err>(on actor: Act, target: RemoteCallTarget, invocation: inout InvocationEncoder, throwing: Err.Type) async throws where Act: DistributedActor, Act.ID == ActorID, Err: Error {\n <#code#>\n}\n}}
50+
51+
52+
public typealias ActorID = ActorAddress
53+
public typealias InvocationDecoder = FakeInvocationDecoder
54+
public typealias InvocationEncoder = FakeInvocationEncoder
55+
public typealias SerializationRequirement = Codable
56+
public typealias ResultHandler = FakeResultHandler
57+
58+
public func resolve<Act>(id: ActorID, as actorType: Act.Type)
59+
throws -> Act? where Act: DistributedActor {
60+
return nil
61+
}
62+
63+
public func assignID<Act>(_ actorType: Act.Type) -> ActorID
64+
where Act: DistributedActor {
65+
ActorAddress(parse: "fake://123")
66+
}
67+
68+
public func actorReady<Act>(_ actor: Act)
69+
where Act: DistributedActor,
70+
Act.ID == ActorID {
71+
}
72+
73+
public func resignID(_ id: ActorID) {
74+
}
75+
76+
public func makeInvocationEncoder() -> InvocationEncoder {
77+
}
78+
}
79+
80+
// ==== ------------------------------------------------------------------------
81+
82+
public struct ActorAddress: Sendable, Hashable, Codable {
83+
let address: String
84+
85+
init(parse address: String) {
86+
self.address = address
87+
}
88+
}
89+
90+
public protocol SomeProtocol: Sendable {
91+
}
92+
93+
public struct FakeInvocationEncoder: DistributedTargetInvocationEncoder {
94+
public typealias SerializationRequirement = Codable
95+
96+
public mutating func recordGenericSubstitution<T>(_ type: T.Type) throws {
97+
}
98+
99+
public mutating func recordArgument<Value: SerializationRequirement>(_ argument: RemoteCallArgument<Value>) throws {
100+
}
101+
102+
public mutating func recordReturnType<R: SerializationRequirement>(_ type: R.Type) throws {
103+
}
104+
105+
public mutating func recordErrorType<E: Error>(_ type: E.Type) throws {
106+
}
107+
108+
public mutating func doneRecording() throws {
109+
}
110+
}
111+
112+
public class FakeInvocationDecoder: DistributedTargetInvocationDecoder {
113+
public typealias SerializationRequirement = Codable
114+
115+
public func decodeGenericSubstitutions() throws -> [Any.Type] {
116+
[]
117+
}
118+
119+
public func decodeNextArgument<Argument: SerializationRequirement>() throws -> Argument {
120+
fatalError()
121+
}
122+
123+
public func decodeReturnType() throws -> Any.Type? {
124+
nil
125+
}
126+
127+
public func decodeErrorType() throws -> Any.Type? {
128+
nil
129+
}
130+
}
131+
132+
public struct FakeResultHandler: DistributedTargetInvocationResultHandler {
133+
public typealias SerializationRequirement = Codable
134+
135+
public func onReturn<Success: SerializationRequirement>(value: Success) async throws {
136+
print("RETURN: \(value)")
137+
}
138+
139+
public func onReturnVoid() async throws {
140+
print("RETURN VOID")
141+
}
142+
143+
public func onThrow<Err: Error>(error: Err) async throws {
144+
print("ERROR: \(error)")
145+
}
146+
}
147+
148+

0 commit comments

Comments
 (0)