@@ -16,8 +16,9 @@ mod llvm_enzyme {
16
16
use rustc_ast:: tokenstream:: * ;
17
17
use rustc_ast:: visit:: AssocCtxt :: * ;
18
18
use rustc_ast:: {
19
- self as ast, AssocItemKind , BindingMode , ExprKind , FnRetTy , FnSig , Generics , ItemKind ,
20
- MetaItemInner , PatKind , QSelf , TyKind , Visibility ,
19
+ self as ast, AngleBracketedArg , AngleBracketedArgs , AssocItemKind , BindingMode , ExprKind ,
20
+ FnRetTy , FnSig , GenericArg , GenericArgs , Generics , ItemKind , MetaItemInner , PatKind , Path ,
21
+ PathSegment , QSelf , TyKind , Visibility ,
21
22
} ;
22
23
use rustc_expand:: base:: { Annotatable , ExtCtxt } ;
23
24
use rustc_span:: { Ident , Span , Symbol , kw, sym} ;
@@ -337,8 +338,14 @@ mod llvm_enzyme {
337
338
& generics,
338
339
) ;
339
340
340
- let d_body =
341
- call_enzyme_autodiff ( ecx, primal, first_ident ( & meta_item_vec[ 0 ] ) , span, & d_sig) ;
341
+ let d_body = call_enzyme_autodiff (
342
+ ecx,
343
+ primal,
344
+ first_ident ( & meta_item_vec[ 0 ] ) ,
345
+ span,
346
+ & d_sig,
347
+ & generics,
348
+ ) ;
342
349
343
350
// The first element of it is the name of the function to be generated
344
351
let asdf = Box :: new ( ast:: Fn {
@@ -505,9 +512,10 @@ mod llvm_enzyme {
505
512
diff : Ident ,
506
513
span : Span ,
507
514
d_sig : & FnSig ,
515
+ generics : & Generics ,
508
516
) -> P < ast:: Block > {
509
- let primal_path_expr = ecx . expr_path ( ecx. path_ident ( span , primal) ) ;
510
- let diff_path_expr = ecx . expr_path ( ecx. path_ident ( span , diff) ) ;
517
+ let primal_path_expr = gen_turbofish_expr ( ecx, primal, generics , span ) ;
518
+ let diff_path_expr = gen_turbofish_expr ( ecx, diff, generics , span ) ;
511
519
512
520
let tuple_expr = ecx. expr_tuple (
513
521
span,
@@ -542,6 +550,37 @@ mod llvm_enzyme {
542
550
block
543
551
}
544
552
553
+ // Generate turbofish expression from fn name and generics
554
+ // Given `foo` and `<A, B, C>`, gen `foo::<A, B, C>`
555
+ fn gen_turbofish_expr (
556
+ ecx : & ExtCtxt < ' _ > ,
557
+ ident : Ident ,
558
+ generics : & Generics ,
559
+ span : Span ,
560
+ ) -> P < ast:: Expr > {
561
+ let generic_args = generics
562
+ . params
563
+ . iter ( )
564
+ . map ( |p| {
565
+ let path = ast:: Path :: from_ident ( p. ident ) ;
566
+ let ty = ecx. ty_path ( path) ;
567
+ AngleBracketedArg :: Arg ( GenericArg :: Type ( ty) )
568
+ } )
569
+ . collect :: < ThinVec < _ > > ( ) ;
570
+
571
+ let args = AngleBracketedArgs { span, args : generic_args } ;
572
+
573
+ let segment = PathSegment {
574
+ ident,
575
+ id : ast:: DUMMY_NODE_ID ,
576
+ args : Some ( P ( GenericArgs :: AngleBracketed ( args) ) ) ,
577
+ } ;
578
+
579
+ let path = Path { span, segments : thin_vec ! [ segment] , tokens : None } ;
580
+
581
+ ecx. expr_path ( path)
582
+ }
583
+
545
584
// Generate dummy const to prevent primal function
546
585
// from being optimized away before applying enzyme
547
586
// ```
0 commit comments