@@ -26,111 +26,115 @@ import Swift
26
26
#error("Unsupported platform")
27
27
#endif
28
28
29
- @usableFromInline
29
+ % for T in [ 'T', 'Double'] : # Prevents name collisions with system math library
30
+ % generic_signature = '< T: FloatingPoint & Differentiable> ' if T == 'T' else ''
31
+ % constraint = 'where T == T. TangentVector' if T == 'T' else ''
32
+ @inlinable
30
33
@derivative( of: fma)
31
- func _jvpFma< T : FloatingPoint & Differentiable > (
32
- _ x: T ,
33
- _ y: T ,
34
- _ z: T
35
- ) -> ( value: T , differential: ( T , T , T ) -> T ) where T == T . TangentVector {
34
+ func _jvpFma$ { generic_signature } (
35
+ _ x: $ { T } ,
36
+ _ y: $ { T } ,
37
+ _ z: $ { T }
38
+ ) - > ( value: $ { T } , differential: ( $ { T } , $ { T } , $ { T } ) - > $ { T } ) $ { constraint } {
36
39
return ( fma ( x, y, z) , { ( dx, dy, dz) in dx * y + dy * x + dz } )
37
40
}
38
41
39
- @usableFromInline
42
+ @inlinable
40
43
@derivative( of: fma)
41
- func _vjpFma< T : FloatingPoint & Differentiable > (
42
- _ x: T ,
43
- _ y: T ,
44
- _ z: T
45
- ) -> ( value: T , pullback: ( T ) -> ( T , T , T ) ) where T == T . TangentVector {
44
+ func _vjpFma$ { generic_signature } (
45
+ _ x: $ { T } ,
46
+ _ y: $ { T } ,
47
+ _ z: $ { T }
48
+ ) - > ( value: $ { T } , pullback: ( $ { T } ) - > ( $ { T } , $ { T } , $ { T } ) ) $ { constraint } {
46
49
return ( fma ( x, y, z) , { v in ( v * y, v * x, v) } )
47
50
}
48
51
49
- @usableFromInline
52
+ @inlinable
50
53
@derivative( of: remainder)
51
- func _jvpRemainder< T : FloatingPoint & Differentiable > (
52
- _ x: T ,
53
- _ y: T
54
- ) -> ( value: T , differential: ( T , T ) -> T ) where T == T . TangentVector {
54
+ func _jvpRemainder$ { generic_signature } (
55
+ _ x: $ { T } ,
56
+ _ y: $ { T }
57
+ ) - > ( value: $ { T } , differential: ( $ { T } , $ { T } ) - > $ { T } ) $ { constraint } {
55
58
fatalError ( """
56
59
Unimplemented JVP for 'remainder(_:)'. \
57
60
https://bugs.swift.org/browse/TF-1108 tracks this issue
58
61
""" )
59
62
}
60
63
61
- @usableFromInline
64
+ @inlinable
62
65
@derivative( of: remainder)
63
- func _vjpRemainder< T : FloatingPoint & Differentiable > (
64
- _ x: T ,
65
- _ y: T
66
- ) -> ( value: T , pullback: ( T ) -> ( T , T ) ) where T == T . TangentVector {
66
+ func _vjpRemainder$ { generic_signature } (
67
+ _ x: $ { T } ,
68
+ _ y: $ { T }
69
+ ) - > ( value: $ { T } , pullback: ( $ { T } ) - > ( $ { T } , $ { T } ) ) $ { constraint } {
67
70
return ( remainder ( x, y) , { v in ( v, - v * ( ( x / y) . rounded ( . toNearestOrEven) ) ) } )
68
71
}
69
72
70
- @usableFromInline
73
+ @inlinable
71
74
@derivative( of: fmod)
72
- func _jvpFmod< T : FloatingPoint & Differentiable > (
73
- _ x: T ,
74
- _ y: T
75
- ) -> ( value: T , differential: ( T , T ) -> T ) where T == T . TangentVector {
75
+ func _jvpFmod$ { generic_signature } (
76
+ _ x: $ { T } ,
77
+ _ y: $ { T }
78
+ ) - > ( value: $ { T } , differential: ( $ { T } , $ { T } ) - > $ { T } ) $ { constraint } {
76
79
fatalError ( """
77
80
Unimplemented JVP for 'fmod(_:)'. \
78
81
https://bugs.swift.org/browse/TF-1108 tracks this issue
79
82
""" )
80
83
}
81
84
82
- @usableFromInline
85
+ @inlinable
83
86
@derivative( of: fmod)
84
- func _vjpFmod< T : FloatingPoint & Differentiable > (
85
- _ x: T ,
86
- _ y: T
87
- ) -> ( value: T , pullback: ( T ) -> ( T , T ) ) where T == T . TangentVector {
87
+ func _vjpFmod$ { generic_signature } (
88
+ _ x: $ { T } ,
89
+ _ y: $ { T }
90
+ ) - > ( value: $ { T } , pullback: ( $ { T } ) - > ( $ { T } , $ { T } ) ) $ { constraint } {
88
91
return ( fmod ( x, y) , { v in ( v, - v * ( ( x / y) . rounded ( . towardZero) ) ) } )
89
92
}
90
93
91
- % for derivative_kind in [ 'jvp', 'vjp'] :
92
- % linear_map_kind = 'differential' if derivative_kind == 'jvp' else 'pullback'
93
- @usableFromInline
94
+ % for derivative_kind in [ 'jvp', 'vjp'] :
95
+ % linear_map_kind = 'differential' if derivative_kind == 'jvp' else 'pullback'
96
+ @inlinable
94
97
@derivative( of: sqrt)
95
- func _${ derivative_kind} Sqrt< T : FloatingPoint & Differentiable > (
96
- _ x: T
97
- ) - > ( value: T , ${ linear_map_kind} : ( T ) - > T ) where T == T . TangentVector {
98
+ func _${ derivative_kind} Sqrt$ { generic_signature } (
99
+ _ x: $ { T }
100
+ ) - > ( value: $ { T } , ${ linear_map_kind} : ( $ { T } ) - > $ { T } ) $ { constraint } {
98
101
let value = sqrt ( x)
99
102
return ( value, { v in v / ( 2 * value) } )
100
103
}
101
104
102
- @usableFromInline
105
+ @inlinable
103
106
@derivative( of: ceil)
104
- func _${ derivative_kind} Ceil< T : FloatingPoint & Differentiable > (
105
- _ x: T
106
- ) - > ( value: T , ${ linear_map_kind} : ( T ) - > T ) where T == T . TangentVector {
107
+ func _${ derivative_kind} Ceil$ { generic_signature } (
108
+ _ x: $ { T }
109
+ ) - > ( value: $ { T } , ${ linear_map_kind} : ( $ { T } ) - > $ { T } ) $ { constraint } {
107
110
return ( ceil ( x) , { v in 0 } )
108
111
}
109
112
110
- @usableFromInline
113
+ @inlinable
111
114
@derivative( of: floor)
112
- func _${ derivative_kind} Floor< T : FloatingPoint & Differentiable > (
113
- _ x: T
114
- ) - > ( value: T , ${ linear_map_kind} : ( T ) - > T ) where T == T . TangentVector {
115
+ func _${ derivative_kind} Floor$ { generic_signature } (
116
+ _ x: $ { T }
117
+ ) - > ( value: $ { T } , ${ linear_map_kind} : ( $ { T } ) - > $ { T } ) $ { constraint } {
115
118
return ( floor ( x) , { v in 0 } )
116
119
}
117
120
118
- @usableFromInline
121
+ @inlinable
119
122
@derivative( of: round)
120
- func _${ derivative_kind} Round< T : FloatingPoint & Differentiable > (
121
- _ x: T
122
- ) - > ( value: T , ${ linear_map_kind} : ( T ) - > T ) where T == T . TangentVector {
123
+ func _${ derivative_kind} Round$ { generic_signature } (
124
+ _ x: $ { T }
125
+ ) - > ( value: $ { T } , ${ linear_map_kind} : ( $ { T } ) - > $ { T } ) $ { constraint } {
123
126
return ( round ( x) , { v in 0 } )
124
127
}
125
128
126
- @usableFromInline
129
+ @inlinable
127
130
@derivative( of: trunc)
128
- func _${ derivative_kind} Trunc< T : FloatingPoint & Differentiable > (
129
- _ x: T
130
- ) - > ( value: T , ${ linear_map_kind} : ( T ) - > T ) where T == T . TangentVector {
131
+ func _${ derivative_kind} Trunc$ { generic_signature } (
132
+ _ x: $ { T }
133
+ ) - > ( value: $ { T } , ${ linear_map_kind} : ( $ { T } ) - > $ { T } ) $ { constraint } {
131
134
return ( trunc ( x) , { v in 0 } )
132
135
}
133
- % end # for derivative_kind in [ 'jvp', 'vjp'] :
136
+ % end # for derivative_kind in [ 'jvp', 'vjp'] :
137
+ % end # for T in [ 'T', 'Double'] :
134
138
135
139
// Unary functions
136
140
% for derivative_kind in [ 'jvp', 'vjp'] :
@@ -276,7 +280,7 @@ func _${derivative_kind}Erfc(_ x: ${T}) -> (value: ${T}, ${linear_map_kind}: (${
276
280
% end # for derivative_kind in [ 'jvp', 'vjp'] :
277
281
278
282
// Binary functions
279
- % for T in [ 'Float', 'Float80 '] :
283
+ % for T in [ 'Float', 'Double' , ' Float80 '] :
280
284
% if T == 'Float80 ':
281
285
#if !(os(Windows) || os(Android)) && (arch(i386) || arch(x86_64))
282
286
% end
@@ -300,4 +304,4 @@ func _jvpPow(_ x: ${T}, _ y: ${T}) -> (value: ${T}, differential: (${T}, ${T}) -
300
304
% if T == 'Float80 ':
301
305
#endif
302
306
% end # if T == 'Float80 ':
303
- % end # for T in [ 'Float', 'Float80 '] :
307
+ % end # for T in [ 'Float', 'Double' , ' Float80 '] :
0 commit comments