Skip to content

Commit 83b2ebe

Browse files
authored
[AutoDiff] Support forward mode differentiation of functions with inout parameters (swiftlang#33584)
Adds forward mode support for `apply` instruction with `inout` arguments. Example of supported code: ``` func add(_ x: inout Float, _ y: inout Float) -> Float { var result = x result += y return result } print(differential(at: 1, 1, in: add)(1, 1)) // prints "2" ```
1 parent 7c508a0 commit 83b2ebe

File tree

4 files changed

+220
-77
lines changed

4 files changed

+220
-77
lines changed

lib/SILOptimizer/Differentiation/JVPCloner.cpp

Lines changed: 34 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -455,18 +455,6 @@ class JVPCloner::Implementation final
455455
return;
456456
}
457457

458-
// Diagnose functions with active inout arguments.
459-
// TODO(TF-129): Support `inout` argument differentiation.
460-
for (auto inoutArg : ai->getInoutArguments()) {
461-
if (activityInfo.isActive(inoutArg, getIndices())) {
462-
context.emitNondifferentiabilityError(
463-
ai, invoker,
464-
diag::autodiff_cannot_differentiate_through_inout_arguments);
465-
errorOccurred = true;
466-
return;
467-
}
468-
}
469-
470458
auto loc = ai->getLoc();
471459
auto &builder = getBuilder();
472460
auto origCallee = getOpValue(ai->getCallee());
@@ -1241,6 +1229,10 @@ class JVPCloner::Implementation final
12411229
SmallVector<SILValue, 8> differentialAllResults;
12421230
collectAllActualResultsInTypeOrder(
12431231
differentialCall, differentialDirectResults, differentialAllResults);
1232+
for (auto inoutArg : ai->getInoutArguments())
1233+
origAllResults.push_back(inoutArg);
1234+
for (auto inoutArg : differentialCall->getInoutArguments())
1235+
differentialAllResults.push_back(inoutArg);
12441236
assert(applyIndices.results->getNumIndices() ==
12451237
differentialAllResults.size());
12461238

@@ -1484,11 +1476,14 @@ void JVPCloner::Implementation::prepareForDifferentialGeneration() {
14841476
auto origIndResults = original->getIndirectResults();
14851477
auto diffIndResults = differential.getIndirectResults();
14861478
#ifndef NDEBUG
1487-
unsigned numInoutParameters = llvm::count_if(
1488-
original->getLoweredFunctionType()->getParameters(),
1489-
[](SILParameterInfo paramInfo) { return paramInfo.isIndirectInOut(); });
1490-
assert(origIndResults.size() + numInoutParameters == diffIndResults.size());
1479+
unsigned numNonWrtInoutParameters = llvm::count_if(
1480+
range(original->getLoweredFunctionType()->getNumParameters()),
1481+
[&] (unsigned i) {
1482+
auto &paramInfo = original->getLoweredFunctionType()->getParameters()[i];
1483+
return paramInfo.isIndirectInOut() && !getIndices().parameters->contains(i);
1484+
});
14911485
#endif
1486+
assert(origIndResults.size() + numNonWrtInoutParameters == diffIndResults.size());
14921487
for (auto &origBB : *original)
14931488
for (auto i : indices(origIndResults))
14941489
setTangentBuffer(&origBB, origIndResults[i], diffIndResults[i]);
@@ -1521,23 +1516,10 @@ void JVPCloner::Implementation::prepareForDifferentialGeneration() {
15211516
auto origParams = origTy->getParameters();
15221517
auto indices = witness->getSILAutoDiffIndices();
15231518

1524-
// Add differential results.
1525-
Optional<SILParameterInfo> inoutDiffParam = None;
1526-
for (auto origParam : origTy->getParameters()) {
1527-
if (!origParam.isIndirectInOut())
1528-
continue;
1529-
inoutDiffParam = origParam;
1530-
}
1531-
1532-
if (inoutDiffParam) {
1533-
dfResults.push_back(
1534-
SILResultInfo(inoutDiffParam->getInterfaceType()
1535-
->getAutoDiffTangentSpace(lookupConformance)
1536-
->getType()
1537-
->getCanonicalType(witnessCanGenSig),
1538-
ResultConvention::Indirect));
1539-
} else {
1540-
for (auto resultIndex : indices.results->getIndices()) {
1519+
1520+
for (auto resultIndex : indices.results->getIndices()) {
1521+
if (resultIndex < origTy->getNumResults()) {
1522+
// Handle formal original result.
15411523
auto origResult = origTy->getResults()[resultIndex];
15421524
origResult = origResult.getWithInterfaceType(
15431525
origResult.getInterfaceType()->getCanonicalType(witnessCanGenSig));
@@ -1548,6 +1530,25 @@ void JVPCloner::Implementation::prepareForDifferentialGeneration() {
15481530
->getCanonicalType(witnessCanGenSig),
15491531
origResult.getConvention()));
15501532
}
1533+
else {
1534+
// Handle original `inout` parameter.
1535+
auto inoutParamIndex = resultIndex - origTy->getNumResults();
1536+
auto inoutParamIt = std::next(
1537+
origTy->getIndirectMutatingParameters().begin(), inoutParamIndex);
1538+
auto paramIndex =
1539+
std::distance(origTy->getParameters().begin(), &*inoutParamIt);
1540+
// If the original `inout` parameter is a differentiability parameter, then
1541+
// it already has a corresponding differential parameter. Skip adding a
1542+
// corresponding differential result.
1543+
if (indices.parameters->contains(paramIndex))
1544+
continue;
1545+
auto inoutParam = origTy->getParameters()[paramIndex];
1546+
auto paramTan = inoutParam.getInterfaceType()->getAutoDiffTangentSpace(
1547+
lookupConformance);
1548+
assert(paramTan && "Parameter type does not have a tangent space?");
1549+
dfResults.push_back(
1550+
{paramTan->getCanonicalType(), ResultConvention::Indirect});
1551+
}
15511552
}
15521553

15531554
// Add differential parameters for the requested wrt parameters.

stdlib/public/Differentiation/FloatingPointDifferentiation.swift.gyb

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -205,8 +205,9 @@ extension ${Self} {
205205
static func _jvpMultiplyAssign(_ lhs: inout ${Self}, _ rhs: ${Self}) -> (
206206
value: Void, differential: (inout ${Self}, ${Self}) -> Void
207207
) {
208+
let oldLhs = lhs
208209
lhs *= rhs
209-
return ((), { $0 *= $1 })
210+
return ((), { $0 = $0 * rhs + oldLhs * $1 })
210211
}
211212

212213
${Availability(bits)}
@@ -251,8 +252,9 @@ extension ${Self} {
251252
static func _jvpDivideAssign(_ lhs: inout ${Self}, _ rhs: ${Self}) -> (
252253
value: Void, differential: (inout ${Self}, ${Self}) -> Void
253254
) {
255+
let oldLhs = lhs
254256
lhs /= rhs
255-
return ((), { $0 /= $1 })
257+
return ((), { $0 = ($0 * rhs - oldLhs * $1) / (rhs * rhs) })
256258
}
257259
}
258260

test/AutoDiff/SILOptimizer/forward_mode_diagnostics.swift

Lines changed: 21 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: %target-swift-frontend -enable-experimental-forward-mode-differentiation -emit-sil -verify %s
1+
// RUN: %target-swift-frontend -emit-sil -enable-experimental-forward-mode-differentiation -verify %s
22

33
// Test forward-mode differentiation transform diagnostics.
44

@@ -46,8 +46,6 @@ func nonVariedResult(_ x: Float) -> Float {
4646
// Multiple results
4747
//===----------------------------------------------------------------------===//
4848

49-
// TODO(TF-983): Support differentiation of multiple results.
50-
/*
5149
func multipleResults(_ x: Float) -> (Float, Float) {
5250
return (x, x)
5351
}
@@ -56,28 +54,21 @@ func usesMultipleResults(_ x: Float) -> Float {
5654
let tuple = multipleResults(x)
5755
return tuple.0 + tuple.1
5856
}
59-
*/
6057

6158
//===----------------------------------------------------------------------===//
6259
// `inout` parameter differentiation
6360
//===----------------------------------------------------------------------===//
6461

65-
// expected-error @+1 {{function is not differentiable}}
6662
@differentiable
67-
// expected-note @+1 {{when differentiating this function definition}}
6863
func activeInoutParamNonactiveInitialResult(_ x: Float) -> Float {
6964
var result: Float = 1
70-
// expected-note @+1 {{cannot differentiate through 'inout' arguments}}
7165
result += x
7266
return result
7367
}
7468

75-
// expected-error @+1 {{function is not differentiable}}
7669
@differentiable
77-
// expected-note @+1 {{when differentiating this function definition}}
7870
func activeInoutParamTuple(_ x: Float) -> Float {
7971
var tuple = (x, x)
80-
// expected-note @+1 {{cannot differentiate through 'inout' arguments}}
8172
tuple.0 *= x
8273
return x * tuple.0
8374
}
@@ -94,49 +85,37 @@ func activeInoutParamControlFlow(_ array: [Float]) -> Float {
9485
return result
9586
}
9687

97-
struct Mut: Differentiable {}
98-
extension Mut {
99-
@differentiable(wrt: x)
100-
mutating func mutatingMethod(_ x: Mut) {}
101-
}
102-
10388
// FIXME(TF-984): Forward-mode crash due to unset tangent buffer.
10489
/*
105-
@differentiable(wrt: x)
106-
func nonActiveInoutParam(_ nonactive: inout Mut, _ x: Mut) -> Mut {
107-
return nonactive.mutatingMethod(x)
90+
struct X: Differentiable {
91+
var x : Float
92+
93+
@differentiable(wrt: y)
94+
mutating func mutate(_ y: X) { self.x = y.x }
10895
}
109-
*/
11096

111-
// FIXME(TF-984): Forward-mode crash due to unset tangent buffer.
112-
/*
113-
@differentiable(wrt: x)
114-
func activeInoutParamMutatingMethod(_ x: Mut) -> Mut {
115-
var result = x
116-
result = result.mutatingMethod(result)
117-
return result
97+
@differentiable
98+
func activeMutatingMethod(_ x: Float) -> Float {
99+
let x1 = X.init(x: x)
100+
var x2 = X.init(x: 0)
101+
x2.mutate(x1)
102+
return x1.x
118103
}
119104
*/
120105

121-
// FIXME(TF-984): Forward-mode crash due to unset tangent buffer.
122-
/*
123-
@differentiable(wrt: x)
124-
func activeInoutParamMutatingMethodVar(_ nonactive: inout Mut, _ x: Mut) -> Mut {
125-
var result = nonactive
126-
result = result.mutatingMethod(x)
127-
return result
106+
107+
struct Mut: Differentiable {}
108+
extension Mut {
109+
@differentiable(wrt: x)
110+
mutating func mutatingMethod(_ x: Mut) {}
128111
}
129-
*/
130112

131-
// FIXME(TF-984): Forward-mode crash due to unset tangent buffer.
132-
/*
133113
@differentiable(wrt: x)
134-
func activeInoutParamMutatingMethodTuple(_ nonactive: inout Mut, _ x: Mut) -> Mut {
135-
var result = (nonactive, x)
136-
let result2 = result.0.mutatingMethod(result.0)
137-
return result2
114+
func activeInoutParamMutatingMethod(_ x: Mut) -> Mut {
115+
var result = x
116+
result.mutatingMethod(result)
117+
return result
138118
}
139-
*/
140119

141120
//===----------------------------------------------------------------------===//
142121
// Subset parameter differentiation thunks

0 commit comments

Comments
 (0)