@@ -3203,15 +3203,25 @@ class AnyFunctionType : public TypeBase {
3203
3203
return getExtInfo ().getRepresentation ();
3204
3204
}
3205
3205
3206
+ // / Appends the parameters indicated by `parameterIndices` to `results`.
3207
+ // /
3208
+ // / For curried function types: if `reverseCurryLevels` is true, append
3209
+ // / the `self` parameter last instead of first.
3210
+ // /
3211
+ // / TODO(TF-874): Simplify logic and remove the `reverseCurryLevels` flag.
3212
+ void getSubsetParameters (IndexSubset *parameterIndices,
3213
+ SmallVectorImpl<AnyFunctionType::Param> &results,
3214
+ bool reverseCurryLevels = false );
3215
+
3206
3216
// / Returns the derivative function type for the given parameter indices,
3207
3217
// / result index, derivative function kind, derivative function generic
3208
3218
// / signature (optional), and other auxiliary parameters.
3209
3219
// /
3210
3220
// / Preconditions:
3211
3221
// / - Parameters corresponding to parameter indices must conform to
3212
3222
// / `Differentiable`.
3213
- // / - The result corresponding to the result index must conform to
3214
- // / `Differentiable`.
3223
+ // / - There is one semantic function result type: either the formal original
3224
+ // / result or an `inout` parameter. It must conform to `Differentiable`.
3215
3225
// /
3216
3226
// / Typing rules, given:
3217
3227
// / - Original function type. Three cases:
@@ -3257,6 +3267,11 @@ class AnyFunctionType : public TypeBase {
3257
3267
// / original result | deriv. wrt result | deriv. wrt params
3258
3268
// / \endverbatim
3259
3269
// /
3270
+ // / The original type may have `inout` parameters. If so, the
3271
+ // / differential/pullback typing rules are more nuanced: see documentation for
3272
+ // / `getAutoDiffDerivativeFunctionLinearMapType` for details. Semantically,
3273
+ // / `inout` parameters behave as both parameters and results.
3274
+ // /
3260
3275
// / By default, if the original type has a `self` parameter list and parameter
3261
3276
// / indices include `self`, the computed derivative function type will return
3262
3277
// / a linear map taking/returning self's tangent *last* instead of first, for
@@ -3267,12 +3282,54 @@ class AnyFunctionType : public TypeBase {
3267
3282
// / derivative function types, e.g. when type-checking `@differentiable` and
3268
3283
// / `@derivative` attributes.
3269
3284
AnyFunctionType *getAutoDiffDerivativeFunctionType (
3270
- IndexSubset *parameterIndices, unsigned resultIndex,
3271
- AutoDiffDerivativeFunctionKind kind,
3285
+ IndexSubset *parameterIndices, AutoDiffDerivativeFunctionKind kind,
3272
3286
LookupConformanceFn lookupConformance,
3273
3287
GenericSignature derivativeGenericSignature = GenericSignature(),
3274
3288
bool makeSelfParamFirst = false);
3275
3289
3290
+ // / Returns the corresponding linear map function type for the given parameter
3291
+ // / indices, linear map function kind, and other auxiliary parameters.
3292
+ // /
3293
+ // / Preconditions:
3294
+ // / - Parameters corresponding to parameter indices must conform to
3295
+ // / `Differentiable`.
3296
+ // / - There is one semantic function result type: either the formal original
3297
+ // / result or an `inout` parameter. It must conform to `Differentiable`.
3298
+ // /
3299
+ // / Differential typing rules: takes "wrt" parameter derivatives and returns a
3300
+ // / "wrt" result derivative.
3301
+ // /
3302
+ // / - Case 1: original function has no `inout` parameters.
3303
+ // / - Original: `(T0, T1, ...) -> R`
3304
+ // / - Differential: `(T0.Tan, T1.Tan, ...) -> R.Tan`
3305
+ // / - Case 2: original function has a non-wrt `inout` parameter.
3306
+ // / - Original: `(T0, inout T1, ...) -> Void`
3307
+ // / - Differential: `(T0.Tan, ...) -> T1.Tan`
3308
+ // / - Case 3: original function has a wrt `inout` parameter.
3309
+ // / - Original: `(T0, inout T1, ...) -> Void`
3310
+ // / - Differential: `(T0.Tan, inout T1.Tan, ...) -> Void`
3311
+ // /
3312
+ // / Pullback typing rules: takes a "wrt" result derivative and returns "wrt"
3313
+ // / parameter derivatives.
3314
+ // /
3315
+ // / - Case 1: original function has no `inout` parameters.
3316
+ // / - Original: `(T0, T1, ...) -> R`
3317
+ // / - Pullback: `R.Tan -> (T0.Tan, T1.Tan, ...)`
3318
+ // / - Case 2: original function has a non-wrt `inout` parameter.
3319
+ // / - Original: `(T0, inout T1, ...) -> Void`
3320
+ // / - Pullback: `(T1.Tan) -> (T0.Tan, ...)`
3321
+ // / - Case 3: original function has a wrt `inout` parameter.
3322
+ // / - Original: `(T0, inout T1, ...) -> Void`
3323
+ // / - Pullback: `(inout T1.Tan) -> (T0.Tan, ...)`
3324
+ // /
3325
+ // / If `makeSelfParamFirst` is true, `self`'s tangent is reordered to appear
3326
+ // / first. `makeSelfParamFirst` should be true when working with user-facing
3327
+ // / derivative function types, e.g. when type-checking `@differentiable` and
3328
+ // / `@derivative` attributes.
3329
+ AnyFunctionType *getAutoDiffDerivativeFunctionLinearMapType (
3330
+ IndexSubset *parameterIndices, AutoDiffLinearMapKind kind,
3331
+ LookupConformanceFn lookupConformance, bool makeSelfParamFirst = false );
3332
+
3276
3333
// / True if the parameter declaration it is attached to is guaranteed
3277
3334
// / to not persist the closure for longer than the duration of the call.
3278
3335
bool isNoEscape () const {
@@ -4404,6 +4461,28 @@ class SILFunctionType final : public TypeBase, public llvm::FoldingSetNode,
4404
4461
return getParameters ().back ();
4405
4462
}
4406
4463
4464
+ struct IndirectMutatingParameterFilter {
4465
+ bool operator ()(SILParameterInfo param) const {
4466
+ return param.isIndirectMutating ();
4467
+ }
4468
+ };
4469
+ using IndirectMutatingParameterIter =
4470
+ llvm::filter_iterator<const SILParameterInfo *,
4471
+ IndirectMutatingParameterFilter>;
4472
+ using IndirectMutatingParameterRange =
4473
+ iterator_range<IndirectMutatingParameterIter>;
4474
+
4475
+ // / A range of SILParameterInfo for all indirect mutating parameters.
4476
+ IndirectMutatingParameterRange getIndirectMutatingParameters () const {
4477
+ return llvm::make_filter_range (getParameters (),
4478
+ IndirectMutatingParameterFilter ());
4479
+ }
4480
+
4481
+ // / Returns the number of indirect mutating parameters.
4482
+ unsigned getNumIndirectMutatingParameters () const {
4483
+ return llvm::count_if (getParameters (), IndirectMutatingParameterFilter ());
4484
+ }
4485
+
4407
4486
// / Get the generic signature used to apply the substitutions of a substituted function type
4408
4487
CanGenericSignature getSubstGenericSignature () const {
4409
4488
return GenericSigAndIsImplied.getPointer ();
@@ -4486,18 +4565,27 @@ class SILFunctionType final : public TypeBase, public llvm::FoldingSetNode,
4486
4565
// / - Returns original results, followed by a differential function, which
4487
4566
// / takes "wrt" parameter derivatives and returns a "wrt" result derivative.
4488
4567
// /
4568
+ // / \verbatim
4489
4569
// / $(T0, ...) -> (R0, ..., (T0.Tan, T1.Tan, ...) -> R0.Tan)
4490
4570
// / ^~~~~~~ ^~~~~~~~~~~~~~~~~~~ ^~~~~~
4491
4571
// / original results | derivatives wrt params | derivative wrt result
4572
+ // / \endverbatim
4492
4573
// /
4493
4574
// / VJP derivative type:
4494
4575
// / - Takes original parameters.
4495
4576
// / - Returns original results, followed by a pullback function, which
4496
4577
// / takes a "wrt" result derivative and returns "wrt" parameter derivatives.
4497
4578
// /
4579
+ // / \verbatim
4498
4580
// / $(T0, ...) -> (R0, ..., (R0.Tan) -> (T0.Tan, T1.Tan, ...))
4499
4581
// / ^~~~~~~ ^~~~~~ ^~~~~~~~~~~~~~~~~~~
4500
4582
// / original results | derivative wrt result | derivatives wrt params
4583
+ // / \endverbatim
4584
+ // /
4585
+ // / The original type may have `inout` parameters. If so, the
4586
+ // / differential/pullback typing rules are more nuanced: see documentation for
4587
+ // / `getAutoDiffDerivativeFunctionLinearMapType` for details. Semantically,
4588
+ // / `inout` parameters behave as both parameters and results.
4501
4589
// /
4502
4590
// / A "constrained derivative generic signature" is computed from
4503
4591
// / `derivativeFunctionGenericSignature`, if specified. Otherwise, it is
0 commit comments