Skip to content

Commit f9ec3b1

Browse files
authored
Merge pull request #71796 from xedin/make-dist-new-requirements-conditionally-available
[Distributed] Make new protocol requirements conditionally available
2 parents 161183c + 2bd1825 commit f9ec3b1

20 files changed

+295
-34
lines changed

include/swift/SIL/SILFunction.h

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,17 @@ class SILFunction
264264
/// @_dynamicReplacement(for:) function.
265265
SILFunction *ReplacedFunction = nullptr;
266266

267+
/// This SILFunction REFerences an ad-hoc protocol requirement witness in
268+
/// order to keep it alive, such that it main be obtained in IRGen. Without
269+
/// this explicit reference, the witness would seem not-used, and not be
270+
/// accessible for IRGen.
271+
///
272+
/// Specifically, one such case is the DistributedTargetInvocationDecoder's
273+
/// 'decodeNextArgument' which must be retained, as it is only used from IRGen
274+
/// and such, appears as-if unused in SIL and would get optimized away.
275+
// TODO: Consider making this a general "references adhoc functions" and make it an array?
276+
SILFunction *RefAdHocRequirementFunction = nullptr;
277+
267278
Identifier ObjCReplacementFor;
268279

269280
/// The head of a single-linked list of currently alive BasicBlockBitfield.
@@ -585,6 +596,27 @@ class SILFunction
585596
ReplacedFunction = nullptr;
586597
}
587598

599+
SILFunction *getReferencedAdHocRequirementWitnessFunction() const {
600+
return RefAdHocRequirementFunction;
601+
}
602+
// Marks that this `SILFunction` uses the passed in ad-hoc protocol
603+
// requirement witness `f` and therefore must retain it explicitly,
604+
// otherwise we might not be able to get a reference to it.
605+
void setReferencedAdHocRequirementWitnessFunction(SILFunction *f) {
606+
assert(RefAdHocRequirementFunction == nullptr && "already set");
607+
608+
if (f == nullptr)
609+
return;
610+
RefAdHocRequirementFunction = f;
611+
RefAdHocRequirementFunction->incrementRefCount();
612+
}
613+
void dropReferencedAdHocRequirementWitnessFunction() {
614+
if (!RefAdHocRequirementFunction)
615+
return;
616+
RefAdHocRequirementFunction->decrementRefCount();
617+
RefAdHocRequirementFunction = nullptr;
618+
}
619+
588620
bool hasObjCReplacement() const {
589621
return !ObjCReplacementFor.empty();
590622
}

lib/AST/DistributedDecl.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,7 @@ swift::getAssociatedDistributedInvocationDecoderDecodeNextArgumentFunction(
256256
return nullptr;
257257

258258
auto systemTy = getConcreteReplacementForProtocolActorSystemType(thunk);
259-
if (!systemTy)
259+
if (!systemTy || systemTy->is<GenericTypeParamType>())
260260
return nullptr;
261261

262262
auto decoderTy =

lib/IRGen/GenDistributed.cpp

Lines changed: 152 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -81,16 +81,22 @@ struct ArgumentDecoderInfo {
8181
/// The type of `decodeNextArgument` method.
8282
CanSILFunctionType MethodType;
8383

84-
// Witness metadata for conformance to DistributedTargetInvocationDecoder
85-
// protocol.
84+
/// Witness metadata for conformance to DistributedTargetInvocationDecoder
85+
/// protocol.
8686
WitnessMetadata Witness;
8787

88+
/// Indicates whether `decodeNextArgument` is referenced through
89+
/// a protocol witness thunk.
90+
bool UsesWitnessDispatch;
91+
8892
ArgumentDecoderInfo(llvm::Value *decoder, llvm::Value *decoderType,
8993
llvm::Value *decoderWitnessTable,
9094
FunctionPointer decodeNextArgumentPtr,
91-
CanSILFunctionType decodeNextArgumentTy)
95+
CanSILFunctionType decodeNextArgumentTy,
96+
bool usesWitnessDispatch)
9297
: Decoder(decoder), MethodPtr(decodeNextArgumentPtr),
93-
MethodType(decodeNextArgumentTy) {
98+
MethodType(decodeNextArgumentTy),
99+
UsesWitnessDispatch(usesWitnessDispatch) {
94100
Witness.SelfMetadata = decoderType;
95101
Witness.SelfWitnessTable = decoderWitnessTable;
96102
}
@@ -101,6 +107,20 @@ struct ArgumentDecoderInfo {
101107
return const_cast<WitnessMetadata *>(&Witness);
102108
}
103109

110+
/// Protocol requirements associated with the generic
111+
/// parameter `Argument` of this decode method.
112+
GenericSignature::RequiredProtocols getProtocolRequirements() const {
113+
if (UsesWitnessDispatch)
114+
return {};
115+
116+
auto signature = MethodType->getInvocationGenericSignature();
117+
auto genericParams = signature.getGenericParams();
118+
119+
// func decodeNextArgument<Arg : #SerializationRequirement#>() throws -> Arg
120+
assert(genericParams.size() == 1);
121+
return signature->getRequiredProtocols(genericParams.front());
122+
}
123+
104124
/// Form a callee to a decode method - `decodeNextArgument`.
105125
Callee getCallee() const;
106126
};
@@ -140,6 +160,10 @@ class DistributedAccessor {
140160
llvm::Value *argumentType, const SILParameterInfo &param,
141161
Explosion &arguments);
142162

163+
void lookupWitnessTables(llvm::Value *value,
164+
ArrayRef<ProtocolDecl *> protocols,
165+
Explosion &witnessTables);
166+
143167
/// Load witness table addresses (if any) from the given buffer
144168
/// into the given argument explosion.
145169
///
@@ -385,13 +409,18 @@ void DistributedAccessor::decodeArgument(unsigned argumentIdx,
385409
// substitution Argument -> <argument metadata>
386410
decodeArgs.add(argumentType);
387411

412+
// Lookup witness tables for the requirement on the argument type.
413+
lookupWitnessTables(argumentType, decoder.getProtocolRequirements(),
414+
decodeArgs);
415+
388416
Address calleeErrorSlot;
389417
llvm::Value *decodeError = nullptr;
390418

391419
emission->begin();
392420
{
393421
emission->setArgs(decodeArgs, /*isOutlined=*/false,
394-
/*witnessMetadata=*/decoder.getWitnessMetadata());
422+
decoder.UsesWitnessDispatch ? decoder.getWitnessMetadata()
423+
: nullptr);
395424

396425
Explosion result;
397426
emission->emitToExplosion(result, /*isOutlined=*/false);
@@ -492,6 +521,43 @@ void DistributedAccessor::decodeArgument(unsigned argumentIdx,
492521
}
493522
}
494523

524+
void DistributedAccessor::lookupWitnessTables(
525+
llvm::Value *value, ArrayRef<ProtocolDecl *> protocols,
526+
Explosion &witnessTables) {
527+
if (protocols.empty())
528+
return;
529+
530+
auto conformsToProtocol = IGM.getConformsToProtocolFunctionPointer();
531+
532+
for (auto *protocol : protocols) {
533+
if (!Lowering::TypeConverter::protocolRequiresWitnessTable(protocol))
534+
continue;
535+
536+
auto *protocolDescriptor = IGM.getAddrOfProtocolDescriptor(protocol);
537+
auto *witnessTable =
538+
IGF.Builder.CreateCall(conformsToProtocol, {value, protocolDescriptor});
539+
540+
auto failBB = IGF.createBasicBlock("missing-witness");
541+
auto contBB = IGF.createBasicBlock("");
542+
543+
auto isNull = IGF.Builder.CreateICmpEQ(
544+
witnessTable, llvm::ConstantPointerNull::get(IGM.WitnessTablePtrTy));
545+
IGF.Builder.CreateCondBr(isNull, failBB, contBB);
546+
547+
// This operation shouldn't fail because runtime should have checked that
548+
// a particular argument type conforms to `SerializationRequirement`
549+
// of the distributed actor the decoder is used for. If it does fail
550+
// then accessor should trap.
551+
{
552+
IGF.Builder.emitBlock(failBB);
553+
IGF.emitTrap("missing witness table", /*EmitUnreachable=*/true);
554+
}
555+
556+
IGF.Builder.emitBlock(contBB);
557+
witnessTables.add(witnessTable);
558+
}
559+
}
560+
495561
void DistributedAccessor::emitLoadOfWitnessTables(llvm::Value *witnessTables,
496562
llvm::Value *numTables,
497563
unsigned expectedWitnessTables,
@@ -731,21 +797,91 @@ DistributedAccessor::getCalleeForDistributedTarget(llvm::Value *self) const {
731797
ArgumentDecoderInfo DistributedAccessor::findArgumentDecoder(
732798
llvm::Value *decoder, llvm::Value *decoderTy, llvm::Value *witnessTable) {
733799
auto &C = IGM.Context;
800+
auto *actor = getDistributedActorOf(Target);
801+
auto expansionContext = IGM.getMaximalTypeExpansionContext();
802+
803+
auto *decodeFn = C.getDistributedActorArgumentDecodingMethod(actor);
804+
805+
// If distributed actor is generic over actor system, we have to
806+
// use witness to reference `decodeNextArgument`.
807+
if (!decodeFn) {
808+
auto decoderProtocol = C.getDistributedTargetInvocationDecoderDecl();
809+
auto decodeNextArgRequirement =
810+
decoderProtocol->getSingleRequirement(C.Id_decodeNextArgument);
811+
assert(decodeNextArgRequirement);
812+
SILDeclRef decodeNextArgumentRef(decodeNextArgRequirement);
813+
814+
llvm::Constant *fnPtr =
815+
IGM.getAddrOfDispatchThunk(decodeNextArgumentRef, NotForDefinition);
816+
auto fnType = IGM.getSILTypes().getConstantFunctionType(
817+
IGM.getMaximalTypeExpansionContext(), decodeNextArgumentRef);
818+
819+
auto sig = IGM.getSignature(fnType);
820+
auto fn = FunctionPointer::forDirect(fnType, fnPtr,
821+
/*secondaryValue=*/nullptr, sig, true);
822+
return {decoder, decoderTy, witnessTable,
823+
fn, fnType, /*usesWitnessDispatch=*/true};
824+
}
825+
826+
auto methodTy = IGM.getSILTypes().getConstantFunctionType(
827+
expansionContext, SILDeclRef(decodeFn));
828+
829+
auto fpKind = FunctionPointerKind::defaultAsync();
830+
auto signature = IGM.getSignature(methodTy, fpKind);
831+
832+
// If the decoder class is `final`, let's emit a direct reference.
833+
auto *decoderDecl = decodeFn->getDeclContext()->getSelfNominalTypeDecl();
734834

735-
auto decoderProtocol = C.getDistributedTargetInvocationDecoderDecl();
736-
SILDeclRef decodeNextArgumentRef(
737-
decoderProtocol->getSingleRequirement(C.Id_decodeNextArgument));
835+
// If decoder is a class, need to load it first because generic parameter
836+
// is passed indirectly. This is good for structs and enums because
837+
// `decodeNextArgument` is a mutating method, but not for classes because
838+
// in that case heap object is mutated directly.
839+
bool usesDispatchThunk = false;
738840

739-
llvm::Constant *fnPtr =
740-
IGM.getAddrOfDispatchThunk(decodeNextArgumentRef, NotForDefinition);
841+
if (auto classDecl = dyn_cast<ClassDecl>(decoderDecl)) {
842+
auto selfTy = methodTy->getSelfParameter().getSILStorageType(
843+
IGM.getSILModule(), methodTy, expansionContext);
741844

742-
auto fnType = IGM.getSILTypes().getConstantFunctionType(
743-
IGM.getMaximalTypeExpansionContext(), decodeNextArgumentRef);
845+
auto &classTI = IGM.getTypeInfo(selfTy).as<ClassTypeInfo>();
846+
auto &classLayout = classTI.getClassLayout(IGM, selfTy,
847+
/*forBackwardDeployment=*/false);
848+
849+
llvm::Value *typedDecoderPtr = IGF.Builder.CreateBitCast(
850+
decoder, classLayout.getType()->getPointerTo()->getPointerTo());
851+
852+
Explosion instance;
853+
854+
classTI.loadAsTake(IGF,
855+
{typedDecoderPtr, classTI.getStorageType(),
856+
classTI.getBestKnownAlignment()},
857+
instance);
858+
859+
decoder = instance.claimNext();
860+
861+
/// When using library evolution functions have another "dispatch thunk"
862+
/// so we must use this instead of the decodeFn directly.
863+
usesDispatchThunk =
864+
getMethodDispatch(decodeFn) == swift::MethodDispatch::Class &&
865+
classDecl->hasResilientMetadata();
866+
}
867+
868+
FunctionPointer methodPtr;
869+
870+
if (usesDispatchThunk) {
871+
auto fnPtr = IGM.getAddrOfDispatchThunk(SILDeclRef(decodeFn), NotForDefinition);
872+
methodPtr = FunctionPointer::createUnsigned(
873+
methodTy, fnPtr, signature, /*useSignature=*/true);
874+
} else {
875+
SILFunction *decodeSILFn = IGM.getSILModule().lookUpFunction(SILDeclRef(decodeFn));
876+
auto fnPtr = IGM.getAddrOfSILFunction(decodeSILFn, NotForDefinition,
877+
/*isDynamicallyReplaceable=*/false);
878+
methodPtr = FunctionPointer::forDirect(
879+
classifyFunctionPointerKind(decodeSILFn), fnPtr,
880+
/*secondaryValue=*/nullptr, signature);
881+
}
744882

745-
auto sig = IGM.getSignature(fnType);
746-
auto fn = FunctionPointer::forDirect(fnType, fnPtr,
747-
/*secondaryValue=*/nullptr, sig, true);
748-
return {decoder, decoderTy, witnessTable, fn, fnType};
883+
return {decoder, decoderTy, witnessTable,
884+
methodPtr, methodTy, /*usesWitnessDispatch=*/false};
749885
}
750886

751887
SILType DistributedAccessor::getResultType() const {

lib/SIL/IR/SILFunction.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,7 @@ void SILFunction::createSnapshot(int id) {
278278
newSnapshot->DeclCtxt = DeclCtxt;
279279
newSnapshot->Profiler = Profiler;
280280
newSnapshot->ReplacedFunction = ReplacedFunction;
281+
newSnapshot->RefAdHocRequirementFunction = RefAdHocRequirementFunction;
281282
newSnapshot->ObjCReplacementFor = ObjCReplacementFor;
282283
newSnapshot->SemanticsAttrSet = SemanticsAttrSet;
283284
newSnapshot->SpecializeAttrSet = SpecializeAttrSet;

lib/SIL/IR/SILFunctionBuilder.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,16 @@ void SILFunctionBuilder::addFunctionAttributes(
274274

275275
F->setDynamicallyReplacedFunction(replacedFunc);
276276
}
277+
} else if (constant.isDistributedThunk()) {
278+
// It's okay for `decodeFuncDecl` to be null because system could be
279+
// generic.
280+
if (auto decodeFuncDecl =
281+
getAssociatedDistributedInvocationDecoderDecodeNextArgumentFunction(
282+
decl)) {
283+
auto decodeRef = SILDeclRef(decodeFuncDecl);
284+
auto *adHocFunc = getOrCreateDeclaration(decodeFuncDecl, decodeRef);
285+
F->setReferencedAdHocRequirementWitnessFunction(adHocFunc);
286+
}
277287
}
278288
}
279289

lib/SIL/IR/SILModule.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,7 @@ SILModule::~SILModule() {
152152
for (SILFunction &F : *this) {
153153
F.dropAllReferences();
154154
F.dropDynamicallyReplacedFunction();
155+
F.dropReferencedAdHocRequirementWitnessFunction();
155156
F.clearSpecializeAttrs();
156157
}
157158

@@ -430,6 +431,7 @@ void SILModule::eraseFunction(SILFunction *F) {
430431
// (References are not needed anymore.)
431432
F->clear();
432433
F->dropDynamicallyReplacedFunction();
434+
F->dropReferencedAdHocRequirementWitnessFunction();
433435
// Drop references for any _specialize(target:) functions.
434436
F->clearSpecializeAttrs();
435437
}

lib/SIL/IR/SILPrinter.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3414,6 +3414,12 @@ void SILFunction::print(SILPrintContext &PrintCtx) const {
34143414
OS << "\"] ";
34153415
}
34163416

3417+
if (auto *usedFunc = getReferencedAdHocRequirementWitnessFunction()) {
3418+
OS << "[ref_adhoc_requirement_witness \"";
3419+
OS << usedFunc->getName();
3420+
OS << "\"] ";
3421+
}
3422+
34173423
if (hasObjCReplacement()) {
34183424
OS << "[objc_replacement_for \"";
34193425
OS << getObjCReplacement().str();

0 commit comments

Comments
 (0)