@@ -3,13 +3,13 @@ use std::ptr;
3
3
use rustc_ast:: expand:: autodiff_attrs:: { AutoDiffAttrs , AutoDiffItem , DiffActivity , DiffMode } ;
4
4
use rustc_codegen_ssa:: ModuleCodegen ;
5
5
use rustc_codegen_ssa:: common:: TypeKind ;
6
- use rustc_codegen_ssa:: traits:: BaseTypeCodegenMethods ;
6
+ use rustc_codegen_ssa:: traits:: { BaseTypeCodegenMethods , BuilderMethods } ;
7
7
use rustc_errors:: FatalError ;
8
8
use rustc_middle:: bug;
9
9
use tracing:: { debug, trace} ;
10
10
11
11
use crate :: back:: write:: llvm_err;
12
- use crate :: builder:: { SBuilder , UNNAMED } ;
12
+ use crate :: builder:: { Builder , OperandRef , PlaceRef , UNNAMED } ;
13
13
use crate :: context:: SimpleCx ;
14
14
use crate :: declare:: declare_simple_fn;
15
15
use crate :: errors:: { AutoDiffWithoutEnable , LlvmError } ;
@@ -18,7 +18,7 @@ use crate::llvm::{Metadata, True};
18
18
use crate :: value:: Value ;
19
19
use crate :: { CodegenContext , LlvmCodegenBackend , ModuleLlvm , attributes, llvm} ;
20
20
21
- fn get_params ( fnc : & Value ) -> Vec < & Value > {
21
+ fn _get_params ( fnc : & Value ) -> Vec < & Value > {
22
22
let param_num = llvm:: LLVMCountParams ( fnc) as usize ;
23
23
let mut fnc_args: Vec < & Value > = vec ! [ ] ;
24
24
fnc_args. reserve ( param_num) ;
@@ -48,9 +48,9 @@ fn has_sret(fnc: &Value) -> bool {
48
48
// need to match those.
49
49
// FIXME(ZuseZ4): This logic is a bit more complicated than it should be, can we simplify it
50
50
// using iterators and peek()?
51
- fn match_args_from_caller_to_enzyme < ' ll > (
51
+ fn match_args_from_caller_to_enzyme < ' ll , ' tcx > (
52
52
cx : & SimpleCx < ' ll > ,
53
- builder : & SBuilder < ' ll , ' ll > ,
53
+ builder : & mut Builder < ' _ , ' ll , ' tcx > ,
54
54
width : u32 ,
55
55
args : & mut Vec < & ' ll llvm:: Value > ,
56
56
inputs : & [ DiffActivity ] ,
@@ -288,11 +288,14 @@ fn compute_enzyme_fn_ty<'ll>(
288
288
/// [^1]: <https://enzyme.mit.edu/getting_started/CallingConvention/>
289
289
// FIXME(ZuseZ4): `outer_fn` should include upstream safety checks to
290
290
// cover some assumptions of enzyme/autodiff, which could lead to UB otherwise.
291
- fn generate_enzyme_call < ' ll > (
291
+ pub ( crate ) fn generate_enzyme_call < ' ll , ' tcx > (
292
+ builder : & mut Builder < ' _ , ' ll , ' tcx > ,
292
293
cx : & SimpleCx < ' ll > ,
293
294
fn_to_diff : & ' ll Value ,
294
295
outer_fn : & ' ll Value ,
296
+ fn_args : & [ OperandRef < ' tcx , & ' ll Value > ] ,
295
297
attrs : AutoDiffAttrs ,
298
+ dest : PlaceRef < ' tcx , & ' ll Value > ,
296
299
) {
297
300
// We have to pick the name depending on whether we want forward or reverse mode autodiff.
298
301
let mut ad_name: String = match attrs. mode {
@@ -365,14 +368,6 @@ fn generate_enzyme_call<'ll>(
365
368
let enzyme_marker_attr = llvm:: CreateAttrString ( cx. llcx , "enzyme_marker" ) ;
366
369
attributes:: apply_to_llfn ( outer_fn, Function , & [ enzyme_marker_attr] ) ;
367
370
368
- // first, remove all calls from fnc
369
- let entry = llvm:: LLVMGetFirstBasicBlock ( outer_fn) ;
370
- let br = llvm:: LLVMRustGetTerminator ( entry) ;
371
- llvm:: LLVMRustEraseInstFromParent ( br) ;
372
-
373
- let last_inst = llvm:: LLVMRustGetLastInstruction ( entry) . unwrap ( ) ;
374
- let mut builder = SBuilder :: build ( cx, entry) ;
375
-
376
371
let num_args = llvm:: LLVMCountParams ( & fn_to_diff) ;
377
372
let mut args = Vec :: with_capacity ( num_args as usize + 1 ) ;
378
373
args. push ( fn_to_diff) ;
@@ -388,40 +383,20 @@ fn generate_enzyme_call<'ll>(
388
383
}
389
384
390
385
let has_sret = has_sret ( outer_fn) ;
391
- let outer_args: Vec < & llvm:: Value > = get_params ( outer_fn ) ;
386
+ let outer_args: Vec < & llvm:: Value > = fn_args . iter ( ) . map ( |op| op . immediate ( ) ) . collect ( ) ;
392
387
match_args_from_caller_to_enzyme (
393
388
& cx,
394
- & builder,
389
+ builder,
395
390
attrs. width ,
396
391
& mut args,
397
392
& attrs. input_activity ,
398
393
& outer_args,
399
394
has_sret,
400
395
) ;
401
396
402
- let call = builder. call ( enzyme_ty, ad_fn, & args, None ) ;
403
-
404
- // This part is a bit iffy. LLVM requires that a call to an inlineable function has some
405
- // metadata attached to it, but we just created this code oota. Given that the
406
- // differentiated function already has partly confusing metadata, and given that this
407
- // affects nothing but the auttodiff IR, we take a shortcut and just steal metadata from the
408
- // dummy code which we inserted at a higher level.
409
- // FIXME(ZuseZ4): Work with Enzyme core devs to clarify what debug metadata issues we have,
410
- // and how to best improve it for enzyme core and rust-enzyme.
411
- let md_ty = cx. get_md_kind_id ( "dbg" ) ;
412
- if llvm:: LLVMRustHasMetadata ( last_inst, md_ty) {
413
- let md = llvm:: LLVMRustDIGetInstMetadata ( last_inst)
414
- . expect ( "failed to get instruction metadata" ) ;
415
- let md_todiff = cx. get_metadata_value ( md) ;
416
- llvm:: LLVMSetMetadata ( call, md_ty, md_todiff) ;
417
- } else {
418
- // We don't panic, since depending on whether we are in debug or release mode, we might
419
- // have no debug info to copy, which would then be ok.
420
- trace ! ( "no dbg info" ) ;
421
- }
397
+ let call = builder. call ( enzyme_ty, None , None , ad_fn, & args, None , None ) ;
422
398
423
- // Now that we copied the metadata, get rid of dummy code.
424
- llvm:: LLVMRustEraseInstUntilInclusive ( entry, last_inst) ;
399
+ builder. store_to_place ( call, dest. val ) ;
425
400
426
401
if cx. val_ty ( call) == cx. type_void ( ) || has_sret {
427
402
if has_sret {
@@ -444,10 +419,10 @@ fn generate_enzyme_call<'ll>(
444
419
llvm:: LLVMBuildStore ( & builder. llbuilder , call, sret_ptr) ;
445
420
}
446
421
builder. ret_void ( ) ;
447
- } else {
448
- builder. ret ( call) ;
449
422
}
450
423
424
+ builder. store_to_place ( call, dest. val ) ;
425
+
451
426
// Let's crash in case that we messed something up above and generated invalid IR.
452
427
llvm:: LLVMRustVerifyFunction (
453
428
outer_fn,
@@ -461,6 +436,7 @@ pub(crate) fn differentiate<'ll>(
461
436
cgcx : & CodegenContext < LlvmCodegenBackend > ,
462
437
diff_items : Vec < AutoDiffItem > ,
463
438
) -> Result < ( ) , FatalError > {
439
+ // TODO(Sa4dUs): delete all this logic
464
440
for item in & diff_items {
465
441
trace ! ( "{}" , item) ;
466
442
}
@@ -480,7 +456,7 @@ pub(crate) fn differentiate<'ll>(
480
456
for item in diff_items. iter ( ) {
481
457
let name = item. source . clone ( ) ;
482
458
let fn_def: Option < & llvm:: Value > = cx. get_function ( & name) ;
483
- let Some ( fn_def ) = fn_def else {
459
+ let Some ( _fn_def ) = fn_def else {
484
460
return Err ( llvm_err (
485
461
diag_handler. handle ( ) ,
486
462
LlvmError :: PrepareAutoDiff {
@@ -492,7 +468,7 @@ pub(crate) fn differentiate<'ll>(
492
468
} ;
493
469
debug ! ( ?item. target) ;
494
470
let fn_target: Option < & llvm:: Value > = cx. get_function ( & item. target ) ;
495
- let Some ( fn_target ) = fn_target else {
471
+ let Some ( _fn_target ) = fn_target else {
496
472
return Err ( llvm_err (
497
473
diag_handler. handle ( ) ,
498
474
LlvmError :: PrepareAutoDiff {
@@ -503,7 +479,7 @@ pub(crate) fn differentiate<'ll>(
503
479
) ) ;
504
480
} ;
505
481
506
- generate_enzyme_call ( & cx, fn_def, fn_target, item. attrs . clone ( ) ) ;
482
+ // generate_enzyme_call(&cx, fn_def, fn_target, item.attrs.clone());
507
483
}
508
484
509
485
// FIXME(ZuseZ4): support SanitizeHWAddress and prevent illegal/unsupported opts
0 commit comments