Skip to content

Commit 27b3881

Browse files
authored
Rollup merge of #147390 - ZuseZ4:autodiff-dbg, r=jieyouxu
Use globals instead of metadata for std::autodiff LLVM's Metadata is quite fragile. In debug builds we use incremental compilation, which caused the metadata to be dropped. With this change we use named globals instead of metadata to instruct Enzyme how to differentiate functions. Globals are proper llvm values and thus can't be dropped. Also added an incremental/dbg test which now passes, to unblock the EnzymeAD CI which wants to run Rust autodiff tests. r? compiler
2 parents 291e129 + 218fa60 commit 27b3881

File tree

4 files changed

+76
-83
lines changed

4 files changed

+76
-83
lines changed

compiler/rustc_codegen_llvm/src/builder/autodiff.rs

Lines changed: 30 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ use tracing::debug;
1212
use crate::builder::{Builder, PlaceRef, UNNAMED};
1313
use crate::context::SimpleCx;
1414
use crate::declare::declare_simple_fn;
15-
use crate::llvm::{self, Metadata, TRUE, Type, Value};
15+
use crate::llvm::{self, TRUE, Type, Value};
1616

1717
pub(crate) fn adjust_activity_to_abi<'tcx>(
1818
tcx: TyCtxt<'tcx>,
@@ -143,9 +143,9 @@ fn match_args_from_caller_to_enzyme<'ll, 'tcx>(
143143
cx: &SimpleCx<'ll>,
144144
builder: &mut Builder<'_, 'll, 'tcx>,
145145
width: u32,
146-
args: &mut Vec<&'ll llvm::Value>,
146+
args: &mut Vec<&'ll Value>,
147147
inputs: &[DiffActivity],
148-
outer_args: &[&'ll llvm::Value],
148+
outer_args: &[&'ll Value],
149149
) {
150150
debug!("matching autodiff arguments");
151151
// We now handle the issue that Rust level arguments not always match the llvm-ir level
@@ -157,32 +157,36 @@ fn match_args_from_caller_to_enzyme<'ll, 'tcx>(
157157
let mut outer_pos: usize = 0;
158158
let mut activity_pos = 0;
159159

160-
let enzyme_const = cx.create_metadata(b"enzyme_const");
161-
let enzyme_out = cx.create_metadata(b"enzyme_out");
162-
let enzyme_dup = cx.create_metadata(b"enzyme_dup");
163-
let enzyme_dupv = cx.create_metadata(b"enzyme_dupv");
164-
let enzyme_dupnoneed = cx.create_metadata(b"enzyme_dupnoneed");
165-
let enzyme_dupnoneedv = cx.create_metadata(b"enzyme_dupnoneedv");
160+
// We used to use llvm's metadata to instruct enzyme how to differentiate a function.
161+
// In debug mode we would use incremental compilation which caused the metadata to be
162+
// dropped. This is prevented by now using named globals, which are also understood
163+
// by Enzyme.
164+
let global_const = cx.declare_global("enzyme_const", cx.type_ptr());
165+
let global_out = cx.declare_global("enzyme_out", cx.type_ptr());
166+
let global_dup = cx.declare_global("enzyme_dup", cx.type_ptr());
167+
let global_dupv = cx.declare_global("enzyme_dupv", cx.type_ptr());
168+
let global_dupnoneed = cx.declare_global("enzyme_dupnoneed", cx.type_ptr());
169+
let global_dupnoneedv = cx.declare_global("enzyme_dupnoneedv", cx.type_ptr());
166170

167171
while activity_pos < inputs.len() {
168172
let diff_activity = inputs[activity_pos as usize];
169173
// Duplicated arguments received a shadow argument, into which enzyme will write the
170174
// gradient.
171-
let (activity, duplicated): (&Metadata, bool) = match diff_activity {
175+
let (activity, duplicated): (&Value, bool) = match diff_activity {
172176
DiffActivity::None => panic!("not a valid input activity"),
173-
DiffActivity::Const => (enzyme_const, false),
174-
DiffActivity::Active => (enzyme_out, false),
175-
DiffActivity::ActiveOnly => (enzyme_out, false),
176-
DiffActivity::Dual => (enzyme_dup, true),
177-
DiffActivity::Dualv => (enzyme_dupv, true),
178-
DiffActivity::DualOnly => (enzyme_dupnoneed, true),
179-
DiffActivity::DualvOnly => (enzyme_dupnoneedv, true),
180-
DiffActivity::Duplicated => (enzyme_dup, true),
181-
DiffActivity::DuplicatedOnly => (enzyme_dupnoneed, true),
182-
DiffActivity::FakeActivitySize(_) => (enzyme_const, false),
177+
DiffActivity::Const => (global_const, false),
178+
DiffActivity::Active => (global_out, false),
179+
DiffActivity::ActiveOnly => (global_out, false),
180+
DiffActivity::Dual => (global_dup, true),
181+
DiffActivity::Dualv => (global_dupv, true),
182+
DiffActivity::DualOnly => (global_dupnoneed, true),
183+
DiffActivity::DualvOnly => (global_dupnoneedv, true),
184+
DiffActivity::Duplicated => (global_dup, true),
185+
DiffActivity::DuplicatedOnly => (global_dupnoneed, true),
186+
DiffActivity::FakeActivitySize(_) => (global_const, false),
183187
};
184188
let outer_arg = outer_args[outer_pos];
185-
args.push(cx.get_metadata_value(activity));
189+
args.push(activity);
186190
if matches!(diff_activity, DiffActivity::Dualv) {
187191
let next_outer_arg = outer_args[outer_pos + 1];
188192
let elem_bytes_size: u64 = match inputs[activity_pos + 1] {
@@ -242,7 +246,7 @@ fn match_args_from_caller_to_enzyme<'ll, 'tcx>(
242246
assert_eq!(cx.type_kind(next_outer_ty3), TypeKind::Integer);
243247
args.push(next_outer_arg2);
244248
}
245-
args.push(cx.get_metadata_value(enzyme_const));
249+
args.push(global_const);
246250
args.push(next_outer_arg);
247251
outer_pos += 2 + 2 * iterations;
248252
activity_pos += 2;
@@ -351,13 +355,13 @@ pub(crate) fn generate_enzyme_call<'ll, 'tcx>(
351355
let mut args = Vec::with_capacity(num_args as usize + 1);
352356
args.push(fn_to_diff);
353357

354-
let enzyme_primal_ret = cx.create_metadata(b"enzyme_primal_return");
358+
let global_primal_ret = cx.declare_global("enzyme_primal_return", cx.type_ptr());
355359
if matches!(attrs.ret_activity, DiffActivity::Dual | DiffActivity::Active) {
356-
args.push(cx.get_metadata_value(enzyme_primal_ret));
360+
args.push(global_primal_ret);
357361
}
358362
if attrs.width > 1 {
359-
let enzyme_width = cx.create_metadata(b"enzyme_width");
360-
args.push(cx.get_metadata_value(enzyme_width));
363+
let global_width = cx.declare_global("enzyme_width", cx.type_ptr());
364+
args.push(global_width);
361365
args.push(cx.get_const_int(cx.type_i64(), attrs.width as u64));
362366
}
363367

tests/ui/autodiff/autodiff_illegal.rs

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -110,15 +110,6 @@ fn f14(x: f32) -> Foo {
110110

111111
type MyFloat = f32;
112112

113-
// We would like to support type alias to f32/f64 in argument type in the future,
114-
// but that requires us to implement our checks at a later stage
115-
// like THIR which has type information available.
116-
#[autodiff_reverse(df15, Active, Active)]
117-
fn f15(x: MyFloat) -> f32 {
118-
//~^^ ERROR failed to resolve: use of undeclared type `MyFloat` [E0433]
119-
unimplemented!()
120-
}
121-
122113
// We would like to support type alias to f32/f64 in return type in the future
123114
#[autodiff_reverse(df16, Active, Active)]
124115
fn f16(x: f32) -> MyFloat {
@@ -136,13 +127,6 @@ fn f17(x: f64) -> F64Trans {
136127
unimplemented!()
137128
}
138129

139-
// We would like to support `#[repr(transparent)]` f32/f64 wrapper in argument type in the future
140-
#[autodiff_reverse(df18, Active, Active)]
141-
fn f18(x: F64Trans) -> f64 {
142-
//~^^ ERROR failed to resolve: use of undeclared type `F64Trans` [E0433]
143-
unimplemented!()
144-
}
145-
146130
// Invalid return activity
147131
#[autodiff_forward(df19, Dual, Active)]
148132
fn f19(x: f32) -> f32 {
@@ -163,11 +147,4 @@ fn f21(x: f32) -> f32 {
163147
unimplemented!()
164148
}
165149

166-
struct DoesNotImplDefault;
167-
#[autodiff_forward(df22, Dual)]
168-
pub fn f22() -> DoesNotImplDefault {
169-
//~^^ ERROR the function or associated item `default` exists for tuple `(DoesNotImplDefault, DoesNotImplDefault)`, but its trait bounds were not satisfied
170-
unimplemented!()
171-
}
172-
173150
fn main() {}

tests/ui/autodiff/autodiff_illegal.stderr

Lines changed: 5 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -107,53 +107,24 @@ LL | #[autodiff_reverse(df13, Reverse)]
107107
| ^^^^^^^
108108

109109
error: invalid return activity Active in Forward Mode
110-
--> $DIR/autodiff_illegal.rs:147:1
110+
--> $DIR/autodiff_illegal.rs:131:1
111111
|
112112
LL | #[autodiff_forward(df19, Dual, Active)]
113113
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
114114

115115
error: invalid return activity Dual in Reverse Mode
116-
--> $DIR/autodiff_illegal.rs:153:1
116+
--> $DIR/autodiff_illegal.rs:137:1
117117
|
118118
LL | #[autodiff_reverse(df20, Active, Dual)]
119119
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
120120

121121
error: invalid return activity Duplicated in Reverse Mode
122-
--> $DIR/autodiff_illegal.rs:160:1
122+
--> $DIR/autodiff_illegal.rs:144:1
123123
|
124124
LL | #[autodiff_reverse(df21, Active, Duplicated)]
125125
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
126126

127-
error[E0433]: failed to resolve: use of undeclared type `MyFloat`
128-
--> $DIR/autodiff_illegal.rs:116:1
129-
|
130-
LL | #[autodiff_reverse(df15, Active, Active)]
131-
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ use of undeclared type `MyFloat`
132-
133-
error[E0433]: failed to resolve: use of undeclared type `F64Trans`
134-
--> $DIR/autodiff_illegal.rs:140:1
135-
|
136-
LL | #[autodiff_reverse(df18, Active, Active)]
137-
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ use of undeclared type `F64Trans`
138-
139-
error[E0599]: the function or associated item `default` exists for tuple `(DoesNotImplDefault, DoesNotImplDefault)`, but its trait bounds were not satisfied
140-
--> $DIR/autodiff_illegal.rs:167:1
141-
|
142-
LL | struct DoesNotImplDefault;
143-
| ------------------------- doesn't satisfy `DoesNotImplDefault: Default`
144-
LL | #[autodiff_forward(df22, Dual)]
145-
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ function or associated item cannot be called on `(DoesNotImplDefault, DoesNotImplDefault)` due to unsatisfied trait bounds
146-
|
147-
= note: the following trait bounds were not satisfied:
148-
`DoesNotImplDefault: Default`
149-
which is required by `(DoesNotImplDefault, DoesNotImplDefault): Default`
150-
help: consider annotating `DoesNotImplDefault` with `#[derive(Default)]`
151-
|
152-
LL + #[derive(Default)]
153-
LL | struct DoesNotImplDefault;
154-
|
155-
156-
error: aborting due to 21 previous errors
127+
error: aborting due to 18 previous errors
157128

158-
Some errors have detailed explanations: E0428, E0433, E0599, E0658.
129+
Some errors have detailed explanations: E0428, E0658.
159130
For more information about an error, try `rustc --explain E0428`.

tests/ui/autodiff/incremental.rs

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
//@ revisions: DEBUG RELEASE
2+
//@[RELEASE] compile-flags: -Zautodiff=Enable,NoTT -C opt-level=3 -Clto=fat
3+
//@[DEBUG] compile-flags: -Zautodiff=Enable,NoTT -C opt-level=0 -Clto=fat -C debuginfo=2
4+
//@ needs-enzyme
5+
//@ incremental
6+
//@ no-prefer-dynamic
7+
//@ build-pass
8+
#![crate_type = "bin"]
9+
#![feature(autodiff)]
10+
11+
// We used to use llvm's metadata to instruct enzyme how to differentiate a function.
12+
// In debug mode we would use incremental compilation which caused the metadata to be
13+
// dropped. We now use globals instead and add this test to verify that incremental
14+
// keeps working. Also testing debug mode while at it.
15+
16+
use std::autodiff::autodiff_reverse;
17+
18+
#[autodiff_reverse(bar, Duplicated, Duplicated)]
19+
pub fn foo(r: &[f64; 10], res: &mut f64) {
20+
let mut output = [0.0; 10];
21+
output[0] = r[0];
22+
output[1] = r[1] * r[2];
23+
output[2] = r[4] * r[5];
24+
output[3] = r[2] * r[6];
25+
output[4] = r[1] * r[7];
26+
output[5] = r[2] * r[8];
27+
output[6] = r[1] * r[9];
28+
output[7] = r[5] * r[6];
29+
output[8] = r[5] * r[7];
30+
output[9] = r[4] * r[8];
31+
*res = output.iter().sum();
32+
}
33+
fn main() {
34+
let inputs = Box::new([3.1; 10]);
35+
let mut d_inputs = Box::new([0.0; 10]);
36+
let mut res = Box::new(0.0);
37+
let mut d_res = Box::new(1.0);
38+
39+
bar(&inputs, &mut d_inputs, &mut res, &mut d_res);
40+
dbg!(&d_inputs);
41+
}

0 commit comments

Comments
 (0)