@@ -15,8 +15,9 @@ mod llvm_enzyme {
15
15
use rustc_ast:: tokenstream:: * ;
16
16
use rustc_ast:: visit:: AssocCtxt :: * ;
17
17
use rustc_ast:: {
18
- self as ast, AssocItemKind , BindingMode , ExprKind , FnRetTy , FnSig , Generics , ItemKind ,
19
- MetaItemInner , PatKind , QSelf , TyKind , Visibility ,
18
+ self as ast, AngleBracketedArg , AngleBracketedArgs , AssocItemKind , BindingMode , ExprKind ,
19
+ FnRetTy , FnSig , GenericArg , GenericArgs , Generics , ItemKind , MetaItemInner , PatKind , Path ,
20
+ PathSegment , QSelf , TyKind , Visibility ,
20
21
} ;
21
22
use rustc_expand:: base:: { Annotatable , ExtCtxt } ;
22
23
use rustc_span:: { Ident , Span , Symbol , kw, sym} ;
@@ -336,8 +337,14 @@ mod llvm_enzyme {
336
337
& generics,
337
338
) ;
338
339
339
- let d_body =
340
- call_enzyme_autodiff ( ecx, primal, first_ident ( & meta_item_vec[ 0 ] ) , span, & d_sig) ;
340
+ let d_body = call_enzyme_autodiff (
341
+ ecx,
342
+ primal,
343
+ first_ident ( & meta_item_vec[ 0 ] ) ,
344
+ span,
345
+ & d_sig,
346
+ & generics,
347
+ ) ;
341
348
342
349
// The first element of it is the name of the function to be generated
343
350
let asdf = Box :: new ( ast:: Fn {
@@ -504,9 +511,10 @@ mod llvm_enzyme {
504
511
diff : Ident ,
505
512
span : Span ,
506
513
d_sig : & FnSig ,
514
+ generics : & Generics ,
507
515
) -> P < ast:: Block > {
508
- let primal_path_expr = ecx . expr_path ( ecx. path_ident ( span , primal) ) ;
509
- let diff_path_expr = ecx . expr_path ( ecx. path_ident ( span , diff) ) ;
516
+ let primal_path_expr = gen_turbofish_expr ( ecx, primal, generics , span ) ;
517
+ let diff_path_expr = gen_turbofish_expr ( ecx, diff, generics , span ) ;
510
518
511
519
let tuple_expr = ecx. expr_tuple (
512
520
span,
@@ -541,6 +549,37 @@ mod llvm_enzyme {
541
549
block
542
550
}
543
551
552
+ // Generate turbofish expression from fn name and generics
553
+ // Given `foo` and `<A, B, C>`, gen `foo::<A, B, C>`
554
+ fn gen_turbofish_expr (
555
+ ecx : & ExtCtxt < ' _ > ,
556
+ ident : Ident ,
557
+ generics : & Generics ,
558
+ span : Span ,
559
+ ) -> P < ast:: Expr > {
560
+ let generic_args = generics
561
+ . params
562
+ . iter ( )
563
+ . map ( |p| {
564
+ let path = ast:: Path :: from_ident ( p. ident ) ;
565
+ let ty = ecx. ty_path ( path) ;
566
+ AngleBracketedArg :: Arg ( GenericArg :: Type ( ty) )
567
+ } )
568
+ . collect :: < ThinVec < _ > > ( ) ;
569
+
570
+ let args = AngleBracketedArgs { span, args : generic_args } ;
571
+
572
+ let segment = PathSegment {
573
+ ident,
574
+ id : ast:: DUMMY_NODE_ID ,
575
+ args : Some ( P ( GenericArgs :: AngleBracketed ( args) ) ) ,
576
+ } ;
577
+
578
+ let path = Path { span, segments : thin_vec ! [ segment] , tokens : None } ;
579
+
580
+ ecx. expr_path ( path)
581
+ }
582
+
544
583
// Generate dummy const to prevent primal function
545
584
// from being optimized away before applying enzyme
546
585
// ```
0 commit comments