Skip to content

Commit da36555

Browse files
committed
[AutoDiff upstream] Gardening.
- Standardize filenames: `XXXDifferentiation.swift`. - Use Pascal or snake case consistently. - Formatting changes.
1 parent 36cf566 commit da36555

10 files changed

+409
-313
lines changed

stdlib/public/Differentiation/ArrayDifferentiation.swift

Lines changed: 100 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1-
//===--- ArrayDifferentiation.swift -------------------------*- swift -*-===//
1+
//===--- ArrayDifferentiation.swift ---------------------------*- swift -*-===//
22
//
33
// This source file is part of the Swift.org open source project
44
//
5-
// Copyright (c) 2014 - 2020 Apple Inc. and the Swift project authors
5+
// Copyright (c) 2019 - 2020 Apple Inc. and the Swift project authors
66
// Licensed under Apache License v2.0 with Runtime Library Exception
77
//
88
// See https://swift.org/LICENSE.txt for license information
@@ -12,7 +12,11 @@
1212

1313
import Swift
1414

15-
// TODO(TF-938): Add 'Element : Differentiable' constraint.
15+
//===----------------------------------------------------------------------===//
16+
// Protocol conformances
17+
//===----------------------------------------------------------------------===//
18+
19+
// TODO(TF-938): Add `Element: Differentiable` requirement.
1620
extension Array {
1721
/// The view of an array as the differentiable product manifold of `Element`
1822
/// multiplied with itself `count` times.
@@ -22,7 +26,8 @@ extension Array {
2226
}
2327
}
2428

25-
extension Array.DifferentiableView : Differentiable where Element : Differentiable {
29+
extension Array.DifferentiableView: Differentiable
30+
where Element: Differentiable {
2631
/// The viewed array.
2732
public var base: [Element] {
2833
get { return _base }
@@ -31,8 +36,9 @@ extension Array.DifferentiableView : Differentiable where Element : Differentiab
3136

3237
@usableFromInline
3338
@derivative(of: base)
34-
func _vjpBase() ->
35-
(value: [Element], pullback: (Array<Element>.TangentVector) -> TangentVector) {
39+
func _vjpBase() -> (
40+
value: [Element], pullback: (Array<Element>.TangentVector) -> TangentVector
41+
) {
3642
return (base, { $0 })
3743
}
3844

@@ -41,8 +47,9 @@ extension Array.DifferentiableView : Differentiable where Element : Differentiab
4147

4248
@usableFromInline
4349
@derivative(of: init(_:))
44-
static func _vjpInit(_ base: [Element]) ->
45-
(value: Array.DifferentiableView, pullback: (TangentVector) -> TangentVector) {
50+
static func _vjpInit(_ base: [Element]) -> (
51+
value: Array.DifferentiableView, pullback: (TangentVector) -> TangentVector
52+
) {
4653
return (Array.DifferentiableView(base), { $0 })
4754
}
4855

@@ -52,16 +59,16 @@ extension Array.DifferentiableView : Differentiable where Element : Differentiab
5259
public mutating func move(along direction: TangentVector) {
5360
precondition(
5461
base.count == direction.base.count,
55-
"cannot move Array.DifferentiableView with count \(base.count) along " +
56-
"direction with different count \(direction.base.count)")
62+
"cannot move Array.DifferentiableView with count \(base.count) along "
63+
+ "direction with different count \(direction.base.count)")
5764
for i in base.indices {
5865
base[i].move(along: direction.base[i])
5966
}
6067
}
6168
}
6269

63-
extension Array.DifferentiableView : Equatable
64-
where Element : Differentiable & Equatable {
70+
extension Array.DifferentiableView: Equatable
71+
where Element: Differentiable & Equatable {
6572
public static func == (
6673
lhs: Array.DifferentiableView,
6774
rhs: Array.DifferentiableView
@@ -70,15 +77,15 @@ extension Array.DifferentiableView : Equatable
7077
}
7178
}
7279

73-
extension Array.DifferentiableView : ExpressibleByArrayLiteral
74-
where Element : Differentiable {
80+
extension Array.DifferentiableView: ExpressibleByArrayLiteral
81+
where Element: Differentiable {
7582
public init(arrayLiteral elements: Element...) {
7683
self.init(elements)
7784
}
7885
}
7986

80-
extension Array.DifferentiableView : CustomStringConvertible
81-
where Element : Differentiable {
87+
extension Array.DifferentiableView: CustomStringConvertible
88+
where Element: Differentiable {
8289
public var description: String {
8390
return base.description
8491
}
@@ -88,8 +95,8 @@ extension Array.DifferentiableView : CustomStringConvertible
8895
///
8996
/// Note that `Array.DifferentiableView([])` is the zero in the product spaces
9097
/// of all counts.
91-
extension Array.DifferentiableView : AdditiveArithmetic
92-
where Element : AdditiveArithmetic & Differentiable {
98+
extension Array.DifferentiableView: AdditiveArithmetic
99+
where Element: AdditiveArithmetic & Differentiable {
93100

94101
public static var zero: Array.DifferentiableView {
95102
return Array.DifferentiableView([])
@@ -100,10 +107,10 @@ extension Array.DifferentiableView : AdditiveArithmetic
100107
rhs: Array.DifferentiableView
101108
) -> Array.DifferentiableView {
102109
precondition(
103-
lhs.base.count == 0 || rhs.base.count == 0 ||
104-
lhs.base.count == rhs.base.count,
105-
"cannot add Array.DifferentiableViews with different counts: " +
106-
"\(lhs.base.count) and \(rhs.base.count)")
110+
lhs.base.count == 0 || rhs.base.count == 0
111+
|| lhs.base.count == rhs.base.count,
112+
"cannot add Array.DifferentiableViews with different counts: "
113+
+ "\(lhs.base.count) and \(rhs.base.count)")
107114
if lhs.base.count == 0 {
108115
return rhs
109116
}
@@ -118,10 +125,10 @@ extension Array.DifferentiableView : AdditiveArithmetic
118125
rhs: Array.DifferentiableView
119126
) -> Array.DifferentiableView {
120127
precondition(
121-
lhs.base.count == 0 || rhs.base.count == 0 ||
122-
lhs.base.count == rhs.base.count,
123-
"cannot subtract Array.DifferentiableViews with different counts: " +
124-
"\(lhs.base.count) and \(rhs.base.count)")
128+
lhs.base.count == 0 || rhs.base.count == 0
129+
|| lhs.base.count == rhs.base.count,
130+
"cannot subtract Array.DifferentiableViews with different counts: "
131+
+ "\(lhs.base.count) and \(rhs.base.count)")
125132
if lhs.base.count == 0 {
126133
return rhs
127134
}
@@ -143,7 +150,7 @@ extension Array.DifferentiableView : AdditiveArithmetic
143150

144151
/// Makes `Array` differentiable as the product manifold of `Element`
145152
/// multiplied with itself `count` times.
146-
extension Array : Differentiable where Element : Differentiable {
153+
extension Array: Differentiable where Element: Differentiable {
147154
// In an ideal world, `TangentVector` would be `[Element.TangentVector]`.
148155
// Unfortunately, we cannot conform `Array` to `AdditiveArithmetic` for
149156
// `TangentVector` because `Array` already has a static `+` method with
@@ -167,14 +174,18 @@ extension Array : Differentiable where Element : Differentiable {
167174
}
168175
}
169176

170-
extension Array where Element : Differentiable {
177+
//===----------------------------------------------------------------------===//
178+
// Derivatives
179+
//===----------------------------------------------------------------------===//
180+
181+
extension Array where Element: Differentiable {
171182
@usableFromInline
172183
@derivative(of: subscript)
173-
func _vjpSubscript(index: Int) ->
174-
(value: Element, pullback: (Element.TangentVector) -> TangentVector)
175-
{
184+
func _vjpSubscript(index: Int) -> (
185+
value: Element, pullback: (Element.TangentVector) -> TangentVector
186+
) {
176187
func pullback(_ gradientIn: Element.TangentVector) -> TangentVector {
177-
var gradientOut = Array<Element.TangentVector>(
188+
var gradientOut = [Element.TangentVector](
178189
repeating: .zero,
179190
count: count)
180191
gradientOut[index] = gradientIn
@@ -185,22 +196,27 @@ extension Array where Element : Differentiable {
185196

186197
@usableFromInline
187198
@derivative(of: +)
188-
static func _vjpConcatenate(_ lhs: [Element], _ rhs: [Element]) ->
189-
(value: [Element], pullback: (TangentVector) -> (TangentVector, TangentVector)) {
190-
func pullback(_ gradientIn: TangentVector) ->
191-
(TangentVector, TangentVector) {
192-
precondition(
193-
gradientIn.base.count == lhs.count + rhs.count,
194-
"+ should receive gradient with count equal to sum of operand " +
195-
"counts, but counts are: gradient \(gradientIn.base.count), " +
196-
"lhs \(lhs.count), rhs \(rhs.count)")
197-
return (
198-
TangentVector(Array<Element.TangentVector>(
199+
static func _vjpConcatenate(_ lhs: [Element], _ rhs: [Element]) -> (
200+
value: [Element],
201+
pullback: (TangentVector) -> (TangentVector, TangentVector)
202+
) {
203+
func pullback(_ gradientIn: TangentVector) -> (TangentVector, TangentVector)
204+
{
205+
precondition(
206+
gradientIn.base.count == lhs.count + rhs.count,
207+
"+ should receive gradient with count equal to sum of operand "
208+
+ "counts, but counts are: gradient \(gradientIn.base.count), "
209+
+ "lhs \(lhs.count), rhs \(rhs.count)")
210+
return (
211+
TangentVector(
212+
[Element.TangentVector](
199213
gradientIn.base[0..<lhs.count])),
200-
TangentVector(Array<Element.TangentVector>(
201-
gradientIn.base[lhs.count...])))
202-
}
203-
return (lhs + rhs, pullback)
214+
TangentVector(
215+
[Element.TangentVector](
216+
gradientIn.base[lhs.count...]))
217+
)
218+
}
219+
return (lhs + rhs, pullback)
204220
}
205221
}
206222

@@ -218,7 +234,8 @@ extension Array where Element: Differentiable {
218234
@usableFromInline
219235
@derivative(of: append)
220236
mutating func _jvpAppend(_ element: Element) -> (
221-
value: Void, differential: (inout TangentVector, Element.TangentVector) -> Void
237+
value: Void,
238+
differential: (inout TangentVector, Element.TangentVector) -> Void
222239
) {
223240
append(element)
224241
return ((), { $0.base.append($1) })
@@ -231,19 +248,22 @@ extension Array where Element: Differentiable {
231248
static func _vjpInit(repeating repeatedValue: Element, count: Int) -> (
232249
value: Self, pullback: (TangentVector) -> Element.TangentVector
233250
) {
234-
(value: Self(repeating: repeatedValue, count: count), pullback: { v in
235-
v.base.reduce(.zero, +)
236-
})
251+
(
252+
value: Self(repeating: repeatedValue, count: count),
253+
pullback: { v in
254+
v.base.reduce(.zero, +)
255+
}
256+
)
237257
}
238258
}
239259

240260
//===----------------------------------------------------------------------===//
241261
// Differentiable higher order functions for collections
242262
//===----------------------------------------------------------------------===//
243263

244-
public extension Array where Element: Differentiable {
264+
extension Array where Element: Differentiable {
245265
@differentiable(wrt: (self, initialResult))
246-
func differentiableReduce<Result: Differentiable>(
266+
public func differentiableReduce<Result: Differentiable>(
247267
_ initialResult: Result,
248268
_ nextPartialResult: @differentiable (Result, Element) -> Result
249269
) -> Result {
@@ -255,12 +275,14 @@ public extension Array where Element: Differentiable {
255275
internal func _vjpDifferentiableReduce<Result: Differentiable>(
256276
_ initialResult: Result,
257277
_ nextPartialResult: @differentiable (Result, Element) -> Result
258-
) -> (value: Result,
259-
pullback: (Result.TangentVector)
260-
-> (Array.TangentVector, Result.TangentVector)) {
278+
) -> (
279+
value: Result,
280+
pullback: (Result.TangentVector)
281+
-> (Array.TangentVector, Result.TangentVector)
282+
) {
261283
var pullbacks:
262-
[(Result.TangentVector) -> (Result.TangentVector, Element.TangentVector)]
263-
= []
284+
[(Result.TangentVector) -> (Result.TangentVector, Element.TangentVector)] =
285+
[]
264286
let count = self.count
265287
pullbacks.reserveCapacity(count)
266288
var result = initialResult
@@ -270,23 +292,26 @@ public extension Array where Element: Differentiable {
270292
result = y
271293
pullbacks.append(pb)
272294
}
273-
return (value: result, pullback: { tangent in
274-
var resultTangent = tangent
275-
var elementTangents = TangentVector([])
276-
elementTangents.base.reserveCapacity(count)
277-
for pullback in pullbacks.reversed() {
278-
let (newResultTangent, elementTangent) = pullback(resultTangent)
279-
resultTangent = newResultTangent
280-
elementTangents.base.append(elementTangent)
295+
return (
296+
value: result,
297+
pullback: { tangent in
298+
var resultTangent = tangent
299+
var elementTangents = TangentVector([])
300+
elementTangents.base.reserveCapacity(count)
301+
for pullback in pullbacks.reversed() {
302+
let (newResultTangent, elementTangent) = pullback(resultTangent)
303+
resultTangent = newResultTangent
304+
elementTangents.base.append(elementTangent)
305+
}
306+
return (TangentVector(elementTangents.base.reversed()), resultTangent)
281307
}
282-
return (TangentVector(elementTangents.base.reversed()), resultTangent)
283-
})
308+
)
284309
}
285310
}
286311

287-
public extension Array where Element: Differentiable {
312+
extension Array where Element: Differentiable {
288313
@differentiable(wrt: self)
289-
func differentiableMap<Result: Differentiable>(
314+
public func differentiableMap<Result: Differentiable>(
290315
_ body: @differentiable (Element) -> Result
291316
) -> [Result] {
292317
map(body)
@@ -296,8 +321,10 @@ public extension Array where Element: Differentiable {
296321
@derivative(of: differentiableMap)
297322
internal func _vjpDifferentiableMap<Result: Differentiable>(
298323
_ body: @differentiable (Element) -> Result
299-
) -> (value: [Result],
300-
pullback: (Array<Result>.TangentVector) -> Array.TangentVector) {
324+
) -> (
325+
value: [Result],
326+
pullback: (Array<Result>.TangentVector) -> Array.TangentVector
327+
) {
301328
var values: [Result] = []
302329
var pullbacks: [(Result.TangentVector) -> Element.TangentVector] = []
303330
for x in self {
@@ -310,4 +337,4 @@ public extension Array where Element: Differentiable {
310337
}
311338
return (value: values, pullback: pullback)
312339
}
313-
}
340+
}

stdlib/public/Differentiation/CMakeLists.txt

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
#
33
# This source file is part of the Swift.org open source project
44
#
5-
# Copyright (c) 2014 - 2019 Apple Inc. and the Swift project authors
5+
# Copyright (c) 2019 - 2020 Apple Inc. and the Swift project authors
66
# Licensed under Apache License v2.0 with Runtime Library Exception
77
#
88
# See https://swift.org/LICENSE.txt for license information
@@ -17,10 +17,14 @@ add_swift_target_library(swift_Differentiation ${SWIFT_STDLIB_LIBRARY_BUILD_TYPE
1717
ArrayDifferentiation.swift
1818

1919
GYB_SOURCES
20-
tgmathDerivatives.swift.gyb
21-
SIMDVectorTypesDerivatives.swift.gyb
22-
FloatingPointTypesDerivatives.swift.gyb
20+
FloatingPointDifferentiation.swift.gyb
21+
TgmathDerivatives.swift.gyb
22+
SIMDDifferentiation.swift.gyb
2323

24+
SWIFT_MODULE_DEPENDS_OSX Darwin
25+
SWIFT_MODULE_DEPENDS_IOS Darwin
26+
SWIFT_MODULE_DEPENDS_TVOS Darwin
27+
SWIFT_MODULE_DEPENDS_WATCHOS Darwin
2428
SWIFT_MODULE_DEPENDS_LINUX Glibc
2529
SWIFT_MODULE_DEPENDS_FREEBSD Glibc
2630
SWIFT_MODULE_DEPENDS_CYGWIN Glibc

stdlib/public/Differentiation/FloatingPointTypesDerivatives.swift.gyb renamed to stdlib/public/Differentiation/FloatingPointDifferentiation.swift.gyb

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
//===--- FloatingPointTypesDerivatives.swift.gyb --------------*- swift -*-===//
1+
//===--- FloatingPointDifferentiation.swift.gyb ---------------*- swift -*-===//
22
//
33
// This source file is part of the Swift.org open source project
44
//
@@ -208,7 +208,6 @@ extension ${Self} {
208208
% end
209209
% end
210210

211-
212211
extension FloatingPoint where Self : Differentiable,
213212
Self == Self.TangentVector {
214213
/// The vector-Jacobian product function of `addingProduct`. Returns the

0 commit comments

Comments
 (0)