Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion compiler/rustc_codegen_llvm/src/builder/autodiff.rs
Original file line number Diff line number Diff line change
Expand Up @@ -378,5 +378,12 @@ pub(crate) fn generate_enzyme_call<'ll, 'tcx>(

let call = builder.call(enzyme_ty, None, None, ad_fn, &args, None, None);

builder.store_to_place(call, dest.val);
let fn_ret_ty = builder.cx.val_ty(call);
if fn_ret_ty != builder.cx.type_void() && fn_ret_ty != builder.cx.type_struct(&[], false) {
// If we return void or an empty struct, then our caller (due to how we generated it)
// does not expect a return value. As such, we have no pointer (or place) into which
// we could store our value, and would store into an undef, which would cause UB.
// As such, we just ignore the return value in those cases.
builder.store_to_place(call, dest.val);
}
}
4 changes: 2 additions & 2 deletions tests/codegen-llvm/autodiff/abi_handling.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
//@ revisions: debug release

//@[debug] compile-flags: -Zautodiff=Enable -C opt-level=0 -Clto=fat
//@[release] compile-flags: -Zautodiff=Enable -C opt-level=3 -Clto=fat
//@[debug] compile-flags: -Zautodiff=Enable,NoTT -C opt-level=0 -Clto=fat
//@[release] compile-flags: -Zautodiff=Enable,NoTT -C opt-level=3 -Clto=fat
//@ no-prefer-dynamic
//@ needs-enzyme

Expand Down
2 changes: 1 addition & 1 deletion tests/codegen-llvm/autodiff/batched.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
//@ compile-flags: -Zautodiff=Enable -C opt-level=3 -Clto=fat
//@ compile-flags: -Zautodiff=Enable,NoTT -C opt-level=3 -Clto=fat
//@ no-prefer-dynamic
//@ needs-enzyme
//
Expand Down
2 changes: 1 addition & 1 deletion tests/codegen-llvm/autodiff/scalar.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
//@ compile-flags: -Zautodiff=Enable -C opt-level=3 -Clto=fat
//@ compile-flags: -Zautodiff=Enable,NoTT -C opt-level=3 -Clto=fat
//@ no-prefer-dynamic
//@ needs-enzyme
#![feature(autodiff)]
Expand Down
2 changes: 1 addition & 1 deletion tests/codegen-llvm/autodiff/sret.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
//@ compile-flags: -Zautodiff=Enable -C opt-level=3 -Clto=fat
//@ compile-flags: -Zautodiff=Enable,NoTT -C opt-level=3 -Clto=fat
//@ no-prefer-dynamic
//@ needs-enzyme

Expand Down
41 changes: 41 additions & 0 deletions tests/codegen-llvm/autodiff/void_ret.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
//@ compile-flags: -Zautodiff=Enable,NoTT,NoPostopt -C no-prepopulate-passes -C opt-level=3 -Clto=fat
//@ no-prefer-dynamic
//@ needs-enzyme

#![feature(autodiff)]
use std::autodiff::*;

// Usually we would store the return value of the differentiated function.
// However, if the return type is void or an empty struct,
// we don't need to store anything. Verify this, since it caused a bug.

// CHECK:; void_ret::main
// CHECK-NEXT: ; Function Attrs:
// CHECK-NEXT: define internal
// CHECK-NOT: store {} undef, ptr undef
// CHECK: ret void

#[autodiff_reverse(bar, Duplicated, Duplicated)]
pub fn foo(r: &[f64; 10], res: &mut f64) {
let mut output = [0.0; 10];
output[0] = r[0];
output[1] = r[1] * r[2];
output[2] = r[4] * r[5];
output[3] = r[2] * r[6];
output[4] = r[1] * r[7];
output[5] = r[2] * r[8];
output[6] = r[1] * r[9];
output[7] = r[5] * r[6];
output[8] = r[5] * r[7];
output[9] = r[4] * r[8];
*res = output.iter().sum();
}
fn main() {
let inputs = Box::new([3.1; 10]);
let mut d_inputs = Box::new([0.0; 10]);
let mut res = Box::new(0.0);
let mut d_res = Box::new(1.0);

bar(&inputs, &mut d_inputs, &mut res, &mut d_res);
dbg!(&d_inputs);
}
Loading