@@ -623,26 +623,10 @@ where Cell.TimeStepOutput: Mergeable {
623
623
) -> Output {
624
624
let forwardOutputs = forward (
625
625
inputs, initialState: initialForwardLayerState)
626
-
627
- // TODO: Replace with inputs.reversed() after it become differentiable.
628
- var inputsReversed = Input ( )
629
-
630
- for forwardIndex in 0 ..< withoutDerivative ( at: inputs. count) {
631
- let backwardIndex = withoutDerivative ( at: inputs. count - 1 - forwardIndex)
632
- inputsReversed. append ( inputs [ backwardIndex] )
633
- }
634
-
635
626
let backwardOutputs = backward (
636
- inputsReversed, initialState: initialBackwardLayerState)
637
-
638
- var outputs = Output ( )
639
-
640
- for forwardIndex in 0 ..< withoutDerivative ( at: inputs. count) {
641
- let backwardIndex = withoutDerivative ( at: inputs. count - 1 - forwardIndex)
642
- outputs. append ( mergeFunction ( forwardOutputs [ forwardIndex] , backwardOutputs [ backwardIndex] ) )
643
- }
644
-
645
- return outputs
627
+ inputs. differentiableReversed ( ) , initialState: initialBackwardLayerState)
628
+ return forwardOutputs. differentiableMerging (
629
+ backwardOutputs. differentiableReversed ( ) , mergeFunction: mergeFunction)
646
630
}
647
631
648
632
@differentiable
@@ -703,3 +687,53 @@ public typealias SimpleRNNCell = BasicRNNCell
703
687
704
688
@available ( * , deprecated, renamed: " BasicRNN " )
705
689
public typealias SimpleRNN = BasicRNN
690
+
691
+ // - MARK: Workaround helpers.
692
+
693
+ fileprivate extension Array where Element: Differentiable {
694
+ /// Returns a reversed copy of `self`.
695
+ ///
696
+ /// This has a custom derivative, which works around the SR-13945 segfault that you would
697
+ /// encounter if you tried to implement this at the callsite using a for loop.
698
+ @differentiable
699
+ func differentiableReversed( ) -> Self {
700
+ . init( self . reversed ( ) )
701
+ }
702
+
703
+ @derivative ( of: differentiableReversed)
704
+ func vjpDifferentiableReversed( )
705
+ -> ( value: Self , pullback: ( TangentVector ) -> TangentVector )
706
+ {
707
+ return ( self . differentiableReversed ( ) , { . init( . init( $0. base. reversed ( ) ) ) } )
708
+ }
709
+
710
+ /// Returns `zip(self, other).map { mergeFunction($0.0, $0.1) }`.
711
+ ///
712
+ /// This has a custom derivative, which works around the SR-13945 segfault that you would
713
+ /// encounter if you tried to implement this at the callsite using a for loop.
714
+ @differentiable
715
+ func differentiableMerging(
716
+ _ other: Self , mergeFunction: @differentiable ( Element , Element ) -> Element
717
+ ) -> Self {
718
+ zip ( self , other) . map { mergeFunction ( $0. 0 , $0. 1 ) }
719
+ }
720
+
721
+ @derivative ( of: differentiableMerging)
722
+ func vjpDifferentiableMerging(
723
+ _ other: Self , mergeFunction: @differentiable ( Element , Element ) -> Element
724
+ ) -> ( value: Self , pullback: ( TangentVector ) -> ( TangentVector , TangentVector ) ) {
725
+ let valuesWithPullbacks = zip ( self , other) . map {
726
+ valueWithPullback ( at: $0. 0 , $0. 1 , in: mergeFunction)
727
+ }
728
+ let pullbacks = valuesWithPullbacks. map { $0. pullback }
729
+ return (
730
+ valuesWithPullbacks. map { $0. value } ,
731
+ { vs in
732
+ let resultPairs = zip ( vs. base, pullbacks) . map { ( v, pb) in
733
+ pb ( v)
734
+ }
735
+ return ( . init( resultPairs. map { $0. 0 } ) , . init( resultPairs. map { $0. 1 } ) )
736
+ }
737
+ )
738
+ }
739
+ }
0 commit comments