Skip to content

Commit d3b6b89

Browse files
authored
[AutoDiff] Support multiple differentiability result indices in SIL. (swiftlang#32206)
`DifferentiableFunctionInst` now stores result indices. `SILAutoDiffIndices` now stores result indices instead of a source index. `@differentiable` SIL function types may now have multiple differentiability result indices and `@noDerivative` resutls. `@differentiable` AST function types do not have `@noDerivative` results (yet), so this functionality is not exposed to users. Resolves TF-689 and TF-1256. Infrastructural support for TF-983: supporting differentiation of `apply` instructions with multiple active semantic results.
1 parent 754f21d commit d3b6b89

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

61 files changed

+925
-616
lines changed

docs/ABI/Mangling.rst

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -589,7 +589,7 @@ mangled in to disambiguate.
589589
impl-function-type ::= type* 'I' FUNC-ATTRIBUTES '_'
590590
impl-function-type ::= type* generic-signature 'I' FUNC-ATTRIBUTES '_'
591591

592-
FUNC-ATTRIBUTES ::= PATTERN-SUBS? INVOCATION-SUBS? PSEUDO-GENERIC? CALLEE-ESCAPE? DIFFERENTIABILITY-KIND? CALLEE-CONVENTION FUNC-REPRESENTATION? COROUTINE-KIND? (PARAM-CONVENTION PARAM-DIFFERENTIABILITY?)* RESULT-CONVENTION* ('Y' PARAM-CONVENTION)* ('z' RESULT-CONVENTION)?
592+
FUNC-ATTRIBUTES ::= PATTERN-SUBS? INVOCATION-SUBS? PSEUDO-GENERIC? CALLEE-ESCAPE? DIFFERENTIABILITY-KIND? CALLEE-CONVENTION FUNC-REPRESENTATION? COROUTINE-KIND? (PARAM-CONVENTION PARAM-DIFFERENTIABILITY?)* RESULT-CONVENTION* ('Y' PARAM-CONVENTION)* ('z' RESULT-CONVENTION RESULT-DIFFERENTIABILITY?)?
593593

594594
PATTERN-SUBS ::= 's' // has pattern substitutions
595595
INVOCATION-SUB ::= 'I' // has invocation substitutions
@@ -634,6 +634,8 @@ mangled in to disambiguate.
634634
RESULT-CONVENTION ::= 'u' // unowned inner pointer
635635
RESULT-CONVENTION ::= 'a' // auto-released
636636

637+
RESULT-DIFFERENTIABILITY ::= 'w' // @noDerivative
638+
637639
For the most part, manglings follow the structure of formal language
638640
types. However, in some cases it is more useful to encode the exact
639641
implementation details of a function type.

include/swift/AST/AutoDiff.h

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -173,18 +173,21 @@ enum class AutoDiffGeneratedDeclarationKind : uint8_t {
173173
};
174174

175175
/// SIL-level automatic differentiation indices. Consists of:
176-
/// - Parameter indices: indices of parameters to differentiate with respect to.
177-
/// - Result index: index of the result to differentiate from.
176+
/// - The differentiability parameter indices.
177+
/// - The differentiability result indices.
178178
// TODO(TF-913): Remove `SILAutoDiffIndices` in favor of `AutoDiffConfig`.
179-
// `AutoDiffConfig` supports multiple result indices.
179+
// `AutoDiffConfig` additionally stores a derivative generic signature.
180180
struct SILAutoDiffIndices {
181-
/// The index of the dependent result to differentiate from.
182-
unsigned source;
183-
/// The indices for independent parameters to differentiate with respect to.
181+
/// The indices of independent parameters to differentiate with respect to.
184182
IndexSubset *parameters;
183+
/// The indices of dependent results to differentiate from.
184+
IndexSubset *results;
185185

186-
/*implicit*/ SILAutoDiffIndices(unsigned source, IndexSubset *parameters)
187-
: source(source), parameters(parameters) {}
186+
/*implicit*/ SILAutoDiffIndices(IndexSubset *parameters, IndexSubset *results)
187+
: parameters(parameters), results(results) {
188+
assert(parameters && "Parameter indices must be non-null");
189+
assert(results && "Result indices must be non-null");
190+
}
188191

189192
bool operator==(const SILAutoDiffIndices &other) const;
190193

@@ -202,7 +205,12 @@ struct SILAutoDiffIndices {
202205
SWIFT_DEBUG_DUMP;
203206

204207
std::string mangle() const {
205-
std::string result = "src_" + llvm::utostr(source) + "_wrt_";
208+
std::string result = "src_";
209+
interleave(
210+
results->getIndices(),
211+
[&](unsigned idx) { result += llvm::utostr(idx); },
212+
[&] { result += '_'; });
213+
result += "_wrt_";
206214
llvm::interleave(
207215
parameters->getIndices(),
208216
[&](unsigned idx) { result += llvm::utostr(idx); },

include/swift/AST/Types.h

Lines changed: 48 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3942,13 +3942,34 @@ inline bool isIndirectFormalResult(ResultConvention convention) {
39423942
return convention == ResultConvention::Indirect;
39433943
}
39443944

3945+
/// The differentiability of a SIL function type result.
3946+
enum class SILResultDifferentiability : unsigned {
3947+
/// Either differentiable or not applicable.
3948+
///
3949+
/// - If the function type is not `@differentiable`, result
3950+
/// differentiability is not applicable. This case is the default value.
3951+
/// - If the function type is `@differentiable`, the function is
3952+
/// differentiable with respect to this result.
3953+
DifferentiableOrNotApplicable,
3954+
3955+
/// Not differentiable: a `@noDerivative` result.
3956+
///
3957+
/// May be applied only to result of `@differentiable` function types.
3958+
/// The function type is not differentiable with respect to this result.
3959+
NotDifferentiable,
3960+
};
3961+
39453962
/// A result type and the rules for returning it.
39463963
class SILResultInfo {
39473964
llvm::PointerIntPair<CanType, 3, ResultConvention> TypeAndConvention;
3965+
SILResultDifferentiability Differentiability : 1;
3966+
39483967
public:
39493968
SILResultInfo() = default;
3950-
SILResultInfo(CanType type, ResultConvention conv)
3951-
: TypeAndConvention(type, conv) {
3969+
SILResultInfo(CanType type, ResultConvention conv,
3970+
SILResultDifferentiability differentiability =
3971+
SILResultDifferentiability::DifferentiableOrNotApplicable)
3972+
: TypeAndConvention(type, conv), Differentiability(differentiability) {
39523973
assert(type->isLegalSILType() && "SILResultInfo has illegal SIL type");
39533974
}
39543975

@@ -3969,6 +3990,17 @@ class SILResultInfo {
39693990
ResultConvention getConvention() const {
39703991
return TypeAndConvention.getInt();
39713992
}
3993+
3994+
SILResultDifferentiability getDifferentiability() const {
3995+
return Differentiability;
3996+
}
3997+
3998+
SILResultInfo
3999+
getWithDifferentiability(SILResultDifferentiability differentiability) const {
4000+
return SILResultInfo(getInterfaceType(), getConvention(),
4001+
differentiability);
4002+
}
4003+
39724004
/// The SIL storage type determines the ABI for arguments based purely on the
39734005
/// formal result conventions. The actual SIL type for the result values may
39744006
/// differ in canonical SIL. In particular, opaque values require indirect
@@ -4025,6 +4057,7 @@ class SILResultInfo {
40254057

40264058
void profile(llvm::FoldingSetNodeID &id) {
40274059
id.AddPointer(TypeAndConvention.getOpaqueValue());
4060+
id.AddInteger((unsigned)getDifferentiability());
40284061
}
40294062

40304063
SWIFT_DEBUG_DUMP;
@@ -4714,24 +4747,31 @@ class SILFunctionType final : public TypeBase, public llvm::FoldingSetNode,
47144747
/// `@noDerivative` ones).
47154748
IndexSubset *getDifferentiabilityParameterIndices();
47164749

4750+
/// Given that `this` is a `@differentiable` or `@differentiable(linear)`
4751+
/// function type, returns an `IndexSubset` corresponding to the
4752+
/// differentiability/linearity results (e.g. all results except the
4753+
/// `@noDerivative` ones).
4754+
IndexSubset *getDifferentiabilityResultIndices();
4755+
47174756
/// Returns the `@differentiable` or `@differentiable(linear)` function type
4718-
/// for the given differentiability kind and parameter indices representing
4719-
/// differentiability/linearity parameters.
4757+
/// for the given differentiability kind and differentiability/linearity
4758+
/// parameter/result indices.
47204759
CanSILFunctionType getWithDifferentiability(DifferentiabilityKind kind,
4721-
IndexSubset *parameterIndices);
4760+
IndexSubset *parameterIndices,
4761+
IndexSubset *resultIndices);
47224762

47234763
/// Returns the SIL function type stripping differentiability kind and
47244764
/// differentiability from all parameters.
47254765
CanSILFunctionType getWithoutDifferentiability();
47264766

47274767
/// Returns the type of the derivative function for the given parameter
4728-
/// indices, result index, derivative function kind, derivative function
4768+
/// indices, result indices, derivative function kind, derivative function
47294769
/// generic signature (optional), and other auxiliary parameters.
47304770
///
47314771
/// Preconditions:
47324772
/// - Parameters corresponding to parameter indices must conform to
47334773
/// `Differentiable`.
4734-
/// - The result corresponding to the result index must conform to
4774+
/// - Results corresponding to result indices must conform to
47354775
/// `Differentiable`.
47364776
///
47374777
/// Typing rules, given:
@@ -4803,14 +4843,8 @@ class SILFunctionType final : public TypeBase, public llvm::FoldingSetNode,
48034843
/// function - this is more direct. It may be possible to implement
48044844
/// reabstraction thunk derivatives using "reabstraction thunks for
48054845
/// the original function's derivative", avoiding extra code generation.
4806-
///
4807-
/// Caveats:
4808-
/// - We may support multiple result indices instead of a single result index
4809-
/// eventually. At the SIL level, this enables differentiating wrt multiple
4810-
/// function results. At the Swift level, this enables differentiating wrt
4811-
/// multiple tuple elements for tuple-returning functions.
48124846
CanSILFunctionType getAutoDiffDerivativeFunctionType(
4813-
IndexSubset *parameterIndices, unsigned resultIndex,
4847+
IndexSubset *parameterIndices, IndexSubset *resultIndices,
48144848
AutoDiffDerivativeFunctionKind kind, Lowering::TypeConverter &TC,
48154849
LookupConformanceFn lookupConformance,
48164850
CanGenericSignature derivativeFunctionGenericSignature = nullptr,

include/swift/Demangling/TypeDecoder.h

Lines changed: 48 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -94,22 +94,13 @@ enum class ImplParameterDifferentiability {
9494
NotDifferentiable
9595
};
9696

97-
static inline llvm::Optional<ImplParameterDifferentiability>
98-
getDifferentiabilityFromString(StringRef string) {
99-
if (string.empty())
100-
return ImplParameterDifferentiability::DifferentiableOrNotApplicable;
101-
if (string == "@noDerivative")
102-
return ImplParameterDifferentiability::NotDifferentiable;
103-
return None;
104-
}
105-
10697
/// Describe a lowered function parameter, parameterized on the type
10798
/// representation.
10899
template <typename BuiltType>
109100
class ImplFunctionParam {
101+
BuiltType Type;
110102
ImplParameterConvention Convention;
111103
ImplParameterDifferentiability Differentiability;
112-
BuiltType Type;
113104

114105
public:
115106
using ConventionType = ImplParameterConvention;
@@ -137,9 +128,18 @@ class ImplFunctionParam {
137128
return None;
138129
}
139130

140-
ImplFunctionParam(ImplParameterConvention convention,
141-
ImplParameterDifferentiability diffKind, BuiltType type)
142-
: Convention(convention), Differentiability(diffKind), Type(type) {}
131+
static llvm::Optional<DifferentiabilityType>
132+
getDifferentiabilityFromString(StringRef string) {
133+
if (string.empty())
134+
return DifferentiabilityType::DifferentiableOrNotApplicable;
135+
if (string == "@noDerivative")
136+
return DifferentiabilityType::NotDifferentiable;
137+
return None;
138+
}
139+
140+
ImplFunctionParam(BuiltType type, ImplParameterConvention convention,
141+
ImplParameterDifferentiability diffKind)
142+
: Type(type), Convention(convention), Differentiability(diffKind) {}
143143

144144
ImplParameterConvention getConvention() const { return Convention; }
145145

@@ -158,15 +158,22 @@ enum class ImplResultConvention {
158158
Autoreleased,
159159
};
160160

161+
enum class ImplResultDifferentiability {
162+
DifferentiableOrNotApplicable,
163+
NotDifferentiable
164+
};
165+
161166
/// Describe a lowered function result, parameterized on the type
162167
/// representation.
163168
template <typename BuiltType>
164169
class ImplFunctionResult {
165-
ImplResultConvention Convention;
166170
BuiltType Type;
171+
ImplResultConvention Convention;
172+
ImplResultDifferentiability Differentiability;
167173

168174
public:
169175
using ConventionType = ImplResultConvention;
176+
using DifferentiabilityType = ImplResultDifferentiability;
170177

171178
static llvm::Optional<ConventionType>
172179
getConventionFromString(StringRef conventionString) {
@@ -184,11 +191,27 @@ class ImplFunctionResult {
184191
return None;
185192
}
186193

187-
ImplFunctionResult(ImplResultConvention convention, BuiltType type)
188-
: Convention(convention), Type(type) {}
194+
static llvm::Optional<DifferentiabilityType>
195+
getDifferentiabilityFromString(StringRef string) {
196+
if (string.empty())
197+
return DifferentiabilityType::DifferentiableOrNotApplicable;
198+
if (string == "@noDerivative")
199+
return DifferentiabilityType::NotDifferentiable;
200+
return None;
201+
}
202+
203+
ImplFunctionResult(
204+
BuiltType type, ImplResultConvention convention,
205+
ImplResultDifferentiability diffKind =
206+
ImplResultDifferentiability::DifferentiableOrNotApplicable)
207+
: Type(type), Convention(convention), Differentiability(diffKind) {}
189208

190209
ImplResultConvention getConvention() const { return Convention; }
191210

211+
ImplResultDifferentiability getDifferentiability() const {
212+
return Differentiability;
213+
}
214+
192215
BuiltType getType() const { return Type; }
193216
};
194217

@@ -640,7 +663,7 @@ class TypeDecoder {
640663
if (decodeImplFunctionParam(child, parameters))
641664
return BuiltType();
642665
} else if (child->getKind() == NodeKind::ImplResult) {
643-
if (decodeImplFunctionPart(child, results))
666+
if (decodeImplFunctionParam(child, results))
644667
return BuiltType();
645668
} else if (child->getKind() == NodeKind::ImplErrorResult) {
646669
if (decodeImplFunctionPart(child, errorResults))
@@ -913,13 +936,13 @@ class TypeDecoder {
913936
if (!type)
914937
return true;
915938

916-
results.emplace_back(*convention, type);
939+
results.emplace_back(type, *convention);
917940
return false;
918941
}
919942

920-
bool decodeImplFunctionParam(
921-
Demangle::NodePointer node,
922-
llvm::SmallVectorImpl<ImplFunctionParam<BuiltType>> &results) {
943+
template <typename T>
944+
bool decodeImplFunctionParam(Demangle::NodePointer node,
945+
llvm::SmallVectorImpl<T> &results) {
923946
// Children: `convention, differentiability?, type`
924947
if (node->getNumChildren() != 2 && node->getNumChildren() != 3)
925948
return true;
@@ -931,28 +954,26 @@ class TypeDecoder {
931954
return true;
932955

933956
StringRef conventionString = conventionNode->getText();
934-
auto convention =
935-
ImplFunctionParam<BuiltType>::getConventionFromString(conventionString);
957+
auto convention = T::getConventionFromString(conventionString);
936958
if (!convention)
937959
return true;
938960
BuiltType type = decodeMangledType(typeNode);
939961
if (!type)
940962
return true;
941963

942-
auto diffKind =
943-
ImplParameterDifferentiability::DifferentiableOrNotApplicable;
964+
auto diffKind = T::DifferentiabilityType::DifferentiableOrNotApplicable;
944965
if (node->getNumChildren() == 3) {
945966
auto diffKindNode = node->getChild(1);
946967
if (diffKindNode->getKind() != Node::Kind::ImplDifferentiability)
947968
return true;
948969
auto optDiffKind =
949-
getDifferentiabilityFromString(diffKindNode->getText());
970+
T::getDifferentiabilityFromString(diffKindNode->getText());
950971
if (!optDiffKind)
951972
return true;
952973
diffKind = *optDiffKind;
953974
}
954975

955-
results.emplace_back(*convention, diffKind, type);
976+
results.emplace_back(type, *convention, diffKind);
956977
return false;
957978
}
958979

include/swift/SIL/SILBuilder.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2179,10 +2179,11 @@ class SILBuilder {
21792179
//===--------------------------------------------------------------------===//
21802180

21812181
DifferentiableFunctionInst *createDifferentiableFunction(
2182-
SILLocation Loc, IndexSubset *ParameterIndices, SILValue OriginalFunction,
2182+
SILLocation Loc, IndexSubset *ParameterIndices,
2183+
IndexSubset *ResultIndices, SILValue OriginalFunction,
21832184
Optional<std::pair<SILValue, SILValue>> JVPAndVJPFunctions = None) {
21842185
return insert(DifferentiableFunctionInst::create(
2185-
getModule(), getSILDebugLocation(Loc), ParameterIndices,
2186+
getModule(), getSILDebugLocation(Loc), ParameterIndices, ResultIndices,
21862187
OriginalFunction, JVPAndVJPFunctions, hasOwnership()));
21872188
}
21882189

include/swift/SIL/SILCloner.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2853,6 +2853,7 @@ void SILCloner<ImplClass>::visitDifferentiableFunctionInst(
28532853
recordClonedInstruction(
28542854
Inst, getBuilder().createDifferentiableFunction(
28552855
getOpLocation(Inst->getLoc()), Inst->getParameterIndices(),
2856+
Inst->getResultIndices(),
28562857
getOpValue(Inst->getOriginalFunction()), derivativeFns));
28572858
}
28582859

0 commit comments

Comments
 (0)