Skip to content

Commit 2eb28ce

Browse files
committed
Support ZST args
1 parent cf0eafd commit 2eb28ce

File tree

2 files changed

+25
-1
lines changed
  • compiler/rustc_monomorphize/src/partitioning
  • tests/codegen-llvm/autodiff

2 files changed

+25
-1
lines changed

compiler/rustc_monomorphize/src/partitioning/autodiff.rs

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ fn adjust_activity_to_abi<'tcx>(tcx: TyCtxt<'tcx>, fn_ty: Ty<'tcx>, da: &mut Vec
1919

2020
let mut new_activities = vec![];
2121
let mut new_positions = vec![];
22+
let mut del_activities = 0;
2223
for (i, ty) in sig.inputs().iter().enumerate() {
2324
if let Some(inner_ty) = ty.builtin_deref(true) {
2425
if inner_ty.is_slice() {
@@ -80,12 +81,20 @@ fn adjust_activity_to_abi<'tcx>(tcx: TyCtxt<'tcx>, fn_ty: Ty<'tcx>, da: &mut Vec
8081
}
8182
};
8283

84+
// For ZST, just ignore and don't add its activity, as this arg won't be present
85+
// in the LLVM passed to Enzyme.
86+
// FIXME(Sa4dUs): Enforce ZST corresponding diff activity be `Const`
87+
if layout.is_zst() {
88+
del_activities += 1;
89+
da.remove(i);
90+
}
91+
8392
// If the argument is lowered as a `ScalarPair`, we need to duplicate its activity.
8493
// Otherwise, the number of activities won't match the number of LLVM arguments and
8594
// this will lead to errors when verifying the Enzyme call.
8695
if let rustc_abi::BackendRepr::ScalarPair(_, _) = layout.backend_repr() {
8796
new_activities.push(da[i].clone());
88-
new_positions.push(i + 1);
97+
new_positions.push(i + 1 - del_activities);
8998
}
9099
}
91100
// now add the extra activities coming from slices

tests/codegen-llvm/autodiff/zst.rs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
//@ compile-flags: -Zautodiff=Enable -C opt-level=3 -Clto=fat
2+
//@ no-prefer-dynamic
3+
//@ needs-enzyme
4+
#![feature(autodiff)]
5+
6+
// CHECK: ;
7+
#[core::autodiff::autodiff_forward(fd_inner, Const, Dual)]
8+
fn f(_zst: (), _x: &mut f64) {}
9+
10+
#[unsafe(no_mangle)]
11+
pub extern "C" fn fd(x: &mut f64, xd: &mut f64) {
12+
fd_inner((), x, xd);
13+
}
14+
15+
fn main() {}

0 commit comments

Comments
 (0)