Skip to content

Commit f81fd98

Browse files
committed
Update codegen tests
1 parent 5c13692 commit f81fd98

File tree

7 files changed

+70
-49
lines changed

7 files changed

+70
-49
lines changed

tests/codegen/autodiff/batched.rs

Lines changed: 32 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
// reduce this test to only match the first lines and the ret instructions.
1111

1212
#![feature(autodiff)]
13-
#![feature(intrinsics)]
13+
#![feature(core_intrinsics)]
1414

1515
use std::autodiff::autodiff_forward;
1616

@@ -22,7 +22,7 @@ fn square(x: &f32) -> f32 {
2222
x * x
2323
}
2424

25-
// d_sqaure2
25+
// d_square2
2626
// CHECK: define internal fastcc [4 x float] @fwddiffe4square(float %x.0.val, [4 x ptr] %"x'")
2727
// CHECK-NEXT: start:
2828
// CHECK-NEXT: %0 = extractvalue [4 x ptr] %"x'", 0
@@ -33,24 +33,20 @@ fn square(x: &f32) -> f32 {
3333
// CHECK-NEXT: %"_2'ipl2" = load float, ptr %2, align 4
3434
// CHECK-NEXT: %3 = extractvalue [4 x ptr] %"x'", 3
3535
// CHECK-NEXT: %"_2'ipl3" = load float, ptr %3, align 4
36-
// CHECK-NEXT: %4 = insertelement <4 x float> poison, float %"_2'ipl", i64 0
37-
// CHECK-NEXT: %5 = insertelement <4 x float> %4, float %"_2'ipl1", i64 1
38-
// CHECK-NEXT: %6 = insertelement <4 x float> %5, float %"_2'ipl2", i64 2
39-
// CHECK-NEXT: %7 = insertelement <4 x float> %6, float %"_2'ipl3", i64 3
40-
// CHECK-NEXT: %8 = fadd fast <4 x float> %7, %7
41-
// CHECK-NEXT: %9 = insertelement <4 x float> poison, float %x.0.val, i64 0
42-
// CHECK-NEXT: %10 = shufflevector <4 x float> %9, <4 x float> poison, <4 x i32> zeroinitializer
43-
// CHECK-NEXT: %11 = fmul fast <4 x float> %8, %10
44-
// CHECK-NEXT: %12 = extractelement <4 x float> %11, i64 0
45-
// CHECK-NEXT: %13 = insertvalue [4 x float] undef, float %12, 0
46-
// CHECK-NEXT: %14 = extractelement <4 x float> %11, i64 1
47-
// CHECK-NEXT: %15 = insertvalue [4 x float] %13, float %14, 1
48-
// CHECK-NEXT: %16 = extractelement <4 x float> %11, i64 2
49-
// CHECK-NEXT: %17 = insertvalue [4 x float] %15, float %16, 2
50-
// CHECK-NEXT: %18 = extractelement <4 x float> %11, i64 3
51-
// CHECK-NEXT: %19 = insertvalue [4 x float] %17, float %18, 3
52-
// CHECK-NEXT: ret [4 x float] %19
53-
// CHECK-NEXT: }
36+
// CHECK-NEXT: %4 = fadd fast float %"_2'ipl", %"_2'ipl"
37+
// CHECK-NEXT: %5 = fmul fast float %4, %x.0.val
38+
// CHECK-NEXT: %6 = insertvalue [4 x float] undef, float %5, 0
39+
// CHECK-NEXT: %7 = fadd fast float %"_2'ipl1", %"_2'ipl1"
40+
// CHECK-NEXT: %8 = fmul fast float %7, %x.0.val
41+
// CHECK-NEXT: %9 = insertvalue [4 x float] %6, float %8, 1
42+
// CHECK-NEXT: %10 = fadd fast float %"_2'ipl2", %"_2'ipl2"
43+
// CHECK-NEXT: %11 = fmul fast float %10, %x.0.val
44+
// CHECK-NEXT: %12 = insertvalue [4 x float] %9, float %11, 2
45+
// CHECK-NEXT: %13 = fadd fast float %"_2'ipl3", %"_2'ipl3"
46+
// CHECK-NEXT: %14 = fmul fast float %13, %x.0.val
47+
// CHECK-NEXT: %15 = insertvalue [4 x float] %12, float %14, 3
48+
// CHECK-NEXT: ret [4 x float] %15
49+
// CHECK-NEXT: }
5450

5551
// d_square3, the extra float is the original return value (x * x)
5652
// CHECK: define internal fastcc { float, [4 x float] } @fwddiffe4square.1(float %x.0.val, [4 x ptr] %"x'")
@@ -64,26 +60,22 @@ fn square(x: &f32) -> f32 {
6460
// CHECK-NEXT: %3 = extractvalue [4 x ptr] %"x'", 3
6561
// CHECK-NEXT: %"_2'ipl3" = load float, ptr %3, align 4
6662
// CHECK-NEXT: %_0 = fmul float %x.0.val, %x.0.val
67-
// CHECK-NEXT: %4 = insertelement <4 x float> poison, float %"_2'ipl", i64 0
68-
// CHECK-NEXT: %5 = insertelement <4 x float> %4, float %"_2'ipl1", i64 1
69-
// CHECK-NEXT: %6 = insertelement <4 x float> %5, float %"_2'ipl2", i64 2
70-
// CHECK-NEXT: %7 = insertelement <4 x float> %6, float %"_2'ipl3", i64 3
71-
// CHECK-NEXT: %8 = fadd fast <4 x float> %7, %7
72-
// CHECK-NEXT: %9 = insertelement <4 x float> poison, float %x.0.val, i64 0
73-
// CHECK-NEXT: %10 = shufflevector <4 x float> %9, <4 x float> poison, <4 x i32> zeroinitializer
74-
// CHECK-NEXT: %11 = fmul fast <4 x float> %8, %10
75-
// CHECK-NEXT: %12 = extractelement <4 x float> %11, i64 0
76-
// CHECK-NEXT: %13 = insertvalue [4 x float] undef, float %12, 0
77-
// CHECK-NEXT: %14 = extractelement <4 x float> %11, i64 1
78-
// CHECK-NEXT: %15 = insertvalue [4 x float] %13, float %14, 1
79-
// CHECK-NEXT: %16 = extractelement <4 x float> %11, i64 2
80-
// CHECK-NEXT: %17 = insertvalue [4 x float] %15, float %16, 2
81-
// CHECK-NEXT: %18 = extractelement <4 x float> %11, i64 3
82-
// CHECK-NEXT: %19 = insertvalue [4 x float] %17, float %18, 3
83-
// CHECK-NEXT: %20 = insertvalue { float, [4 x float] } undef, float %_0, 0
84-
// CHECK-NEXT: %21 = insertvalue { float, [4 x float] } %20, [4 x float] %19, 1
85-
// CHECK-NEXT: ret { float, [4 x float] } %21
86-
// CHECK-NEXT: }
63+
// CHECK-NEXT: %4 = fadd fast float %"_2'ipl", %"_2'ipl"
64+
// CHECK-NEXT: %5 = fmul fast float %4, %x.0.val
65+
// CHECK-NEXT: %6 = insertvalue [4 x float] undef, float %5, 0
66+
// CHECK-NEXT: %7 = fadd fast float %"_2'ipl1", %"_2'ipl1"
67+
// CHECK-NEXT: %8 = fmul fast float %7, %x.0.val
68+
// CHECK-NEXT: %9 = insertvalue [4 x float] %6, float %8, 1
69+
// CHECK-NEXT: %10 = fadd fast float %"_2'ipl2", %"_2'ipl2"
70+
// CHECK-NEXT: %11 = fmul fast float %10, %x.0.val
71+
// CHECK-NEXT: %12 = insertvalue [4 x float] %9, float %11, 2
72+
// CHECK-NEXT: %13 = fadd fast float %"_2'ipl3", %"_2'ipl3"
73+
// CHECK-NEXT: %14 = fmul fast float %13, %x.0.val
74+
// CHECK-NEXT: %15 = insertvalue [4 x float] %12, float %14, 3
75+
// CHECK-NEXT: %16 = insertvalue { float, [4 x float] } undef, float %_0, 0
76+
// CHECK-NEXT: %17 = insertvalue { float, [4 x float] } %16, [4 x float] %15, 1
77+
// CHECK-NEXT: ret { float, [4 x float] } %17
78+
// CHECK-NEXT: }
8779

8880
fn main() {
8981
let x = std::hint::black_box(3.0);

tests/codegen/autodiff/generic.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
//@ no-prefer-dynamic
33
//@ needs-enzyme
44
#![feature(autodiff)]
5-
#![feature(intrinsics)]
5+
#![feature(core_intrinsics)]
66

77
use std::autodiff::autodiff_reverse;
88

tests/codegen/autodiff/identical_fnc.rs

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
// We also explicetly test that we keep running merge_function after AD, by checking for two
1111
// identical function calls in the LLVM-IR, while having two different calls in the Rust code.
1212
#![feature(autodiff)]
13-
#![feature(intrinsics)]
13+
#![feature(core_intrinsics)]
1414

1515
use std::autodiff::autodiff_reverse;
1616

@@ -30,10 +30,8 @@ fn square2(x: &f64) -> f64 {
3030
// CHECK-NEXT:start:
3131
// CHECK-NOT:br
3232
// CHECK-NOT:ret
33-
// CHECK:; call identical_fnc::d_square
34-
// CHECK-NEXT: call fastcc void @_ZN13identical_fnc8d_square17h4c364207a2f8e06dE(double %x.val, ptr noalias noundef nonnull align 8 dereferenceable(8) %dx1)
35-
// CHECK-NEXT:; call identical_fnc::d_square
36-
// CHECK-NEXT: call fastcc void @_ZN13identical_fnc8d_square17h4c364207a2f8e06dE(double %x.val, ptr noalias noundef nonnull align 8 dereferenceable(8) %dx2)
33+
// CHECK:call fastcc void @diffe_ZN13identical_fnc6square17hdfa1c645848284b7E(double %x.val, ptr %dx1)
34+
// CHECK-NEXT:call fastcc void @diffe_ZN13identical_fnc6square17hdfa1c645848284b7E(double %x.val, ptr %dx2)
3735

3836
fn main() {
3937
let x = std::hint::black_box(3.0);

tests/codegen/autodiff/inline.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
//@ needs-enzyme
44

55
#![feature(autodiff)]
6-
#![feature(intrinsics)]
6+
#![feature(core_intrinsics)]
77

88
use std::autodiff::autodiff_reverse;
99

tests/codegen/autodiff/scalar.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
//@ no-prefer-dynamic
33
//@ needs-enzyme
44
#![feature(autodiff)]
5-
#![feature(intrinsics)]
5+
#![feature(core_intrinsics)]
66

77
use std::autodiff::autodiff_reverse;
88

tests/codegen/autodiff/sret.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
// We therefore use this test to verify some of our sret handling.
99

1010
#![feature(autodiff)]
11-
#![feature(intrinsics)]
11+
#![feature(core_intrinsics)]
1212

1313
use std::autodiff::autodiff_reverse;
1414

tests/codegen/autodiff/trait.rs

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
//@ compile-flags: -Zautodiff=Enable -Zautodiff=NoPostopt -C opt-level=3 -Clto=fat
2+
//@ no-prefer-dynamic
3+
//@ needs-enzyme
4+
5+
// Just check it does not crash for now
6+
// CHECK: ;
7+
#![feature(autodiff)]
8+
#![feature(core_intrinsics)]
9+
10+
use std::autodiff::autodiff_reverse;
11+
12+
struct Foo {
13+
a: f64,
14+
}
15+
16+
trait MyTrait {
17+
fn f(&self, x: f64) -> f64;
18+
fn df(&self, x: f64, seed: f64) -> (f64, f64);
19+
}
20+
21+
impl MyTrait for Foo {
22+
#[autodiff_reverse(df, Const, Active, Active)]
23+
fn f(&self, x: f64) -> f64 {
24+
self.a * 0.25 * (x * x - 1.0 - 2.0 * x.ln())
25+
}
26+
}
27+
28+
fn main() {
29+
let foo = Foo { a: 3.0f64 };
30+
dbg!(foo.df(1.0, 1.0));
31+
}

0 commit comments

Comments
 (0)