|
1 | 1 | // RUN: %empty-directory(%t)
|
2 |
| -// RUN: %target-swift-frontend %s -emit-module -parse-as-library -o %t |
| 2 | +// RUN: %target-swift-frontend -enable-experimental-differentiable-programming %s -emit-module -parse-as-library -o %t |
3 | 3 | // RUN: llvm-bcanalyzer %t/derivative_attr.swiftmodule | %FileCheck %s -check-prefix=BCANALYZER
|
4 |
| -// RUN: %target-sil-opt -disable-sil-linking -enable-sil-verify-all %t/derivative_attr.swiftmodule -o - | %FileCheck %s |
| 4 | +// RUN: %target-sil-opt -enable-experimental-differentiable-programming -disable-sil-linking -enable-sil-verify-all %t/derivative_attr.swiftmodule -o - | %FileCheck %s |
5 | 5 |
|
6 | 6 | // BCANALYZER-NOT: UnknownCode
|
7 | 7 |
|
8 |
| -// TODO(TF-837): Enable this test. |
9 |
| -// Blocked by TF-829: `@derivative` attribute type-checking. |
10 |
| -// XFAIL: * |
| 8 | +// REQUIRES: differentiable_programming |
11 | 9 |
|
12 |
| -func add(x: Float, y: Float) -> Float { |
13 |
| - return x + y |
| 10 | +import _Differentiation |
| 11 | + |
| 12 | +// Dummy `Differentiable`-conforming type. |
| 13 | +struct S: Differentiable & AdditiveArithmetic { |
| 14 | + static var zero: S { S() } |
| 15 | + static func + (_: S, _: S) -> S { S() } |
| 16 | + static func - (_: S, _: S) -> S { S() } |
| 17 | + typealias TangentVector = S |
14 | 18 | }
|
15 |
| -// CHECK: @derivative(of: add, wrt: x) |
16 |
| -@derivative(of: add, wrt: x) |
17 |
| -func jvpAddWrtX(x: Float, y: Float) -> (value: Float, differential: (Float) -> (Float)) { |
18 |
| - return (x + y, { $0 }) |
| 19 | + |
| 20 | +// Test top-level functions. |
| 21 | + |
| 22 | +func top1(_ x: S) -> S { |
| 23 | + x |
19 | 24 | }
|
20 |
| -// CHECK: @derivative(of: add, wrt: (x, y)) |
21 |
| -@derivative(of: add) |
22 |
| -func vjpAdd(x: Float, y: Float) -> (value: Float, pullback: (Float) -> (Float, Float)) { |
23 |
| - return (x + y, { ($0, $0) }) |
| 25 | +// CHECK: @derivative(of: top1, wrt: x) |
| 26 | +@derivative(of: top1, wrt: x) |
| 27 | +func derivativeTop1(_ x: S) -> (value: S, differential: (S) -> S) { |
| 28 | + (x, { $0 }) |
24 | 29 | }
|
25 | 30 |
|
26 |
| -func generic<T : Numeric>(x: T) -> T { |
27 |
| - return x |
| 31 | +func top2<T, U>(_ x: T, _ i: Int, _ y: U) -> U { |
| 32 | + y |
28 | 33 | }
|
29 |
| -// CHECK: @derivative(of: generic, wrt: x) |
30 |
| -@derivative(of: generic) |
31 |
| -func vjpGeneric<T>(x: T) -> (value: T, pullback: (T.TangentVector) -> T.TangentVector) |
32 |
| - where T : Numeric, T : Differentiable |
33 |
| -{ |
34 |
| - return (x, { v in v }) |
| 34 | +// CHECK: @derivative(of: top2, wrt: (x, y)) |
| 35 | +@derivative(of: top2, wrt: (x, y)) |
| 36 | +func derivativeTop2<T: Differentiable, U: Differentiable>( |
| 37 | + _ x: T, _ i: Int, _ y: U |
| 38 | +) -> (value: U, differential: (T.TangentVector, U.TangentVector) -> U.TangentVector) { |
| 39 | + (y, { (dx, dy) in dy }) |
35 | 40 | }
|
36 | 41 |
|
37 |
| -protocol InstanceMethod : Differentiable { |
38 |
| - func foo(_ x: Self) -> Self |
39 |
| - func bar<T : Differentiable>(_ x: T) -> Self |
| 42 | +// Test instance methods. |
| 43 | + |
| 44 | +extension S { |
| 45 | + func instanceMethod(_ x: S) -> S { |
| 46 | + self |
| 47 | + } |
| 48 | + |
| 49 | + // CHECK: @derivative(of: instanceMethod, wrt: x) |
| 50 | + @derivative(of: instanceMethod, wrt: x) |
| 51 | + func derivativeInstanceMethodWrtX(_ x: S) -> (value: S, differential: (S) -> S) { |
| 52 | + (self, { _ in .zero }) |
| 53 | + } |
| 54 | + |
| 55 | + // CHECK: @derivative(of: instanceMethod, wrt: self) |
| 56 | + @derivative(of: instanceMethod, wrt: self) |
| 57 | + func derivativeInstanceMethodWrtSelf(_ x: S) -> (value: S, differential: (S) -> S) { |
| 58 | + (self, { $0 }) |
| 59 | + } |
| 60 | + |
| 61 | + // CHECK: @derivative(of: instanceMethod, wrt: (self, x)) |
| 62 | + @derivative(of: instanceMethod, wrt: (self, x)) |
| 63 | + func derivativeInstanceMethodWrtAll(_ x: S) -> (value: S, differential: (S, S) -> S) { |
| 64 | + (self, { (dself, dx) in self }) |
| 65 | + } |
40 | 66 | }
|
41 |
| -extension InstanceMethod { |
42 |
| - // CHECK: @derivative(of: foo, wrt: (self, x)) |
43 |
| - @derivative(of: foo) |
44 |
| - func vjpFoo(x: Self) -> (value: Self, pullback: (TangentVector) -> (TangentVector, TangentVector)) { |
45 |
| - return (x, { ($0, $0) }) |
| 67 | + |
| 68 | +// Test static methods. |
| 69 | + |
| 70 | +extension S { |
| 71 | + static func staticMethod(_ x: S) -> S { |
| 72 | + x |
| 73 | + } |
| 74 | + |
| 75 | + // CHECK: @derivative(of: staticMethod, wrt: x) |
| 76 | + @derivative(of: staticMethod, wrt: x) |
| 77 | + static func derivativeStaticMethod(_ x: S) -> (value: S, differential: (S) -> S) { |
| 78 | + (x, { $0 }) |
46 | 79 | }
|
| 80 | +} |
| 81 | + |
| 82 | +// Test computed properties. |
| 83 | + |
| 84 | +extension S { |
| 85 | + var computedProperty: S { |
| 86 | + self |
| 87 | + } |
| 88 | + |
| 89 | + // CHECK: @derivative(of: computedProperty, wrt: self) |
| 90 | + @derivative(of: computedProperty, wrt: self) |
| 91 | + func derivativeProperty() -> (value: S, differential: (S) -> S) { |
| 92 | + (self, { $0 }) |
| 93 | + } |
| 94 | +} |
| 95 | + |
| 96 | +// Test subscripts. |
47 | 97 |
|
48 |
| - // CHECK: @derivative(of: bar, wrt: (self, x)) |
49 |
| - @derivative(of: bar, wrt: (self, x)) |
50 |
| - func jvpBarWrt<T : Differentiable>(_ x: T) -> (value: Self, differential: (TangentVector, T) -> TangentVector) |
51 |
| - where T == T.TangentVector |
52 |
| - { |
53 |
| - return (self, { dself, dx in dself }) |
| 98 | +extension S { |
| 99 | + subscript<T: Differentiable>(x: T) -> S { |
| 100 | + self |
54 | 101 | }
|
55 | 102 |
|
56 |
| - // CHECK: @derivative(of: bar, wrt: (self, x)) |
57 |
| - @derivative(of: bar, wrt: (self, x)) |
58 |
| - func vjpBarWrt<T : Differentiable>(_ x: T) -> (value: Self, pullback: (TangentVector) -> (TangentVector, T)) |
59 |
| - where T == T.TangentVector |
60 |
| - { |
61 |
| - return (self, { v in (v, .zero) }) |
| 103 | + // CHECK: @derivative(of: subscript, wrt: self) |
| 104 | + @derivative(of: subscript(_:), wrt: self) |
| 105 | + func derivativeSubscript<T: Differentiable>(x: T) -> (value: S, differential: (S) -> S) { |
| 106 | + (self, { $0 }) |
62 | 107 | }
|
63 | 108 | }
|
0 commit comments