Skip to content

Commit 72fca4c

Browse files
committed
[Distributed] IRGen: Make distributed accessors generic over invocation decoder type
Since invocation decoder is no longer required to be a class-bound type, accessors have to be made generic over actual type of decoder. Each accessor still knows what is the expected type of a decoder is (based on the distributed actor it's associated with), so it still does direct method calls but it has to load data decoder instance first iff the decoder is a class.
1 parent 9bf03b6 commit 72fca4c

File tree

1 file changed

+142
-82
lines changed

1 file changed

+142
-82
lines changed

lib/IRGen/GenDistributed.cpp

Lines changed: 142 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -72,21 +72,31 @@ struct ArgumentDecoderInfo {
7272
/// The instance of the decoder this information belongs to.
7373
llvm::Value *Decoder;
7474

75-
/// The type of `decodeNextArgument` method.
76-
CanSILFunctionType MethodType;
77-
7875
/// The pointer to `decodeNextArgument` method which
7976
/// could be used to form a call to it.
8077
FunctionPointer MethodPtr;
8178

79+
/// The type of `decodeNextArgument` method.
80+
CanSILFunctionType MethodType;
81+
8282
/// Protocol requirements associated with the generic
8383
/// parameter `Argument` of this decode method.
8484
GenericSignature::RequiredProtocols ProtocolRequirements;
8585

86-
ArgumentDecoderInfo(llvm::Value *decoder, CanSILFunctionType decodeMethodTy,
87-
FunctionPointer decodePtr)
88-
: Decoder(decoder), MethodType(decodeMethodTy), MethodPtr(decodePtr),
89-
ProtocolRequirements(findProtocolRequirements(decodeMethodTy)) {}
86+
// Witness metadata for conformance to DistributedTargetInvocationDecoder
87+
// protocol.
88+
WitnessMetadata Witness;
89+
90+
ArgumentDecoderInfo(llvm::Value *decoder, llvm::Value *decoderType,
91+
llvm::Value *decoderWitnessTable,
92+
FunctionPointer decodeNextArgumentPtr,
93+
CanSILFunctionType decodeNextArgumentTy)
94+
: Decoder(decoder), MethodPtr(decodeNextArgumentPtr),
95+
MethodType(decodeNextArgumentTy),
96+
ProtocolRequirements(findProtocolRequirements(decodeNextArgumentTy)) {
97+
Witness.SelfMetadata = decoderType;
98+
Witness.SelfWitnessTable = decoderWitnessTable;
99+
}
90100

91101
CanSILFunctionType getMethodType() const { return MethodType; }
92102

@@ -131,8 +141,8 @@ class DistributedAccessor {
131141
void emit();
132142

133143
private:
134-
void decodeArguments(llvm::Value *decoder, llvm::Value *argumentTypes,
135-
Explosion &arguments);
144+
void decodeArguments(const ArgumentDecoderInfo &decoder,
145+
llvm::Value *argumentTypes, Explosion &arguments);
136146

137147
/// Load an argument value from the given decoder \c decoder
138148
/// to the given explosion \c arguments. Information describing
@@ -163,8 +173,11 @@ class DistributedAccessor {
163173

164174
Callee getCalleeForDistributedTarget(llvm::Value *self) const;
165175

166-
/// Given an instance of argument decoder, find `decodeNextArgument`.
167-
ArgumentDecoderInfo findArgumentDecoder(llvm::Value *decoder);
176+
/// Given an instance of invocation decoder, its type metadata,
177+
/// and protocol witness table, find `decodeNextArgument`.
178+
ArgumentDecoderInfo findArgumentDecoder(llvm::Value *decoder,
179+
llvm::Value *decoderTy,
180+
llvm::Value *witnessTable);
168181

169182
/// The result type of the accessor.
170183
SILType getResultType() const;
@@ -183,61 +196,80 @@ static NominalTypeDecl *getDistributedActorOf(SILFunction *thunk) {
183196
}
184197

185198
/// Compute a type of a distributed method accessor function based
186-
/// on the provided distributed method.
199+
/// on the provided distributed target.
187200
static CanSILFunctionType getAccessorType(IRGenModule &IGM,
188201
SILFunction *Target) {
189202
auto &Context = IGM.Context;
190203

191-
auto getInvocationDecoderParameter = [&]() {
192-
auto *actor = getDistributedActorOf(Target);
193-
auto *decoder = Context.getDistributedActorInvocationDecoder(actor);
194-
auto decoderTy = decoder->getInterfaceType()->getMetatypeInstanceType();
195-
auto paramType = IGM.getLoweredType(decoderTy);
196-
return SILParameterInfo(paramType.getASTType(),
197-
ParameterConvention::Direct_Guaranteed);
198-
};
199-
200-
auto getRawPointerParameter = [&]() {
201-
auto ptrType = Context.getUnsafeRawPointerType();
202-
return SILParameterInfo(ptrType->getCanonicalType(),
203-
ParameterConvention::Direct_Unowned);
204-
};
205-
206-
auto getUIntParameter = [&]() {
207-
return SILParameterInfo(Context.getUIntType()->getCanonicalType(),
208-
ParameterConvention::Direct_Unowned);
209-
};
210-
211-
// `self` of the distributed actor is going to be passed as an argument
212-
// to this accessor function.
213-
auto extInfo = SILExtInfoBuilder()
214-
.withRepresentation(SILFunctionTypeRepresentation::Thin)
215-
.withAsync()
216-
.build();
204+
// func __accessor__<D: DistributedTargetInvocationDecoder>(
205+
// inout D, <- invocation decoder
206+
// UnsafeRawPointer, <- argument types
207+
// UnsafeRawPointer, <- result buffer
208+
// UnsafeRawPointer?, <- generic parameter substitutions
209+
// UnsafeRawPointer?, <- witness tables
210+
// UInt, <- number of witness tables
211+
// <actor>
212+
// ) async throws
213+
214+
SmallVector<GenericFunctionType::Param, 8> parameters;
215+
216+
// A generic parameter that represents instance of invocation decoder.
217+
auto *decoderType =
218+
GenericTypeParamType::get(/*isTypeSequence=*/false,
219+
/*depth=*/0, /*index=*/0, Context);
220+
221+
// decoder
222+
parameters.push_back(GenericFunctionType::Param(
223+
decoderType,
224+
/*label=*/Identifier(),
225+
/*flags=*/ParameterTypeFlags().withInOut(true)));
226+
227+
// argument type buffer
228+
parameters.push_back(
229+
GenericFunctionType::Param(Context.getUnsafeRawPointerType()));
230+
231+
// result buffer
232+
parameters.push_back(
233+
GenericFunctionType::Param(Context.getUnsafeRawPointerType()));
234+
235+
// generic parameter substitutions
236+
parameters.push_back(
237+
GenericFunctionType::Param(Context.getUnsafeRawPointerType()));
238+
239+
// witness tables
240+
parameters.push_back(
241+
GenericFunctionType::Param(Context.getUnsafeRawPointerType()));
242+
243+
// number of witness tables
244+
parameters.push_back(GenericFunctionType::Param(Context.getUIntType()));
245+
246+
// actor
247+
{
248+
auto targetTy = Target->getLoweredFunctionType();
249+
auto actorLoc = targetTy->getParameters().back();
217250

218-
auto targetTy = Target->getLoweredFunctionType();
251+
parameters.push_back(
252+
GenericFunctionType::Param(actorLoc.getInterfaceType()));
253+
}
254+
255+
auto decoderProtocolTy =
256+
Context
257+
.getProtocol(KnownProtocolKind::DistributedTargetInvocationDecoder)
258+
->getDeclaredInterfaceType();
219259

220-
assert(targetTy->isAsync());
221-
assert(targetTy->hasErrorResult());
222-
223-
// Accessor gets argument/result value buffer and a reference to `self` of
224-
// the actor and produces a call to the distributed thunk forwarding
225-
// its result(s) out.
226-
return SILFunctionType::get(
227-
/*genericSignature=*/nullptr, extInfo, SILCoroutineKind::None,
228-
ParameterConvention::Direct_Guaranteed,
229-
{/*argumentDecoder=*/getInvocationDecoderParameter(),
230-
/*argumentTypes=*/getRawPointerParameter(),
231-
/*resultBuffer=*/getRawPointerParameter(),
232-
/*substitutions=*/getRawPointerParameter(),
233-
/*witnessTables=*/getRawPointerParameter(),
234-
/*numWitnessTables=*/getUIntParameter(),
235-
/*actor=*/targetTy->getParameters().back()},
236-
/*Yields=*/{},
237-
/*Results=*/{},
238-
/*ErrorResult=*/targetTy->getErrorResult(),
239-
/*patternSubs=*/SubstitutionMap(),
240-
/*invocationSubs=*/SubstitutionMap(), Context);
260+
auto signature = GenericSignature::get(
261+
{decoderType},
262+
{{RequirementKind::Conformance, decoderType, decoderProtocolTy}});
263+
264+
auto accessorTy = GenericFunctionType::get(
265+
signature, parameters, Context.TheEmptyTupleType,
266+
ASTExtInfoBuilder()
267+
.withRepresentation(FunctionTypeRepresentation::Thin)
268+
.withAsync()
269+
.withThrows()
270+
.build());
271+
272+
return IGM.getLoweredType(accessorTy).castTo<SILFunctionType>();
241273
}
242274

243275
llvm::Function *
@@ -279,7 +311,7 @@ DistributedAccessor::DistributedAccessor(IRGenFunction &IGF,
279311
IGM.DebugInfo->emitArtificialFunction(IGF, IGF.CurFn);
280312
}
281313

282-
void DistributedAccessor::decodeArguments(llvm::Value *decoder,
314+
void DistributedAccessor::decodeArguments(const ArgumentDecoderInfo &decoder,
283315
llvm::Value *argumentTypes,
284316
Explosion &arguments) {
285317
auto fnType = Target->getLoweredFunctionType();
@@ -295,10 +327,6 @@ void DistributedAccessor::decodeArguments(llvm::Value *decoder,
295327
argumentTypes =
296328
IGF.Builder.CreateBitCast(argumentTypes, IGM.TypeMetadataPtrPtrTy);
297329

298-
/// The argument decoder associated with the distributed actor
299-
/// this accessor belong to.
300-
ArgumentDecoderInfo decoderInfo = findArgumentDecoder(decoder);
301-
302330
for (unsigned i = 0, n = parameters.size(); i != n; ++i) {
303331
const auto &param = parameters[i];
304332
auto paramTy = param.getSILStorageInterfaceType();
@@ -326,7 +354,7 @@ void DistributedAccessor::decodeArguments(llvm::Value *decoder,
326354
auto *argumentTy = IGF.Builder.CreateLoad(typeLoc, "arg_type");
327355

328356
// Decode and load argument value using loaded type metadata.
329-
decodeArgument(i, decoderInfo, argumentTy, param, arguments);
357+
decodeArgument(i, decoder, argumentTy, param, arguments);
330358
}
331359
}
332360

@@ -574,6 +602,10 @@ void DistributedAccessor::emit() {
574602
auto *numWitnessTables = params.claimNext();
575603
// Reference to a `self` of the actor to be called.
576604
auto *actorSelf = params.claimNext();
605+
// Metadata that represents passed in the invocation decoder.
606+
auto *decoderType = params.claimNext();
607+
// Witness table for decoder conformance to DistributedTargetInvocationDecoder
608+
auto *decoderProtocolWitness = params.claimNext();
577609

578610
GenericContextScope scope(IGM, targetTy->getInvocationGenericSignature());
579611

@@ -600,9 +632,18 @@ void DistributedAccessor::emit() {
600632
arguments.add(typedResultBuffer);
601633
}
602634

603-
// Step one is to load all of the data from argument buffer,
604-
// so it could be forwarded to the distributed method.
605-
decodeArguments(argDecoder, argTypes, arguments);
635+
// There is always at least one parameter associated with accessor - `self`
636+
// of the distributed actor.
637+
if (targetTy->getNumParameters() > 1) {
638+
/// The argument decoder associated with the distributed actor
639+
/// this accessor belong to.
640+
ArgumentDecoderInfo decoder =
641+
findArgumentDecoder(argDecoder, decoderType, decoderProtocolWitness);
642+
643+
// Step one is to load all of the data from argument buffer,
644+
// so it could be forwarded to the distributed method.
645+
decodeArguments(decoder, argTypes, arguments);
646+
}
606647

607648
// Add all of the substitutions to the explosion
608649
if (auto *genericEnvironment = Target->getGenericEnvironment()) {
@@ -700,23 +741,48 @@ DistributedAccessor::getCalleeForDistributedTarget(llvm::Value *self) const {
700741
return {std::move(info), getPointerToTarget(), self};
701742
}
702743

703-
ArgumentDecoderInfo
704-
DistributedAccessor::findArgumentDecoder(llvm::Value *decoder) {
744+
ArgumentDecoderInfo DistributedAccessor::findArgumentDecoder(
745+
llvm::Value *decoder, llvm::Value *decoderTy, llvm::Value *witnessTable) {
705746
auto *actor = getDistributedActorOf(Target);
706747
auto expansionContext = IGM.getMaximalTypeExpansionContext();
707748

708749
auto *decodeFn = IGM.Context.getDistributedActorArgumentDecodingMethod(actor);
709750
assert(decodeFn && "no suitable decoder?");
710751

711-
auto *decoderDecl = decodeFn->getDeclContext()->getSelfNominalTypeDecl();
712752
auto methodTy = IGM.getSILTypes().getConstantFunctionType(
713753
expansionContext, SILDeclRef(decodeFn));
714754

715755
auto fpKind = FunctionPointerKind::defaultAsync();
716756
auto signature = IGM.getSignature(methodTy, fpKind);
717757

718758
// If the decoder class is `final`, let's emit a direct reference.
719-
if (decoderDecl->isFinal()) {
759+
auto *decoderDecl = decodeFn->getDeclContext()->getSelfNominalTypeDecl();
760+
761+
// If decoder is a class, need to load it first because generic parameter
762+
// is passed indirectly. This is good for structs and enums because
763+
// `decodeNextArgument` is a mutating method, but not for classes because
764+
// in that case heap object is mutated directly.
765+
if (isa<ClassDecl>(decoderDecl)) {
766+
auto selfTy = methodTy->getSelfParameter().getSILStorageType(
767+
IGM.getSILModule(), methodTy, expansionContext);
768+
769+
auto &classTI = IGM.getTypeInfo(selfTy).as<ClassTypeInfo>();
770+
auto &classLayout = classTI.getClassLayout(IGM, selfTy,
771+
/*forBackwardDeployment=*/false);
772+
773+
llvm::Value *typedDecoderPtr = IGF.Builder.CreateBitCast(
774+
decoder, classLayout.getType()->getPointerTo()->getPointerTo());
775+
776+
Explosion instance;
777+
778+
classTI.loadAsTake(IGF, {typedDecoderPtr, classTI.getBestKnownAlignment()},
779+
instance);
780+
781+
decoder = instance.claimNext();
782+
}
783+
784+
if (isa<StructDecl>(decoderDecl) || isa<EnumDecl>(decoderDecl) ||
785+
decoderDecl->isFinal()) {
720786
auto *decodeSIL = IGM.getSILModule().lookUpFunction(SILDeclRef(decodeFn));
721787
auto *fnPtr = IGM.getAddrOfSILFunction(decodeSIL, NotForDefinition,
722788
/*isDynamicallyReplacible=*/false);
@@ -725,19 +791,13 @@ DistributedAccessor::findArgumentDecoder(llvm::Value *decoder) {
725791
classifyFunctionPointerKind(decodeSIL), fnPtr,
726792
/*secondaryValue=*/nullptr, signature);
727793

728-
return {decoder, methodTy, methodPtr};
794+
return {decoder, decoderTy, witnessTable, methodPtr, methodTy};
729795
}
730796

731-
auto selfTy = methodTy->getSelfParameter().getSILStorageType(
732-
IGM.getSILModule(), methodTy, expansionContext);
733-
734-
auto *metadata = emitHeapMetadataRefForHeapObject(IGF, decoder, selfTy,
735-
/*suppress cast*/ true);
736-
737797
auto methodPtr =
738-
emitVirtualMethodValue(IGF, metadata, SILDeclRef(decodeFn), methodTy);
798+
emitVirtualMethodValue(IGF, decoderTy, SILDeclRef(decodeFn), methodTy);
739799

740-
return {decoder, methodTy, methodPtr};
800+
return {decoder, decoderTy, witnessTable, methodPtr, methodTy};
741801
}
742802

743803
SILType DistributedAccessor::getResultType() const {

0 commit comments

Comments
 (0)