Skip to content

Commit 4d391f0

Browse files
authored
Add Differentiable requirements to pattern substitutions / pattern generic signature (#68777)
Add `Differentiable` requirements to pattern substitutions / pattern generic signature when calculating constrained function type. Also, add requirements for differentiable results as well. Fixes #65487
1 parent 5ad504e commit 4d391f0

File tree

5 files changed

+137
-44
lines changed

5 files changed

+137
-44
lines changed

include/swift/AST/AutoDiff.h

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ class AbstractFunctionDecl;
3636
class AnyFunctionType;
3737
class SourceFile;
3838
class SILFunctionType;
39+
class SILResultInfo;
3940
class TupleType;
4041
class VarDecl;
4142

@@ -621,18 +622,28 @@ IndexSubset *getFunctionSemanticResultIndices(
621622
IndexSubset *getLoweredParameterIndices(IndexSubset *astParameterIndices,
622623
AnyFunctionType *functionType);
623624

625+
/// Collects the semantic results of the given function type in
626+
/// `originalResults`. The semantic results are formal results followed by
627+
/// semantic result parameters, in type order.
628+
void
629+
getSemanticResults(SILFunctionType *functionType,
630+
IndexSubset *parameterIndices,
631+
SmallVectorImpl<SILResultInfo> &originalResults);
632+
624633
/// "Constrained" derivative generic signatures require all differentiability
625-
/// parameters to conform to the `Differentiable` protocol.
634+
/// parameters / results to conform to the `Differentiable` protocol.
626635
///
627636
/// "Constrained" transpose generic signatures additionally require all
628637
/// linearity parameters to satisfy `Self == Self.TangentVector`.
629638
///
630639
/// Returns the "constrained" derivative/transpose generic signature given:
631640
/// - An original SIL function type.
632641
/// - Differentiability/linearity parameter indices.
642+
/// - Differentiability/linearity result indices.
633643
/// - A possibly "unconstrained" derivative/transpose generic signature.
634644
GenericSignature getConstrainedDerivativeGenericSignature(
635-
SILFunctionType *originalFnTy, IndexSubset *diffParamIndices,
645+
SILFunctionType *originalFnTy,
646+
IndexSubset *diffParamIndices, IndexSubset *diffResultIndices,
636647
GenericSignature derivativeGenSig, LookupConformanceFn lookupConformance,
637648
bool isTranspose = false);
638649

lib/AST/AutoDiff.cpp

Lines changed: 57 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -297,8 +297,30 @@ autodiff::getLoweredParameterIndices(IndexSubset *parameterIndices,
297297
loweredSILIndices);
298298
}
299299

300+
/// Collects the semantic results of the given function type in
301+
/// `originalResults`. The semantic results are formal results followed by
302+
/// semantic result parameters, in type order.
303+
void
304+
autodiff::getSemanticResults(SILFunctionType *functionType,
305+
IndexSubset *parameterIndices,
306+
SmallVectorImpl<SILResultInfo> &originalResults) {
307+
// Collect original formal results.
308+
originalResults.append(functionType->getResults().begin(),
309+
functionType->getResults().end());
310+
311+
// Collect original semantic result parameters.
312+
for (auto i : range(functionType->getNumParameters())) {
313+
auto param = functionType->getParameters()[i];
314+
if (!param.isAutoDiffSemanticResult())
315+
continue;
316+
if (param.getDifferentiability() != SILParameterDifferentiability::NotDifferentiable)
317+
originalResults.emplace_back(param.getInterfaceType(), ResultConvention::Indirect);
318+
}
319+
}
320+
300321
GenericSignature autodiff::getConstrainedDerivativeGenericSignature(
301-
SILFunctionType *originalFnTy, IndexSubset *diffParamIndices,
322+
SILFunctionType *originalFnTy,
323+
IndexSubset *diffParamIndices, IndexSubset *diffResultIndices,
302324
GenericSignature derivativeGenSig, LookupConformanceFn lookupConformance,
303325
bool isTranspose) {
304326
if (!derivativeGenSig)
@@ -308,21 +330,48 @@ GenericSignature autodiff::getConstrainedDerivativeGenericSignature(
308330
auto &ctx = originalFnTy->getASTContext();
309331
auto *diffableProto = ctx.getProtocol(KnownProtocolKind::Differentiable);
310332
SmallVector<Requirement, 4> requirements;
311-
for (unsigned paramIdx : diffParamIndices->getIndices()) {
312-
// Require differentiability parameters to conform to `Differentiable`.
313-
auto paramType = originalFnTy->getParameters()[paramIdx].getInterfaceType();
314-
Requirement req(RequirementKind::Conformance, paramType,
333+
334+
auto addRequirement = [&](CanType type) {
335+
Requirement req(RequirementKind::Conformance, type,
315336
diffableProto->getDeclaredInterfaceType());
316337
requirements.push_back(req);
317338
if (isTranspose) {
318339
// Require linearity parameters to additionally satisfy
319340
// `Self == Self.TangentVector`.
320-
auto tanSpace = paramType->getAutoDiffTangentSpace(lookupConformance);
321-
auto paramTanType = tanSpace->getCanonicalType();
322-
Requirement req(RequirementKind::SameType, paramType, paramTanType);
341+
auto tanSpace = type->getAutoDiffTangentSpace(lookupConformance);
342+
auto tanType = tanSpace->getCanonicalType();
343+
Requirement req(RequirementKind::SameType, type, tanType);
323344
requirements.push_back(req);
324345
}
346+
};
347+
348+
// Require differentiability parameters to conform to `Differentiable`.
349+
for (unsigned paramIdx : diffParamIndices->getIndices()) {
350+
auto paramType = originalFnTy->getParameters()[paramIdx].getInterfaceType();
351+
addRequirement(paramType);
352+
}
353+
354+
// Require differentiability results to conform to `Differentiable`.
355+
SmallVector<SILResultInfo, 2> originalResults;
356+
getSemanticResults(originalFnTy, diffParamIndices, originalResults);
357+
for (unsigned resultIdx : diffResultIndices->getIndices()) {
358+
// Handle formal original result.
359+
if (resultIdx < originalFnTy->getNumResults()) {
360+
auto resultType = originalResults[resultIdx].getInterfaceType();
361+
addRequirement(resultType);
362+
continue;
363+
}
364+
// Handle original semantic result parameters.
365+
// FIXME: Constraint generic yields when we will start supporting them
366+
auto resultParamIndex = resultIdx - originalFnTy->getNumResults();
367+
auto resultParamIt = std::next(
368+
originalFnTy->getAutoDiffSemanticResultsParameters().begin(),
369+
resultParamIndex);
370+
auto paramIndex =
371+
std::distance(originalFnTy->getParameters().begin(), &*resultParamIt);
372+
addRequirement(originalFnTy->getParameters()[paramIndex].getInterfaceType());
325373
}
374+
326375
return buildGenericSignature(ctx, derivativeGenSig,
327376
/*addedGenericParams*/ {},
328377
std::move(requirements));

lib/SIL/IR/SILFunctionType.cpp

Lines changed: 35 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -366,27 +366,6 @@ getDifferentiabilityParameters(SILFunctionType *originalFnTy,
366366
diffParams.push_back(valueAndIndex.value());
367367
}
368368

369-
/// Collects the semantic results of the given function type in
370-
/// `originalResults`. The semantic results are formal results followed by
371-
/// semantic result parameters, in type order.
372-
static void
373-
getSemanticResults(SILFunctionType *functionType,
374-
IndexSubset *parameterIndices,
375-
SmallVectorImpl<SILResultInfo> &originalResults) {
376-
// Collect original formal results.
377-
originalResults.append(functionType->getResults().begin(),
378-
functionType->getResults().end());
379-
380-
// Collect original semantic result parameters.
381-
for (auto i : range(functionType->getNumParameters())) {
382-
auto param = functionType->getParameters()[i];
383-
if (!param.isAutoDiffSemanticResult())
384-
continue;
385-
if (param.getDifferentiability() != SILParameterDifferentiability::NotDifferentiable)
386-
originalResults.emplace_back(param.getInterfaceType(), ResultConvention::Indirect);
387-
}
388-
}
389-
390369
static CanGenericSignature buildDifferentiableGenericSignature(CanGenericSignature sig,
391370
CanType tanType,
392371
CanType origTypeOfAbstraction) {
@@ -563,7 +542,7 @@ static CanSILFunctionType getAutoDiffDifferentialType(
563542
SmallVector<ProtocolConformanceRef, 4> substConformances;
564543

565544
SmallVector<SILResultInfo, 2> originalResults;
566-
getSemanticResults(originalFnTy, parameterIndices, originalResults);
545+
autodiff::getSemanticResults(originalFnTy, parameterIndices, originalResults);
567546

568547
SmallVector<SILParameterInfo, 4> diffParams;
569548
getDifferentiabilityParameters(originalFnTy, parameterIndices, diffParams);
@@ -647,7 +626,7 @@ static CanSILFunctionType getAutoDiffPullbackType(
647626
SmallVector<ProtocolConformanceRef, 4> substConformances;
648627

649628
SmallVector<SILResultInfo, 2> originalResults;
650-
getSemanticResults(originalFnTy, parameterIndices, originalResults);
629+
autodiff::getSemanticResults(originalFnTy, parameterIndices, originalResults);
651630

652631
// Given a type, returns its formal SIL parameter info.
653632
auto getTangentParameterConventionForOriginalResult =
@@ -791,9 +770,9 @@ static CanSILFunctionType getAutoDiffPullbackType(
791770
llvm::makeArrayRef(substConformances));
792771
}
793772
return SILFunctionType::get(
794-
GenericSignature(), SILFunctionType::ExtInfo(), SILCoroutineKind::None,
795-
ParameterConvention::Direct_Guaranteed, pullbackParams, {},
796-
pullbackResults, llvm::None, substitutions,
773+
GenericSignature(), SILFunctionType::ExtInfo(), originalFnTy->getCoroutineKind(),
774+
ParameterConvention::Direct_Guaranteed,
775+
pullbackParams, {}, pullbackResults, llvm::None, substitutions,
797776
/*invocationSubstitutions*/ SubstitutionMap(), ctx);
798777
}
799778

@@ -804,7 +783,7 @@ static CanSILFunctionType getAutoDiffPullbackType(
804783
/// - The invocation generic signature is replaced by the
805784
/// `constrainedInvocationGenSig` argument.
806785
static SILFunctionType *getConstrainedAutoDiffOriginalFunctionType(
807-
SILFunctionType *original, IndexSubset *parameterIndices,
786+
SILFunctionType *original, IndexSubset *parameterIndices, IndexSubset *resultIndices,
808787
LookupConformanceFn lookupConformance,
809788
CanGenericSignature constrainedInvocationGenSig) {
810789
auto originalInvocationGenSig = original->getInvocationGenericSignature();
@@ -813,6 +792,25 @@ static SILFunctionType *getConstrainedAutoDiffOriginalFunctionType(
813792
constrainedInvocationGenSig->areAllParamsConcrete() &&
814793
"derivative function cannot have invocation generic signature "
815794
"when original function doesn't");
795+
if (auto patternSig = original->getPatternGenericSignature()) {
796+
auto constrainedPatternSig =
797+
autodiff::getConstrainedDerivativeGenericSignature(
798+
original, parameterIndices, resultIndices,
799+
patternSig, lookupConformance).getCanonicalSignature();
800+
auto constrainedPatternSubs =
801+
SubstitutionMap::get(constrainedPatternSig,
802+
QuerySubstitutionMap{original->getPatternSubstitutions()},
803+
lookupConformance);
804+
return SILFunctionType::get(GenericSignature(),
805+
original->getExtInfo(), original->getCoroutineKind(),
806+
original->getCalleeConvention(),
807+
original->getParameters(), original->getYields(),
808+
original->getResults(), original->getOptionalErrorResult(),
809+
constrainedPatternSubs,
810+
/*invocationSubstitutions*/ SubstitutionMap(), original->getASTContext(),
811+
original->getWitnessMethodConformanceOrInvalid());
812+
}
813+
816814
return original;
817815
}
818816

@@ -823,10 +821,10 @@ static SILFunctionType *getConstrainedAutoDiffOriginalFunctionType(
823821
if (!constrainedInvocationGenSig)
824822
return original;
825823
constrainedInvocationGenSig =
826-
autodiff::getConstrainedDerivativeGenericSignature(
827-
original, parameterIndices, constrainedInvocationGenSig,
828-
lookupConformance)
829-
.getCanonicalSignature();
824+
autodiff::getConstrainedDerivativeGenericSignature(
825+
original, parameterIndices, resultIndices,
826+
constrainedInvocationGenSig,
827+
lookupConformance).getCanonicalSignature();
830828

831829
SmallVector<SILParameterInfo, 4> newParameters;
832830
newParameters.reserve(original->getNumParameters());
@@ -882,9 +880,10 @@ CanSILFunctionType SILFunctionType::getAutoDiffDerivativeFunctionType(
882880
return cachedResult;
883881

884882
SILFunctionType *constrainedOriginalFnTy =
885-
getConstrainedAutoDiffOriginalFunctionType(this, parameterIndices,
883+
getConstrainedAutoDiffOriginalFunctionType(this, parameterIndices, resultIndices,
886884
lookupConformance,
887885
derivativeFnInvocationGenSig);
886+
888887
// Compute closure type.
889888
CanSILFunctionType closureType;
890889
switch (kind) {
@@ -957,11 +956,14 @@ CanSILFunctionType SILFunctionType::getAutoDiffTransposeFunctionType(
957956
IndexSubset *parameterIndices, Lowering::TypeConverter &TC,
958957
LookupConformanceFn lookupConformance,
959958
CanGenericSignature transposeFnGenSig) {
959+
auto &ctx = getASTContext();
960+
960961
// Get the "constrained" transpose function generic signature.
961962
if (!transposeFnGenSig)
962963
transposeFnGenSig = getSubstGenericSignature();
963964
transposeFnGenSig = autodiff::getConstrainedDerivativeGenericSignature(
964-
this, parameterIndices, transposeFnGenSig,
965+
this, parameterIndices, IndexSubset::getDefault(ctx, 0),
966+
transposeFnGenSig,
965967
lookupConformance, /*isLinear*/ true)
966968
.getCanonicalSignature();
967969

lib/SILOptimizer/Mandatory/Differentiation.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -576,7 +576,8 @@ emitDerivativeFunctionReference(
576576
.second->getDerivativeGenericSignature();
577577
auto derivativeConstrainedGenSig =
578578
autodiff::getConstrainedDerivativeGenericSignature(
579-
originalFn->getLoweredFunctionType(), desiredParameterIndices,
579+
originalFn->getLoweredFunctionType(),
580+
desiredParameterIndices, desiredResultIndices,
580581
contextualDerivativeGenSig,
581582
LookUpConformanceInModule(context.getModule().getSwiftModule()));
582583
minimalWitness = SILDifferentiabilityWitness::createDefinition(
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
// RUN: %target-swift-frontend -emit-sil -verify %s
2+
3+
// https://github.com/apple/swift/pull/68777
4+
// We used not to constrain the derivative type properly,
5+
// the `Differentiable` requirement was missed if the function type had
6+
// pattern substitution resulting in assertion:
7+
//
8+
// Invalid type parameter in getReducedType()
9+
// Original type: Optional<τ_0_0.TangentVector>
10+
// Simplified term: τ_0_0.[Differentiable:TangentVector]
11+
// Longest valid prefix: τ_0_0
12+
// Prefix type: τ_0_0
13+
//
14+
// Requirement machine for <τ_0_0, τ_0_1>
15+
// Rewrite system: {
16+
// }
17+
// }
18+
// Property map: {
19+
// }
20+
// Conformance paths: {
21+
// }
22+
//
23+
// Note that the generic signature <τ_0_0, τ_0_1> should be
24+
// <τ_0_0 : Differentiable, τ_0_1 : Differentiable>, otherwise
25+
// there is no way to derive associated type τ_0_0.TangentVector
26+
27+
import _Differentiation;
28+
29+
public struct D<T>: Differentiable {}
30+
extension D {@differentiable(reverse, wrt: self) mutating func m(r: @differentiable(reverse) (T?) -> T?) {}}

0 commit comments

Comments
 (0)