Skip to content

Commit dfe1e69

Browse files
JaapWijnenJaap Wijnen
andauthored
[AutoDiff] Add missing vjp, jvp functions to existing FloatingPoint initializers (#70688)
Adds derivatives to already existing initializers that allow converting between floating point type. For example converting a Float to a Double. Co-authored-by: Jaap Wijnen <[email protected]>
1 parent 43d566e commit dfe1e69

File tree

2 files changed

+43
-12
lines changed

2 files changed

+43
-12
lines changed

stdlib/public/Differentiation/FloatingPointDifferentiation.swift.gyb

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -54,21 +54,42 @@ extension ${Self}: Differentiable {
5454
/// Derivatives of constructors.
5555
${Availability(bits)}
5656
extension ${Self} {
57+
58+
% for other_type in all_floating_point_types():
59+
%{
60+
Other = other_type.stdlib_name
61+
other_bits = other_type.bits
62+
}%
63+
64+
% if other_bits == 80:
65+
#if !(os(Windows) || os(Android)) && (arch(i386) || arch(x86_64))
66+
% end
67+
% if other_bits == 16:
68+
#if !os(macOS) && !(os(iOS) && targetEnvironment(macCatalyst))
69+
% end
70+
71+
${Availability(other_bits)}
5772
@inlinable
5873
@_transparent
5974
@derivative(of: init(_:))
60-
static func _vjpInit(x: ${Self})
61-
-> (value: ${Self}, pullback: (${Self}) -> ${Self}) {
62-
return (x, { v in v })
75+
static func _vjpInit(x: ${Other})
76+
-> (value: ${Self}, pullback: (${Self}) -> ${Other}) {
77+
return (value: ${Self}(x), pullback: { v in ${Other}(v) })
6378
}
64-
79+
80+
${Availability(other_bits)}
6581
@inlinable
6682
@_transparent
6783
@derivative(of: init(_:))
68-
static func _jvpInit(x: ${Self})
69-
-> (value: ${Self}, differential: (${Self}) -> ${Self}) {
70-
return (x, { dx in dx })
84+
static func _jvpInit(x: ${Other})
85+
-> (value: ${Self}, differential: (${Other}) -> ${Self}) {
86+
return (value: ${Self}(x), differential: { dx in ${Self}(dx) })
7187
}
88+
89+
% if other_bits == 80 or other_bits == 16:
90+
#endif
91+
% end
92+
% end
7293
}
7394

7495
/// Derivatives of standard unary operators.

test/AutoDiff/stdlib/floating_point.swift.gyb

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,14 +34,24 @@ func expectEqualWithTolerance<T>(_ expected: TestLiteralType, _ actual: T,
3434
#if !(os(Windows) || os(Android)) && (arch(i386) || arch(x86_64))
3535
%end
3636

37-
FloatingPointDerivativeTests.test("${Self}.init") {
38-
expectEqual(1, gradient(at: ${Self}(4), of: ${Self}.init(_:)))
39-
expectEqual(10, pullback(at: ${Self}(4), of: ${Self}.init(_:))(${Self}(10)))
37+
%for Other in ['Float', 'Double', 'Float80']:
38+
%if Other == 'Float80':
39+
#if !(os(Windows) || os(Android)) && (arch(i386) || arch(x86_64))
40+
%end
41+
42+
FloatingPointDerivativeTests.test("${Self}.init(_:${Other})") {
43+
expectEqual(1, gradient(at: ${Other}(4), of: ${Self}.init(_:)))
44+
expectEqual(10, pullback(at: ${Other}(4), of: ${Self}.init(_:))(${Self}(10)))
4045

41-
expectEqual(1, derivative(at: ${Self}(4), of: ${Self}.init(_:)))
42-
expectEqual(10, differential(at: ${Self}(4), of: ${Self}.init(_:))(${Self}(10)))
46+
expectEqual(1, derivative(at: ${Other}(4), of: ${Self}.init(_:)))
47+
expectEqual(10, differential(at: ${Other}(4), of: ${Self}.init(_:))(${Other}(10)))
4348
}
4449

50+
%if Other == 'Float80':
51+
#endif
52+
%end
53+
%end # for Other in ['Float', 'Double', 'Float80']:
54+
4555
FloatingPointDerivativeTests.test("${Self}.+") {
4656
expectEqual((1, 1), gradient(at: ${Self}(4), ${Self}(5), of: +))
4757
expectEqual((10, 10), pullback(at: ${Self}(4), ${Self}(5), of: +)(${Self}(10)))

0 commit comments

Comments
 (0)