Skip to content

Commit db020c1

Browse files
authored
Merge pull request #71435 from xedin/fix-need-for-adhoc-distributed-requirements
[Distributed] Re-implement ad-hoc requirements into dynamic witness table lookup for `SerializationRequirement` conformance
2 parents c440cca + aac4e85 commit db020c1

35 files changed

+509
-825
lines changed

include/swift/AST/Decl.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7498,6 +7498,17 @@ class AbstractFunctionDecl : public GenericContext, public ValueDecl {
74987498
/// 'DistributedTargetInvocationResultHandler' protocol.
74997499
bool isDistributedTargetInvocationResultHandlerOnReturn() const;
75007500

7501+
/// Determines whether this declaration is a witness to a
7502+
/// protocol requirement with ad-hoc `SerializationRequirement`
7503+
/// conformance.
7504+
bool isDistributedWitnessWithAdHocSerializationRequirement() const {
7505+
return isDistributedActorSystemRemoteCall(/*isVoidResult=*/false) ||
7506+
isDistributedTargetInvocationEncoderRecordArgument() ||
7507+
isDistributedTargetInvocationEncoderRecordReturnType() ||
7508+
isDistributedTargetInvocationDecoderDecodeNextArgument() ||
7509+
isDistributedTargetInvocationResultHandlerOnReturn();
7510+
}
7511+
75017512
/// For a method of a class, checks whether it will require a new entry in the
75027513
/// vtable.
75037514
bool needsNewVTableEntry() const;

include/swift/SIL/SILFunction.h

Lines changed: 0 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -264,17 +264,6 @@ class SILFunction
264264
/// @_dynamicReplacement(for:) function.
265265
SILFunction *ReplacedFunction = nullptr;
266266

267-
/// This SILFunction REFerences an ad-hoc protocol requirement witness in
268-
/// order to keep it alive, such that it main be obtained in IRGen. Without
269-
/// this explicit reference, the witness would seem not-used, and not be
270-
/// accessible for IRGen.
271-
///
272-
/// Specifically, one such case is the DistributedTargetInvocationDecoder's
273-
/// 'decodeNextArgument' which must be retained, as it is only used from IRGen
274-
/// and such, appears as-if unused in SIL and would get optimized away.
275-
// TODO: Consider making this a general "references adhoc functions" and make it an array?
276-
SILFunction *RefAdHocRequirementFunction = nullptr;
277-
278267
Identifier ObjCReplacementFor;
279268

280269
/// The head of a single-linked list of currently alive BasicBlockBitfield.
@@ -596,27 +585,6 @@ class SILFunction
596585
ReplacedFunction = nullptr;
597586
}
598587

599-
SILFunction *getReferencedAdHocRequirementWitnessFunction() const {
600-
return RefAdHocRequirementFunction;
601-
}
602-
// Marks that this `SILFunction` uses the passed in ad-hoc protocol
603-
// requirement witness `f` and therefore must retain it explicitly,
604-
// otherwise we might not be able to get a reference to it.
605-
void setReferencedAdHocRequirementWitnessFunction(SILFunction *f) {
606-
assert(RefAdHocRequirementFunction == nullptr && "already set");
607-
608-
if (f == nullptr)
609-
return;
610-
RefAdHocRequirementFunction = f;
611-
RefAdHocRequirementFunction->incrementRefCount();
612-
}
613-
void dropReferencedAdHocRequirementWitnessFunction() {
614-
if (!RefAdHocRequirementFunction)
615-
return;
616-
RefAdHocRequirementFunction->decrementRefCount();
617-
RefAdHocRequirementFunction = nullptr;
618-
}
619-
620588
bool hasObjCReplacement() const {
621589
return !ObjCReplacementFor.empty();
622590
}

include/swift/Sema/ConstraintLocator.h

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1293,6 +1293,37 @@ class ConstraintLocatorBuilder {
12931293
return false;
12941294
}
12951295

1296+
std::optional<std::pair</*witness=*/ValueDecl *, GenericTypeParamType *>>
1297+
isForWitnessGenericParameterRequirement() const {
1298+
SmallVector<LocatorPathElt, 2> path;
1299+
getLocatorParts(path);
1300+
1301+
// -> witness -> generic env -> requirement
1302+
if (path.size() < 3)
1303+
return std::nullopt;
1304+
1305+
GenericTypeParamType *GP = nullptr;
1306+
if (auto reqLoc =
1307+
path.back().getAs<LocatorPathElt::TypeParameterRequirement>()) {
1308+
path.pop_back();
1309+
if (auto openedGeneric =
1310+
path.back().getAs<LocatorPathElt::OpenedGeneric>()) {
1311+
auto signature = openedGeneric->getSignature();
1312+
auto requirement = signature.getRequirements()[reqLoc->getIndex()];
1313+
GP = requirement.getFirstType()->getAs<GenericTypeParamType>();
1314+
}
1315+
}
1316+
1317+
if (!GP)
1318+
return std::nullopt;
1319+
1320+
auto witness = path.front().getAs<LocatorPathElt::Witness>();
1321+
if (!witness)
1322+
return std::nullopt;
1323+
1324+
return std::make_pair(witness->getDecl(), GP);
1325+
}
1326+
12961327
/// Checks whether this locator is describing an argument application for a
12971328
/// non-ephemeral parameter.
12981329
bool isNonEphemeralParameterApplication() const {

include/swift/Sema/ConstraintSystem.h

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1627,6 +1627,11 @@ class Solution {
16271627
llvm::DenseMap<ConstraintLocator *, UnresolvedDotExpr *>
16281628
ImplicitCallAsFunctionRoots;
16291629

1630+
/// The set of conformances synthesized during solving (i.e. for
1631+
/// ad-hoc distributed `SerializationRequirement` conformances).
1632+
llvm::MapVector<ConstraintLocator *, ProtocolConformanceRef>
1633+
SynthesizedConformances;
1634+
16301635
/// Record a new argument matching choice for given locator that maps a
16311636
/// single argument to a single parameter.
16321637
void recordSingleArgMatchingChoice(ConstraintLocator *locator);
@@ -1667,11 +1672,15 @@ class Solution {
16671672
/// Compute the set of substitutions for a generic signature opened at the
16681673
/// given locator.
16691674
///
1675+
/// \param decl The underlying declaration for which the substitutions are
1676+
/// computed.
1677+
///
16701678
/// \param sig The generic signature.
16711679
///
16721680
/// \param locator The locator that describes where the substitutions came
16731681
/// from.
1674-
SubstitutionMap computeSubstitutions(GenericSignature sig,
1682+
SubstitutionMap computeSubstitutions(NullablePtr<ValueDecl> decl,
1683+
GenericSignature sig,
16751684
ConstraintLocator *locator) const;
16761685

16771686
/// Resolves the contextual substitutions for a reference to a declaration
@@ -2411,6 +2420,11 @@ class ConstraintSystem {
24112420
llvm::SmallMapVector<ConstraintLocator *, UnresolvedDotExpr *, 2>
24122421
ImplicitCallAsFunctionRoots;
24132422

2423+
/// The set of conformances synthesized during solving (i.e. for
2424+
/// ad-hoc distributed `SerializationRequirement` conformances).
2425+
llvm::MapVector<ConstraintLocator *, ProtocolConformanceRef>
2426+
SynthesizedConformances;
2427+
24142428
private:
24152429
/// Describe the candidate expression for partial solving.
24162430
/// This class used by shrink & solve methods which apply
@@ -2934,6 +2948,9 @@ class ConstraintSystem {
29342948
/// The length of \c ImplicitCallAsFunctionRoots.
29352949
unsigned numImplicitCallAsFunctionRoots;
29362950

2951+
/// The length of \c SynthesizedConformances.
2952+
unsigned numSynthesizedConformances;
2953+
29372954
/// The previous score.
29382955
Score PreviousScore;
29392956

lib/AST/DistributedDecl.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -387,6 +387,10 @@ bool swift::checkDistributedSerializationRequirementIsExactlyCodable(
387387
bool AbstractFunctionDecl::isDistributedActorSystemRemoteCall(bool isVoidReturn) const {
388388
auto &C = getASTContext();
389389
auto module = getParentModule();
390+
auto *DC = getDeclContext();
391+
392+
if (!DC->isTypeContext() || !isGeneric())
393+
return false;
390394

391395
// === Check the name
392396
auto callId = isVoidReturn ? C.Id_remoteCallVoid : C.Id_remoteCall;
@@ -398,7 +402,7 @@ bool AbstractFunctionDecl::isDistributedActorSystemRemoteCall(bool isVoidReturn)
398402
ProtocolDecl *systemProto =
399403
C.getDistributedActorSystemDecl();
400404

401-
auto systemNominal = getDeclContext()->getSelfNominalTypeDecl();
405+
auto systemNominal = DC->getSelfNominalTypeDecl();
402406
auto distSystemConformance = module->lookupConformance(
403407
systemNominal->getDeclaredInterfaceType(), systemProto);
404408

lib/IRGen/GenDistributed.cpp

Lines changed: 16 additions & 119 deletions
Original file line numberDiff line numberDiff line change
@@ -81,10 +81,6 @@ struct ArgumentDecoderInfo {
8181
/// The type of `decodeNextArgument` method.
8282
CanSILFunctionType MethodType;
8383

84-
/// Protocol requirements associated with the generic
85-
/// parameter `Argument` of this decode method.
86-
GenericSignature::RequiredProtocols ProtocolRequirements;
87-
8884
// Witness metadata for conformance to DistributedTargetInvocationDecoder
8985
// protocol.
9086
WitnessMetadata Witness;
@@ -94,31 +90,19 @@ struct ArgumentDecoderInfo {
9490
FunctionPointer decodeNextArgumentPtr,
9591
CanSILFunctionType decodeNextArgumentTy)
9692
: Decoder(decoder), MethodPtr(decodeNextArgumentPtr),
97-
MethodType(decodeNextArgumentTy),
98-
ProtocolRequirements(findProtocolRequirements(decodeNextArgumentTy)) {
93+
MethodType(decodeNextArgumentTy) {
9994
Witness.SelfMetadata = decoderType;
10095
Witness.SelfWitnessTable = decoderWitnessTable;
10196
}
10297

10398
CanSILFunctionType getMethodType() const { return MethodType; }
10499

105-
ArrayRef<ProtocolDecl *> getProtocolRequirements() const {
106-
return ProtocolRequirements;
100+
WitnessMetadata *getWitnessMetadata() const {
101+
return const_cast<WitnessMetadata *>(&Witness);
107102
}
108103

109104
/// Form a callee to a decode method - `decodeNextArgument`.
110105
Callee getCallee() const;
111-
112-
private:
113-
static GenericSignature::RequiredProtocols
114-
findProtocolRequirements(CanSILFunctionType decodeMethodTy) {
115-
auto signature = decodeMethodTy->getInvocationGenericSignature();
116-
auto genericParams = signature.getGenericParams();
117-
118-
// func decodeNextArgument<Arg : #SerializationRequirement#>() throws -> Arg
119-
assert(genericParams.size() == 1);
120-
return signature->getRequiredProtocols(genericParams.front());
121-
}
122106
};
123107

124108
class DistributedAccessor {
@@ -156,10 +140,6 @@ class DistributedAccessor {
156140
llvm::Value *argumentType, const SILParameterInfo &param,
157141
Explosion &arguments);
158142

159-
void lookupWitnessTables(llvm::Value *value,
160-
ArrayRef<ProtocolDecl *> protocols,
161-
Explosion &witnessTables);
162-
163143
/// Load witness table addresses (if any) from the given buffer
164144
/// into the given argument explosion.
165145
///
@@ -417,17 +397,13 @@ void DistributedAccessor::decodeArgument(unsigned argumentIdx,
417397
// substitution Argument -> <argument metadata>
418398
decodeArgs.add(argumentType);
419399

420-
// Lookup witness tables for the requirement on the argument type.
421-
lookupWitnessTables(argumentType, decoder.getProtocolRequirements(),
422-
decodeArgs);
423-
424400
Address calleeErrorSlot;
425401
llvm::Value *decodeError = nullptr;
426402

427403
emission->begin();
428404
{
429405
emission->setArgs(decodeArgs, /*isOutlined=*/false,
430-
/*witnessMetadata=*/nullptr);
406+
/*witnessMetadata=*/decoder.getWitnessMetadata());
431407

432408
Explosion result;
433409
emission->emitToExplosion(result, /*isOutlined=*/false);
@@ -528,37 +504,6 @@ void DistributedAccessor::decodeArgument(unsigned argumentIdx,
528504
}
529505
}
530506

531-
void DistributedAccessor::lookupWitnessTables(
532-
llvm::Value *value, ArrayRef<ProtocolDecl *> protocols,
533-
Explosion &witnessTables) {
534-
auto conformsToProtocol = IGM.getConformsToProtocolFunctionPointer();
535-
536-
for (auto *protocol : protocols) {
537-
auto *protocolDescriptor = IGM.getAddrOfProtocolDescriptor(protocol);
538-
auto *witnessTable =
539-
IGF.Builder.CreateCall(conformsToProtocol, {value, protocolDescriptor});
540-
541-
auto failBB = IGF.createBasicBlock("missing-witness");
542-
auto contBB = IGF.createBasicBlock("");
543-
544-
auto isNull = IGF.Builder.CreateICmpEQ(
545-
witnessTable, llvm::ConstantPointerNull::get(IGM.WitnessTablePtrTy));
546-
IGF.Builder.CreateCondBr(isNull, failBB, contBB);
547-
548-
// This operation shouldn't fail because runtime should have checked that
549-
// a particular argument type conforms to `SerializationRequirement`
550-
// of the distributed actor the decoder is used for. If it does fail
551-
// then accessor should trap.
552-
{
553-
IGF.Builder.emitBlock(failBB);
554-
IGF.emitTrap("missing witness table", /*EmitUnreachable=*/true);
555-
}
556-
557-
IGF.Builder.emitBlock(contBB);
558-
witnessTables.add(witnessTable);
559-
}
560-
}
561-
562507
void DistributedAccessor::emitLoadOfWitnessTables(llvm::Value *witnessTables,
563508
llvm::Value *numTables,
564509
unsigned expectedWitnessTables,
@@ -803,70 +748,22 @@ DistributedAccessor::getCalleeForDistributedTarget(llvm::Value *self) const {
803748

804749
ArgumentDecoderInfo DistributedAccessor::findArgumentDecoder(
805750
llvm::Value *decoder, llvm::Value *decoderTy, llvm::Value *witnessTable) {
806-
auto *actor = getDistributedActorOf(Target);
807-
auto expansionContext = IGM.getMaximalTypeExpansionContext();
808-
809-
auto *decodeFn = IGM.Context.getDistributedActorArgumentDecodingMethod(actor);
810-
assert(decodeFn && "no suitable decoder?");
811-
812-
auto methodTy = IGM.getSILTypes().getConstantFunctionType(
813-
expansionContext, SILDeclRef(decodeFn));
814-
815-
auto fpKind = FunctionPointerKind::defaultAsync();
816-
auto signature = IGM.getSignature(methodTy, fpKind);
817-
818-
// If the decoder class is `final`, let's emit a direct reference.
819-
auto *decoderDecl = decodeFn->getDeclContext()->getSelfNominalTypeDecl();
820-
821-
// If decoder is a class, need to load it first because generic parameter
822-
// is passed indirectly. This is good for structs and enums because
823-
// `decodeNextArgument` is a mutating method, but not for classes because
824-
// in that case heap object is mutated directly.
825-
bool usesDispatchThunk = false;
751+
auto &C = IGM.Context;
826752

827-
if (auto classDecl = dyn_cast<ClassDecl>(decoderDecl)) {
828-
auto selfTy = methodTy->getSelfParameter().getSILStorageType(
829-
IGM.getSILModule(), methodTy, expansionContext);
753+
auto decoderProtocol = C.getDistributedTargetInvocationDecoderDecl();
754+
SILDeclRef decodeNextArgumentRef(
755+
decoderProtocol->getSingleRequirement(C.Id_decodeNextArgument));
830756

831-
auto &classTI = IGM.getTypeInfo(selfTy).as<ClassTypeInfo>();
832-
auto &classLayout = classTI.getClassLayout(IGM, selfTy,
833-
/*forBackwardDeployment=*/false);
757+
llvm::Constant *fnPtr =
758+
IGM.getAddrOfDispatchThunk(decodeNextArgumentRef, NotForDefinition);
834759

835-
llvm::Value *typedDecoderPtr = IGF.Builder.CreateBitCast(
836-
decoder, classLayout.getType()->getPointerTo()->getPointerTo());
837-
838-
Explosion instance;
839-
840-
classTI.loadAsTake(IGF,
841-
{typedDecoderPtr, classTI.getStorageType(),
842-
classTI.getBestKnownAlignment()},
843-
instance);
844-
845-
decoder = instance.claimNext();
846-
847-
/// When using library evolution functions have another "dispatch thunk"
848-
/// so we must use this instead of the decodeFn directly.
849-
usesDispatchThunk =
850-
getMethodDispatch(decodeFn) == swift::MethodDispatch::Class &&
851-
classDecl->hasResilientMetadata();
852-
}
853-
854-
FunctionPointer methodPtr;
855-
856-
if (usesDispatchThunk) {
857-
auto fnPtr = IGM.getAddrOfDispatchThunk(SILDeclRef(decodeFn), NotForDefinition);
858-
methodPtr = FunctionPointer::createUnsigned(
859-
methodTy, fnPtr, signature, /*useSignature=*/true);
860-
} else {
861-
SILFunction *decodeSILFn = IGM.getSILModule().lookUpFunction(SILDeclRef(decodeFn));
862-
auto fnPtr = IGM.getAddrOfSILFunction(decodeSILFn, NotForDefinition,
863-
/*isDynamicallyReplaceable=*/false);
864-
methodPtr = FunctionPointer::forDirect(
865-
classifyFunctionPointerKind(decodeSILFn), fnPtr,
866-
/*secondaryValue=*/nullptr, signature);
867-
}
760+
auto fnType = IGM.getSILTypes().getConstantFunctionType(
761+
IGM.getMaximalTypeExpansionContext(), decodeNextArgumentRef);
868762

869-
return {decoder, decoderTy, witnessTable, methodPtr, methodTy};
763+
auto sig = IGM.getSignature(fnType);
764+
auto fn = FunctionPointer::forDirect(fnType, fnPtr,
765+
/*secondaryValue=*/nullptr, sig, true);
766+
return {decoder, decoderTy, witnessTable, fn, fnType};
870767
}
871768

872769
SILType DistributedAccessor::getResultType() const {

0 commit comments

Comments
 (0)