Skip to content

Commit 75088cd

Browse files
committed
[AutoDiff] Mangle derivative vtable thunks.
Add the following mangling rule: ``` global ::= global generic-signature? 'TJV' AUTODIFF-FUNCTION-KIND INDEX-SUBSET 'p' INDEX-SUBSET 'r' // autodiff derivative vtable thunk ``` Resolves rdar://74340331.
1 parent 55199f5 commit 75088cd

File tree

17 files changed

+104
-67
lines changed

17 files changed

+104
-67
lines changed

docs/ABI/Mangling.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,7 @@ types where the metadata itself has unknown layout.)
229229
global ::= type generic-signature 'TH' // key path equality
230230
global ::= type generic-signature 'Th' // key path hasher
231231
global ::= global generic-signature? 'TJ' AUTODIFF-FUNCTION-KIND INDEX-SUBSET 'p' INDEX-SUBSET 'r' // autodiff function
232+
global ::= global generic-signature? 'TJV' AUTODIFF-FUNCTION-KIND INDEX-SUBSET 'p' INDEX-SUBSET 'r' // autodiff derivative vtable thunk
232233
global ::= from-type to-type 'TJO' AUTODIFF-FUNCTION-KIND // autodiff self-reordering reabstraction thunk
233234
global ::= from-type 'TJS' AUTODIFF-FUNCTION-KIND INDEX-SUBSET 'p' INDEX-SUBSET 'r' INDEX-SUBSET 'P' // autodiff linear map subset parameters thunk
234235
global ::= global to-type 'TJS' AUTODIFF-FUNCTION-KIND INDEX-SUBSET 'p' INDEX-SUBSET 'r' INDEX-SUBSET 'P' // autodiff derivative function subset parameters thunk

include/swift/AST/ASTMangler.h

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -173,15 +173,17 @@ class ASTMangler : public Mangler {
173173
CanType ResultType,
174174
bool predefined);
175175

176-
/// Mangle the derivative function (JVP/VJP) for the given:
176+
/// Mangle the derivative function (JVP/VJP), or optionally its vtable entry
177+
/// thunk, for the given:
177178
/// - Mangled original function declaration.
178179
/// - Derivative function kind.
179180
/// - Derivative function configuration: parameter/result indices and
180181
/// derivative generic signature.
181182
std::string
182183
mangleAutoDiffDerivativeFunction(const AbstractFunctionDecl *originalAFD,
183184
AutoDiffDerivativeFunctionKind kind,
184-
AutoDiffConfig config);
185+
AutoDiffConfig config,
186+
bool isVTableThunk = false);
185187

186188
/// Mangle the linear map (differential/pullback) for the given:
187189
/// - Mangled original function declaration.
@@ -447,7 +449,8 @@ class ASTMangler : public Mangler {
447449

448450
void beginManglingWithAutoDiffOriginalFunction(
449451
const AbstractFunctionDecl *afd);
450-
void appendAutoDiffFunctionParts(char functionKindCode,
452+
void appendAutoDiffFunctionParts(StringRef op,
453+
Demangle::AutoDiffFunctionKind kind,
451454
AutoDiffConfig config);
452455
void appendIndexSubset(IndexSubset *indexSubset);
453456
};

include/swift/Demangling/DemangleNodes.def

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -312,6 +312,7 @@ CONTEXT_NODE(AutoDiffFunction)
312312
NODE(AutoDiffFunctionKind)
313313
NODE(AutoDiffSelfReorderingReabstractionThunk)
314314
NODE(AutoDiffSubsetParametersThunk)
315+
NODE(AutoDiffDerivativeVTableThunk)
315316
NODE(IndexSubset)
316317

317318
#undef CONTEXT_NODE

include/swift/Demangling/Demangler.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -569,7 +569,7 @@ class Demangler : public NodeFactory {
569569

570570
NodePointer demangleTypeMangling();
571571
NodePointer demangleSymbolicReference(unsigned char rawKind);
572-
NodePointer demangleAutoDiffFunction();
572+
NodePointer demangleAutoDiffFunctionOrSimpleThunk(Node::Kind nodeKind);
573573
NodePointer demangleAutoDiffFunctionKind();
574574
NodePointer demangleAutoDiffSubsetParametersThunk();
575575
NodePointer demangleAutoDiffSelfReorderingReabstractionThunk();

lib/AST/ASTMangler.cpp

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -409,17 +409,19 @@ std::string ASTMangler::mangleObjCAsyncCompletionHandlerImpl(
409409
std::string ASTMangler::mangleAutoDiffDerivativeFunction(
410410
const AbstractFunctionDecl *originalAFD,
411411
AutoDiffDerivativeFunctionKind kind,
412-
AutoDiffConfig config) {
412+
AutoDiffConfig config,
413+
bool isVTableThunk) {
413414
beginManglingWithAutoDiffOriginalFunction(originalAFD);
414-
appendAutoDiffFunctionParts((char)getAutoDiffFunctionKind(kind), config);
415+
appendAutoDiffFunctionParts(
416+
isVTableThunk ? "TJV" : "TJ", getAutoDiffFunctionKind(kind), config);
415417
return finalize();
416418
}
417419

418420
std::string ASTMangler::mangleAutoDiffLinearMap(
419421
const AbstractFunctionDecl *originalAFD, AutoDiffLinearMapKind kind,
420422
AutoDiffConfig config) {
421423
beginManglingWithAutoDiffOriginalFunction(originalAFD);
422-
appendAutoDiffFunctionParts((char)getAutoDiffFunctionKind(kind), config);
424+
appendAutoDiffFunctionParts("TJ", getAutoDiffFunctionKind(kind), config);
423425
return finalize();
424426
}
425427

@@ -437,11 +439,13 @@ void ASTMangler::beginManglingWithAutoDiffOriginalFunction(
437439
appendEntity(afd);
438440
}
439441

440-
void ASTMangler::appendAutoDiffFunctionParts(char functionKindCode,
442+
void ASTMangler::appendAutoDiffFunctionParts(StringRef op,
443+
AutoDiffFunctionKind kind,
441444
AutoDiffConfig config) {
442445
if (auto sig = config.derivativeGenericSignature)
443446
appendGenericSignature(sig);
444-
appendOperator("TJ", StringRef(&functionKindCode, 1));
447+
auto kindCode = (char)kind;
448+
appendOperator(op, StringRef(&kindCode, 1));
445449
appendIndexSubset(config.parameterIndices);
446450
appendOperator("p");
447451
appendIndexSubset(config.resultIndices);

lib/Demangling/Demangler.cpp

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2501,18 +2501,27 @@ NodePointer Demangler::demangleThunkOrSpecialization() {
25012501
case 'J':
25022502
switch (peekChar()) {
25032503
case 'S':
2504+
nextChar();
25042505
return demangleAutoDiffSubsetParametersThunk();
25052506
case 'O':
2507+
nextChar();
25062508
return demangleAutoDiffSelfReorderingReabstractionThunk();
2509+
case 'V':
2510+
nextChar();
2511+
return demangleAutoDiffFunctionOrSimpleThunk(
2512+
Node::Kind::AutoDiffDerivativeVTableThunk);
2513+
default:
2514+
return demangleAutoDiffFunctionOrSimpleThunk(
2515+
Node::Kind::AutoDiffFunction);
25072516
}
2508-
return demangleAutoDiffFunction();
25092517
default:
25102518
return nullptr;
25112519
}
25122520
}
25132521

2514-
NodePointer Demangler::demangleAutoDiffFunction() {
2515-
auto result = createNode(Node::Kind::AutoDiffFunction);
2522+
NodePointer
2523+
Demangler::demangleAutoDiffFunctionOrSimpleThunk(Node::Kind nodeKind) {
2524+
auto result = createNode(nodeKind);
25162525
while (auto *originalNode = popNode())
25172526
result = addChild(result, originalNode);
25182527
result->reverseChildren();
@@ -2535,7 +2544,6 @@ NodePointer Demangler::demangleAutoDiffFunctionKind() {
25352544
}
25362545

25372546
NodePointer Demangler::demangleAutoDiffSubsetParametersThunk() {
2538-
nextChar();
25392547
auto result = createNode(Node::Kind::AutoDiffSubsetParametersThunk);
25402548
while (auto *node = popNode())
25412549
result = addChild(result, node);
@@ -2555,7 +2563,6 @@ NodePointer Demangler::demangleAutoDiffSubsetParametersThunk() {
25552563
}
25562564

25572565
NodePointer Demangler::demangleAutoDiffSelfReorderingReabstractionThunk() {
2558-
nextChar();
25592566
auto result = createNode(
25602567
Node::Kind::AutoDiffSelfReorderingReabstractionThunk);
25612568
addChild(result, popNode(Node::Kind::DependentGenericSignature));

lib/Demangling/NodePrinter.cpp

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -566,6 +566,7 @@ class NodePrinter {
566566
case Node::Kind::CanonicalPrespecializedGenericTypeCachingOnceToken:
567567
case Node::Kind::AsyncFunctionPointer:
568568
case Node::Kind::AutoDiffFunction:
569+
case Node::Kind::AutoDiffDerivativeVTableThunk:
569570
case Node::Kind::AutoDiffSelfReorderingReabstractionThunk:
570571
case Node::Kind::AutoDiffSubsetParametersThunk:
571572
case Node::Kind::AutoDiffFunctionKind:
@@ -1739,16 +1740,19 @@ NodePointer NodePrinter::print(NodePointer Node, bool asPrefixContext) {
17391740
print(Node->getChild(idx));
17401741
return nullptr;
17411742
}
1742-
case Node::Kind::AutoDiffFunction: {
1743+
case Node::Kind::AutoDiffFunction:
1744+
case Node::Kind::AutoDiffDerivativeVTableThunk: {
17431745
unsigned prefixEndIndex = 0;
17441746
while (prefixEndIndex != Node->getNumChildren() &&
17451747
Node->getChild(prefixEndIndex)->getKind()
17461748
!= Node::Kind::AutoDiffFunctionKind)
17471749
++prefixEndIndex;
1748-
auto kind = Node->getChild(prefixEndIndex);
1750+
auto funcKind = Node->getChild(prefixEndIndex);
17491751
auto paramIndices = Node->getChild(prefixEndIndex + 1);
17501752
auto resultIndices = Node->getChild(prefixEndIndex + 2);
1751-
print(kind);
1753+
if (kind == Node::Kind::AutoDiffDerivativeVTableThunk)
1754+
Printer << "vtable thunk for ";
1755+
print(funcKind);
17521756
Printer << " of ";
17531757
NodePointer optionalGenSig = nullptr;
17541758
for (unsigned i = 0; i < prefixEndIndex; ++i) {

lib/Demangling/OldRemangler.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -748,6 +748,10 @@ void Remangler::mangleAutoDiffFunction(Node *node, EntityContext &ctx) {
748748
Buffer << "<autodiff-function>";
749749
}
750750

751+
void Remangler::mangleAutoDiffDerivativeVTableThunk(Node *node) {
752+
Buffer << "<autodiff-derivative-vtable-thunk>";
753+
}
754+
751755
void Remangler::mangleAutoDiffSelfReorderingReabstractionThunk(Node *node) {
752756
Buffer << "<autodiff-self-reordering-reabstraction-thunk>";
753757
}

lib/Demangling/Remangler.cpp

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -311,6 +311,8 @@ class Remangler : public RemanglerBase {
311311

312312
void mangleKeyPathThunkHelper(Node *node, StringRef op);
313313

314+
void mangleAutoDiffFunctionOrSimpleThunk(Node *node, StringRef op);
315+
314316
#define NODE(ID) \
315317
void mangle##ID(Node *node);
316318
#define CONTEXT_NODE(ID) \
@@ -2114,19 +2116,27 @@ void Remangler::mangleReabstractionThunkHelperWithSelf(Node *node) {
21142116
Buffer << "Ty";
21152117
}
21162118

2117-
void Remangler::mangleAutoDiffFunction(Node *node) {
2119+
void Remangler::mangleAutoDiffFunctionOrSimpleThunk(Node *node, StringRef op) {
21182120
auto childIt = node->begin();
21192121
while (childIt != node->end() &&
21202122
(*childIt)->getKind() != Node::Kind::AutoDiffFunctionKind)
21212123
mangle(*childIt++);
2122-
Buffer << "TJ";
2124+
Buffer << op;
21232125
mangle(*childIt++); // kind
21242126
mangle(*childIt++); // parameter indices
21252127
Buffer << 'p';
21262128
mangle(*childIt++); // result indices
21272129
Buffer << 'r';
21282130
}
21292131

2132+
void Remangler::mangleAutoDiffFunction(Node *node) {
2133+
mangleAutoDiffFunctionOrSimpleThunk(node, "TJ");
2134+
}
2135+
2136+
void Remangler::mangleAutoDiffDerivativeVTableThunk(Node *node) {
2137+
mangleAutoDiffFunctionOrSimpleThunk(node, "TJV");
2138+
}
2139+
21302140
void Remangler::mangleAutoDiffSelfReorderingReabstractionThunk(Node *node) {
21312141
auto childIt = node->begin();
21322142
mangle(*childIt++); // from type

lib/IRGen/IRGenMangler.h

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -56,13 +56,12 @@ class IRGenMangler : public Mangle::ASTMangler {
5656
const AbstractFunctionDecl *func,
5757
AutoDiffDerivativeFunctionIdentifier *derivativeId) {
5858
beginManglingWithAutoDiffOriginalFunction(func);
59-
auto kindCode =
60-
(char)Demangle::getAutoDiffFunctionKind(derivativeId->getKind());
59+
auto kind = Demangle::getAutoDiffFunctionKind(derivativeId->getKind());
6160
AutoDiffConfig config(
6261
derivativeId->getParameterIndices(),
6362
IndexSubset::get(func->getASTContext(), 1, {0}),
6463
derivativeId->getDerivativeGenericSignature());
65-
appendAutoDiffFunctionParts(kindCode, config);
64+
appendAutoDiffFunctionParts("TJ", kind, config);
6665
appendOperator("Tj");
6766
return finalize();
6867
}
@@ -86,13 +85,12 @@ class IRGenMangler : public Mangle::ASTMangler {
8685
const AbstractFunctionDecl *func,
8786
AutoDiffDerivativeFunctionIdentifier *derivativeId) {
8887
beginManglingWithAutoDiffOriginalFunction(func);
89-
auto kindCode =
90-
(char)Demangle::getAutoDiffFunctionKind(derivativeId->getKind());
88+
auto kind = Demangle::getAutoDiffFunctionKind(derivativeId->getKind());
9189
AutoDiffConfig config(
9290
derivativeId->getParameterIndices(),
9391
IndexSubset::get(func->getASTContext(), 1, {0}),
9492
derivativeId->getDerivativeGenericSignature());
95-
appendAutoDiffFunctionParts(kindCode, config);
93+
appendAutoDiffFunctionParts("TJ", kind, config);
9694
appendOperator("Tq");
9795
return finalize();
9896
}

0 commit comments

Comments
 (0)