@@ -10,7 +10,7 @@ use rustc_middle::bug;
1010use tracing:: { debug, trace} ;
1111
1212use crate :: back:: write:: llvm_err;
13- use crate :: builder:: SBuilder ;
13+ use crate :: builder:: { SBuilder , UNNAMED } ;
1414use crate :: context:: SimpleCx ;
1515use crate :: declare:: declare_simple_fn;
1616use crate :: errors:: { AutoDiffWithoutEnable , LlvmError } ;
@@ -51,6 +51,7 @@ fn has_sret(fnc: &Value) -> bool {
5151// using iterators and peek()?
5252fn match_args_from_caller_to_enzyme < ' ll > (
5353 cx : & SimpleCx < ' ll > ,
54+ builder : & SBuilder < ' ll , ' ll > ,
5455 width : u32 ,
5556 args : & mut Vec < & ' ll llvm:: Value > ,
5657 inputs : & [ DiffActivity ] ,
@@ -78,6 +79,7 @@ fn match_args_from_caller_to_enzyme<'ll>(
7879 let enzyme_const = cx. create_metadata ( "enzyme_const" . to_string ( ) ) . unwrap ( ) ;
7980 let enzyme_out = cx. create_metadata ( "enzyme_out" . to_string ( ) ) . unwrap ( ) ;
8081 let enzyme_dup = cx. create_metadata ( "enzyme_dup" . to_string ( ) ) . unwrap ( ) ;
82+ let enzyme_dupv = cx. create_metadata ( "enzyme_dupv" . to_string ( ) ) . unwrap ( ) ;
8183 let enzyme_dupnoneed = cx. create_metadata ( "enzyme_dupnoneed" . to_string ( ) ) . unwrap ( ) ;
8284
8385 while activity_pos < inputs. len ( ) {
@@ -90,13 +92,26 @@ fn match_args_from_caller_to_enzyme<'ll>(
9092 DiffActivity :: Active => ( enzyme_out, false ) ,
9193 DiffActivity :: ActiveOnly => ( enzyme_out, false ) ,
9294 DiffActivity :: Dual => ( enzyme_dup, true ) ,
95+ DiffActivity :: Dualv => ( enzyme_dupv, true ) ,
9396 DiffActivity :: DualOnly => ( enzyme_dupnoneed, true ) ,
9497 DiffActivity :: Duplicated => ( enzyme_dup, true ) ,
9598 DiffActivity :: DuplicatedOnly => ( enzyme_dupnoneed, true ) ,
9699 DiffActivity :: FakeActivitySize => ( enzyme_const, false ) ,
97100 } ;
98101 let outer_arg = outer_args[ outer_pos] ;
99102 args. push ( cx. get_metadata_value ( activity) ) ;
103+ if matches ! ( diff_activity, DiffActivity :: Dualv ) {
104+ let next_outer_arg = outer_args[ outer_pos + 1 ] ;
105+ // stride: sizeof(T) * n_elems.
106+ // T=f32 => 4 bytes
107+ // n_elems is the next integer.
108+ // Now we multiply `4 * next_outer_arg` to get the stride.
109+ //let mul = builder
110+ // .build_mul(cx.get_const_i64(4), next_outer_arg)
111+ // .unwrap();
112+ let mul = unsafe { llvm:: LLVMBuildMul ( builder. llbuilder , cx. get_const_i64 ( 4 ) , next_outer_arg, UNNAMED ) } ;
113+ args. push ( mul) ;
114+ }
100115 args. push ( outer_arg) ;
101116 if duplicated {
102117 // We know that duplicated args by construction have a following argument,
@@ -125,7 +140,13 @@ fn match_args_from_caller_to_enzyme<'ll>(
125140 // int2 >= int1, which means the shadow vector is large enough to store the gradient.
126141 assert_eq ! ( cx. type_kind( next_outer_ty) , TypeKind :: Integer ) ;
127142
128- for i in 0 ..( width as usize ) {
143+ let iterations = if matches ! ( diff_activity, DiffActivity :: Dualv ) {
144+ 1
145+ } else {
146+ width as usize
147+ } ;
148+
149+ for i in 0 ..iterations {
129150 let next_outer_arg2 = outer_args[ outer_pos + 2 * ( i + 1 ) ] ;
130151 let next_outer_ty2 = cx. val_ty ( next_outer_arg2) ;
131152 assert_eq ! ( cx. type_kind( next_outer_ty2) , TypeKind :: Pointer ) ;
@@ -136,7 +157,7 @@ fn match_args_from_caller_to_enzyme<'ll>(
136157 }
137158 args. push ( cx. get_metadata_value ( enzyme_const) ) ;
138159 args. push ( next_outer_arg) ;
139- outer_pos += 2 + 2 * width as usize ;
160+ outer_pos += 2 + 2 * iterations ;
140161 activity_pos += 2 ;
141162 } else {
142163 // A duplicated pointer will have the following two outer_fn arguments:
@@ -344,6 +365,7 @@ fn generate_enzyme_call<'ll>(
344365 let outer_args: Vec < & llvm:: Value > = get_params ( outer_fn) ;
345366 match_args_from_caller_to_enzyme (
346367 & cx,
368+ & builder,
347369 attrs. width ,
348370 & mut args,
349371 & attrs. input_activity ,
0 commit comments