@@ -119,10 +119,11 @@ bool swift::requiresForeignEntryPoint(ValueDecl *vd) {
119
119
SILDeclRef::SILDeclRef (ValueDecl *vd, SILDeclRef::Kind kind, bool isForeign,
120
120
AutoDiffDerivativeFunctionIdentifier *derivativeId)
121
121
: loc(vd), kind(kind), isForeign(isForeign), defaultArgIndex(0 ),
122
- derivativeFunctionIdentifier (derivativeId) {}
122
+ pointer (derivativeId) {}
123
123
124
124
SILDeclRef::SILDeclRef (SILDeclRef::Loc baseLoc, bool asForeign)
125
- : defaultArgIndex(0 ), derivativeFunctionIdentifier(nullptr ) {
125
+ : defaultArgIndex(0 ),
126
+ pointer((AutoDiffDerivativeFunctionIdentifier *)nullptr) {
126
127
if (auto *vd = baseLoc.dyn_cast <ValueDecl*>()) {
127
128
if (auto *fd = dyn_cast<FuncDecl>(vd)) {
128
129
// Map FuncDecls directly to Func SILDeclRefs.
@@ -164,7 +165,7 @@ SILDeclRef::SILDeclRef(SILDeclRef::Loc baseLoc, bool asForeign)
164
165
SILDeclRef::SILDeclRef (SILDeclRef::Loc baseLoc,
165
166
GenericSignature prespecializedSig)
166
167
: SILDeclRef(baseLoc, false ) {
167
- specializedSignature = prespecializedSig;
168
+ pointer = prespecializedSig. getPointer () ;
168
169
}
169
170
170
171
Optional<AnyFunctionRef> SILDeclRef::getAnyFunctionRef () const {
@@ -232,7 +233,7 @@ bool SILDeclRef::isImplicit() const {
232
233
SILLinkage SILDeclRef::getLinkage (ForDefinition_t forDefinition) const {
233
234
234
235
// Prespecializations are public.
235
- if (specializedSignature ) {
236
+ if (getSpecializedSignature () ) {
236
237
return SILLinkage::Public;
237
238
}
238
239
@@ -678,6 +679,7 @@ std::string SILDeclRef::mangle(ManglingKind MKind) const {
678
679
using namespace Mangle ;
679
680
ASTMangler mangler;
680
681
682
+ auto *derivativeFunctionIdentifier = getDerivativeFunctionIdentifier ();
681
683
if (derivativeFunctionIdentifier) {
682
684
std::string originalMangled = asAutoDiffOriginalFunction ().mangle (MKind);
683
685
auto *silParameterIndices = autodiff::getLoweredParameterIndices (
@@ -716,14 +718,15 @@ std::string SILDeclRef::mangle(ManglingKind MKind) const {
716
718
}
717
719
718
720
// Mangle prespecializations.
719
- if (specializedSignature ) {
721
+ if (getSpecializedSignature () ) {
720
722
SILDeclRef nonSpecializedDeclRef = *this ;
721
- nonSpecializedDeclRef.specializedSignature = GenericSignature ();
723
+ nonSpecializedDeclRef.pointer =
724
+ (AutoDiffDerivativeFunctionIdentifier *)nullptr ;
722
725
auto mangledNonSpecializedString = nonSpecializedDeclRef.mangle ();
723
726
auto *funcDecl = cast<AbstractFunctionDecl>(getDecl ());
724
727
auto genericSig = funcDecl->getGenericSignature ();
725
728
return GenericSpecializationMangler::manglePrespecialization (
726
- mangledNonSpecializedString, genericSig, specializedSignature );
729
+ mangledNonSpecializedString, genericSig, getSpecializedSignature () );
727
730
}
728
731
729
732
ASTMangler::SymbolKind SKind = ASTMangler::SymbolKind::Default;
@@ -818,7 +821,7 @@ std::string SILDeclRef::mangle(ManglingKind MKind) const {
818
821
// Returns true if the given JVP/VJP SILDeclRef requires a new vtable entry.
819
822
// FIXME(TF-1213): Also consider derived declaration `@derivative` attributes.
820
823
static bool derivativeFunctionRequiresNewVTableEntry (SILDeclRef declRef) {
821
- assert (declRef.derivativeFunctionIdentifier &&
824
+ assert (declRef.getDerivativeFunctionIdentifier () &&
822
825
" Expected a derivative function SILDeclRef" );
823
826
auto overridden = declRef.getOverridden ();
824
827
if (!overridden)
@@ -828,7 +831,7 @@ static bool derivativeFunctionRequiresNewVTableEntry(SILDeclRef declRef) {
828
831
declRef.getDecl ()->getAttrs ().getAttributes <DifferentiableAttr>(),
829
832
[&](const DifferentiableAttr *derivedDiffAttr) {
830
833
return derivedDiffAttr->getParameterIndices () ==
831
- declRef.derivativeFunctionIdentifier ->getParameterIndices ();
834
+ declRef.getDerivativeFunctionIdentifier () ->getParameterIndices ();
832
835
});
833
836
assert (derivedDiffAttr && " Expected `@differentiable` attribute" );
834
837
// Otherwise, if the base `@differentiable` attribute specifies a derivative
@@ -838,7 +841,7 @@ static bool derivativeFunctionRequiresNewVTableEntry(SILDeclRef declRef) {
838
841
overridden.getDecl ()->getAttrs ().getAttributes <DifferentiableAttr>();
839
842
for (auto *baseDiffAttr : baseDiffAttrs) {
840
843
if (baseDiffAttr->getParameterIndices () ==
841
- declRef.derivativeFunctionIdentifier ->getParameterIndices ())
844
+ declRef.getDerivativeFunctionIdentifier () ->getParameterIndices ())
842
845
return false ;
843
846
}
844
847
// Otherwise, if there is no base `@differentiable` attribute exists, then a
@@ -847,7 +850,7 @@ static bool derivativeFunctionRequiresNewVTableEntry(SILDeclRef declRef) {
847
850
}
848
851
849
852
bool SILDeclRef::requiresNewVTableEntry () const {
850
- if (derivativeFunctionIdentifier )
853
+ if (getDerivativeFunctionIdentifier () )
851
854
if (derivativeFunctionRequiresNewVTableEntry (*this ))
852
855
return true ;
853
856
if (!hasDecl ())
@@ -928,15 +931,16 @@ SILDeclRef SILDeclRef::getNextOverriddenVTableEntry() const {
928
931
929
932
// JVPs/VJPs are overridden only if the base declaration has a
930
933
// `@differentiable` attribute with the same parameter indices.
931
- if (derivativeFunctionIdentifier ) {
934
+ if (getDerivativeFunctionIdentifier () ) {
932
935
auto overriddenAttrs =
933
936
overridden.getDecl ()->getAttrs ().getAttributes <DifferentiableAttr>();
934
937
for (const auto *attr : overriddenAttrs) {
935
938
if (attr->getParameterIndices () !=
936
- derivativeFunctionIdentifier ->getParameterIndices ())
939
+ getDerivativeFunctionIdentifier () ->getParameterIndices ())
937
940
continue ;
938
- auto *overriddenDerivativeId = overridden.derivativeFunctionIdentifier ;
939
- overridden.derivativeFunctionIdentifier =
941
+ auto *overriddenDerivativeId =
942
+ overridden.getDerivativeFunctionIdentifier ();
943
+ overridden.pointer =
940
944
AutoDiffDerivativeFunctionIdentifier::get (
941
945
overriddenDerivativeId->getKind (),
942
946
overriddenDerivativeId->getParameterIndices (),
0 commit comments