@@ -329,17 +329,22 @@ mod llvm_enzyme {
329
329
. filter ( |a| * * a == DiffActivity :: Active || * * a == DiffActivity :: ActiveOnly )
330
330
. count ( ) as u32 ;
331
331
let ( d_sig, new_args, idents, errored) = gen_enzyme_decl ( ecx, & sig, & x, span) ;
332
- let d_body = gen_enzyme_body (
332
+
333
+ // TODO(Sa4dUs): Remove this and all the related logic
334
+ let _d_body = gen_enzyme_body (
333
335
ecx, & x, n_active, & sig, & d_sig, primal, & new_args, span, sig_span, idents, errored,
334
336
& generics,
335
337
) ;
336
338
339
+ let d_body =
340
+ call_autodiff ( ecx, primal, first_ident ( & meta_item_vec[ 0 ] ) , span, & d_sig) ;
341
+
337
342
// The first element of it is the name of the function to be generated
338
343
let asdf = Box :: new ( ast:: Fn {
339
344
defaultness : ast:: Defaultness :: Final ,
340
345
sig : d_sig,
341
346
ident : first_ident ( & meta_item_vec[ 0 ] ) ,
342
- generics,
347
+ generics : generics . clone ( ) ,
343
348
contract : None ,
344
349
body : Some ( d_body) ,
345
350
define_opaque : None ,
@@ -428,12 +433,15 @@ mod llvm_enzyme {
428
433
tokens : ts,
429
434
} ) ;
430
435
436
+ let vis_clone = vis. clone ( ) ;
437
+
438
+ let new_id = ecx. sess . psess . attr_id_generator . mk_attr_id ( ) ;
431
439
let d_attr = outer_normal_attr ( & rustc_ad_attr, new_id, span) ;
432
440
let d_annotatable = match & item {
433
441
Annotatable :: AssocItem ( _, _) => {
434
442
let assoc_item: AssocItemKind = ast:: AssocItemKind :: Fn ( asdf) ;
435
443
let d_fn = Box :: new ( ast:: AssocItem {
436
- attrs : thin_vec ! [ d_attr, inline_never ] ,
444
+ attrs : thin_vec ! [ d_attr] ,
437
445
id : ast:: DUMMY_NODE_ID ,
438
446
span,
439
447
vis,
@@ -443,13 +451,13 @@ mod llvm_enzyme {
443
451
Annotatable :: AssocItem ( d_fn, Impl { of_trait : false } )
444
452
}
445
453
Annotatable :: Item ( _) => {
446
- let mut d_fn = ecx. item ( span, thin_vec ! [ d_attr, inline_never ] , ItemKind :: Fn ( asdf) ) ;
454
+ let mut d_fn = ecx. item ( span, thin_vec ! [ d_attr] , ItemKind :: Fn ( asdf) ) ;
447
455
d_fn. vis = vis;
448
456
449
457
Annotatable :: Item ( d_fn)
450
458
}
451
459
Annotatable :: Stmt ( _) => {
452
- let mut d_fn = ecx. item ( span, thin_vec ! [ d_attr, inline_never ] , ItemKind :: Fn ( asdf) ) ;
460
+ let mut d_fn = ecx. item ( span, thin_vec ! [ d_attr] , ItemKind :: Fn ( asdf) ) ;
453
461
d_fn. vis = vis;
454
462
455
463
Annotatable :: Stmt ( Box :: new ( ast:: Stmt {
@@ -463,7 +471,9 @@ mod llvm_enzyme {
463
471
}
464
472
} ;
465
473
466
- return vec ! [ orig_annotatable, d_annotatable] ;
474
+ let dummy_const_annotatable = gen_dummy_const ( ecx, span, primal, sig, generics, vis_clone) ;
475
+
476
+ return vec ! [ orig_annotatable, dummy_const_annotatable, d_annotatable] ;
467
477
}
468
478
469
479
// shadow arguments (the extra ones which were not in the original (primal) function), in reverse mode must be
@@ -484,6 +494,123 @@ mod llvm_enzyme {
484
494
ty
485
495
}
486
496
497
+ // Generate `autodiff` intrinsic call
498
+ // ```
499
+ // std::intrinsics::autodiff(source, diff, (args))
500
+ // ```
501
+ fn call_autodiff (
502
+ ecx : & ExtCtxt < ' _ > ,
503
+ primal : Ident ,
504
+ diff : Ident ,
505
+ span : Span ,
506
+ d_sig : & FnSig ,
507
+ ) -> P < ast:: Block > {
508
+ let primal_path_expr = ecx. expr_path ( ecx. path_ident ( span, primal) ) ;
509
+ let diff_path_expr = ecx. expr_path ( ecx. path_ident ( span, diff) ) ;
510
+
511
+ let tuple_expr = ecx. expr_tuple (
512
+ span,
513
+ d_sig
514
+ . decl
515
+ . inputs
516
+ . iter ( )
517
+ . map ( |arg| match arg. pat . kind {
518
+ PatKind :: Ident ( _, ident, _) => ecx. expr_path ( ecx. path_ident ( span, ident) ) ,
519
+ _ => todo ! ( ) ,
520
+ } )
521
+ . collect :: < ThinVec < _ > > ( )
522
+ . into ( ) ,
523
+ ) ;
524
+
525
+ let enzyme_path = ecx. path (
526
+ span,
527
+ vec ! [
528
+ Ident :: from_str( "std" ) ,
529
+ Ident :: from_str( "intrinsics" ) ,
530
+ Ident :: from_str( "autodiff" ) ,
531
+ ] ,
532
+ ) ;
533
+ let call_expr = ecx. expr_call (
534
+ span,
535
+ ecx. expr_path ( enzyme_path) ,
536
+ vec ! [ primal_path_expr, diff_path_expr, tuple_expr] . into ( ) ,
537
+ ) ;
538
+
539
+ let block = ecx. block_expr ( call_expr) ;
540
+
541
+ block
542
+ }
543
+
544
+ // Generate dummy const to prevent primal function
545
+ // from being optimized away before applying enzyme
546
+ // ```
547
+ // const _: () =
548
+ // {
549
+ // #[used]
550
+ // pub static DUMMY_PTR: fn_type = primal_fn;
551
+ // };
552
+ // ```
553
+ fn gen_dummy_const (
554
+ ecx : & ExtCtxt < ' _ > ,
555
+ span : Span ,
556
+ primal : Ident ,
557
+ sig : FnSig ,
558
+ generics : Generics ,
559
+ vis : Visibility ,
560
+ ) -> Annotatable {
561
+ // #[used]
562
+ let used_attr = P ( ast:: NormalAttr :: from_ident ( Ident :: with_dummy_span ( sym:: used) ) ) ;
563
+ let new_id = ecx. sess . psess . attr_id_generator . mk_attr_id ( ) ;
564
+ let used_attr = outer_normal_attr ( & used_attr, new_id, span) ;
565
+
566
+ // static DUMMY_PTR: <fn_type> = <primal_ident>
567
+ let static_ident = Ident :: from_str_and_span ( "DUMMY_PTR" , span) ;
568
+ let fn_ptr_ty = ast:: TyKind :: BareFn ( Box :: new ( ast:: BareFnTy {
569
+ safety : sig. header . safety ,
570
+ ext : sig. header . ext ,
571
+ generic_params : generics. params ,
572
+ decl : sig. decl ,
573
+ decl_span : sig. span ,
574
+ } ) ) ;
575
+ let static_ty = ecx. ty ( span, fn_ptr_ty) ;
576
+
577
+ let static_expr = ecx. expr_path ( ecx. path ( span, vec ! [ primal] ) ) ;
578
+ let static_item_kind = ast:: ItemKind :: Static ( Box :: new ( ast:: StaticItem {
579
+ ident : static_ident,
580
+ ty : static_ty,
581
+ safety : ast:: Safety :: Default ,
582
+ mutability : ast:: Mutability :: Not ,
583
+ expr : Some ( static_expr) ,
584
+ define_opaque : None ,
585
+ } ) ) ;
586
+
587
+ let static_item = ast:: Item {
588
+ attrs : thin_vec ! [ used_attr] ,
589
+ id : ast:: DUMMY_NODE_ID ,
590
+ span,
591
+ vis,
592
+ kind : static_item_kind,
593
+ tokens : None ,
594
+ } ;
595
+
596
+ let block_expr = ecx. expr_block ( Box :: new ( ast:: Block {
597
+ stmts : thin_vec ! [ ecx. stmt_item( span, P ( static_item) ) ] ,
598
+ id : ast:: DUMMY_NODE_ID ,
599
+ rules : ast:: BlockCheckMode :: Default ,
600
+ span,
601
+ tokens : None ,
602
+ } ) ) ;
603
+
604
+ let const_item = ecx. item_const (
605
+ span,
606
+ Ident :: from_str_and_span ( "_" , span) ,
607
+ ecx. ty ( span, ast:: TyKind :: Tup ( thin_vec ! [ ] ) ) ,
608
+ block_expr,
609
+ ) ;
610
+
611
+ Annotatable :: Item ( const_item)
612
+ }
613
+
487
614
// Will generate a body of the type:
488
615
// ```
489
616
// {
0 commit comments