Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit e8686f8

Browse files
author
marcrasi
authored
work around SR-13945 segfault (#1141)
1 parent eaa1c8e commit e8686f8

File tree

1 file changed

+53
-19
lines changed

1 file changed

+53
-19
lines changed

Sources/TensorFlow/Layers/Recurrent.swift

Lines changed: 53 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -623,26 +623,10 @@ where Cell.TimeStepOutput: Mergeable {
623623
) -> Output {
624624
let forwardOutputs = forward(
625625
inputs, initialState: initialForwardLayerState)
626-
627-
// TODO: Replace with inputs.reversed() after it become differentiable.
628-
var inputsReversed = Input()
629-
630-
for forwardIndex in 0 ..< withoutDerivative(at: inputs.count) {
631-
let backwardIndex = withoutDerivative(at: inputs.count - 1 - forwardIndex)
632-
inputsReversed.append(inputs[backwardIndex])
633-
}
634-
635626
let backwardOutputs = backward(
636-
inputsReversed, initialState: initialBackwardLayerState)
637-
638-
var outputs = Output()
639-
640-
for forwardIndex in 0 ..< withoutDerivative(at: inputs.count) {
641-
let backwardIndex = withoutDerivative(at: inputs.count - 1 - forwardIndex)
642-
outputs.append(mergeFunction(forwardOutputs[forwardIndex], backwardOutputs[backwardIndex]))
643-
}
644-
645-
return outputs
627+
inputs.differentiableReversed(), initialState: initialBackwardLayerState)
628+
return forwardOutputs.differentiableMerging(
629+
backwardOutputs.differentiableReversed(), mergeFunction: mergeFunction)
646630
}
647631

648632
@differentiable
@@ -703,3 +687,53 @@ public typealias SimpleRNNCell = BasicRNNCell
703687

704688
@available(*, deprecated, renamed: "BasicRNN")
705689
public typealias SimpleRNN = BasicRNN
690+
691+
// - MARK: Workaround helpers.
692+
693+
fileprivate extension Array where Element: Differentiable {
694+
/// Returns a reversed copy of `self`.
695+
///
696+
/// This has a custom derivative, which works around the SR-13945 segfault that you would
697+
/// encounter if you tried to implement this at the callsite using a for loop.
698+
@differentiable
699+
func differentiableReversed() -> Self {
700+
.init(self.reversed())
701+
}
702+
703+
@derivative(of: differentiableReversed)
704+
func vjpDifferentiableReversed()
705+
-> (value: Self, pullback: (TangentVector) -> TangentVector)
706+
{
707+
return (self.differentiableReversed(), { .init(.init($0.base.reversed())) })
708+
}
709+
710+
/// Returns `zip(self, other).map { mergeFunction($0.0, $0.1) }`.
711+
///
712+
/// This has a custom derivative, which works around the SR-13945 segfault that you would
713+
/// encounter if you tried to implement this at the callsite using a for loop.
714+
@differentiable
715+
func differentiableMerging(
716+
_ other: Self, mergeFunction: @differentiable (Element, Element) -> Element
717+
) -> Self {
718+
zip(self, other).map { mergeFunction($0.0, $0.1) }
719+
}
720+
721+
@derivative(of: differentiableMerging)
722+
func vjpDifferentiableMerging(
723+
_ other: Self, mergeFunction: @differentiable (Element, Element) -> Element
724+
) -> (value: Self, pullback: (TangentVector) -> (TangentVector, TangentVector)) {
725+
let valuesWithPullbacks = zip(self, other).map {
726+
valueWithPullback(at: $0.0, $0.1, in: mergeFunction)
727+
}
728+
let pullbacks = valuesWithPullbacks.map { $0.pullback }
729+
return (
730+
valuesWithPullbacks.map { $0.value },
731+
{ vs in
732+
let resultPairs = zip(vs.base, pullbacks).map { (v, pb) in
733+
pb(v)
734+
}
735+
return (.init(resultPairs.map { $0.0 }), .init(resultPairs.map { $0.1 }))
736+
}
737+
)
738+
}
739+
}

0 commit comments

Comments
 (0)