@@ -34,6 +34,17 @@ mod llvm_enzyme {
3434 }
3535 }
3636 fn first_ident ( x : & MetaItemInner ) -> rustc_span:: Ident {
37+ if x. lit ( ) . is_some ( ) {
38+ let l = x. lit ( ) . unwrap ( ) ;
39+ match l. kind {
40+ ast:: LitKind :: Int ( val, _) => {
41+ // get an Ident from a lit
42+ return rustc_span:: Ident :: from_str ( val. get ( ) . to_string ( ) . as_str ( ) ) ;
43+ }
44+ _ => { }
45+ }
46+ }
47+
3748 let segments = & x. meta_item ( ) . unwrap ( ) . path . segments ;
3849 assert ! ( segments. len( ) == 1 ) ;
3950 segments[ 0 ] . ident
@@ -43,6 +54,14 @@ mod llvm_enzyme {
4354 first_ident ( x) . name . to_string ( )
4455 }
4556
57+ fn width ( x : & MetaItemInner ) -> Option < u128 > {
58+ let lit = x. lit ( ) ?;
59+ match lit. kind {
60+ ast:: LitKind :: Int ( x, _) => Some ( x. get ( ) ) ,
61+ _ => return None ,
62+ }
63+ }
64+
4665 pub ( crate ) fn from_ast (
4766 ecx : & mut ExtCtxt < ' _ > ,
4867 meta_item : & ThinVec < MetaItemInner > ,
@@ -54,9 +73,28 @@ mod llvm_enzyme {
5473 dcx. emit_err ( errors:: AutoDiffInvalidMode { span : meta_item[ 1 ] . span ( ) , mode } ) ;
5574 return AutoDiffAttrs :: error ( ) ;
5675 } ;
76+
77+ // Now we check, whether the user wants autodiff in batch/vector mode, or scalar mode.
78+ // If he doesn't specify an integer (=width), we default to scalar mode, thus width=1.
79+ let mut first_activity = 2 ;
80+ let width: u32 = match width ( & meta_item[ 2 ] ) {
81+ Some ( x) => {
82+ first_activity = 3 ;
83+ match x. try_into ( ) {
84+ Ok ( x) => x,
85+ Err ( _) => {
86+ dcx. emit_err ( errors:: AutoDiffInvalidWidth { span : meta_item[ 2 ] . span ( ) , width : x } ) ;
87+ return AutoDiffAttrs :: error ( ) ;
88+ }
89+ }
90+ } ,
91+ None => 1 ,
92+ } ;
93+
94+ dbg ! ( & first_activity) ;
5795 let mut activities: Vec < DiffActivity > = vec ! [ ] ;
5896 let mut errors = false ;
59- for x in & meta_item[ 2 ..] {
97+ for x in & meta_item[ first_activity ..] {
6098 let activity_str = name ( & x) ;
6199 let res = DiffActivity :: from_str ( & activity_str) ;
62100 match res {
@@ -87,7 +125,7 @@ mod llvm_enzyme {
87125 ( & DiffActivity :: None , activities. as_slice ( ) )
88126 } ;
89127
90- AutoDiffAttrs { mode, ret_activity : * ret_activity, input_activity : input_activity. to_vec ( ) }
128+ AutoDiffAttrs { mode, width , ret_activity : * ret_activity, input_activity : input_activity. to_vec ( ) }
91129 }
92130
93131 /// We expand the autodiff macro to generate a new placeholder function which passes
@@ -193,13 +231,13 @@ mod llvm_enzyme {
193231 // input and output args.
194232 dcx. emit_err ( errors:: AutoDiffMissingConfig { span : item. span ( ) } ) ;
195233 return vec ! [ item] ;
196- } else {
197- for t in meta_item_vec . clone ( ) [ 1 .. ] . iter ( ) {
198- let val = first_ident ( t ) ;
199- let t = Token :: from_ast_ident ( val ) ;
200- ts . push ( TokenTree :: Token ( t , Spacing :: Joint ) ) ;
201- ts. push ( TokenTree :: Token ( comma . clone ( ) , Spacing :: Alone ) ) ;
202- }
234+ }
235+
236+ for t in meta_item_vec . clone ( ) [ 1 .. ] . iter ( ) {
237+ let val = first_ident ( t ) ;
238+ let t = Token :: from_ast_ident ( val ) ;
239+ ts. push ( TokenTree :: Token ( t , Spacing :: Joint ) ) ;
240+ ts . push ( TokenTree :: Token ( comma . clone ( ) , Spacing :: Alone ) ) ;
203241 }
204242 if !has_ret {
205243 // We don't want users to provide a return activity if the function doesn't return anything.
0 commit comments