@@ -20,7 +20,7 @@ mod llvm_enzyme {
20
20
MetaItemInner , PatKind , Path , PathSegment , TyKind , Visibility ,
21
21
} ;
22
22
use rustc_expand:: base:: { Annotatable , ExtCtxt } ;
23
- use rustc_span:: { Ident , Span , Symbol , kw , sym} ;
23
+ use rustc_span:: { Ident , Span , Symbol , sym} ;
24
24
use thin_vec:: { ThinVec , thin_vec} ;
25
25
use tracing:: { debug, trace} ;
26
26
@@ -73,9 +73,7 @@ mod llvm_enzyme {
73
73
}
74
74
75
75
// Get information about the function the macro is applied to
76
- fn extract_item_info (
77
- iitem : & Box < ast:: Item > ,
78
- ) -> Option < ( Visibility , FnSig , Ident , Generics , bool ) > {
76
+ fn extract_item_info ( iitem : & Box < ast:: Item > ) -> Option < ( Visibility , FnSig , Ident , Generics ) > {
79
77
match & iitem. kind {
80
78
ItemKind :: Fn ( box ast:: Fn { sig, ident, generics, .. } ) => {
81
79
Some ( ( iitem. vis . clone ( ) , sig. clone ( ) , ident. clone ( ) , generics. clone ( ) , false ) )
@@ -182,11 +180,8 @@ mod llvm_enzyme {
182
180
}
183
181
184
182
/// We expand the autodiff macro to generate a new placeholder function which passes
185
- /// type-checking and can be called by users. The function body of the placeholder function will
186
- /// later be replaced on LLVM-IR level, so the design of the body is less important and for now
187
- /// should just prevent early inlining and optimizations which alter the function signature.
188
- /// The exact signature of the generated function depends on the configuration provided by the
189
- /// user, but here is an example:
183
+ /// type-checking and can be called by users. The exact signature of the generated function
184
+ /// depends on the configuration provided by the user, but here is an example:
190
185
///
191
186
/// ```
192
187
/// #[autodiff(cos_box, Reverse, Duplicated, Active)]
@@ -202,14 +197,8 @@ mod llvm_enzyme {
202
197
/// f32::sin(**x)
203
198
/// }
204
199
/// #[rustc_autodiff(Reverse, Duplicated, Active)]
205
- /// #[inline(never)]
206
200
/// fn cos_box(x: &Box<f32>, dx: &mut Box<f32>, dret: f32) -> f32 {
207
- /// unsafe {
208
- /// asm!("NOP");
209
- /// };
210
- /// ::core::hint::black_box(sin(x));
211
- /// ::core::hint::black_box((dx, dret));
212
- /// ::core::hint::black_box(sin(x))
201
+ /// std::intrinsics::enzyme_autodiff(sin::<>, cos_box::<>, (x, dx, dret))
213
202
/// }
214
203
/// ```
215
204
/// FIXME(ZuseZ4): Once autodiff is enabled by default, make this a doc comment which is checked
@@ -329,22 +318,20 @@ mod llvm_enzyme {
329
318
}
330
319
let span = ecx. with_def_site_ctxt ( expand_span) ;
331
320
332
- let ( d_sig, idents , errored ) = gen_enzyme_decl ( ecx, & sig, & x, span) ;
321
+ let d_sig = gen_enzyme_decl ( ecx, & sig, & x, span) ;
333
322
334
323
let d_body = gen_enzyme_body (
335
324
ecx,
336
325
& d_sig,
337
326
primal,
338
327
span,
339
- idents,
340
- errored,
341
328
first_ident ( & meta_item_vec[ 0 ] ) ,
342
329
& generics,
343
330
impl_of_trait,
344
331
) ;
345
332
346
333
// The first element of it is the name of the function to be generated
347
- let asdf = Box :: new ( ast:: Fn {
334
+ let d_fn = Box :: new ( ast:: Fn {
348
335
defaultness : ast:: Defaultness :: Final ,
349
336
sig : d_sig,
350
337
ident : first_ident ( & meta_item_vec[ 0 ] ) ,
@@ -453,13 +440,13 @@ mod llvm_enzyme {
453
440
Annotatable :: AssocItem ( d_fn, Impl { of_trait : false } )
454
441
}
455
442
Annotatable :: Item ( _) => {
456
- let mut d_fn = ecx. item ( span, thin_vec ! [ d_attr] , ItemKind :: Fn ( asdf ) ) ;
443
+ let mut d_fn = ecx. item ( span, thin_vec ! [ d_attr] , ItemKind :: Fn ( d_fn ) ) ;
457
444
d_fn. vis = vis;
458
445
459
446
Annotatable :: Item ( d_fn)
460
447
}
461
448
Annotatable :: Stmt ( _) => {
462
- let mut d_fn = ecx. item ( span, thin_vec ! [ d_attr] , ItemKind :: Fn ( asdf ) ) ;
449
+ let mut d_fn = ecx. item ( span, thin_vec ! [ d_attr] , ItemKind :: Fn ( d_fn ) ) ;
463
450
d_fn. vis = vis;
464
451
465
452
Annotatable :: Stmt ( Box :: new ( ast:: Stmt {
@@ -524,14 +511,8 @@ mod llvm_enzyme {
524
511
. into ( ) ,
525
512
) ;
526
513
527
- let enzyme_path = ecx. path (
528
- span,
529
- vec ! [
530
- Ident :: from_str( "std" ) ,
531
- Ident :: from_str( "intrinsics" ) ,
532
- Ident :: with_dummy_span( sym:: enzyme_autodiff) ,
533
- ] ,
534
- ) ;
514
+ let enzyme_path_idents = ecx. std_path ( & [ sym:: intrinsics, sym:: enzyme_autodiff] ) ;
515
+ let enzyme_path = ecx. path ( span, enzyme_path_idents) ;
535
516
let call_expr = ecx. expr_call (
536
517
span,
537
518
ecx. expr_path ( enzyme_path) ,
@@ -549,7 +530,7 @@ mod llvm_enzyme {
549
530
generics : & Generics ,
550
531
span : Span ,
551
532
is_impl : bool ,
552
- ) -> P < ast:: Expr > {
533
+ ) -> Box < ast:: Expr > {
553
534
let generic_args = generics
554
535
. params
555
536
. iter ( )
@@ -573,7 +554,7 @@ mod llvm_enzyme {
573
554
let segment = PathSegment {
574
555
ident,
575
556
id : ast:: DUMMY_NODE_ID ,
576
- args : Some ( P ( GenericArgs :: AngleBracketed ( args) ) ) ,
557
+ args : Some ( Box :: new ( GenericArgs :: AngleBracketed ( args) ) ) ,
577
558
} ;
578
559
579
560
let segments = if is_impl {
@@ -590,25 +571,6 @@ mod llvm_enzyme {
590
571
ecx. expr_path ( path)
591
572
}
592
573
593
- // Will generate a body of the type:
594
- // ```
595
- // primal(args);
596
- // std::intrinsics::enzyme_autodiff(primal, diff, (args))
597
- // }
598
- // ```
599
- fn init_body_helper (
600
- ecx : & ExtCtxt < ' _ > ,
601
- span : Span ,
602
- primal : Ident ,
603
- idents : & [ Ident ] ,
604
- _errored : bool ,
605
- generics : & Generics ,
606
- ) -> Box < ast:: Block > {
607
- let _primal_call = gen_primal_call ( ecx, span, primal, idents, generics) ;
608
- let body = ecx. block ( span, ThinVec :: new ( ) ) ;
609
- body
610
- }
611
-
612
574
/// We only want this function to type-check, since we will replace the body
613
575
/// later on llvm level. Using `loop {}` does not cover all return types anymore,
614
576
/// so instead we manually build something that should pass the type checker.
@@ -622,8 +584,6 @@ mod llvm_enzyme {
622
584
d_sig : & ast:: FnSig ,
623
585
primal : Ident ,
624
586
span : Span ,
625
- idents : Vec < Ident > ,
626
- errored : bool ,
627
587
diff_ident : Ident ,
628
588
generics : & Generics ,
629
589
is_impl : bool ,
@@ -632,87 +592,22 @@ mod llvm_enzyme {
632
592
633
593
// Add a call to the primal function to prevent it from being inlined
634
594
// and call `enzyme_autodiff` intrinsic (this also covers the return type)
635
- let mut body = init_body_helper ( ecx, span, primal, & idents, errored, generics) ;
636
-
637
- body. stmts . push ( call_enzyme_autodiff (
638
- ecx,
639
- primal,
640
- diff_ident,
641
- new_decl_span,
642
- d_sig,
643
- generics,
644
- is_impl,
645
- ) ) ;
595
+ let body = ecx. block (
596
+ span,
597
+ thin_vec ! [ call_enzyme_autodiff(
598
+ ecx,
599
+ primal,
600
+ diff_ident,
601
+ new_decl_span,
602
+ d_sig,
603
+ generics,
604
+ is_impl,
605
+ ) ] ,
606
+ ) ;
646
607
647
608
body
648
609
}
649
610
650
- fn gen_primal_call (
651
- ecx : & ExtCtxt < ' _ > ,
652
- span : Span ,
653
- primal : Ident ,
654
- idents : & [ Ident ] ,
655
- generics : & Generics ,
656
- ) -> Box < ast:: Expr > {
657
- let has_self = idents. len ( ) > 0 && idents[ 0 ] . name == kw:: SelfLower ;
658
-
659
- if has_self {
660
- let args: ThinVec < _ > =
661
- idents[ 1 ..] . iter ( ) . map ( |arg| ecx. expr_path ( ecx. path_ident ( span, * arg) ) ) . collect ( ) ;
662
- let self_expr = ecx. expr_self ( span) ;
663
- ecx. expr_method_call ( span, self_expr, primal, args)
664
- } else {
665
- let args: ThinVec < _ > =
666
- idents. iter ( ) . map ( |arg| ecx. expr_path ( ecx. path_ident ( span, * arg) ) ) . collect ( ) ;
667
- let mut primal_path = ecx. path_ident ( span, primal) ;
668
-
669
- let is_generic = !generics. params . is_empty ( ) ;
670
-
671
- match ( is_generic, primal_path. segments . last_mut ( ) ) {
672
- ( true , Some ( function_path) ) => {
673
- let primal_generic_types = generics
674
- . params
675
- . iter ( )
676
- . filter ( |param| matches ! ( param. kind, ast:: GenericParamKind :: Type { .. } ) ) ;
677
-
678
- let generated_generic_types = primal_generic_types
679
- . map ( |type_param| {
680
- let generic_param = TyKind :: Path (
681
- None ,
682
- ast:: Path {
683
- span,
684
- segments : thin_vec ! [ ast:: PathSegment {
685
- ident: type_param. ident,
686
- args: None ,
687
- id: ast:: DUMMY_NODE_ID ,
688
- } ] ,
689
- tokens : None ,
690
- } ,
691
- ) ;
692
-
693
- ast:: AngleBracketedArg :: Arg ( ast:: GenericArg :: Type ( Box :: new ( ast:: Ty {
694
- id : type_param. id ,
695
- span,
696
- kind : generic_param,
697
- tokens : None ,
698
- } ) ) )
699
- } )
700
- . collect ( ) ;
701
-
702
- function_path. args =
703
- Some ( Box :: new ( ast:: GenericArgs :: AngleBracketed ( ast:: AngleBracketedArgs {
704
- span,
705
- args : generated_generic_types,
706
- } ) ) ) ;
707
- }
708
- _ => { }
709
- }
710
-
711
- let primal_call_expr = ecx. expr_path ( primal_path) ;
712
- ecx. expr_call ( span, primal_call_expr, args)
713
- }
714
- }
715
-
716
611
// Generate the new function declaration. Const arguments are kept as is. Duplicated arguments must
717
612
// be pointers or references. Those receive a shadow argument, which is a mutable reference/pointer.
718
613
// Active arguments must be scalars. Their shadow argument is added to the return type (and will be
@@ -729,7 +624,7 @@ mod llvm_enzyme {
729
624
sig : & ast:: FnSig ,
730
625
x : & AutoDiffAttrs ,
731
626
span : Span ,
732
- ) -> ( ast:: FnSig , Vec < Ident > , bool ) {
627
+ ) -> ast:: FnSig {
733
628
let dcx = ecx. sess . dcx ( ) ;
734
629
let has_ret = has_ret ( & sig. decl . output ) ;
735
630
let sig_args = sig. decl . inputs . len ( ) + if has_ret { 1 } else { 0 } ;
@@ -741,7 +636,7 @@ mod llvm_enzyme {
741
636
found : num_activities,
742
637
} ) ;
743
638
// This is not the right signature, but we can continue parsing.
744
- return ( sig. clone ( ) , vec ! [ ] , true ) ;
639
+ return sig. clone ( ) ;
745
640
}
746
641
assert ! ( sig. decl. inputs. len( ) == x. input_activity. len( ) ) ;
747
642
assert ! ( has_ret == x. has_ret_activity( ) ) ;
@@ -784,7 +679,7 @@ mod llvm_enzyme {
784
679
785
680
if errors {
786
681
// This is not the right signature, but we can continue parsing.
787
- return ( sig. clone ( ) , idents , true ) ;
682
+ return sig. clone ( ) ;
788
683
}
789
684
790
685
let unsafe_activities = x
@@ -998,7 +893,7 @@ mod llvm_enzyme {
998
893
}
999
894
let d_sig = FnSig { header : d_header, decl : d_decl, span } ;
1000
895
trace ! ( "Generated signature: {:?}" , d_sig) ;
1001
- ( d_sig, idents , false )
896
+ d_sig
1002
897
}
1003
898
}
1004
899
0 commit comments