Skip to content

Commit 2831548

Browse files
authored
[AutoDiff upstream] Serialize derivative function configurations. (#30672)
Serialize derivative function configurations per module. `@differentiable` and `@derivative` attributes register derivatives for `AbstractFunctionDecl`s for a particular "derivative function configuration": parameter indices and dervative generic signature. To find `@derivative` functions registered in other Swift modules, derivative function configurations must be serialized per module. When configurations for a `AbstractFunctionDecl` are requested, all configurations from imported modules are deserialized. This module serialization technique has precedent: it is used for protocol conformances (e.g. extension declarations for a nominal type) and Obj-C members for a class type. Add `AbstractFunctionDecl::getDerivativeFunctionConfigurations` entry point for accessing derivative function configurations. In the differentiation transform: use `AbstractFunctionDecl::getDerivativeFunctionConfigurations` to implement `findMinimalDerivativeConfiguration` for canonical derivative function configuration lookup, replacing `getMinimalASTDifferentiableAttr`. Resolves TF-1100.
1 parent bbe86e9 commit 2831548

15 files changed

+346
-13
lines changed

include/swift/AST/ASTContext.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -741,6 +741,19 @@ class ASTContext final {
741741
unsigned previousGeneration,
742742
llvm::TinyPtrVector<AbstractFunctionDecl *> &methods);
743743

744+
/// Load derivative function configurations for the given
745+
/// AbstractFunctionDecl.
746+
///
747+
/// \param originalAFD The declaration whose derivative function
748+
/// configurations should be loaded.
749+
///
750+
/// \param previousGeneration The previous generation number. The AST already
751+
/// contains derivative function configurations loaded from any generation up
752+
/// to and including this one.
753+
void loadDerivativeFunctionConfigurations(
754+
AbstractFunctionDecl *originalAFD, unsigned previousGeneration,
755+
llvm::SetVector<AutoDiffConfig> &results);
756+
744757
/// Retrieve the Clang module loader for this ASTContext.
745758
///
746759
/// If there is no Clang module loader, returns a null pointer.

include/swift/AST/Decl.h

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5796,6 +5796,7 @@ class AbstractFunctionDecl : public GenericContext, public ValueDecl {
57965796
private:
57975797
ParameterList *Params;
57985798

5799+
private:
57995800
/// The generation at which we last loaded derivative function configurations.
58005801
unsigned DerivativeFunctionConfigGeneration = 0;
58015802
/// Prepare to traverse the list of derivative function configurations.
@@ -5810,6 +5811,13 @@ class AbstractFunctionDecl : public GenericContext, public ValueDecl {
58105811
struct DerivativeFunctionConfigurationList;
58115812
DerivativeFunctionConfigurationList *DerivativeFunctionConfigs = nullptr;
58125813

5814+
public:
5815+
/// Get all derivative function configurations.
5816+
ArrayRef<AutoDiffConfig> getDerivativeFunctionConfigurations();
5817+
5818+
/// Add the given derivative function configuration.
5819+
void addDerivativeFunctionConfiguration(AutoDiffConfig config);
5820+
58135821
protected:
58145822
// If a function has a body at all, we have either a parsed body AST node or
58155823
// we have saved the end location of the unparsed body.
@@ -6129,12 +6137,6 @@ class AbstractFunctionDecl : public GenericContext, public ValueDecl {
61296137
/// constructor.
61306138
bool hasDynamicSelfResult() const;
61316139

6132-
/// Get all derivative function configurations.
6133-
ArrayRef<AutoDiffConfig> getDerivativeFunctionConfigurations();
6134-
6135-
/// Add the given derivative function configuration.
6136-
void addDerivativeFunctionConfiguration(AutoDiffConfig config);
6137-
61386140
using DeclContext::operator new;
61396141
using Decl::getASTContext;
61406142
};

include/swift/AST/ModuleLoader.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ class DependencyCollector;
3636
namespace swift {
3737

3838
class AbstractFunctionDecl;
39+
struct AutoDiffConfig;
3940
class ClangImporterOptions;
4041
class ClassDecl;
4142
class FileUnit;
@@ -153,6 +154,23 @@ class ModuleLoader {
153154
unsigned previousGeneration,
154155
llvm::TinyPtrVector<AbstractFunctionDecl *> &methods) = 0;
155156

157+
/// Load derivative function configurations for the given
158+
/// AbstractFunctionDecl.
159+
///
160+
/// \param originalAFD The declaration whose derivative function
161+
/// configurations should be loaded.
162+
///
163+
/// \param previousGeneration The previous generation number. The AST already
164+
/// contains derivative function configurations loaded from any generation up
165+
/// to and including this one.
166+
///
167+
/// \param results The result list of derivative function configurations.
168+
/// This list will be extended with any methods found in subsequent
169+
/// generations.
170+
virtual void loadDerivativeFunctionConfigurations(
171+
AbstractFunctionDecl *originalAFD, unsigned previousGeneration,
172+
llvm::SetVector<AutoDiffConfig> &results) {};
173+
156174
/// Verify all modules loaded by this loader.
157175
virtual void verifyAllModules() { }
158176

include/swift/Serialization/SerializedModuleLoader.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,10 @@ class SerializedModuleLoaderBase : public ModuleLoader {
184184
unsigned previousGeneration,
185185
llvm::TinyPtrVector<AbstractFunctionDecl *> &methods) override;
186186

187+
virtual void loadDerivativeFunctionConfigurations(
188+
AbstractFunctionDecl *originalAFD, unsigned previousGeneration,
189+
llvm::SetVector<AutoDiffConfig> &results) override;
190+
187191
virtual void verifyAllModules() override;
188192
};
189193

lib/AST/ASTContext.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1471,6 +1471,17 @@ void ASTContext::loadObjCMethods(
14711471
}
14721472
}
14731473

1474+
void ASTContext::loadDerivativeFunctionConfigurations(
1475+
AbstractFunctionDecl *originalAFD, unsigned previousGeneration,
1476+
llvm::SetVector<AutoDiffConfig> &results) {
1477+
PrettyStackTraceDecl stackTrace(
1478+
"loading derivative function configurations for", originalAFD);
1479+
for (auto &loader : getImpl().ModuleLoaders) {
1480+
loader->loadDerivativeFunctionConfigurations(originalAFD,
1481+
previousGeneration, results);
1482+
}
1483+
}
1484+
14741485
void ASTContext::verifyAllLoadedModules() const {
14751486
#ifndef NDEBUG
14761487
FrontendStatsTracer tracer(Stats, "verify-all-loaded-modules");

lib/AST/Decl.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7100,8 +7100,10 @@ AbstractFunctionDecl::getDerivativeFunctionConfigurations() {
71007100
prepareDerivativeFunctionConfigurations();
71017101
auto &ctx = getASTContext();
71027102
if (ctx.getCurrentGeneration() > DerivativeFunctionConfigGeneration) {
7103-
// TODO(TF-1100): Upstream derivative function configuration serialization
7104-
// logic.
7103+
unsigned previousGeneration = DerivativeFunctionConfigGeneration;
7104+
DerivativeFunctionConfigGeneration = ctx.getCurrentGeneration();
7105+
ctx.loadDerivativeFunctionConfigurations(this, previousGeneration,
7106+
*DerivativeFunctionConfigs);
71057107
}
71067108
return DerivativeFunctionConfigs->getArrayRef();
71077109
}

lib/Sema/TypeCheckAttr.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3971,6 +3971,10 @@ llvm::Expected<IndexSubset *> DifferentiableAttributeTypeCheckRequest::evaluate(
39713971
return nullptr;
39723972
}
39733973
getterDecl->getAttrs().add(newAttr);
3974+
// Register derivative function configuration.
3975+
auto *resultIndices = IndexSubset::get(ctx, 1, {0});
3976+
getterDecl->addDerivativeFunctionConfiguration(
3977+
{resolvedDiffParamIndices, resultIndices, derivativeGenSig});
39743978
return resolvedDiffParamIndices;
39753979
}
39763980
// Reject duplicate `@differentiable` attributes.
@@ -4342,6 +4346,12 @@ static bool typeCheckDerivativeAttr(ASTContext &Ctx, Decl *D,
43424346
return true;
43434347
}
43444348

4349+
// Register derivative function configuration.
4350+
auto *resultIndices = IndexSubset::get(Ctx, 1, {0});
4351+
originalAFD->addDerivativeFunctionConfiguration(
4352+
{resolvedDiffParamIndices, resultIndices,
4353+
derivative->getGenericSignature()});
4354+
43454355
return false;
43464356
}
43474357

lib/Sema/TypeCheckProtocol.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -438,6 +438,11 @@ matchWitnessDifferentiableAttr(DeclContext *dc, ValueDecl *req,
438438
} else {
439439
witness->getAttrs().add(newAttr);
440440
success = true;
441+
// Register derivative function configuration.
442+
auto *resultIndices = IndexSubset::get(ctx, 1, {0});
443+
witnessAFD->addDerivativeFunctionConfiguration(
444+
{newAttr->getParameterIndices(), resultIndices,
445+
newAttr->getDerivativeGenericSignature()});
441446
}
442447
}
443448
if (!success) {

lib/Serialization/DeclTypeRecordNodes.def

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,8 @@ OTHER(XREF_OPAQUE_RETURN_TYPE_PATH_PIECE, 252)
192192

193193
OTHER(CLANG_TYPE, 253)
194194

195+
OTHER(DERIVATIVE_FUNCTION_CONFIGURATION, 254)
196+
195197
#undef RECORD
196198
#undef DECLTYPERECORDNODES_HAS_RECORD_VAL
197199
#undef RECORD_VAL

lib/Serialization/ModuleFile.cpp

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -910,6 +910,66 @@ ModuleFile::readObjCMethodTable(ArrayRef<uint64_t> fields, StringRef blobData) {
910910
base + sizeof(uint32_t), base));
911911
}
912912

913+
/// Used to deserialize entries in the on-disk derivative function configuration
914+
/// table.
915+
class ModuleFile::DerivativeFunctionConfigTableInfo {
916+
public:
917+
using internal_key_type = StringRef;
918+
using external_key_type = internal_key_type;
919+
using data_type = SmallVector<std::pair<std::string, GenericSignatureID>, 8>;
920+
using hash_value_type = uint32_t;
921+
using offset_type = unsigned;
922+
923+
external_key_type GetExternalKey(internal_key_type ID) { return ID; }
924+
925+
internal_key_type GetInternalKey(external_key_type ID) { return ID; }
926+
927+
hash_value_type ComputeHash(internal_key_type key) {
928+
return llvm::djbHash(key, SWIFTMODULE_HASH_SEED);
929+
}
930+
931+
static bool EqualKey(internal_key_type lhs, internal_key_type rhs) {
932+
return lhs == rhs;
933+
}
934+
935+
static std::pair<unsigned, unsigned> ReadKeyDataLength(const uint8_t *&data) {
936+
unsigned keyLength = endian::readNext<uint16_t, little, unaligned>(data);
937+
unsigned dataLength = endian::readNext<uint16_t, little, unaligned>(data);
938+
return {keyLength, dataLength};
939+
}
940+
941+
static internal_key_type ReadKey(const uint8_t *data, unsigned length) {
942+
return StringRef(reinterpret_cast<const char *>(data), length);
943+
}
944+
945+
static data_type ReadData(internal_key_type key, const uint8_t *data,
946+
unsigned length) {
947+
data_type result;
948+
const uint8_t *limit = data + length;
949+
while (data < limit) {
950+
DeclID genSigId = endian::readNext<uint32_t, little, unaligned>(data);
951+
int32_t nameLength = endian::readNext<int32_t, little, unaligned>(data);
952+
StringRef mangledName(reinterpret_cast<const char *>(data), nameLength);
953+
data += nameLength;
954+
result.push_back({mangledName, genSigId});
955+
}
956+
return result;
957+
}
958+
};
959+
960+
std::unique_ptr<ModuleFile::SerializedDerivativeFunctionConfigTable>
961+
ModuleFile::readDerivativeFunctionConfigTable(ArrayRef<uint64_t> fields,
962+
StringRef blobData) {
963+
uint32_t tableOffset;
964+
index_block::DerivativeFunctionConfigTableLayout::readRecord(fields,
965+
tableOffset);
966+
auto base = reinterpret_cast<const uint8_t *>(blobData.data());
967+
968+
using OwnedTable = std::unique_ptr<SerializedDerivativeFunctionConfigTable>;
969+
return OwnedTable(SerializedDerivativeFunctionConfigTable::Create(
970+
base + tableOffset, base + sizeof(uint32_t), base));
971+
}
972+
913973
bool ModuleFile::readIndexBlock(llvm::BitstreamCursor &cursor) {
914974
if (llvm::Error Err = cursor.EnterSubBlock(INDEX_BLOCK_ID)) {
915975
// FIXME this drops the error on the floor.
@@ -1015,6 +1075,10 @@ bool ModuleFile::readIndexBlock(llvm::BitstreamCursor &cursor) {
10151075
case index_block::OBJC_METHODS:
10161076
ObjCMethods = readObjCMethodTable(scratch, blobData);
10171077
break;
1078+
case index_block::DERIVATIVE_FUNCTION_CONFIGURATIONS:
1079+
DerivativeFunctionConfigurations =
1080+
readDerivativeFunctionConfigTable(scratch, blobData);
1081+
break;
10181082
case index_block::ENTRY_POINT:
10191083
assert(blobData.empty());
10201084
setEntryPointClassID(scratch.front());
@@ -2405,6 +2469,34 @@ void ModuleFile::loadObjCMethods(
24052469
}
24062470
}
24072471

2472+
void ModuleFile::loadDerivativeFunctionConfigurations(
2473+
AbstractFunctionDecl *originalAFD,
2474+
llvm::SetVector<AutoDiffConfig> &results) {
2475+
if (!DerivativeFunctionConfigurations)
2476+
return;
2477+
auto &ctx = originalAFD->getASTContext();
2478+
Mangle::ASTMangler Mangler;
2479+
auto mangledName = Mangler.mangleDeclAsUSR(originalAFD, "");
2480+
auto configs = DerivativeFunctionConfigurations->find(mangledName);
2481+
if (configs == DerivativeFunctionConfigurations->end())
2482+
return;
2483+
for (auto entry : *configs) {
2484+
auto *parameterIndices = IndexSubset::getFromString(ctx, entry.first);
2485+
auto derivativeGenSigOrError = getGenericSignatureChecked(entry.second);
2486+
if (!derivativeGenSigOrError) {
2487+
if (!getContext().LangOpts.EnableDeserializationRecovery)
2488+
fatal(derivativeGenSigOrError.takeError());
2489+
llvm::consumeError(derivativeGenSigOrError.takeError());
2490+
}
2491+
auto derivativeGenSig = derivativeGenSigOrError.get();
2492+
// NOTE(TF-1038): Result indices are currently unsupported in derivative
2493+
// registration attributes. In the meantime, always use `{0}` (wrt the
2494+
// first and only result).
2495+
auto resultIndices = IndexSubset::get(ctx, 1, {0});
2496+
results.insert({parameterIndices, resultIndices, derivativeGenSig});
2497+
}
2498+
}
2499+
24082500
TinyPtrVector<ValueDecl *>
24092501
ModuleFile::loadNamedMembers(const IterableDeclContext *IDC, DeclBaseName N,
24102502
uint64_t contextData) {

0 commit comments

Comments
 (0)