@@ -81,16 +81,22 @@ struct ArgumentDecoderInfo {
81
81
// / The type of `decodeNextArgument` method.
82
82
CanSILFunctionType MethodType;
83
83
84
- // Witness metadata for conformance to DistributedTargetInvocationDecoder
85
- // protocol.
84
+ // / Witness metadata for conformance to DistributedTargetInvocationDecoder
85
+ // / protocol.
86
86
WitnessMetadata Witness;
87
87
88
+ // / Indicates whether `decodeNextArgument` is referenced through
89
+ // / a protocol witness thunk.
90
+ bool UsesWitnessDispatch;
91
+
88
92
ArgumentDecoderInfo (llvm::Value *decoder, llvm::Value *decoderType,
89
93
llvm::Value *decoderWitnessTable,
90
94
FunctionPointer decodeNextArgumentPtr,
91
- CanSILFunctionType decodeNextArgumentTy)
95
+ CanSILFunctionType decodeNextArgumentTy,
96
+ bool usesWitnessDispatch)
92
97
: Decoder(decoder), MethodPtr(decodeNextArgumentPtr),
93
- MethodType (decodeNextArgumentTy) {
98
+ MethodType (decodeNextArgumentTy),
99
+ UsesWitnessDispatch(usesWitnessDispatch) {
94
100
Witness.SelfMetadata = decoderType;
95
101
Witness.SelfWitnessTable = decoderWitnessTable;
96
102
}
@@ -101,6 +107,20 @@ struct ArgumentDecoderInfo {
101
107
return const_cast <WitnessMetadata *>(&Witness);
102
108
}
103
109
110
+ // / Protocol requirements associated with the generic
111
+ // / parameter `Argument` of this decode method.
112
+ GenericSignature::RequiredProtocols getProtocolRequirements () const {
113
+ if (UsesWitnessDispatch)
114
+ return {};
115
+
116
+ auto signature = MethodType->getInvocationGenericSignature ();
117
+ auto genericParams = signature.getGenericParams ();
118
+
119
+ // func decodeNextArgument<Arg : #SerializationRequirement#>() throws -> Arg
120
+ assert (genericParams.size () == 1 );
121
+ return signature->getRequiredProtocols (genericParams.front ());
122
+ }
123
+
104
124
// / Form a callee to a decode method - `decodeNextArgument`.
105
125
Callee getCallee () const ;
106
126
};
@@ -140,6 +160,10 @@ class DistributedAccessor {
140
160
llvm::Value *argumentType, const SILParameterInfo ¶m,
141
161
Explosion &arguments);
142
162
163
+ void lookupWitnessTables (llvm::Value *value,
164
+ ArrayRef<ProtocolDecl *> protocols,
165
+ Explosion &witnessTables);
166
+
143
167
// / Load witness table addresses (if any) from the given buffer
144
168
// / into the given argument explosion.
145
169
// /
@@ -385,13 +409,18 @@ void DistributedAccessor::decodeArgument(unsigned argumentIdx,
385
409
// substitution Argument -> <argument metadata>
386
410
decodeArgs.add (argumentType);
387
411
412
+ // Lookup witness tables for the requirement on the argument type.
413
+ lookupWitnessTables (argumentType, decoder.getProtocolRequirements (),
414
+ decodeArgs);
415
+
388
416
Address calleeErrorSlot;
389
417
llvm::Value *decodeError = nullptr ;
390
418
391
419
emission->begin ();
392
420
{
393
421
emission->setArgs (decodeArgs, /* isOutlined=*/ false ,
394
- /* witnessMetadata=*/ decoder.getWitnessMetadata ());
422
+ decoder.UsesWitnessDispatch ? decoder.getWitnessMetadata ()
423
+ : nullptr );
395
424
396
425
Explosion result;
397
426
emission->emitToExplosion (result, /* isOutlined=*/ false );
@@ -492,6 +521,43 @@ void DistributedAccessor::decodeArgument(unsigned argumentIdx,
492
521
}
493
522
}
494
523
524
+ void DistributedAccessor::lookupWitnessTables (
525
+ llvm::Value *value, ArrayRef<ProtocolDecl *> protocols,
526
+ Explosion &witnessTables) {
527
+ if (protocols.empty ())
528
+ return ;
529
+
530
+ auto conformsToProtocol = IGM.getConformsToProtocolFunctionPointer ();
531
+
532
+ for (auto *protocol : protocols) {
533
+ if (!Lowering::TypeConverter::protocolRequiresWitnessTable (protocol))
534
+ continue ;
535
+
536
+ auto *protocolDescriptor = IGM.getAddrOfProtocolDescriptor (protocol);
537
+ auto *witnessTable =
538
+ IGF.Builder .CreateCall (conformsToProtocol, {value, protocolDescriptor});
539
+
540
+ auto failBB = IGF.createBasicBlock (" missing-witness" );
541
+ auto contBB = IGF.createBasicBlock (" " );
542
+
543
+ auto isNull = IGF.Builder .CreateICmpEQ (
544
+ witnessTable, llvm::ConstantPointerNull::get (IGM.WitnessTablePtrTy ));
545
+ IGF.Builder .CreateCondBr (isNull, failBB, contBB);
546
+
547
+ // This operation shouldn't fail because runtime should have checked that
548
+ // a particular argument type conforms to `SerializationRequirement`
549
+ // of the distributed actor the decoder is used for. If it does fail
550
+ // then accessor should trap.
551
+ {
552
+ IGF.Builder .emitBlock (failBB);
553
+ IGF.emitTrap (" missing witness table" , /* EmitUnreachable=*/ true );
554
+ }
555
+
556
+ IGF.Builder .emitBlock (contBB);
557
+ witnessTables.add (witnessTable);
558
+ }
559
+ }
560
+
495
561
void DistributedAccessor::emitLoadOfWitnessTables (llvm::Value *witnessTables,
496
562
llvm::Value *numTables,
497
563
unsigned expectedWitnessTables,
@@ -731,21 +797,91 @@ DistributedAccessor::getCalleeForDistributedTarget(llvm::Value *self) const {
731
797
ArgumentDecoderInfo DistributedAccessor::findArgumentDecoder (
732
798
llvm::Value *decoder, llvm::Value *decoderTy, llvm::Value *witnessTable) {
733
799
auto &C = IGM.Context ;
800
+ auto *actor = getDistributedActorOf (Target);
801
+ auto expansionContext = IGM.getMaximalTypeExpansionContext ();
802
+
803
+ auto *decodeFn = C.getDistributedActorArgumentDecodingMethod (actor);
804
+
805
+ // If distributed actor is generic over actor system, we have to
806
+ // use witness to reference `decodeNextArgument`.
807
+ if (!decodeFn) {
808
+ auto decoderProtocol = C.getDistributedTargetInvocationDecoderDecl ();
809
+ auto decodeNextArgRequirement =
810
+ decoderProtocol->getSingleRequirement (C.Id_decodeNextArgument );
811
+ assert (decodeNextArgRequirement);
812
+ SILDeclRef decodeNextArgumentRef (decodeNextArgRequirement);
813
+
814
+ llvm::Constant *fnPtr =
815
+ IGM.getAddrOfDispatchThunk (decodeNextArgumentRef, NotForDefinition);
816
+ auto fnType = IGM.getSILTypes ().getConstantFunctionType (
817
+ IGM.getMaximalTypeExpansionContext (), decodeNextArgumentRef);
818
+
819
+ auto sig = IGM.getSignature (fnType);
820
+ auto fn = FunctionPointer::forDirect (fnType, fnPtr,
821
+ /* secondaryValue=*/ nullptr , sig, true );
822
+ return {decoder, decoderTy, witnessTable,
823
+ fn, fnType, /* usesWitnessDispatch=*/ true };
824
+ }
825
+
826
+ auto methodTy = IGM.getSILTypes ().getConstantFunctionType (
827
+ expansionContext, SILDeclRef (decodeFn));
828
+
829
+ auto fpKind = FunctionPointerKind::defaultAsync ();
830
+ auto signature = IGM.getSignature (methodTy, fpKind);
831
+
832
+ // If the decoder class is `final`, let's emit a direct reference.
833
+ auto *decoderDecl = decodeFn->getDeclContext ()->getSelfNominalTypeDecl ();
734
834
735
- auto decoderProtocol = C.getDistributedTargetInvocationDecoderDecl ();
736
- SILDeclRef decodeNextArgumentRef (
737
- decoderProtocol->getSingleRequirement (C.Id_decodeNextArgument ));
835
+ // If decoder is a class, need to load it first because generic parameter
836
+ // is passed indirectly. This is good for structs and enums because
837
+ // `decodeNextArgument` is a mutating method, but not for classes because
838
+ // in that case heap object is mutated directly.
839
+ bool usesDispatchThunk = false ;
738
840
739
- llvm::Constant *fnPtr =
740
- IGM.getAddrOfDispatchThunk (decodeNextArgumentRef, NotForDefinition);
841
+ if (auto classDecl = dyn_cast<ClassDecl>(decoderDecl)) {
842
+ auto selfTy = methodTy->getSelfParameter ().getSILStorageType (
843
+ IGM.getSILModule (), methodTy, expansionContext);
741
844
742
- auto fnType = IGM.getSILTypes ().getConstantFunctionType (
743
- IGM.getMaximalTypeExpansionContext (), decodeNextArgumentRef);
845
+ auto &classTI = IGM.getTypeInfo (selfTy).as <ClassTypeInfo>();
846
+ auto &classLayout = classTI.getClassLayout (IGM, selfTy,
847
+ /* forBackwardDeployment=*/ false );
848
+
849
+ llvm::Value *typedDecoderPtr = IGF.Builder .CreateBitCast (
850
+ decoder, classLayout.getType ()->getPointerTo ()->getPointerTo ());
851
+
852
+ Explosion instance;
853
+
854
+ classTI.loadAsTake (IGF,
855
+ {typedDecoderPtr, classTI.getStorageType (),
856
+ classTI.getBestKnownAlignment ()},
857
+ instance);
858
+
859
+ decoder = instance.claimNext ();
860
+
861
+ // / When using library evolution functions have another "dispatch thunk"
862
+ // / so we must use this instead of the decodeFn directly.
863
+ usesDispatchThunk =
864
+ getMethodDispatch (decodeFn) == swift::MethodDispatch::Class &&
865
+ classDecl->hasResilientMetadata ();
866
+ }
867
+
868
+ FunctionPointer methodPtr;
869
+
870
+ if (usesDispatchThunk) {
871
+ auto fnPtr = IGM.getAddrOfDispatchThunk (SILDeclRef (decodeFn), NotForDefinition);
872
+ methodPtr = FunctionPointer::createUnsigned (
873
+ methodTy, fnPtr, signature, /* useSignature=*/ true );
874
+ } else {
875
+ SILFunction *decodeSILFn = IGM.getSILModule ().lookUpFunction (SILDeclRef (decodeFn));
876
+ auto fnPtr = IGM.getAddrOfSILFunction (decodeSILFn, NotForDefinition,
877
+ /* isDynamicallyReplaceable=*/ false );
878
+ methodPtr = FunctionPointer::forDirect (
879
+ classifyFunctionPointerKind (decodeSILFn), fnPtr,
880
+ /* secondaryValue=*/ nullptr , signature);
881
+ }
744
882
745
- auto sig = IGM.getSignature (fnType);
746
- auto fn = FunctionPointer::forDirect (fnType, fnPtr,
747
- /* secondaryValue=*/ nullptr , sig, true );
748
- return {decoder, decoderTy, witnessTable, fn, fnType};
883
+ return {decoder, decoderTy, witnessTable,
884
+ methodPtr, methodTy, /* usesWitnessDispatch=*/ false };
749
885
}
750
886
751
887
SILType DistributedAccessor::getResultType () const {
0 commit comments