Skip to content

Commit 0218844

Browse files
dan-zhengMarc Rasi
authored andcommitted
Fix differentiable function type demangling.
Handle differentiability kind (`@differentiable` and `@differentiable(linear)`) in `ASTBuilder::createImplFunctionType`. Resolves TF-1225.
1 parent 57d228b commit 0218844

File tree

3 files changed

+77
-13
lines changed

3 files changed

+77
-13
lines changed

include/swift/Demangling/TypeDecoder.h

Lines changed: 41 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -181,33 +181,53 @@ enum class ImplFunctionRepresentation {
181181
Closure
182182
};
183183

184+
enum class ImplFunctionDifferentiabilityKind {
185+
NonDifferentiable,
186+
Normal,
187+
Linear
188+
};
189+
184190
class ImplFunctionTypeFlags {
185191
unsigned Rep : 3;
186192
unsigned Pseudogeneric : 1;
187193
unsigned Escaping : 1;
194+
unsigned DifferentiabilityKind : 2;
188195

189196
public:
190-
ImplFunctionTypeFlags() : Rep(0), Pseudogeneric(0), Escaping(0) {}
197+
ImplFunctionTypeFlags()
198+
: Rep(0), Pseudogeneric(0), Escaping(0), DifferentiabilityKind(0) {}
191199

192-
ImplFunctionTypeFlags(ImplFunctionRepresentation rep,
193-
bool pseudogeneric, bool noescape)
194-
: Rep(unsigned(rep)), Pseudogeneric(pseudogeneric), Escaping(noescape) {}
200+
ImplFunctionTypeFlags(ImplFunctionRepresentation rep, bool pseudogeneric,
201+
bool noescape,
202+
ImplFunctionDifferentiabilityKind diffKind)
203+
: Rep(unsigned(rep)), Pseudogeneric(pseudogeneric), Escaping(noescape),
204+
DifferentiabilityKind(unsigned(diffKind)) {}
195205

196206
ImplFunctionTypeFlags
197207
withRepresentation(ImplFunctionRepresentation rep) const {
198-
return ImplFunctionTypeFlags(rep, Pseudogeneric, Escaping);
208+
return ImplFunctionTypeFlags(
209+
rep, Pseudogeneric, Escaping,
210+
ImplFunctionDifferentiabilityKind(DifferentiabilityKind));
199211
}
200212

201213
ImplFunctionTypeFlags
202214
withEscaping() const {
203-
return ImplFunctionTypeFlags(ImplFunctionRepresentation(Rep),
204-
Pseudogeneric, true);
215+
return ImplFunctionTypeFlags(
216+
ImplFunctionRepresentation(Rep), Pseudogeneric, true,
217+
ImplFunctionDifferentiabilityKind(DifferentiabilityKind));
205218
}
206219

207220
ImplFunctionTypeFlags
208221
withPseudogeneric() const {
209-
return ImplFunctionTypeFlags(ImplFunctionRepresentation(Rep),
210-
true, Escaping);
222+
return ImplFunctionTypeFlags(
223+
ImplFunctionRepresentation(Rep), true, Escaping,
224+
ImplFunctionDifferentiabilityKind(DifferentiabilityKind));
225+
}
226+
227+
ImplFunctionTypeFlags
228+
withDifferentiabilityKind(ImplFunctionDifferentiabilityKind diffKind) const {
229+
return ImplFunctionTypeFlags(ImplFunctionRepresentation(Rep), Pseudogeneric,
230+
Escaping, diffKind);
211231
}
212232

213233
ImplFunctionRepresentation getRepresentation() const {
@@ -217,6 +237,10 @@ class ImplFunctionTypeFlags {
217237
bool isEscaping() const { return Escaping; }
218238

219239
bool isPseudogeneric() const { return Pseudogeneric; }
240+
241+
ImplFunctionDifferentiabilityKind getDifferentiabilityKind() const {
242+
return ImplFunctionDifferentiabilityKind(DifferentiabilityKind);
243+
}
220244
};
221245

222246
#if SWIFT_OBJC_INTEROP
@@ -582,6 +606,14 @@ class TypeDecoder {
582606
flags =
583607
flags.withRepresentation(ImplFunctionRepresentation::Block);
584608
}
609+
} else if (child->getKind() == NodeKind::ImplDifferentiable) {
610+
flags = flags.withDifferentiabilityKind(
611+
ImplFunctionDifferentiabilityKind::Normal);
612+
} else if (child->getKind() == NodeKind::ImplLinear) {
613+
flags = flags.withDifferentiabilityKind(
614+
ImplFunctionDifferentiabilityKind::Linear);
615+
} else if (child->getKind() == NodeKind::ImplEscaping) {
616+
flags = flags.withEscaping();
585617
} else if (child->getKind() == NodeKind::ImplEscaping) {
586618
flags = flags.withEscaping();
587619
} else if (child->getKind() == NodeKind::ImplParameter) {

lib/AST/ASTDemangler.cpp

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -500,11 +500,23 @@ Type ASTBuilder::createImplFunctionType(
500500
break;
501501
}
502502

503+
DifferentiabilityKind diffKind;
504+
switch (flags.getDifferentiabilityKind()) {
505+
case ImplFunctionDifferentiabilityKind::NonDifferentiable:
506+
diffKind = DifferentiabilityKind::NonDifferentiable;
507+
break;
508+
case ImplFunctionDifferentiabilityKind::Normal:
509+
diffKind = DifferentiabilityKind::Normal;
510+
break;
511+
case ImplFunctionDifferentiabilityKind::Linear:
512+
diffKind = DifferentiabilityKind::Linear;
513+
break;
514+
}
515+
503516
// TODO: [store-sil-clang-function-type]
504-
auto einfo = SILFunctionType::ExtInfo(
505-
representation, flags.isPseudogeneric(), !flags.isEscaping(),
506-
DifferentiabilityKind::NonDifferentiable,
507-
/*clangFunctionType*/ nullptr);
517+
auto einfo = SILFunctionType::ExtInfo(representation, flags.isPseudogeneric(),
518+
!flags.isEscaping(), diffKind,
519+
/*clangFunctionType*/ nullptr);
508520

509521
llvm::SmallVector<SILParameterInfo, 8> funcParams;
510522
llvm::SmallVector<SILYieldInfo, 8> funcYields;
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
// RUN: %target-swift-frontend -enable-experimental-differentiable-programming -emit-ir -g %s
2+
3+
// TF-1225: IRGenDebugInfo crash when reconstructing `@differentiable` and
4+
// `@differentiable(linear)` function types.
5+
6+
import _Differentiation
7+
8+
@inlinable
9+
public func transpose<T, R>(
10+
of body: @escaping @differentiable(linear) (T) -> R
11+
) -> @differentiable(linear) (R) -> T {
12+
fatalError()
13+
}
14+
15+
@inlinable
16+
public func valueWithDifferential<T, R>(
17+
at x: T, in f: @differentiable (T) -> R
18+
) -> (value: R, differential: (T.TangentVector) -> R.TangentVector) {
19+
fatalError()
20+
}

0 commit comments

Comments
 (0)