Skip to content

Commit b520ec0

Browse files
authored
Merge pull request #2620 from swiftwasm/main
[pull] swiftwasm from main
2 parents a02c617 + 9aed7ce commit b520ec0

19 files changed

+220
-120
lines changed

include/swift/Demangling/DemangleNodes.def

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -309,7 +309,7 @@ NODE(CanonicalPrespecializedGenericTypeCachingOnceToken)
309309

310310
// Added in Swift 5.5
311311
NODE(AsyncFunctionPointer)
312-
NODE(AutoDiffFunction)
312+
CONTEXT_NODE(AutoDiffFunction)
313313
NODE(AutoDiffFunctionKind)
314314
NODE(IndexSubset)
315315

include/swift/IRGen/Linking.h

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,11 @@ class LinkEntity {
126126
/// or a class.
127127
DispatchThunk,
128128

129+
/// A derivative method dispatch thunk. The pointer is a
130+
/// AbstractFunctionDecl* inside a protocol or a class, and the secondary
131+
/// pointer is an AutoDiffDerivativeFunctionIdentifier*.
132+
DispatchThunkDerivative,
133+
129134
/// A method dispatch thunk for an initializing constructor. The pointer
130135
/// is a ConstructorDecl* inside a class.
131136
DispatchThunkInitializer,
@@ -152,6 +157,11 @@ class LinkEntity {
152157
/// or a class.
153158
MethodDescriptor,
154159

160+
/// A derivative method descriptor. The pointer is a AbstractFunctionDecl*
161+
/// inside a protocol or a class, and the secondary pointer is an
162+
/// AutoDiffDerivativeFunctionIdentifier*.
163+
MethodDescriptorDerivative,
164+
155165
/// A method descriptor for an initializing constructor. The pointer
156166
/// is a ConstructorDecl* inside a class.
157167
MethodDescriptorInitializer,
@@ -618,6 +628,16 @@ class LinkEntity {
618628
static LinkEntity forDispatchThunk(SILDeclRef declRef) {
619629
assert(isValidResilientMethodRef(declRef));
620630

631+
if (declRef.isAutoDiffDerivativeFunction()) {
632+
LinkEntity entity;
633+
// The derivative function for any decl is always a method (not an
634+
// initializer).
635+
entity.setForDecl(Kind::DispatchThunkDerivative, declRef.getDecl());
636+
entity.SecondaryPointer =
637+
declRef.getAutoDiffDerivativeFunctionIdentifier();
638+
return entity;
639+
}
640+
621641
LinkEntity::Kind kind;
622642
switch (declRef.kind) {
623643
case SILDeclRef::Kind::Func:
@@ -641,6 +661,16 @@ class LinkEntity {
641661
static LinkEntity forMethodDescriptor(SILDeclRef declRef) {
642662
assert(isValidResilientMethodRef(declRef));
643663

664+
if (declRef.isAutoDiffDerivativeFunction()) {
665+
LinkEntity entity;
666+
// The derivative function for any decl is always a method (not an
667+
// initializer).
668+
entity.setForDecl(Kind::MethodDescriptorDerivative, declRef.getDecl());
669+
entity.SecondaryPointer =
670+
declRef.getAutoDiffDerivativeFunctionIdentifier();
671+
return entity;
672+
}
673+
644674
LinkEntity::Kind kind;
645675
switch (declRef.kind) {
646676
case SILDeclRef::Kind::Func:
@@ -1263,6 +1293,15 @@ class LinkEntity {
12631293
assert(getKind() == Kind::AssociatedTypeWitnessTableAccessFunction);
12641294
return reinterpret_cast<ProtocolDecl*>(Pointer);
12651295
}
1296+
1297+
AutoDiffDerivativeFunctionIdentifier *
1298+
getAutoDiffDerivativeFunctionIdentifier() const {
1299+
assert(getKind() == Kind::DispatchThunkDerivative ||
1300+
getKind() == Kind::MethodDescriptorDerivative);
1301+
return reinterpret_cast<AutoDiffDerivativeFunctionIdentifier*>(
1302+
SecondaryPointer);
1303+
}
1304+
12661305
bool isDynamicallyReplaceable() const {
12671306
assert(getKind() == Kind::SILFunction);
12681307
return LINKENTITY_GET_FIELD(Data, IsDynamicallyReplaceableImpl);

include/swift/SIL/SILVTableVisitor.h

Lines changed: 20 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -45,21 +45,7 @@ template <class T> class SILVTableVisitor {
4545
SILDeclRef constant(fd, SILDeclRef::Kind::Func);
4646
maybeAddEntry(constant);
4747

48-
for (auto *diffAttr : fd->getAttrs().getAttributes<DifferentiableAttr>()) {
49-
auto jvpConstant = constant.asAutoDiffDerivativeFunction(
50-
AutoDiffDerivativeFunctionIdentifier::get(
51-
AutoDiffDerivativeFunctionKind::JVP,
52-
diffAttr->getParameterIndices(),
53-
diffAttr->getDerivativeGenericSignature(), fd->getASTContext()));
54-
maybeAddEntry(jvpConstant);
55-
56-
auto vjpConstant = constant.asAutoDiffDerivativeFunction(
57-
AutoDiffDerivativeFunctionIdentifier::get(
58-
AutoDiffDerivativeFunctionKind::VJP,
59-
diffAttr->getParameterIndices(),
60-
diffAttr->getDerivativeGenericSignature(), fd->getASTContext()));
61-
maybeAddEntry(vjpConstant);
62-
}
48+
maybeAddAutoDiffDerivativeMethods(constant);
6349
}
6450

6551
void maybeAddConstructor(ConstructorDecl *cd) {
@@ -72,21 +58,7 @@ template <class T> class SILVTableVisitor {
7258
SILDeclRef constant(cd, SILDeclRef::Kind::Allocator);
7359
maybeAddEntry(constant);
7460

75-
for (auto *diffAttr : cd->getAttrs().getAttributes<DifferentiableAttr>()) {
76-
auto jvpConstant = constant.asAutoDiffDerivativeFunction(
77-
AutoDiffDerivativeFunctionIdentifier::get(
78-
AutoDiffDerivativeFunctionKind::JVP,
79-
diffAttr->getParameterIndices(),
80-
diffAttr->getDerivativeGenericSignature(), cd->getASTContext()));
81-
maybeAddEntry(jvpConstant);
82-
83-
auto vjpConstant = constant.asAutoDiffDerivativeFunction(
84-
AutoDiffDerivativeFunctionIdentifier::get(
85-
AutoDiffDerivativeFunctionKind::VJP,
86-
diffAttr->getParameterIndices(),
87-
diffAttr->getDerivativeGenericSignature(), cd->getASTContext()));
88-
maybeAddEntry(vjpConstant);
89-
}
61+
maybeAddAutoDiffDerivativeMethods(constant);
9062
}
9163

9264
void maybeAddAccessors(AbstractStorageDecl *asd) {
@@ -142,6 +114,24 @@ template <class T> class SILVTableVisitor {
142114
asDerived().addPlaceholder(placeholder);
143115
}
144116

117+
void maybeAddAutoDiffDerivativeMethods(SILDeclRef constant) {
118+
auto *D = constant.getDecl();
119+
for (auto *diffAttr : D->getAttrs().getAttributes<DifferentiableAttr>()) {
120+
maybeAddEntry(constant.asAutoDiffDerivativeFunction(
121+
AutoDiffDerivativeFunctionIdentifier::get(
122+
AutoDiffDerivativeFunctionKind::JVP,
123+
diffAttr->getParameterIndices(),
124+
diffAttr->getDerivativeGenericSignature(),
125+
D->getASTContext())));
126+
maybeAddEntry(constant.asAutoDiffDerivativeFunction(
127+
AutoDiffDerivativeFunctionIdentifier::get(
128+
AutoDiffDerivativeFunctionKind::VJP,
129+
diffAttr->getParameterIndices(),
130+
diffAttr->getDerivativeGenericSignature(),
131+
D->getASTContext())));
132+
}
133+
}
134+
145135
protected:
146136
void addVTableEntries(ClassDecl *theClass) {
147137
// Imported classes do not have a vtable.

lib/Demangling/OldRemangler.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -744,7 +744,7 @@ void Remangler::mangleReabstractionThunk(Node *node) {
744744
Buffer << "<reabstraction-thunk>";
745745
}
746746

747-
void Remangler::mangleAutoDiffFunction(Node *node) {
747+
void Remangler::mangleAutoDiffFunction(Node *node, EntityContext &ctx) {
748748
Buffer << "<autodiff-function>";
749749
}
750750

lib/IRGen/GenThunk.cpp

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -165,9 +165,10 @@ void IRGenThunk::prepareArguments() {
165165
}
166166

167167
for (unsigned i = 0, e = asyncLayout->getArgumentCount(); i < e; ++i) {
168-
Address addr = asyncLayout->getArgumentLayout(i).project(
169-
IGF, context, llvm::None);
170-
params.add(IGF.Builder.CreateLoad(addr));
168+
auto layout = asyncLayout->getArgumentLayout(i);
169+
Address addr = layout.project(IGF, context, llvm::None);
170+
auto &ti = cast<LoadableTypeInfo>(layout.getType());
171+
ti.loadAsTake(IGF, addr, params);
171172
}
172173

173174
if (asyncLayout->hasBindings()) {
@@ -329,8 +330,20 @@ void IRGenThunk::emit() {
329330
emission->emitToExplosion(result, /*isOutlined=*/false);
330331
}
331332

333+
llvm::Value *errorValue = nullptr;
334+
335+
if (isAsync && origTy->hasErrorResult()) {
336+
SILType errorType = conv.getSILErrorType(expansionContext);
337+
Address calleeErrorSlot = emission->getCalleeErrorSlot(errorType);
338+
errorValue = IGF.Builder.CreateLoad(calleeErrorSlot);
339+
}
340+
332341
emission->end();
333342

343+
if (isAsync && errorValue) {
344+
IGF.Builder.CreateStore(errorValue, IGF.getCallerErrorResultSlot());
345+
}
346+
334347
if (isAsync) {
335348
emitAsyncReturn(IGF, *asyncLayout, origTy);
336349
IGF.emitCoroutineOrAsyncExit();

lib/IRGen/IRGenMangler.h

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
#include "IRGenModule.h"
1717
#include "swift/AST/ASTMangler.h"
18+
#include "swift/AST/AutoDiff.h"
1819
#include "swift/AST/ProtocolAssociations.h"
1920
#include "swift/IRGen/ValueWitness.h"
2021
#include "llvm/Support/SaveAndRestore.h"
@@ -51,6 +52,21 @@ class IRGenMangler : public Mangle::ASTMangler {
5152
return finalize();
5253
}
5354

55+
std::string mangleDerivativeDispatchThunk(
56+
const AbstractFunctionDecl *func,
57+
AutoDiffDerivativeFunctionIdentifier *derivativeId) {
58+
beginManglingWithAutoDiffOriginalFunction(func);
59+
auto kindCode =
60+
(char)Demangle::getAutoDiffFunctionKind(derivativeId->getKind());
61+
AutoDiffConfig config(
62+
derivativeId->getParameterIndices(),
63+
IndexSubset::get(func->getASTContext(), 1, {0}),
64+
derivativeId->getDerivativeGenericSignature());
65+
appendAutoDiffFunctionParts(kindCode, config);
66+
appendOperator("Tj");
67+
return finalize();
68+
}
69+
5470
std::string mangleConstructorDispatchThunk(const ConstructorDecl *ctor,
5571
bool isAllocating) {
5672
beginMangling();
@@ -66,6 +82,21 @@ class IRGenMangler : public Mangle::ASTMangler {
6682
return finalize();
6783
}
6884

85+
std::string mangleDerivativeMethodDescriptor(
86+
const AbstractFunctionDecl *func,
87+
AutoDiffDerivativeFunctionIdentifier *derivativeId) {
88+
beginManglingWithAutoDiffOriginalFunction(func);
89+
auto kindCode =
90+
(char)Demangle::getAutoDiffFunctionKind(derivativeId->getKind());
91+
AutoDiffConfig config(
92+
derivativeId->getParameterIndices(),
93+
IndexSubset::get(func->getASTContext(), 1, {0}),
94+
derivativeId->getDerivativeGenericSignature());
95+
appendAutoDiffFunctionParts(kindCode, config);
96+
appendOperator("Tq");
97+
return finalize();
98+
}
99+
69100
std::string mangleConstructorMethodDescriptor(const ConstructorDecl *ctor,
70101
bool isAllocating) {
71102
beginMangling();

lib/IRGen/Linking.cpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,12 @@ std::string LinkEntity::mangleAsString() const {
104104
return mangler.mangleDispatchThunk(func);
105105
}
106106

107+
case Kind::DispatchThunkDerivative: {
108+
auto *func = cast<AbstractFunctionDecl>(getDecl());
109+
auto *derivativeId = getAutoDiffDerivativeFunctionIdentifier();
110+
return mangler.mangleDerivativeDispatchThunk(func, derivativeId);
111+
}
112+
107113
case Kind::DispatchThunkInitializer: {
108114
auto *ctor = cast<ConstructorDecl>(getDecl());
109115
return mangler.mangleConstructorDispatchThunk(ctor,
@@ -121,6 +127,12 @@ std::string LinkEntity::mangleAsString() const {
121127
return mangler.mangleMethodDescriptor(func);
122128
}
123129

130+
case Kind::MethodDescriptorDerivative: {
131+
auto *func = cast<AbstractFunctionDecl>(getDecl());
132+
auto *derivativeId = getAutoDiffDerivativeFunctionIdentifier();
133+
return mangler.mangleDerivativeMethodDescriptor(func, derivativeId);
134+
}
135+
124136
case Kind::MethodDescriptorInitializer: {
125137
auto *ctor = cast<ConstructorDecl>(getDecl());
126138
return mangler.mangleConstructorMethodDescriptor(ctor,
@@ -460,9 +472,11 @@ SILLinkage LinkEntity::getLinkage(ForDefinition_t forDefinition) const {
460472

461473
switch (getKind()) {
462474
case Kind::DispatchThunk:
475+
case Kind::DispatchThunkDerivative:
463476
case Kind::DispatchThunkInitializer:
464477
case Kind::DispatchThunkAllocator:
465478
case Kind::MethodDescriptor:
479+
case Kind::MethodDescriptorDerivative:
466480
case Kind::MethodDescriptorInitializer:
467481
case Kind::MethodDescriptorAllocator: {
468482
auto *decl = getDecl();
@@ -742,12 +756,14 @@ bool LinkEntity::isContextDescriptor() const {
742756
case Kind::AsyncFunctionPointerAST:
743757
case Kind::PropertyDescriptor:
744758
case Kind::DispatchThunk:
759+
case Kind::DispatchThunkDerivative:
745760
case Kind::DispatchThunkInitializer:
746761
case Kind::DispatchThunkAllocator:
747762
case Kind::DispatchThunkAsyncFunctionPointer:
748763
case Kind::DispatchThunkInitializerAsyncFunctionPointer:
749764
case Kind::DispatchThunkAllocatorAsyncFunctionPointer:
750765
case Kind::MethodDescriptor:
766+
case Kind::MethodDescriptorDerivative:
751767
case Kind::MethodDescriptorInitializer:
752768
case Kind::MethodDescriptorAllocator:
753769
case Kind::MethodLookupFunction:
@@ -892,6 +908,7 @@ llvm::Type *LinkEntity::getDefaultDeclarationType(IRGenModule &IGM) const {
892908
case Kind::MethodDescriptor:
893909
case Kind::MethodDescriptorInitializer:
894910
case Kind::MethodDescriptorAllocator:
911+
case Kind::MethodDescriptorDerivative:
895912
return IGM.MethodDescriptorStructTy;
896913
case Kind::DynamicallyReplaceableFunctionKey:
897914
case Kind::OpaqueTypeDescriptorAccessorKey:
@@ -1020,9 +1037,11 @@ bool LinkEntity::isWeakImported(ModuleDecl *module) const {
10201037

10211038
case Kind::AsyncFunctionPointerAST:
10221039
case Kind::DispatchThunk:
1040+
case Kind::DispatchThunkDerivative:
10231041
case Kind::DispatchThunkInitializer:
10241042
case Kind::DispatchThunkAllocator:
10251043
case Kind::MethodDescriptor:
1044+
case Kind::MethodDescriptorDerivative:
10261045
case Kind::MethodDescriptorInitializer:
10271046
case Kind::MethodDescriptorAllocator:
10281047
case Kind::MethodLookupFunction:
@@ -1104,9 +1123,11 @@ DeclContext *LinkEntity::getDeclContextForEmission() const {
11041123
switch (getKind()) {
11051124
case Kind::AsyncFunctionPointerAST:
11061125
case Kind::DispatchThunk:
1126+
case Kind::DispatchThunkDerivative:
11071127
case Kind::DispatchThunkInitializer:
11081128
case Kind::DispatchThunkAllocator:
11091129
case Kind::MethodDescriptor:
1130+
case Kind::MethodDescriptorDerivative:
11101131
case Kind::MethodDescriptorInitializer:
11111132
case Kind::MethodDescriptorAllocator:
11121133
case Kind::MethodLookupFunction:

lib/SILGen/SILGenBuiltin.cpp

Lines changed: 0 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1231,41 +1231,6 @@ static ManagedValue emitBuiltinApplyTranspose(
12311231
arity, throws, SGF, loc, substitutions, args, C);
12321232
}
12331233

1234-
static ManagedValue emitBuiltinDifferentiableFunction(
1235-
SILGenFunction &SGF, SILLocation loc, SubstitutionMap substitutions,
1236-
ArrayRef<ManagedValue> args, SGFContext C) {
1237-
assert(args.size() == 3);
1238-
auto origFn = args.front();
1239-
auto origType = origFn.getType().castTo<SILFunctionType>();
1240-
auto numResults =
1241-
origType->getNumResults() + origType->getNumIndirectMutatingParameters();
1242-
auto diffFn = SGF.B.createDifferentiableFunction(
1243-
loc,
1244-
IndexSubset::getDefault(SGF.getASTContext(), origType->getNumParameters(),
1245-
/*includeAll*/ true),
1246-
IndexSubset::getDefault(SGF.getASTContext(), numResults,
1247-
/*includeAll*/ true),
1248-
origFn.forward(SGF),
1249-
std::make_pair(args[1].forward(SGF), args[2].forward(SGF)));
1250-
return SGF.emitManagedRValueWithCleanup(diffFn);
1251-
}
1252-
1253-
static ManagedValue emitBuiltinLinearFunction(
1254-
SILGenFunction &SGF, SILLocation loc, SubstitutionMap substitutions,
1255-
ArrayRef<ManagedValue> args, SGFContext C) {
1256-
assert(args.size() == 2);
1257-
auto origFn = args.front();
1258-
auto origType = origFn.getType().castTo<SILFunctionType>();
1259-
auto linearFn = SGF.B.createLinearFunction(
1260-
loc,
1261-
IndexSubset::getDefault(
1262-
SGF.getASTContext(),
1263-
origType->getNumParameters(),
1264-
/*includeAll*/ true),
1265-
origFn.forward(SGF), args[1].forward(SGF));
1266-
return SGF.emitManagedRValueWithCleanup(linearFn);
1267-
}
1268-
12691234
/// Emit SIL for the named builtin: globalStringTablePointer. Unlike the default
12701235
/// ownership convention for named builtins, which is to take (non-trivial)
12711236
/// arguments as Owned, this builtin accepts owned as well as guaranteed

lib/Sema/TypeCheckConcurrency.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2016,6 +2016,12 @@ void swift::checkOverrideActorIsolation(ValueDecl *value) {
20162016
if (isolation == overriddenIsolation)
20172017
return;
20182018

2019+
// If the overridden declaration is from Objective-C with no actor annotation,
2020+
// and the overriding declaration has been placed in a global actor, allow it.
2021+
if (overridden->hasClangNode() && !overriddenIsolation &&
2022+
isolation.getKind() == ActorIsolation::GlobalActor)
2023+
return;
2024+
20192025
// Isolation mismatch. Diagnose it.
20202026
value->diagnose(
20212027
diag::actor_isolation_override_mismatch, isolation,

0 commit comments

Comments
 (0)