Skip to content

Commit eefe9a0

Browse files
author
marcrasi
authored
Merge pull request swiftlang#30851 from apple/derivative-attr-serialization
[AutoDiff] SR-12526: cross-module @Derivative deserialization
2 parents 99993a9 + 7abf8ae commit eefe9a0

File tree

16 files changed

+116
-30
lines changed

16 files changed

+116
-30
lines changed

include/swift/AST/ASTTypeIDZone.def

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ SWIFT_TYPEID(Type)
2929
SWIFT_TYPEID(TypePair)
3030
SWIFT_TYPEID(TypeWitnessAndDecl)
3131
SWIFT_TYPEID(Witness)
32+
SWIFT_TYPEID_NAMED(AbstractFunctionDecl *, AbstractFunctionDecl)
3233
SWIFT_TYPEID_NAMED(ClosureExpr *, ClosureExpr)
3334
SWIFT_TYPEID_NAMED(CodeCompletionCallbacksFactory *,
3435
CodeCompletionCallbacksFactory)

include/swift/AST/Attr.h

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1865,6 +1865,7 @@ class DerivativeAttr final
18651865
: public DeclAttribute,
18661866
private llvm::TrailingObjects<DerivativeAttr, ParsedAutoDiffParameter> {
18671867
friend TrailingObjects;
1868+
friend class DerivativeAttrOriginalDeclRequest;
18681869

18691870
/// The base type for the referenced original declaration. This field is
18701871
/// non-null only for parsed attributes that reference a qualified original
@@ -1873,8 +1874,24 @@ class DerivativeAttr final
18731874
TypeRepr *BaseTypeRepr;
18741875
/// The original function name.
18751876
DeclNameRefWithLoc OriginalFunctionName;
1876-
/// The original function declaration, resolved by the type checker.
1877-
AbstractFunctionDecl *OriginalFunction = nullptr;
1877+
/// The original function.
1878+
///
1879+
/// The states are:
1880+
/// - nullptr:
1881+
/// The original function is unknown. The typechecker is responsible for
1882+
/// eventually resolving it.
1883+
/// - AbstractFunctionDecl:
1884+
/// The original function is known to be this `AbstractFunctionDecl`.
1885+
/// - LazyMemberLoader:
1886+
/// This `LazyMemberLoader` knows how to resolve the original function.
1887+
/// `ResolverContextData` is an additional piece of data that the
1888+
/// `LazyMemberLoader` needs.
1889+
// TODO(TF-1235): Making `DerivativeAttr` immutable will simplify this by
1890+
// removing the `AbstractFunctionDecl` state.
1891+
llvm::PointerUnion<AbstractFunctionDecl *, LazyMemberLoader *> OriginalFunction;
1892+
/// Data representing the original function declaration. See doc comment for
1893+
/// `OriginalFunction`.
1894+
uint64_t ResolverContextData = 0;
18781895
/// The number of parsed differentiability parameters specified in 'wrt:'.
18791896
unsigned NumParsedParameters = 0;
18801897
/// The differentiability parameter indices, resolved by the type checker.
@@ -1907,12 +1924,10 @@ class DerivativeAttr final
19071924
DeclNameRefWithLoc getOriginalFunctionName() const {
19081925
return OriginalFunctionName;
19091926
}
1910-
AbstractFunctionDecl *getOriginalFunction() const {
1911-
return OriginalFunction;
1912-
}
1913-
void setOriginalFunction(AbstractFunctionDecl *decl) {
1914-
OriginalFunction = decl;
1915-
}
1927+
AbstractFunctionDecl *getOriginalFunction(ASTContext &context) const;
1928+
void setOriginalFunction(AbstractFunctionDecl *decl);
1929+
void setOriginalFunctionResolver(LazyMemberLoader *resolver,
1930+
uint64_t resolverContextData);
19161931

19171932
AutoDiffDerivativeFunctionKind getDerivativeKind() const {
19181933
assert(Kind && "Derivative function kind has not yet been resolved");

include/swift/AST/LazyResolver.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,12 @@ class alignas(void*) LazyMemberLoader {
106106
loadDynamicallyReplacedFunctionDecl(const DynamicReplacementAttr *DRA,
107107
uint64_t contextData) = 0;
108108

109+
/// Returns the referenced original declaration for a `@derivative(of:)`
110+
/// attribute.
111+
virtual AbstractFunctionDecl *
112+
loadReferencedFunctionDecl(const DerivativeAttr *DA,
113+
uint64_t contextData) = 0;
114+
109115
/// Returns the type for a given @_typeEraser() attribute.
110116
virtual Type loadTypeEraserType(const TypeEraserAttr *TRA,
111117
uint64_t contextData) = 0;

include/swift/AST/TypeCheckRequests.h

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2130,6 +2130,26 @@ class DifferentiableAttributeTypeCheckRequest
21302130
void cacheResult(IndexSubset *value) const;
21312131
};
21322132

2133+
/// Resolves the referenced original declaration for a `@derivative` attribute.
2134+
class DerivativeAttrOriginalDeclRequest
2135+
: public SimpleRequest<DerivativeAttrOriginalDeclRequest,
2136+
AbstractFunctionDecl *(DerivativeAttr *),
2137+
RequestFlags::Cached> {
2138+
public:
2139+
using SimpleRequest::SimpleRequest;
2140+
2141+
private:
2142+
friend SimpleRequest;
2143+
2144+
// Evaluation.
2145+
AbstractFunctionDecl *evaluate(Evaluator &evaluator,
2146+
DerivativeAttr *attr) const;
2147+
2148+
public:
2149+
// Caching.
2150+
bool isCached() const { return true; }
2151+
};
2152+
21332153
/// Checks whether a type eraser has a viable initializer.
21342154
class TypeEraserHasViableInitRequest
21352155
: public SimpleRequest<TypeEraserHasViableInitRequest,

include/swift/AST/TypeCheckerTypeIDZone.def

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,9 @@ SWIFT_REQUEST(TypeChecker, DefaultTypeRequest,
4949
SWIFT_REQUEST(TypeChecker, DifferentiableAttributeTypeCheckRequest,
5050
IndexSubset *(DifferentiableAttr *),
5151
SeparatelyCached, NoLocationInfo)
52+
SWIFT_REQUEST(TypeChecker, DerivativeAttrOriginalDeclRequest,
53+
AbstractFunctionDecl *(DerivativeAttr *),
54+
Cached, NoLocationInfo)
5255
SWIFT_REQUEST(TypeChecker, TypeEraserHasViableInitRequest,
5356
bool(TypeEraserAttr *, ProtocolDecl *),
5457
Cached, NoLocationInfo)

lib/AST/Attr.cpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include "swift/AST/Expr.h"
2222
#include "swift/AST/GenericEnvironment.h"
2323
#include "swift/AST/IndexSubset.h"
24+
#include "swift/AST/LazyResolver.h"
2425
#include "swift/AST/Module.h"
2526
#include "swift/AST/ParameterList.h"
2627
#include "swift/AST/TypeCheckRequests.h"
@@ -1750,6 +1751,26 @@ DerivativeAttr *DerivativeAttr::create(ASTContext &context, bool implicit,
17501751
std::move(originalName), parameterIndices);
17511752
}
17521753

1754+
AbstractFunctionDecl *
1755+
DerivativeAttr::getOriginalFunction(ASTContext &context) const {
1756+
return evaluateOrDefault(
1757+
context.evaluator,
1758+
DerivativeAttrOriginalDeclRequest{const_cast<DerivativeAttr *>(this)},
1759+
nullptr);
1760+
}
1761+
1762+
void DerivativeAttr::setOriginalFunction(AbstractFunctionDecl *decl) {
1763+
assert(!OriginalFunction && "cannot overwrite original function");
1764+
OriginalFunction = decl;
1765+
}
1766+
1767+
void DerivativeAttr::setOriginalFunctionResolver(
1768+
LazyMemberLoader *resolver, uint64_t resolverContextData) {
1769+
assert(!OriginalFunction && "cannot overwrite original function");
1770+
OriginalFunction = resolver;
1771+
ResolverContextData = resolverContextData;
1772+
}
1773+
17531774
TransposeAttr::TransposeAttr(bool implicit, SourceLoc atLoc,
17541775
SourceRange baseRange, TypeRepr *baseTypeRepr,
17551776
DeclNameRefWithLoc originalName,

lib/ClangImporter/ImporterImpl.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1290,6 +1290,12 @@ class LLVM_LIBRARY_VISIBILITY ClangImporter::Implementation
12901290
llvm_unreachable("unimplemented for ClangImporter");
12911291
}
12921292

1293+
AbstractFunctionDecl *
1294+
loadReferencedFunctionDecl(const DerivativeAttr *DA,
1295+
uint64_t contextData) override {
1296+
llvm_unreachable("unimplemented for ClangImporter");
1297+
}
1298+
12931299
Type loadTypeEraserType(const TypeEraserAttr *TRA,
12941300
uint64_t contextData) override {
12951301
llvm_unreachable("unimplemented for ClangImporter");

lib/SILGen/SILGen.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -791,7 +791,7 @@ void SILGenModule::emitDifferentiabilityWitnessesForFunction(
791791
vjp = F;
792792
break;
793793
}
794-
auto *origAFD = derivAttr->getOriginalFunction();
794+
auto *origAFD = derivAttr->getOriginalFunction(getASTContext());
795795
auto origDeclRef =
796796
SILDeclRef(origAFD).asForeign(requiresForeignEntryPoint(origAFD));
797797
auto *origFn = getFunction(origDeclRef, NotForDefinition);

lib/Sema/TypeCheckAttr.cpp

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3705,8 +3705,6 @@ getTransposeOriginalFunctionType(AnyFunctionType *transposeFnType,
37053705
return originalType;
37063706
}
37073707

3708-
3709-
37103708
/// Given a `@differentiable` attribute, attempts to resolve the original
37113709
/// `AbstractFunctionDecl` for which it is registered, using the declaration
37123710
/// on which it is actually declared. On error, emits diagnostic and returns
@@ -4454,6 +4452,21 @@ void AttributeChecker::visitDerivativeAttr(DerivativeAttr *attr) {
44544452
attr->setInvalid();
44554453
}
44564454

4455+
AbstractFunctionDecl *
4456+
DerivativeAttrOriginalDeclRequest::evaluate(Evaluator &evaluator,
4457+
DerivativeAttr *attr) const {
4458+
// If the typechecker has resolved the original function, return it.
4459+
if (auto *FD = attr->OriginalFunction.dyn_cast<AbstractFunctionDecl *>())
4460+
return FD;
4461+
4462+
// If the function can be lazily resolved, do so now.
4463+
if (auto *Resolver = attr->OriginalFunction.dyn_cast<LazyMemberLoader *>())
4464+
return Resolver->loadReferencedFunctionDecl(attr,
4465+
attr->ResolverContextData);
4466+
4467+
return nullptr;
4468+
}
4469+
44574470
/// Returns true if the given type's `TangentVector` is equal to itself in the
44584471
/// given module.
44594472
static bool tangentVectorEqualsSelf(Type type, DeclContext *DC) {

lib/Serialization/Deserialization.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4379,7 +4379,6 @@ llvm::Error DeclDeserializer::deserializeDeclAttributes() {
43794379

43804380
DeclNameRefWithLoc origName{
43814381
DeclNameRef(MF.getDeclBaseName(origNameId)), DeclNameLoc()};
4382-
auto *origDecl = cast<AbstractFunctionDecl>(MF.getDecl(origDeclId));
43834382
auto derivativeKind =
43844383
getActualAutoDiffDerivativeFunctionKind(rawDerivativeKind);
43854384
if (!derivativeKind)
@@ -4392,7 +4391,7 @@ llvm::Error DeclDeserializer::deserializeDeclAttributes() {
43924391
auto *derivativeAttr =
43934392
DerivativeAttr::create(ctx, isImplicit, SourceLoc(), SourceRange(),
43944393
/*baseType*/ nullptr, origName, indices);
4395-
derivativeAttr->setOriginalFunction(origDecl);
4394+
derivativeAttr->setOriginalFunctionResolver(&MF, origDeclId);
43964395
derivativeAttr->setDerivativeKind(*derivativeKind);
43974396
Attr = derivativeAttr;
43984397
break;
@@ -5941,6 +5940,12 @@ ValueDecl *ModuleFile::loadDynamicallyReplacedFunctionDecl(
59415940
return cast<ValueDecl>(getDecl(contextData));
59425941
}
59435942

5943+
AbstractFunctionDecl *
5944+
ModuleFile::loadReferencedFunctionDecl(const DerivativeAttr *DA,
5945+
uint64_t contextData) {
5946+
return cast<AbstractFunctionDecl>(getDecl(contextData));
5947+
}
5948+
59445949
Type ModuleFile::loadTypeEraserType(const TypeEraserAttr *TRA,
59455950
uint64_t contextData) {
59465951
return getType(contextData);

0 commit comments

Comments
 (0)