@@ -323,14 +323,17 @@ mod llvm_enzyme {
323
323
324
324
let d_sig = gen_enzyme_decl ( ecx, & sig, & x, span) ;
325
325
326
- let d_body = gen_enzyme_body (
327
- ecx,
328
- & d_sig,
329
- primal,
326
+ let d_body = ecx. block (
330
327
span,
331
- first_ident ( & meta_item_vec[ 0 ] ) ,
332
- & generics,
333
- impl_of_trait,
328
+ thin_vec ! [ call_autodiff(
329
+ ecx,
330
+ primal,
331
+ first_ident( & meta_item_vec[ 0 ] ) ,
332
+ span,
333
+ & d_sig,
334
+ & generics,
335
+ impl_of_trait,
336
+ ) ] ,
334
337
) ;
335
338
336
339
// The first element of it is the name of the function to be generated
@@ -584,43 +587,6 @@ mod llvm_enzyme {
584
587
ecx. expr_path ( path)
585
588
}
586
589
587
- /// We only want this function to type-check, since we will replace the body
588
- /// later on llvm level. Using `loop {}` does not cover all return types anymore,
589
- /// so instead we manually build something that should pass the type checker.
590
- /// We also add a inline_asm line, as one more barrier for rustc to prevent inlining
591
- /// or const propagation. inline_asm will also triggers an Enzyme crash if due to another
592
- /// bug would ever try to accidentally differentiate this placeholder function body.
593
- /// Finally, we also add back_box usages of all input arguments, to prevent rustc
594
- /// from optimizing any arguments away.
595
- fn gen_enzyme_body (
596
- ecx : & ExtCtxt < ' _ > ,
597
- d_sig : & ast:: FnSig ,
598
- primal : Ident ,
599
- span : Span ,
600
- diff_ident : Ident ,
601
- generics : & Generics ,
602
- is_impl : bool ,
603
- ) -> Box < ast:: Block > {
604
- let new_decl_span = d_sig. span ;
605
-
606
- // Add a call to the primal function to prevent it from being inlined
607
- // and call `autodiff` intrinsic (this also covers the return type)
608
- let body = ecx. block (
609
- span,
610
- thin_vec ! [ call_autodiff(
611
- ecx,
612
- primal,
613
- diff_ident,
614
- new_decl_span,
615
- d_sig,
616
- generics,
617
- is_impl,
618
- ) ] ,
619
- ) ;
620
-
621
- body
622
- }
623
-
624
590
// Generate the new function declaration. Const arguments are kept as is. Duplicated arguments must
625
591
// be pointers or references. Those receive a shadow argument, which is a mutable reference/pointer.
626
592
// Active arguments must be scalars. Their shadow argument is added to the return type (and will be
0 commit comments