Skip to content

Commit 59ff689

Browse files
committed
[Distributed] IRGen: Teach distributed thunk accessor to load witness tables (when needed)
1 parent dc978f2 commit 59ff689

File tree

4 files changed

+116
-7
lines changed

4 files changed

+116
-7
lines changed

lib/IRGen/GenDistributed.cpp

Lines changed: 67 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,16 @@ class DistributedAccessor {
121121
const SILParameterInfo &param,
122122
Explosion &arguments);
123123

124+
/// Load witness table addresses (if any) from the given buffer
125+
/// into the given argument explosion.
126+
///
127+
/// Number of witnesses to load is provided by \c numTables but
128+
/// it's checked against the number of \c expectedWitnessTables.
129+
void emitLoadOfWitnessTables(llvm::Value *witnessTables,
130+
llvm::Value *numTables,
131+
unsigned expectedWitnessTables,
132+
Explosion &arguments);
133+
124134
FunctionPointer getPointerToTarget() const;
125135

126136
Callee getCalleeForDistributedTarget(llvm::Value *self) const;
@@ -140,6 +150,11 @@ static CanSILFunctionType getAccessorType(IRGenModule &IGM,
140150
ParameterConvention::Direct_Guaranteed);
141151
};
142152

153+
auto getUIntParameter = [&]() {
154+
return SILParameterInfo(Context.getUIntType()->getCanonicalType(),
155+
ParameterConvention::Direct_Guaranteed);
156+
};
157+
143158
// `self` of the distributed actor is going to be passed as an argument
144159
// to this accessor function.
145160
auto extInfo = SILExtInfoBuilder()
@@ -162,6 +177,8 @@ static CanSILFunctionType getAccessorType(IRGenModule &IGM,
162177
/*argumentTypes=*/getRawPointerParameter(),
163178
/*resultBuffer=*/getRawPointerParameter(),
164179
/*substitutions=*/getRawPointerParameter(),
180+
/*witnessTables=*/getRawPointerParameter(),
181+
/*numWitnessTables=*/getUIntParameter(),
165182
/*actor=*/targetTy->getParameters().back()},
166183
/*Yields=*/{},
167184
/*Results=*/{},
@@ -432,6 +449,37 @@ DistributedAccessor::loadArgument(unsigned argumentIdx,
432449
return {alignedOffset.getAddress(), typeInfo.getSize(IGF, paramTy)};
433450
}
434451

452+
void DistributedAccessor::emitLoadOfWitnessTables(llvm::Value *witnessTables,
453+
llvm::Value *numTables,
454+
unsigned expectedWitnessTables,
455+
Explosion &arguments) {
456+
auto contBB = IGF.createBasicBlock("");
457+
auto unreachableBB = IGF.createBasicBlock("incorrect-witness-tables");
458+
459+
auto incorrectNum = IGF.Builder.CreateICmpNE(
460+
numTables, llvm::ConstantInt::get(IGM.SizeTy, expectedWitnessTables));
461+
462+
// Make sure that we have a correct number of witness tables provided to us.
463+
IGF.Builder.CreateCondBr(incorrectNum, unreachableBB, contBB);
464+
{
465+
IGF.Builder.emitBlock(unreachableBB);
466+
IGF.Builder.CreateUnreachable();
467+
}
468+
469+
IGF.Builder.emitBlock(contBB);
470+
471+
witnessTables = IGF.Builder.CreateBitCast(witnessTables, IGM.Int8PtrPtrTy);
472+
473+
for (unsigned i = 0, n = expectedWitnessTables; i != n; ++i) {
474+
auto offset = Size(i * IGM.getPointerSize());
475+
auto alignment = IGM.getPointerAlignment();
476+
477+
auto witnessTableAddr = IGF.emitAddressAtOffset(
478+
witnessTables, Offset(offset), IGM.Int8PtrTy, Alignment(alignment));
479+
arguments.add(witnessTableAddr.getAddress());
480+
}
481+
}
482+
435483
void DistributedAccessor::emit() {
436484
auto targetTy = Target->getLoweredFunctionType();
437485
SILFunctionConventions targetConv(targetTy, IGF.getSILModule());
@@ -457,6 +505,10 @@ void DistributedAccessor::emit() {
457505
auto *resultBuffer = params.claimNext();
458506
// UnsafeRawPointer that represents a list of substitutions
459507
auto *substitutions = params.claimNext();
508+
// UnsafeRawPointer that represents a list of witness tables
509+
auto *witnessTables = params.claimNext();
510+
// Integer that represented the number of witness tables
511+
auto *numWitnessTables = params.claimNext();
460512
// Reference to a `self` of the actor to be called.
461513
auto *actorSelf = params.claimNext();
462514

@@ -490,12 +542,22 @@ void DistributedAccessor::emit() {
490542
computeArguments(argBuffer, argTypes, arguments);
491543

492544
// Add all of the substitutions to the explosion
493-
if (auto *environment = Target->getGenericEnvironment()) {
545+
if (auto *genericEnvironment = Target->getGenericEnvironment()) {
494546
// swift.type **
495547
llvm::Value *substitutionBuffer =
496548
IGF.Builder.CreateBitCast(substitutions, IGM.TypeMetadataPtrPtrTy);
497549

498-
for (unsigned index : indices(environment->getGenericParams())) {
550+
// Collect the generic arguments expected by the distributed thunk.
551+
// We need this to determine the expected number of witness tables
552+
// to load from the buffer provided by the caller.
553+
llvm::SmallVector<llvm::Type *, 4> targetGenericArguments;
554+
expandPolymorphicSignature(IGM, targetTy, targetGenericArguments);
555+
556+
unsigned numGenericArgs = genericEnvironment->getGenericParams().size();
557+
unsigned expectedWitnessTables =
558+
targetGenericArguments.size() - numGenericArgs;
559+
560+
for (unsigned index = 0; index < numGenericArgs; ++index) {
499561
auto offset =
500562
Size(index * IGM.DataLayout.getTypeAllocSize(IGM.TypeMetadataPtrTy));
501563
auto alignment =
@@ -506,6 +568,9 @@ void DistributedAccessor::emit() {
506568
IGM.TypeMetadataPtrTy, Alignment(alignment));
507569
arguments.add(IGF.Builder.CreateLoad(substitution, "substitution"));
508570
}
571+
572+
emitLoadOfWitnessTables(witnessTables, numWitnessTables,
573+
expectedWitnessTables, arguments);
509574
}
510575

511576
// Step two, let's form and emit a call to the distributed method

stdlib/public/Concurrency/Actor.cpp

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1724,7 +1724,9 @@ void *swift_distributed_get_generic_environment(const char *targetNameStart,
17241724
/// argumentBuffer: Builtin.RawPointer,
17251725
/// argumentTypes: UnsafeBufferPointer<Any.Type>,
17261726
/// resultBuffer: Builtin.RawPointer,
1727-
/// substitutions: UnsafeRawPointer?
1727+
/// substitutions: UnsafeRawPointer?,
1728+
/// witnessTables: UnsafeRawPointer?,
1729+
/// numWitnessTables: UInt
17281730
/// ) async throws
17291731
using TargetExecutorSignature =
17301732
AsyncSignature<void(/*on=*/DefaultActor *,
@@ -1733,6 +1735,8 @@ using TargetExecutorSignature =
17331735
/*argumentTypes=*/const Metadata *const *,
17341736
/*resultBuffer=*/void *,
17351737
/*substitutions=*/void *,
1738+
/*witnessTables=*/void **,
1739+
/*numWitnessTables=*/size_t,
17361740
/*resumeFunc=*/TaskContinuationFunction *,
17371741
/*callContext=*/AsyncContext *),
17381742
/*throws=*/true>;
@@ -1747,12 +1751,16 @@ TargetExecutorSignature::FunctionType swift_distributed_execute_target;
17471751
/// - a list of all argument types (with substitutions applied)
17481752
/// - a result buffer as a raw pointer
17491753
/// - a list of substitutions
1754+
/// - a list of witness tables
1755+
/// - a number of witness tables in the buffer
17501756
/// - a reference to an actor to execute method on.
17511757
using DistributedAccessorSignature =
17521758
AsyncSignature<void(/*argumentBuffer=*/void *,
17531759
/*argumentTypes=*/const Metadata *const *,
17541760
/*resultBuffer=*/void *,
17551761
/*substitutions=*/void *,
1762+
/*witnessTables=*/void **,
1763+
/*numWitnessTables=*/size_t,
17561764
/*actor=*/HeapObject *),
17571765
/*throws=*/true>;
17581766

@@ -1783,6 +1791,8 @@ void ::swift_distributed_execute_target(
17831791
const Metadata *const *argumentTypes,
17841792
void *resultBuffer,
17851793
void *substitutions,
1794+
void **witnessTables,
1795+
size_t numWitnessTables,
17861796
TaskContinuationFunction *resumeFunc,
17871797
AsyncContext *callContext) {
17881798
auto *accessor = findDistributedAccessor(targetNameStart, targetNameLength);
@@ -1820,5 +1830,7 @@ void ::swift_distributed_execute_target(
18201830
argumentBuffer, argumentTypes,
18211831
resultBuffer,
18221832
substitutions,
1833+
witnessTables,
1834+
numWitnessTables,
18231835
actor);
18241836
}

stdlib/public/Distributed/DistributedActorSystem.swift

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -324,7 +324,9 @@ extension DistributedActorSystem {
324324
argumentBuffer: hargs.buffer._rawValue, // TODO(distributed): pass the invocationDecoder instead, so we can decode inside IRGen directly into the argument explosion
325325
argumentTypes: argumentTypesBuffer.baseAddress!._rawValue,
326326
resultBuffer: resultBuffer._rawValue,
327-
substitutions: UnsafeRawPointer(substitutionsBuffer)
327+
substitutions: UnsafeRawPointer(substitutionsBuffer),
328+
witnessTables: witnessTablesBuffer,
329+
numWitnessTables: UInt(numWitnessTables)
328330
)
329331

330332
func onReturn<R>(_ resultTy: R.Type) async throws {
@@ -345,7 +347,9 @@ func _executeDistributedTarget(
345347
argumentBuffer: Builtin.RawPointer, // HeterogeneousBuffer of arguments
346348
argumentTypes: Builtin.RawPointer,
347349
resultBuffer: Builtin.RawPointer,
348-
substitutions: UnsafeRawPointer?
350+
substitutions: UnsafeRawPointer?,
351+
witnessTables: UnsafeRawPointer?,
352+
numWitnessTables: UInt
349353
) async throws
350354

351355
// ==== ----------------------------------------------------------------------------------------------------------------

test/Distributed/Runtime/distributed_actor_remoteCall.swift

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,19 @@ enum E : Sendable, Codable {
3434
case foo, bar
3535
}
3636

37+
struct S<T: Codable> : Codable {
38+
var data: T
39+
}
40+
3741
@_silgen_name("swift_distributed_actor_is_remote")
3842
func __isRemoteActor(_ actor: AnyObject) -> Bool
3943

4044
distributed actor Greeter {
45+
distributed func generic1<T: Codable, U: Codable>(t: T, u: U) {
46+
print("---> T = \(t), type(of:) = \(type(of: t))")
47+
print("---> U = \(u), type(of:) = \(type(of: u))")
48+
}
49+
4150
distributed func empty() {
4251
}
4352

@@ -60,7 +69,6 @@ distributed actor Greeter {
6069
distributed func enumResult() -> E {
6170
.bar
6271
}
63-
6472
}
6573

6674

@@ -158,7 +166,6 @@ struct FakeInvocation: DistributedTargetInvocationEncoder, DistributedTargetInvo
158166
fatalError("Cannot cast argument\(anyArgument) to expected \(Argument.self)")
159167
}
160168

161-
print(" > decode argument: \(argument)")
162169
pointer.initialize(to: argument)
163170
argumentIndex += 1
164171
}
@@ -194,6 +201,8 @@ let answerName = "$s4main7GreeterC6answerSiyFTE"
194201
let largeResultName = "$s4main7GreeterC11largeResultAA11LargeStructVyFTE"
195202
let enumResultName = "$s4main7GreeterC10enumResultAA1EOyFTE"
196203
let echoName = "$s4main7GreeterC4echo4name3ageS2S_SitFTE"
204+
// <T: Codable, U: Codable>(t: T, u: U)
205+
let generic1Name = "$s4main7GreeterC8generic11t1uyx_q_tSeRzSERzSeR_SER_r0_lFTE"
197206

198207
func test() async throws {
199208
let system = FakeActorSystem()
@@ -255,6 +264,25 @@ func test() async throws {
255264
)
256265
// CHECK: RETURN: Echo: name: Caplin, age: 42
257266

267+
var generic1Invocation = system.makeInvocationEncoder()
268+
269+
try generic1Invocation.recordGenericSubstitution(Int.self)
270+
try generic1Invocation.recordGenericSubstitution(String.self)
271+
try generic1Invocation.recordArgument(42)
272+
try generic1Invocation.recordArgument("Ultimate Question!")
273+
try generic1Invocation.doneRecording()
274+
275+
try await system.executeDistributedTarget(
276+
on: local,
277+
mangledTargetName: generic1Name,
278+
invocationDecoder: &generic1Invocation,
279+
handler: FakeResultHandler()
280+
)
281+
282+
// CHECK: ---> T = 42, type(of:) = Int
283+
// CHECK-NEXT: ---> U = Ultimate Question!, type(of:) = String
284+
// CHECK-NEXT: RETURN: ()
285+
258286
print("done")
259287
// CHECK-NEXT: done
260288
}

0 commit comments

Comments
 (0)