Skip to content

Commit 9115b9f

Browse files
committed
Update autodiff tests for the new intrinsics impl
1 parent ed1e5c2 commit 9115b9f

File tree

13 files changed

+152
-221
lines changed

13 files changed

+152
-221
lines changed

tests/codegen-llvm/autodiffv2.rs renamed to tests/codegen-llvm/autodiff/autodiffv2.rs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,13 @@
2626

2727
#![feature(autodiff)]
2828

29-
use std::autodiff::autodiff;
29+
use std::autodiff::autodiff_forward;
3030

31+
// CHECK: ;
3132
#[no_mangle]
3233
//#[autodiff(d_square1, Forward, Dual, Dual)]
33-
#[autodiff(d_square2, Forward, 4, Dualv, Dualv)]
34-
#[autodiff(d_square3, Forward, 4, Dual, Dual)]
34+
#[autodiff_forward(d_square2, 4, Dualv, Dualv)]
35+
#[autodiff_forward(d_square3, 4, Dual, Dual)]
3536
fn square(x: &[f32], y: &mut [f32]) {
3637
assert!(x.len() >= 4);
3738
assert!(y.len() >= 5);

tests/codegen-llvm/autodiff/batched.rs

Lines changed: 32 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,12 @@ use std::autodiff::autodiff_forward;
1717
#[autodiff_forward(d_square2, 4, Dual, DualOnly)]
1818
#[autodiff_forward(d_square1, 4, Dual, Dual)]
1919
#[no_mangle]
20+
#[inline(never)]
2021
fn square(x: &f32) -> f32 {
2122
x * x
2223
}
2324

24-
// d_sqaure2
25+
// d_square2
2526
// CHECK: define internal fastcc [4 x float] @fwddiffe4square(float %x.0.val, [4 x ptr] %"x'")
2627
// CHECK-NEXT: start:
2728
// CHECK-NEXT: %0 = extractvalue [4 x ptr] %"x'", 0
@@ -32,24 +33,20 @@ fn square(x: &f32) -> f32 {
3233
// CHECK-NEXT: %"_2'ipl2" = load float, ptr %2, align 4
3334
// CHECK-NEXT: %3 = extractvalue [4 x ptr] %"x'", 3
3435
// CHECK-NEXT: %"_2'ipl3" = load float, ptr %3, align 4
35-
// CHECK-NEXT: %4 = insertelement <4 x float> poison, float %"_2'ipl", i64 0
36-
// CHECK-NEXT: %5 = insertelement <4 x float> %4, float %"_2'ipl1", i64 1
37-
// CHECK-NEXT: %6 = insertelement <4 x float> %5, float %"_2'ipl2", i64 2
38-
// CHECK-NEXT: %7 = insertelement <4 x float> %6, float %"_2'ipl3", i64 3
39-
// CHECK-NEXT: %8 = fadd fast <4 x float> %7, %7
40-
// CHECK-NEXT: %9 = insertelement <4 x float> poison, float %x.0.val, i64 0
41-
// CHECK-NEXT: %10 = shufflevector <4 x float> %9, <4 x float> poison, <4 x i32> zeroinitializer
42-
// CHECK-NEXT: %11 = fmul fast <4 x float> %8, %10
43-
// CHECK-NEXT: %12 = extractelement <4 x float> %11, i64 0
44-
// CHECK-NEXT: %13 = insertvalue [4 x float] undef, float %12, 0
45-
// CHECK-NEXT: %14 = extractelement <4 x float> %11, i64 1
46-
// CHECK-NEXT: %15 = insertvalue [4 x float] %13, float %14, 1
47-
// CHECK-NEXT: %16 = extractelement <4 x float> %11, i64 2
48-
// CHECK-NEXT: %17 = insertvalue [4 x float] %15, float %16, 2
49-
// CHECK-NEXT: %18 = extractelement <4 x float> %11, i64 3
50-
// CHECK-NEXT: %19 = insertvalue [4 x float] %17, float %18, 3
51-
// CHECK-NEXT: ret [4 x float] %19
52-
// CHECK-NEXT: }
36+
// CHECK-NEXT: %4 = fmul float %"_2'ipl", 2.000000e+00
37+
// CHECK-NEXT: %5 = fmul fast float %4, %x.0.val
38+
// CHECK-NEXT: %6 = insertvalue [4 x float] undef, float %5, 0
39+
// CHECK-NEXT: %7 = fmul float %"_2'ipl1", 2.000000e+00
40+
// CHECK-NEXT: %8 = fmul fast float %7, %x.0.val
41+
// CHECK-NEXT: %9 = insertvalue [4 x float] %6, float %8, 1
42+
// CHECK-NEXT: %10 = fmul float %"_2'ipl2", 2.000000e+00
43+
// CHECK-NEXT: %11 = fmul fast float %10, %x.0.val
44+
// CHECK-NEXT: %12 = insertvalue [4 x float] %9, float %11, 2
45+
// CHECK-NEXT: %13 = fmul float %"_2'ipl3", 2.000000e+00
46+
// CHECK-NEXT: %14 = fmul fast float %13, %x.0.val
47+
// CHECK-NEXT: %15 = insertvalue [4 x float] %12, float %14, 3
48+
// CHECK-NEXT: ret [4 x float] %15
49+
// CHECK-NEXT: }
5350

5451
// d_square3, the extra float is the original return value (x * x)
5552
// CHECK: define internal fastcc { float, [4 x float] } @fwddiffe4square.1(float %x.0.val, [4 x ptr] %"x'")
@@ -63,26 +60,22 @@ fn square(x: &f32) -> f32 {
6360
// CHECK-NEXT: %3 = extractvalue [4 x ptr] %"x'", 3
6461
// CHECK-NEXT: %"_2'ipl3" = load float, ptr %3, align 4
6562
// CHECK-NEXT: %_0 = fmul float %x.0.val, %x.0.val
66-
// CHECK-NEXT: %4 = insertelement <4 x float> poison, float %"_2'ipl", i64 0
67-
// CHECK-NEXT: %5 = insertelement <4 x float> %4, float %"_2'ipl1", i64 1
68-
// CHECK-NEXT: %6 = insertelement <4 x float> %5, float %"_2'ipl2", i64 2
69-
// CHECK-NEXT: %7 = insertelement <4 x float> %6, float %"_2'ipl3", i64 3
70-
// CHECK-NEXT: %8 = fadd fast <4 x float> %7, %7
71-
// CHECK-NEXT: %9 = insertelement <4 x float> poison, float %x.0.val, i64 0
72-
// CHECK-NEXT: %10 = shufflevector <4 x float> %9, <4 x float> poison, <4 x i32> zeroinitializer
73-
// CHECK-NEXT: %11 = fmul fast <4 x float> %8, %10
74-
// CHECK-NEXT: %12 = extractelement <4 x float> %11, i64 0
75-
// CHECK-NEXT: %13 = insertvalue [4 x float] undef, float %12, 0
76-
// CHECK-NEXT: %14 = extractelement <4 x float> %11, i64 1
77-
// CHECK-NEXT: %15 = insertvalue [4 x float] %13, float %14, 1
78-
// CHECK-NEXT: %16 = extractelement <4 x float> %11, i64 2
79-
// CHECK-NEXT: %17 = insertvalue [4 x float] %15, float %16, 2
80-
// CHECK-NEXT: %18 = extractelement <4 x float> %11, i64 3
81-
// CHECK-NEXT: %19 = insertvalue [4 x float] %17, float %18, 3
82-
// CHECK-NEXT: %20 = insertvalue { float, [4 x float] } undef, float %_0, 0
83-
// CHECK-NEXT: %21 = insertvalue { float, [4 x float] } %20, [4 x float] %19, 1
84-
// CHECK-NEXT: ret { float, [4 x float] } %21
85-
// CHECK-NEXT: }
63+
// CHECK-NEXT: %4 = fmul float %"_2'ipl", 2.000000e+00
64+
// CHECK-NEXT: %5 = fmul fast float %4, %x.0.val
65+
// CHECK-NEXT: %6 = insertvalue [4 x float] undef, float %5, 0
66+
// CHECK-NEXT: %7 = fmul float %"_2'ipl1", 2.000000e+00
67+
// CHECK-NEXT: %8 = fmul fast float %7, %x.0.val
68+
// CHECK-NEXT: %9 = insertvalue [4 x float] %6, float %8, 1
69+
// CHECK-NEXT: %10 = fmul float %"_2'ipl2", 2.000000e+00
70+
// CHECK-NEXT: %11 = fmul fast float %10, %x.0.val
71+
// CHECK-NEXT: %12 = insertvalue [4 x float] %9, float %11, 2
72+
// CHECK-NEXT: %13 = fmul float %"_2'ipl3", 2.000000e+00
73+
// CHECK-NEXT: %14 = fmul fast float %13, %x.0.val
74+
// CHECK-NEXT: %15 = insertvalue [4 x float] %12, float %14, 3
75+
// CHECK-NEXT: %16 = insertvalue { float, [4 x float] } undef, float %_0, 0
76+
// CHECK-NEXT: %17 = insertvalue { float, [4 x float] } %16, [4 x float] %15, 1
77+
// CHECK-NEXT: ret { float, [4 x float] } %17
78+
// CHECK-NEXT: }
8679

8780
fn main() {
8881
let x = std::hint::black_box(3.0);

tests/codegen-llvm/autodiff/generic.rs

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,27 +6,28 @@
66
use std::autodiff::autodiff_reverse;
77

88
#[autodiff_reverse(d_square, Duplicated, Active)]
9+
#[inline(never)]
910
fn square<T: std::ops::Mul<Output = T> + Copy>(x: &T) -> T {
1011
*x * *x
1112
}
1213

13-
// Ensure that `d_square::<f64>` code is generated even if `square::<f64>` was never called
14+
// Ensure that `d_square::<f32>` code is generated
1415
//
1516
// CHECK: ; generic::square
16-
// CHECK-NEXT: ; Function Attrs:
17-
// CHECK-NEXT: define internal {{.*}} double
17+
// CHECK-NEXT: ; Function Attrs: {{.*}}
18+
// CHECK-NEXT: define internal {{.*}} float
1819
// CHECK-NEXT: start:
1920
// CHECK-NOT: ret
20-
// CHECK: fmul double
21+
// CHECK: fmul float
2122

22-
// Ensure that `d_square::<f32>` code is generated
23+
// Ensure that `d_square::<f64>` code is generated even if `square::<f64>` was never called
2324
//
2425
// CHECK: ; generic::square
25-
// CHECK-NEXT: ; Function Attrs: {{.*}}
26-
// CHECK-NEXT: define internal {{.*}} float
26+
// CHECK-NEXT: ; Function Attrs:
27+
// CHECK-NEXT: define internal {{.*}} double
2728
// CHECK-NEXT: start:
2829
// CHECK-NOT: ret
29-
// CHECK: fmul float
30+
// CHECK: fmul double
3031

3132
fn main() {
3233
let xf32: f32 = std::hint::black_box(3.0);

tests/codegen-llvm/autodiff/identical_fnc.rs

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,25 +14,27 @@
1414
use std::autodiff::autodiff_reverse;
1515

1616
#[autodiff_reverse(d_square, Duplicated, Active)]
17+
#[inline(never)]
1718
fn square(x: &f64) -> f64 {
1819
x * x
1920
}
2021

2122
#[autodiff_reverse(d_square2, Duplicated, Active)]
23+
#[inline(never)]
2224
fn square2(x: &f64) -> f64 {
2325
x * x
2426
}
2527

2628
// CHECK:; identical_fnc::main
2729
// CHECK-NEXT:; Function Attrs:
28-
// CHECK-NEXT:define internal void @_ZN13identical_fnc4main17hf4dbc69c8d2f9130E()
30+
// CHECK-NEXT:define internal void @_ZN13identical_fnc4main17h6009e4f751bf9407E()
2931
// CHECK-NEXT:start:
3032
// CHECK-NOT:br
3133
// CHECK-NOT:ret
3234
// CHECK:; call identical_fnc::d_square
33-
// CHECK-NEXT: call fastcc void @_ZN13identical_fnc8d_square17h4c364207a2f8e06dE(double %x.val, ptr noalias noundef nonnull align 8 dereferenceable(8) %dx1)
34-
// CHECK-NEXT:; call identical_fnc::d_square
35-
// CHECK-NEXT: call fastcc void @_ZN13identical_fnc8d_square17h4c364207a2f8e06dE(double %x.val, ptr noalias noundef nonnull align 8 dereferenceable(8) %dx2)
35+
// CHECK-NEXT:call fastcc void @_ZN13identical_fnc8d_square17hcb5768e95528c35fE(double %x.val, ptr noalias noundef align 8 dereferenceable(8) %dx1)
36+
// CHECK:; call identical_fnc::d_square
37+
// CHECK-NEXT:call fastcc void @_ZN13identical_fnc8d_square17hcb5768e95528c35fE(double %x.val, ptr noalias noundef align 8 dereferenceable(8) %dx2)
3638

3739
fn main() {
3840
let x = std::hint::black_box(3.0);

tests/codegen-llvm/autodiff/inline.rs

Lines changed: 0 additions & 23 deletions
This file was deleted.
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
//@ compile-flags: -Zautodiff=Enable -C opt-level=3 -Clto=fat
2+
//@ no-prefer-dynamic
3+
//@ needs-enzyme
4+
5+
// In the past, we just checked for correct macro hygiene information.
6+
7+
#![feature(autodiff)]
8+
9+
// CHECK: ;
10+
macro_rules! demo {
11+
() => {
12+
#[std::autodiff::autodiff_reverse(fd, Active, Active)]
13+
fn f(x: f64) -> f64 {
14+
x * x
15+
}
16+
};
17+
}
18+
demo!();
19+
20+
fn main() {
21+
dbg!(f(2.0f64));
22+
}

tests/codegen-llvm/autodiff/scalar.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,12 @@ use std::autodiff::autodiff_reverse;
77

88
#[autodiff_reverse(d_square, Duplicated, Active)]
99
#[no_mangle]
10+
#[inline(never)]
1011
fn square(x: &f64) -> f64 {
1112
x * x
1213
}
1314

14-
// CHECK:define internal fastcc double @diffesquare(double %x.0.val, ptr nocapture nonnull align 8 %"x'"
15+
// CHECK:define internal fastcc double @diffesquare(double %x.0.val, ptr nonnull align 8 captures(none) %"x'")
1516
// CHECK-NEXT:invertstart:
1617
// CHECK-NEXT: %_0 = fmul double %x.0.val, %x.0.val
1718
// CHECK-NEXT: %0 = fadd fast double %x.0.val, %x.0.val

tests/codegen-llvm/autodiff/sret.rs

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,30 +13,30 @@ use std::autodiff::autodiff_reverse;
1313

1414
#[no_mangle]
1515
#[autodiff_reverse(df, Active, Active, Active)]
16+
#[inline(never)]
1617
fn primal(x: f32, y: f32) -> f64 {
1718
(x * x * y) as f64
1819
}
1920

20-
// CHECK:define internal fastcc void @_ZN4sret2df17h93be4316dd8ea006E(ptr dead_on_unwind noalias nocapture noundef nonnull writable writeonly align 8 dereferenceable(16) initializes((0, 16)) %_0, float noundef %x, float noundef %y)
21-
// CHECK-NEXT:start:
22-
// CHECK-NEXT: %0 = tail call fastcc { double, float, float } @diffeprimal(float %x, float %y)
23-
// CHECK-NEXT: %.elt = extractvalue { double, float, float } %0, 0
24-
// CHECK-NEXT: store double %.elt, ptr %_0, align 8
25-
// CHECK-NEXT: %_0.repack1 = getelementptr inbounds nuw i8, ptr %_0, i64 8
26-
// CHECK-NEXT: %.elt2 = extractvalue { double, float, float } %0, 1
27-
// CHECK-NEXT: store float %.elt2, ptr %_0.repack1, align 8
28-
// CHECK-NEXT: %_0.repack3 = getelementptr inbounds nuw i8, ptr %_0, i64 12
29-
// CHECK-NEXT: %.elt4 = extractvalue { double, float, float } %0, 2
30-
// CHECK-NEXT: store float %.elt4, ptr %_0.repack3, align 4
31-
// CHECK-NEXT: ret void
32-
// CHECK-NEXT:}
21+
// CHECK: define internal fastcc { double, float, float } @diffeprimal(float noundef %x, float noundef %y)
22+
// CHECK-NEXT: invertstart:
23+
// CHECK-NEXT: %_4 = fmul float %x, %x
24+
// CHECK-NEXT: %_3 = fmul float %_4, %y
25+
// CHECK-NEXT: %_0 = fpext float %_3 to double
26+
// CHECK-NEXT: %0 = fadd fast float %y, %y
27+
// CHECK-NEXT: %1 = fmul fast float %0, %x
28+
// CHECK-NEXT: %2 = insertvalue { double, float, float } undef, double %_0, 0
29+
// CHECK-NEXT: %3 = insertvalue { double, float, float } %2, float %1, 1
30+
// CHECK-NEXT: %4 = insertvalue { double, float, float } %3, float %_4, 2
31+
// CHECK-NEXT: ret { double, float, float } %4
32+
// CHECK-NEXT: }
3333

3434
fn main() {
3535
let x = std::hint::black_box(3.0);
3636
let y = std::hint::black_box(2.5);
3737
let scalar = std::hint::black_box(1.0);
3838
let (r1, r2, r3) = df(x, y, scalar);
39-
// 3*3*1.5 = 22.5
39+
// 3*3*2.5 = 22.5
4040
assert_eq!(r1, 22.5);
4141
// 2*x*y = 2*3*2.5 = 15.0
4242
assert_eq!(r2, 15.0);

tests/codegen-llvm/autodiff/trait.rs

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
//@ compile-flags: -Zautodiff=Enable -Zautodiff=NoPostopt -C opt-level=3 -Clto=fat
2+
//@ no-prefer-dynamic
3+
//@ needs-enzyme
4+
5+
// Just check it does not crash for now
6+
// CHECK: ;
7+
#![feature(autodiff)]
8+
9+
use std::autodiff::autodiff_reverse;
10+
11+
struct Foo {
12+
a: f64,
13+
}
14+
15+
trait MyTrait {
16+
fn f(&self, x: f64) -> f64;
17+
fn df(&self, x: f64, seed: f64) -> (f64, f64);
18+
}
19+
20+
impl MyTrait for Foo {
21+
#[autodiff_reverse(df, Const, Active, Active)]
22+
fn f(&self, x: f64) -> f64 {
23+
self.a * 0.25 * (x * x - 1.0 - 2.0 * x.ln())
24+
}
25+
}
26+
27+
fn main() {
28+
let foo = Foo { a: 3.0f64 };
29+
dbg!(foo.df(1.0, 1.0));
30+
}

0 commit comments

Comments
 (0)