Skip to content

Commit 5e5bcd4

Browse files
committed
Add test for autodiff abi handling
1 parent c018ae5 commit 5e5bcd4

File tree

1 file changed

+209
-0
lines changed

1 file changed

+209
-0
lines changed
Lines changed: 209 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,209 @@
1+
//@ revisions: debug release
2+
3+
//@[debug] compile-flags: -Zautodiff=Enable -C opt-level=0 -Clto=fat
4+
//@[release] compile-flags: -Zautodiff=Enable -C opt-level=3 -Clto=fat
5+
//@ no-prefer-dynamic
6+
//@ needs-enzyme
7+
8+
// This test checks that Rust types are lowered to LLVM-IR types in a way
9+
// we expect and Enzyme can handle. We explicitly check release mode to
10+
// ensure that LLVM's O3 pipeline doesn't rewrite function signatures
11+
// into forms that Enzyme can't process correctly.
12+
13+
#![feature(autodiff)]
14+
15+
use std::autodiff::{autodiff_forward, autodiff_reverse};
16+
17+
#[derive(Copy, Clone)]
18+
struct Input {
19+
x: f32,
20+
y: f32,
21+
}
22+
23+
#[derive(Copy, Clone)]
24+
struct Wrapper {
25+
z: f32,
26+
}
27+
28+
#[derive(Copy, Clone)]
29+
struct NestedInput {
30+
x: f32,
31+
y: Wrapper,
32+
}
33+
34+
fn square(x: f32) -> f32 {
35+
x * x
36+
}
37+
38+
// CHECK: ; abi_handling::df1
39+
// CHECK-NEXT: Function Attrs
40+
// debug-NEXT: define internal { float, float } @_ZN12abi_handling3df117h55144f697f16c530E
41+
// debug-SAME: (ptr align 4 %x, ptr align 4 %bx_0)
42+
// release-NEXT: define internal fastcc float @_ZN12abi_handling3df117h55144f697f16c530E
43+
// release-SAME: (float %x.0.val, float %x.4.val)
44+
45+
// CHECK: ; abi_handling::f1
46+
// CHECK-NEXT: Function Attrs
47+
// debug-NEXT: define internal float @_ZN12abi_handling2f117hd2edd01111f953c8E
48+
// debug-SAME: (ptr align 4 %x)
49+
// release-NEXT: define internal fastcc noundef float @_ZN12abi_handling2f117hd2edd01111f953c8E
50+
// release-SAME: (float %x.0.val, float %x.4.val)
51+
#[autodiff_forward(df1, Dual, Dual)]
52+
#[inline(never)]
53+
fn f1(x: &[f32; 2]) -> f32 {
54+
x[0] + x[1]
55+
}
56+
57+
// CHECK: ; abi_handling::df2
58+
// CHECK-NEXT: Function Attrs
59+
// debug-NEXT: define internal { float, float } @_ZN12abi_handling3df217h8225b3063251a601E
60+
// debug-SAME: (ptr %f, float %x, float %dret)
61+
// release-NEXT: define internal fastcc float @_ZN12abi_handling3df217h8225b3063251a601E
62+
// release-SAME: (float noundef %x)
63+
64+
// CHECK: ; abi_handling::f2
65+
// CHECK-NEXT: Function Attrs
66+
// debug-NEXT: define internal float @_ZN12abi_handling2f217ha5269d6a663e4b66E
67+
// debug-SAME: (ptr %f, float %x)
68+
// release-NEXT: define internal fastcc noundef float @_ZN12abi_handling2f217ha5269d6a663e4b66E
69+
// release-SAME: (float noundef %x)
70+
#[autodiff_reverse(df2, Const, Active, Active)]
71+
#[inline(never)]
72+
fn f2(f: fn(f32) -> f32, x: f32) -> f32 {
73+
f(x)
74+
}
75+
76+
// CHECK: ; abi_handling::df3
77+
// CHECK-NEXT: Function Attrs
78+
// debug: define internal { float, float } @_ZN12abi_handling3df317h46666d86d46a2ce1E
79+
// debug-SAME: (ptr align 4 %x, ptr align 4 %bx_0, ptr align 4 %y, ptr align 4 %by_0)
80+
// release-NEXT: define internal fastcc { float, float } @_ZN12abi_handling3df317h46666d86d46a2ce1E
81+
// release-SAME: (float %x.0.val)
82+
83+
// CHECK: ; abi_handling::f3
84+
// CHECK-NEXT: Function Attrs
85+
// debug-NEXT: define internal float @_ZN12abi_handling2f317h3dc8010df407d95eE
86+
// debug-SAME: (ptr align 4 %x, ptr align 4 %y)
87+
// release-NEXT: define internal fastcc noundef float @_ZN12abi_handling2f317h3dc8010df407d95eE
88+
// release-SAME: (float %x.0.val)
89+
#[autodiff_forward(df3, Dual, Dual, Dual)]
90+
#[inline(never)]
91+
fn f3<'a>(x: &'a f32, y: &'a f32) -> f32 {
92+
*x * *y
93+
}
94+
95+
// CHECK: ; abi_handling::df4
96+
// CHECK-NEXT: Function Attrs
97+
// debug: define internal { float, float } @_ZN12abi_handling3df417h5cc6e38f26bcfe23E
98+
// debug-SAME: (float %x.0, float %x.1, float %bx_0.0, float %bx_0.1)
99+
// release-NEXT: define internal fastcc { float, float } @_ZN12abi_handling3df417h5cc6e38f26bcfe23E
100+
// release-SAME: (float noundef %x.0, float noundef %x.1)
101+
102+
// CHECK: ; abi_handling::f4
103+
// CHECK-NEXT: Function Attrs
104+
// debug-NEXT: define internal float @_ZN12abi_handling2f417h4dd0d54e71689dcaE
105+
// debug-SAME: (float %x.0, float %x.1)
106+
// release-NEXT: define internal fastcc noundef float @_ZN12abi_handling2f417h4dd0d54e71689dcaE
107+
// release-SAME: (float noundef %x.0, float noundef %x.1)
108+
#[autodiff_forward(df4, Dual, Dual)]
109+
#[inline(never)]
110+
fn f4(x: (f32, f32)) -> f32 {
111+
x.0 * x.1
112+
}
113+
114+
// CHECK: ; abi_handling::df5
115+
// CHECK-NEXT: Function Attrs
116+
// debug-NEXT: define internal { float, float } @_ZN12abi_handling3df517hb63274288df3629dE
117+
// debug-SAME: (float %i.0, float %i.1, float %bi_0.0, float %bi_0.1)
118+
// release-NEXT: define internal fastcc { float, float } @_ZN12abi_handling3df517hb63274288df3629dE
119+
// release-SAME: (float noundef %i.0, float noundef %i.1)
120+
121+
// CHECK: ; abi_handling::f5
122+
// CHECK-NEXT: Function Attrs
123+
// debug-NEXT: define internal float @_ZN12abi_handling2f517hd1e1fe9d327743a4E
124+
// debug-SAME: (float %i.0, float %i.1)
125+
// release-NEXT: define internal fastcc noundef float @_ZN12abi_handling2f517hd1e1fe9d327743a4E
126+
// release-SAME: (float noundef %i.0, float noundef %i.1)
127+
#[autodiff_forward(df5, Dual, Dual)]
128+
#[inline(never)]
129+
fn f5(i: Input) -> f32 {
130+
i.x + i.y
131+
}
132+
133+
// CHECK: ; abi_handling::df6
134+
// CHECK-NEXT: Function Attrs
135+
// debug-NEXT: define internal { float, float } @_ZN12abi_handling3df617hf53c2881f50cee19E
136+
// debug-SAME: (float %i.0, float %i.1, float %bi_0.0, float %bi_0.1)
137+
// release-NEXT: define internal fastcc { float, float } @_ZN12abi_handling3df617hf53c2881f50cee19E
138+
// release-SAME: (float noundef %i.0, float noundef %i.1, float noundef %bi_0.0, float noundef %bi_0.1)
139+
140+
// CHECK: ; abi_handling::f6
141+
// CHECK-NEXT: Function Attrs
142+
// debug-NEXT: define internal float @_ZN12abi_handling2f617h7249eba95dfb35c5E
143+
// debug-SAME: (float %i.0, float %i.1)
144+
// release-NEXT: define internal fastcc noundef float @_ZN12abi_handling2f617h7249eba95dfb35c5E
145+
// release-SAME: (float noundef %i.0, float noundef %i.1)
146+
#[autodiff_forward(df6, Dual, Dual)]
147+
#[inline(never)]
148+
fn f6(i: NestedInput) -> f32 {
149+
i.x + i.y.z * i.y.z
150+
}
151+
152+
// CHECK: ; abi_handling::df7
153+
// CHECK-NEXT: Function Attrs
154+
// debug-NEXT: define internal { float, float } @_ZN12abi_handling3df717h51ca21c37d451463E
155+
// debug-SAME: (ptr align 4 %x.0, ptr align 4 %x.1, ptr align 4 %bx_0.0, ptr align 4 %bx_0.1)
156+
// release-NEXT: define internal fastcc { float, float } @_ZN12abi_handling3df717h51ca21c37d451463E
157+
// release-SAME: (float %x.0.0.val, float %x.1.0.val)
158+
159+
// CHECK: ; abi_handling::f7
160+
// CHECK-NEXT: Function Attrs
161+
// debug-NEXT: define internal float @_ZN12abi_handling2f717hf4d2ab6f1a195af6E
162+
// debug-SAME: (ptr align 4 %x.0, ptr align 4 %x.1)
163+
// release-NEXT: define internal fastcc noundef float @_ZN12abi_handling2f717hf4d2ab6f1a195af6E
164+
// release-SAME: (float %x.0.0.val, float %x.1.0.val)
165+
#[autodiff_forward(df7, Dual, Dual)]
166+
#[inline(never)]
167+
fn f7(x: (&f32, &f32)) -> f32 {
168+
x.0 * x.1
169+
}
170+
171+
fn main() {
172+
let x = std::hint::black_box(2.0);
173+
let y = std::hint::black_box(3.0);
174+
let z = std::hint::black_box(4.0);
175+
static Y: f32 = std::hint::black_box(3.2);
176+
177+
let in_f1 = [x, y];
178+
dbg!(f1(&in_f1));
179+
let res_f1 = df1(&in_f1, &[1.0, 0.0]);
180+
dbg!(res_f1);
181+
182+
dbg!(f2(square, x));
183+
let res_f2 = df2(square, x, 1.0);
184+
dbg!(res_f2);
185+
186+
dbg!(f3(&x, &Y));
187+
let res_f3 = df3(&x, &Y, &1.0, &0.0);
188+
dbg!(res_f3);
189+
190+
let in_f4 = (x, y);
191+
dbg!(f4(in_f4));
192+
let res_f4 = df4(in_f4, (1.0, 0.0));
193+
dbg!(res_f4);
194+
195+
let in_f5 = Input { x, y };
196+
dbg!(f5(in_f5));
197+
let res_f5 = df5(in_f5, Input { x: 1.0, y: 0.0 });
198+
dbg!(res_f5);
199+
200+
let in_f6 = NestedInput { x, y: Wrapper { z: y } };
201+
dbg!(f6(in_f6));
202+
let res_f6 = df6(in_f6, NestedInput { x, y: Wrapper { z } });
203+
dbg!(res_f6);
204+
205+
let in_f7 = (&x, &y);
206+
dbg!(f7(in_f7));
207+
let res_f7 = df7(in_f7, (&1.0, &0.0));
208+
dbg!(res_f7);
209+
}

0 commit comments

Comments
 (0)