Skip to content

Commit 30861c5

Browse files
committed
compiles and parses width, but doesn't use it yet
1 parent 8c39296 commit 30861c5

File tree

6 files changed

+128
-12
lines changed

6 files changed

+128
-12
lines changed

compiler/rustc_ast/src/expand/autodiff_attrs.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ pub struct AutoDiffAttrs {
7777
/// e.g. in the [JAX
7878
/// Documentation](https://jax.readthedocs.io/en/latest/_tutorials/advanced-autodiff.html#how-it-s-made-two-foundational-autodiff-functions).
7979
pub mode: DiffMode,
80+
pub width: u32,
8081
pub ret_activity: DiffActivity,
8182
pub input_activity: Vec<DiffActivity>,
8283
}
@@ -222,13 +223,15 @@ impl AutoDiffAttrs {
222223
pub const fn error() -> Self {
223224
AutoDiffAttrs {
224225
mode: DiffMode::Error,
226+
width: 0,
225227
ret_activity: DiffActivity::None,
226228
input_activity: Vec::new(),
227229
}
228230
}
229231
pub fn source() -> Self {
230232
AutoDiffAttrs {
231233
mode: DiffMode::Source,
234+
width: 0,
232235
ret_activity: DiffActivity::None,
233236
input_activity: Vec::new(),
234237
}

compiler/rustc_builtin_macros/messages.ftl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ builtin_macros_autodiff_mode_activity = {$act} can not be used in {$mode} Mode
7676
builtin_macros_autodiff_not_build = this rustc version does not support autodiff
7777
builtin_macros_autodiff_number_activities = expected {$expected} activities, but found {$found}
7878
builtin_macros_autodiff_ty_activity = {$act} can not be used for this type
79+
builtin_macros_autodiff_width = autodiff width must fit u32, but is {$width}
7980
8081
builtin_macros_autodiff_unknown_activity = did not recognize Activity: `{$act}`
8182
builtin_macros_bad_derive_target = `derive` may only be applied to `struct`s, `enum`s and `union`s

compiler/rustc_builtin_macros/src/autodiff.rs

Lines changed: 47 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,17 @@ mod llvm_enzyme {
3434
}
3535
}
3636
fn first_ident(x: &MetaItemInner) -> rustc_span::Ident {
37+
if x.lit().is_some() {
38+
let l = x.lit().unwrap();
39+
match l.kind {
40+
ast::LitKind::Int(val, _) => {
41+
// get an Ident from a lit
42+
return rustc_span::Ident::from_str(val.get().to_string().as_str());
43+
}
44+
_ => {}
45+
}
46+
}
47+
3748
let segments = &x.meta_item().unwrap().path.segments;
3849
assert!(segments.len() == 1);
3950
segments[0].ident
@@ -43,6 +54,14 @@ mod llvm_enzyme {
4354
first_ident(x).name.to_string()
4455
}
4556

57+
fn width(x: &MetaItemInner) -> Option<u128> {
58+
let lit = x.lit()?;
59+
match lit.kind {
60+
ast::LitKind::Int(x, _) => Some(x.get()),
61+
_ => return None,
62+
}
63+
}
64+
4665
pub(crate) fn from_ast(
4766
ecx: &mut ExtCtxt<'_>,
4867
meta_item: &ThinVec<MetaItemInner>,
@@ -54,9 +73,28 @@ mod llvm_enzyme {
5473
dcx.emit_err(errors::AutoDiffInvalidMode { span: meta_item[1].span(), mode });
5574
return AutoDiffAttrs::error();
5675
};
76+
77+
// Now we check, whether the user wants autodiff in batch/vector mode, or scalar mode.
78+
// If he doesn't specify an integer (=width), we default to scalar mode, thus width=1.
79+
let mut first_activity = 2;
80+
let width: u32 = match width(&meta_item[2]) {
81+
Some(x) => {
82+
first_activity = 3;
83+
match x.try_into() {
84+
Ok(x) => x,
85+
Err(_) => {
86+
dcx.emit_err(errors::AutoDiffInvalidWidth { span: meta_item[2].span(), width: x });
87+
return AutoDiffAttrs::error();
88+
}
89+
}
90+
},
91+
None => 1,
92+
};
93+
94+
dbg!(&first_activity);
5795
let mut activities: Vec<DiffActivity> = vec![];
5896
let mut errors = false;
59-
for x in &meta_item[2..] {
97+
for x in &meta_item[first_activity..] {
6098
let activity_str = name(&x);
6199
let res = DiffActivity::from_str(&activity_str);
62100
match res {
@@ -87,7 +125,7 @@ mod llvm_enzyme {
87125
(&DiffActivity::None, activities.as_slice())
88126
};
89127

90-
AutoDiffAttrs { mode, ret_activity: *ret_activity, input_activity: input_activity.to_vec() }
128+
AutoDiffAttrs { mode, width, ret_activity: *ret_activity, input_activity: input_activity.to_vec() }
91129
}
92130

93131
/// We expand the autodiff macro to generate a new placeholder function which passes
@@ -193,13 +231,13 @@ mod llvm_enzyme {
193231
// input and output args.
194232
dcx.emit_err(errors::AutoDiffMissingConfig { span: item.span() });
195233
return vec![item];
196-
} else {
197-
for t in meta_item_vec.clone()[1..].iter() {
198-
let val = first_ident(t);
199-
let t = Token::from_ast_ident(val);
200-
ts.push(TokenTree::Token(t, Spacing::Joint));
201-
ts.push(TokenTree::Token(comma.clone(), Spacing::Alone));
202-
}
234+
}
235+
236+
for t in meta_item_vec.clone()[1..].iter() {
237+
let val = first_ident(t);
238+
let t = Token::from_ast_ident(val);
239+
ts.push(TokenTree::Token(t, Spacing::Joint));
240+
ts.push(TokenTree::Token(comma.clone(), Spacing::Alone));
203241
}
204242
if !has_ret {
205243
// We don't want users to provide a return activity if the function doesn't return anything.

compiler/rustc_builtin_macros/src/errors.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,14 @@ mod autodiff {
193193
pub(crate) mode: String,
194194
}
195195

196+
#[derive(Diagnostic)]
197+
#[diag(builtin_macros_autodiff_width)]
198+
pub(crate) struct AutoDiffInvalidWidth {
199+
#[primary_span]
200+
pub(crate) span: Span,
201+
pub(crate) width: u128,
202+
}
203+
196204
#[derive(Diagnostic)]
197205
#[diag(builtin_macros_autodiff)]
198206
pub(crate) struct AutoDiffInvalidApplication {

compiler/rustc_codegen_ssa/src/codegen_attrs.rs

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -875,8 +875,8 @@ fn autodiff_attrs(tcx: TyCtxt<'_>, id: DefId) -> Option<AutoDiffAttrs> {
875875
return Some(AutoDiffAttrs::source());
876876
}
877877

878-
let [mode, input_activities @ .., ret_activity] = &list[..] else {
879-
span_bug!(attr.span(), "rustc_autodiff attribute must contain mode and activities");
878+
let [mode, width_meta, input_activities @ .., ret_activity] = &list[..] else {
879+
span_bug!(attr.span(), "rustc_autodiff attribute must contain mode, width and activities");
880880
};
881881
let mode = if let MetaItemInner::MetaItem(MetaItem { path: p1, .. }) = mode {
882882
p1.segments.first().unwrap().ident
@@ -893,6 +893,39 @@ fn autodiff_attrs(tcx: TyCtxt<'_>, id: DefId) -> Option<AutoDiffAttrs> {
893893
}
894894
};
895895

896+
dbg!(&width_meta);
897+
let w = if let MetaItemInner::MetaItem(MetaItem { path: p1, .. }) = width_meta {
898+
p1.segments.first().unwrap().ident
899+
} else {
900+
span_bug!(attr.span(), "rustc_autodiff attribute must contain width");
901+
};
902+
903+
let width: u32 = match w.as_str().parse() {
904+
Ok(val) => val,
905+
Err(_) => {
906+
span_bug!(w.span, "rustc_autodiff width should fit u32");
907+
}
908+
};
909+
910+
//let width;
911+
//let width = rustc_ast::name(width_meta);
912+
//if let MetaItemInner::Lit(lit) = width_meta {
913+
// match lit.kind {
914+
// rustc_ast::LitKind::Int(val, _) => {width = val.get();}
915+
// _ => {
916+
// span_bug!(lit.span, "rustc_autodiff attribute must contain a width");
917+
// }
918+
// }
919+
//} else {
920+
// span_bug!(attr.span(), "rustc_autodiff attribute must contain a width");
921+
//};
922+
//let width: u32 = match width.try_into() {
923+
// Ok(val) => val,
924+
// Err(_) => {
925+
// span_bug!(width_meta.span(), "rustc_autodiff width should fit u32");
926+
// }
927+
//};
928+
896929
// First read the ret symbol from the attribute
897930
let ret_symbol = if let MetaItemInner::MetaItem(MetaItem { path: p1, .. }) = ret_activity {
898931
p1.segments.first().unwrap().ident
@@ -939,7 +972,7 @@ fn autodiff_attrs(tcx: TyCtxt<'_>, id: DefId) -> Option<AutoDiffAttrs> {
939972
span_bug!(attr.span(), "Invalid return activity {} for {} mode", ret_activity, mode);
940973
}
941974

942-
Some(AutoDiffAttrs { mode, ret_activity, input_activity: arg_activities })
975+
Some(AutoDiffAttrs { mode, width, ret_activity, input_activity: arg_activities })
943976
}
944977

945978
pub(crate) fn provide(providers: &mut Providers) {

tests/codegen/autodiffv.rs

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
//@ compile-flags: -Zautodiff=Enable -C opt-level=3 -Clto=fat
2+
//@ no-prefer-dynamic
3+
//@ needs-enzyme
4+
#![feature(autodiff)]
5+
6+
use std::autodiff::autodiff;
7+
8+
#[autodiff(d_square, Reverse, 4, Duplicated, Active)]
9+
#[no_mangle]
10+
fn square(x: &f64) -> f64 {
11+
x * x
12+
}
13+
14+
// CHECK:define internal fastcc double @diffesquare(double %x.0.val, ptr nocapture align 8 %"x'"
15+
// CHECK-NEXT:invertstart:
16+
// CHECK-NEXT: %_0 = fmul double %x.0.val, %x.0.val
17+
// CHECK-NEXT: %0 = fadd fast double %x.0.val, %x.0.val
18+
// CHECK-NEXT: %1 = load double, ptr %"x'", align 8
19+
// CHECK-NEXT: %2 = fadd fast double %1, %0
20+
// CHECK-NEXT: store double %2, ptr %"x'", align 8
21+
// CHECK-NEXT: ret double %_0
22+
// CHECK-NEXT:}
23+
24+
fn main() {
25+
let x = 3.0;
26+
let output = square(&x);
27+
assert_eq!(9.0, output);
28+
29+
let mut df_dx = 0.0;
30+
let output_ = d_square(&x, &mut df_dx, 1.0);
31+
assert_eq!(output, output_);
32+
assert_eq!(6.0, df_dx);
33+
}

0 commit comments

Comments
 (0)