@@ -14,7 +14,7 @@ use crate::context::SimpleCx;
14
14
use crate :: declare:: declare_simple_fn;
15
15
use crate :: errors:: { AutoDiffWithoutEnable , LlvmError } ;
16
16
use crate :: llvm:: AttributePlace :: Function ;
17
- use crate :: llvm:: { Metadata , True } ;
17
+ use crate :: llvm:: { Metadata , True , Type } ;
18
18
use crate :: value:: Value ;
19
19
use crate :: { CodegenContext , LlvmCodegenBackend , ModuleLlvm , attributes, llvm} ;
20
20
@@ -29,7 +29,7 @@ fn _get_params(fnc: &Value) -> Vec<&Value> {
29
29
fnc_args
30
30
}
31
31
32
- fn has_sret ( fnc : & Value ) -> bool {
32
+ fn _has_sret ( fnc : & Value ) -> bool {
33
33
let num_args = llvm:: LLVMCountParams ( fnc) as usize ;
34
34
if num_args == 0 {
35
35
false
@@ -55,7 +55,6 @@ fn match_args_from_caller_to_enzyme<'ll, 'tcx>(
55
55
args : & mut Vec < & ' ll llvm:: Value > ,
56
56
inputs : & [ DiffActivity ] ,
57
57
outer_args : & [ & ' ll llvm:: Value ] ,
58
- has_sret : bool ,
59
58
) {
60
59
debug ! ( "matching autodiff arguments" ) ;
61
60
// We now handle the issue that Rust level arguments not always match the llvm-ir level
@@ -67,20 +66,12 @@ fn match_args_from_caller_to_enzyme<'ll, 'tcx>(
67
66
let mut outer_pos: usize = 0 ;
68
67
let mut activity_pos = 0 ;
69
68
70
- if has_sret {
71
- // Then the first outer arg is the sret pointer. Enzyme doesn't know about sret, so the
72
- // inner function will still return something. We increase our outer_pos by one,
73
- // and once we're done with all other args we will take the return of the inner call and
74
- // update the sret pointer with it
75
- outer_pos = 1 ;
76
- }
77
-
78
- let enzyme_const = cx. create_metadata ( b"enzyme_const" ) ;
79
- let enzyme_out = cx. create_metadata ( b"enzyme_out" ) ;
80
- let enzyme_dup = cx. create_metadata ( b"enzyme_dup" ) ;
81
- let enzyme_dupv = cx. create_metadata ( b"enzyme_dupv" ) ;
82
- let enzyme_dupnoneed = cx. create_metadata ( b"enzyme_dupnoneed" ) ;
83
- let enzyme_dupnoneedv = cx. create_metadata ( b"enzyme_dupnoneedv" ) ;
69
+ let enzyme_const = cx. create_metadata ( "enzyme_const" . to_string ( ) ) . unwrap ( ) ;
70
+ let enzyme_out = cx. create_metadata ( "enzyme_out" . to_string ( ) ) . unwrap ( ) ;
71
+ let enzyme_dup = cx. create_metadata ( "enzyme_dup" . to_string ( ) ) . unwrap ( ) ;
72
+ let enzyme_dupv = cx. create_metadata ( "enzyme_dupv" . to_string ( ) ) . unwrap ( ) ;
73
+ let enzyme_dupnoneed = cx. create_metadata ( "enzyme_dupnoneed" . to_string ( ) ) . unwrap ( ) ;
74
+ let enzyme_dupnoneedv = cx. create_metadata ( "enzyme_dupnoneedv" . to_string ( ) ) . unwrap ( ) ;
84
75
85
76
while activity_pos < inputs. len ( ) {
86
77
let diff_activity = inputs[ activity_pos as usize ] ;
@@ -193,92 +184,6 @@ fn match_args_from_caller_to_enzyme<'ll, 'tcx>(
193
184
}
194
185
}
195
186
196
- // On LLVM-IR, we can luckily declare __enzyme_ functions without specifying the input
197
- // arguments. We do however need to declare them with their correct return type.
198
- // We already figured the correct return type out in our frontend, when generating the outer_fn,
199
- // so we can now just go ahead and use that. This is not always trivial, e.g. because sret.
200
- // Beyond sret, this article describes our challenges nicely:
201
- // <https://yorickpeterse.com/articles/the-mess-that-is-handling-structure-arguments-and-returns-in-llvm/>
202
- // I.e. (i32, f32) will get merged into i64, but we don't handle that yet.
203
- fn compute_enzyme_fn_ty < ' ll > (
204
- cx : & SimpleCx < ' ll > ,
205
- attrs : & AutoDiffAttrs ,
206
- fn_to_diff : & ' ll Value ,
207
- outer_fn : & ' ll Value ,
208
- ) -> & ' ll llvm:: Type {
209
- let fn_ty = cx. get_type_of_global ( outer_fn) ;
210
- let mut ret_ty = cx. get_return_type ( fn_ty) ;
211
-
212
- let has_sret = has_sret ( outer_fn) ;
213
-
214
- if has_sret {
215
- // Now we don't just forward the return type, so we have to figure it out based on the
216
- // primal return type, in combination with the autodiff settings.
217
- let fn_ty = cx. get_type_of_global ( fn_to_diff) ;
218
- let inner_ret_ty = cx. get_return_type ( fn_ty) ;
219
-
220
- let void_ty = unsafe { llvm:: LLVMVoidTypeInContext ( cx. llcx ) } ;
221
- if inner_ret_ty == void_ty {
222
- // This indicates that even the inner function has an sret.
223
- // Right now I only look for an sret in the outer function.
224
- // This *probably* needs some extra handling, but I never ran
225
- // into such a case. So I'll wait for user reports to have a test case.
226
- bug ! ( "sret in inner function" ) ;
227
- }
228
-
229
- if attrs. width == 1 {
230
- // Enzyme returns a struct of style:
231
- // `{ original_ret(if requested), float, float, ... }`
232
- let mut struct_elements = vec ! [ ] ;
233
- if attrs. has_primal_ret ( ) {
234
- struct_elements. push ( inner_ret_ty) ;
235
- }
236
- // Next, we push the list of active floats, since they will be lowered to `enzyme_out`,
237
- // and therefore part of the return struct.
238
- let param_tys = cx. func_params_types ( fn_ty) ;
239
- for ( act, param_ty) in attrs. input_activity . iter ( ) . zip ( param_tys) {
240
- if matches ! ( act, DiffActivity :: Active ) {
241
- // Now find the float type at position i based on the fn_ty,
242
- // to know what (f16/f32/f64/...) to add to the struct.
243
- struct_elements. push ( param_ty) ;
244
- }
245
- }
246
- ret_ty = cx. type_struct ( & struct_elements, false ) ;
247
- } else {
248
- // First we check if we also have to deal with the primal return.
249
- match attrs. mode {
250
- DiffMode :: Forward => match attrs. ret_activity {
251
- DiffActivity :: Dual => {
252
- let arr_ty =
253
- unsafe { llvm:: LLVMArrayType2 ( inner_ret_ty, attrs. width as u64 + 1 ) } ;
254
- ret_ty = arr_ty;
255
- }
256
- DiffActivity :: DualOnly => {
257
- let arr_ty =
258
- unsafe { llvm:: LLVMArrayType2 ( inner_ret_ty, attrs. width as u64 ) } ;
259
- ret_ty = arr_ty;
260
- }
261
- DiffActivity :: Const => {
262
- todo ! ( "Not sure, do we need to do something here?" ) ;
263
- }
264
- _ => {
265
- bug ! ( "unreachable" ) ;
266
- }
267
- } ,
268
- DiffMode :: Reverse => {
269
- todo ! ( "Handle sret for reverse mode" ) ;
270
- }
271
- _ => {
272
- bug ! ( "unreachable" ) ;
273
- }
274
- }
275
- }
276
- }
277
-
278
- // LLVM can figure out the input types on it's own, so we take a shortcut here.
279
- unsafe { llvm:: LLVMFunctionType ( ret_ty, ptr:: null ( ) , 0 , True ) }
280
- }
281
-
282
187
/// When differentiating `fn_to_diff`, take a `outer_fn` and generate another
283
188
/// function with expected naming and calling conventions[^1] which will be
284
189
/// discovered by the enzyme LLVM pass and its body populated with the differentiated
@@ -292,7 +197,8 @@ pub(crate) fn generate_enzyme_call<'ll, 'tcx>(
292
197
builder : & mut Builder < ' _ , ' ll , ' tcx > ,
293
198
cx : & SimpleCx < ' ll > ,
294
199
fn_to_diff : & ' ll Value ,
295
- outer_fn : & ' ll Value ,
200
+ outer_name : & str ,
201
+ ret_ty : & ' ll Type ,
296
202
fn_args : & [ OperandRef < ' tcx , & ' ll Value > ] ,
297
203
attrs : AutoDiffAttrs ,
298
204
dest : PlaceRef < ' tcx , & ' ll Value > ,
@@ -305,11 +211,9 @@ pub(crate) fn generate_enzyme_call<'ll, 'tcx>(
305
211
}
306
212
. to_string ( ) ;
307
213
308
- // add outer_fn name to ad_name to make it unique, in case users apply autodiff to multiple
214
+ // add outer_name to ad_name to make it unique, in case users apply autodiff to multiple
309
215
// functions. Unwrap will only panic, if LLVM gave us an invalid string.
310
- let name = llvm:: get_value_name ( outer_fn) ;
311
- let outer_fn_name = std:: str:: from_utf8 ( & name) . unwrap ( ) ;
312
- ad_name. push_str ( outer_fn_name) ;
216
+ ad_name. push_str ( outer_name) ;
313
217
314
218
// Let us assume the user wrote the following function square:
315
219
//
@@ -320,13 +224,7 @@ pub(crate) fn generate_enzyme_call<'ll, 'tcx>(
320
224
// ret double %0
321
225
// }
322
226
// ```
323
- //
324
- // The user now applies autodiff to the function square, in which case fn_to_diff will be `square`.
325
- // Our macro generates the following placeholder code (slightly simplified):
326
- //
327
- // ```llvm
328
227
// define double @dsquare(double %x) {
329
- // ; placeholder code
330
228
// return 0.0;
331
229
// }
332
230
// ```
@@ -343,92 +241,54 @@ pub(crate) fn generate_enzyme_call<'ll, 'tcx>(
343
241
// ret double %0
344
242
// }
345
243
// ```
346
- unsafe {
347
- let enzyme_ty = compute_enzyme_fn_ty ( cx, & attrs, fn_to_diff, outer_fn) ;
348
-
349
- // FIXME(ZuseZ4): the CC/Addr/Vis values are best effort guesses, we should look at tests and
350
- // think a bit more about what should go here.
351
- let cc = llvm:: LLVMGetFunctionCallConv ( outer_fn) ;
352
- let ad_fn = declare_simple_fn (
353
- cx,
354
- & ad_name,
355
- llvm:: CallConv :: try_from ( cc) . expect ( "invalid callconv" ) ,
356
- llvm:: UnnamedAddr :: No ,
357
- llvm:: Visibility :: Default ,
358
- enzyme_ty,
359
- ) ;
360
-
361
- // Otherwise LLVM might inline our temporary code before the enzyme pass has a chance to
362
- // do it's work.
363
- let attr = llvm:: AttributeKind :: NoInline . create_attr ( cx. llcx ) ;
364
- attributes:: apply_to_llfn ( ad_fn, Function , & [ attr] ) ;
365
-
366
- // We add a made-up attribute just such that we can recognize it after AD to update
367
- // (no)-inline attributes. We'll then also remove this attribute.
368
- let enzyme_marker_attr = llvm:: CreateAttrString ( cx. llcx , "enzyme_marker" ) ;
369
- attributes:: apply_to_llfn ( outer_fn, Function , & [ enzyme_marker_attr] ) ;
370
-
371
- let num_args = llvm:: LLVMCountParams ( & fn_to_diff) ;
372
- let mut args = Vec :: with_capacity ( num_args as usize + 1 ) ;
373
- args. push ( fn_to_diff) ;
374
-
375
- let enzyme_primal_ret = cx. create_metadata ( b"enzyme_primal_return" ) ;
376
- if matches ! ( attrs. ret_activity, DiffActivity :: Dual | DiffActivity :: Active ) {
377
- args. push ( cx. get_metadata_value ( enzyme_primal_ret) ) ;
378
- }
379
- if attrs. width > 1 {
380
- let enzyme_width = cx. create_metadata ( b"enzyme_width" ) ;
381
- args. push ( cx. get_metadata_value ( enzyme_width) ) ;
382
- args. push ( cx. get_const_int ( cx. type_i64 ( ) , attrs. width as u64 ) ) ;
383
- }
384
-
385
- let has_sret = has_sret ( outer_fn) ;
386
- let outer_args: Vec < & llvm:: Value > = fn_args. iter ( ) . map ( |op| op. immediate ( ) ) . collect ( ) ;
387
- match_args_from_caller_to_enzyme (
388
- & cx,
389
- builder,
390
- attrs. width ,
391
- & mut args,
392
- & attrs. input_activity ,
393
- & outer_args,
394
- has_sret,
395
- ) ;
396
-
397
- let call = builder. call ( enzyme_ty, None , None , ad_fn, & args, None , None ) ;
398
-
399
- builder. store_to_place ( call, dest. val ) ;
244
+ let enzyme_ty = unsafe { llvm:: LLVMFunctionType ( ret_ty, ptr:: null ( ) , 0 , True ) } ;
245
+
246
+ // FIXME(ZuseZ4): the CC/Addr/Vis values are best effort guesses, we should look at tests and
247
+ // think a bit more about what should go here.
248
+ // FIXME(Sa4dUs): have to find a way to get the cc, using `FastCallConv` for now
249
+ let cc = 8 ;
250
+ let ad_fn = declare_simple_fn (
251
+ cx,
252
+ & ad_name,
253
+ llvm:: CallConv :: try_from ( cc) . expect ( "invalid callconv" ) ,
254
+ llvm:: UnnamedAddr :: No ,
255
+ llvm:: Visibility :: Default ,
256
+ enzyme_ty,
257
+ ) ;
258
+
259
+ // Otherwise LLVM might inline our temporary code before the enzyme pass has a chance to
260
+ // do it's work.
261
+ let attr = llvm:: AttributeKind :: NoInline . create_attr ( cx. llcx ) ;
262
+ attributes:: apply_to_llfn ( ad_fn, Function , & [ attr] ) ;
263
+
264
+ let num_args = llvm:: LLVMCountParams ( & fn_to_diff) ;
265
+ let mut args = Vec :: with_capacity ( num_args as usize + 1 ) ;
266
+ args. push ( fn_to_diff) ;
267
+
268
+ let enzyme_primal_ret = cx. create_metadata ( "enzyme_primal_return" . to_string ( ) ) . unwrap ( ) ;
269
+ if matches ! ( attrs. ret_activity, DiffActivity :: Dual | DiffActivity :: Active ) {
270
+ args. push ( cx. get_metadata_value ( enzyme_primal_ret) ) ;
271
+ }
272
+ if attrs. width > 1 {
273
+ let enzyme_width = cx. create_metadata ( "enzyme_width" . to_string ( ) ) . unwrap ( ) ;
274
+ args. push ( cx. get_metadata_value ( enzyme_width) ) ;
275
+ args. push ( cx. get_const_int ( cx. type_i64 ( ) , attrs. width as u64 ) ) ;
276
+ }
400
277
401
- if cx. val_ty ( call) == cx. type_void ( ) || has_sret {
402
- if has_sret {
403
- // This is what we already have in our outer_fn (shortened):
404
- // define void @_foo(ptr <..> sret([32 x i8]) initializes((0, 32)) %0, <...>) {
405
- // %7 = call [4 x double] (...) @__enzyme_fwddiff_foo(ptr @square, metadata !"enzyme_width", i64 4, <...>)
406
- // <Here we are, we want to add the following two lines>
407
- // store [4 x double] %7, ptr %0, align 8
408
- // ret void
409
- // }
278
+ let outer_args: Vec < & llvm:: Value > = fn_args. iter ( ) . map ( |op| op. immediate ( ) ) . collect ( ) ;
410
279
411
- // now store the result of the enzyme call into the sret pointer.
412
- let sret_ptr = outer_args[ 0 ] ;
413
- let call_ty = cx. val_ty ( call) ;
414
- if attrs. width == 1 {
415
- assert_eq ! ( cx. type_kind( call_ty) , TypeKind :: Struct ) ;
416
- } else {
417
- assert_eq ! ( cx. type_kind( call_ty) , TypeKind :: Array ) ;
418
- }
419
- llvm:: LLVMBuildStore ( & builder. llbuilder , call, sret_ptr) ;
420
- }
421
- builder. ret_void ( ) ;
422
- }
280
+ match_args_from_caller_to_enzyme (
281
+ & cx,
282
+ builder,
283
+ attrs. width ,
284
+ & mut args,
285
+ & attrs. input_activity ,
286
+ & outer_args,
287
+ ) ;
423
288
424
- builder. store_to_place ( call, dest . val ) ;
289
+ let call = builder. call ( enzyme_ty , None , None , ad_fn , & args , None , None ) ;
425
290
426
- // Let's crash in case that we messed something up above and generated invalid IR.
427
- llvm:: LLVMRustVerifyFunction (
428
- outer_fn,
429
- llvm:: LLVMRustVerifierFailureAction :: LLVMAbortProcessAction ,
430
- ) ;
431
- }
291
+ builder. store_to_place ( call, dest. val ) ;
432
292
}
433
293
434
294
pub ( crate ) fn differentiate < ' ll > (
0 commit comments