Skip to content

Commit 495b571

Browse files
marcrasidan-zheng
andauthored
[AutoDiff upstream] Add @differentiable function reabstraction. (swiftlang#30692)
Add SILGen logic for reabstracting `@differentiable` functions. Resolves TF-1223. Co-authored-by: Dan Zheng <[email protected]>
1 parent d067e7e commit 495b571

File tree

6 files changed

+362
-3
lines changed

6 files changed

+362
-3
lines changed

include/swift/SIL/AbstractionPattern.h

Lines changed: 105 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,65 @@ class AbstractionPattern {
188188
/// The partially-applied curried imported type of a C++ method. OrigType is
189189
/// valid and is a function type. CXXMethod is valid.
190190
PartialCurriedCXXMethodType,
191+
/// A Swift function whose parameters and results are opaque. This is
192+
/// like `AP::Type<T>((T) -> T)`, except that the number of parameters is
193+
/// unspecified.
194+
///
195+
/// This is used to construct the abstraction pattern for the
196+
/// derivative function of a function with opaque abstraction pattern. See
197+
/// `OpaqueDerivativeFunction`.
198+
OpaqueFunction,
199+
/// A Swift function whose parameters are opaque and whose result is the
200+
/// tuple abstraction pattern `(AP::Opaque, AP::OpaqueFunction)`.
201+
///
202+
/// Purpose: when we reabstract `@differentiable` function-typed values
203+
/// using the`AP::Opaque` pattern, we use `AP::Opaque` to reabstract the
204+
/// original function in the bundle and `AP::OpaqueDerivativeFunction` to
205+
/// reabstract the derivative functions in the bundle. This preserves the
206+
/// `@differentiable` function invariant that the derivative type
207+
/// (`SILFunctionType::getAutoDiffDerivativeFunctionType()`) of the original
208+
/// function is equal to the type of the derivative function. For example:
209+
///
210+
/// differentiable_function
211+
/// [parameters 0]
212+
/// %0 : $@callee_guaranteed (Float) -> Float
213+
/// with_derivative {
214+
/// %1 : $@callee_guaranteed (Float) -> (
215+
/// Float,
216+
/// @owned @callee_guaranteed (Float) -> Float
217+
/// ),
218+
/// %2 : $@callee_guaranteed (Float) -> (
219+
/// Float,
220+
/// @owned @callee_guaranteed (Float) -> Float
221+
/// )
222+
/// }
223+
///
224+
/// The invariant-respecting abstraction of this value to `AP::Opaque` is:
225+
///
226+
/// differentiable_function
227+
/// [parameters 0]
228+
/// %3 : $@callee_guaranteed (@in_guaranteed Float) -> @out Float
229+
/// with_derivative {
230+
/// %4 : $@callee_guaranteed (@in_guaranteed Float) -> (
231+
/// @out Float,
232+
/// @owned @callee_guaranteed (@in_guaranteed Float) -> @out Float
233+
/// ),
234+
/// %5 : $@callee_guaranteed (@in_guaranteed Float) -> (
235+
/// @out Float,
236+
/// @owned @callee_guaranteed (@in_guaranteed Float) -> @out Float
237+
/// )
238+
/// }
239+
///
240+
/// In particular:
241+
///
242+
/// - The reabstraction %0 => %3 uses pattern `AP::Opaque`.
243+
/// - The reabstraction %1 => %4 uses pattern
244+
/// `AP::OpaqueDerivativeFunction`, which maximally abstracts all the
245+
/// parameters, and abstracts the result as the tuple
246+
/// `(AP::Opaque, AP::OpaqueFunction)`.
247+
/// - The reabstraction %2 => %5 similarly uses pattern
248+
/// `AP::OpaqueDerivativeFunction`.
249+
OpaqueDerivativeFunction,
191250
};
192251

193252
class EncodedForeignErrorInfo {
@@ -238,7 +297,7 @@ class AbstractionPattern {
238297
static constexpr const unsigned NumOtherDataBits = 28;
239298
static constexpr const unsigned MaxOtherData = (1 << NumOtherDataBits) - 1;
240299

241-
unsigned TheKind : 32 - NumOtherDataBits;
300+
unsigned TheKind : 33 - NumOtherDataBits;
242301
unsigned OtherData : NumOtherDataBits;
243302
CanType OrigType;
244303
union {
@@ -384,6 +443,14 @@ class AbstractionPattern {
384443
return AbstractionPattern(Kind::Invalid);
385444
}
386445

446+
static AbstractionPattern getOpaqueFunction() {
447+
return AbstractionPattern(Kind::OpaqueFunction);
448+
}
449+
450+
static AbstractionPattern getOpaqueDerivativeFunction() {
451+
return AbstractionPattern(Kind::OpaqueDerivativeFunction);
452+
}
453+
387454
bool hasGenericSignature() const {
388455
switch (getKind()) {
389456
case Kind::Type:
@@ -402,6 +469,8 @@ class AbstractionPattern {
402469
case Kind::Invalid:
403470
case Kind::Opaque:
404471
case Kind::Tuple:
472+
case Kind::OpaqueFunction:
473+
case Kind::OpaqueDerivativeFunction:
405474
return false;
406475
}
407476
llvm_unreachable("Unhandled AbstractionPatternKind in switch");
@@ -730,6 +799,10 @@ class AbstractionPattern {
730799
llvm_unreachable("opaque pattern has no type");
731800
case Kind::Tuple:
732801
llvm_unreachable("open-coded tuple pattern has no type");
802+
case Kind::OpaqueFunction:
803+
llvm_unreachable("opaque function pattern has no type");
804+
case Kind::OpaqueDerivativeFunction:
805+
llvm_unreachable("opaque derivative function pattern has no type");
733806
case Kind::ClangType:
734807
case Kind::CurriedObjCMethodType:
735808
case Kind::PartialCurriedObjCMethodType:
@@ -763,6 +836,8 @@ class AbstractionPattern {
763836
case Kind::Invalid:
764837
case Kind::Opaque:
765838
case Kind::Tuple:
839+
case Kind::OpaqueFunction:
840+
case Kind::OpaqueDerivativeFunction:
766841
llvm_unreachable("type cannot be replaced on pattern without type");
767842
case Kind::ClangType:
768843
case Kind::CurriedObjCMethodType:
@@ -798,6 +873,8 @@ class AbstractionPattern {
798873
case Kind::Tuple:
799874
case Kind::Type:
800875
case Kind::Discard:
876+
case Kind::OpaqueFunction:
877+
case Kind::OpaqueDerivativeFunction:
801878
return false;
802879
case Kind::ClangType:
803880
case Kind::PartialCurriedObjCMethodType:
@@ -854,6 +931,11 @@ class AbstractionPattern {
854931
return CXXMethod;
855932
}
856933

934+
bool isOpaqueFunctionOrOpaqueDerivativeFunction() const {
935+
return (getKind() == Kind::OpaqueFunction ||
936+
getKind() == Kind::OpaqueDerivativeFunction);
937+
}
938+
857939
EncodedForeignErrorInfo getEncodedForeignErrorInfo() const {
858940
assert(hasStoredForeignErrorInfo());
859941
return EncodedForeignErrorInfo::fromOpaqueValue(OtherData);
@@ -876,6 +958,8 @@ class AbstractionPattern {
876958
case Kind::CXXMethodType:
877959
case Kind::CurriedCXXMethodType:
878960
case Kind::PartialCurriedCXXMethodType:
961+
case Kind::OpaqueFunction:
962+
case Kind::OpaqueDerivativeFunction:
879963
return false;
880964
case Kind::PartialCurriedObjCMethodType:
881965
case Kind::CurriedObjCMethodType:
@@ -897,6 +981,9 @@ class AbstractionPattern {
897981
return typename CanTypeWrapperTraits<TYPE>::type();
898982
case Kind::Tuple:
899983
return typename CanTypeWrapperTraits<TYPE>::type();
984+
case Kind::OpaqueFunction:
985+
case Kind::OpaqueDerivativeFunction:
986+
return typename CanTypeWrapperTraits<TYPE>::type();
900987
case Kind::ClangType:
901988
case Kind::PartialCurriedObjCMethodType:
902989
case Kind::CurriedObjCMethodType:
@@ -935,6 +1022,8 @@ class AbstractionPattern {
9351022
case Kind::CXXMethodType:
9361023
case Kind::CurriedCXXMethodType:
9371024
case Kind::PartialCurriedCXXMethodType:
1025+
case Kind::OpaqueFunction:
1026+
case Kind::OpaqueDerivativeFunction:
9381027
// We assume that the Clang type might provide additional structure.
9391028
return false;
9401029
case Kind::Type:
@@ -962,6 +1051,8 @@ class AbstractionPattern {
9621051
case Kind::CXXMethodType:
9631052
case Kind::CurriedCXXMethodType:
9641053
case Kind::PartialCurriedCXXMethodType:
1054+
case Kind::OpaqueFunction:
1055+
case Kind::OpaqueDerivativeFunction:
9651056
return false;
9661057
case Kind::Tuple:
9671058
return true;
@@ -987,6 +1078,8 @@ class AbstractionPattern {
9871078
case Kind::CXXMethodType:
9881079
case Kind::CurriedCXXMethodType:
9891080
case Kind::PartialCurriedCXXMethodType:
1081+
case Kind::OpaqueFunction:
1082+
case Kind::OpaqueDerivativeFunction:
9901083
llvm_unreachable("pattern is not a tuple");
9911084
case Kind::Tuple:
9921085
return getNumTupleElements_Stored();
@@ -1022,6 +1115,17 @@ class AbstractionPattern {
10221115
/// it.
10231116
AbstractionPattern getReferenceStorageReferentType() const;
10241117

1118+
/// Given that the value being abstracted is a function type, return the
1119+
/// abstraction pattern for the derivative function.
1120+
///
1121+
/// The arguments are the same as the arguments to
1122+
/// `AnyFunctionType::getAutoDiffDerivativeFunctionType()`.
1123+
AbstractionPattern getAutoDiffDerivativeFunctionType(
1124+
IndexSubset *parameterIndices, AutoDiffDerivativeFunctionKind kind,
1125+
LookupConformanceFn lookupConformance,
1126+
GenericSignature derivativeGenericSignature = GenericSignature(),
1127+
bool makeSelfParamFirst = false);
1128+
10251129
void dump() const LLVM_ATTRIBUTE_USED;
10261130
void print(raw_ostream &OS) const;
10271131
};

lib/SIL/AbstractionPattern.cpp

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,8 @@ AbstractionPattern::getOptional(AbstractionPattern object) {
183183
case Kind::CXXMethodType:
184184
case Kind::CurriedCXXMethodType:
185185
case Kind::PartialCurriedCXXMethodType:
186+
case Kind::OpaqueFunction:
187+
case Kind::OpaqueDerivativeFunction:
186188
llvm_unreachable("cannot add optionality to non-type abstraction");
187189
case Kind::Opaque:
188190
return AbstractionPattern::getOpaque();
@@ -289,6 +291,8 @@ bool AbstractionPattern::matchesTuple(CanTupleType substType) {
289291
case Kind::CXXMethodType:
290292
case Kind::CurriedCXXMethodType:
291293
case Kind::PartialCurriedCXXMethodType:
294+
case Kind::OpaqueFunction:
295+
case Kind::OpaqueDerivativeFunction:
292296
return false;
293297
case Kind::Opaque:
294298
return true;
@@ -359,6 +363,8 @@ AbstractionPattern::getTupleElementType(unsigned index) const {
359363
case Kind::CXXMethodType:
360364
case Kind::CurriedCXXMethodType:
361365
case Kind::PartialCurriedCXXMethodType:
366+
case Kind::OpaqueFunction:
367+
case Kind::OpaqueDerivativeFunction:
362368
llvm_unreachable("function types are not tuples");
363369
case Kind::Opaque:
364370
return *this;
@@ -486,6 +492,12 @@ AbstractionPattern AbstractionPattern::getFunctionResultType() const {
486492
return AbstractionPattern(getGenericSignatureForFunctionComponent(),
487493
getResultType(getType()),
488494
getObjCMethod()->getReturnType().getTypePtr());
495+
case Kind::OpaqueFunction:
496+
return getOpaque();
497+
case Kind::OpaqueDerivativeFunction:
498+
static SmallVector<AbstractionPattern, 2> elements{getOpaque(),
499+
getOpaqueFunction()};
500+
return getTuple(elements);
489501
}
490502
llvm_unreachable("bad kind");
491503
}
@@ -615,6 +627,10 @@ AbstractionPattern::getFunctionParamType(unsigned index) const {
615627
params[index].getParameterType(),
616628
getClangFunctionParameterType(getClangType(), index));
617629
}
630+
case Kind::OpaqueFunction:
631+
return getOpaque();
632+
case Kind::OpaqueDerivativeFunction:
633+
return getOpaque();
618634
default:
619635
llvm_unreachable("does not have function parameters");
620636
}
@@ -644,6 +660,8 @@ AbstractionPattern AbstractionPattern::getOptionalObjectType() const {
644660
case Kind::CurriedCXXMethodType:
645661
case Kind::PartialCurriedCXXMethodType:
646662
case Kind::Tuple:
663+
case Kind::OpaqueFunction:
664+
case Kind::OpaqueDerivativeFunction:
647665
llvm_unreachable("pattern for function or tuple cannot be for optional");
648666

649667
case Kind::Opaque:
@@ -685,6 +703,8 @@ AbstractionPattern AbstractionPattern::getReferenceStorageReferentType() const {
685703
case Kind::CurriedCXXMethodType:
686704
case Kind::PartialCurriedCXXMethodType:
687705
case Kind::Tuple:
706+
case Kind::OpaqueFunction:
707+
case Kind::OpaqueDerivativeFunction:
688708
return *this;
689709
case Kind::Type:
690710
return AbstractionPattern(getGenericSignature(),
@@ -714,6 +734,12 @@ void AbstractionPattern::print(raw_ostream &out) const {
714734
case Kind::Opaque:
715735
out << "AP::Opaque";
716736
return;
737+
case Kind::OpaqueFunction:
738+
out << "AP::OpaqueFunction";
739+
return;
740+
case Kind::OpaqueDerivativeFunction:
741+
out << "AP::OpaqueDerivativeFunction";
742+
return;
717743
case Kind::Type:
718744
case Kind::Discard:
719745
out << (getKind() == Kind::Type
@@ -877,6 +903,12 @@ const {
877903
case Kind::Tuple:
878904
llvm_unreachable("should not have a tuple pattern matching a struct/enum "
879905
"type");
906+
case Kind::OpaqueFunction:
907+
llvm_unreachable("should not have an opaque function pattern matching a "
908+
"struct/enum type");
909+
case Kind::OpaqueDerivativeFunction:
910+
llvm_unreachable("should not have an opaque derivative function pattern "
911+
"matching a struct/enum type");
880912
case Kind::PartialCurriedObjCMethodType:
881913
case Kind::CurriedObjCMethodType:
882914
case Kind::PartialCurriedCFunctionAsMethodType:
@@ -896,3 +928,27 @@ const {
896928
return AbstractionPattern(getGenericSignature(), memberTy);
897929
}
898930
}
931+
932+
AbstractionPattern AbstractionPattern::getAutoDiffDerivativeFunctionType(
933+
IndexSubset *parameterIndices, AutoDiffDerivativeFunctionKind kind,
934+
LookupConformanceFn lookupConformance,
935+
GenericSignature derivativeGenericSignature, bool makeSelfParamFirst) {
936+
switch (getKind()) {
937+
case Kind::Type: {
938+
auto fnTy = dyn_cast<AnyFunctionType>(getType());
939+
if (!fnTy)
940+
return getOpaqueDerivativeFunction();
941+
auto derivativeFnTy = fnTy->getAutoDiffDerivativeFunctionType(
942+
parameterIndices, kind, lookupConformance, derivativeGenericSignature,
943+
makeSelfParamFirst);
944+
assert(derivativeFnTy);
945+
return AbstractionPattern(
946+
getGenericSignature(),
947+
derivativeFnTy->getCanonicalType(getGenericSignature()));
948+
}
949+
case Kind::Opaque:
950+
return getOpaqueDerivativeFunction();
951+
default:
952+
llvm_unreachable("called on unsupported abstraction pattern kind");
953+
}
954+
}

lib/SIL/SILFunctionType.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1082,7 +1082,8 @@ class SubstFunctionTypeCollector {
10821082
// type.
10831083

10841084
// The entire original context could be a generic parameter.
1085-
if (origType.isTypeParameter()) {
1085+
if (origType.isTypeParameter() ||
1086+
origType.isOpaqueFunctionOrOpaqueDerivativeFunction()) {
10861087
return addSubstitution(origType.getLayoutConstraint(), substType,
10871088
nullptr, {});
10881089
}
@@ -1253,6 +1254,10 @@ class DestructureResults {
12531254
|| substTL.isAddressOnly()) {
12541255
return true;
12551256

1257+
// Functions are always returned directly.
1258+
} else if (origType.isOpaqueFunctionOrOpaqueDerivativeFunction()) {
1259+
return false;
1260+
12561261
// If the substitution didn't change the type, then a negative
12571262
// response to the above is determinative as well.
12581263
} else if (origType.getType() == substType &&

lib/SIL/TypeLowering.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,8 @@ namespace {
268268

269269
RetTy visitAbstractTypeParamType(CanType type,
270270
AbstractionPattern origType) {
271-
if (origType.isTypeParameterOrOpaqueArchetype()) {
271+
if (origType.isTypeParameterOrOpaqueArchetype() ||
272+
origType.isOpaqueFunctionOrOpaqueDerivativeFunction()) {
272273
if (origType.requiresClass()) {
273274
return asImpl().handleReference(type);
274275
} else {

0 commit comments

Comments
 (0)