Skip to content

Commit e6409d7

Browse files
authored
Correctly handle multiple semantic results for autodiff subset differential thunks (swiftlang#84366)
Fixes swiftlang#84365
1 parent 7b8b33e commit e6409d7

File tree

2 files changed

+69
-9
lines changed

2 files changed

+69
-9
lines changed

lib/SILOptimizer/Differentiation/Thunk.cpp

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -621,22 +621,33 @@ getOrCreateSubsetParametersThunkForLinearMap(
621621

622622
// If differential thunk, deallocate local allocations and directly return
623623
// `apply` result (if it is desired).
624+
// TODO: Unify with VJP code below
624625
if (kind == AutoDiffDerivativeFunctionKind::JVP) {
625626
SmallVector<SILValue, 8> differentialDirectResults;
626627
extractAllElements(ai, builder, differentialDirectResults);
627628
SmallVector<SILValue, 8> allResults;
628629
collectAllActualResultsInTypeOrder(ai, differentialDirectResults, allResults);
629-
unsigned numResults = thunk->getConventions().getNumDirectSILResults() +
630-
thunk->getConventions().getNumDirectSILResults();
631630
SmallVector<SILValue, 8> results;
632-
for (unsigned idx : *actualConfig.resultIndices) {
633-
if (idx >= numResults)
634-
break;
635631

636-
auto result = allResults[idx];
637-
if (desiredConfig.isWrtResult(idx))
638-
results.push_back(result);
639-
else {
632+
unsigned firstSemanticParamResultIdx = origFnType->getNumResults();
633+
for (unsigned resultIndex : *actualConfig.resultIndices) {
634+
SILValue result;
635+
if (resultIndex >= firstSemanticParamResultIdx) {
636+
auto semanticResultArgIdx = resultIndex - firstSemanticParamResultIdx;
637+
result =
638+
*std::next(ai->getAutoDiffSemanticResultArguments().begin(),
639+
semanticResultArgIdx);
640+
} else
641+
result = allResults[resultIndex];
642+
643+
// If result is desired:
644+
// - Do nothing if result is indirect.
645+
// (It was already forwarded to the `apply` instruction).
646+
// - Push it to `results` if result is direct.
647+
if (desiredConfig.isWrtResult(resultIndex)) {
648+
if (result->getType().isObject())
649+
results.push_back(result);
650+
} else { // Otherwise, cleanup the unused results.
640651
if (result->getType().isAddress())
641652
builder.emitDestroyAddrAndFold(loc, result);
642653
else
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
// RUN: %target-swift-frontend -emit-sil -verify %s
2+
3+
// https://github.com/swiftlang/swift/issues/84365
4+
// Ensure autodiff subset thunks for differential correctly
5+
// handle multiple semantic results and release unwanted
6+
// result values
7+
8+
import _Differentiation
9+
10+
@differentiable(reverse,wrt: logits)
11+
public func softSolveForwardWithQ(logits: [Float]) -> ([Float], [Float]) {
12+
return ([Float](repeating: 0, count: 0), [])
13+
}
14+
15+
@derivative(of: softSolveForwardWithQ, wrt: logits)
16+
public func vjpSoftSolveForwardWithQ(logits: [Float]) -> (value: ([Float], [Float]), pullback: ([Float].TangentVector, [Float].TangentVector) -> [Float].TangentVector) {
17+
let n = logits.count
18+
let q = [Float](repeating: 0, count: 0)
19+
let y = [Float](repeating: 0, count: 0)
20+
21+
return (
22+
value: (y, q),
23+
pullback: { _, _ in
24+
return Array<Float>.DifferentiableView([Float](repeating: 0, count: n))
25+
}
26+
)
27+
}
28+
29+
@differentiable(reverse,wrt: logits)
30+
public func forwardPredict(logits: [Float]) -> ([Float], [Float], [Float]) {
31+
let (y, q) = softSolveForwardWithQ(logits: logits)
32+
return (y, q, [0.0])
33+
}
34+
35+
@derivative(of: forwardPredict, wrt: logits)
36+
public func vjpForwardPredict(logits: [Float]) -> (
37+
value: ([Float], [Float], [Float]),
38+
pullback: ([Float].TangentVector, [Float].TangentVector, [Float].TangentVector) -> [Float].TangentVector
39+
) {
40+
let (valYQ, pb) = vjpSoftSolveForwardWithQ(logits: logits)
41+
let (y, q) = valYQ
42+
return ((y, q, [0.0]), { upY, upQ, _ in pb(upY, upQ) })
43+
}
44+
45+
@differentiable(reverse,wrt: logits)
46+
public func crossEntropyFromForwardPredict(logits: [Float]) -> Float {
47+
let (_, q, _) = forwardPredict(logits: logits)
48+
return q[0] + 1e-8
49+
}

0 commit comments

Comments
 (0)