Skip to content

Commit 9082dd4

Browse files
author
ematejska
authored
[Autodiff upstream] Finishing @transpose attr serialization (swiftlang#30683)
* Adding @transpose attr deserialization support * Turning on the transpose serialization test
1 parent 791312f commit 9082dd4

File tree

2 files changed

+34
-14
lines changed

2 files changed

+34
-14
lines changed

lib/Serialization/Deserialization.cpp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4369,6 +4369,30 @@ llvm::Error DeclDeserializer::deserializeDeclAttributes() {
43694369
break;
43704370
}
43714371

4372+
case decls_block::Transpose_DECL_ATTR: {
4373+
bool isImplicit;
4374+
uint64_t origNameId;
4375+
DeclID origDeclId;
4376+
ArrayRef<uint64_t> parameters;
4377+
4378+
serialization::decls_block::TransposeDeclAttrLayout::readRecord(
4379+
scratch, isImplicit, origNameId, origDeclId, parameters);
4380+
4381+
DeclNameRefWithLoc origName{
4382+
DeclNameRef(MF.getDeclBaseName(origNameId)), DeclNameLoc()};
4383+
auto *origDecl = cast<AbstractFunctionDecl>(MF.getDecl(origDeclId));
4384+
llvm::SmallBitVector parametersBitVector(parameters.size());
4385+
for (unsigned i : indices(parameters))
4386+
parametersBitVector[i] = parameters[i];
4387+
auto *indices = IndexSubset::get(ctx, parametersBitVector);
4388+
auto *transposeAttr =
4389+
TransposeAttr::create(ctx, isImplicit, SourceLoc(), SourceRange(),
4390+
/*baseTypeRepr*/ nullptr, origName, indices);
4391+
transposeAttr->setOriginalFunction(origDecl);
4392+
Attr = transposeAttr;
4393+
break;
4394+
}
4395+
43724396
case decls_block::SPIAccessControl_DECL_ATTR: {
43734397
ArrayRef<uint64_t> spiIds;
43744398
serialization::decls_block::SPIAccessControlDeclAttrLayout::readRecord(

test/AutoDiff/Serialization/transpose_attr.swift

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,6 @@
55

66
// BCANALYZER-NOT: UnknownCode
77

8-
// TODO(TF-838): Enable this test.
9-
// Blocked by TF-830: `@transpose` attribute type-checking.
10-
// XFAIL: *
11-
128
import _Differentiation
139

1410
// Dummy `Differentiable`-conforming type.
@@ -50,14 +46,14 @@ extension S {
5046

5147
// CHECK: @transpose(of: instanceMethod, wrt: 0)
5248
@transpose(of: instanceMethod, wrt: 0)
53-
func transposeInstanceMethod(v: S) -> (S, S) {
54-
(v, v)
49+
func transposeInstanceMethod(t: S) -> S {
50+
self + t
5551
}
5652

5753
// CHECK: @transpose(of: instanceMethod, wrt: self)
5854
@transpose(of: instanceMethod, wrt: self)
59-
func transposeInstanceMethodWrtSelf(v: S) -> (S, S) {
60-
(v, v)
55+
static func transposeInstanceMethodWrtSelf(_ other: S, t: S) -> S {
56+
other + t
6157
}
6258
}
6359

@@ -70,8 +66,8 @@ extension S {
7066

7167
// CHECK: @transpose(of: staticMethod, wrt: 0)
7268
@transpose(of: staticMethod, wrt: 0)
73-
func transposeStaticMethod(_: S.Type) -> S {
74-
self
69+
static func transposeStaticMethod(t: S) -> S {
70+
t
7571
}
7672
}
7773

@@ -81,8 +77,8 @@ extension S {
8177

8278
// CHECK: @transpose(of: computedProperty, wrt: self)
8379
@transpose(of: computedProperty, wrt: self)
84-
func transposeProperty() -> Self {
85-
self
80+
static func transposeProperty(t: Self) -> Self {
81+
t
8682
}
8783
}
8884

@@ -92,7 +88,7 @@ extension S {
9288

9389
// CHECK: @transpose(of: subscript, wrt: self)
9490
@transpose(of: subscript(_:), wrt: self)
95-
func transposeSubscript<T: Differentiable>(x: T) -> Self {
96-
self
91+
static func transposeSubscript<T: Differentiable>(x: T, t: Self) -> Self {
92+
t
9793
}
9894
}

0 commit comments

Comments
 (0)