Skip to content

Commit 697c722

Browse files
authored
[AutoDiff] Type-checking support for inout parameter differentiation. (swiftlang#29959)
Semantically, an `inout` parameter is both a parameter and a result. `@differentiable` and `@derivative` attributes now support original functions with one "semantic result": either a formal result or an `inout` parameter. Derivative typing rules for functions with `inout` parameters are now defined. The differential/pullback type of a function with `inout` differentiability parameters also has `inout` parameters. This is ideal for performance. Differential typing rules: - Case 1: original function has no `inout` parameters. - Original: `(T0, T1, ...) -> R` - Differential: `(T0.Tan, T1.Tan, ...) -> R.Tan` - Case 2: original function has a non-wrt `inout` parameter. - Original: `(T0, inout T1, ...) -> Void` - Differential: `(T0.Tan, ...) -> T1.Tan` - Case 3: original function has a wrt `inout` parameter. - Original: `(T0, inout T1, ...) -> Void` - Differential: `(T0.Tan, inout T1.Tan, ...) -> Void` Pullback typing rules: - Case 1: original function has no `inout` parameters. - Original: `(T0, T1, ...) -> R` - Pullback: `R.Tan -> (T0.Tan, T1.Tan, ...)` - Case 2: original function has a non-wrt `inout` parameter. - Original: `(T0, inout T1, ...) -> Void` - Pullback: `(T1.Tan) -> (T0.Tan, ...)` - Case 3: original function has a wrt `inout` parameter. - Original: `(T0, inout T1, ...) -> Void` - Pullback: `(inout T1.Tan) -> (T0.Tan, ...)` Resolves TF-1164.
1 parent 1cefebe commit 697c722

File tree

9 files changed

+598
-202
lines changed

9 files changed

+598
-202
lines changed

include/swift/AST/AutoDiff.h

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,13 @@ struct AutoDiffConfig {
114114
SWIFT_DEBUG_DUMP;
115115
};
116116

117+
/// A semantic function result type: either a formal function result type or
118+
/// an `inout` parameter type. Used in derivative function type calculation.
119+
struct AutoDiffSemanticFunctionResultType {
120+
Type type;
121+
bool isInout;
122+
};
123+
117124
/// Key for caching SIL derivative function types.
118125
struct SILAutoDiffDerivativeFunctionKey {
119126
SILFunctionType *originalType;
@@ -263,11 +270,17 @@ using SILDifferentiabilityWitnessKey = std::pair<StringRef, AutoDiffConfig>;
263270
/// Automatic differentiation utility namespace.
264271
namespace autodiff {
265272

266-
/// Appends the subset's parameter's types to `results`, in the order in
267-
/// which they appear in the function type.
268-
void getSubsetParameterTypes(IndexSubset *indices, AnyFunctionType *type,
269-
SmallVectorImpl<Type> &results,
270-
bool reverseCurryLevels = false);
273+
/// Given a function type, collects its semantic result types in type order
274+
/// into `result`: first, the formal result type (if non-`Void`), followed by
275+
/// `inout` parameter types.
276+
///
277+
/// The function type may have at most two parameter lists.
278+
///
279+
/// Remaps the original semantic result using `genericEnv`, if specified.
280+
void getFunctionSemanticResultTypes(
281+
AnyFunctionType *functionType,
282+
SmallVectorImpl<AutoDiffSemanticFunctionResultType> &result,
283+
GenericEnvironment *genericEnv = nullptr);
271284

272285
/// "Constrained" derivative generic signatures require all differentiability
273286
/// parameters to conform to the `Differentiable` protocol.

include/swift/AST/DiagnosticsSema.def

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2898,8 +2898,6 @@ ERROR(implements_attr_protocol_not_conformed_to,none,
28982898
(DeclName, DeclName))
28992899

29002900
// @differentiable
2901-
ERROR(differentiable_attr_void_result,none,
2902-
"cannot differentiate void function %0", (DeclName))
29032901
ERROR(differentiable_attr_no_vjp_or_jvp_when_linear,none,
29042902
"cannot specify 'vjp:' or 'jvp:' for linear functions; use '@transpose' "
29052903
"attribute for transpose registration instead", ())
@@ -2996,6 +2994,11 @@ ERROR(autodiff_attr_original_decl_none_valid_found,none,
29962994
"could not find function %0 with expected type %1", (DeclNameRef, Type))
29972995
ERROR(autodiff_attr_original_decl_not_same_type_context,none,
29982996
"%0 is not defined in the current type context", (DeclNameRef))
2997+
ERROR(autodiff_attr_original_void_result,none,
2998+
"cannot differentiate void function %0", (DeclName))
2999+
ERROR(autodiff_attr_original_multiple_semantic_results,none,
3000+
"cannot differentiate functions with both an 'inout' parameter and a "
3001+
"result", ())
29993002

30003003
// differentiation `wrt` parameters clause
30013004
ERROR(diff_function_no_parameters,none,

include/swift/AST/Types.h

Lines changed: 92 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3203,15 +3203,25 @@ class AnyFunctionType : public TypeBase {
32033203
return getExtInfo().getRepresentation();
32043204
}
32053205

3206+
/// Appends the parameters indicated by `parameterIndices` to `results`.
3207+
///
3208+
/// For curried function types: if `reverseCurryLevels` is true, append
3209+
/// the `self` parameter last instead of first.
3210+
///
3211+
/// TODO(TF-874): Simplify logic and remove the `reverseCurryLevels` flag.
3212+
void getSubsetParameters(IndexSubset *parameterIndices,
3213+
SmallVectorImpl<AnyFunctionType::Param> &results,
3214+
bool reverseCurryLevels = false);
3215+
32063216
/// Returns the derivative function type for the given parameter indices,
32073217
/// result index, derivative function kind, derivative function generic
32083218
/// signature (optional), and other auxiliary parameters.
32093219
///
32103220
/// Preconditions:
32113221
/// - Parameters corresponding to parameter indices must conform to
32123222
/// `Differentiable`.
3213-
/// - The result corresponding to the result index must conform to
3214-
/// `Differentiable`.
3223+
/// - There is one semantic function result type: either the formal original
3224+
/// result or an `inout` parameter. It must conform to `Differentiable`.
32153225
///
32163226
/// Typing rules, given:
32173227
/// - Original function type. Three cases:
@@ -3257,6 +3267,11 @@ class AnyFunctionType : public TypeBase {
32573267
/// original result | deriv. wrt result | deriv. wrt params
32583268
/// \endverbatim
32593269
///
3270+
/// The original type may have `inout` parameters. If so, the
3271+
/// differential/pullback typing rules are more nuanced: see documentation for
3272+
/// `getAutoDiffDerivativeFunctionLinearMapType` for details. Semantically,
3273+
/// `inout` parameters behave as both parameters and results.
3274+
///
32603275
/// By default, if the original type has a `self` parameter list and parameter
32613276
/// indices include `self`, the computed derivative function type will return
32623277
/// a linear map taking/returning self's tangent *last* instead of first, for
@@ -3267,12 +3282,54 @@ class AnyFunctionType : public TypeBase {
32673282
/// derivative function types, e.g. when type-checking `@differentiable` and
32683283
/// `@derivative` attributes.
32693284
AnyFunctionType *getAutoDiffDerivativeFunctionType(
3270-
IndexSubset *parameterIndices, unsigned resultIndex,
3271-
AutoDiffDerivativeFunctionKind kind,
3285+
IndexSubset *parameterIndices, AutoDiffDerivativeFunctionKind kind,
32723286
LookupConformanceFn lookupConformance,
32733287
GenericSignature derivativeGenericSignature = GenericSignature(),
32743288
bool makeSelfParamFirst = false);
32753289

3290+
/// Returns the corresponding linear map function type for the given parameter
3291+
/// indices, linear map function kind, and other auxiliary parameters.
3292+
///
3293+
/// Preconditions:
3294+
/// - Parameters corresponding to parameter indices must conform to
3295+
/// `Differentiable`.
3296+
/// - There is one semantic function result type: either the formal original
3297+
/// result or an `inout` parameter. It must conform to `Differentiable`.
3298+
///
3299+
/// Differential typing rules: takes "wrt" parameter derivatives and returns a
3300+
/// "wrt" result derivative.
3301+
///
3302+
/// - Case 1: original function has no `inout` parameters.
3303+
/// - Original: `(T0, T1, ...) -> R`
3304+
/// - Differential: `(T0.Tan, T1.Tan, ...) -> R.Tan`
3305+
/// - Case 2: original function has a non-wrt `inout` parameter.
3306+
/// - Original: `(T0, inout T1, ...) -> Void`
3307+
/// - Differential: `(T0.Tan, ...) -> T1.Tan`
3308+
/// - Case 3: original function has a wrt `inout` parameter.
3309+
/// - Original: `(T0, inout T1, ...) -> Void`
3310+
/// - Differential: `(T0.Tan, inout T1.Tan, ...) -> Void`
3311+
///
3312+
/// Pullback typing rules: takes a "wrt" result derivative and returns "wrt"
3313+
/// parameter derivatives.
3314+
///
3315+
/// - Case 1: original function has no `inout` parameters.
3316+
/// - Original: `(T0, T1, ...) -> R`
3317+
/// - Pullback: `R.Tan -> (T0.Tan, T1.Tan, ...)`
3318+
/// - Case 2: original function has a non-wrt `inout` parameter.
3319+
/// - Original: `(T0, inout T1, ...) -> Void`
3320+
/// - Pullback: `(T1.Tan) -> (T0.Tan, ...)`
3321+
/// - Case 3: original function has a wrt `inout` parameter.
3322+
/// - Original: `(T0, inout T1, ...) -> Void`
3323+
/// - Pullback: `(inout T1.Tan) -> (T0.Tan, ...)`
3324+
///
3325+
/// If `makeSelfParamFirst` is true, `self`'s tangent is reordered to appear
3326+
/// first. `makeSelfParamFirst` should be true when working with user-facing
3327+
/// derivative function types, e.g. when type-checking `@differentiable` and
3328+
/// `@derivative` attributes.
3329+
AnyFunctionType *getAutoDiffDerivativeFunctionLinearMapType(
3330+
IndexSubset *parameterIndices, AutoDiffLinearMapKind kind,
3331+
LookupConformanceFn lookupConformance, bool makeSelfParamFirst = false);
3332+
32763333
/// True if the parameter declaration it is attached to is guaranteed
32773334
/// to not persist the closure for longer than the duration of the call.
32783335
bool isNoEscape() const {
@@ -4404,6 +4461,28 @@ class SILFunctionType final : public TypeBase, public llvm::FoldingSetNode,
44044461
return getParameters().back();
44054462
}
44064463

4464+
struct IndirectMutatingParameterFilter {
4465+
bool operator()(SILParameterInfo param) const {
4466+
return param.isIndirectMutating();
4467+
}
4468+
};
4469+
using IndirectMutatingParameterIter =
4470+
llvm::filter_iterator<const SILParameterInfo *,
4471+
IndirectMutatingParameterFilter>;
4472+
using IndirectMutatingParameterRange =
4473+
iterator_range<IndirectMutatingParameterIter>;
4474+
4475+
/// A range of SILParameterInfo for all indirect mutating parameters.
4476+
IndirectMutatingParameterRange getIndirectMutatingParameters() const {
4477+
return llvm::make_filter_range(getParameters(),
4478+
IndirectMutatingParameterFilter());
4479+
}
4480+
4481+
/// Returns the number of indirect mutating parameters.
4482+
unsigned getNumIndirectMutatingParameters() const {
4483+
return llvm::count_if(getParameters(), IndirectMutatingParameterFilter());
4484+
}
4485+
44074486
/// Get the generic signature used to apply the substitutions of a substituted function type
44084487
CanGenericSignature getSubstGenericSignature() const {
44094488
return GenericSigAndIsImplied.getPointer();
@@ -4486,18 +4565,27 @@ class SILFunctionType final : public TypeBase, public llvm::FoldingSetNode,
44864565
/// - Returns original results, followed by a differential function, which
44874566
/// takes "wrt" parameter derivatives and returns a "wrt" result derivative.
44884567
///
4568+
/// \verbatim
44894569
/// $(T0, ...) -> (R0, ..., (T0.Tan, T1.Tan, ...) -> R0.Tan)
44904570
/// ^~~~~~~ ^~~~~~~~~~~~~~~~~~~ ^~~~~~
44914571
/// original results | derivatives wrt params | derivative wrt result
4572+
/// \endverbatim
44924573
///
44934574
/// VJP derivative type:
44944575
/// - Takes original parameters.
44954576
/// - Returns original results, followed by a pullback function, which
44964577
/// takes a "wrt" result derivative and returns "wrt" parameter derivatives.
44974578
///
4579+
/// \verbatim
44984580
/// $(T0, ...) -> (R0, ..., (R0.Tan) -> (T0.Tan, T1.Tan, ...))
44994581
/// ^~~~~~~ ^~~~~~ ^~~~~~~~~~~~~~~~~~~
45004582
/// original results | derivative wrt result | derivatives wrt params
4583+
/// \endverbatim
4584+
///
4585+
/// The original type may have `inout` parameters. If so, the
4586+
/// differential/pullback typing rules are more nuanced: see documentation for
4587+
/// `getAutoDiffDerivativeFunctionLinearMapType` for details. Semantically,
4588+
/// `inout` parameters behave as both parameters and results.
45014589
///
45024590
/// A "constrained derivative generic signature" is computed from
45034591
/// `derivativeFunctionGenericSignature`, if specified. Otherwise, it is

lib/AST/AutoDiff.cpp

Lines changed: 42 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
#include "swift/AST/AutoDiff.h"
1414
#include "swift/AST/ASTContext.h"
15+
#include "swift/AST/GenericEnvironment.h"
1516
#include "swift/AST/Module.h"
1617
#include "swift/AST/TypeCheckRequests.h"
1718
#include "swift/AST/Types.h"
@@ -72,13 +73,11 @@ static unsigned countNumFlattenedElementTypes(Type type) {
7273
}
7374

7475
// TODO(TF-874): Simplify this helper and remove the `reverseCurryLevels` flag.
75-
// See TF-874 for WIP.
76-
void autodiff::getSubsetParameterTypes(IndexSubset *subset,
77-
AnyFunctionType *type,
78-
SmallVectorImpl<Type> &results,
79-
bool reverseCurryLevels) {
76+
void AnyFunctionType::getSubsetParameters(
77+
IndexSubset *parameterIndices,
78+
SmallVectorImpl<AnyFunctionType::Param> &results, bool reverseCurryLevels) {
8079
SmallVector<AnyFunctionType *, 2> curryLevels;
81-
unwrapCurryLevels(type, curryLevels);
80+
unwrapCurryLevels(this, curryLevels);
8281

8382
SmallVector<unsigned, 2> curryLevelParameterIndexOffsets(curryLevels.size());
8483
unsigned currentOffset = 0;
@@ -99,8 +98,43 @@ void autodiff::getSubsetParameterTypes(IndexSubset *subset,
9998
unsigned parameterIndexOffset =
10099
curryLevelParameterIndexOffsets[curryLevelIndex];
101100
for (unsigned paramIndex : range(curryLevel->getNumParams()))
102-
if (subset->contains(parameterIndexOffset + paramIndex))
103-
results.push_back(curryLevel->getParams()[paramIndex].getOldType());
101+
if (parameterIndices->contains(parameterIndexOffset + paramIndex))
102+
results.push_back(curryLevel->getParams()[paramIndex]);
103+
}
104+
}
105+
106+
void autodiff::getFunctionSemanticResultTypes(
107+
AnyFunctionType *functionType,
108+
SmallVectorImpl<AutoDiffSemanticFunctionResultType> &result,
109+
GenericEnvironment *genericEnv) {
110+
auto &ctx = functionType->getASTContext();
111+
112+
// Remap type in `genericEnv`, if specified.
113+
auto remap = [&](Type type) {
114+
if (!genericEnv)
115+
return type;
116+
return genericEnv->mapTypeIntoContext(type);
117+
};
118+
119+
// Collect formal result type as a semantic result, unless it is
120+
// `Void`.
121+
auto formalResultType = functionType->getResult();
122+
if (auto *resultFunctionType =
123+
functionType->getResult()->getAs<AnyFunctionType>()) {
124+
formalResultType = resultFunctionType->getResult();
125+
}
126+
if (!formalResultType->isEqual(ctx.TheEmptyTupleType))
127+
result.push_back({remap(formalResultType), /*isInout*/ false});
128+
129+
// Collect `inout` parameters as semantic results.
130+
for (auto param : functionType->getParams())
131+
if (param.isInOut())
132+
result.push_back({remap(param.getPlainType()), /*isInout*/ true});
133+
if (auto *resultFunctionType =
134+
functionType->getResult()->getAs<AnyFunctionType>()) {
135+
for (auto param : resultFunctionType->getParams())
136+
if (param.isInOut())
137+
result.push_back({remap(param.getPlainType()), /*isInout*/ true});
104138
}
105139
}
106140

0 commit comments

Comments
 (0)