@@ -2268,6 +2268,29 @@ static bool attributeChainContains(DeclAttribute *attr) {
2268
2268
return tempAttrs.hasAttribute <DERIVED>();
2269
2269
}
2270
2270
2271
+ // Set original declaration and parameter indices in `@differentiable`
2272
+ // attributes.
2273
+ //
2274
+ // Serializing/deserializing the original declaration DeclID in
2275
+ // `@differentiable` attributes does not work because it causes
2276
+ // `@differentiable` attribute deserialization to enter an infinite loop.
2277
+ //
2278
+ // Instead, call this ad-hoc function after deserializing a declaration to set
2279
+ // the original declaration and parameter indices for its `@differentiable`
2280
+ // attributes.
2281
+ static void setOriginalDeclarationAndParameterIndicesInDifferentiableAttributes (
2282
+ Decl *decl, DeclAttribute *attrs,
2283
+ llvm::DenseMap<DifferentiableAttr *, IndexSubset *>
2284
+ &diffAttrParamIndicesMap) {
2285
+ DeclAttributes tempAttrs;
2286
+ tempAttrs.setRawAttributeChain (attrs);
2287
+ for (auto *attr : tempAttrs.getAttributes <DifferentiableAttr>()) {
2288
+ auto *diffAttr = const_cast <DifferentiableAttr *>(attr);
2289
+ diffAttr->setOriginalDeclaration (decl);
2290
+ diffAttr->setParameterIndices (diffAttrParamIndicesMap[diffAttr]);
2291
+ }
2292
+ }
2293
+
2271
2294
Decl *ModuleFile::getDecl (DeclID DID) {
2272
2295
Expected<Decl *> deserialized = getDeclChecked (DID);
2273
2296
if (!deserialized) {
@@ -2294,6 +2317,9 @@ class DeclDeserializer {
2294
2317
unsigned localDiscriminator = 0 ;
2295
2318
StringRef filenameForPrivate;
2296
2319
2320
+ // Auxiliary map for deserializing `@differentiable` attributes.
2321
+ llvm::DenseMap<DifferentiableAttr *, IndexSubset *> diffAttrParamIndicesMap;
2322
+
2297
2323
void AddAttribute (DeclAttribute *Attr) {
2298
2324
// Advance the linked list.
2299
2325
// This isn't just using DeclAttributes because that would result in the
@@ -4257,6 +4283,36 @@ llvm::Error DeclDeserializer::deserializeDeclAttributes() {
4257
4283
break ;
4258
4284
}
4259
4285
4286
+ case decls_block::Differentiable_DECL_ATTR: {
4287
+ bool isImplicit;
4288
+ bool linear;
4289
+ GenericSignatureID derivativeGenSigId;
4290
+ ArrayRef<uint64_t > parameters;
4291
+
4292
+ serialization::decls_block::DifferentiableDeclAttrLayout::readRecord (
4293
+ scratch, isImplicit, linear, derivativeGenSigId, parameters);
4294
+
4295
+ auto derivativeGenSig = MF.getGenericSignature (derivativeGenSigId);
4296
+ llvm::SmallBitVector parametersBitVector (parameters.size ());
4297
+ for (unsigned i : indices (parameters))
4298
+ parametersBitVector[i] = parameters[i];
4299
+ auto *indices = IndexSubset::get (ctx, parametersBitVector);
4300
+ auto *diffAttr = DifferentiableAttr::create (
4301
+ ctx, isImplicit, SourceLoc (), SourceRange (), linear,
4302
+ /* parsedParameters*/ {}, /* trailingWhereClause*/ nullptr );
4303
+
4304
+ // Cache parameter indices so that they can set later.
4305
+ // `DifferentiableAttr::setParameterIndices` cannot be called here
4306
+ // because it requires `DifferentiableAttr::setOriginalDeclaration` to
4307
+ // be called first. `DifferentiableAttr::setOriginalDeclaration` cannot
4308
+ // be called here because the original declaration is not accessible in
4309
+ // this function (`DeclDeserializer::deserializeDeclAttributes`).
4310
+ diffAttrParamIndicesMap[diffAttr] = indices;
4311
+ diffAttr->setDerivativeGenericSignature (derivativeGenSig);
4312
+ Attr = diffAttr;
4313
+ break ;
4314
+ }
4315
+
4260
4316
case decls_block::Derivative_DECL_ATTR: {
4261
4317
bool isImplicit;
4262
4318
uint64_t origNameId;
@@ -4391,8 +4447,18 @@ DeclDeserializer::getDeclCheckedImpl(
4391
4447
4392
4448
switch (recordID) {
4393
4449
#define CASE (RECORD_NAME ) \
4394
- case decls_block::RECORD_NAME##Layout::Code: \
4395
- return deserialize##RECORD_NAME (scratch, blobData);
4450
+ case decls_block::RECORD_NAME##Layout::Code: {\
4451
+ auto decl = deserialize##RECORD_NAME (scratch, blobData); \
4452
+ if (decl) { \
4453
+ /* \
4454
+ // Set original declaration and parameter indices in `@differentiable` \
4455
+ // attributes. \
4456
+ */ \
4457
+ setOriginalDeclarationAndParameterIndicesInDifferentiableAttributes (\
4458
+ decl.get (), DAttrs, diffAttrParamIndicesMap); \
4459
+ } \
4460
+ return decl; \
4461
+ }
4396
4462
4397
4463
CASE (TypeAlias)
4398
4464
CASE (GenericTypeParamDecl)
0 commit comments