@@ -191,6 +191,56 @@ SILFunctionType::getWitnessMethodClass(SILModule &M) const {
191
191
return nullptr ;
192
192
}
193
193
194
+ IndexSubset *
195
+ SILFunctionType::getDifferentiabilityParameterIndices () {
196
+ assert (isDifferentiable () && " Must be a differentiable function" );
197
+ SmallVector<unsigned , 8 > result;
198
+ for (auto valueAndIndex : enumerate(getParameters ()))
199
+ if (valueAndIndex.value ().getDifferentiability () !=
200
+ SILParameterDifferentiability::NotDifferentiable)
201
+ result.push_back (valueAndIndex.index ());
202
+ return IndexSubset::get (getASTContext (), getNumParameters (), result);
203
+ }
204
+
205
+ CanSILFunctionType
206
+ SILFunctionType::getWithDifferentiability (DifferentiabilityKind kind,
207
+ IndexSubset *parameterIndices) {
208
+ assert (kind != DifferentiabilityKind::NonDifferentiable &&
209
+ " Differentiability kind must be normal or linear" );
210
+ SmallVector<SILParameterInfo, 8 > newParameters;
211
+ for (auto paramAndIndex : enumerate(getParameters ())) {
212
+ auto ¶m = paramAndIndex.value ();
213
+ unsigned index = paramAndIndex.index ();
214
+ newParameters.push_back (param.getWithDifferentiability (
215
+ index < parameterIndices->getCapacity () &&
216
+ parameterIndices->contains (index)
217
+ ? SILParameterDifferentiability::DifferentiableOrNotApplicable
218
+ : SILParameterDifferentiability::NotDifferentiable));
219
+ }
220
+ auto newExtInfo = getExtInfo ().withDifferentiabilityKind (kind);
221
+ return get (getSubstGenericSignature (), newExtInfo, getCoroutineKind (),
222
+ getCalleeConvention (), newParameters, getYields (), getResults (),
223
+ getOptionalErrorResult (), getSubstitutions (),
224
+ isGenericSignatureImplied (), getASTContext (),
225
+ getWitnessMethodConformanceOrInvalid ());
226
+ }
227
+
228
+ CanSILFunctionType SILFunctionType::getWithoutDifferentiability () {
229
+ if (!isDifferentiable ())
230
+ return CanSILFunctionType (this );
231
+ auto nondiffExtInfo = getExtInfo ().withDifferentiabilityKind (
232
+ DifferentiabilityKind::NonDifferentiable);
233
+ SmallVector<SILParameterInfo, 8 > newParams;
234
+ for (auto ¶m : getParameters ())
235
+ newParams.push_back (param.getWithDifferentiability (
236
+ SILParameterDifferentiability::DifferentiableOrNotApplicable));
237
+ return SILFunctionType::get (getSubstGenericSignature (), nondiffExtInfo,
238
+ getCoroutineKind (), getCalleeConvention (),
239
+ newParams, getYields (), getResults (),
240
+ getOptionalErrorResult (), getSubstitutions (),
241
+ isGenericSignatureImplied (), getASTContext ());
242
+ }
243
+
194
244
CanSILFunctionType SILFunctionType::getAutoDiffDerivativeFunctionType (
195
245
IndexSubset *parameterIndices, unsigned resultIndex,
196
246
AutoDiffDerivativeFunctionKind kind, TypeConverter &TC,
0 commit comments