@@ -143,9 +143,9 @@ fn match_args_from_caller_to_enzyme<'ll, 'tcx>(
143
143
cx : & SimpleCx < ' ll > ,
144
144
builder : & mut Builder < ' _ , ' ll , ' tcx > ,
145
145
width : u32 ,
146
- args : & mut Vec < & ' ll llvm :: Value > ,
146
+ args : & mut Vec < & ' ll Value > ,
147
147
inputs : & [ DiffActivity ] ,
148
- outer_args : & [ & ' ll llvm :: Value ] ,
148
+ outer_args : & [ & ' ll Value ] ,
149
149
) {
150
150
debug ! ( "matching autodiff arguments" ) ;
151
151
// We now handle the issue that Rust level arguments not always match the llvm-ir level
@@ -157,32 +157,36 @@ fn match_args_from_caller_to_enzyme<'ll, 'tcx>(
157
157
let mut outer_pos: usize = 0 ;
158
158
let mut activity_pos = 0 ;
159
159
160
- let enzyme_const = cx. create_metadata ( b"enzyme_const" ) ;
161
- let enzyme_out = cx. create_metadata ( b"enzyme_out" ) ;
162
- let enzyme_dup = cx. create_metadata ( b"enzyme_dup" ) ;
163
- let enzyme_dupv = cx. create_metadata ( b"enzyme_dupv" ) ;
164
- let enzyme_dupnoneed = cx. create_metadata ( b"enzyme_dupnoneed" ) ;
165
- let enzyme_dupnoneedv = cx. create_metadata ( b"enzyme_dupnoneedv" ) ;
160
+ // We used to use llvm's metadata to instruct enzyme how to differentiate a function.
161
+ // In debug mode we would use incremental compilation which caused the metadata to be
162
+ // dropped. This is prevented by now using named globals, which are also understood
163
+ // by Enzyme.
164
+ let global_const = cx. declare_global ( "enzyme_const" , cx. type_ptr ( ) ) ;
165
+ let global_out = cx. declare_global ( "enzyme_out" , cx. type_ptr ( ) ) ;
166
+ let global_dup = cx. declare_global ( "enzyme_dup" , cx. type_ptr ( ) ) ;
167
+ let global_dupv = cx. declare_global ( "enzyme_dupv" , cx. type_ptr ( ) ) ;
168
+ let global_dupnoneed = cx. declare_global ( "enzyme_dupnoneed" , cx. type_ptr ( ) ) ;
169
+ let global_dupnoneedv = cx. declare_global ( "enzyme_dupnoneedv" , cx. type_ptr ( ) ) ;
166
170
167
171
while activity_pos < inputs. len ( ) {
168
172
let diff_activity = inputs[ activity_pos as usize ] ;
169
173
// Duplicated arguments received a shadow argument, into which enzyme will write the
170
174
// gradient.
171
- let ( activity, duplicated) : ( & Metadata , bool ) = match diff_activity {
175
+ let ( activity, duplicated) : ( & Value , bool ) = match diff_activity {
172
176
DiffActivity :: None => panic ! ( "not a valid input activity" ) ,
173
- DiffActivity :: Const => ( enzyme_const , false ) ,
174
- DiffActivity :: Active => ( enzyme_out , false ) ,
175
- DiffActivity :: ActiveOnly => ( enzyme_out , false ) ,
176
- DiffActivity :: Dual => ( enzyme_dup , true ) ,
177
- DiffActivity :: Dualv => ( enzyme_dupv , true ) ,
178
- DiffActivity :: DualOnly => ( enzyme_dupnoneed , true ) ,
179
- DiffActivity :: DualvOnly => ( enzyme_dupnoneedv , true ) ,
180
- DiffActivity :: Duplicated => ( enzyme_dup , true ) ,
181
- DiffActivity :: DuplicatedOnly => ( enzyme_dupnoneed , true ) ,
182
- DiffActivity :: FakeActivitySize ( _) => ( enzyme_const , false ) ,
177
+ DiffActivity :: Const => ( global_const , false ) ,
178
+ DiffActivity :: Active => ( global_out , false ) ,
179
+ DiffActivity :: ActiveOnly => ( global_out , false ) ,
180
+ DiffActivity :: Dual => ( global_dup , true ) ,
181
+ DiffActivity :: Dualv => ( global_dupv , true ) ,
182
+ DiffActivity :: DualOnly => ( global_dupnoneed , true ) ,
183
+ DiffActivity :: DualvOnly => ( global_dupnoneedv , true ) ,
184
+ DiffActivity :: Duplicated => ( global_dup , true ) ,
185
+ DiffActivity :: DuplicatedOnly => ( global_dupnoneed , true ) ,
186
+ DiffActivity :: FakeActivitySize ( _) => ( global_const , false ) ,
183
187
} ;
184
188
let outer_arg = outer_args[ outer_pos] ;
185
- args. push ( cx . get_metadata_value ( activity) ) ;
189
+ args. push ( activity) ;
186
190
if matches ! ( diff_activity, DiffActivity :: Dualv ) {
187
191
let next_outer_arg = outer_args[ outer_pos + 1 ] ;
188
192
let elem_bytes_size: u64 = match inputs[ activity_pos + 1 ] {
@@ -242,7 +246,7 @@ fn match_args_from_caller_to_enzyme<'ll, 'tcx>(
242
246
assert_eq ! ( cx. type_kind( next_outer_ty3) , TypeKind :: Integer ) ;
243
247
args. push ( next_outer_arg2) ;
244
248
}
245
- args. push ( cx . get_metadata_value ( enzyme_const ) ) ;
249
+ args. push ( global_const ) ;
246
250
args. push ( next_outer_arg) ;
247
251
outer_pos += 2 + 2 * iterations;
248
252
activity_pos += 2 ;
@@ -351,13 +355,13 @@ pub(crate) fn generate_enzyme_call<'ll, 'tcx>(
351
355
let mut args = Vec :: with_capacity ( num_args as usize + 1 ) ;
352
356
args. push ( fn_to_diff) ;
353
357
354
- let enzyme_primal_ret = cx. create_metadata ( b "enzyme_primal_return") ;
358
+ let global_primal_ret = cx. declare_global ( "enzyme_primal_return" , cx . type_ptr ( ) ) ;
355
359
if matches ! ( attrs. ret_activity, DiffActivity :: Dual | DiffActivity :: Active ) {
356
- args. push ( cx . get_metadata_value ( enzyme_primal_ret ) ) ;
360
+ args. push ( global_primal_ret ) ;
357
361
}
358
362
if attrs. width > 1 {
359
- let enzyme_width = cx. create_metadata ( b "enzyme_width") ;
360
- args. push ( cx . get_metadata_value ( enzyme_width ) ) ;
363
+ let global_width = cx. declare_global ( "enzyme_width" , cx . type_ptr ( ) ) ;
364
+ args. push ( global_width ) ;
361
365
args. push ( cx. get_const_int ( cx. type_i64 ( ) , attrs. width as u64 ) ) ;
362
366
}
363
367
0 commit comments