@@ -366,27 +366,6 @@ getDifferentiabilityParameters(SILFunctionType *originalFnTy,
366
366
diffParams.push_back (valueAndIndex.value ());
367
367
}
368
368
369
- // / Collects the semantic results of the given function type in
370
- // / `originalResults`. The semantic results are formal results followed by
371
- // / semantic result parameters, in type order.
372
- static void
373
- getSemanticResults (SILFunctionType *functionType,
374
- IndexSubset *parameterIndices,
375
- SmallVectorImpl<SILResultInfo> &originalResults) {
376
- // Collect original formal results.
377
- originalResults.append (functionType->getResults ().begin (),
378
- functionType->getResults ().end ());
379
-
380
- // Collect original semantic result parameters.
381
- for (auto i : range (functionType->getNumParameters ())) {
382
- auto param = functionType->getParameters ()[i];
383
- if (!param.isAutoDiffSemanticResult ())
384
- continue ;
385
- if (param.getDifferentiability () != SILParameterDifferentiability::NotDifferentiable)
386
- originalResults.emplace_back (param.getInterfaceType (), ResultConvention::Indirect);
387
- }
388
- }
389
-
390
369
static CanGenericSignature buildDifferentiableGenericSignature (CanGenericSignature sig,
391
370
CanType tanType,
392
371
CanType origTypeOfAbstraction) {
@@ -563,7 +542,7 @@ static CanSILFunctionType getAutoDiffDifferentialType(
563
542
SmallVector<ProtocolConformanceRef, 4 > substConformances;
564
543
565
544
SmallVector<SILResultInfo, 2 > originalResults;
566
- getSemanticResults (originalFnTy, parameterIndices, originalResults);
545
+ autodiff:: getSemanticResults (originalFnTy, parameterIndices, originalResults);
567
546
568
547
SmallVector<SILParameterInfo, 4 > diffParams;
569
548
getDifferentiabilityParameters (originalFnTy, parameterIndices, diffParams);
@@ -647,7 +626,7 @@ static CanSILFunctionType getAutoDiffPullbackType(
647
626
SmallVector<ProtocolConformanceRef, 4 > substConformances;
648
627
649
628
SmallVector<SILResultInfo, 2 > originalResults;
650
- getSemanticResults (originalFnTy, parameterIndices, originalResults);
629
+ autodiff:: getSemanticResults (originalFnTy, parameterIndices, originalResults);
651
630
652
631
// Given a type, returns its formal SIL parameter info.
653
632
auto getTangentParameterConventionForOriginalResult =
@@ -791,9 +770,9 @@ static CanSILFunctionType getAutoDiffPullbackType(
791
770
llvm::makeArrayRef (substConformances));
792
771
}
793
772
return SILFunctionType::get (
794
- GenericSignature (), SILFunctionType::ExtInfo (), SILCoroutineKind::None ,
795
- ParameterConvention::Direct_Guaranteed, pullbackParams, {},
796
- pullbackResults, llvm::None, substitutions,
773
+ GenericSignature (), SILFunctionType::ExtInfo (), originalFnTy-> getCoroutineKind () ,
774
+ ParameterConvention::Direct_Guaranteed,
775
+ pullbackParams, {}, pullbackResults, llvm::None, substitutions,
797
776
/* invocationSubstitutions*/ SubstitutionMap (), ctx);
798
777
}
799
778
@@ -804,7 +783,7 @@ static CanSILFunctionType getAutoDiffPullbackType(
804
783
// / - The invocation generic signature is replaced by the
805
784
// / `constrainedInvocationGenSig` argument.
806
785
static SILFunctionType *getConstrainedAutoDiffOriginalFunctionType (
807
- SILFunctionType *original, IndexSubset *parameterIndices,
786
+ SILFunctionType *original, IndexSubset *parameterIndices, IndexSubset *resultIndices,
808
787
LookupConformanceFn lookupConformance,
809
788
CanGenericSignature constrainedInvocationGenSig) {
810
789
auto originalInvocationGenSig = original->getInvocationGenericSignature ();
@@ -813,6 +792,25 @@ static SILFunctionType *getConstrainedAutoDiffOriginalFunctionType(
813
792
constrainedInvocationGenSig->areAllParamsConcrete () &&
814
793
" derivative function cannot have invocation generic signature "
815
794
" when original function doesn't" );
795
+ if (auto patternSig = original->getPatternGenericSignature ()) {
796
+ auto constrainedPatternSig =
797
+ autodiff::getConstrainedDerivativeGenericSignature (
798
+ original, parameterIndices, resultIndices,
799
+ patternSig, lookupConformance).getCanonicalSignature ();
800
+ auto constrainedPatternSubs =
801
+ SubstitutionMap::get (constrainedPatternSig,
802
+ QuerySubstitutionMap{original->getPatternSubstitutions ()},
803
+ lookupConformance);
804
+ return SILFunctionType::get (GenericSignature (),
805
+ original->getExtInfo (), original->getCoroutineKind (),
806
+ original->getCalleeConvention (),
807
+ original->getParameters (), original->getYields (),
808
+ original->getResults (), original->getOptionalErrorResult (),
809
+ constrainedPatternSubs,
810
+ /* invocationSubstitutions*/ SubstitutionMap (), original->getASTContext (),
811
+ original->getWitnessMethodConformanceOrInvalid ());
812
+ }
813
+
816
814
return original;
817
815
}
818
816
@@ -823,10 +821,10 @@ static SILFunctionType *getConstrainedAutoDiffOriginalFunctionType(
823
821
if (!constrainedInvocationGenSig)
824
822
return original;
825
823
constrainedInvocationGenSig =
826
- autodiff::getConstrainedDerivativeGenericSignature (
827
- original, parameterIndices, constrainedInvocationGenSig ,
828
- lookupConformance)
829
- .getCanonicalSignature ();
824
+ autodiff::getConstrainedDerivativeGenericSignature (
825
+ original, parameterIndices, resultIndices ,
826
+ constrainedInvocationGenSig,
827
+ lookupConformance) .getCanonicalSignature ();
830
828
831
829
SmallVector<SILParameterInfo, 4 > newParameters;
832
830
newParameters.reserve (original->getNumParameters ());
@@ -882,9 +880,10 @@ CanSILFunctionType SILFunctionType::getAutoDiffDerivativeFunctionType(
882
880
return cachedResult;
883
881
884
882
SILFunctionType *constrainedOriginalFnTy =
885
- getConstrainedAutoDiffOriginalFunctionType (this , parameterIndices,
883
+ getConstrainedAutoDiffOriginalFunctionType (this , parameterIndices, resultIndices,
886
884
lookupConformance,
887
885
derivativeFnInvocationGenSig);
886
+
888
887
// Compute closure type.
889
888
CanSILFunctionType closureType;
890
889
switch (kind) {
@@ -957,11 +956,14 @@ CanSILFunctionType SILFunctionType::getAutoDiffTransposeFunctionType(
957
956
IndexSubset *parameterIndices, Lowering::TypeConverter &TC,
958
957
LookupConformanceFn lookupConformance,
959
958
CanGenericSignature transposeFnGenSig) {
959
+ auto &ctx = getASTContext ();
960
+
960
961
// Get the "constrained" transpose function generic signature.
961
962
if (!transposeFnGenSig)
962
963
transposeFnGenSig = getSubstGenericSignature ();
963
964
transposeFnGenSig = autodiff::getConstrainedDerivativeGenericSignature (
964
- this , parameterIndices, transposeFnGenSig,
965
+ this , parameterIndices, IndexSubset::getDefault (ctx, 0 ),
966
+ transposeFnGenSig,
965
967
lookupConformance, /* isLinear*/ true )
966
968
.getCanonicalSignature ();
967
969
0 commit comments