Skip to content

Commit 0a8a3a7

Browse files
committed
Remove gen_enzyme_body and call it in place instead
1 parent e8919f4 commit 0a8a3a7

File tree

1 file changed

+10
-44
lines changed

1 file changed

+10
-44
lines changed

compiler/rustc_builtin_macros/src/autodiff.rs

Lines changed: 10 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -323,14 +323,17 @@ mod llvm_enzyme {
323323

324324
let d_sig = gen_enzyme_decl(ecx, &sig, &x, span);
325325

326-
let d_body = gen_enzyme_body(
327-
ecx,
328-
&d_sig,
329-
primal,
326+
let d_body = ecx.block(
330327
span,
331-
first_ident(&meta_item_vec[0]),
332-
&generics,
333-
impl_of_trait,
328+
thin_vec![call_autodiff(
329+
ecx,
330+
primal,
331+
first_ident(&meta_item_vec[0]),
332+
span,
333+
&d_sig,
334+
&generics,
335+
impl_of_trait,
336+
)],
334337
);
335338

336339
// The first element of it is the name of the function to be generated
@@ -584,43 +587,6 @@ mod llvm_enzyme {
584587
ecx.expr_path(path)
585588
}
586589

587-
/// We only want this function to type-check, since we will replace the body
588-
/// later on llvm level. Using `loop {}` does not cover all return types anymore,
589-
/// so instead we manually build something that should pass the type checker.
590-
/// We also add a inline_asm line, as one more barrier for rustc to prevent inlining
591-
/// or const propagation. inline_asm will also triggers an Enzyme crash if due to another
592-
/// bug would ever try to accidentally differentiate this placeholder function body.
593-
/// Finally, we also add back_box usages of all input arguments, to prevent rustc
594-
/// from optimizing any arguments away.
595-
fn gen_enzyme_body(
596-
ecx: &ExtCtxt<'_>,
597-
d_sig: &ast::FnSig,
598-
primal: Ident,
599-
span: Span,
600-
diff_ident: Ident,
601-
generics: &Generics,
602-
is_impl: bool,
603-
) -> Box<ast::Block> {
604-
let new_decl_span = d_sig.span;
605-
606-
// Add a call to the primal function to prevent it from being inlined
607-
// and call `autodiff` intrinsic (this also covers the return type)
608-
let body = ecx.block(
609-
span,
610-
thin_vec![call_autodiff(
611-
ecx,
612-
primal,
613-
diff_ident,
614-
new_decl_span,
615-
d_sig,
616-
generics,
617-
is_impl,
618-
)],
619-
);
620-
621-
body
622-
}
623-
624590
// Generate the new function declaration. Const arguments are kept as is. Duplicated arguments must
625591
// be pointers or references. Those receive a shadow argument, which is a mutable reference/pointer.
626592
// Active arguments must be scalars. Their shadow argument is added to the return type (and will be

0 commit comments

Comments
 (0)