|
| 1 | +// RUN: %target-swift-frontend -emit-sil -verify %s |
| 2 | + |
| 3 | +// https://github.com/swiftlang/swift/issues/84365 |
| 4 | +// Ensure autodiff subset thunks for differential correctly |
| 5 | +// handle multiple semantic results and release unwanted |
| 6 | +// result values |
| 7 | + |
| 8 | +import _Differentiation |
| 9 | + |
| 10 | +@differentiable(reverse,wrt: logits) |
| 11 | +public func softSolveForwardWithQ(logits: [Float]) -> ([Float], [Float]) { |
| 12 | + return ([Float](repeating: 0, count: 0), []) |
| 13 | +} |
| 14 | + |
| 15 | +@derivative(of: softSolveForwardWithQ, wrt: logits) |
| 16 | +public func vjpSoftSolveForwardWithQ(logits: [Float]) -> (value: ([Float], [Float]), pullback: ([Float].TangentVector, [Float].TangentVector) -> [Float].TangentVector) { |
| 17 | + let n = logits.count |
| 18 | + let q = [Float](repeating: 0, count: 0) |
| 19 | + let y = [Float](repeating: 0, count: 0) |
| 20 | + |
| 21 | + return ( |
| 22 | + value: (y, q), |
| 23 | + pullback: { _, _ in |
| 24 | + return Array<Float>.DifferentiableView([Float](repeating: 0, count: n)) |
| 25 | + } |
| 26 | + ) |
| 27 | +} |
| 28 | + |
| 29 | +@differentiable(reverse,wrt: logits) |
| 30 | +public func forwardPredict(logits: [Float]) -> ([Float], [Float], [Float]) { |
| 31 | + let (y, q) = softSolveForwardWithQ(logits: logits) |
| 32 | + return (y, q, [0.0]) |
| 33 | +} |
| 34 | + |
| 35 | +@derivative(of: forwardPredict, wrt: logits) |
| 36 | +public func vjpForwardPredict(logits: [Float]) -> ( |
| 37 | + value: ([Float], [Float], [Float]), |
| 38 | + pullback: ([Float].TangentVector, [Float].TangentVector, [Float].TangentVector) -> [Float].TangentVector |
| 39 | +) { |
| 40 | + let (valYQ, pb) = vjpSoftSolveForwardWithQ(logits: logits) |
| 41 | + let (y, q) = valYQ |
| 42 | + return ((y, q, [0.0]), { upY, upQ, _ in pb(upY, upQ) }) |
| 43 | +} |
| 44 | + |
| 45 | +@differentiable(reverse,wrt: logits) |
| 46 | +public func crossEntropyFromForwardPredict(logits: [Float]) -> Float { |
| 47 | + let (_, q, _) = forwardPredict(logits: logits) |
| 48 | + return q[0] + 1e-8 |
| 49 | +} |
0 commit comments