Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit fb720c5

Browse files
authored
Add reflection impl for EuclideanDifferentiable. (#1133)
1 parent 51af7c8 commit fb720c5

File tree

3 files changed

+155
-20
lines changed

3 files changed

+155
-20
lines changed

Sources/TensorFlow/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ add_library(TensorFlow SHARED
3636
Core/TensorShape.swift
3737
Core/Threading.swift
3838
Core/Utilities.swift
39+
Core/EuclideanDifferentiable.swift
3940

4041
Epochs/Algorithms.swift
4142
Epochs/Backend.swift
@@ -84,6 +85,7 @@ if(ENABLE_PYTHON_SUPPORT)
8485
endif()
8586
target_compile_definitions(TensorFlow PRIVATE
8687
USING_X10_BACKEND
88+
$<$<BOOL:${TENSORFLOW_USE_STANDARD_TOOLCHAIN}>:TENSORFLOW_USE_STANDARD_TOOLCHAIN>
8789
DEFAULT_BACKEND_EAGER)
8890
target_link_libraries(TensorFlow PRIVATE
8991
CX10
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
// Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
import _Differentiation
16+
17+
#if TENSORFLOW_USE_STANDARD_TOOLCHAIN
18+
@_spi(Reflection) import Swift
19+
20+
func listFields<Root>(of type: Root.Type) -> [(String, PartialKeyPath<Root>)] {
21+
var out = [(String, PartialKeyPath<Root>)]()
22+
_forEachFieldWithKeyPath(of: type, options: .ignoreUnknown) { name, kp in
23+
out.append((String(validatingUTF8: name)!, kp))
24+
return true
25+
}
26+
return out
27+
}
28+
29+
extension Differentiable {
30+
static var differentiableFields: [(String, PartialKeyPath<Self>, PartialKeyPath<TangentVector>)]
31+
{
32+
let tangentFields = listFields(of: TangentVector.self)
33+
var i = 0
34+
var out = [(String, PartialKeyPath<Self>, PartialKeyPath<TangentVector>)]()
35+
_forEachFieldWithKeyPath(of: Self.self, options: .ignoreUnknown) { cname, kp in
36+
if i >= tangentFields.count { return false }
37+
let name = String(validatingUTF8: cname)!
38+
if tangentFields[i].0 == name {
39+
out.append((name, kp, tangentFields[i].1))
40+
i += 1
41+
}
42+
return true
43+
}
44+
return out
45+
}
46+
}
47+
48+
public protocol _EuclideanDifferentiable {
49+
static func _copyWeightsToTangentVector<Root: Differentiable>(
50+
_ base: Root, _ out: inout Root.TangentVector,
51+
_ keyPathBase: PartialKeyPath<Root>,
52+
_ keyPathOut: PartialKeyPath<Root.TangentVector>
53+
)
54+
}
55+
56+
public protocol EuclideanDifferentiable: _EuclideanDifferentiable & Differentiable {
57+
var differentiableVectorView: TangentVector { get }
58+
func _copyWeightsToTangentVector(_ out: inout TangentVector)
59+
}
60+
61+
extension EuclideanDifferentiable where TangentVector == Self {
62+
public var differentiableVectorView: TangentVector { _read { yield self } }
63+
public func _copyWeightsToTangentVector(_ out: inout TangentVector) {
64+
out = differentiableVectorView
65+
}
66+
}
67+
68+
extension EuclideanDifferentiable {
69+
public static func _copyWeightsToTangentVector<Root: Differentiable>(
70+
_ base: Root, _ out: inout Root.TangentVector,
71+
_ keyPathBase: PartialKeyPath<Root>,
72+
_ keyPathOut: PartialKeyPath<Root.TangentVector>
73+
) {
74+
guard let keyPathBase = keyPathBase as? WritableKeyPath<Root, Self>,
75+
let keyPathOut = keyPathOut as? WritableKeyPath<Root.TangentVector, Self.TangentVector>
76+
else {
77+
fatalError("Failure to build differentiableVectorView via reflection: \(Self.self)")
78+
}
79+
base[keyPath: keyPathBase]._copyWeightsToTangentVector(&out[keyPath: keyPathOut])
80+
}
81+
public var differentiableVectorView: TangentVector {
82+
var out = TangentVector.zero
83+
_copyWeightsToTangentVector(&out)
84+
return out
85+
}
86+
public func _copyWeightsToTangentVector(_ out: inout TangentVector) {
87+
for (_, keyPathBase, keyPathOut) in Self.differentiableFields {
88+
let valueType = type(of: keyPathBase).valueType
89+
if let valueType = valueType as? _EuclideanDifferentiable.Type {
90+
valueType._copyWeightsToTangentVector(self, &out, keyPathBase, keyPathOut)
91+
} else {
92+
fatalError("Failure to build differentiableVectorView via reflection: \(valueType)")
93+
}
94+
}
95+
}
96+
}
97+
98+
extension Float: EuclideanDifferentiable {}
99+
extension Double: EuclideanDifferentiable {}
100+
101+
extension Array: EuclideanDifferentiable & _EuclideanDifferentiable
102+
where Element: EuclideanDifferentiable {
103+
public func _copyWeightsToTangentVector(_ out: inout TangentVector) {
104+
out = Array.DifferentiableView.TangentVector(self.map { $0.differentiableVectorView })
105+
}
106+
}
107+
extension Array.DifferentiableView: EuclideanDifferentiable & _EuclideanDifferentiable
108+
where Element: EuclideanDifferentiable {
109+
public func _copyWeightsToTangentVector(_ out: inout TangentVector) {
110+
out = Array.DifferentiableView.TangentVector(self.base.map { $0.differentiableVectorView })
111+
}
112+
}
113+
#endif

Sources/TensorFlow/StdlibExtensions.swift

Lines changed: 40 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,12 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414

15-
import _Differentiation
15+
@_exported import _Differentiation
1616
#if TENSORFLOW_USE_STANDARD_TOOLCHAIN
1717
import Numerics
1818
#endif
1919

20+
#if !TENSORFLOW_USE_STANDARD_TOOLCHAIN
2021
// MARK: - Array extensions
2122

2223
extension Array: ElementaryFunctions where Element: ElementaryFunctions {
@@ -107,6 +108,7 @@ extension Array: ElementaryFunctions where Element: ElementaryFunctions {
107108
/// For complex types, there is a branch cut along the negative real axis.
108109
public static func root(_ x: Self, _ n: Int) -> Self { x.map { Element.root($0, n) } }
109110
}
111+
#endif
110112

111113
// MARK: - Array derivative extensions
112114

@@ -116,47 +118,48 @@ where Element: Differentiable & ElementaryFunctions {
116118
///
117119
/// For real types, if `x` is negative the result is `.nan`. For complex
118120
/// types there is a branch cut on the negative real axis.
119-
public static func sqrt(_ x: Self) -> Self { .init(Array.sqrt(x.base)) }
121+
public static func sqrt(_ x: Self) -> Self { .init(x.map(Element.sqrt)) }
120122

121123
/// The cosine of `x`, interpreted as an angle in radians.
122-
public static func cos(_ x: Self) -> Self { .init(Array.cos(x.base)) }
124+
public static func cos(_ x: Self) -> Self { .init(x.map(Element.cos)) }
123125

124126
/// The sine of `x`, interpreted as an angle in radians.
125-
public static func sin(_ x: Self) -> Self { .init(Array.sin(x.base)) }
127+
public static func sin(_ x: Self) -> Self { .init(x.map(Element.sin)) }
126128

127129
/// The tangent of `x`, interpreted as an angle in radians.
128-
public static func tan(_ x: Self) -> Self { .init(Array.tan(x.base)) }
130+
public static func tan(_ x: Self) -> Self { .init(x.map(Element.tan)) }
129131

130132
/// The inverse cosine of `x` in radians.
131-
public static func acos(_ x: Self) -> Self { .init(Array.acos(x.base)) }
133+
public static func acos(_ x: Self) -> Self { .init(x.map(Element.acos)) }
132134

133135
/// The inverse sine of `x` in radians.
134-
public static func asin(_ x: Self) -> Self { .init(Array.asin(x.base)) }
136+
public static func asin(_ x: Self) -> Self { .init(x.map(Element.asin)) }
135137

136138
/// The inverse tangent of `x` in radians.
137-
public static func atan(_ x: Self) -> Self { .init(Array.atan(x.base)) }
139+
public static func atan(_ x: Self) -> Self { .init(x.map(Element.atan)) }
138140

139141
/// The hyperbolic cosine of `x`.
140-
public static func cosh(_ x: Self) -> Self { .init(Array.cosh(x.base)) }
142+
public static func cosh(_ x: Self) -> Self { .init(x.map(Element.cosh)) }
141143

142144
/// The hyperbolic sine of `x`.
143-
public static func sinh(_ x: Self) -> Self { .init(Array.sinh(x.base)) }
145+
public static func sinh(_ x: Self) -> Self { .init(x.map(Element.sinh)) }
144146

145147
/// The hyperbolic tangent of `x`.
146-
public static func tanh(_ x: Self) -> Self { .init(Array.tanh(x.base)) }
148+
public static func tanh(_ x: Self) -> Self { .init(x.map(Element.tanh)) }
147149

148150
/// The inverse hyperbolic cosine of `x`.
149-
public static func acosh(_ x: Self) -> Self { .init(Array.acosh(x.base)) }
151+
public static func acosh(_ x: Self) -> Self { .init(x.map(Element.acosh)) }
150152

151153
/// The inverse hyperbolic sine of `x`.
152-
public static func asinh(_ x: Self) -> Self { .init(Array.asinh(x.base)) }
154+
public static func asinh(_ x: Self) -> Self { .init(x.map(Element.asinh)) }
153155

154156
/// The inverse hyperbolic tangent of `x`.
155-
public static func atanh(_ x: Self) -> Self { .init(Array.atanh(x.base)) }
157+
public static func atanh(_ x: Self) -> Self { .init(x.map(Element.atanh)) }
156158

157159
/// The exponential function applied to `x`, or `e**x`.
158-
public static func exp(_ x: Self) -> Self { .init(Array.exp(x.base)) }
160+
public static func exp(_ x: Self) -> Self { .init(x.map(Element.exp)) }
159161

162+
#if !TENSORFLOW_USE_STANDARD_TOOLCHAIN
160163
/// Two raised to to power `x`.
161164
public static func exp2(_ x: Self) -> Self { .init(Array.exp2(x.base)) }
162165

@@ -165,36 +168,51 @@ where Element: Differentiable & ElementaryFunctions {
165168

166169
/// `exp(x) - 1` evaluated so as to preserve accuracy close to zero.
167170
public static func expm1(_ x: Self) -> Self { .init(Array.expm1(x.base)) }
171+
#else
172+
173+
/// `exp(x) - 1` evaluated so as to preserve accuracy close to zero.
174+
public static func expMinusOne(_ x: Self) -> Self { .init(x.map(Element.expMinusOne)) }
175+
#endif
168176

169177
/// The natural logarithm of `x`.
170-
public static func log(_ x: Self) -> Self { .init(Array.log(x.base)) }
178+
public static func log(_ x: Self) -> Self { .init(x.map { Element.exp($0) }) }
171179

180+
#if !TENSORFLOW_USE_STANDARD_TOOLCHAIN
172181
/// The base-two logarithm of `x`.
173182
public static func log2(_ x: Self) -> Self { .init(Array.log2(x.base)) }
174183

175184
/// The base-ten logarithm of `x`.
176185
public static func log10(_ x: Self) -> Self { .init(Array.log10(x.base)) }
177186

178187
/// `log(1 + x)` evaluated so as to preserve accuracy close to zero.
179-
public static func log1p(_ x: Self) -> Self { .init(Array.log1p(x.base)) }
188+
public static func log1p(_ x: Self) -> Self {
189+
.init(Array.log1p(x.base))
190+
}
191+
#else
192+
193+
/// The natural logarithm of `x + 1` to preserve accuracy close to zero.
194+
public static func log(onePlus x: Self) -> Self {
195+
.init(x.map { Element.log(onePlus: $0) })
196+
}
197+
#endif
180198

181199
/// `exp(y log(x))` computed without loss of intermediate precision.
182200
///
183201
/// For real types, if `x` is negative the result is NaN, even if `y` has
184202
/// an integral value. For complex types, there is a branch cut on the
185203
/// negative real axis.
186-
public static func pow(_ x: Self, _ y: Self) -> Self { .init(Array.pow(x.base, y.base)) }
204+
public static func pow(_ x: Self, _ y: Self) -> Self { .init(zip(x, y).map(Element.pow)) }
187205

188206
/// `x` raised to the `n`th power.
189207
///
190208
/// The product of `n` copies of `x`.
191-
public static func pow(_ x: Self, _ n: Int) -> Self { .init(Array.pow(x.base, n)) }
209+
public static func pow(_ x: Self, _ n: Int) -> Self { .init(x.map { Element.pow($0, n) }) }
192210

193211
/// The `n`th root of `x`.
194212
///
195213
/// For real types, if `x` is negative and `n` is even, the result is NaN.
196214
/// For complex types, there is a branch cut along the negative real axis.
197-
public static func root(_ x: Self, _ n: Int) -> Self { .init(Array.root(x.base, n)) }
215+
public static func root(_ x: Self, _ n: Int) -> Self { .init(x.map { Element.root($0, n) }) }
198216
}
199217

200218
extension Array.DifferentiableView:
@@ -226,6 +244,7 @@ where Element: Differentiable {
226244
public init() { self.init(.init()) }
227245
}
228246

247+
#if !TENSORFLOW_USE_STANDARD_TOOLCHAIN
229248
extension Array.DifferentiableView: VectorProtocol
230249
where Element: Differentiable & VectorProtocol {
231250
public typealias VectorSpaceScalar = Element.VectorSpaceScalar
@@ -282,6 +301,7 @@ where Element: Differentiable & PointwiseMultiplicative {
282301
}
283302
}
284303
}
304+
#endif
285305

286306
extension Collection {
287307
/// Returns the `n`th position in `self`.

0 commit comments

Comments
 (0)