Skip to content

Commit 82886bf

Browse files
committed
[AutoDiff] Fix mangling of '@noDerivative' in function types.
`@noDerivative` was not mangled in function types, and was resolved incorrectly when there's an ownership specifier. It is fixed by this patch with the following changes: * Add `NoDerivative` demangle node represented by a `k` operator. ``` list-type ::= type identifier? 'k'? 'z'? 'h'? 'n'? 'd'? // type with optional label, '@noDerivative', inout convention, shared convention, owned convention, and variadic specifier ``` * Fix `NoDerivative`'s overflown offset in `ParameterTypeFlags` (`7` -> `6`). * In type decoder and type resolver where attributed type nodes are processed, add support for nested attributed nodes, e.g. `inout @noDerivative T`. * Add `TypeResolverContext::InoutFunctionInput` so that when we resolve an `inout @noDerivative T` parameter, the `@noDerivative T` checking logic won't get a `TypeResolverContext::None` set by the caller. Resolves rdar://75916833.
1 parent 7b5b474 commit 82886bf

File tree

18 files changed

+144
-52
lines changed

18 files changed

+144
-52
lines changed

docs/ABI/Mangling.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -582,7 +582,7 @@ Types
582582
type-list ::= empty-list
583583

584584
// FIXME: Consider replacing 'h' with a two-char code
585-
list-type ::= type identifier? 'z'? 'h'? 'n'? 'd'? // type with optional label, inout convention, shared convention, owned convention, and variadic specifier
585+
list-type ::= type identifier? 'k'? 'z'? 'h'? 'n'? 'd'? // type with optional label, '@noDerivative', inout convention, shared convention, owned convention, and variadic specifier
586586

587587
METATYPE-REPR ::= 't' // Thin metatype representation
588588
METATYPE-REPR ::= 'T' // Thick metatype representation
@@ -666,7 +666,7 @@ mangled in to disambiguate.
666666
COROUTINE-KIND ::= 'A' // yield-once coroutine
667667
COROUTINE-KIND ::= 'G' // yield-many coroutine
668668

669-
SENDABLE ::= 'h' // @Sendable
669+
SENDABLE ::= 'h' // @Sendable
670670
ASYNC ::= 'H' // @async
671671

672672
PARAM-CONVENTION ::= 'i' // indirect in

include/swift/AST/Attr.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2324,8 +2324,6 @@ class TypeAttributes {
23242324

23252325
Optional<Convention> ConventionArguments;
23262326

2327-
// Indicates whether the type's '@differentiable' attribute has a 'linear'
2328-
// argument.
23292327
DifferentiabilityKind differentiabilityKind =
23302328
DifferentiabilityKind::NonDifferentiable;
23312329

include/swift/AST/Types.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1929,7 +1929,7 @@ class ParameterTypeFlags {
19291929
NonEphemeral = 1 << 2,
19301930
OwnershipShift = 3,
19311931
Ownership = 7 << OwnershipShift,
1932-
NoDerivative = 1 << 7,
1932+
NoDerivative = 1 << 6,
19331933
NumBits = 7
19341934
};
19351935
OptionSet<ParameterFlags> value;

include/swift/Demangling/DemangleNodes.def

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -312,6 +312,7 @@ NODE(AutoDiffSelfReorderingReabstractionThunk)
312312
NODE(AutoDiffSubsetParametersThunk)
313313
NODE(AutoDiffDerivativeVTableThunk)
314314
NODE(DifferentiabilityWitness)
315+
NODE(NoDerivative)
315316
NODE(IndexSubset)
316317
NODE(AsyncAwaitResumePartialFunction)
317318
NODE(AsyncSuspendResumePartialFunction)

include/swift/Demangling/TypeDecoder.h

Lines changed: 30 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ class FunctionParam {
7272
void setValueOwnership(ValueOwnership ownership) {
7373
Flags = Flags.withValueOwnership(ownership);
7474
}
75+
void setNoDerivative() { Flags = Flags.withNoDerivative(true); }
7576
void setFlags(ParameterFlags flags) { Flags = flags; };
7677

7778
FunctionParam withLabel(StringRef label) const {
@@ -1375,28 +1376,39 @@ class TypeDecoder {
13751376
node = node->getFirstChild();
13761377
hasParamFlags = true;
13771378
};
1378-
switch (node->getKind()) {
1379-
case NodeKind::InOut:
1380-
setOwnership(ValueOwnership::InOut);
1381-
break;
13821379

1383-
case NodeKind::Shared:
1384-
setOwnership(ValueOwnership::Shared);
1385-
break;
1380+
bool recurse = true;
1381+
while (recurse) {
1382+
switch (node->getKind()) {
1383+
case NodeKind::InOut:
1384+
setOwnership(ValueOwnership::InOut);
1385+
break;
13861386

1387-
case NodeKind::Owned:
1388-
setOwnership(ValueOwnership::Owned);
1389-
break;
1387+
case NodeKind::Shared:
1388+
setOwnership(ValueOwnership::Shared);
1389+
break;
13901390

1391-
case NodeKind::AutoClosureType:
1392-
case NodeKind::EscapingAutoClosureType: {
1393-
param.setAutoClosure();
1394-
hasParamFlags = true;
1395-
break;
1396-
}
1391+
case NodeKind::Owned:
1392+
setOwnership(ValueOwnership::Owned);
1393+
break;
13971394

1398-
default:
1399-
break;
1395+
case NodeKind::NoDerivative:
1396+
param.setNoDerivative();
1397+
node = node->getFirstChild();
1398+
hasParamFlags = true;
1399+
break;
1400+
1401+
case NodeKind::AutoClosureType:
1402+
case NodeKind::EscapingAutoClosureType:
1403+
param.setAutoClosure();
1404+
hasParamFlags = true;
1405+
recurse = false;
1406+
break;
1407+
1408+
default:
1409+
recurse = false;
1410+
break;
1411+
}
14001412
}
14011413

14021414
auto paramType = decodeMangledType(node);

lib/AST/ASTMangler.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2541,6 +2541,9 @@ void ASTMangler::appendTypeListElement(Identifier name, Type elementType,
25412541
else
25422542
appendType(elementType, forDecl);
25432543

2544+
if (flags.isNoDerivative()) {
2545+
appendOperator("k");
2546+
}
25442547
switch (flags.getValueOwnership()) {
25452548
case ValueOwnership::Default:
25462549
/* nothing */

lib/Demangling/Demangler.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -793,6 +793,9 @@ NodePointer Demangler::demangleOperator() {
793793
popTypeAndGetChild()));
794794
case 'i': return demangleSubscript();
795795
case 'j': return demangleDifferentiableFunctionType();
796+
case 'k':
797+
return createType(
798+
createWithChild(Node::Kind::NoDerivative, popTypeAndGetChild()));
796799
case 'l': return demangleGenericSignature(/*hasParamCounts*/ false);
797800
case 'm': return createType(createWithChild(Node::Kind::Metatype,
798801
popNode(Node::Kind::Type)));

lib/Demangling/NodePrinter.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -569,6 +569,7 @@ class NodePrinter {
569569
case Node::Kind::AutoDiffSubsetParametersThunk:
570570
case Node::Kind::AutoDiffFunctionKind:
571571
case Node::Kind::DifferentiabilityWitness:
572+
case Node::Kind::NoDerivative:
572573
case Node::Kind::IndexSubset:
573574
case Node::Kind::AsyncAwaitResumePartialFunction:
574575
case Node::Kind::AsyncSuspendResumePartialFunction:
@@ -1421,6 +1422,10 @@ NodePointer NodePrinter::print(NodePointer Node, bool asPrefixContext) {
14211422
Printer << "__owned ";
14221423
print(Node->getChild(0));
14231424
return nullptr;
1425+
case Node::Kind::NoDerivative:
1426+
Printer << "@noDerivative ";
1427+
print(Node->getChild(0));
1428+
return nullptr;
14241429
case Node::Kind::NonObjCAttribute:
14251430
Printer << "@nonobjc ";
14261431
return nullptr;

lib/Demangling/OldDemangler.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2063,6 +2063,14 @@ class OldDemangler {
20632063
inout->addChild(type, Factory);
20642064
return inout;
20652065
}
2066+
if (c == 'k') {
2067+
auto noDerivative = Factory.createNode(Node::Kind::NoDerivative);
2068+
auto type = demangleTypeImpl();
2069+
if (!type)
2070+
return nullptr;
2071+
noDerivative->addChild(type, Factory);
2072+
return noDerivative;
2073+
}
20662074
if (c == 'S') {
20672075
return demangleSubstitutionIndex();
20682076
}

lib/Demangling/OldRemangler.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1485,6 +1485,11 @@ void Remangler::mangleInOut(Node *node) {
14851485
mangleSingleChildNode(node); // type
14861486
}
14871487

1488+
void Remangler::mangleNoDerivative(Node *node) {
1489+
Buffer << 'k';
1490+
mangleSingleChildNode(node); // type
1491+
}
1492+
14881493
void Remangler::mangleTuple(Node *node) {
14891494
size_t NumElems = node->getNumChildren();
14901495
if (NumElems > 0 &&

0 commit comments

Comments
 (0)