@@ -81,6 +81,10 @@ struct ArgumentDecoderInfo {
81
81
// / The type of `decodeNextArgument` method.
82
82
CanSILFunctionType MethodType;
83
83
84
+ // / Protocol requirements associated with the generic
85
+ // / parameter `Argument` of this decode method.
86
+ GenericSignature::RequiredProtocols ProtocolRequirements;
87
+
84
88
// Witness metadata for conformance to DistributedTargetInvocationDecoder
85
89
// protocol.
86
90
WitnessMetadata Witness;
@@ -90,19 +94,31 @@ struct ArgumentDecoderInfo {
90
94
FunctionPointer decodeNextArgumentPtr,
91
95
CanSILFunctionType decodeNextArgumentTy)
92
96
: Decoder(decoder), MethodPtr(decodeNextArgumentPtr),
93
- MethodType (decodeNextArgumentTy) {
97
+ MethodType (decodeNextArgumentTy),
98
+ ProtocolRequirements(findProtocolRequirements(decodeNextArgumentTy)) {
94
99
Witness.SelfMetadata = decoderType;
95
100
Witness.SelfWitnessTable = decoderWitnessTable;
96
101
}
97
102
98
103
CanSILFunctionType getMethodType () const { return MethodType; }
99
104
100
- WitnessMetadata * getWitnessMetadata () const {
101
- return const_cast <WitnessMetadata *>(&Witness) ;
105
+ ArrayRef<ProtocolDecl *> getProtocolRequirements () const {
106
+ return ProtocolRequirements ;
102
107
}
103
108
104
109
// / Form a callee to a decode method - `decodeNextArgument`.
105
110
Callee getCallee () const ;
111
+
112
+ private:
113
+ static GenericSignature::RequiredProtocols
114
+ findProtocolRequirements (CanSILFunctionType decodeMethodTy) {
115
+ auto signature = decodeMethodTy->getInvocationGenericSignature ();
116
+ auto genericParams = signature.getGenericParams ();
117
+
118
+ // func decodeNextArgument<Arg : #SerializationRequirement#>() throws -> Arg
119
+ assert (genericParams.size () == 1 );
120
+ return signature->getRequiredProtocols (genericParams.front ());
121
+ }
106
122
};
107
123
108
124
class DistributedAccessor {
@@ -140,6 +156,10 @@ class DistributedAccessor {
140
156
llvm::Value *argumentType, const SILParameterInfo ¶m,
141
157
Explosion &arguments);
142
158
159
+ void lookupWitnessTables (llvm::Value *value,
160
+ ArrayRef<ProtocolDecl *> protocols,
161
+ Explosion &witnessTables);
162
+
143
163
// / Load witness table addresses (if any) from the given buffer
144
164
// / into the given argument explosion.
145
165
// /
@@ -385,13 +405,17 @@ void DistributedAccessor::decodeArgument(unsigned argumentIdx,
385
405
// substitution Argument -> <argument metadata>
386
406
decodeArgs.add (argumentType);
387
407
408
+ // Lookup witness tables for the requirement on the argument type.
409
+ lookupWitnessTables (argumentType, decoder.getProtocolRequirements (),
410
+ decodeArgs);
411
+
388
412
Address calleeErrorSlot;
389
413
llvm::Value *decodeError = nullptr ;
390
414
391
415
emission->begin ();
392
416
{
393
417
emission->setArgs (decodeArgs, /* isOutlined=*/ false ,
394
- /* witnessMetadata=*/ decoder. getWitnessMetadata () );
418
+ /* witnessMetadata=*/ nullptr );
395
419
396
420
Explosion result;
397
421
emission->emitToExplosion (result, /* isOutlined=*/ false );
@@ -492,6 +516,37 @@ void DistributedAccessor::decodeArgument(unsigned argumentIdx,
492
516
}
493
517
}
494
518
519
+ void DistributedAccessor::lookupWitnessTables (
520
+ llvm::Value *value, ArrayRef<ProtocolDecl *> protocols,
521
+ Explosion &witnessTables) {
522
+ auto conformsToProtocol = IGM.getConformsToProtocolFunctionPointer ();
523
+
524
+ for (auto *protocol : protocols) {
525
+ auto *protocolDescriptor = IGM.getAddrOfProtocolDescriptor (protocol);
526
+ auto *witnessTable =
527
+ IGF.Builder .CreateCall (conformsToProtocol, {value, protocolDescriptor});
528
+
529
+ auto failBB = IGF.createBasicBlock (" missing-witness" );
530
+ auto contBB = IGF.createBasicBlock (" " );
531
+
532
+ auto isNull = IGF.Builder .CreateICmpEQ (
533
+ witnessTable, llvm::ConstantPointerNull::get (IGM.WitnessTablePtrTy ));
534
+ IGF.Builder .CreateCondBr (isNull, failBB, contBB);
535
+
536
+ // This operation shouldn't fail because runtime should have checked that
537
+ // a particular argument type conforms to `SerializationRequirement`
538
+ // of the distributed actor the decoder is used for. If it does fail
539
+ // then accessor should trap.
540
+ {
541
+ IGF.Builder .emitBlock (failBB);
542
+ IGF.emitTrap (" missing witness table" , /* EmitUnreachable=*/ true );
543
+ }
544
+
545
+ IGF.Builder .emitBlock (contBB);
546
+ witnessTables.add (witnessTable);
547
+ }
548
+ }
549
+
495
550
void DistributedAccessor::emitLoadOfWitnessTables (llvm::Value *witnessTables,
496
551
llvm::Value *numTables,
497
552
unsigned expectedWitnessTables,
@@ -730,22 +785,70 @@ DistributedAccessor::getCalleeForDistributedTarget(llvm::Value *self) const {
730
785
731
786
ArgumentDecoderInfo DistributedAccessor::findArgumentDecoder (
732
787
llvm::Value *decoder, llvm::Value *decoderTy, llvm::Value *witnessTable) {
733
- auto &C = IGM.Context ;
788
+ auto *actor = getDistributedActorOf (Target);
789
+ auto expansionContext = IGM.getMaximalTypeExpansionContext ();
790
+
791
+ auto *decodeFn = IGM.Context .getDistributedActorArgumentDecodingMethod (actor);
792
+ assert (decodeFn && " no suitable decoder?" );
793
+
794
+ auto methodTy = IGM.getSILTypes ().getConstantFunctionType (
795
+ expansionContext, SILDeclRef (decodeFn));
796
+
797
+ auto fpKind = FunctionPointerKind::defaultAsync ();
798
+ auto signature = IGM.getSignature (methodTy, fpKind);
799
+
800
+ // If the decoder class is `final`, let's emit a direct reference.
801
+ auto *decoderDecl = decodeFn->getDeclContext ()->getSelfNominalTypeDecl ();
802
+
803
+ // If decoder is a class, need to load it first because generic parameter
804
+ // is passed indirectly. This is good for structs and enums because
805
+ // `decodeNextArgument` is a mutating method, but not for classes because
806
+ // in that case heap object is mutated directly.
807
+ bool usesDispatchThunk = false ;
734
808
735
- auto decoderProtocol = C. getDistributedTargetInvocationDecoderDecl ();
736
- SILDeclRef decodeNextArgumentRef (
737
- decoderProtocol-> getSingleRequirement (C. Id_decodeNextArgument ) );
809
+ if ( auto classDecl = dyn_cast<ClassDecl>(decoderDecl)) {
810
+ auto selfTy = methodTy-> getSelfParameter (). getSILStorageType (
811
+ IGM. getSILModule (), methodTy, expansionContext );
738
812
739
- llvm::Constant *fnPtr =
740
- IGM.getAddrOfDispatchThunk (decodeNextArgumentRef, NotForDefinition);
813
+ auto &classTI = IGM.getTypeInfo (selfTy).as <ClassTypeInfo>();
814
+ auto &classLayout = classTI.getClassLayout (IGM, selfTy,
815
+ /* forBackwardDeployment=*/ false );
741
816
742
- auto fnType = IGM.getSILTypes ().getConstantFunctionType (
743
- IGM.getMaximalTypeExpansionContext (), decodeNextArgumentRef);
817
+ llvm::Value *typedDecoderPtr = IGF.Builder .CreateBitCast (
818
+ decoder, classLayout.getType ()->getPointerTo ()->getPointerTo ());
819
+
820
+ Explosion instance;
821
+
822
+ classTI.loadAsTake (IGF,
823
+ {typedDecoderPtr, classTI.getStorageType (),
824
+ classTI.getBestKnownAlignment ()},
825
+ instance);
826
+
827
+ decoder = instance.claimNext ();
828
+
829
+ // / When using library evolution functions have another "dispatch thunk"
830
+ // / so we must use this instead of the decodeFn directly.
831
+ usesDispatchThunk =
832
+ getMethodDispatch (decodeFn) == swift::MethodDispatch::Class &&
833
+ classDecl->hasResilientMetadata ();
834
+ }
835
+
836
+ FunctionPointer methodPtr;
837
+
838
+ if (usesDispatchThunk) {
839
+ auto fnPtr = IGM.getAddrOfDispatchThunk (SILDeclRef (decodeFn), NotForDefinition);
840
+ methodPtr = FunctionPointer::createUnsigned (
841
+ methodTy, fnPtr, signature, /* useSignature=*/ true );
842
+ } else {
843
+ SILFunction *decodeSILFn = IGM.getSILModule ().lookUpFunction (SILDeclRef (decodeFn));
844
+ auto fnPtr = IGM.getAddrOfSILFunction (decodeSILFn, NotForDefinition,
845
+ /* isDynamicallyReplaceable=*/ false );
846
+ methodPtr = FunctionPointer::forDirect (
847
+ classifyFunctionPointerKind (decodeSILFn), fnPtr,
848
+ /* secondaryValue=*/ nullptr , signature);
849
+ }
744
850
745
- auto sig = IGM.getSignature (fnType);
746
- auto fn = FunctionPointer::forDirect (fnType, fnPtr,
747
- /* secondaryValue=*/ nullptr , sig, true );
748
- return {decoder, decoderTy, witnessTable, fn, fnType};
851
+ return {decoder, decoderTy, witnessTable, methodPtr, methodTy};
749
852
}
750
853
751
854
SILType DistributedAccessor::getResultType () const {
0 commit comments