Skip to content

Commit 09ece3b

Browse files
authored
[AutoDiff] Allow tgmath differentiation of Double (swiftlang#40991)
This allows automatic differentiation of concrete calls to math functions operating on `Double`. Such calls were not differentiable because the concrete math functions were imported from C, not the generic ones defined in Swift. The fix is to define derivatives for these C functions. `test/AutoDiff/stdlib/tgmath_derivatives.swift.gyb` was modified to test the new functions.
1 parent 97359ec commit 09ece3b

File tree

2 files changed

+63
-59
lines changed

2 files changed

+63
-59
lines changed

stdlib/public/Differentiation/TgmathDerivatives.swift.gyb

Lines changed: 61 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -26,111 +26,115 @@ import Swift
2626
#error("Unsupported platform")
2727
#endif
2828

29-
@usableFromInline
29+
%for T in ['T', 'Double']: # Prevents name collisions with system math library
30+
% generic_signature = '<T: FloatingPoint & Differentiable>' if T == 'T' else ''
31+
% constraint = 'where T == T.TangentVector' if T == 'T' else ''
32+
@inlinable
3033
@derivative(of: fma)
31-
func _jvpFma<T: FloatingPoint & Differentiable> (
32-
_ x: T,
33-
_ y: T,
34-
_ z: T
35-
) -> (value: T, differential: (T, T, T) -> T) where T == T.TangentVector {
34+
func _jvpFma${generic_signature} (
35+
_ x: ${T},
36+
_ y: ${T},
37+
_ z: ${T}
38+
) -> (value: ${T}, differential: (${T}, ${T}, ${T}) -> ${T}) ${constraint} {
3639
return (fma(x, y, z), { (dx, dy, dz) in dx * y + dy * x + dz })
3740
}
3841

39-
@usableFromInline
42+
@inlinable
4043
@derivative(of: fma)
41-
func _vjpFma<T: FloatingPoint & Differentiable> (
42-
_ x: T,
43-
_ y: T,
44-
_ z: T
45-
) -> (value: T, pullback: (T) -> (T, T, T)) where T == T.TangentVector {
44+
func _vjpFma${generic_signature} (
45+
_ x: ${T},
46+
_ y: ${T},
47+
_ z: ${T}
48+
) -> (value: ${T}, pullback: (${T}) -> (${T}, ${T}, ${T})) ${constraint} {
4649
return (fma(x, y, z), { v in (v * y, v * x, v) })
4750
}
4851

49-
@usableFromInline
52+
@inlinable
5053
@derivative(of: remainder)
51-
func _jvpRemainder<T: FloatingPoint & Differentiable> (
52-
_ x: T,
53-
_ y: T
54-
) -> (value: T, differential: (T, T) -> T) where T == T.TangentVector {
54+
func _jvpRemainder${generic_signature} (
55+
_ x: ${T},
56+
_ y: ${T}
57+
) -> (value: ${T}, differential: (${T}, ${T}) -> ${T}) ${constraint} {
5558
fatalError("""
5659
Unimplemented JVP for 'remainder(_:)'. \
5760
https://bugs.swift.org/browse/TF-1108 tracks this issue
5861
""")
5962
}
6063

61-
@usableFromInline
64+
@inlinable
6265
@derivative(of: remainder)
63-
func _vjpRemainder<T: FloatingPoint & Differentiable> (
64-
_ x: T,
65-
_ y: T
66-
) -> (value: T, pullback: (T) -> (T, T)) where T == T.TangentVector {
66+
func _vjpRemainder${generic_signature} (
67+
_ x: ${T},
68+
_ y: ${T}
69+
) -> (value: ${T}, pullback: (${T}) -> (${T}, ${T})) ${constraint} {
6770
return (remainder(x, y), { v in (v, -v * ((x / y).rounded(.toNearestOrEven))) })
6871
}
6972

70-
@usableFromInline
73+
@inlinable
7174
@derivative(of: fmod)
72-
func _jvpFmod<T: FloatingPoint & Differentiable> (
73-
_ x: T,
74-
_ y: T
75-
) -> (value: T, differential: (T, T) -> T) where T == T.TangentVector {
75+
func _jvpFmod${generic_signature} (
76+
_ x: ${T},
77+
_ y: ${T}
78+
) -> (value: ${T}, differential: (${T}, ${T}) -> ${T}) ${constraint} {
7679
fatalError("""
7780
Unimplemented JVP for 'fmod(_:)'. \
7881
https://bugs.swift.org/browse/TF-1108 tracks this issue
7982
""")
8083
}
8184

82-
@usableFromInline
85+
@inlinable
8386
@derivative(of: fmod)
84-
func _vjpFmod<T: FloatingPoint & Differentiable> (
85-
_ x: T,
86-
_ y: T
87-
) -> (value: T, pullback: (T) -> (T, T)) where T == T.TangentVector {
87+
func _vjpFmod${generic_signature} (
88+
_ x: ${T},
89+
_ y: ${T}
90+
) -> (value: ${T}, pullback: (${T}) -> (${T}, ${T})) ${constraint} {
8891
return (fmod(x, y), { v in (v, -v * ((x / y).rounded(.towardZero))) })
8992
}
9093

91-
%for derivative_kind in ['jvp', 'vjp']:
92-
% linear_map_kind = 'differential' if derivative_kind == 'jvp' else 'pullback'
93-
@usableFromInline
94+
% for derivative_kind in ['jvp', 'vjp']:
95+
% linear_map_kind = 'differential' if derivative_kind == 'jvp' else 'pullback'
96+
@inlinable
9497
@derivative(of: sqrt)
95-
func _${derivative_kind}Sqrt<T: FloatingPoint & Differentiable> (
96-
_ x: T
97-
) -> (value: T, ${linear_map_kind}: (T) -> T) where T == T.TangentVector {
98+
func _${derivative_kind}Sqrt${generic_signature} (
99+
_ x: ${T}
100+
) -> (value: ${T}, ${linear_map_kind}: (${T}) -> ${T}) ${constraint} {
98101
let value = sqrt(x)
99102
return (value, { v in v / (2 * value) })
100103
}
101104

102-
@usableFromInline
105+
@inlinable
103106
@derivative(of: ceil)
104-
func _${derivative_kind}Ceil<T: FloatingPoint & Differentiable> (
105-
_ x: T
106-
) -> (value: T, ${linear_map_kind}: (T) -> T) where T == T.TangentVector {
107+
func _${derivative_kind}Ceil${generic_signature} (
108+
_ x: ${T}
109+
) -> (value: ${T}, ${linear_map_kind}: (${T}) -> ${T}) ${constraint} {
107110
return (ceil(x), { v in 0 })
108111
}
109112

110-
@usableFromInline
113+
@inlinable
111114
@derivative(of: floor)
112-
func _${derivative_kind}Floor<T: FloatingPoint & Differentiable> (
113-
_ x: T
114-
) -> (value: T, ${linear_map_kind}: (T) -> T) where T == T.TangentVector {
115+
func _${derivative_kind}Floor${generic_signature} (
116+
_ x: ${T}
117+
) -> (value: ${T}, ${linear_map_kind}: (${T}) -> ${T}) ${constraint} {
115118
return (floor(x), { v in 0 })
116119
}
117120

118-
@usableFromInline
121+
@inlinable
119122
@derivative(of: round)
120-
func _${derivative_kind}Round<T: FloatingPoint & Differentiable> (
121-
_ x: T
122-
) -> (value: T, ${linear_map_kind}: (T) -> T) where T == T.TangentVector {
123+
func _${derivative_kind}Round${generic_signature} (
124+
_ x: ${T}
125+
) -> (value: ${T}, ${linear_map_kind}: (${T}) -> ${T}) ${constraint} {
123126
return (round(x), { v in 0 })
124127
}
125128

126-
@usableFromInline
129+
@inlinable
127130
@derivative(of: trunc)
128-
func _${derivative_kind}Trunc<T: FloatingPoint & Differentiable> (
129-
_ x: T
130-
) -> (value: T, ${linear_map_kind}: (T) -> T) where T == T.TangentVector {
131+
func _${derivative_kind}Trunc${generic_signature} (
132+
_ x: ${T}
133+
) -> (value: ${T}, ${linear_map_kind}: (${T}) -> ${T}) ${constraint} {
131134
return (trunc(x), { v in 0 })
132135
}
133-
%end # for derivative_kind in ['jvp', 'vjp']:
136+
% end # for derivative_kind in ['jvp', 'vjp']:
137+
%end # for T in ['T', 'Double']:
134138

135139
// Unary functions
136140
%for derivative_kind in ['jvp', 'vjp']:
@@ -276,7 +280,7 @@ func _${derivative_kind}Erfc(_ x: ${T}) -> (value: ${T}, ${linear_map_kind}: (${
276280
%end # for derivative_kind in ['jvp', 'vjp']:
277281

278282
// Binary functions
279-
%for T in ['Float', 'Float80']:
283+
%for T in ['Float', 'Double', 'Float80']:
280284
% if T == 'Float80':
281285
#if !(os(Windows) || os(Android)) && (arch(i386) || arch(x86_64))
282286
% end
@@ -300,4 +304,4 @@ func _jvpPow(_ x: ${T}, _ y: ${T}) -> (value: ${T}, differential: (${T}, ${T}) -
300304
% if T == 'Float80':
301305
#endif
302306
% end # if T == 'Float80':
303-
%end # for T in ['Float', 'Float80']:
307+
%end # for T in ['Float', 'Double', 'Float80']:

test/AutoDiff/stdlib/tgmath_derivatives.swift.gyb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ where T == T.TangentVector {
7575
}
7676

7777
%for op in ['derivative', 'gradient']:
78-
%for T in ['Float', 'Float80']:
78+
%for T in ['Float', 'Double', 'Float80']:
7979

8080
%if T == 'Float80':
8181
#if !(os(Windows) || os(Android)) && (arch(i386) || arch(x86_64))
@@ -205,7 +205,7 @@ DerivativeTests.test("${op}_${T}") {
205205
#endif
206206
%end
207207

208-
%end # for T in ['Float', 'Float80']:
208+
%end # for T in ['Float', 'Double', 'Float80']:
209209
%end # for op in ['derivative', 'gradient']:
210210

211211
runAllTests()

0 commit comments

Comments
 (0)