Skip to content

Commit a856d59

Browse files
authored
[AutoDiff upstream] Add @differentiable attribute serialization. (swiftlang#30605)
Serialize "is linear?" flag, differentiability parameter indices, and differentiability generic signature. Deserialization has some ad-hoc logic for setting the original declaration and parameter indices for `@differentiable` attributes because `DeclDeserializer::deserializeDeclAttributes` does not have access to the original declaration. Resolves TF-836.
1 parent 0873622 commit a856d59

File tree

3 files changed

+83
-24
lines changed

3 files changed

+83
-24
lines changed

lib/Serialization/Deserialization.cpp

Lines changed: 68 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2268,6 +2268,29 @@ static bool attributeChainContains(DeclAttribute *attr) {
22682268
return tempAttrs.hasAttribute<DERIVED>();
22692269
}
22702270

2271+
// Set original declaration and parameter indices in `@differentiable`
2272+
// attributes.
2273+
//
2274+
// Serializing/deserializing the original declaration DeclID in
2275+
// `@differentiable` attributes does not work because it causes
2276+
// `@differentiable` attribute deserialization to enter an infinite loop.
2277+
//
2278+
// Instead, call this ad-hoc function after deserializing a declaration to set
2279+
// the original declaration and parameter indices for its `@differentiable`
2280+
// attributes.
2281+
static void setOriginalDeclarationAndParameterIndicesInDifferentiableAttributes(
2282+
Decl *decl, DeclAttribute *attrs,
2283+
llvm::DenseMap<DifferentiableAttr *, IndexSubset *>
2284+
&diffAttrParamIndicesMap) {
2285+
DeclAttributes tempAttrs;
2286+
tempAttrs.setRawAttributeChain(attrs);
2287+
for (auto *attr : tempAttrs.getAttributes<DifferentiableAttr>()) {
2288+
auto *diffAttr = const_cast<DifferentiableAttr *>(attr);
2289+
diffAttr->setOriginalDeclaration(decl);
2290+
diffAttr->setParameterIndices(diffAttrParamIndicesMap[diffAttr]);
2291+
}
2292+
}
2293+
22712294
Decl *ModuleFile::getDecl(DeclID DID) {
22722295
Expected<Decl *> deserialized = getDeclChecked(DID);
22732296
if (!deserialized) {
@@ -2294,6 +2317,9 @@ class DeclDeserializer {
22942317
unsigned localDiscriminator = 0;
22952318
StringRef filenameForPrivate;
22962319

2320+
// Auxiliary map for deserializing `@differentiable` attributes.
2321+
llvm::DenseMap<DifferentiableAttr *, IndexSubset *> diffAttrParamIndicesMap;
2322+
22972323
void AddAttribute(DeclAttribute *Attr) {
22982324
// Advance the linked list.
22992325
// This isn't just using DeclAttributes because that would result in the
@@ -4257,6 +4283,36 @@ llvm::Error DeclDeserializer::deserializeDeclAttributes() {
42574283
break;
42584284
}
42594285

4286+
case decls_block::Differentiable_DECL_ATTR: {
4287+
bool isImplicit;
4288+
bool linear;
4289+
GenericSignatureID derivativeGenSigId;
4290+
ArrayRef<uint64_t> parameters;
4291+
4292+
serialization::decls_block::DifferentiableDeclAttrLayout::readRecord(
4293+
scratch, isImplicit, linear, derivativeGenSigId, parameters);
4294+
4295+
auto derivativeGenSig = MF.getGenericSignature(derivativeGenSigId);
4296+
llvm::SmallBitVector parametersBitVector(parameters.size());
4297+
for (unsigned i : indices(parameters))
4298+
parametersBitVector[i] = parameters[i];
4299+
auto *indices = IndexSubset::get(ctx, parametersBitVector);
4300+
auto *diffAttr = DifferentiableAttr::create(
4301+
ctx, isImplicit, SourceLoc(), SourceRange(), linear,
4302+
/*parsedParameters*/ {}, /*trailingWhereClause*/ nullptr);
4303+
4304+
// Cache parameter indices so that they can set later.
4305+
// `DifferentiableAttr::setParameterIndices` cannot be called here
4306+
// because it requires `DifferentiableAttr::setOriginalDeclaration` to
4307+
// be called first. `DifferentiableAttr::setOriginalDeclaration` cannot
4308+
// be called here because the original declaration is not accessible in
4309+
// this function (`DeclDeserializer::deserializeDeclAttributes`).
4310+
diffAttrParamIndicesMap[diffAttr] = indices;
4311+
diffAttr->setDerivativeGenericSignature(derivativeGenSig);
4312+
Attr = diffAttr;
4313+
break;
4314+
}
4315+
42604316
case decls_block::Derivative_DECL_ATTR: {
42614317
bool isImplicit;
42624318
uint64_t origNameId;
@@ -4391,8 +4447,18 @@ DeclDeserializer::getDeclCheckedImpl(
43914447

43924448
switch (recordID) {
43934449
#define CASE(RECORD_NAME) \
4394-
case decls_block::RECORD_NAME##Layout::Code: \
4395-
return deserialize##RECORD_NAME(scratch, blobData);
4450+
case decls_block::RECORD_NAME##Layout::Code: {\
4451+
auto decl = deserialize##RECORD_NAME(scratch, blobData); \
4452+
if (decl) { \
4453+
/* \
4454+
// Set original declaration and parameter indices in `@differentiable` \
4455+
// attributes. \
4456+
*/ \
4457+
setOriginalDeclarationAndParameterIndicesInDifferentiableAttributes(\
4458+
decl.get(), DAttrs, diffAttrParamIndicesMap); \
4459+
} \
4460+
return decl; \
4461+
}
43964462

43974463
CASE(TypeAlias)
43984464
CASE(GenericTypeParamDecl)

lib/Serialization/Serialization.cpp

Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2395,23 +2395,20 @@ class Serializer::DeclSerializer : public DeclVisitor<DeclSerializer> {
23952395
case DAK_Differentiable: {
23962396
auto abbrCode = S.DeclTypeAbbrCodes[DifferentiableDeclAttrLayout::Code];
23972397
auto *attr = cast<DifferentiableAttr>(DA);
2398-
2399-
auto paramIndices = attr->getParameterIndices();
2400-
// NOTE(TF-836): `@differentiable` attribute serialization is blocked by
2401-
// `@differentiable` attribute type-checking (TF-828), which resolves
2402-
// parameter indices (`IndexSubset *`).
2403-
if (!paramIndices)
2404-
return;
2398+
assert(attr->getOriginalDeclaration() &&
2399+
"`@differentiable` attribute should have original declaration set "
2400+
"during construction or parsing");
2401+
auto *paramIndices = attr->getParameterIndices();
24052402
assert(paramIndices && "Parameter indices must be resolved");
2406-
SmallVector<bool, 4> indices;
2403+
SmallVector<bool, 4> paramIndicesVector;
24072404
for (unsigned i : range(paramIndices->getCapacity()))
2408-
indices.push_back(paramIndices->contains(i));
2405+
paramIndicesVector.push_back(paramIndices->contains(i));
24092406

24102407
DifferentiableDeclAttrLayout::emitRecord(
24112408
S.Out, S.ScratchRecord, abbrCode, attr->isImplicit(),
24122409
attr->isLinear(),
24132410
S.addGenericSignatureRef(attr->getDerivativeGenericSignature()),
2414-
indices);
2411+
paramIndicesVector);
24152412
return;
24162413
}
24172414

@@ -2428,12 +2425,12 @@ class Serializer::DeclSerializer : public DeclVisitor<DeclSerializer> {
24282425
getRawStableAutoDiffDerivativeFunctionKind(attr->getDerivativeKind());
24292426
auto *parameterIndices = attr->getParameterIndices();
24302427
assert(parameterIndices && "Parameter indices must be resolved");
2431-
SmallVector<bool, 4> indices;
2428+
SmallVector<bool, 4> paramIndicesVector;
24322429
for (unsigned i : range(parameterIndices->getCapacity()))
2433-
indices.push_back(parameterIndices->contains(i));
2430+
paramIndicesVector.push_back(parameterIndices->contains(i));
24342431
DerivativeDeclAttrLayout::emitRecord(
24352432
S.Out, S.ScratchRecord, abbrCode, attr->isImplicit(), origNameId,
2436-
origDeclID, derivativeKind, indices);
2433+
origDeclID, derivativeKind, paramIndicesVector);
24372434
return;
24382435
}
24392436

@@ -2453,12 +2450,12 @@ class Serializer::DeclSerializer : public DeclVisitor<DeclSerializer> {
24532450
DeclID origDeclID = S.addDeclRef(attr->getOriginalFunction());
24542451
auto *parameterIndices = attr->getParameterIndices();
24552452
assert(parameterIndices && "Parameter indices must be resolved");
2456-
SmallVector<bool, 4> indices;
2453+
SmallVector<bool, 4> paramIndicesVector;
24572454
for (unsigned i : range(parameterIndices->getCapacity()))
2458-
indices.push_back(parameterIndices->contains(i));
2455+
paramIndicesVector.push_back(parameterIndices->contains(i));
24592456
TransposeDeclAttrLayout::emitRecord(
24602457
S.Out, S.ScratchRecord, abbrCode, attr->isImplicit(), origNameId,
2461-
origDeclID, indices);
2458+
origDeclID, paramIndicesVector);
24622459
return;
24632460
}
24642461
}

test/AutoDiff/Serialization/differentiable_attr.swift

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,7 @@
11
// RUN: %empty-directory(%t)
2-
// RUN: %target-swift-frontend %s -emit-module -parse-as-library -o %t
2+
// RUN: %target-swift-frontend -enable-experimental-differentiable-programming %s -emit-module -parse-as-library -o %t
33
// RUN: llvm-bcanalyzer %t/differentiable_attr.swiftmodule | %FileCheck %s -check-prefix=BCANALYZER
4-
// RUN: %target-sil-opt -disable-sil-linking -enable-sil-verify-all %t/differentiable_attr.swiftmodule -o - | %FileCheck %s
5-
6-
// TODO(TF-836): Enable this test.
7-
// Blocked by TF-828: `@differentiable` attribute type-checking.
8-
// XFAIL: *
4+
// RUN: %target-sil-opt -enable-experimental-differentiable-programming -disable-sil-linking -enable-sil-verify-all %t/differentiable_attr.swiftmodule -o - | %FileCheck %s
95

106
// BCANALYZER-NOT: UnknownCode
117

0 commit comments

Comments
 (0)