12
12
// See the License for the specific language governing permissions and
13
13
// limitations under the License.
14
14
15
- import _Differentiation
15
+ @ _exported import _Differentiation
16
16
#if TENSORFLOW_USE_STANDARD_TOOLCHAIN
17
17
import Numerics
18
18
#endif
19
19
20
+ #if !TENSORFLOW_USE_STANDARD_TOOLCHAIN
20
21
// MARK: - Array extensions
21
22
22
23
extension Array : ElementaryFunctions where Element: ElementaryFunctions {
@@ -107,6 +108,7 @@ extension Array: ElementaryFunctions where Element: ElementaryFunctions {
107
108
/// For complex types, there is a branch cut along the negative real axis.
108
109
public static func root( _ x: Self , _ n: Int ) -> Self { x. map { Element . root ( $0, n) } }
109
110
}
111
+ #endif
110
112
111
113
// MARK: - Array derivative extensions
112
114
@@ -116,47 +118,48 @@ where Element: Differentiable & ElementaryFunctions {
116
118
///
117
119
/// For real types, if `x` is negative the result is `.nan`. For complex
118
120
/// types there is a branch cut on the negative real axis.
119
- public static func sqrt( _ x: Self ) -> Self { . init( Array . sqrt ( x . base ) ) }
121
+ public static func sqrt( _ x: Self ) -> Self { . init( x . map ( Element . sqrt ) ) }
120
122
121
123
/// The cosine of `x`, interpreted as an angle in radians.
122
- public static func cos( _ x: Self ) -> Self { . init( Array . cos ( x . base ) ) }
124
+ public static func cos( _ x: Self ) -> Self { . init( x . map ( Element . cos ) ) }
123
125
124
126
/// The sine of `x`, interpreted as an angle in radians.
125
- public static func sin( _ x: Self ) -> Self { . init( Array . sin ( x . base ) ) }
127
+ public static func sin( _ x: Self ) -> Self { . init( x . map ( Element . sin ) ) }
126
128
127
129
/// The tangent of `x`, interpreted as an angle in radians.
128
- public static func tan( _ x: Self ) -> Self { . init( Array . tan ( x . base ) ) }
130
+ public static func tan( _ x: Self ) -> Self { . init( x . map ( Element . tan ) ) }
129
131
130
132
/// The inverse cosine of `x` in radians.
131
- public static func acos( _ x: Self ) -> Self { . init( Array . acos ( x . base ) ) }
133
+ public static func acos( _ x: Self ) -> Self { . init( x . map ( Element . acos ) ) }
132
134
133
135
/// The inverse sine of `x` in radians.
134
- public static func asin( _ x: Self ) -> Self { . init( Array . asin ( x . base ) ) }
136
+ public static func asin( _ x: Self ) -> Self { . init( x . map ( Element . asin ) ) }
135
137
136
138
/// The inverse tangent of `x` in radians.
137
- public static func atan( _ x: Self ) -> Self { . init( Array . atan ( x . base ) ) }
139
+ public static func atan( _ x: Self ) -> Self { . init( x . map ( Element . atan ) ) }
138
140
139
141
/// The hyperbolic cosine of `x`.
140
- public static func cosh( _ x: Self ) -> Self { . init( Array . cosh ( x . base ) ) }
142
+ public static func cosh( _ x: Self ) -> Self { . init( x . map ( Element . cosh ) ) }
141
143
142
144
/// The hyperbolic sine of `x`.
143
- public static func sinh( _ x: Self ) -> Self { . init( Array . sinh ( x . base ) ) }
145
+ public static func sinh( _ x: Self ) -> Self { . init( x . map ( Element . sinh ) ) }
144
146
145
147
/// The hyperbolic tangent of `x`.
146
- public static func tanh( _ x: Self ) -> Self { . init( Array . tanh ( x . base ) ) }
148
+ public static func tanh( _ x: Self ) -> Self { . init( x . map ( Element . tanh ) ) }
147
149
148
150
/// The inverse hyperbolic cosine of `x`.
149
- public static func acosh( _ x: Self ) -> Self { . init( Array . acosh ( x . base ) ) }
151
+ public static func acosh( _ x: Self ) -> Self { . init( x . map ( Element . acosh ) ) }
150
152
151
153
/// The inverse hyperbolic sine of `x`.
152
- public static func asinh( _ x: Self ) -> Self { . init( Array . asinh ( x . base ) ) }
154
+ public static func asinh( _ x: Self ) -> Self { . init( x . map ( Element . asinh ) ) }
153
155
154
156
/// The inverse hyperbolic tangent of `x`.
155
- public static func atanh( _ x: Self ) -> Self { . init( Array . atanh ( x . base ) ) }
157
+ public static func atanh( _ x: Self ) -> Self { . init( x . map ( Element . atanh ) ) }
156
158
157
159
/// The exponential function applied to `x`, or `e**x`.
158
- public static func exp( _ x: Self ) -> Self { . init( Array . exp ( x . base ) ) }
160
+ public static func exp( _ x: Self ) -> Self { . init( x . map ( Element . exp ) ) }
159
161
162
+ #if !TENSORFLOW_USE_STANDARD_TOOLCHAIN
160
163
/// Two raised to to power `x`.
161
164
public static func exp2( _ x: Self ) -> Self { . init( Array . exp2 ( x. base) ) }
162
165
@@ -165,36 +168,51 @@ where Element: Differentiable & ElementaryFunctions {
165
168
166
169
/// `exp(x) - 1` evaluated so as to preserve accuracy close to zero.
167
170
public static func expm1( _ x: Self ) -> Self { . init( Array . expm1 ( x. base) ) }
171
+ #else
172
+
173
+ /// `exp(x) - 1` evaluated so as to preserve accuracy close to zero.
174
+ public static func expMinusOne( _ x: Self ) -> Self { . init( x. map ( Element . expMinusOne) ) }
175
+ #endif
168
176
169
177
/// The natural logarithm of `x`.
170
- public static func log( _ x: Self ) -> Self { . init( Array . log ( x . base ) ) }
178
+ public static func log( _ x: Self ) -> Self { . init( x . map { Element . exp ( $0 ) } ) }
171
179
180
+ #if !TENSORFLOW_USE_STANDARD_TOOLCHAIN
172
181
/// The base-two logarithm of `x`.
173
182
public static func log2( _ x: Self ) -> Self { . init( Array . log2 ( x. base) ) }
174
183
175
184
/// The base-ten logarithm of `x`.
176
185
public static func log10( _ x: Self ) -> Self { . init( Array . log10 ( x. base) ) }
177
186
178
187
/// `log(1 + x)` evaluated so as to preserve accuracy close to zero.
179
- public static func log1p( _ x: Self ) -> Self { . init( Array . log1p ( x. base) ) }
188
+ public static func log1p( _ x: Self ) -> Self {
189
+ . init( Array . log1p ( x. base) )
190
+ }
191
+ #else
192
+
193
+ /// The natural logarithm of `x + 1` to preserve accuracy close to zero.
194
+ public static func log( onePlus x: Self ) -> Self {
195
+ . init( x. map { Element . log ( onePlus: $0) } )
196
+ }
197
+ #endif
180
198
181
199
/// `exp(y log(x))` computed without loss of intermediate precision.
182
200
///
183
201
/// For real types, if `x` is negative the result is NaN, even if `y` has
184
202
/// an integral value. For complex types, there is a branch cut on the
185
203
/// negative real axis.
186
- public static func pow( _ x: Self , _ y: Self ) -> Self { . init( Array . pow ( x . base , y. base ) ) }
204
+ public static func pow( _ x: Self , _ y: Self ) -> Self { . init( zip ( x , y) . map ( Element . pow ) ) }
187
205
188
206
/// `x` raised to the `n`th power.
189
207
///
190
208
/// The product of `n` copies of `x`.
191
- public static func pow( _ x: Self , _ n: Int ) -> Self { . init( Array . pow ( x . base , n) ) }
209
+ public static func pow( _ x: Self , _ n: Int ) -> Self { . init( x . map { Element . pow ( $0 , n) } ) }
192
210
193
211
/// The `n`th root of `x`.
194
212
///
195
213
/// For real types, if `x` is negative and `n` is even, the result is NaN.
196
214
/// For complex types, there is a branch cut along the negative real axis.
197
- public static func root( _ x: Self , _ n: Int ) -> Self { . init( Array . root ( x . base , n) ) }
215
+ public static func root( _ x: Self , _ n: Int ) -> Self { . init( x . map { Element . root ( $0 , n) } ) }
198
216
}
199
217
200
218
extension Array . DifferentiableView :
@@ -226,6 +244,7 @@ where Element: Differentiable {
226
244
public init ( ) { self . init ( . init( ) ) }
227
245
}
228
246
247
+ #if !TENSORFLOW_USE_STANDARD_TOOLCHAIN
229
248
extension Array . DifferentiableView : VectorProtocol
230
249
where Element: Differentiable & VectorProtocol {
231
250
public typealias VectorSpaceScalar = Element . VectorSpaceScalar
@@ -282,6 +301,7 @@ where Element: Differentiable & PointwiseMultiplicative {
282
301
}
283
302
}
284
303
}
304
+ #endif
285
305
286
306
extension Collection {
287
307
/// Returns the `n`th position in `self`.
0 commit comments