@@ -188,6 +188,65 @@ class AbstractionPattern {
188
188
// / The partially-applied curried imported type of a C++ method. OrigType is
189
189
// / valid and is a function type. CXXMethod is valid.
190
190
PartialCurriedCXXMethodType,
191
+ // / A Swift function whose parameters and results are opaque. This is
192
+ // / like `AP::Type<T>((T) -> T)`, except that the number of parameters is
193
+ // / unspecified.
194
+ // /
195
+ // / This is used to construct the abstraction pattern for the
196
+ // / derivative function of a function with opaque abstraction pattern. See
197
+ // / `OpaqueDerivativeFunction`.
198
+ OpaqueFunction,
199
+ // / A Swift function whose parameters are opaque and whose result is the
200
+ // / tuple abstraction pattern `(AP::Opaque, AP::OpaqueFunction)`.
201
+ // /
202
+ // / Purpose: when we reabstract `@differentiable` function-typed values
203
+ // / using the`AP::Opaque` pattern, we use `AP::Opaque` to reabstract the
204
+ // / original function in the bundle and `AP::OpaqueDerivativeFunction` to
205
+ // / reabstract the derivative functions in the bundle. This preserves the
206
+ // / `@differentiable` function invariant that the derivative type
207
+ // / (`SILFunctionType::getAutoDiffDerivativeFunctionType()`) of the original
208
+ // / function is equal to the type of the derivative function. For example:
209
+ // /
210
+ // / differentiable_function
211
+ // / [parameters 0]
212
+ // / %0 : $@callee_guaranteed (Float) -> Float
213
+ // / with_derivative {
214
+ // / %1 : $@callee_guaranteed (Float) -> (
215
+ // / Float,
216
+ // / @owned @callee_guaranteed (Float) -> Float
217
+ // / ),
218
+ // / %2 : $@callee_guaranteed (Float) -> (
219
+ // / Float,
220
+ // / @owned @callee_guaranteed (Float) -> Float
221
+ // / )
222
+ // / }
223
+ // /
224
+ // / The invariant-respecting abstraction of this value to `AP::Opaque` is:
225
+ // /
226
+ // / differentiable_function
227
+ // / [parameters 0]
228
+ // / %3 : $@callee_guaranteed (@in_guaranteed Float) -> @out Float
229
+ // / with_derivative {
230
+ // / %4 : $@callee_guaranteed (@in_guaranteed Float) -> (
231
+ // / @out Float,
232
+ // / @owned @callee_guaranteed (@in_guaranteed Float) -> @out Float
233
+ // / ),
234
+ // / %5 : $@callee_guaranteed (@in_guaranteed Float) -> (
235
+ // / @out Float,
236
+ // / @owned @callee_guaranteed (@in_guaranteed Float) -> @out Float
237
+ // / )
238
+ // / }
239
+ // /
240
+ // / In particular:
241
+ // /
242
+ // / - The reabstraction %0 => %3 uses pattern `AP::Opaque`.
243
+ // / - The reabstraction %1 => %4 uses pattern
244
+ // / `AP::OpaqueDerivativeFunction`, which maximally abstracts all the
245
+ // / parameters, and abstracts the result as the tuple
246
+ // / `(AP::Opaque, AP::OpaqueFunction)`.
247
+ // / - The reabstraction %2 => %5 similarly uses pattern
248
+ // / `AP::OpaqueDerivativeFunction`.
249
+ OpaqueDerivativeFunction,
191
250
};
192
251
193
252
class EncodedForeignErrorInfo {
@@ -238,7 +297,7 @@ class AbstractionPattern {
238
297
static constexpr const unsigned NumOtherDataBits = 28 ;
239
298
static constexpr const unsigned MaxOtherData = (1 << NumOtherDataBits) - 1 ;
240
299
241
- unsigned TheKind : 32 - NumOtherDataBits;
300
+ unsigned TheKind : 33 - NumOtherDataBits;
242
301
unsigned OtherData : NumOtherDataBits;
243
302
CanType OrigType;
244
303
union {
@@ -384,6 +443,14 @@ class AbstractionPattern {
384
443
return AbstractionPattern (Kind::Invalid);
385
444
}
386
445
446
+ static AbstractionPattern getOpaqueFunction () {
447
+ return AbstractionPattern (Kind::OpaqueFunction);
448
+ }
449
+
450
+ static AbstractionPattern getOpaqueDerivativeFunction () {
451
+ return AbstractionPattern (Kind::OpaqueDerivativeFunction);
452
+ }
453
+
387
454
bool hasGenericSignature () const {
388
455
switch (getKind ()) {
389
456
case Kind::Type:
@@ -402,6 +469,8 @@ class AbstractionPattern {
402
469
case Kind::Invalid:
403
470
case Kind::Opaque:
404
471
case Kind::Tuple:
472
+ case Kind::OpaqueFunction:
473
+ case Kind::OpaqueDerivativeFunction:
405
474
return false ;
406
475
}
407
476
llvm_unreachable (" Unhandled AbstractionPatternKind in switch" );
@@ -730,6 +799,10 @@ class AbstractionPattern {
730
799
llvm_unreachable (" opaque pattern has no type" );
731
800
case Kind::Tuple:
732
801
llvm_unreachable (" open-coded tuple pattern has no type" );
802
+ case Kind::OpaqueFunction:
803
+ llvm_unreachable (" opaque function pattern has no type" );
804
+ case Kind::OpaqueDerivativeFunction:
805
+ llvm_unreachable (" opaque derivative function pattern has no type" );
733
806
case Kind::ClangType:
734
807
case Kind::CurriedObjCMethodType:
735
808
case Kind::PartialCurriedObjCMethodType:
@@ -763,6 +836,8 @@ class AbstractionPattern {
763
836
case Kind::Invalid:
764
837
case Kind::Opaque:
765
838
case Kind::Tuple:
839
+ case Kind::OpaqueFunction:
840
+ case Kind::OpaqueDerivativeFunction:
766
841
llvm_unreachable (" type cannot be replaced on pattern without type" );
767
842
case Kind::ClangType:
768
843
case Kind::CurriedObjCMethodType:
@@ -798,6 +873,8 @@ class AbstractionPattern {
798
873
case Kind::Tuple:
799
874
case Kind::Type:
800
875
case Kind::Discard:
876
+ case Kind::OpaqueFunction:
877
+ case Kind::OpaqueDerivativeFunction:
801
878
return false ;
802
879
case Kind::ClangType:
803
880
case Kind::PartialCurriedObjCMethodType:
@@ -854,6 +931,11 @@ class AbstractionPattern {
854
931
return CXXMethod;
855
932
}
856
933
934
+ bool isOpaqueFunctionOrOpaqueDerivativeFunction () const {
935
+ return (getKind () == Kind::OpaqueFunction ||
936
+ getKind () == Kind::OpaqueDerivativeFunction);
937
+ }
938
+
857
939
EncodedForeignErrorInfo getEncodedForeignErrorInfo () const {
858
940
assert (hasStoredForeignErrorInfo ());
859
941
return EncodedForeignErrorInfo::fromOpaqueValue (OtherData);
@@ -876,6 +958,8 @@ class AbstractionPattern {
876
958
case Kind::CXXMethodType:
877
959
case Kind::CurriedCXXMethodType:
878
960
case Kind::PartialCurriedCXXMethodType:
961
+ case Kind::OpaqueFunction:
962
+ case Kind::OpaqueDerivativeFunction:
879
963
return false ;
880
964
case Kind::PartialCurriedObjCMethodType:
881
965
case Kind::CurriedObjCMethodType:
@@ -897,6 +981,9 @@ class AbstractionPattern {
897
981
return typename CanTypeWrapperTraits<TYPE>::type ();
898
982
case Kind::Tuple:
899
983
return typename CanTypeWrapperTraits<TYPE>::type ();
984
+ case Kind::OpaqueFunction:
985
+ case Kind::OpaqueDerivativeFunction:
986
+ return typename CanTypeWrapperTraits<TYPE>::type ();
900
987
case Kind::ClangType:
901
988
case Kind::PartialCurriedObjCMethodType:
902
989
case Kind::CurriedObjCMethodType:
@@ -935,6 +1022,8 @@ class AbstractionPattern {
935
1022
case Kind::CXXMethodType:
936
1023
case Kind::CurriedCXXMethodType:
937
1024
case Kind::PartialCurriedCXXMethodType:
1025
+ case Kind::OpaqueFunction:
1026
+ case Kind::OpaqueDerivativeFunction:
938
1027
// We assume that the Clang type might provide additional structure.
939
1028
return false ;
940
1029
case Kind::Type:
@@ -962,6 +1051,8 @@ class AbstractionPattern {
962
1051
case Kind::CXXMethodType:
963
1052
case Kind::CurriedCXXMethodType:
964
1053
case Kind::PartialCurriedCXXMethodType:
1054
+ case Kind::OpaqueFunction:
1055
+ case Kind::OpaqueDerivativeFunction:
965
1056
return false ;
966
1057
case Kind::Tuple:
967
1058
return true ;
@@ -987,6 +1078,8 @@ class AbstractionPattern {
987
1078
case Kind::CXXMethodType:
988
1079
case Kind::CurriedCXXMethodType:
989
1080
case Kind::PartialCurriedCXXMethodType:
1081
+ case Kind::OpaqueFunction:
1082
+ case Kind::OpaqueDerivativeFunction:
990
1083
llvm_unreachable (" pattern is not a tuple" );
991
1084
case Kind::Tuple:
992
1085
return getNumTupleElements_Stored ();
@@ -1022,6 +1115,17 @@ class AbstractionPattern {
1022
1115
// / it.
1023
1116
AbstractionPattern getReferenceStorageReferentType () const ;
1024
1117
1118
+ // / Given that the value being abstracted is a function type, return the
1119
+ // / abstraction pattern for the derivative function.
1120
+ // /
1121
+ // / The arguments are the same as the arguments to
1122
+ // / `AnyFunctionType::getAutoDiffDerivativeFunctionType()`.
1123
+ AbstractionPattern getAutoDiffDerivativeFunctionType (
1124
+ IndexSubset *parameterIndices, AutoDiffDerivativeFunctionKind kind,
1125
+ LookupConformanceFn lookupConformance,
1126
+ GenericSignature derivativeGenericSignature = GenericSignature(),
1127
+ bool makeSelfParamFirst = false);
1128
+
1025
1129
void dump () const LLVM_ATTRIBUTE_USED;
1026
1130
void print (raw_ostream &OS) const ;
1027
1131
};
0 commit comments