Skip to content

Commit f3f0b93

Browse files
committed
[AutoDiff] Fix reflection and metadata IRGen for differentiable function types
1 parent 9778e1a commit f3f0b93

File tree

5 files changed

+66
-4
lines changed

5 files changed

+66
-4
lines changed

lib/IRGen/MetadataRequest.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1453,9 +1453,8 @@ namespace {
14531453
llvm::Value *diffKindVal = nullptr;
14541454
if (type->isDifferentiable()) {
14551455
assert(metadataDifferentiabilityKind.isDifferentiable());
1456-
// FIXME: Shouldn't this use metadataDifferentiabilityKind?
1457-
diffKindVal = llvm::ConstantInt::get(IGF.IGM.SizeTy,
1458-
flags.getIntValue());
1456+
diffKindVal = llvm::ConstantInt::get(
1457+
IGF.IGM.SizeTy, metadataDifferentiabilityKind.getIntValue());
14591458
} else if (type->getGlobalActor()) {
14601459
diffKindVal = llvm::ConstantInt::get(
14611460
IGF.IGM.SizeTy,

stdlib/public/Reflection/TypeRef.cpp

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,27 @@ class PrintTypeRef : public TypeRefVisitor<PrintTypeRef, void> {
137137
break;
138138
}
139139

140+
switch (F->getDifferentiabilityKind().Value) {
141+
case FunctionMetadataDifferentiabilityKind::NonDifferentiable:
142+
break;
143+
144+
case FunctionMetadataDifferentiabilityKind::Forward:
145+
printField("differentiable", "forward");
146+
break;
147+
148+
case FunctionMetadataDifferentiabilityKind::Reverse:
149+
printField("differentiable", "reverse");
150+
break;
151+
152+
case FunctionMetadataDifferentiabilityKind::Normal:
153+
printField("differentiable", "normal");
154+
break;
155+
156+
case FunctionMetadataDifferentiabilityKind::Linear:
157+
printField("differentiable", "linear");
158+
break;
159+
}
160+
140161
if (auto globalActor = F->getGlobalActor()) {
141162
fprintf(file, "\n");
142163
Indent += 2;
@@ -683,7 +704,27 @@ class DemanglingForTypeRef
683704
funcNode->addChild(node, Dem);
684705
}
685706

686-
// FIXME: Differentiability is missing
707+
if (F->getFlags().isDifferentiable()) {
708+
MangledDifferentiabilityKind mangledKind;
709+
switch (F->getDifferentiabilityKind().Value) {
710+
#define CASE(X) case FunctionMetadataDifferentiabilityKind::X: \
711+
mangledKind = MangledDifferentiabilityKind::X; break;
712+
713+
CASE(NonDifferentiable)
714+
CASE(Forward)
715+
CASE(Reverse)
716+
CASE(Normal)
717+
CASE(Linear)
718+
#undef CASE
719+
}
720+
721+
funcNode->addChild(
722+
Dem.createNode(
723+
Node::Kind::DifferentiableFunctionType,
724+
(Node::IndexType)mangledKind),
725+
Dem);
726+
}
727+
687728
if (F->getFlags().isThrowing())
688729
funcNode->addChild(Dem.createNode(Node::Kind::ThrowsAnnotation), Dem);
689730
if (F->getFlags().isSendable()) {
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
import _Differentiation
2+
3+
public struct HasAutoDiffTypes {
4+
public var aFunction: @differentiable(reverse) (Float) -> Float
5+
}
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
// this file intentionally left blank
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
// REQUIRES: no_asan
2+
// RUN: %empty-directory(%t)
3+
import _Differentiation
4+
5+
// RUN: %target-build-swift -Xfrontend -enable-anonymous-context-mangled-names %S/Inputs/AutoDiffTypes.swift -parse-as-library -emit-module -emit-library -module-name TypesToReflect -o %t/%target-library-name(TypesToReflect)
6+
// RUN: %target-build-swift -Xfrontend -enable-anonymous-context-mangled-names %S/Inputs/AutoDiffTypes.swift %S/Inputs/main.swift -emit-module -emit-executable -module-name TypesToReflect -o %t/TypesToReflect
7+
8+
// RUN: %target-swift-reflection-dump -binary-filename %t/%target-library-name(TypesToReflect) | %FileCheck %s
9+
// RUN: %target-swift-reflection-dump -binary-filename %t/TypesToReflect | %FileCheck %s
10+
11+
// CHECK: FIELDS:
12+
// CHECK: =======
13+
// CHECK: TypesToReflect.HasAutoDiffTypes
14+
// CHECK: aFunction: @differentiable(reverse) (Swift.Float) -> Swift.Float
15+
// CHECK: (function differentiable=reverse
16+

0 commit comments

Comments
 (0)