@@ -76,15 +76,18 @@ fn match_args_from_caller_to_enzyme<'ll>(
7676 outer_pos = 1 ;
7777 }
7878
79+ // Autodiff activities
7980 let enzyme_const = cx. create_metadata ( "enzyme_const" . to_string ( ) ) . unwrap ( ) ;
8081 let enzyme_out = cx. create_metadata ( "enzyme_out" . to_string ( ) ) . unwrap ( ) ;
8182 let enzyme_dup = cx. create_metadata ( "enzyme_dup" . to_string ( ) ) . unwrap ( ) ;
8283 let enzyme_dupv = cx. create_metadata ( "enzyme_dupv" . to_string ( ) ) . unwrap ( ) ;
8384 let enzyme_dupnoneed = cx. create_metadata ( "enzyme_dupnoneed" . to_string ( ) ) . unwrap ( ) ;
8485 let enzyme_dupnoneedv = cx. create_metadata ( "enzyme_dupnoneedv" . to_string ( ) ) . unwrap ( ) ;
8586
87+ // Batching activities
8688 let enzyme_scalar = cx. create_metadata ( "enzyme_scalar" . to_string ( ) ) . unwrap ( ) ;
8789 let enzyme_vector = cx. create_metadata ( "enzyme_vector" . to_string ( ) ) . unwrap ( ) ;
90+ let enzyme_buffer = cx. create_metadata ( "enzyme_buffer" . to_string ( ) ) . unwrap ( ) ;
8891
8992 while activity_pos < inputs. len ( ) {
9093 let diff_activity = inputs[ activity_pos as usize ] ;
@@ -103,16 +106,17 @@ fn match_args_from_caller_to_enzyme<'ll>(
103106 DiffActivity :: DuplicatedOnly => ( enzyme_dupnoneed, true ) ,
104107 DiffActivity :: FakeActivitySize ( _) => ( enzyme_const, false ) ,
105108 DiffActivity :: Vector => ( enzyme_vector, true ) ,
109+ DiffActivity :: Buffer => ( enzyme_buffer, false ) ,
106110 DiffActivity :: Scalar => ( enzyme_scalar, true ) ,
107111 } ;
108- let no_autodiff_only_batching = matches ! ( diff_activity, DiffActivity :: Scalar | DiffActivity :: Vector ) ;
112+ let no_autodiff_only_batching = matches ! ( diff_activity, DiffActivity :: Scalar | DiffActivity :: Vector | DiffActivity :: Buffer ) ;
109113 let outer_arg = outer_args[ outer_pos] ;
110114 args. push ( cx. get_metadata_value ( activity) ) ;
111115 if matches ! ( diff_activity, DiffActivity :: Dualv ) {
112116 let next_outer_arg = outer_args[ outer_pos + 1 ] ;
113117 let elem_bytes_size: u64 = match inputs[ activity_pos + 1 ] {
114118 DiffActivity :: FakeActivitySize ( Some ( s) ) => s. into ( ) ,
115- _ => bug ! ( "incorrect Dualv handling recognized." ) ,
119+ _ => bug ! ( "incorrect Dualv/Batching handling recognized." ) ,
116120 } ;
117121 // stride: sizeof(T) * n_elems.
118122 // n_elems is the next integer.
@@ -127,7 +131,53 @@ fn match_args_from_caller_to_enzyme<'ll>(
127131 } ;
128132 args. push ( mul) ;
129133 }
134+ if matches ! ( diff_activity, DiffActivity :: Buffer ) {
135+ // There are various cases.
136+ // A) We look at a scalar float.
137+ // B) We look at a Vector/Array of floats (byVal). Not sure if this is valid.
138+ // C) We look at a ptr as part of a slice.
139+ // D) We look at a ptr as part of a raw pointer or reference.
140+
141+ let mut elem_offset = cx. get_const_i64 ( width. into ( ) ) ;
142+ let outer_ty = cx. val_ty ( outer_arg) ;
143+ dbg ! ( & outer_ty) ;
144+ let bit_width = if cx. is_float_type ( outer_ty) {
145+ cx. float_width ( outer_ty)
146+ } else if cx. is_vec_or_array_type ( outer_ty) {
147+ let elem_ty = cx. element_type ( outer_ty) ;
148+ assert ! ( cx. is_float_type( elem_ty) ) ;
149+ let num_vec_elements = cx. vector_length ( outer_ty) ;
150+ assert ! ( num_vec_elements == width as usize ) ;
151+ dbg ! ( & num_vec_elements) ;
152+ cx. float_width ( elem_ty)
153+ } else if cx. is_ptr_type ( outer_ty) {
154+ if is_slice ( activity_pos, inputs) {
155+ elem_offset = outer_args[ outer_pos + 1 ] ;
156+ let elem_bytes_size: u64 = match inputs[ activity_pos + 1 ] {
157+ DiffActivity :: FakeActivitySize ( Some ( s) ) => s. into ( ) ,
158+ _ => bug ! ( "incorrect Dualv/Buffer handling recognized." ) ,
159+ } ;
160+ elem_bytes_size as usize * 8
161+ } else {
162+ // raw pointer or ref, hence `num_elem` = 1
163+ unimplemented ! ( )
164+ }
165+ } else {
166+ bug ! ( "expected float or vector type, found {:?}" , outer_ty) ;
167+ } ;
168+ let elem_bytes_size = bit_width as u64 / 8 ;
169+ let mul = unsafe {
170+ llvm:: LLVMBuildMul (
171+ builder. llbuilder ,
172+ cx. get_const_i64 ( elem_bytes_size) ,
173+ elem_offset,
174+ UNNAMED ,
175+ )
176+ } ;
177+ args. push ( mul) ;
178+ }
130179 args. push ( outer_arg) ;
180+ dbg ! ( & args) ;
131181 if duplicated {
132182 // We know that duplicated args by construction have a following argument,
133183 // so this can not be out of bounds.
@@ -136,17 +186,7 @@ fn match_args_from_caller_to_enzyme<'ll>(
136186 // FIXME(ZuseZ4): We should add support for Vec here too, but it's less urgent since
137187 // vectors behind references (&Vec<T>) are already supported. Users can not pass a
138188 // Vec by value for reverse mode, so this would only help forward mode autodiff.
139- let slice = {
140- if activity_pos + 1 >= inputs. len ( ) {
141- // If there is no arg following our ptr, it also can't be a slice,
142- // since that would lead to a ptr, int pair.
143- false
144- } else {
145- let next_activity = inputs[ activity_pos + 1 ] ;
146- // We analyze the MIR types and add this dummy activity if we visit a slice.
147- matches ! ( next_activity, DiffActivity :: FakeActivitySize ( _) )
148- }
149- } ;
189+ let slice = is_slice ( activity_pos, & inputs) ;
150190 if slice {
151191 // A duplicated slice will have the following two outer_fn arguments:
152192 // (..., ptr1, int1, ptr2, int2, ...). We add the following llvm-ir to our __enzyme call:
@@ -209,6 +249,19 @@ fn match_args_from_caller_to_enzyme<'ll>(
209249 activity_pos += 1 ;
210250 }
211251 }
252+ dbg ! ( "ending" ) ;
253+ }
254+
255+ fn is_slice ( activity_pos : usize , inputs : & [ DiffActivity ] ) -> bool {
256+ if activity_pos + 1 >= inputs. len ( ) {
257+ // If there is no arg following our ptr, it also can't be a slice,
258+ // since that would lead to a ptr, int pair.
259+ false
260+ } else {
261+ let next_activity = inputs[ activity_pos + 1 ] ;
262+ // We analyze the MIR types and add this dummy activity if we visit a slice.
263+ matches ! ( next_activity, DiffActivity :: FakeActivitySize ( _) )
264+ }
212265}
213266
214267// On LLVM-IR, we can luckily declare __enzyme_ functions without specifying the input
@@ -426,6 +479,7 @@ fn generate_enzyme_call<'ll>(
426479
427480 let call = builder. call ( enzyme_ty, ad_fn, & args, None ) ;
428481
482+ dbg ! ( & call) ;
429483 // This part is a bit iffy. LLVM requires that a call to an inlineable function has some
430484 // metadata attached to it, but we just created this code oota. Given that the
431485 // differentiated function already has partly confusing metadata, and given that this
@@ -472,6 +526,7 @@ fn generate_enzyme_call<'ll>(
472526 } else {
473527 builder. ret ( call) ;
474528 }
529+ dbg ! ( "Still alive" ) ;
475530
476531 // Let's crash in case that we messed something up above and generated invalid IR.
477532 llvm:: LLVMRustVerifyFunction (
@@ -531,6 +586,7 @@ pub(crate) fn differentiate<'ll>(
531586
532587 generate_enzyme_call ( & cx, fn_def, fn_target, item. attrs . clone ( ) ) ;
533588 }
589+ dbg ! ( "lowered all" ) ;
534590
535591 // FIXME(ZuseZ4): support SanitizeHWAddress and prevent illegal/unsupported opts
536592
0 commit comments