Skip to content

Commit b8c1ca7

Browse files
committed
Remove primal call and collect it in mono instead
1 parent d0ae1fd commit b8c1ca7

File tree

3 files changed

+48
-9
lines changed

3 files changed

+48
-9
lines changed

compiler/rustc_builtin_macros/src/autodiff.rs

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -594,17 +594,11 @@ mod llvm_enzyme {
594594
span: Span,
595595
primal: Ident,
596596
idents: &[Ident],
597-
errored: bool,
597+
_errored: bool,
598598
generics: &Generics,
599599
) -> P<ast::Block> {
600-
let primal_call = gen_primal_call(ecx, span, primal, idents, generics);
601-
let mut body = ecx.block(span, ThinVec::new());
602-
603-
// This uses primal args which won't be available if we errored before
604-
if !errored {
605-
body.stmts.push(ecx.stmt_semi(primal_call.clone()));
606-
}
607-
600+
let _primal_call = gen_primal_call(ecx, span, primal, idents, generics);
601+
let body = ecx.block(span, ThinVec::new());
608602
body
609603
}
610604

compiler/rustc_monomorphize/src/collector.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,8 @@
205205
//! this is not implemented however: a mono item will be produced
206206
//! regardless of whether it is actually needed or not.
207207
208+
mod autodiff;
209+
208210
use std::cell::OnceCell;
209211
use std::path::PathBuf;
210212

@@ -237,6 +239,8 @@ use rustc_span::source_map::{Spanned, dummy_spanned, respan};
237239
use rustc_span::{DUMMY_SP, Span};
238240
use tracing::{debug, instrument, trace};
239241

242+
#[cfg(llvm_enzyme)]
243+
use crate::collector::autodiff::collect_enzyme_autodiff_source_fn;
240244
use crate::errors::{self, EncounteredErrorWhileInstantiating, NoOptimizedMir, RecursionLimit};
241245

242246
#[derive(PartialEq)]
@@ -913,6 +917,9 @@ fn visit_instance_use<'tcx>(
913917
return;
914918
}
915919
if let Some(intrinsic) = tcx.intrinsic(instance.def_id()) {
920+
#[cfg(llvm_enzyme)]
921+
collect_enzyme_autodiff_source_fn(tcx, instance, intrinsic, output);
922+
916923
if let Some(_requirement) = ValidityRequirement::from_intrinsic(intrinsic.name) {
917924
// The intrinsics assert_inhabited, assert_zero_valid, and assert_mem_uninitialized_valid will
918925
// be lowered in codegen to nothing or a call to panic_nounwind. So if we encounter any
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
use rustc_middle::bug;
2+
use rustc_middle::ty::{self, IntrinsicDef, TyCtxt};
3+
use tracing::debug;
4+
5+
use crate::collector::{MonoItems, create_fn_mono_item};
6+
7+
pub(crate) fn collect_enzyme_autodiff_source_fn<'tcx>(
8+
tcx: TyCtxt<'tcx>,
9+
instance: ty::Instance<'tcx>,
10+
intrinsic: IntrinsicDef,
11+
output: &mut MonoItems<'tcx>,
12+
) {
13+
if intrinsic.name != rustc_span::sym::enzyme_autodiff {
14+
return;
15+
};
16+
17+
debug!("enzyme_autodiff found");
18+
let (primal, span) = match instance.args[0].kind() {
19+
rustc_middle::infer::canonical::ir::GenericArgKind::Type(ty) => match ty.kind() {
20+
ty::FnDef(def_id, substs) => {
21+
let span = tcx.def_span(def_id);
22+
let instance = ty::Instance::expect_resolve(
23+
tcx,
24+
ty::TypingEnv::non_body_analysis(tcx, def_id),
25+
*def_id,
26+
substs,
27+
span,
28+
);
29+
30+
(instance, span)
31+
}
32+
_ => bug!("expected function"),
33+
},
34+
_ => bug!("expected type"),
35+
};
36+
37+
output.push(create_fn_mono_item(tcx, primal, span));
38+
}

0 commit comments

Comments
 (0)