@@ -262,7 +262,6 @@ mod llvm_enzyme {
262
262
} ;
263
263
264
264
let has_ret = has_ret ( & sig. decl . output ) ;
265
- let sig_span = ecx. with_call_site_ctxt ( sig. span ) ;
266
265
267
266
// create TokenStream from vec elemtents:
268
267
// meta_item doesn't have a .tokens field
@@ -331,24 +330,13 @@ mod llvm_enzyme {
331
330
}
332
331
let span = ecx. with_def_site_ctxt ( expand_span) ;
333
332
334
- let n_active: u32 = x
335
- . input_activity
336
- . iter ( )
337
- . filter ( |a| * * a == DiffActivity :: Active || * * a == DiffActivity :: ActiveOnly )
338
- . count ( ) as u32 ;
339
- let ( d_sig, new_args, idents, errored) = gen_enzyme_decl ( ecx, & sig, & x, span) ;
333
+ let ( d_sig, idents, errored) = gen_enzyme_decl ( ecx, & sig, & x, span) ;
340
334
341
- // TODO(Sa4dUs): Remove this and all the related logic
342
335
let d_body = gen_enzyme_body (
343
336
ecx,
344
- & x,
345
- n_active,
346
- & sig,
347
337
& d_sig,
348
338
primal,
349
- & new_args,
350
339
span,
351
- sig_span,
352
340
idents,
353
341
errored,
354
342
first_ident ( & meta_item_vec[ 0 ] ) ,
@@ -361,7 +349,7 @@ mod llvm_enzyme {
361
349
defaultness : ast:: Defaultness :: Final ,
362
350
sig : d_sig,
363
351
ident : first_ident ( & meta_item_vec[ 0 ] ) ,
364
- generics : generics . clone ( ) ,
352
+ generics,
365
353
contract : None ,
366
354
body : Some ( d_body) ,
367
355
define_opaque : None ,
@@ -542,7 +530,7 @@ mod llvm_enzyme {
542
530
vec ! [
543
531
Ident :: from_str( "std" ) ,
544
532
Ident :: from_str( "intrinsics" ) ,
545
- Ident :: from_str ( " enzyme_autodiff" ) ,
533
+ Ident :: with_dummy_span ( sym :: enzyme_autodiff) ,
546
534
] ,
547
535
) ;
548
536
let call_expr = ecx. expr_call (
@@ -555,7 +543,7 @@ mod llvm_enzyme {
555
543
}
556
544
557
545
// Generate turbofish expression from fn name and generics
558
- // Given `foo` and `<A, B, C>`, gen `foo::<A, B, C>`
546
+ // Given `foo` and `<A, B, C>` params , gen `foo::<A, B, C>`
559
547
fn gen_turbofish_expr (
560
548
ecx : & ExtCtxt < ' _ > ,
561
549
ident : Ident ,
@@ -597,43 +585,27 @@ mod llvm_enzyme {
597
585
598
586
// Will generate a body of the type:
599
587
// ```
600
- // {
601
- // unsafe {
602
- // asm!("NOP");
603
- // }
604
- // ::core::hint::black_box(primal(args));
605
- // ::core::hint::black_box((args, ret));
606
- // <This part remains to be done by following function>
588
+ // primal(args);
589
+ // std::intrinsics::enzyme_autodiff(primal, diff, (args))
607
590
// }
608
591
// ```
609
592
fn init_body_helper (
610
593
ecx : & ExtCtxt < ' _ > ,
611
594
span : Span ,
612
595
primal : Ident ,
613
- _new_names : & [ String ] ,
614
- _sig_span : Span ,
615
- new_decl_span : Span ,
616
596
idents : & [ Ident ] ,
617
597
errored : bool ,
618
598
generics : & Generics ,
619
- ) -> ( P < ast:: Block > , P < ast:: Expr > , P < ast:: Expr > , P < ast:: Expr > ) {
620
- let blackbox_path = ecx. std_path ( & [ sym:: hint, sym:: black_box] ) ;
621
- let blackbox_call_expr = ecx. expr_path ( ecx. path ( span, blackbox_path) ) ;
599
+ ) -> P < ast:: Block > {
622
600
let primal_call = gen_primal_call ( ecx, span, primal, idents, generics) ;
623
- let black_box_primal_call = ecx. expr_call (
624
- new_decl_span,
625
- blackbox_call_expr. clone ( ) ,
626
- thin_vec ! [ primal_call. clone( ) ] ,
627
- ) ;
628
-
629
601
let mut body = ecx. block ( span, ThinVec :: new ( ) ) ;
630
602
631
603
// This uses primal args which won't be available if we errored before
632
604
if !errored {
633
605
body. stmts . push ( ecx. stmt_semi ( primal_call. clone ( ) ) ) ;
634
606
}
635
607
636
- ( body, primal_call , black_box_primal_call , blackbox_call_expr )
608
+ body
637
609
}
638
610
639
611
/// We only want this function to type-check, since we will replace the body
@@ -646,14 +618,9 @@ mod llvm_enzyme {
646
618
/// from optimizing any arguments away.
647
619
fn gen_enzyme_body (
648
620
ecx : & ExtCtxt < ' _ > ,
649
- _x : & AutoDiffAttrs ,
650
- _n_active : u32 ,
651
- _sig : & ast:: FnSig ,
652
621
d_sig : & ast:: FnSig ,
653
622
primal : Ident ,
654
- new_names : & [ String ] ,
655
623
span : Span ,
656
- sig_span : Span ,
657
624
idents : Vec < Ident > ,
658
625
errored : bool ,
659
626
diff_ident : Ident ,
@@ -664,17 +631,7 @@ mod llvm_enzyme {
664
631
665
632
// Add a call to the primal function to prevent it from being inlined
666
633
// and call `enzyme_autodiff` intrinsic (this also covers the return type)
667
- let ( mut body, _primal_call, _bb_primal_call, _bb_call_expr) = init_body_helper (
668
- ecx,
669
- span,
670
- primal,
671
- new_names,
672
- sig_span,
673
- new_decl_span,
674
- & idents,
675
- errored,
676
- generics,
677
- ) ;
634
+ let mut body = init_body_helper ( ecx, span, primal, & idents, errored, generics) ;
678
635
679
636
body. stmts . push ( call_enzyme_autodiff (
680
637
ecx,
@@ -771,7 +728,7 @@ mod llvm_enzyme {
771
728
sig : & ast:: FnSig ,
772
729
x : & AutoDiffAttrs ,
773
730
span : Span ,
774
- ) -> ( ast:: FnSig , Vec < String > , Vec < Ident > , bool ) {
731
+ ) -> ( ast:: FnSig , Vec < Ident > , bool ) {
775
732
let dcx = ecx. sess . dcx ( ) ;
776
733
let has_ret = has_ret ( & sig. decl . output ) ;
777
734
let sig_args = sig. decl . inputs . len ( ) + if has_ret { 1 } else { 0 } ;
@@ -783,7 +740,7 @@ mod llvm_enzyme {
783
740
found : num_activities,
784
741
} ) ;
785
742
// This is not the right signature, but we can continue parsing.
786
- return ( sig. clone ( ) , vec ! [ ] , vec ! [ ] , true ) ;
743
+ return ( sig. clone ( ) , vec ! [ ] , true ) ;
787
744
}
788
745
assert ! ( sig. decl. inputs. len( ) == x. input_activity. len( ) ) ;
789
746
assert ! ( has_ret == x. has_ret_activity( ) ) ;
@@ -826,7 +783,7 @@ mod llvm_enzyme {
826
783
827
784
if errors {
828
785
// This is not the right signature, but we can continue parsing.
829
- return ( sig. clone ( ) , new_inputs , idents, true ) ;
786
+ return ( sig. clone ( ) , idents, true ) ;
830
787
}
831
788
832
789
let unsafe_activities = x
@@ -1034,7 +991,7 @@ mod llvm_enzyme {
1034
991
}
1035
992
let d_sig = FnSig { header : d_header, decl : d_decl, span } ;
1036
993
trace ! ( "Generated signature: {:?}" , d_sig) ;
1037
- ( d_sig, new_inputs , idents, false )
994
+ ( d_sig, idents, false )
1038
995
}
1039
996
}
1040
997
0 commit comments