Skip to content

Commit bbaecdf

Browse files
authored
Unrolled build for #147200
Rollup merge of #147200 - ZuseZ4:fix-autodiff-emptry-ret, r=Zalathar Fix autodiff empty ret regression closes #147144 The two gsoc summer projects caused a bit of churn, which was to be expected, especially since we don't run autodiff in CI yet. This adds a void return testcase that we should have had anyway, and fixes the regression. r? `@Zalathar` (Just guessing since I've seen you in a few LLVM PRs and Oli is probably still busy. Feel free to reroll!)
2 parents 1e1a394 + de189fa commit bbaecdf

File tree

6 files changed

+54
-6
lines changed

6 files changed

+54
-6
lines changed

compiler/rustc_codegen_llvm/src/builder/autodiff.rs

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -378,5 +378,12 @@ pub(crate) fn generate_enzyme_call<'ll, 'tcx>(
378378

379379
let call = builder.call(enzyme_ty, None, None, ad_fn, &args, None, None);
380380

381-
builder.store_to_place(call, dest.val);
381+
let fn_ret_ty = builder.cx.val_ty(call);
382+
if fn_ret_ty != builder.cx.type_void() && fn_ret_ty != builder.cx.type_struct(&[], false) {
383+
// If we return void or an empty struct, then our caller (due to how we generated it)
384+
// does not expect a return value. As such, we have no pointer (or place) into which
385+
// we could store our value, and would store into an undef, which would cause UB.
386+
// As such, we just ignore the return value in those cases.
387+
builder.store_to_place(call, dest.val);
388+
}
382389
}

tests/codegen-llvm/autodiff/abi_handling.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
//@ revisions: debug release
22

3-
//@[debug] compile-flags: -Zautodiff=Enable -C opt-level=0 -Clto=fat
4-
//@[release] compile-flags: -Zautodiff=Enable -C opt-level=3 -Clto=fat
3+
//@[debug] compile-flags: -Zautodiff=Enable,NoTT -C opt-level=0 -Clto=fat
4+
//@[release] compile-flags: -Zautodiff=Enable,NoTT -C opt-level=3 -Clto=fat
55
//@ no-prefer-dynamic
66
//@ needs-enzyme
77

tests/codegen-llvm/autodiff/batched.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
//@ compile-flags: -Zautodiff=Enable -C opt-level=3 -Clto=fat
1+
//@ compile-flags: -Zautodiff=Enable,NoTT -C opt-level=3 -Clto=fat
22
//@ no-prefer-dynamic
33
//@ needs-enzyme
44
//

tests/codegen-llvm/autodiff/scalar.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
//@ compile-flags: -Zautodiff=Enable -C opt-level=3 -Clto=fat
1+
//@ compile-flags: -Zautodiff=Enable,NoTT -C opt-level=3 -Clto=fat
22
//@ no-prefer-dynamic
33
//@ needs-enzyme
44
#![feature(autodiff)]

tests/codegen-llvm/autodiff/sret.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
//@ compile-flags: -Zautodiff=Enable -C opt-level=3 -Clto=fat
1+
//@ compile-flags: -Zautodiff=Enable,NoTT -C opt-level=3 -Clto=fat
22
//@ no-prefer-dynamic
33
//@ needs-enzyme
44

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
//@ compile-flags: -Zautodiff=Enable,NoTT,NoPostopt -C no-prepopulate-passes -C opt-level=3 -Clto=fat
2+
//@ no-prefer-dynamic
3+
//@ needs-enzyme
4+
5+
#![feature(autodiff)]
6+
use std::autodiff::*;
7+
8+
// Usually we would store the return value of the differentiated function.
9+
// However, if the return type is void or an empty struct,
10+
// we don't need to store anything. Verify this, since it caused a bug.
11+
12+
// CHECK:; void_ret::main
13+
// CHECK-NEXT: ; Function Attrs:
14+
// CHECK-NEXT: define internal
15+
// CHECK-NOT: store {} undef, ptr undef
16+
// CHECK: ret void
17+
18+
#[autodiff_reverse(bar, Duplicated, Duplicated)]
19+
pub fn foo(r: &[f64; 10], res: &mut f64) {
20+
let mut output = [0.0; 10];
21+
output[0] = r[0];
22+
output[1] = r[1] * r[2];
23+
output[2] = r[4] * r[5];
24+
output[3] = r[2] * r[6];
25+
output[4] = r[1] * r[7];
26+
output[5] = r[2] * r[8];
27+
output[6] = r[1] * r[9];
28+
output[7] = r[5] * r[6];
29+
output[8] = r[5] * r[7];
30+
output[9] = r[4] * r[8];
31+
*res = output.iter().sum();
32+
}
33+
fn main() {
34+
let inputs = Box::new([3.1; 10]);
35+
let mut d_inputs = Box::new([0.0; 10]);
36+
let mut res = Box::new(0.0);
37+
let mut d_res = Box::new(1.0);
38+
39+
bar(&inputs, &mut d_inputs, &mut res, &mut d_res);
40+
dbg!(&d_inputs);
41+
}

0 commit comments

Comments
 (0)