Skip to content

Commit 493b4a8

Browse files
author
marcrasi
authored
Merge pull request swiftlang#32916 from marcrasi/remove-gen-sig-more-places
[AutoDiff] remove all-concrete gen sig from more places
2 parents fe5ddd4 + 7191c9c commit 493b4a8

File tree

8 files changed

+89
-50
lines changed

8 files changed

+89
-50
lines changed

include/swift/AST/AutoDiff.h

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -649,6 +649,29 @@ bool getBuiltinDifferentiableOrLinearFunctionConfig(
649649
bool getBuiltinDifferentiableOrLinearFunctionConfig(
650650
StringRef operationName, unsigned &arity, bool &throws);
651651

652+
/// Returns the SIL differentiability witness generic signature given the
653+
/// original declaration's generic signature and the derivative generic
654+
/// signature.
655+
///
656+
/// In general, the differentiability witness generic signature is equal to the
657+
/// derivative generic signature.
658+
///
659+
/// Edge case, if two conditions are satisfied:
660+
/// 1. The derivative generic signature is equal to the original generic
661+
/// signature.
662+
/// 2. The derivative generic signature has *all concrete* generic parameters
663+
/// (i.e. all generic parameters are bound to concrete types via same-type
664+
/// requirements).
665+
///
666+
/// Then the differentiability witness generic signature is `nullptr`.
667+
///
668+
/// Both the original and derivative declarations are lowered to SIL functions
669+
/// with a fully concrete type and no generic signature, so the
670+
/// differentiability witness should similarly have no generic signature.
671+
GenericSignature
672+
getDifferentiabilityWitnessGenericSignature(GenericSignature origGenSig,
673+
GenericSignature derivativeGenSig);
674+
652675
} // end namespace autodiff
653676

654677
} // end namespace swift

lib/AST/AutoDiff.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -372,6 +372,23 @@ bool autodiff::getBuiltinDifferentiableOrLinearFunctionConfig(
372372
return operationName.empty();
373373
}
374374

375+
GenericSignature autodiff::getDifferentiabilityWitnessGenericSignature(
376+
GenericSignature origGenSig, GenericSignature derivativeGenSig) {
377+
// If there is no derivative generic signature, return the original generic
378+
// signature.
379+
if (!derivativeGenSig)
380+
return origGenSig;
381+
// If derivative generic signature has all concrete generic parameters and is
382+
// equal to the original generic signature, return `nullptr`.
383+
auto derivativeCanGenSig = derivativeGenSig.getCanonicalSignature();
384+
auto origCanGenSig = origGenSig.getCanonicalSignature();
385+
if (origCanGenSig == derivativeCanGenSig &&
386+
derivativeCanGenSig->areAllParamsConcrete())
387+
return GenericSignature();
388+
// Otherwise, return the derivative generic signature.
389+
return derivativeGenSig;
390+
}
391+
375392
Type TangentSpace::getType() const {
376393
switch (kind) {
377394
case Kind::TangentVector:

lib/SILGen/SILGen.cpp

Lines changed: 7 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -933,43 +933,6 @@ void SILGenModule::postEmitFunction(SILDeclRef constant,
933933
emitDifferentiabilityWitnessesForFunction(constant, F);
934934
}
935935

936-
/// Returns the SIL differentiability witness generic signature given the
937-
/// original declaration's generic signature and the derivative generic
938-
/// signature.
939-
///
940-
/// In general, the differentiability witness generic signature is equal to the
941-
/// derivative generic signature.
942-
///
943-
/// Edge case, if two conditions are satisfied:
944-
/// 1. The derivative generic signature is equal to the original generic
945-
/// signature.
946-
/// 2. The derivative generic signature has *all concrete* generic parameters
947-
/// (i.e. all generic parameters are bound to concrete types via same-type
948-
/// requirements).
949-
///
950-
/// Then the differentiability witness generic signature is `nullptr`.
951-
///
952-
/// Both the original and derivative declarations are lowered to SIL functions
953-
/// with a fully concrete type and no generic signature, so the
954-
/// differentiability witness should similarly have no generic signature.
955-
static GenericSignature
956-
getDifferentiabilityWitnessGenericSignature(GenericSignature origGenSig,
957-
GenericSignature derivativeGenSig) {
958-
// If there is no derivative generic signature, return the original generic
959-
// signature.
960-
if (!derivativeGenSig)
961-
return origGenSig;
962-
// If derivative generic signature has all concrete generic parameters and is
963-
// equal to the original generic signature, return `nullptr`.
964-
auto derivativeCanGenSig = derivativeGenSig.getCanonicalSignature();
965-
auto origCanGenSig = origGenSig.getCanonicalSignature();
966-
if (origCanGenSig == derivativeCanGenSig &&
967-
derivativeCanGenSig->areAllParamsConcrete())
968-
return GenericSignature();
969-
// Otherwise, return the derivative generic signature.
970-
return derivativeGenSig;
971-
}
972-
973936
void SILGenModule::emitDifferentiabilityWitnessesForFunction(
974937
SILDeclRef constant, SILFunction *F) {
975938
// Visit `@derivative` attributes and generate SIL differentiability
@@ -990,9 +953,10 @@ void SILGenModule::emitDifferentiabilityWitnessesForFunction(
990953
diffAttr->getDerivativeGenericSignature()) &&
991954
"Type-checking should resolve derivative generic signatures for "
992955
"all original SIL functions with generic signatures");
993-
auto witnessGenSig = getDifferentiabilityWitnessGenericSignature(
994-
AFD->getGenericSignature(),
995-
diffAttr->getDerivativeGenericSignature());
956+
auto witnessGenSig =
957+
autodiff::getDifferentiabilityWitnessGenericSignature(
958+
AFD->getGenericSignature(),
959+
diffAttr->getDerivativeGenericSignature());
996960
AutoDiffConfig config(diffAttr->getParameterIndices(), resultIndices,
997961
witnessGenSig);
998962
emitDifferentiabilityWitness(AFD, F, config, /*jvp*/ nullptr,
@@ -1013,8 +977,9 @@ void SILGenModule::emitDifferentiabilityWitnessesForFunction(
1013977
auto origDeclRef =
1014978
SILDeclRef(origAFD).asForeign(requiresForeignEntryPoint(origAFD));
1015979
auto *origFn = getFunction(origDeclRef, NotForDefinition);
1016-
auto witnessGenSig = getDifferentiabilityWitnessGenericSignature(
1017-
origAFD->getGenericSignature(), AFD->getGenericSignature());
980+
auto witnessGenSig =
981+
autodiff::getDifferentiabilityWitnessGenericSignature(
982+
origAFD->getGenericSignature(), AFD->getGenericSignature());
1018983
auto *resultIndices = IndexSubset::get(getASTContext(), 1, {0});
1019984
AutoDiffConfig config(derivAttr->getParameterIndices(), resultIndices,
1020985
witnessGenSig);

lib/SILOptimizer/Differentiation/Common.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -452,8 +452,11 @@ findMinimalDerivativeConfiguration(AbstractFunctionDecl *original,
452452
silParameterIndices->getNumIndices() <
453453
minimalConfig->parameterIndices->getNumIndices())) {
454454
minimalASTParameterIndices = config.parameterIndices;
455-
minimalConfig = AutoDiffConfig(silParameterIndices, config.resultIndices,
456-
config.derivativeGenericSignature);
455+
minimalConfig =
456+
AutoDiffConfig(silParameterIndices, config.resultIndices,
457+
autodiff::getDifferentiabilityWitnessGenericSignature(
458+
original->getGenericSignature(),
459+
config.derivativeGenericSignature));
457460
}
458461
}
459462
return minimalConfig;

lib/TBDGen/TBDGen.cpp

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -530,8 +530,10 @@ void TBDGenVisitor::addAutoDiffLinearMapFunction(AbstractFunctionDecl *original,
530530
config.parameterIndices,
531531
original->getInterfaceType()->castTo<AnyFunctionType>());
532532
Mangle::ASTMangler mangler;
533-
AutoDiffConfig silConfig{loweredParamIndices, config.resultIndices,
534-
config.derivativeGenericSignature};
533+
AutoDiffConfig silConfig{
534+
loweredParamIndices, config.resultIndices,
535+
autodiff::getDifferentiabilityWitnessGenericSignature(
536+
original->getGenericSignature(), config.derivativeGenericSignature)};
535537
std::string linearMapName =
536538
mangler.mangleAutoDiffLinearMapHelper(declRef.mangle(), kind, silConfig);
537539
addSymbol(linearMapName);
@@ -542,7 +544,9 @@ void TBDGenVisitor::addAutoDiffDerivativeFunction(
542544
GenericSignature derivativeGenericSignature,
543545
AutoDiffDerivativeFunctionKind kind) {
544546
auto *assocFnId = AutoDiffDerivativeFunctionIdentifier::get(
545-
kind, parameterIndices, derivativeGenericSignature,
547+
kind, parameterIndices,
548+
autodiff::getDifferentiabilityWitnessGenericSignature(
549+
original->getGenericSignature(), derivativeGenericSignature),
546550
original->getASTContext());
547551
auto declRef =
548552
SILDeclRef(original).asForeign(requiresForeignEntryPoint(original));
@@ -569,8 +573,10 @@ void TBDGenVisitor::addDifferentiabilityWitness(
569573
original->getInterfaceType()->castTo<AnyFunctionType>());
570574

571575
auto originalMangledName = declRef.mangle();
572-
AutoDiffConfig config{silParamIndices, resultIndices,
573-
derivativeGenericSignature};
576+
AutoDiffConfig config{
577+
silParamIndices, resultIndices,
578+
autodiff::getDifferentiabilityWitnessGenericSignature(
579+
original->getGenericSignature(), derivativeGenericSignature)};
574580
SILDifferentiabilityWitnessKey key(originalMangledName, config);
575581

576582
Mangle::ASTMangler mangler;

test/AutoDiff/SILOptimizer/Inputs/differentiation_diagnostics_other_file.swift

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,3 +41,14 @@ class Class: Differentiable {
4141
set {}
4242
}
4343
}
44+
45+
struct S: Differentiable {
46+
var value: Float
47+
}
48+
49+
extension Array where Element == S {
50+
@differentiable
51+
func sum() -> Float {
52+
return 0
53+
}
54+
}

test/AutoDiff/SILOptimizer/differentiation_diagnostics_cross_file.swift

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,3 +57,10 @@ func classRequirementSetters(_ x: inout Class, _ newValue: Float) {
5757
x.property = newValue
5858
x[] = newValue
5959
}
60+
61+
// Test cross-file lookup of a derivative function with all-concrete derivative generic signature.
62+
@differentiable
63+
func allConcreteDerivativeGenericSignature(_ a: [S]) -> Float {
64+
// No error expected.
65+
return a.sum()
66+
}

test/AutoDiff/TBD/derivative_symbols.swift

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ public func topLevelDerivative<T: Differentiable>(_ x: T) -> (
1919
fatalError()
2020
}
2121

22-
struct Struct: Differentiable {
22+
public struct Struct: Differentiable {
2323
var stored: Float
2424

2525
// Test property.
@@ -54,3 +54,10 @@ struct Struct: Differentiable {
5454
fatalError()
5555
}
5656
}
57+
58+
extension Array where Element == Struct {
59+
@differentiable
60+
public func sum() -> Float {
61+
return 0
62+
}
63+
}

0 commit comments

Comments
 (0)