Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit 63a43a4

Browse files
authored
Reflection implementation of ElementaryFunctions. (#1138)
1 parent dba28a9 commit 63a43a4

File tree

9 files changed

+205
-4
lines changed

9 files changed

+205
-4
lines changed

Sources/TensorFlow/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ add_library(TensorFlow SHARED
3737
Core/Threading.swift
3838
Core/Utilities.swift
3939
Core/EuclideanDifferentiable.swift
40+
Core/ElementaryFunctions.swift
4041

4142
Epochs/Algorithms.swift
4243
Epochs/Backend.swift
Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
// Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#if TENSORFLOW_USE_STANDARD_TOOLCHAIN
16+
import Numerics
17+
@_spi(Reflection) import Swift
18+
extension ElementaryFunctions {
19+
internal static func visitChildren(
20+
_ body: (PartialKeyPath<Self>, ElementaryFunctionsVisit.Type) -> Void
21+
) {
22+
if !_forEachFieldWithKeyPath(
23+
of: Self.self,
24+
body: { name, kp in
25+
func visitChild<T>(_: T.Type) {
26+
guard let t = ElementaryFunctionsVisitor<T>.self as? ElementaryFunctionsVisit.Type
27+
else {
28+
fatalError("No conformance of \(T.self) to ElementaryFunctions")
29+
}
30+
body(kp, t)
31+
}
32+
let valueType = type(of: kp).valueType
33+
_openExistential(valueType, do: visitChild)
34+
return true
35+
})
36+
{
37+
fatalError("not all children of \(Self.self) conform to ElementaryFunctions")
38+
}
39+
}
40+
}
41+
42+
protocol Functor1 {
43+
func callAsFunction<T: ElementaryFunctions>(_ x: T) -> T
44+
}
45+
protocol Functor2 {
46+
func callAsFunction<T: ElementaryFunctions>(_ x: T, _ y: T) -> T
47+
}
48+
49+
protocol ElementaryFunctionsVisit {
50+
static func applyFunctor<Root, Fn: Functor1>(
51+
_ out: inout Root, _ kp: PartialKeyPath<Root>, _ fn: Fn)
52+
static func applyFunctor<Root, Fn: Functor2>(
53+
_ out: inout Root, _ y: Root, _ kp: PartialKeyPath<Root>, _ fn: Fn)
54+
}
55+
struct ElementaryFunctionsVisitor<T> {}
56+
extension ElementaryFunctionsVisitor: ElementaryFunctionsVisit where T: ElementaryFunctions {
57+
static func applyFunctor<Root, Fn: Functor1>(
58+
_ out: inout Root, _ kp: PartialKeyPath<Root>, _ fn: Fn
59+
) {
60+
guard let kp = kp as? WritableKeyPath<Root, T> else { fatalError("problem") }
61+
({ (x: inout T) in x = fn(x) })(&out[keyPath: kp])
62+
}
63+
static func applyFunctor<Root, Fn: Functor2>(
64+
_ out: inout Root, _ y: Root, _ kp: PartialKeyPath<Root>, _ fn: Fn
65+
) {
66+
guard let kp = kp as? WritableKeyPath<Root, T> else { fatalError("problem") }
67+
({ (x: inout T) in x = fn(x, y[keyPath: kp]) })(&out[keyPath: kp])
68+
}
69+
}
70+
71+
extension ElementaryFunctions {
72+
internal init<Fn: Functor1>(mapped fn: Fn, _ x: Self) {
73+
self = x
74+
Self.visitChildren { kp, t in t.applyFunctor(&self, kp, fn) }
75+
}
76+
internal init<Fn: Functor2>(mapped fn: Fn, _ x: Self, _ y: Self) {
77+
self = x
78+
Self.visitChildren { kp, t in t.applyFunctor(&self, y, kp, fn) }
79+
}
80+
}
81+
82+
struct Functor_exp: Functor1 {
83+
func callAsFunction<T: ElementaryFunctions>(_ x: T) -> T { T.exp(x) }
84+
}
85+
struct Functor_expMinusOne: Functor1 {
86+
func callAsFunction<T: ElementaryFunctions>(_ x: T) -> T { T.expMinusOne(x) }
87+
}
88+
struct Functor_cosh: Functor1 {
89+
func callAsFunction<T: ElementaryFunctions>(_ x: T) -> T { T.cosh(x) }
90+
}
91+
struct Functor_sinh: Functor1 {
92+
func callAsFunction<T: ElementaryFunctions>(_ x: T) -> T { T.sinh(x) }
93+
}
94+
struct Functor_tanh: Functor1 {
95+
func callAsFunction<T: ElementaryFunctions>(_ x: T) -> T { T.tanh(x) }
96+
}
97+
struct Functor_cos: Functor1 {
98+
func callAsFunction<T: ElementaryFunctions>(_ x: T) -> T { T.cos(x) }
99+
}
100+
struct Functor_sin: Functor1 {
101+
func callAsFunction<T: ElementaryFunctions>(_ x: T) -> T { T.sin(x) }
102+
}
103+
struct Functor_tan: Functor1 {
104+
func callAsFunction<T: ElementaryFunctions>(_ x: T) -> T { T.tan(x) }
105+
}
106+
struct Functor_log: Functor1 {
107+
func callAsFunction<T: ElementaryFunctions>(_ x: T) -> T { T.log(x) }
108+
}
109+
struct Functor_log1p: Functor1 {
110+
func callAsFunction<T: ElementaryFunctions>(_ x: T) -> T { T.log(onePlus: x) }
111+
}
112+
struct Functor_acosh: Functor1 {
113+
func callAsFunction<T: ElementaryFunctions>(_ x: T) -> T { T.acosh(x) }
114+
}
115+
struct Functor_asinh: Functor1 {
116+
func callAsFunction<T: ElementaryFunctions>(_ x: T) -> T { T.asinh(x) }
117+
}
118+
struct Functor_atanh: Functor1 {
119+
func callAsFunction<T: ElementaryFunctions>(_ x: T) -> T { T.atanh(x) }
120+
}
121+
struct Functor_acos: Functor1 {
122+
func callAsFunction<T: ElementaryFunctions>(_ x: T) -> T { T.acos(x) }
123+
}
124+
struct Functor_asin: Functor1 {
125+
func callAsFunction<T: ElementaryFunctions>(_ x: T) -> T { T.asin(x) }
126+
}
127+
struct Functor_atan: Functor1 {
128+
func callAsFunction<T: ElementaryFunctions>(_ x: T) -> T { T.atan(x) }
129+
}
130+
struct Functor_sqrt: Functor1 {
131+
func callAsFunction<T: ElementaryFunctions>(_ x: T) -> T { T.sqrt(x) }
132+
}
133+
struct Functor_pow: Functor1 {
134+
var n: Int
135+
func callAsFunction<T: ElementaryFunctions>(_ x: T) -> T { T.pow(x, n) }
136+
}
137+
struct Functor_pow2: Functor2 {
138+
func callAsFunction<T: ElementaryFunctions>(_ x: T, _ y: T) -> T { T.pow(x, y) }
139+
}
140+
struct Functor_root: Functor1 {
141+
var n: Int
142+
func callAsFunction<T: ElementaryFunctions>(_ x: T) -> T { T.root(x, n) }
143+
}
144+
145+
extension ElementaryFunctions {
146+
public static func exp(_ x: Self) -> Self { .init(mapped: Functor_exp(), x) }
147+
public static func expMinusOne(_ x: Self) -> Self { .init(mapped: Functor_expMinusOne(), x) }
148+
public static func tanh(_ x: Self) -> Self { .init(mapped: Functor_tanh(), x) }
149+
public static func cosh(_ x: Self) -> Self { .init(mapped: Functor_cosh(), x) }
150+
public static func sinh(_ x: Self) -> Self { .init(mapped: Functor_sinh(), x) }
151+
public static func cos(_ x: Self) -> Self { .init(mapped: Functor_cos(), x) }
152+
public static func sin(_ x: Self) -> Self { .init(mapped: Functor_sin(), x) }
153+
public static func tan(_ x: Self) -> Self { .init(mapped: Functor_tan(), x) }
154+
public static func log(_ x: Self) -> Self { .init(mapped: Functor_log(), x) }
155+
public static func log(onePlus x: Self) -> Self { .init(mapped: Functor_log1p(), x) }
156+
public static func acosh(_ x: Self) -> Self { .init(mapped: Functor_acosh(), x) }
157+
public static func asinh(_ x: Self) -> Self { .init(mapped: Functor_asinh(), x) }
158+
public static func atanh(_ x: Self) -> Self { .init(mapped: Functor_atanh(), x) }
159+
public static func acos(_ x: Self) -> Self { .init(mapped: Functor_acos(), x) }
160+
public static func asin(_ x: Self) -> Self { .init(mapped: Functor_asin(), x) }
161+
public static func atan(_ x: Self) -> Self { .init(mapped: Functor_atan(), x) }
162+
public static func sqrt(_ x: Self) -> Self { .init(mapped: Functor_sqrt(), x) }
163+
public static func pow(_ x: Self, _ n: Int) -> Self { .init(mapped: Functor_pow(n: n), x) }
164+
public static func root(_ x: Self, _ n: Int) -> Self { .init(mapped: Functor_root(n: n), x) }
165+
public static func pow(_ x: Self, _ y: Self) -> Self { .init(mapped: Functor_pow2(), x, y) }
166+
}
167+
#endif

Sources/TensorFlow/Layer.swift

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,9 +256,15 @@ extension Layer {
256256
public func appliedForBackpropagation(to input: Input)
257257
-> (output: Output, backpropagator: Backpropagator)
258258
{
259+
#if TENSORFLOW_USE_STANDARD_TOOLCHAIN
260+
let (out, pullback) = _Differentiation.valueWithPullback(at: self, input) { layer, input in
261+
return layer(input)
262+
}
263+
#else
259264
let (out, pullback) = Swift.valueWithPullback(at: self, input) { layer, input in
260265
return layer(input)
261266
}
267+
#endif
262268
return (out, pullback)
263269
}
264270
}

Sources/TensorFlow/Operators/Math.swift

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,13 @@ extension Tensor: ElementaryFunctions where Scalar: TensorFlowFloatingPoint {
266266
_Raw.expm1(x)
267267
}
268268

269+
#if TENSORFLOW_USE_STANDARD_TOOLCHAIN
270+
@differentiable
271+
public static func expMinusOne(_ x: Self) -> Self {
272+
return expm1(x)
273+
}
274+
#endif
275+
269276
@inlinable
270277
@derivative(of: expm1)
271278
internal static func _vjpExpm1(
@@ -282,7 +289,7 @@ extension Tensor: ElementaryFunctions where Scalar: TensorFlowFloatingPoint {
282289
}
283290

284291
@inlinable
285-
@derivative(of: log)
292+
@derivative(of: log(_:))
286293
internal static func _vjpLog(
287294
_ x: Tensor
288295
) -> (value: Tensor, pullback: (Tensor) -> Tensor) {
@@ -307,6 +314,13 @@ extension Tensor: ElementaryFunctions where Scalar: TensorFlowFloatingPoint {
307314
_Raw.log1p(x)
308315
}
309316

317+
#if TENSORFLOW_USE_STANDARD_TOOLCHAIN
318+
@differentiable
319+
public static func log(onePlus x: Self) -> Self {
320+
return log1p(x)
321+
}
322+
#endif
323+
310324
@inlinable
311325
@derivative(of: log1p)
312326
internal static func _vjpLog1p(

Sources/TensorFlow/StdlibExtensions.swift

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,6 @@ where Element: Differentiable {
244244
public init() { self.init(.init()) }
245245
}
246246

247-
#if !TENSORFLOW_USE_STANDARD_TOOLCHAIN
248247
extension Array.DifferentiableView: VectorProtocol
249248
where Element: Differentiable & VectorProtocol {
250249
public typealias VectorSpaceScalar = Element.VectorSpaceScalar
@@ -301,7 +300,6 @@ where Element: Differentiable & PointwiseMultiplicative {
301300
}
302301
}
303302
}
304-
#endif
305303

306304
extension Collection {
307305
/// Returns the `n`th position in `self`.

Sources/x10/CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@ target_link_libraries(x10_optimizers_tensor_visitor_plan PUBLIC
1010
add_library(x10_optimizers_optimizer SHARED
1111
swift_bindings/optimizers/Optimizer.swift
1212
swift_bindings/optimizers/Optimizers.swift)
13+
target_compile_definitions(x10_optimizers_optimizer PRIVATE
14+
$<$<BOOL:${TENSORFLOW_USE_STANDARD_TOOLCHAIN}>:TENSORFLOW_USE_STANDARD_TOOLCHAIN>
15+
)
1316
set_target_properties(x10_optimizers_optimizer PROPERTIES
1417
INTERFACE_INCLUDE_DIRECTORIES ${CMAKE_Swift_MODULE_DIRECTORY})
1518
target_link_libraries(x10_optimizers_optimizer PUBLIC

Tests/ExperimentalTests/ComplexTests.swift

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import XCTest
1616

17+
import _Differentiation
1718
@testable import Experimental
1819

1920
final class ComplexTests: XCTestCase {

Tests/TensorFlowTests/CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,3 +42,6 @@ target_link_libraries(TensorFlowTests PUBLIC
4242
TensorFlow
4343
Tensor
4444
XCTest)
45+
target_compile_definitions(TensorFlowTests PRIVATE
46+
$<$<BOOL:${TENSORFLOW_USE_STANDARD_TOOLCHAIN}>:TENSORFLOW_USE_STANDARD_TOOLCHAIN>
47+
)

Tests/TensorFlowTests/OperatorTests/MathTests.swift

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,11 +47,19 @@ final class MathOperatorTests: XCTestCase {
4747
testElementaryFunction(name: "exp", exp, Float.exp)
4848
testElementaryFunction(name: "exp2", exp2, Float.exp2)
4949
testElementaryFunction(name: "exp10", exp10, Float.exp10)
50+
#if TENSORFLOW_USE_STANDARD_TOOLCHAIN
51+
testElementaryFunction(name: "expm1", expm1, Float.expMinusOne)
52+
#else
5053
testElementaryFunction(name: "expm1", expm1, Float.expm1)
51-
testElementaryFunction(name: "log", log, Float.log)
54+
#endif
55+
testElementaryFunction(name: "log", log, { Float.log($0) })
5256
testElementaryFunction(name: "log2", log2, Float.log2)
5357
testElementaryFunction(name: "log10", log10, Float.log10)
58+
#if TENSORFLOW_USE_STANDARD_TOOLCHAIN
59+
testElementaryFunction(name: "log1p", log1p, {Float.log(onePlus: $0)})
60+
#else
5461
testElementaryFunction(name: "log1p", log1p, Float.log1p)
62+
#endif
5563
testElementaryFunction(
5664
name: "pow",
5765
{ x in pow(x, x) }, { x in Float.pow(x, x) })

0 commit comments

Comments
 (0)