Skip to content

Commit c0dc888

Browse files
committed
Revert "[IRGen] Distributed: Always invoke decodeNextArgument through witness thunk"
This reverts commit 4d4c80b.
1 parent 39b81ac commit c0dc888

File tree

3 files changed

+135
-55
lines changed

3 files changed

+135
-55
lines changed

lib/IRGen/GenDistributed.cpp

Lines changed: 119 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,10 @@ struct ArgumentDecoderInfo {
8181
/// The type of `decodeNextArgument` method.
8282
CanSILFunctionType MethodType;
8383

84+
/// Protocol requirements associated with the generic
85+
/// parameter `Argument` of this decode method.
86+
GenericSignature::RequiredProtocols ProtocolRequirements;
87+
8488
// Witness metadata for conformance to DistributedTargetInvocationDecoder
8589
// protocol.
8690
WitnessMetadata Witness;
@@ -90,19 +94,31 @@ struct ArgumentDecoderInfo {
9094
FunctionPointer decodeNextArgumentPtr,
9195
CanSILFunctionType decodeNextArgumentTy)
9296
: Decoder(decoder), MethodPtr(decodeNextArgumentPtr),
93-
MethodType(decodeNextArgumentTy) {
97+
MethodType(decodeNextArgumentTy),
98+
ProtocolRequirements(findProtocolRequirements(decodeNextArgumentTy)) {
9499
Witness.SelfMetadata = decoderType;
95100
Witness.SelfWitnessTable = decoderWitnessTable;
96101
}
97102

98103
CanSILFunctionType getMethodType() const { return MethodType; }
99104

100-
WitnessMetadata *getWitnessMetadata() const {
101-
return const_cast<WitnessMetadata *>(&Witness);
105+
ArrayRef<ProtocolDecl *> getProtocolRequirements() const {
106+
return ProtocolRequirements;
102107
}
103108

104109
/// Form a callee to a decode method - `decodeNextArgument`.
105110
Callee getCallee() const;
111+
112+
private:
113+
static GenericSignature::RequiredProtocols
114+
findProtocolRequirements(CanSILFunctionType decodeMethodTy) {
115+
auto signature = decodeMethodTy->getInvocationGenericSignature();
116+
auto genericParams = signature.getGenericParams();
117+
118+
// func decodeNextArgument<Arg : #SerializationRequirement#>() throws -> Arg
119+
assert(genericParams.size() == 1);
120+
return signature->getRequiredProtocols(genericParams.front());
121+
}
106122
};
107123

108124
class DistributedAccessor {
@@ -140,6 +156,10 @@ class DistributedAccessor {
140156
llvm::Value *argumentType, const SILParameterInfo &param,
141157
Explosion &arguments);
142158

159+
void lookupWitnessTables(llvm::Value *value,
160+
ArrayRef<ProtocolDecl *> protocols,
161+
Explosion &witnessTables);
162+
143163
/// Load witness table addresses (if any) from the given buffer
144164
/// into the given argument explosion.
145165
///
@@ -385,13 +405,17 @@ void DistributedAccessor::decodeArgument(unsigned argumentIdx,
385405
// substitution Argument -> <argument metadata>
386406
decodeArgs.add(argumentType);
387407

408+
// Lookup witness tables for the requirement on the argument type.
409+
lookupWitnessTables(argumentType, decoder.getProtocolRequirements(),
410+
decodeArgs);
411+
388412
Address calleeErrorSlot;
389413
llvm::Value *decodeError = nullptr;
390414

391415
emission->begin();
392416
{
393417
emission->setArgs(decodeArgs, /*isOutlined=*/false,
394-
/*witnessMetadata=*/decoder.getWitnessMetadata());
418+
/*witnessMetadata=*/nullptr);
395419

396420
Explosion result;
397421
emission->emitToExplosion(result, /*isOutlined=*/false);
@@ -492,6 +516,37 @@ void DistributedAccessor::decodeArgument(unsigned argumentIdx,
492516
}
493517
}
494518

519+
void DistributedAccessor::lookupWitnessTables(
520+
llvm::Value *value, ArrayRef<ProtocolDecl *> protocols,
521+
Explosion &witnessTables) {
522+
auto conformsToProtocol = IGM.getConformsToProtocolFunctionPointer();
523+
524+
for (auto *protocol : protocols) {
525+
auto *protocolDescriptor = IGM.getAddrOfProtocolDescriptor(protocol);
526+
auto *witnessTable =
527+
IGF.Builder.CreateCall(conformsToProtocol, {value, protocolDescriptor});
528+
529+
auto failBB = IGF.createBasicBlock("missing-witness");
530+
auto contBB = IGF.createBasicBlock("");
531+
532+
auto isNull = IGF.Builder.CreateICmpEQ(
533+
witnessTable, llvm::ConstantPointerNull::get(IGM.WitnessTablePtrTy));
534+
IGF.Builder.CreateCondBr(isNull, failBB, contBB);
535+
536+
// This operation shouldn't fail because runtime should have checked that
537+
// a particular argument type conforms to `SerializationRequirement`
538+
// of the distributed actor the decoder is used for. If it does fail
539+
// then accessor should trap.
540+
{
541+
IGF.Builder.emitBlock(failBB);
542+
IGF.emitTrap("missing witness table", /*EmitUnreachable=*/true);
543+
}
544+
545+
IGF.Builder.emitBlock(contBB);
546+
witnessTables.add(witnessTable);
547+
}
548+
}
549+
495550
void DistributedAccessor::emitLoadOfWitnessTables(llvm::Value *witnessTables,
496551
llvm::Value *numTables,
497552
unsigned expectedWitnessTables,
@@ -730,22 +785,70 @@ DistributedAccessor::getCalleeForDistributedTarget(llvm::Value *self) const {
730785

731786
ArgumentDecoderInfo DistributedAccessor::findArgumentDecoder(
732787
llvm::Value *decoder, llvm::Value *decoderTy, llvm::Value *witnessTable) {
733-
auto &C = IGM.Context;
788+
auto *actor = getDistributedActorOf(Target);
789+
auto expansionContext = IGM.getMaximalTypeExpansionContext();
790+
791+
auto *decodeFn = IGM.Context.getDistributedActorArgumentDecodingMethod(actor);
792+
assert(decodeFn && "no suitable decoder?");
793+
794+
auto methodTy = IGM.getSILTypes().getConstantFunctionType(
795+
expansionContext, SILDeclRef(decodeFn));
796+
797+
auto fpKind = FunctionPointerKind::defaultAsync();
798+
auto signature = IGM.getSignature(methodTy, fpKind);
799+
800+
// If the decoder class is `final`, let's emit a direct reference.
801+
auto *decoderDecl = decodeFn->getDeclContext()->getSelfNominalTypeDecl();
802+
803+
// If decoder is a class, need to load it first because generic parameter
804+
// is passed indirectly. This is good for structs and enums because
805+
// `decodeNextArgument` is a mutating method, but not for classes because
806+
// in that case heap object is mutated directly.
807+
bool usesDispatchThunk = false;
734808

735-
auto decoderProtocol = C.getDistributedTargetInvocationDecoderDecl();
736-
SILDeclRef decodeNextArgumentRef(
737-
decoderProtocol->getSingleRequirement(C.Id_decodeNextArgument));
809+
if (auto classDecl = dyn_cast<ClassDecl>(decoderDecl)) {
810+
auto selfTy = methodTy->getSelfParameter().getSILStorageType(
811+
IGM.getSILModule(), methodTy, expansionContext);
738812

739-
llvm::Constant *fnPtr =
740-
IGM.getAddrOfDispatchThunk(decodeNextArgumentRef, NotForDefinition);
813+
auto &classTI = IGM.getTypeInfo(selfTy).as<ClassTypeInfo>();
814+
auto &classLayout = classTI.getClassLayout(IGM, selfTy,
815+
/*forBackwardDeployment=*/false);
741816

742-
auto fnType = IGM.getSILTypes().getConstantFunctionType(
743-
IGM.getMaximalTypeExpansionContext(), decodeNextArgumentRef);
817+
llvm::Value *typedDecoderPtr = IGF.Builder.CreateBitCast(
818+
decoder, classLayout.getType()->getPointerTo()->getPointerTo());
819+
820+
Explosion instance;
821+
822+
classTI.loadAsTake(IGF,
823+
{typedDecoderPtr, classTI.getStorageType(),
824+
classTI.getBestKnownAlignment()},
825+
instance);
826+
827+
decoder = instance.claimNext();
828+
829+
/// When using library evolution functions have another "dispatch thunk"
830+
/// so we must use this instead of the decodeFn directly.
831+
usesDispatchThunk =
832+
getMethodDispatch(decodeFn) == swift::MethodDispatch::Class &&
833+
classDecl->hasResilientMetadata();
834+
}
835+
836+
FunctionPointer methodPtr;
837+
838+
if (usesDispatchThunk) {
839+
auto fnPtr = IGM.getAddrOfDispatchThunk(SILDeclRef(decodeFn), NotForDefinition);
840+
methodPtr = FunctionPointer::createUnsigned(
841+
methodTy, fnPtr, signature, /*useSignature=*/true);
842+
} else {
843+
SILFunction *decodeSILFn = IGM.getSILModule().lookUpFunction(SILDeclRef(decodeFn));
844+
auto fnPtr = IGM.getAddrOfSILFunction(decodeSILFn, NotForDefinition,
845+
/*isDynamicallyReplaceable=*/false);
846+
methodPtr = FunctionPointer::forDirect(
847+
classifyFunctionPointerKind(decodeSILFn), fnPtr,
848+
/*secondaryValue=*/nullptr, signature);
849+
}
744850

745-
auto sig = IGM.getSignature(fnType);
746-
auto fn = FunctionPointer::forDirect(fnType, fnPtr,
747-
/*secondaryValue=*/nullptr, sig, true);
748-
return {decoder, decoderTy, witnessTable, fn, fnType};
851+
return {decoder, decoderTy, witnessTable, methodPtr, methodTy};
749852
}
750853

751854
SILType DistributedAccessor::getResultType() const {

test/Distributed/Runtime/distributed_actor_localSystem_generic.swift

Lines changed: 0 additions & 38 deletions
This file was deleted.

test/Distributed/distributed_actor_accessor_thunks_64bit.swift

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,10 +98,25 @@ public distributed actor MyOtherActor {
9898

9999
/// Read the current offset and cast an element to `Int`
100100

101+
// CHECK: [[DECODER:%.*]] = load ptr, ptr %1
102+
// CHECK-NEXT: [[DECODE_NEXT_ARG_REF:%.*]] = getelementptr inbounds ptr, ptr %2, i64
103+
101104
// CHECK: [[ARG_0_SIZE_ADJ:%.*]] = add i64 %size, 15
102105
// CHECK-NEXT: [[ARG_0_SIZE:%.*]] = and i64 [[ARG_0_SIZE_ADJ]], -16
103106
// CHECK-NEXT: [[ARG_0_VALUE_BUF:%.*]] = call swiftcc ptr @swift_task_alloc(i64 [[ARG_0_SIZE]])
104-
// CHECK-NEXT: call swiftcc void @"$s11Distributed0A23TargetInvocationDecoderP18decodeNextArgumentqd__yKlFTj"(ptr noalias sret(%swift.opaque) [[ARG_0_VALUE_BUF]], ptr %arg_type, ptr swiftself %1, ptr noalias nocapture swifterror dereferenceable(8) %swifterror, ptr [[DECODER_TYPE]], ptr [[DECODER_PROTOCOL_WITNESS]])
107+
// CHECK-NEXT: [[ENCODABLE_WITNESS:%.*]] = call ptr @swift_conformsToProtocol{{(2)?}}(ptr %arg_type, ptr @"$sSeMp")
108+
// CHECK-NEXT: [[IS_NULL:%.*]] = icmp eq ptr [[ENCODABLE_WITNESS]], null
109+
// CHECK-NEXT: br i1 [[IS_NULL]], label %missing-witness, label [[CONT:%.*]]
110+
// CHECK: missing-witness:
111+
// CHECK-NEXT: call void @llvm.trap()
112+
// CHECK-NEXT: unreachable
113+
// CHECK: [[DECODABLE_WITNESS:%.*]] = call ptr @swift_conformsToProtocol{{(2)?}}(ptr %arg_type, ptr @"$sSEMp")
114+
// CHECK-NEXT: [[IS_NULL:%.*]] = icmp eq ptr [[DECODABLE_WITNESS]], null
115+
// CHECK-NEXT: br i1 [[IS_NULL]], label %missing-witness1, label [[CONT:%.*]]
116+
// CHECK: missing-witness1:
117+
// CHECK-NEXT: call void @llvm.trap()
118+
// CHECK-NEXT: unreachable
119+
// CHECK: call swiftcc void @"$s27FakeDistributedActorSystems0A17InvocationDecoderC18decodeNextArgumentxyKSeRzSERzlF"(ptr noalias sret(%swift.opaque) [[ARG_0_VALUE_BUF]], ptr %arg_type, ptr [[ENCODABLE_WITNESS]], ptr [[DECODABLE_WITNESS]], ptr swiftself [[DECODER]], ptr noalias nocapture swifterror dereferenceable(8) %swifterror)
105120

106121
// CHECK: store ptr null, ptr %swifterror
107122
// CHECK-NEXT: %._value = getelementptr inbounds %TSi, ptr [[ARG_0_VALUE_BUF]], i32 0, i32 0

0 commit comments

Comments
 (0)