Skip to content

Commit c6820a5

Browse files
committed
[IRGen] Distributed: Expand distributed actor accessor to support protocol requirements
Given the following protocol: ``` protocol Greeter : DistributedActor { distributed func greet() } ``` The changes make it possible to synthesize a distributed accessor thunk for the requirement `greet` which would be dispatched to the underlying concrete actor implementation at runtime.
1 parent 0f41071 commit c6820a5

File tree

2 files changed

+140
-67
lines changed

2 files changed

+140
-67
lines changed

lib/IRGen/GenDistributed.cpp

Lines changed: 134 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,19 @@ llvm::Value *irgen::emitDistributedActorInitializeRemote(
7070

7171
namespace {
7272

73+
using ThunkOrRequirement = llvm::PointerUnion<SILFunction *, AbstractFunctionDecl *>;
74+
75+
static LinkEntity
76+
getAccessorLinking(ThunkOrRequirement accessorFor) {
77+
if (auto *method = accessorFor.dyn_cast<SILFunction *>()) {
78+
assert(method->isDistributed());
79+
return LinkEntity::forDistributedTargetAccessor(method);
80+
}
81+
82+
auto *requirement = accessorFor.get<AbstractFunctionDecl *>();
83+
return LinkEntity::forDistributedTargetAccessor(requirement);
84+
}
85+
7386
struct ArgumentDecoderInfo {
7487
/// The instance of the decoder this information belongs to.
7588
llvm::Value *Decoder;
@@ -128,32 +141,47 @@ struct ArgumentDecoderInfo {
128141
struct AccessorTarget {
129142
private:
130143
IRGenFunction &IGF;
131-
SILFunction *Target;
144+
ThunkOrRequirement Target;
132145

133146
CanSILFunctionType Type;
134147

148+
mutable std::optional<WitnessMetadata> Witness;
149+
135150
public:
136-
AccessorTarget(IRGenFunction &IGF, SILFunction *target)
137-
: IGF(IGF), Target(target), Type(target->getLoweredFunctionType()) {}
151+
AccessorTarget(IRGenFunction &IGF, ThunkOrRequirement target)
152+
: IGF(IGF), Target(target) {
153+
if (auto *thunk = target.dyn_cast<SILFunction *>()) {
154+
Type = thunk->getLoweredFunctionType();
155+
} else {
156+
auto *requirement = target.get<AbstractFunctionDecl *>();
157+
Type = IGF.IGM.getSILTypes().getConstantFunctionType(
158+
IGF.IGM.getMaximalTypeExpansionContext(),
159+
SILDeclRef(requirement).asDistributed());
160+
}
161+
}
138162

139-
DeclContext *getDeclContext() const { return Target->getDeclContext(); }
163+
DeclContext *getDeclContext() const {
164+
if (auto *thunk = Target.dyn_cast<SILFunction *>())
165+
return thunk->getDeclContext();
166+
return Target.get<AbstractFunctionDecl *>();
167+
}
140168

141169
CanSILFunctionType getType() const { return Type; }
142170

143-
bool isGeneric() const { return Target->isGeneric(); }
171+
bool isGeneric() const {
172+
auto sig = Type->getInvocationGenericSignature();
173+
return sig && !sig->areAllParamsConcrete();
174+
}
144175

145-
Callee getCallee(llvm::Value *actorSelf) const;
176+
Callee getCallee(llvm::Value *actorSelf);
146177

147-
LinkEntity getLinking() const {
148-
return LinkEntity::forDistributedTargetAccessor(Target);
149-
}
178+
LinkEntity getLinking() const { return getAccessorLinking(Target); }
150179

151-
WitnessMetadata *getWitnessMetadata() const {
152-
return nullptr;
153-
}
180+
/// Witness metadata is computed lazily upon the first request.
181+
WitnessMetadata *getWitnessMetadata(llvm::Value *actorSelf);
154182

155183
public:
156-
FunctionPointer getPointerToTarget() const;
184+
FunctionPointer getPointerToTarget(llvm::Value *actorSelf);
157185
};
158186

159187
class DistributedAccessor {
@@ -175,7 +203,7 @@ class DistributedAccessor {
175203
SmallVector<std::pair<Address, /*type=*/llvm::Value *>, 4> LoadedArguments;
176204

177205
public:
178-
DistributedAccessor(IRGenFunction &IGF, SILFunction *target,
206+
DistributedAccessor(IRGenFunction &IGF, ThunkOrRequirement target,
179207
CanSILFunctionType accessorTy);
180208

181209
void emit();
@@ -313,27 +341,24 @@ static CanSILFunctionType getAccessorType(IRGenModule &IGM) {
313341
}
314342

315343
llvm::Function *
316-
IRGenModule::getAddrOfDistributedTargetAccessor(SILFunction *F,
344+
IRGenModule::getAddrOfDistributedTargetAccessor(LinkEntity accessor,
317345
ForDefinition_t forDefinition) {
318-
auto entity = LinkEntity::forDistributedTargetAccessor(F);
319-
320-
llvm::Function *&entry = GlobalFuncs[entity];
346+
llvm::Function *&entry = GlobalFuncs[accessor];
321347
if (entry) {
322348
if (forDefinition)
323-
updateLinkageForDefinition(*this, entry, entity);
349+
updateLinkageForDefinition(*this, entry, accessor);
324350
return entry;
325351
}
326352

327353
Signature signature = getSignature(getAccessorType(*this));
328-
LinkInfo link = LinkInfo::get(*this, entity, forDefinition);
354+
LinkInfo link = LinkInfo::get(*this, accessor, forDefinition);
329355

330356
return createFunction(*this, link, signature);
331357
}
332358

333-
void IRGenModule::emitDistributedTargetAccessor(SILFunction *target) {
334-
assert(target->isDistributed());
335-
336-
auto *f = getAddrOfDistributedTargetAccessor(target, ForDefinition);
359+
void IRGenModule::emitDistributedTargetAccessor(ThunkOrRequirement target) {
360+
auto *f = getAddrOfDistributedTargetAccessor(getAccessorLinking(target),
361+
ForDefinition);
337362

338363
if (!f->isDeclaration())
339364
return;
@@ -343,7 +368,7 @@ void IRGenModule::emitDistributedTargetAccessor(SILFunction *target) {
343368
}
344369

345370
DistributedAccessor::DistributedAccessor(IRGenFunction &IGF,
346-
SILFunction *target,
371+
ThunkOrRequirement target,
347372
CanSILFunctionType accessorTy)
348373
: IGM(IGF.IGM), IGF(IGF), Target(IGF, target), AccessorType(accessorTy),
349374
AsyncLayout(getAsyncContextLayout(IGM, AccessorType, AccessorType,
@@ -540,6 +565,35 @@ void DistributedAccessor::decodeArgument(unsigned argumentIdx,
540565
}
541566
}
542567

568+
static llvm::Value *lookupWitnessTable(IRGenFunction &IGF, llvm::Value *witness,
569+
ProtocolDecl *protocol) {
570+
assert(Lowering::TypeConverter::protocolRequiresWitnessTable(protocol));
571+
572+
auto &IGM = IGF.IGM;
573+
auto *protocolDescriptor = IGM.getAddrOfProtocolDescriptor(protocol);
574+
auto *witnessTable = IGF.Builder.CreateCall(
575+
IGM.getConformsToProtocolFunctionPointer(), {witness, protocolDescriptor});
576+
577+
auto failBB = IGF.createBasicBlock("missing-witness");
578+
auto contBB = IGF.createBasicBlock("");
579+
580+
auto isNull = IGF.Builder.CreateICmpEQ(
581+
witnessTable, llvm::ConstantPointerNull::get(IGM.WitnessTablePtrTy));
582+
IGF.Builder.CreateCondBr(isNull, failBB, contBB);
583+
584+
// This operation shouldn't fail because the compuler should have
585+
// checked that the given witness conforms to the protocol. If it
586+
// does fail then accessor should trap.
587+
{
588+
IGF.Builder.emitBlock(failBB);
589+
IGF.emitTrap("missing witness table", /*EmitUnreachable=*/true);
590+
}
591+
592+
IGF.Builder.emitBlock(contBB);
593+
594+
return witnessTable;
595+
}
596+
543597
void DistributedAccessor::lookupWitnessTables(
544598
llvm::Value *value, ArrayRef<ProtocolDecl *> protocols,
545599
Explosion &witnessTables) {
@@ -552,28 +606,7 @@ void DistributedAccessor::lookupWitnessTables(
552606
if (!Lowering::TypeConverter::protocolRequiresWitnessTable(protocol))
553607
continue;
554608

555-
auto *protocolDescriptor = IGM.getAddrOfProtocolDescriptor(protocol);
556-
auto *witnessTable =
557-
IGF.Builder.CreateCall(conformsToProtocol, {value, protocolDescriptor});
558-
559-
auto failBB = IGF.createBasicBlock("missing-witness");
560-
auto contBB = IGF.createBasicBlock("");
561-
562-
auto isNull = IGF.Builder.CreateICmpEQ(
563-
witnessTable, llvm::ConstantPointerNull::get(IGM.WitnessTablePtrTy));
564-
IGF.Builder.CreateCondBr(isNull, failBB, contBB);
565-
566-
// This operation shouldn't fail because runtime should have checked that
567-
// a particular argument type conforms to `SerializationRequirement`
568-
// of the distributed actor the decoder is used for. If it does fail
569-
// then accessor should trap.
570-
{
571-
IGF.Builder.emitBlock(failBB);
572-
IGF.emitTrap("missing witness table", /*EmitUnreachable=*/true);
573-
}
574-
575-
IGF.Builder.emitBlock(contBB);
576-
witnessTables.add(witnessTable);
609+
witnessTables.add(lookupWitnessTable(IGF, value, protocol));
577610
}
578611
}
579612

@@ -759,7 +792,7 @@ void DistributedAccessor::emit() {
759792

760793
emission->begin();
761794
emission->setArgs(arguments, /*isOutlined=*/false,
762-
Target.getWitnessMetadata());
795+
Target.getWitnessMetadata(actorSelf));
763796

764797
// Load result of the thunk into the location provided by the caller.
765798
// This would only generate code for direct results, if thunk has an
@@ -790,39 +823,75 @@ void DistributedAccessor::emit() {
790823
}
791824
}
792825

793-
FunctionPointer AccessorTarget::getPointerToTarget() const {
826+
FunctionPointer AccessorTarget::getPointerToTarget(llvm::Value *actorSelf) {
794827
auto &IGM = IGF.IGM;
795-
auto fpKind = classifyFunctionPointerKind(Target);
796-
auto signature = IGM.getSignature(Type, fpKind);
797828

798-
auto *fnPtr =
799-
llvm::ConstantExpr::getBitCast(IGM.getAddrOfAsyncFunctionPointer(Target),
800-
signature.getType()->getPointerTo());
829+
if (auto *thunk = Target.dyn_cast<SILFunction *>()) {
830+
auto fpKind = classifyFunctionPointerKind(thunk);
831+
auto signature = IGM.getSignature(Type, fpKind);
832+
833+
auto *fnPtr =
834+
llvm::ConstantExpr::getBitCast(IGM.getAddrOfAsyncFunctionPointer(thunk),
835+
signature.getType()->getPointerTo());
836+
837+
return FunctionPointer::forDirect(
838+
FunctionPointer::Kind(Type), fnPtr,
839+
IGM.getAddrOfSILFunction(thunk, NotForDefinition), signature);
840+
}
841+
842+
auto *requirementDecl = Target.get<AbstractFunctionDecl *>();
843+
auto *protocol = requirementDecl->getDeclContext()->getSelfProtocolDecl();
844+
SILDeclRef requirementRef = SILDeclRef(requirementDecl).asDistributed();
845+
846+
if (!IGM.isResilient(protocol, ResilienceExpansion::Maximal)) {
847+
auto *witness = getWitnessMetadata(actorSelf);
848+
return emitWitnessMethodValue(IGF, witness->SelfWitnessTable,
849+
requirementRef);
850+
}
801851

802-
return FunctionPointer::forDirect(
803-
FunctionPointer::Kind(Type), fnPtr,
804-
IGM.getAddrOfSILFunction(Target, NotForDefinition), signature);
852+
auto fnPtr = IGM.getAddrOfDispatchThunk(requirementRef, NotForDefinition);
853+
auto sig = IGM.getSignature(Type);
854+
return FunctionPointer::forDirect(Type, fnPtr,
855+
/*secondaryValue=*/nullptr, sig, true);
805856
}
806857

807-
Callee AccessorTarget::getCallee(llvm::Value *actorSelf) const {
858+
Callee AccessorTarget::getCallee(llvm::Value *actorSelf) {
808859
CalleeInfo info{Type, Type, SubstitutionMap()};
809-
return {std::move(info), getPointerToTarget(), actorSelf};
860+
return {std::move(info), getPointerToTarget(actorSelf), actorSelf};
861+
}
862+
863+
WitnessMetadata *AccessorTarget::getWitnessMetadata(llvm::Value *actorSelf) {
864+
if (Target.is<SILFunction *>())
865+
return nullptr;
866+
867+
if (!Witness) {
868+
WitnessMetadata witness;
869+
870+
auto *requirement = Target.get<AbstractFunctionDecl *>();
871+
auto *protocol = requirement->getDeclContext()->getSelfProtocolDecl();
872+
assert(protocol);
873+
874+
witness.SelfMetadata = actorSelf;
875+
witness.SelfWitnessTable = lookupWitnessTable(
876+
IGF, emitHeapMetadataRefForUnknownHeapObject(IGF, actorSelf), protocol);
877+
878+
Witness = witness;
879+
}
880+
881+
return &(*Witness);
810882
}
811883

812884
ArgumentDecoderInfo DistributedAccessor::findArgumentDecoder(
813885
llvm::Value *decoder, llvm::Value *decoderTy, llvm::Value *witnessTable) {
814886
auto &C = IGM.Context;
815-
DeclContext *targetContext = Target.getDeclContext();
887+
auto *thunk = cast<AbstractFunctionDecl>(Target.getDeclContext());
816888
auto expansionContext = IGM.getMaximalTypeExpansionContext();
817889

818890
/// If the context was a function, unwrap it and look for the decode method
819891
/// based off a concrete class; If we're not in a concrete class, we'll be
820892
/// using a witness for the decoder so returning null is okey.
821-
FuncDecl *decodeFn = nullptr;
822-
if (auto func = dyn_cast<AbstractFunctionDecl>(targetContext)) {
823-
decodeFn = C.getDistributedActorArgumentDecodingMethod(
824-
func->getDeclContext()->getSelfNominalTypeDecl());
825-
}
893+
FuncDecl *decodeFn = C.getDistributedActorArgumentDecodingMethod(
894+
thunk->getDeclContext()->getSelfNominalTypeDecl());
826895

827896
// If distributed actor is generic over actor system, we have to
828897
// use witness to reference `decodeNextArgument`.

lib/IRGen/IRGenModule.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1813,10 +1813,14 @@ private: \
18131813
Address getAddrOfObjCISAMask();
18141814

18151815
llvm::Function *
1816-
getAddrOfDistributedTargetAccessor(SILFunction *F,
1816+
getAddrOfDistributedTargetAccessor(LinkEntity accessor,
18171817
ForDefinition_t forDefinition);
18181818

1819-
void emitDistributedTargetAccessor(SILFunction *method);
1819+
/// Emit a distributed accessor function for the given distributed thunk or
1820+
/// protocol requirement.
1821+
void emitDistributedTargetAccessor(
1822+
llvm::PointerUnion<SILFunction *, AbstractFunctionDecl *>
1823+
thunkOrRequirement);
18201824

18211825
llvm::Constant *getAddrOfAccessibleFunctionRecord(SILFunction *accessibleFn);
18221826

0 commit comments

Comments
 (0)