Skip to content

Commit 8aac6f9

Browse files
authored
[AutoDiff upstream] Conform floating-point types to Differentiable. (swiftlang#28718)
Add `Differentiable` conformances for floating-point types to the `_Differentiation` module. The `TangentVector` associated type for floating-point types is `Self`. This design adheres to the differentiable programming manifesto: docs/DifferentiableProgramming.md. Partially resolves TF-1052.
1 parent 0d08f8c commit 8aac6f9

File tree

3 files changed

+37
-1
lines changed

3 files changed

+37
-1
lines changed

stdlib/public/Differentiation/Differentiable.swift

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,3 +42,19 @@ public extension Differentiable where TangentVector == Self {
4242
self += direction
4343
}
4444
}
45+
46+
//===----------------------------------------------------------------------===//
47+
// `Differentiable` conformances
48+
//===----------------------------------------------------------------------===//
49+
50+
extension Float: Differentiable {
51+
public typealias TangentVector = Self
52+
}
53+
extension Double: Differentiable {
54+
public typealias TangentVector = Self
55+
}
56+
#if (arch(i386) || arch(x86_64)) && !(os(Windows) || os(Android))
57+
extension Float80: Differentiable {
58+
public typealias TangentVector = Self
59+
}
60+
#endif

test/AutoDiff/stdlib/differentiable_protocol.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
import _Differentiation
55

6-
// Test conformances.
6+
// Test `Differentiable` protocol conformances.
77

88
struct FloatWrapper {
99
var value: Float
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
// RUN: %target-run-simple-swift
2+
// REQUIRES: differentiable_programming
3+
4+
import _Differentiation
5+
6+
// Test `Differentiable` protocol conformances for stdlib types.
7+
8+
func assertConformsToDifferentiable<T>(_: T.Type) where T: Differentiable {}
9+
10+
func assertSelfEqualsTangentVector<T>(_: T.Type)
11+
where T: Differentiable, T == T.TangentVector {}
12+
13+
// Test `FloatingPoint` types.
14+
func testFloatingPointDifferentiableConformance() {
15+
assertSelfEqualsTangentVector(Float.self)
16+
assertSelfEqualsTangentVector(Double.self)
17+
#if (arch(i386) || arch(x86_64)) && !(os(Windows) || os(Android))
18+
assertSelfEqualsTangentVector(Float80.self)
19+
#endif
20+
}

0 commit comments

Comments
 (0)