Skip to content

Commit ece4fb7

Browse files
committed
[AutoDiff upstream] Add Float16 derivatives.
Add `@available` attributes for declarations using `Float16`.
1 parent e377436 commit ece4fb7

File tree

1 file changed

+50
-12
lines changed

1 file changed

+50
-12
lines changed

stdlib/public/Differentiation/FloatingPointDifferentiation.swift.gyb

Lines changed: 50 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,37 @@ import SwiftShims
1515

1616
% from SwiftFloatingPointTypes import all_floating_point_types
1717
% for self_type in all_floating_point_types():
18-
% Self = self_type.stdlib_name
18+
%{
19+
Self = self_type.stdlib_name
20+
bits = self_type.bits
1921

20-
% if self_type.bits == 80:
22+
def Availability(bits):
23+
if bits == 16:
24+
return '@available(macOS 9999, iOS 9999, tvOS 9999, watchOS 9999, *)'
25+
return ''
26+
}%
27+
28+
% if bits == 80:
2129
#if !(os(Windows) || os(Android)) && (arch(i386) || arch(x86_64))
2230
% end
2331

32+
//===----------------------------------------------------------------------===//
33+
// Protocol conformances
34+
//===----------------------------------------------------------------------===//
35+
36+
${Availability(bits)}
37+
extension ${Self}: Differentiable {
38+
${Availability(bits)}
39+
public typealias TangentVector = ${Self}
40+
41+
${Availability(bits)}
42+
public mutating func move(along direction: TangentVector) {
43+
self += direction
44+
}
45+
}
46+
2447
/// Derivatives of standard unary operators.
48+
${Availability(bits)}
2549
extension ${Self} {
2650
@usableFromInline
2751
@_transparent
@@ -41,7 +65,9 @@ extension ${Self} {
4165
}
4266

4367
/// Derivatives of standard binary operators.
68+
${Availability(bits)}
4469
extension ${Self} {
70+
${Availability(bits)}
4571
@inlinable // FIXME(sil-serialize-all)
4672
@_transparent
4773
@derivative(of: +)
@@ -51,6 +77,7 @@ extension ${Self} {
5177
return (lhs + rhs, { v in (v, v) })
5278
}
5379

80+
${Availability(bits)}
5481
@inlinable // FIXME(sil-serialize-all)
5582
@_transparent
5683
@derivative(of: +)
@@ -60,6 +87,7 @@ extension ${Self} {
6087
return (lhs + rhs, { (dlhs, drhs) in dlhs + drhs })
6188
}
6289

90+
${Availability(bits)}
6391
@inlinable // FIXME(sil-serialize-all)
6492
@_transparent
6593
@derivative(of: +=)
@@ -70,6 +98,7 @@ extension ${Self} {
7098
return ((), { v in v })
7199
}
72100

101+
${Availability(bits)}
73102
@inlinable // FIXME(sil-serialize-all)
74103
@_transparent
75104
@derivative(of: +=)
@@ -80,6 +109,7 @@ extension ${Self} {
80109
return ((), { $0 += $1 })
81110
}
82111

112+
${Availability(bits)}
83113
@inlinable // FIXME(sil-serialize-all)
84114
@_transparent
85115
@derivative(of: -)
@@ -89,6 +119,7 @@ extension ${Self} {
89119
return (lhs - rhs, { v in (v, -v) })
90120
}
91121

122+
${Availability(bits)}
92123
@inlinable // FIXME(sil-serialize-all)
93124
@_transparent
94125
@derivative(of: -)
@@ -98,6 +129,7 @@ extension ${Self} {
98129
return (lhs - rhs, { (dlhs, drhs) in dlhs - drhs })
99130
}
100131

132+
${Availability(bits)}
101133
@inlinable // FIXME(sil-serialize-all)
102134
@_transparent
103135
@derivative(of: -=)
@@ -108,6 +140,7 @@ extension ${Self} {
108140
return ((), { v in -v })
109141
}
110142

143+
${Availability(bits)}
111144
@inlinable // FIXME(sil-serialize-all)
112145
@_transparent
113146
@derivative(of: -=)
@@ -118,6 +151,7 @@ extension ${Self} {
118151
return ((), { $0 -= $1 })
119152
}
120153

154+
${Availability(bits)}
121155
@inlinable // FIXME(sil-serialize-all)
122156
@_transparent
123157
@derivative(of: *)
@@ -127,6 +161,7 @@ extension ${Self} {
127161
return (lhs * rhs, { v in (rhs * v, lhs * v) })
128162
}
129163

164+
${Availability(bits)}
130165
@inlinable // FIXME(sil-serialize-all)
131166
@_transparent
132167
@derivative(of: *)
@@ -136,6 +171,7 @@ extension ${Self} {
136171
return (lhs * rhs, { (dlhs, drhs) in lhs * drhs + rhs * dlhs })
137172
}
138173

174+
${Availability(bits)}
139175
@inlinable // FIXME(sil-serialize-all)
140176
@_transparent
141177
@derivative(of: *=)
@@ -150,6 +186,7 @@ extension ${Self} {
150186
})
151187
}
152188

189+
${Availability(bits)}
153190
@inlinable // FIXME(sil-serialize-all)
154191
@_transparent
155192
@derivative(of: *=)
@@ -160,6 +197,7 @@ extension ${Self} {
160197
return ((), { $0 *= $1 })
161198
}
162199

200+
${Availability(bits)}
163201
@inlinable // FIXME(sil-serialize-all)
164202
@_transparent
165203
@derivative(of: /)
@@ -169,6 +207,7 @@ extension ${Self} {
169207
return (lhs / rhs, { v in (v / rhs, -lhs / (rhs * rhs) * v) })
170208
}
171209

210+
${Availability(bits)}
172211
@inlinable // FIXME(sil-serialize-all)
173212
@_transparent
174213
@derivative(of: /)
@@ -178,6 +217,7 @@ extension ${Self} {
178217
return (lhs / rhs, { (dlhs, drhs) in dlhs / rhs - lhs / (rhs * rhs) * drhs })
179218
}
180219

220+
${Availability(bits)}
181221
@inlinable // FIXME(sil-serialize-all)
182222
@_transparent
183223
@derivative(of: /=)
@@ -192,6 +232,7 @@ extension ${Self} {
192232
})
193233
}
194234

235+
${Availability(bits)}
195236
@inlinable // FIXME(sil-serialize-all)
196237
@_transparent
197238
@derivative(of: /=)
@@ -203,31 +244,28 @@ extension ${Self} {
203244
}
204245
}
205246

206-
% if self_type.bits == 80:
247+
% if bits == 80:
207248
#endif
208249
% end
209250
% end
210251

211-
extension FloatingPoint where Self : Differentiable,
212-
Self == Self.TangentVector {
213-
/// The vector-Jacobian product function of `addingProduct`. Returns the
214-
/// original result and pullback of `addingProduct` with respect to `self`,
215-
/// `lhs` and `rhs`.
216-
@inlinable
252+
extension FloatingPoint
253+
where
254+
Self: Differentiable,
255+
Self == Self.TangentVector
256+
{
257+
@inlinable // FIXME(sil-serialize-all)
217258
@derivative(of: addingProduct)
218259
func _vjpAddingProduct(
219260
_ lhs: Self, _ rhs: Self
220261
) -> (value: Self, pullback: (Self) -> (Self, Self, Self)) {
221262
return (addingProduct(lhs, rhs), { _ in (1, rhs, lhs) })
222263
}
223264

224-
/// The vector-Jacobian product function of `squareRoot`. Returns the original
225-
/// result and pullback of `squareRoot` with respect to `self`.
226265
@inlinable // FIXME(sil-serialize-all)
227266
@derivative(of: squareRoot)
228267
func _vjpSquareRoot() -> (value: Self, pullback: (Self) -> Self) {
229268
let y = squareRoot()
230269
return (y, { v in v / (2 * y) })
231270
}
232271
}
233-

0 commit comments

Comments
 (0)