Skip to content

Commit 0c22a2c

Browse files
committed
Macro expansion cleanup
1 parent 3aecee9 commit 0c22a2c

File tree

4 files changed

+58
-155
lines changed

4 files changed

+58
-155
lines changed

compiler/rustc_builtin_macros/src/autodiff.rs

Lines changed: 29 additions & 134 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ mod llvm_enzyme {
2020
MetaItemInner, PatKind, Path, PathSegment, TyKind, Visibility,
2121
};
2222
use rustc_expand::base::{Annotatable, ExtCtxt};
23-
use rustc_span::{Ident, Span, Symbol, kw, sym};
23+
use rustc_span::{Ident, Span, Symbol, sym};
2424
use thin_vec::{ThinVec, thin_vec};
2525
use tracing::{debug, trace};
2626

@@ -73,9 +73,7 @@ mod llvm_enzyme {
7373
}
7474

7575
// Get information about the function the macro is applied to
76-
fn extract_item_info(
77-
iitem: &Box<ast::Item>,
78-
) -> Option<(Visibility, FnSig, Ident, Generics, bool)> {
76+
fn extract_item_info(iitem: &Box<ast::Item>) -> Option<(Visibility, FnSig, Ident, Generics)> {
7977
match &iitem.kind {
8078
ItemKind::Fn(box ast::Fn { sig, ident, generics, .. }) => {
8179
Some((iitem.vis.clone(), sig.clone(), ident.clone(), generics.clone(), false))
@@ -182,11 +180,8 @@ mod llvm_enzyme {
182180
}
183181

184182
/// We expand the autodiff macro to generate a new placeholder function which passes
185-
/// type-checking and can be called by users. The function body of the placeholder function will
186-
/// later be replaced on LLVM-IR level, so the design of the body is less important and for now
187-
/// should just prevent early inlining and optimizations which alter the function signature.
188-
/// The exact signature of the generated function depends on the configuration provided by the
189-
/// user, but here is an example:
183+
/// type-checking and can be called by users. The exact signature of the generated function
184+
/// depends on the configuration provided by the user, but here is an example:
190185
///
191186
/// ```
192187
/// #[autodiff(cos_box, Reverse, Duplicated, Active)]
@@ -202,14 +197,8 @@ mod llvm_enzyme {
202197
/// f32::sin(**x)
203198
/// }
204199
/// #[rustc_autodiff(Reverse, Duplicated, Active)]
205-
/// #[inline(never)]
206200
/// fn cos_box(x: &Box<f32>, dx: &mut Box<f32>, dret: f32) -> f32 {
207-
/// unsafe {
208-
/// asm!("NOP");
209-
/// };
210-
/// ::core::hint::black_box(sin(x));
211-
/// ::core::hint::black_box((dx, dret));
212-
/// ::core::hint::black_box(sin(x))
201+
/// std::intrinsics::enzyme_autodiff(sin::<>, cos_box::<>, (x, dx, dret))
213202
/// }
214203
/// ```
215204
/// FIXME(ZuseZ4): Once autodiff is enabled by default, make this a doc comment which is checked
@@ -329,22 +318,20 @@ mod llvm_enzyme {
329318
}
330319
let span = ecx.with_def_site_ctxt(expand_span);
331320

332-
let (d_sig, idents, errored) = gen_enzyme_decl(ecx, &sig, &x, span);
321+
let d_sig = gen_enzyme_decl(ecx, &sig, &x, span);
333322

334323
let d_body = gen_enzyme_body(
335324
ecx,
336325
&d_sig,
337326
primal,
338327
span,
339-
idents,
340-
errored,
341328
first_ident(&meta_item_vec[0]),
342329
&generics,
343330
impl_of_trait,
344331
);
345332

346333
// The first element of it is the name of the function to be generated
347-
let asdf = Box::new(ast::Fn {
334+
let d_fn = Box::new(ast::Fn {
348335
defaultness: ast::Defaultness::Final,
349336
sig: d_sig,
350337
ident: first_ident(&meta_item_vec[0]),
@@ -453,13 +440,13 @@ mod llvm_enzyme {
453440
Annotatable::AssocItem(d_fn, Impl { of_trait: false })
454441
}
455442
Annotatable::Item(_) => {
456-
let mut d_fn = ecx.item(span, thin_vec![d_attr], ItemKind::Fn(asdf));
443+
let mut d_fn = ecx.item(span, thin_vec![d_attr], ItemKind::Fn(d_fn));
457444
d_fn.vis = vis;
458445

459446
Annotatable::Item(d_fn)
460447
}
461448
Annotatable::Stmt(_) => {
462-
let mut d_fn = ecx.item(span, thin_vec![d_attr], ItemKind::Fn(asdf));
449+
let mut d_fn = ecx.item(span, thin_vec![d_attr], ItemKind::Fn(d_fn));
463450
d_fn.vis = vis;
464451

465452
Annotatable::Stmt(Box::new(ast::Stmt {
@@ -524,14 +511,8 @@ mod llvm_enzyme {
524511
.into(),
525512
);
526513

527-
let enzyme_path = ecx.path(
528-
span,
529-
vec![
530-
Ident::from_str("std"),
531-
Ident::from_str("intrinsics"),
532-
Ident::with_dummy_span(sym::enzyme_autodiff),
533-
],
534-
);
514+
let enzyme_path_idents = ecx.std_path(&[sym::intrinsics, sym::enzyme_autodiff]);
515+
let enzyme_path = ecx.path(span, enzyme_path_idents);
535516
let call_expr = ecx.expr_call(
536517
span,
537518
ecx.expr_path(enzyme_path),
@@ -549,7 +530,7 @@ mod llvm_enzyme {
549530
generics: &Generics,
550531
span: Span,
551532
is_impl: bool,
552-
) -> P<ast::Expr> {
533+
) -> Box<ast::Expr> {
553534
let generic_args = generics
554535
.params
555536
.iter()
@@ -573,7 +554,7 @@ mod llvm_enzyme {
573554
let segment = PathSegment {
574555
ident,
575556
id: ast::DUMMY_NODE_ID,
576-
args: Some(P(GenericArgs::AngleBracketed(args))),
557+
args: Some(Box::new(GenericArgs::AngleBracketed(args))),
577558
};
578559

579560
let segments = if is_impl {
@@ -590,25 +571,6 @@ mod llvm_enzyme {
590571
ecx.expr_path(path)
591572
}
592573

593-
// Will generate a body of the type:
594-
// ```
595-
// primal(args);
596-
// std::intrinsics::enzyme_autodiff(primal, diff, (args))
597-
// }
598-
// ```
599-
fn init_body_helper(
600-
ecx: &ExtCtxt<'_>,
601-
span: Span,
602-
primal: Ident,
603-
idents: &[Ident],
604-
_errored: bool,
605-
generics: &Generics,
606-
) -> Box<ast::Block> {
607-
let _primal_call = gen_primal_call(ecx, span, primal, idents, generics);
608-
let body = ecx.block(span, ThinVec::new());
609-
body
610-
}
611-
612574
/// We only want this function to type-check, since we will replace the body
613575
/// later on llvm level. Using `loop {}` does not cover all return types anymore,
614576
/// so instead we manually build something that should pass the type checker.
@@ -622,8 +584,6 @@ mod llvm_enzyme {
622584
d_sig: &ast::FnSig,
623585
primal: Ident,
624586
span: Span,
625-
idents: Vec<Ident>,
626-
errored: bool,
627587
diff_ident: Ident,
628588
generics: &Generics,
629589
is_impl: bool,
@@ -632,87 +592,22 @@ mod llvm_enzyme {
632592

633593
// Add a call to the primal function to prevent it from being inlined
634594
// and call `enzyme_autodiff` intrinsic (this also covers the return type)
635-
let mut body = init_body_helper(ecx, span, primal, &idents, errored, generics);
636-
637-
body.stmts.push(call_enzyme_autodiff(
638-
ecx,
639-
primal,
640-
diff_ident,
641-
new_decl_span,
642-
d_sig,
643-
generics,
644-
is_impl,
645-
));
595+
let body = ecx.block(
596+
span,
597+
thin_vec![call_enzyme_autodiff(
598+
ecx,
599+
primal,
600+
diff_ident,
601+
new_decl_span,
602+
d_sig,
603+
generics,
604+
is_impl,
605+
)],
606+
);
646607

647608
body
648609
}
649610

650-
fn gen_primal_call(
651-
ecx: &ExtCtxt<'_>,
652-
span: Span,
653-
primal: Ident,
654-
idents: &[Ident],
655-
generics: &Generics,
656-
) -> Box<ast::Expr> {
657-
let has_self = idents.len() > 0 && idents[0].name == kw::SelfLower;
658-
659-
if has_self {
660-
let args: ThinVec<_> =
661-
idents[1..].iter().map(|arg| ecx.expr_path(ecx.path_ident(span, *arg))).collect();
662-
let self_expr = ecx.expr_self(span);
663-
ecx.expr_method_call(span, self_expr, primal, args)
664-
} else {
665-
let args: ThinVec<_> =
666-
idents.iter().map(|arg| ecx.expr_path(ecx.path_ident(span, *arg))).collect();
667-
let mut primal_path = ecx.path_ident(span, primal);
668-
669-
let is_generic = !generics.params.is_empty();
670-
671-
match (is_generic, primal_path.segments.last_mut()) {
672-
(true, Some(function_path)) => {
673-
let primal_generic_types = generics
674-
.params
675-
.iter()
676-
.filter(|param| matches!(param.kind, ast::GenericParamKind::Type { .. }));
677-
678-
let generated_generic_types = primal_generic_types
679-
.map(|type_param| {
680-
let generic_param = TyKind::Path(
681-
None,
682-
ast::Path {
683-
span,
684-
segments: thin_vec![ast::PathSegment {
685-
ident: type_param.ident,
686-
args: None,
687-
id: ast::DUMMY_NODE_ID,
688-
}],
689-
tokens: None,
690-
},
691-
);
692-
693-
ast::AngleBracketedArg::Arg(ast::GenericArg::Type(Box::new(ast::Ty {
694-
id: type_param.id,
695-
span,
696-
kind: generic_param,
697-
tokens: None,
698-
})))
699-
})
700-
.collect();
701-
702-
function_path.args =
703-
Some(Box::new(ast::GenericArgs::AngleBracketed(ast::AngleBracketedArgs {
704-
span,
705-
args: generated_generic_types,
706-
})));
707-
}
708-
_ => {}
709-
}
710-
711-
let primal_call_expr = ecx.expr_path(primal_path);
712-
ecx.expr_call(span, primal_call_expr, args)
713-
}
714-
}
715-
716611
// Generate the new function declaration. Const arguments are kept as is. Duplicated arguments must
717612
// be pointers or references. Those receive a shadow argument, which is a mutable reference/pointer.
718613
// Active arguments must be scalars. Their shadow argument is added to the return type (and will be
@@ -729,7 +624,7 @@ mod llvm_enzyme {
729624
sig: &ast::FnSig,
730625
x: &AutoDiffAttrs,
731626
span: Span,
732-
) -> (ast::FnSig, Vec<Ident>, bool) {
627+
) -> ast::FnSig {
733628
let dcx = ecx.sess.dcx();
734629
let has_ret = has_ret(&sig.decl.output);
735630
let sig_args = sig.decl.inputs.len() + if has_ret { 1 } else { 0 };
@@ -741,7 +636,7 @@ mod llvm_enzyme {
741636
found: num_activities,
742637
});
743638
// This is not the right signature, but we can continue parsing.
744-
return (sig.clone(), vec![], true);
639+
return sig.clone();
745640
}
746641
assert!(sig.decl.inputs.len() == x.input_activity.len());
747642
assert!(has_ret == x.has_ret_activity());
@@ -784,7 +679,7 @@ mod llvm_enzyme {
784679

785680
if errors {
786681
// This is not the right signature, but we can continue parsing.
787-
return (sig.clone(), idents, true);
682+
return sig.clone();
788683
}
789684

790685
let unsafe_activities = x
@@ -998,7 +893,7 @@ mod llvm_enzyme {
998893
}
999894
let d_sig = FnSig { header: d_header, decl: d_decl, span };
1000895
trace!("Generated signature: {:?}", d_sig);
1001-
(d_sig, idents, false)
896+
d_sig
1002897
}
1003898
}
1004899

0 commit comments

Comments
 (0)