Skip to content

Commit ea29723

Browse files
committed
[Distributed] Add name parameter to recordArgument for better interop
1 parent ce2c771 commit ea29723

22 files changed

+113
-93
lines changed

lib/AST/DistributedDecl.cpp

Lines changed: 54 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -635,60 +635,72 @@ AbstractFunctionDecl::isDistributedTargetInvocationEncoderRecordArgument() const
635635

636636
// === Check all parameters
637637
auto params = getParameters();
638-
if (params->size() != 1) {
638+
if (params->size() != 2) {
639639
return false;
640640
}
641641

642-
// --- Check parameter: _ argument
643-
auto argumentParam = params->get(0);
644-
if (!argumentParam->getArgumentName().is("")) {
645-
return false;
646-
}
642+
// --- Check parameter: label
643+
auto labelParam = params->get(0);
644+
if (!labelParam->getArgumentName().is("name")) {
645+
return false;
646+
}
647+
if (!labelParam->getInterfaceType()->isEqual(C.getStringType())) {
648+
return false;
649+
}
647650

648-
// === Check generic parameters in detail
649-
// --- Check: Argument: SerializationRequirement
650-
GenericTypeParamDecl *ArgumentParam = genericParams->getParams()[0];
651+
// --- Check parameter: _ argument
652+
auto argumentParam = params->get(1);
653+
if (!argumentParam->getArgumentName().is("")) {
654+
return false;
655+
}
651656

652-
auto sig = getGenericSignature();
653-
auto requirements = sig.getRequirements();
657+
// === Check generic parameters in detail
658+
// --- Check: Argument: SerializationRequirement
659+
GenericTypeParamDecl *ArgumentParam = genericParams->getParams()[0];
654660

655-
if (requirements.size() != expectedRequirementsNum) {
656-
return false;
657-
}
661+
auto sig = getGenericSignature();
662+
auto requirements = sig.getRequirements();
658663

659-
// --- Check the expected requirements
660-
// --- all the Argument requirements ---
661-
// conforms_to: Argument Decodable
662-
// conforms_to: Argument Encodable
663-
// ...
664+
if (requirements.size() != expectedRequirementsNum) {
665+
return false;
666+
}
664667

665-
auto func = dyn_cast<FuncDecl>(this);
666-
if (!func) {
667-
return false;
668-
}
668+
// --- Check the expected requirements
669+
// --- all the Argument requirements ---
670+
// e.g.
671+
// conforms_to: Argument Decodable
672+
// conforms_to: Argument Encodable
673+
// ...
669674

670-
auto resultType = func->mapTypeIntoContext(argumentParam->getInterfaceType())
671-
->getDesugaredType();
672-
auto resultParamType = func->mapTypeIntoContext(
673-
ArgumentParam->getInterfaceType()->getMetatypeInstanceType());
674-
// The result of the function must be the `Res` generic argument.
675-
if (!resultType->isEqual(resultParamType)) {
676-
return false;
677-
}
675+
auto func = dyn_cast<FuncDecl>(this);
676+
if (!func) {
677+
return false;
678+
}
678679

679-
for (auto requirementProto : requirementProtos) {
680-
auto conformance = module->lookupConformance(resultType, requirementProto);
681-
if (conformance.isInvalid()) {
680+
auto resultType =
681+
func->mapTypeIntoContext(argumentParam->getInterfaceType())
682+
->getDesugaredType();
683+
auto resultParamType = func->mapTypeIntoContext(
684+
ArgumentParam->getInterfaceType()->getMetatypeInstanceType());
685+
// The result of the function must be the `Res` generic argument.
686+
if (!resultType->isEqual(resultParamType)) {
682687
return false;
683688
}
684-
}
685689

686-
// === Check result type: Void
687-
if (!func->getResultInterfaceType()->isVoid()) {
688-
return false;
689-
}
690+
for (auto requirementProto : requirementProtos) {
691+
auto conformance =
692+
module->lookupConformance(resultType, requirementProto);
693+
if (conformance.isInvalid()) {
694+
return false;
695+
}
696+
}
690697

691-
return true;
698+
// === Check result type: Void
699+
if (!func->getResultInterfaceType()->isVoid()) {
700+
return false;
701+
}
702+
703+
return true;
692704
}
693705

694706
bool
@@ -879,8 +891,8 @@ AbstractFunctionDecl::isDistributedTargetInvocationEncoderRecordErrorType() cons
879891
}
880892

881893
// --- Check parameter: _ errorType
882-
auto argumentParam = params->get(0);
883-
if (!argumentParam->getArgumentName().is("")) {
894+
auto errorTypeParam = params->get(0);
895+
if (!errorTypeParam->getArgumentName().is("")) {
884896
return false;
885897
}
886898

lib/Sema/CodeSynthesisDistributedActor.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -269,9 +269,14 @@ deriveBodyDistributed_thunk(AbstractFunctionDecl *thunk, void *context) {
269269
auto recordArgumentDeclRef = UnresolvedDeclRefExpr::createImplicit(
270270
C, recordArgumentDecl->getName());
271271

272+
auto argumentName = param->getArgumentName().str();
272273
auto recordArgArgsList = ArgumentList::forImplicitCallTo(
273274
recordArgumentDeclRef->getName(),
274275
{
276+
// name:
277+
new (C) StringLiteralExpr(argumentName, SourceRange(),
278+
/*implicit=*/true),
279+
// _ argument:
275280
new (C) DeclRefExpr(
276281
ConcreteDeclRef(param), dloc, implicit,
277282
AccessSemantics::Ordinary,

lib/Sema/TypeCheckDistributed.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -305,7 +305,7 @@ bool swift::checkDistributedActorSystemAdHocProtocolRequirements(
305305
decl->getDescriptiveKind(), decl->getName(), identifier);
306306
decl->diagnose(diag::note_distributed_actor_system_conformance_missing_adhoc_requirement,
307307
decl->getName(), identifier,
308-
"mutating func recordArgument<Argument: SerializationRequirement>(_ argument: Argument) throws\n");
308+
"mutating func recordArgument<Argument: SerializationRequirement>(name: String, _ argument: Argument) throws\n");
309309
anyMissingAdHocRequirements = true;
310310
}
311311
if (checkAdHocRequirementAccessControl(decl, Proto, recordArgumentDecl))

stdlib/public/Distributed/DistributedActorSystem.swift

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -418,8 +418,7 @@ public protocol DistributedTargetInvocationEncoder {
418418
// ///
419419
// /// Record an argument of `Argument` type.
420420
// /// This will be invoked for every argument of the target, in declaration order.
421-
// mutating func recordArgument<Argument: SerializationRequirement>(_ argument: Argument) throws
422-
// TODO(distributed): offer recordArgument(label:type:)
421+
// mutating func recordArgument<Argument: SerializationRequirement>(name: String, _ argument: Argument) throws
423422

424423
/// Record the error type of the distributed method.
425424
/// This method will not be invoked if the target is not throwing.

test/Distributed/Inputs/BadDistributedActorSystems.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -257,8 +257,8 @@ public struct FakeInvocationEncoder : DistributedTargetInvocationEncoder {
257257
genericSubs.append(type)
258258
}
259259

260-
public mutating func recordArgument<Argument: SerializationRequirement>(_ argument: Argument) throws {
261-
print(" > encode argument: \(argument)")
260+
public mutating func recordArgument<Argument: SerializationRequirement>(name: String, _ argument: Argument) throws {
261+
print(" > encode argument name:\(name), argument: \(argument)")
262262
arguments.append(argument)
263263
}
264264
public mutating func recordErrorType<E: Error>(_ type: E.Type) throws {

test/Distributed/Inputs/FakeDistributedActorSystems.swift

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -267,18 +267,22 @@ public struct FakeInvocationEncoder : DistributedTargetInvocationEncoder {
267267
genericSubs.append(type)
268268
}
269269

270-
public mutating func recordArgument<Argument: SerializationRequirement>(_ argument: Argument) throws {
271-
print(" > encode argument: \(argument)")
270+
public mutating func recordArgument<Argument: SerializationRequirement>(
271+
name: String, _ argument: Argument) throws {
272+
print(" > encode argument name:\(name), value: \(argument)")
272273
arguments.append(argument)
273274
}
275+
274276
public mutating func recordErrorType<E: Error>(_ type: E.Type) throws {
275277
print(" > encode error type: \(String(reflecting: type))")
276278
self.errorType = type
277279
}
280+
278281
public mutating func recordReturnType<R: SerializationRequirement>(_ type: R.Type) throws {
279282
print(" > encode return type: \(String(reflecting: type))")
280283
self.returnType = type
281284
}
285+
282286
public mutating func doneRecording() throws {
283287
print(" > done recording")
284288
}

test/Distributed/Inputs/dynamic_replacement_da_decl.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ struct FakeInvocationEncoder: DistributedTargetInvocationEncoder {
8787
typealias SerializationRequirement = Codable
8888

8989
mutating func recordGenericSubstitution<T>(_ type: T.Type) throws {}
90-
mutating func recordArgument<Argument: SerializationRequirement>(_ argument: Argument) throws {}
90+
mutating func recordArgument<Argument: SerializationRequirement>(name: String, _ argument: Argument) throws {}
9191
mutating func recordReturnType<R: SerializationRequirement>(_ type: R.Type) throws {}
9292
mutating func recordErrorType<E: Error>(_ type: E.Type) throws {}
9393
mutating func doneRecording() throws {}

test/Distributed/Runtime/distributed_actor_decode.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ class FakeInvocation: DistributedTargetInvocationEncoder, DistributedTargetInvoc
105105
typealias SerializationRequirement = Codable
106106

107107
func recordGenericSubstitution<T>(_ type: T.Type) throws {}
108-
func recordArgument<Argument: SerializationRequirement>(_ argument: Argument) throws {}
108+
func recordArgument<Argument: SerializationRequirement>(name: String, _ argument: Argument) throws {}
109109
func recordReturnType<R: SerializationRequirement>(_ type: R.Type) throws {}
110110
func recordErrorType<E: Error>(_ type: E.Type) throws {}
111111
func doneRecording() throws {}

test/Distributed/Runtime/distributed_actor_deinit.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ class FakeDistributedInvocation: DistributedTargetInvocationEncoder, Distributed
123123
typealias SerializationRequirement = Codable
124124

125125
func recordGenericSubstitution<T>(_ type: T.Type) throws { }
126-
func recordArgument<Argument: SerializationRequirement>(_ argument: Argument) throws { }
126+
func recordArgument<Argument: SerializationRequirement>(name: String, _ argument: Argument) throws { }
127127
func recordReturnType<R: SerializationRequirement>(_ type: R.Type) throws { }
128128
func recordErrorType<E: Error>(_ type: E.Type) throws { }
129129
func doneRecording() throws { }

test/Distributed/Runtime/distributed_actor_func_calls_remoteCall_genericFunc.swift

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ func test() async throws {
4040

4141
let r1 = try await ref.generic("Caplin")
4242
// CHECK: > encode generic sub: Swift.String
43-
// CHECK: > encode argument: Caplin
43+
// CHECK: > encode argument name:, value: Caplin
4444
// CHECK: > encode return type: Swift.String
4545
// CHECK: > done recording
4646
// CHECK: >> remoteCall: on:main.Greeter, target:main.Greeter.generic(_:), invocation:FakeInvocationEncoder(genericSubs: [Swift.String], arguments: ["Caplin"], returnType: Optional(Swift.String), errorType: nil), throwing:Swift.Never, returning:Swift.String
@@ -54,9 +54,9 @@ func test() async throws {
5454
)
5555
// CHECK: > encode generic sub: Swift.String
5656
// CHECK: > encode generic sub: Swift.Int
57-
// CHECK: > encode argument: 2.0
58-
// CHECK: > encode argument: Caplin
59-
// CHECK: > encode argument: [1, 2, 3]
57+
// CHECK: > encode argument name:strict, value: 2.0
58+
// CHECK: > encode argument name:, value: Caplin
59+
// CHECK: > encode argument name:, value: [1, 2, 3]
6060
// CHECK: > encode return type: Swift.String
6161
// CHECK: > done recording
6262
// CHECK: >> remoteCall: on:main.Greeter, target:main.Greeter.generic2(strict:_:_:), invocation:FakeInvocationEncoder(genericSubs: [Swift.String, Swift.Int], arguments: [2.0, "Caplin", [1, 2, 3]], returnType: Optional(Swift.String), errorType: nil), throwing:Swift.Never, returning:Swift.String

0 commit comments

Comments
 (0)