Skip to content

Commit 9602657

Browse files
authored
[AutoDiff upstream] Upstream serialization changes. (swiftlang#30720)
Upstream random serialization code from tensorflow branch: - `SerializedSILLoader::getAllDifferentiabilityWitnesses()` - Add diff. witness serialization for functions with `inout` parameters. - `deserializeSILFunctionType`: fix assertion for differentiability kind.
1 parent 3d379d2 commit 9602657

File tree

6 files changed

+30
-16
lines changed

6 files changed

+30
-16
lines changed

include/swift/Serialization/SerializedSILLoader.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#ifndef SWIFT_SERIALIZATION_SILLOADER_H
1414
#define SWIFT_SERIALIZATION_SILLOADER_H
1515

16+
#include "swift/AST/AutoDiff.h"
1617
#include "swift/AST/Decl.h"
1718
#include "swift/AST/Identifier.h"
1819
#include "swift/SIL/Notifications.h"
@@ -102,6 +103,9 @@ class SerializedSILLoader {
102103
/// Deserialize all Properties in all SILModules.
103104
void getAllProperties();
104105

106+
/// Deserialize all DifferentiabilityWitnesses in all SILModules.
107+
void getAllDifferentiabilityWitnesses();
108+
105109
SerializedSILLoader(const SerializedSILLoader &) = delete;
106110
SerializedSILLoader(SerializedSILLoader &&) = delete;
107111
SerializedSILLoader &operator=(const SerializedSILLoader &) = delete;

lib/Serialization/Deserialization.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5411,8 +5411,11 @@ class TypeDeserializer {
54115411
};
54125412

54135413
// Bounds check. FIXME: overflow
5414-
if (2 * numParams + 2 * numResults + 2 * unsigned(hasErrorResult)
5415-
> variableData.size()) {
5414+
unsigned entriesPerParam =
5415+
diffKind != DifferentiabilityKind::NonDifferentiable ? 3 : 2;
5416+
if (entriesPerParam * numParams + 2 * numResults +
5417+
2 * unsigned(hasErrorResult) >
5418+
variableData.size()) {
54165419
MF.fatal();
54175420
}
54185421

lib/Serialization/DeserializeSIL.cpp

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3573,19 +3573,22 @@ SILDeserializer::readDifferentiabilityWitness(DeclID DId) {
35733573
}
35743574
auto derivativeGenSig = MF->getGenericSignature(derivativeGenSigID);
35753575

3576+
auto originalFnType = original->getLoweredFunctionType();
35763577
SmallVector<unsigned, 8> parameterAndResultIndices(
35773578
rawParameterAndResultIndices.begin(), rawParameterAndResultIndices.end());
35783579
assert(parameterAndResultIndices.size() ==
35793580
numParameterIndices + numResultIndices &&
35803581
"Parameter/result indices count mismatch");
3581-
auto *parameterIndices = IndexSubset::get(
3582-
MF->getContext(), original->getLoweredFunctionType()->getNumParameters(),
3583-
ArrayRef<unsigned>(parameterAndResultIndices)
3584-
.take_front(numParameterIndices));
3585-
auto *resultIndices = IndexSubset::get(
3586-
MF->getContext(), original->getLoweredFunctionType()->getNumResults(),
3587-
ArrayRef<unsigned>(parameterAndResultIndices)
3588-
.take_back(numResultIndices));
3582+
auto *parameterIndices =
3583+
IndexSubset::get(MF->getContext(), originalFnType->getNumParameters(),
3584+
ArrayRef<unsigned>(parameterAndResultIndices)
3585+
.take_front(numParameterIndices));
3586+
auto numResults = originalFnType->getNumResults() +
3587+
originalFnType->getNumIndirectMutatingParameters();
3588+
auto *resultIndices =
3589+
IndexSubset::get(MF->getContext(), numResults,
3590+
ArrayRef<unsigned>(parameterAndResultIndices)
3591+
.take_back(numResultIndices));
35893592

35903593
AutoDiffConfig config(parameterIndices, resultIndices, derivativeGenSig);
35913594
auto *diffWitness =

lib/Serialization/Serialization.cpp

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2438,11 +2438,6 @@ class Serializer::DeclSerializer : public DeclVisitor<DeclSerializer> {
24382438
case DAK_Transpose: {
24392439
auto abbrCode = S.DeclTypeAbbrCodes[TransposeDeclAttrLayout::Code];
24402440
auto *attr = cast<TransposeAttr>(DA);
2441-
// NOTE(TF-838): `@transpose` attribute serialization is blocked by
2442-
// `@transpose` attribute type-checking (TF-830), which resolves
2443-
// the original declaration.
2444-
if (!attr->getOriginalFunction())
2445-
return;
24462441
assert(attr->getOriginalFunction() &&
24472442
"`@transpose` attribute should have original declaration set "
24482443
"during construction or parsing");

lib/Serialization/SerializeSIL.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2542,7 +2542,11 @@ void SILSerializer::writeSILDifferentiabilityWitness(
25422542
dw.getParameterIndices()->getCapacity() &&
25432543
"Original function parameter count should match differentiability "
25442544
"witness parameter indices capacity");
2545-
assert(originalFnType->getNumResults() ==
2545+
unsigned numInoutParameters = llvm::count_if(
2546+
originalFnType->getParameters(), [](SILParameterInfo paramInfo) {
2547+
return paramInfo.isIndirectMutating();
2548+
});
2549+
assert(originalFnType->getNumResults() + numInoutParameters ==
25462550
dw.getResultIndices()->getCapacity() &&
25472551
"Original function result count should match differentiability "
25482552
"witness result indices capacity");

lib/Serialization/SerializedSILLoader.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,3 +202,8 @@ void SerializedSILLoader::getAllProperties() {
202202
Des->getAllProperties();
203203
}
204204

205+
/// Deserialize all DifferentiabilityWitnesses in all SILModules.
206+
void SerializedSILLoader::getAllDifferentiabilityWitnesses() {
207+
for (auto &Des : LoadedSILSections)
208+
Des->getAllDifferentiabilityWitnesses();
209+
}

0 commit comments

Comments
 (0)