@@ -73,7 +73,9 @@ mod llvm_enzyme {
73
73
}
74
74
75
75
// Get information about the function the macro is applied to
76
- fn extract_item_info ( iitem : & Box < ast:: Item > ) -> Option < ( Visibility , FnSig , Ident , Generics , bool ) > {
76
+ fn extract_item_info (
77
+ iitem : & Box < ast:: Item > ,
78
+ ) -> Option < ( Visibility , FnSig , Ident , Generics , bool ) > {
77
79
match & iitem. kind {
78
80
ItemKind :: Fn ( box ast:: Fn { sig, ident, generics, .. } ) => {
79
81
Some ( ( iitem. vis . clone ( ) , sig. clone ( ) , ident. clone ( ) , generics. clone ( ) , false ) )
@@ -259,7 +261,6 @@ mod llvm_enzyme {
259
261
} ;
260
262
261
263
let has_ret = has_ret ( & sig. decl . output ) ;
262
- let sig_span = ecx. with_call_site_ctxt ( sig. span ) ;
263
264
264
265
// create TokenStream from vec elemtents:
265
266
// meta_item doesn't have a .tokens field
@@ -328,24 +329,13 @@ mod llvm_enzyme {
328
329
}
329
330
let span = ecx. with_def_site_ctxt ( expand_span) ;
330
331
331
- let n_active: u32 = x
332
- . input_activity
333
- . iter ( )
334
- . filter ( |a| * * a == DiffActivity :: Active || * * a == DiffActivity :: ActiveOnly )
335
- . count ( ) as u32 ;
336
- let ( d_sig, new_args, idents, errored) = gen_enzyme_decl ( ecx, & sig, & x, span) ;
332
+ let ( d_sig, idents, errored) = gen_enzyme_decl ( ecx, & sig, & x, span) ;
337
333
338
- // TODO(Sa4dUs): Remove this and all the related logic
339
334
let d_body = gen_enzyme_body (
340
335
ecx,
341
- & x,
342
- n_active,
343
- & sig,
344
336
& d_sig,
345
337
primal,
346
- & new_args,
347
338
span,
348
- sig_span,
349
339
idents,
350
340
errored,
351
341
first_ident ( & meta_item_vec[ 0 ] ) ,
@@ -358,7 +348,7 @@ mod llvm_enzyme {
358
348
defaultness : ast:: Defaultness :: Final ,
359
349
sig : d_sig,
360
350
ident : first_ident ( & meta_item_vec[ 0 ] ) ,
361
- generics : generics . clone ( ) ,
351
+ generics,
362
352
contract : None ,
363
353
body : Some ( d_body) ,
364
354
define_opaque : None ,
@@ -539,7 +529,7 @@ mod llvm_enzyme {
539
529
vec ! [
540
530
Ident :: from_str( "std" ) ,
541
531
Ident :: from_str( "intrinsics" ) ,
542
- Ident :: from_str ( " enzyme_autodiff" ) ,
532
+ Ident :: with_dummy_span ( sym :: enzyme_autodiff) ,
543
533
] ,
544
534
) ;
545
535
let call_expr = ecx. expr_call (
@@ -552,7 +542,7 @@ mod llvm_enzyme {
552
542
}
553
543
554
544
// Generate turbofish expression from fn name and generics
555
- // Given `foo` and `<A, B, C>`, gen `foo::<A, B, C>`
545
+ // Given `foo` and `<A, B, C>` params , gen `foo::<A, B, C>`
556
546
fn gen_turbofish_expr (
557
547
ecx : & ExtCtxt < ' _ > ,
558
548
ident : Ident ,
@@ -594,43 +584,28 @@ mod llvm_enzyme {
594
584
595
585
// Will generate a body of the type:
596
586
// ```
597
- // {
598
- // unsafe {
599
- // asm!("NOP");
600
- // }
601
- // ::core::hint::black_box(primal(args));
602
- // ::core::hint::black_box((args, ret));
603
- // <This part remains to be done by following function>
587
+ // primal(args);
588
+ // std::intrinsics::enzyme_autodiff(primal, diff, (args))
604
589
// }
605
590
// ```
606
591
fn init_body_helper (
607
592
ecx : & ExtCtxt < ' _ > ,
608
593
span : Span ,
609
594
primal : Ident ,
610
- _new_names : & [ String ] ,
611
- _sig_span : Span ,
612
- new_decl_span : Span ,
613
595
idents : & [ Ident ] ,
614
596
errored : bool ,
615
597
generics : & Generics ,
616
- ) -> ( Box < ast:: Block > , Box < ast:: Expr > , Box < ast:: Expr > , Box < ast:: Expr > ) {
617
- let blackbox_path = ecx. std_path ( & [ sym:: hint, sym:: black_box] ) ;
598
+ ) -> Box < ast:: Block > {
618
599
let blackbox_call_expr = ecx. expr_path ( ecx. path ( span, blackbox_path) ) ;
619
600
let primal_call = gen_primal_call ( ecx, span, primal, idents, generics) ;
620
- let black_box_primal_call = ecx. expr_call (
621
- new_decl_span,
622
- blackbox_call_expr. clone ( ) ,
623
- thin_vec ! [ primal_call. clone( ) ] ,
624
- ) ;
625
-
626
601
let mut body = ecx. block ( span, ThinVec :: new ( ) ) ;
627
602
628
603
// This uses primal args which won't be available if we errored before
629
604
if !errored {
630
605
body. stmts . push ( ecx. stmt_semi ( primal_call. clone ( ) ) ) ;
631
606
}
632
607
633
- ( body, primal_call , black_box_primal_call , blackbox_call_expr )
608
+ body
634
609
}
635
610
636
611
/// We only want this function to type-check, since we will replace the body
@@ -643,14 +618,9 @@ mod llvm_enzyme {
643
618
/// from optimizing any arguments away.
644
619
fn gen_enzyme_body (
645
620
ecx : & ExtCtxt < ' _ > ,
646
- _x : & AutoDiffAttrs ,
647
- _n_active : u32 ,
648
- _sig : & ast:: FnSig ,
649
621
d_sig : & ast:: FnSig ,
650
622
primal : Ident ,
651
- new_names : & [ String ] ,
652
623
span : Span ,
653
- sig_span : Span ,
654
624
idents : Vec < Ident > ,
655
625
errored : bool ,
656
626
diff_ident : Ident ,
@@ -661,17 +631,7 @@ mod llvm_enzyme {
661
631
662
632
// Add a call to the primal function to prevent it from being inlined
663
633
// and call `enzyme_autodiff` intrinsic (this also covers the return type)
664
- let ( mut body, _primal_call, _bb_primal_call, _bb_call_expr) = init_body_helper (
665
- ecx,
666
- span,
667
- primal,
668
- new_names,
669
- sig_span,
670
- new_decl_span,
671
- & idents,
672
- errored,
673
- generics,
674
- ) ;
634
+ let mut body = init_body_helper ( ecx, span, primal, & idents, errored, generics) ;
675
635
676
636
body. stmts . push ( call_enzyme_autodiff (
677
637
ecx,
@@ -768,7 +728,7 @@ mod llvm_enzyme {
768
728
sig : & ast:: FnSig ,
769
729
x : & AutoDiffAttrs ,
770
730
span : Span ,
771
- ) -> ( ast:: FnSig , Vec < String > , Vec < Ident > , bool ) {
731
+ ) -> ( ast:: FnSig , Vec < Ident > , bool ) {
772
732
let dcx = ecx. sess . dcx ( ) ;
773
733
let has_ret = has_ret ( & sig. decl . output ) ;
774
734
let sig_args = sig. decl . inputs . len ( ) + if has_ret { 1 } else { 0 } ;
@@ -780,7 +740,7 @@ mod llvm_enzyme {
780
740
found : num_activities,
781
741
} ) ;
782
742
// This is not the right signature, but we can continue parsing.
783
- return ( sig. clone ( ) , vec ! [ ] , vec ! [ ] , true ) ;
743
+ return ( sig. clone ( ) , vec ! [ ] , true ) ;
784
744
}
785
745
assert ! ( sig. decl. inputs. len( ) == x. input_activity. len( ) ) ;
786
746
assert ! ( has_ret == x. has_ret_activity( ) ) ;
@@ -823,7 +783,7 @@ mod llvm_enzyme {
823
783
824
784
if errors {
825
785
// This is not the right signature, but we can continue parsing.
826
- return ( sig. clone ( ) , new_inputs , idents, true ) ;
786
+ return ( sig. clone ( ) , idents, true ) ;
827
787
}
828
788
829
789
let unsafe_activities = x
@@ -1037,7 +997,7 @@ mod llvm_enzyme {
1037
997
}
1038
998
let d_sig = FnSig { header : d_header, decl : d_decl, span } ;
1039
999
trace ! ( "Generated signature: {:?}" , d_sig) ;
1040
- ( d_sig, new_inputs , idents, false )
1000
+ ( d_sig, idents, false )
1041
1001
}
1042
1002
}
1043
1003
0 commit comments