Skip to content

Commit c1fe0e3

Browse files
authored
[AutoDiff upstream] Add differentiable function type mangling. (swiftlang#30675)
Add mangling scheme for `@differentiable` and `@differentiable(linear)` function types. Mangling support is important for debug information, among other things. Update docs and add tests. Resolves TF-948.
1 parent 9f0f92f commit c1fe0e3

File tree

13 files changed

+268
-11
lines changed

13 files changed

+268
-11
lines changed

docs/ABI/Mangling.rst

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -517,6 +517,10 @@ Types
517517
FUNCTION-KIND ::= 'C' // C function pointer type
518518
FUNCTION-KIND ::= 'A' // @auto_closure function type (escaping)
519519
FUNCTION-KIND ::= 'E' // function type (noescape)
520+
FUNCTION-KIND ::= 'F' // @differentiable function type
521+
FUNCTION-KIND ::= 'G' // @differentiable function type (escaping)
522+
FUNCTION-KIND ::= 'H' // @differentiable(linear) function type
523+
FUNCTION-KIND ::= 'I' // @differentiable(linear) function type (escaping)
520524

521525
function-signature ::= params-type params-type throws? // results and parameters
522526

@@ -585,14 +589,18 @@ mangled in to disambiguate.
585589
impl-function-type ::= type* 'I' FUNC-ATTRIBUTES '_'
586590
impl-function-type ::= type* generic-signature 'I' FUNC-ATTRIBUTES '_'
587591

588-
FUNC-ATTRIBUTES ::= PATTERN-SUBS? INVOCATION-SUBS? PSEUDO-GENERIC? CALLEE-ESCAPE? CALLEE-CONVENTION FUNC-REPRESENTATION? COROUTINE-KIND? PARAM-CONVENTION* RESULT-CONVENTION* ('Y' PARAM-CONVENTION)* ('z' RESULT-CONVENTION)?
592+
FUNC-ATTRIBUTES ::= PATTERN-SUBS? INVOCATION-SUBS? PSEUDO-GENERIC? CALLEE-ESCAPE? DIFFERENTIABILITY-KIND? CALLEE-CONVENTION FUNC-REPRESENTATION? COROUTINE-KIND? PARAM-CONVENTION* RESULT-CONVENTION* ('Y' PARAM-CONVENTION)* ('z' RESULT-CONVENTION)?
589593

590594
PATTERN-SUBS ::= 's' // has pattern substitutions
591595
INVOCATION-SUB ::= 'I' // has invocation substitutions
592596
PSEUDO-GENERIC ::= 'P'
593597

594598
CALLEE-ESCAPE ::= 'e' // @escaping (inverse of SIL @noescape)
595599

600+
DIFFERENTIABILITY-KIND ::= DIFFERENTIABLE | LINEAR
601+
DIFFERENTIABLE ::= 'd' // @differentiable
602+
LINEAR ::= 'l' // @differentiable(linear)
603+
596604
CALLEE-CONVENTION ::= 'y' // @callee_unowned
597605
CALLEE-CONVENTION ::= 'g' // @callee_guaranteed
598606
CALLEE-CONVENTION ::= 'x' // @callee_owned

include/swift/ABI/MetadataValues.h

Lines changed: 38 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -764,6 +764,14 @@ enum class FunctionMetadataConvention: uint8_t {
764764
CFunctionPointer = 3,
765765
};
766766

767+
/// Differentiability kind for function type metadata.
768+
/// Duplicates `DifferentiabilityKind` in AutoDiff.h.
769+
enum class FunctionMetadataDifferentiabilityKind: uint8_t {
770+
NonDifferentiable = 0b00,
771+
Normal = 0b01,
772+
Linear = 0b11
773+
};
774+
767775
/// Flags in a function type metadata record.
768776
template <typename int_type>
769777
class TargetFunctionTypeFlags {
@@ -777,6 +785,8 @@ class TargetFunctionTypeFlags {
777785
ThrowsMask = 0x01000000U,
778786
ParamFlagsMask = 0x02000000U,
779787
EscapingMask = 0x04000000U,
788+
DifferentiableMask = 0x08000000U,
789+
LinearMask = 0x10000000U
780790
};
781791
int_type Data;
782792

@@ -801,6 +811,16 @@ class TargetFunctionTypeFlags {
801811
(throws ? ThrowsMask : 0));
802812
}
803813

814+
constexpr TargetFunctionTypeFlags<int_type> withDifferentiabilityKind(
815+
FunctionMetadataDifferentiabilityKind differentiability) const {
816+
return TargetFunctionTypeFlags<int_type>(
817+
(Data & ~DifferentiableMask & ~LinearMask) |
818+
(differentiability == FunctionMetadataDifferentiabilityKind::Normal
819+
? DifferentiableMask : 0) |
820+
(differentiability == FunctionMetadataDifferentiabilityKind::Linear
821+
? LinearMask : 0));
822+
}
823+
804824
constexpr TargetFunctionTypeFlags<int_type>
805825
withParameterFlags(bool hasFlags) const {
806826
return TargetFunctionTypeFlags<int_type>((Data & ~ParamFlagsMask) |
@@ -829,6 +849,19 @@ class TargetFunctionTypeFlags {
829849

830850
bool hasParameterFlags() const { return bool(Data & ParamFlagsMask); }
831851

852+
bool isDifferentiable() const {
853+
return getDifferentiabilityKind() >=
854+
FunctionMetadataDifferentiabilityKind::Normal;
855+
}
856+
857+
FunctionMetadataDifferentiabilityKind getDifferentiabilityKind() const {
858+
if (bool(Data & DifferentiableMask))
859+
return FunctionMetadataDifferentiabilityKind::Normal;
860+
if (bool(Data & LinearMask))
861+
return FunctionMetadataDifferentiabilityKind::Linear;
862+
return FunctionMetadataDifferentiabilityKind::NonDifferentiable;
863+
}
864+
832865
int_type getIntValue() const {
833866
return Data;
834867
}
@@ -849,9 +882,10 @@ using FunctionTypeFlags = TargetFunctionTypeFlags<size_t>;
849882
template <typename int_type>
850883
class TargetParameterTypeFlags {
851884
enum : int_type {
852-
ValueOwnershipMask = 0x7F,
853-
VariadicMask = 0x80,
854-
AutoClosureMask = 0x100,
885+
ValueOwnershipMask = 0x7F,
886+
VariadicMask = 0x80,
887+
AutoClosureMask = 0x100,
888+
NoDerivativeMask = 0x200
855889
};
856890
int_type Data;
857891

@@ -881,6 +915,7 @@ class TargetParameterTypeFlags {
881915
bool isNone() const { return Data == 0; }
882916
bool isVariadic() const { return Data & VariadicMask; }
883917
bool isAutoClosure() const { return Data & AutoClosureMask; }
918+
bool isNoDerivative() const { return Data & NoDerivativeMask; }
884919

885920
ValueOwnership getValueOwnership() const {
886921
return (ValueOwnership)(Data & ValueOwnershipMask);

include/swift/Demangling/DemangleNodes.def

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,10 @@ NODE(DependentProtocolConformanceInherited)
6868
NODE(DependentProtocolConformanceAssociated)
6969
CONTEXT_NODE(Destructor)
7070
CONTEXT_NODE(DidSet)
71+
NODE(DifferentiableFunctionType)
72+
NODE(EscapingDifferentiableFunctionType)
73+
NODE(LinearFunctionType)
74+
NODE(EscapingLinearFunctionType)
7175
NODE(Directness)
7276
NODE(DynamicAttribute)
7377
NODE(DirectMethodReferenceAttribute)
@@ -109,6 +113,8 @@ NODE(Identifier)
109113
NODE(Index)
110114
CONTEXT_NODE(IVarInitializer)
111115
CONTEXT_NODE(IVarDestroyer)
116+
NODE(ImplDifferentiable)
117+
NODE(ImplLinear)
112118
NODE(ImplEscaping)
113119
NODE(ImplConvention)
114120
NODE(ImplFunctionAttribute)

include/swift/Demangling/TypeDecoder.h

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -494,6 +494,10 @@ class TypeDecoder {
494494
case NodeKind::NoEscapeFunctionType:
495495
case NodeKind::AutoClosureType:
496496
case NodeKind::EscapingAutoClosureType:
497+
case NodeKind::DifferentiableFunctionType:
498+
case NodeKind::EscapingDifferentiableFunctionType:
499+
case NodeKind::LinearFunctionType:
500+
case NodeKind::EscapingLinearFunctionType:
497501
case NodeKind::FunctionType: {
498502
if (Node->getNumChildren() < 2)
499503
return BuiltType();
@@ -507,6 +511,15 @@ class TypeDecoder {
507511
flags.withConvention(FunctionMetadataConvention::CFunctionPointer);
508512
} else if (Node->getKind() == NodeKind::ThinFunctionType) {
509513
flags = flags.withConvention(FunctionMetadataConvention::Thin);
514+
} else if (Node->getKind() == NodeKind::DifferentiableFunctionType ||
515+
Node->getKind() ==
516+
NodeKind::EscapingDifferentiableFunctionType) {
517+
flags = flags.withDifferentiabilityKind(
518+
FunctionMetadataDifferentiabilityKind::Normal);
519+
} else if (Node->getKind() == NodeKind::LinearFunctionType ||
520+
Node->getKind() == NodeKind::EscapingLinearFunctionType) {
521+
flags = flags.withDifferentiabilityKind(
522+
FunctionMetadataDifferentiabilityKind::Linear);
510523
}
511524

512525
bool isThrow =
@@ -527,7 +540,11 @@ class TypeDecoder {
527540
.withEscaping(
528541
Node->getKind() == NodeKind::FunctionType ||
529542
Node->getKind() == NodeKind::EscapingAutoClosureType ||
530-
Node->getKind() == NodeKind::EscapingObjCBlock);
543+
Node->getKind() == NodeKind::EscapingObjCBlock ||
544+
Node->getKind() ==
545+
NodeKind::EscapingDifferentiableFunctionType ||
546+
Node->getKind() ==
547+
NodeKind::EscapingLinearFunctionType);
531548

532549
auto result = decodeMangledType(Node->getChild(isThrow ? 2 : 1));
533550
if (!result) return BuiltType();

lib/AST/ASTDemangler.cpp

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -365,7 +365,8 @@ Type ASTBuilder::createFunctionType(
365365
auto parameterFlags = ParameterTypeFlags()
366366
.withValueOwnership(ownership)
367367
.withVariadic(flags.isVariadic())
368-
.withAutoClosure(flags.isAutoClosure());
368+
.withAutoClosure(flags.isAutoClosure())
369+
.withNoDerivative(flags.isNoDerivative());
369370

370371
funcParams.push_back(AnyFunctionType::Param(type, label, parameterFlags));
371372
}
@@ -386,16 +387,27 @@ Type ASTBuilder::createFunctionType(
386387
break;
387388
}
388389

390+
DifferentiabilityKind diffKind;
391+
switch (flags.getDifferentiabilityKind()) {
392+
case FunctionMetadataDifferentiabilityKind::NonDifferentiable:
393+
diffKind = DifferentiabilityKind::NonDifferentiable;
394+
break;
395+
case FunctionMetadataDifferentiabilityKind::Normal:
396+
diffKind = DifferentiabilityKind::Normal;
397+
break;
398+
case FunctionMetadataDifferentiabilityKind::Linear:
399+
diffKind = DifferentiabilityKind::Linear;
400+
break;
401+
}
402+
389403
auto noescape =
390404
(representation == FunctionTypeRepresentation::Swift
391405
|| representation == FunctionTypeRepresentation::Block)
392406
&& !flags.isEscaping();
393407

394408
FunctionType::ExtInfo incompleteExtInfo(
395409
FunctionTypeRepresentation::Swift,
396-
noescape, flags.throws(),
397-
DifferentiabilityKind::NonDifferentiable,
398-
/*clangFunctionType*/nullptr);
410+
noescape, flags.throws(), diffKind, /*clangFunctionType*/nullptr);
399411

400412
const clang::Type *clangFunctionType = nullptr;
401413
if (representation == FunctionTypeRepresentation::CFunctionPointer)

lib/AST/ASTMangler.cpp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1517,6 +1517,18 @@ void ASTMangler::appendImplFunctionType(SILFunctionType *fn) {
15171517
if (!fn->isNoEscape())
15181518
OpArgs.push_back('e');
15191519

1520+
// Differentiability kind.
1521+
switch (fn->getExtInfo().getDifferentiabilityKind()) {
1522+
case DifferentiabilityKind::NonDifferentiable:
1523+
break;
1524+
case DifferentiabilityKind::Normal:
1525+
OpArgs.push_back('d');
1526+
break;
1527+
case DifferentiabilityKind::Linear:
1528+
OpArgs.push_back('l');
1529+
break;
1530+
}
1531+
15201532
// <impl-callee-convention>
15211533
if (fn->getExtInfo().hasContext()) {
15221534
OpArgs.push_back(getParamConvention(fn->getCalleeConvention()));
@@ -2117,6 +2129,18 @@ void ASTMangler::appendFunctionType(AnyFunctionType *fn, bool isAutoClosure,
21172129
case AnyFunctionType::Representation::Thin:
21182130
return appendOperator("Xf");
21192131
case AnyFunctionType::Representation::Swift:
2132+
if (fn->getDifferentiabilityKind() == DifferentiabilityKind::Normal) {
2133+
if (fn->isNoEscape())
2134+
return appendOperator("XF");
2135+
else
2136+
return appendOperator("XG");
2137+
}
2138+
if (fn->getDifferentiabilityKind() == DifferentiabilityKind::Linear) {
2139+
if (fn->isNoEscape())
2140+
return appendOperator("XH");
2141+
else
2142+
return appendOperator("XI");
2143+
}
21202144
if (isAutoClosure) {
21212145
if (fn->isNoEscape())
21222146
return appendOperator("XK");

lib/Demangling/Demangler.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1773,6 +1773,11 @@ NodePointer Demangler::demangleImplFunctionType() {
17731773
if (nextIf('e'))
17741774
type->addChild(createNode(Node::Kind::ImplEscaping), *this);
17751775

1776+
if (nextIf('d'))
1777+
type->addChild(createNode(Node::Kind::ImplDifferentiable), *this);
1778+
if (nextIf('l'))
1779+
type->addChild(createNode(Node::Kind::ImplLinear), *this);
1780+
17761781
const char *CAttr = nullptr;
17771782
switch (nextChar()) {
17781783
case 'y': CAttr = "@callee_unowned"; break;
@@ -2791,6 +2796,14 @@ NodePointer Demangler::demangleSpecialType() {
27912796
return popFunctionType(Node::Kind::ObjCBlock);
27922797
case 'C':
27932798
return popFunctionType(Node::Kind::CFunctionPointer);
2799+
case 'F':
2800+
return popFunctionType(Node::Kind::DifferentiableFunctionType);
2801+
case 'G':
2802+
return popFunctionType(Node::Kind::EscapingDifferentiableFunctionType);
2803+
case 'H':
2804+
return popFunctionType(Node::Kind::LinearFunctionType);
2805+
case 'I':
2806+
return popFunctionType(Node::Kind::EscapingLinearFunctionType);
27942807
case 'o':
27952808
return createType(createWithChild(Node::Kind::Unowned,
27962809
popNode(Node::Kind::Type)));

lib/Demangling/NodePrinter.cpp

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -351,6 +351,10 @@ class NodePrinter {
351351
case Node::Kind::DependentPseudogenericSignature:
352352
case Node::Kind::Destructor:
353353
case Node::Kind::DidSet:
354+
case Node::Kind::DifferentiableFunctionType:
355+
case Node::Kind::EscapingDifferentiableFunctionType:
356+
case Node::Kind::LinearFunctionType:
357+
case Node::Kind::EscapingLinearFunctionType:
354358
case Node::Kind::DirectMethodReferenceAttribute:
355359
case Node::Kind::Directness:
356360
case Node::Kind::DynamicAttribute:
@@ -386,6 +390,8 @@ class NodePrinter {
386390
case Node::Kind::Index:
387391
case Node::Kind::IVarInitializer:
388392
case Node::Kind::IVarDestroyer:
393+
case Node::Kind::ImplDifferentiable:
394+
case Node::Kind::ImplLinear:
389395
case Node::Kind::ImplEscaping:
390396
case Node::Kind::ImplConvention:
391397
case Node::Kind::ImplFunctionAttribute:
@@ -1234,6 +1240,22 @@ NodePointer NodePrinter::print(NodePointer Node, bool asPrefixContext) {
12341240
Printer << "@convention(thin) ";
12351241
printFunctionType(nullptr, Node);
12361242
return nullptr;
1243+
case Node::Kind::DifferentiableFunctionType:
1244+
Printer << "@differentiable ";
1245+
printFunctionType(nullptr, Node);
1246+
return nullptr;
1247+
case Node::Kind::EscapingDifferentiableFunctionType:
1248+
Printer << "@escaping @differentiable ";
1249+
printFunctionType(nullptr, Node);
1250+
return nullptr;
1251+
case Node::Kind::LinearFunctionType:
1252+
Printer << "@differentiable(linear) ";
1253+
printFunctionType(nullptr, Node);
1254+
return nullptr;
1255+
case Node::Kind::EscapingLinearFunctionType:
1256+
Printer << "@escaping @differentiable(linear) ";
1257+
printFunctionType(nullptr, Node);
1258+
return nullptr;
12371259
case Node::Kind::FunctionType:
12381260
case Node::Kind::UncurriedFunctionType:
12391261
printFunctionType(nullptr, Node);
@@ -2026,6 +2048,12 @@ NodePointer NodePrinter::print(NodePointer Node, bool asPrefixContext) {
20262048
return nullptr;
20272049
case Node::Kind::LabelList:
20282050
return nullptr;
2051+
case Node::Kind::ImplDifferentiable:
2052+
Printer << "@differentiable";
2053+
return nullptr;
2054+
case Node::Kind::ImplLinear:
2055+
Printer << "@differentiable(linear)";
2056+
return nullptr;
20292057
case Node::Kind::ImplEscaping:
20302058
Printer << "@escaping";
20312059
return nullptr;
@@ -2527,6 +2555,14 @@ void NodePrinter::printEntityType(NodePointer Entity, NodePointer type,
25272555
Printer << ' ';
25282556
type = dependentType->getFirstChild();
25292557
}
2558+
if (type->getKind() == Node::Kind::DifferentiableFunctionType)
2559+
Printer << "@differentiable ";
2560+
else if (type->getKind() == Node::Kind::EscapingDifferentiableFunctionType)
2561+
Printer << "@escaping @differentiable ";
2562+
else if (type->getKind() == Node::Kind::LinearFunctionType)
2563+
Printer << "@differentiable(linear) ";
2564+
else if (type->getKind() == Node::Kind::EscapingLinearFunctionType)
2565+
Printer << "@escaping @differentiable(linear) ";
25302566
printFunctionType(labelList, type);
25312567
} else {
25322568
print(type);

lib/Demangling/OldRemangler.cpp

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1168,6 +1168,26 @@ void Remangler::mangleThinFunctionType(Node *node) {
11681168
mangleChildNodes(node); // argument tuple, result type
11691169
}
11701170

1171+
void Remangler::mangleDifferentiableFunctionType(Node *node) {
1172+
Buffer << "XF";
1173+
mangleChildNodes(node); // argument tuple, result type
1174+
}
1175+
1176+
void Remangler::mangleEscapingDifferentiableFunctionType(Node *node) {
1177+
Buffer << "XG";
1178+
mangleChildNodes(node); // argument tuple, result type
1179+
}
1180+
1181+
void Remangler::mangleLinearFunctionType(Node *node) {
1182+
Buffer << "XH";
1183+
mangleChildNodes(node); // argument tuple, result type
1184+
}
1185+
1186+
void Remangler::mangleEscapingLinearFunctionType(Node *node) {
1187+
Buffer << "XI";
1188+
mangleChildNodes(node); // argument tuple, result type
1189+
}
1190+
11711191
void Remangler::mangleArgumentTuple(Node *node) {
11721192
mangleSingleChildNode(node);
11731193
}
@@ -1258,6 +1278,16 @@ void Remangler::mangleImplYield(Node *node) {
12581278
mangleChildNodes(node); // impl convention, type
12591279
}
12601280

1281+
void Remangler::mangleImplDifferentiable(Node *node) {
1282+
// TODO(TF-750): Check if this code path actually triggers and add a test.
1283+
Buffer << 'd';
1284+
}
1285+
1286+
void Remangler::mangleImplLinear(Node *node) {
1287+
// TODO(TF-750): Check if this code path actually triggers and add a test.
1288+
Buffer << 'l';
1289+
}
1290+
12611291
void Remangler::mangleImplEscaping(Node *node) {
12621292
// The old mangler does not encode escaping.
12631293
}

0 commit comments

Comments
 (0)