Skip to content

Commit f07ae89

Browse files
committed
First pass at upstreaming Differentiable conformances and derivatives
1 parent 1d423e4 commit f07ae89

10 files changed

+1941
-8
lines changed

lib/Sema/TypeCheckAttr.cpp

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4412,14 +4412,6 @@ static bool typeCheckDerivativeAttr(ASTContext &Ctx, Decl *D,
44124412
return true;
44134413
}
44144414

4415-
// Reject different-file derivative registration.
4416-
// TODO(TF-1021): Lift same-file derivative registration restriction.
4417-
if (originalAFD->getParentSourceFile() != derivative->getParentSourceFile()) {
4418-
diags.diagnose(attr->getLocation(),
4419-
diag::derivative_attr_not_in_same_file_as_original);
4420-
return true;
4421-
}
4422-
44234415
// Reject duplicate `@derivative` attributes.
44244416
auto &derivativeAttrs = Ctx.DerivativeAttrs[std::make_tuple(
44254417
originalAFD, resolvedDiffParamIndices, kind)];
Lines changed: 313 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,313 @@
1+
//===--- ArrayDifferentiation.swift -------------------------*- swift -*-===//
2+
//
3+
// This source file is part of the Swift.org open source project
4+
//
5+
// Copyright (c) 2014 - 2020 Apple Inc. and the Swift project authors
6+
// Licensed under Apache License v2.0 with Runtime Library Exception
7+
//
8+
// See https://swift.org/LICENSE.txt for license information
9+
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
import Swift
14+
15+
// TODO(TF-938): Add 'Element : Differentiable' constraint.
16+
extension Array {
17+
/// The view of an array as the differentiable product manifold of `Element`
18+
/// multiplied with itself `count` times.
19+
@frozen
20+
public struct DifferentiableView {
21+
var _base: [Element]
22+
}
23+
}
24+
25+
extension Array.DifferentiableView : Differentiable where Element : Differentiable {
26+
/// The viewed array.
27+
public var base: [Element] {
28+
get { return _base }
29+
_modify { yield &_base }
30+
}
31+
32+
@usableFromInline
33+
@derivative(of: base)
34+
func _vjpBase() ->
35+
(value: [Element], pullback: (Array<Element>.TangentVector) -> TangentVector) {
36+
return (base, { $0 })
37+
}
38+
39+
/// Creates a differentiable view of the given array.
40+
public init(_ base: [Element]) { self._base = base }
41+
42+
@usableFromInline
43+
@derivative(of: init(_:))
44+
static func _vjpInit(_ base: [Element]) ->
45+
(value: Array.DifferentiableView, pullback: (TangentVector) -> TangentVector) {
46+
return (Array.DifferentiableView(base), { $0 })
47+
}
48+
49+
public typealias TangentVector =
50+
Array<Element.TangentVector>.DifferentiableView
51+
52+
public mutating func move(along direction: TangentVector) {
53+
precondition(
54+
base.count == direction.base.count,
55+
"cannot move Array.DifferentiableView with count \(base.count) along " +
56+
"direction with different count \(direction.base.count)")
57+
for i in base.indices {
58+
base[i].move(along: direction.base[i])
59+
}
60+
}
61+
}
62+
63+
extension Array.DifferentiableView : Equatable
64+
where Element : Differentiable & Equatable {
65+
public static func == (
66+
lhs: Array.DifferentiableView,
67+
rhs: Array.DifferentiableView
68+
) -> Bool {
69+
return lhs.base == rhs.base
70+
}
71+
}
72+
73+
extension Array.DifferentiableView : ExpressibleByArrayLiteral
74+
where Element : Differentiable {
75+
public init(arrayLiteral elements: Element...) {
76+
self.init(elements)
77+
}
78+
}
79+
80+
extension Array.DifferentiableView : CustomStringConvertible
81+
where Element : Differentiable {
82+
public var description: String {
83+
return base.description
84+
}
85+
}
86+
87+
/// Makes `Array.DifferentiableView` additive as the product space.
88+
///
89+
/// Note that `Array.DifferentiableView([])` is the zero in the product spaces
90+
/// of all counts.
91+
extension Array.DifferentiableView : AdditiveArithmetic
92+
where Element : AdditiveArithmetic & Differentiable {
93+
94+
public static var zero: Array.DifferentiableView {
95+
return Array.DifferentiableView([])
96+
}
97+
98+
public static func + (
99+
lhs: Array.DifferentiableView,
100+
rhs: Array.DifferentiableView
101+
) -> Array.DifferentiableView {
102+
precondition(
103+
lhs.base.count == 0 || rhs.base.count == 0 ||
104+
lhs.base.count == rhs.base.count,
105+
"cannot add Array.DifferentiableViews with different counts: " +
106+
"\(lhs.base.count) and \(rhs.base.count)")
107+
if lhs.base.count == 0 {
108+
return rhs
109+
}
110+
if rhs.base.count == 0 {
111+
return lhs
112+
}
113+
return Array.DifferentiableView(zip(lhs.base, rhs.base).map(+))
114+
}
115+
116+
public static func - (
117+
lhs: Array.DifferentiableView,
118+
rhs: Array.DifferentiableView
119+
) -> Array.DifferentiableView {
120+
precondition(
121+
lhs.base.count == 0 || rhs.base.count == 0 ||
122+
lhs.base.count == rhs.base.count,
123+
"cannot subtract Array.DifferentiableViews with different counts: " +
124+
"\(lhs.base.count) and \(rhs.base.count)")
125+
if lhs.base.count == 0 {
126+
return rhs
127+
}
128+
if rhs.base.count == 0 {
129+
return lhs
130+
}
131+
return Array.DifferentiableView(zip(lhs.base, rhs.base).map(-))
132+
}
133+
134+
@inlinable
135+
public subscript(_ index: Int) -> Element {
136+
if index < base.count {
137+
return base[index]
138+
} else {
139+
return Element.zero
140+
}
141+
}
142+
}
143+
144+
/// Makes `Array` differentiable as the product manifold of `Element`
145+
/// multiplied with itself `count` times.
146+
extension Array : Differentiable where Element : Differentiable {
147+
// In an ideal world, `TangentVector` would be `[Element.TangentVector]`.
148+
// Unfortunately, we cannot conform `Array` to `AdditiveArithmetic` for
149+
// `TangentVector` because `Array` already has a static `+` method with
150+
// different semantics from `AdditiveArithmetic.+`. So we use
151+
// `Array.DifferentiableView` for all these associated types.
152+
public typealias TangentVector =
153+
Array<Element.TangentVector>.DifferentiableView
154+
155+
public mutating func move(along direction: TangentVector) {
156+
var view = DifferentiableView(self)
157+
view.move(along: direction)
158+
self = view.base
159+
}
160+
161+
/// A closure that produces a `TangentVector` of zeros with the same
162+
/// `count` as `self`.
163+
public var zeroTangentVectorInitializer: () -> TangentVector {
164+
{ [count = self.count] in
165+
TangentVector(.init(repeating: .zero, count: count))
166+
}
167+
}
168+
}
169+
170+
extension Array where Element : Differentiable {
171+
@usableFromInline
172+
@derivative(of: subscript)
173+
func _vjpSubscript(index: Int) ->
174+
(value: Element, pullback: (Element.TangentVector) -> TangentVector)
175+
{
176+
func pullback(_ gradientIn: Element.TangentVector) -> TangentVector {
177+
var gradientOut = Array<Element.TangentVector>(
178+
repeating: .zero,
179+
count: count)
180+
gradientOut[index] = gradientIn
181+
return TangentVector(gradientOut)
182+
}
183+
return (self[index], pullback)
184+
}
185+
186+
@usableFromInline
187+
@derivative(of: +)
188+
static func _vjpConcatenate(_ lhs: [Element], _ rhs: [Element]) ->
189+
(value: [Element], pullback: (TangentVector) -> (TangentVector, TangentVector)) {
190+
func pullback(_ gradientIn: TangentVector) ->
191+
(TangentVector, TangentVector) {
192+
precondition(
193+
gradientIn.base.count == lhs.count + rhs.count,
194+
"+ should receive gradient with count equal to sum of operand " +
195+
"counts, but counts are: gradient \(gradientIn.base.count), " +
196+
"lhs \(lhs.count), rhs \(rhs.count)")
197+
return (
198+
TangentVector(Array<Element.TangentVector>(
199+
gradientIn.base[0..<lhs.count])),
200+
TangentVector(Array<Element.TangentVector>(
201+
gradientIn.base[lhs.count...])))
202+
}
203+
return (lhs + rhs, pullback)
204+
}
205+
}
206+
207+
extension Array where Element: Differentiable {
208+
@usableFromInline
209+
@derivative(of: append)
210+
mutating func _vjpAppend(_ element: Element) -> (
211+
value: Void, pullback: (inout TangentVector) -> Element.TangentVector
212+
) {
213+
let appendedElementIndex = count
214+
defer { append(element) }
215+
return ((), { dself in dself.base[appendedElementIndex] })
216+
}
217+
218+
@usableFromInline
219+
@derivative(of: append)
220+
mutating func _jvpAppend(_ element: Element) -> (
221+
value: Void, differential: (inout TangentVector, Element.TangentVector) -> Void
222+
) {
223+
append(element)
224+
return ((), { $0.base.append($1) })
225+
}
226+
}
227+
228+
extension Array where Element: Differentiable {
229+
@usableFromInline
230+
@derivative(of: init(repeating:count:))
231+
static func _vjpInit(repeating repeatedValue: Element, count: Int) -> (
232+
value: Self, pullback: (TangentVector) -> Element.TangentVector
233+
) {
234+
(value: Self(repeating: repeatedValue, count: count), pullback: { v in
235+
v.base.reduce(.zero, +)
236+
})
237+
}
238+
}
239+
240+
//===----------------------------------------------------------------------===//
241+
// Differentiable higher order functions for collections
242+
//===----------------------------------------------------------------------===//
243+
244+
public extension Array where Element: Differentiable {
245+
@differentiable(wrt: (self, initialResult))
246+
func differentiableReduce<Result: Differentiable>(
247+
_ initialResult: Result,
248+
_ nextPartialResult: @differentiable (Result, Element) -> Result
249+
) -> Result {
250+
reduce(initialResult, nextPartialResult)
251+
}
252+
253+
@usableFromInline
254+
@derivative(of: differentiableReduce)
255+
internal func _vjpDifferentiableReduce<Result: Differentiable>(
256+
_ initialResult: Result,
257+
_ nextPartialResult: @differentiable (Result, Element) -> Result
258+
) -> (value: Result,
259+
pullback: (Result.TangentVector)
260+
-> (Array.TangentVector, Result.TangentVector)) {
261+
var pullbacks:
262+
[(Result.TangentVector) -> (Result.TangentVector, Element.TangentVector)]
263+
= []
264+
let count = self.count
265+
pullbacks.reserveCapacity(count)
266+
var result = initialResult
267+
for element in self {
268+
let (y, pb) =
269+
valueWithPullback(at: result, element, in: nextPartialResult)
270+
result = y
271+
pullbacks.append(pb)
272+
}
273+
return (value: result, pullback: { tangent in
274+
var resultTangent = tangent
275+
var elementTangents = TangentVector([])
276+
elementTangents.base.reserveCapacity(count)
277+
for pullback in pullbacks.reversed() {
278+
let (newResultTangent, elementTangent) = pullback(resultTangent)
279+
resultTangent = newResultTangent
280+
elementTangents.base.append(elementTangent)
281+
}
282+
return (TangentVector(elementTangents.base.reversed()), resultTangent)
283+
})
284+
}
285+
}
286+
287+
public extension Array where Element: Differentiable {
288+
@differentiable(wrt: self)
289+
func differentiableMap<Result: Differentiable>(
290+
_ body: @differentiable (Element) -> Result
291+
) -> [Result] {
292+
map(body)
293+
}
294+
295+
@usableFromInline
296+
@derivative(of: differentiableMap)
297+
internal func _vjpDifferentiableMap<Result: Differentiable>(
298+
_ body: @differentiable (Element) -> Result
299+
) -> (value: [Result],
300+
pullback: (Array<Result>.TangentVector) -> Array.TangentVector) {
301+
var values: [Result] = []
302+
var pullbacks: [(Result.TangentVector) -> Element.TangentVector] = []
303+
for x in self {
304+
let (y, pb) = valueWithPullback(at: x, in: body)
305+
values.append(y)
306+
pullbacks.append(pb)
307+
}
308+
func pullback(_ tans: Array<Result>.TangentVector) -> Array.TangentVector {
309+
.init(zip(tans.base, pullbacks).map { tan, pb in pb(tan) })
310+
}
311+
return (value: values, pullback: pullback)
312+
}
313+
}

stdlib/public/Differentiation/CMakeLists.txt

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,18 @@ add_swift_target_library(swift_Differentiation ${SWIFT_STDLIB_LIBRARY_BUILD_TYPE
1414
Differentiable.swift
1515
DifferentialOperators.swift
1616
DifferentiationUtilities.swift
17+
ArrayDifferentiation.swift
18+
19+
GYB_SOURCES
20+
tgmathDerivatives.swift.gyb
21+
SIMDVectorTypesDerivatives.swift.gyb
22+
FloatingPointTypesDerivatives.swift.gyb
23+
24+
SWIFT_MODULE_DEPENDS_LINUX Glibc
25+
SWIFT_MODULE_DEPENDS_FREEBSD Glibc
26+
SWIFT_MODULE_DEPENDS_CYGWIN Glibc
27+
SWIFT_MODULE_DEPENDS_HAIKU Glibc
28+
SWIFT_MODULE_DEPENDS_WINDOWS MSVCRT
1729

1830
SWIFT_COMPILE_FLAGS
1931
${SWIFT_STANDARD_LIBRARY_SWIFT_FLAGS}

0 commit comments

Comments
 (0)