Skip to content

Commit 6c484e9

Browse files
committed
Move logic to a dedicated enzyme_autodiff intrinsic
1 parent b4ac54a commit 6c484e9

File tree

6 files changed

+180
-45
lines changed

6 files changed

+180
-45
lines changed

compiler/rustc_builtin_macros/src/autodiff.rs

Lines changed: 130 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -330,20 +330,23 @@ mod llvm_enzyme {
330330
.count() as u32;
331331
let (d_sig, new_args, idents, errored) = gen_enzyme_decl(ecx, &sig, &x, span);
332332

333-
// UNUSED
333+
// TODO(Sa4dUs): Remove this and all the related logic
334334
let _d_body = gen_enzyme_body(
335335
ecx, &x, n_active, &sig, &d_sig, primal, &new_args, span, sig_span, idents, errored,
336336
&generics,
337337
);
338338

339+
let d_body =
340+
call_enzyme_autodiff(ecx, primal, first_ident(&meta_item_vec[0]), span, &d_sig);
341+
339342
// The first element of it is the name of the function to be generated
340343
let asdf = Box::new(ast::Fn {
341344
defaultness: ast::Defaultness::Final,
342345
sig: d_sig,
343346
ident: first_ident(&meta_item_vec[0]),
344-
generics,
347+
generics: generics.clone(),
345348
contract: None,
346-
body: None, // This leads to an error when the ad function is inside a traits
349+
body: Some(d_body),
347350
define_opaque: None,
348351
});
349352
let mut rustc_ad_attr =
@@ -430,18 +433,15 @@ mod llvm_enzyme {
430433
tokens: ts,
431434
});
432435

433-
let rustc_intrinsic_attr =
434-
P(ast::NormalAttr::from_ident(Ident::with_dummy_span(sym::rustc_intrinsic)));
435-
let new_id = ecx.sess.psess.attr_id_generator.mk_attr_id();
436-
let intrinsic_attr = outer_normal_attr(&rustc_intrinsic_attr, new_id, span);
436+
let vis_clone = vis.clone();
437437

438438
let new_id = ecx.sess.psess.attr_id_generator.mk_attr_id();
439439
let d_attr = outer_normal_attr(&rustc_ad_attr, new_id, span);
440440
let d_annotatable = match &item {
441441
Annotatable::AssocItem(_, _) => {
442442
let assoc_item: AssocItemKind = ast::AssocItemKind::Fn(asdf);
443443
let d_fn = Box::new(ast::AssocItem {
444-
attrs: thin_vec![d_attr, inline_never],
444+
attrs: thin_vec![d_attr],
445445
id: ast::DUMMY_NODE_ID,
446446
span,
447447
vis,
@@ -451,15 +451,13 @@ mod llvm_enzyme {
451451
Annotatable::AssocItem(d_fn, Impl { of_trait: false })
452452
}
453453
Annotatable::Item(_) => {
454-
let mut d_fn =
455-
ecx.item(span, thin_vec![d_attr, intrinsic_attr], ItemKind::Fn(asdf));
454+
let mut d_fn = ecx.item(span, thin_vec![d_attr], ItemKind::Fn(asdf));
456455
d_fn.vis = vis;
457456

458457
Annotatable::Item(d_fn)
459458
}
460459
Annotatable::Stmt(_) => {
461-
let mut d_fn =
462-
ecx.item(span, thin_vec![d_attr, intrinsic_attr], ItemKind::Fn(asdf));
460+
let mut d_fn = ecx.item(span, thin_vec![d_attr], ItemKind::Fn(asdf));
463461
d_fn.vis = vis;
464462

465463
Annotatable::Stmt(Box::new(ast::Stmt {
@@ -473,7 +471,9 @@ mod llvm_enzyme {
473471
}
474472
};
475473

476-
return vec![orig_annotatable, d_annotatable];
474+
let dummy_const_annotatable = gen_dummy_const(ecx, span, primal, sig, generics, vis_clone);
475+
476+
return vec![orig_annotatable, dummy_const_annotatable, d_annotatable];
477477
}
478478

479479
// shadow arguments (the extra ones which were not in the original (primal) function), in reverse mode must be
@@ -494,6 +494,123 @@ mod llvm_enzyme {
494494
ty
495495
}
496496

497+
// Generate `enzyme_autodiff` intrinsic call
498+
// ```
499+
// std::intrinsics::enzyme_autodiff(source, diff, (args))
500+
// ```
501+
fn call_enzyme_autodiff(
502+
ecx: &ExtCtxt<'_>,
503+
primal: Ident,
504+
diff: Ident,
505+
span: Span,
506+
d_sig: &FnSig,
507+
) -> P<ast::Block> {
508+
let primal_path_expr = ecx.expr_path(ecx.path_ident(span, primal));
509+
let diff_path_expr = ecx.expr_path(ecx.path_ident(span, diff));
510+
511+
let tuple_expr = ecx.expr_tuple(
512+
span,
513+
d_sig
514+
.decl
515+
.inputs
516+
.iter()
517+
.map(|arg| match arg.pat.kind {
518+
PatKind::Ident(_, ident, _) => ecx.expr_path(ecx.path_ident(span, ident)),
519+
_ => todo!(),
520+
})
521+
.collect::<ThinVec<_>>()
522+
.into(),
523+
);
524+
525+
let enzyme_path = ecx.path(
526+
span,
527+
vec![
528+
Ident::from_str("std"),
529+
Ident::from_str("intrinsics"),
530+
Ident::from_str("enzyme_autodiff"),
531+
],
532+
);
533+
let call_expr = ecx.expr_call(
534+
span,
535+
ecx.expr_path(enzyme_path),
536+
vec![primal_path_expr, diff_path_expr, tuple_expr].into(),
537+
);
538+
539+
let block = ecx.block_expr(call_expr);
540+
541+
block
542+
}
543+
544+
// Generate dummy const to prevent primal function
545+
// from being optimized away before applying enzyme
546+
// ```
547+
// const _: () =
548+
// {
549+
// #[used]
550+
// pub static DUMMY_PTR: fn_type = primal_fn;
551+
// };
552+
// ```
553+
fn gen_dummy_const(
554+
ecx: &ExtCtxt<'_>,
555+
span: Span,
556+
primal: Ident,
557+
sig: FnSig,
558+
generics: Generics,
559+
vis: Visibility,
560+
) -> Annotatable {
561+
// #[used]
562+
let used_attr = P(ast::NormalAttr::from_ident(Ident::with_dummy_span(sym::used)));
563+
let new_id = ecx.sess.psess.attr_id_generator.mk_attr_id();
564+
let used_attr = outer_normal_attr(&used_attr, new_id, span);
565+
566+
// static DUMMY_PTR: <fn_type> = <primal_ident>
567+
let static_ident = Ident::from_str_and_span("DUMMY_PTR", span);
568+
let fn_ptr_ty = ast::TyKind::BareFn(Box::new(ast::BareFnTy {
569+
safety: sig.header.safety,
570+
ext: sig.header.ext,
571+
generic_params: generics.params,
572+
decl: sig.decl,
573+
decl_span: sig.span,
574+
}));
575+
let static_ty = ecx.ty(span, fn_ptr_ty);
576+
577+
let static_expr = ecx.expr_path(ecx.path(span, vec![primal]));
578+
let static_item_kind = ast::ItemKind::Static(Box::new(ast::StaticItem {
579+
ident: static_ident,
580+
ty: static_ty,
581+
safety: ast::Safety::Default,
582+
mutability: ast::Mutability::Not,
583+
expr: Some(static_expr),
584+
define_opaque: None,
585+
}));
586+
587+
let static_item = ast::Item {
588+
attrs: thin_vec![used_attr],
589+
id: ast::DUMMY_NODE_ID,
590+
span,
591+
vis,
592+
kind: static_item_kind,
593+
tokens: None,
594+
};
595+
596+
let block_expr = ecx.expr_block(Box::new(ast::Block {
597+
stmts: thin_vec![ecx.stmt_item(span, P(static_item))],
598+
id: ast::DUMMY_NODE_ID,
599+
rules: ast::BlockCheckMode::Default,
600+
span,
601+
tokens: None,
602+
}));
603+
604+
let const_item = ecx.item_const(
605+
span,
606+
Ident::from_str_and_span("_", span),
607+
ecx.ty(span, ast::TyKind::Tup(thin_vec![])),
608+
block_expr,
609+
);
610+
611+
Annotatable::Item(const_item)
612+
}
613+
497614
// Will generate a body of the type:
498615
// ```
499616
// {

compiler/rustc_codegen_llvm/src/builder/autodiff.rs

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ use rustc_middle::bug;
99
use tracing::{debug, trace};
1010

1111
use crate::back::write::llvm_err;
12-
use crate::builder::{Builder, OperandRef, PlaceRef, UNNAMED};
12+
use crate::builder::{Builder, PlaceRef, UNNAMED};
1313
use crate::context::SimpleCx;
1414
use crate::declare::declare_simple_fn;
1515
use crate::errors::{AutoDiffWithoutEnable, LlvmError};
@@ -199,7 +199,7 @@ pub(crate) fn generate_enzyme_call<'ll, 'tcx>(
199199
fn_to_diff: &'ll Value,
200200
outer_name: &str,
201201
ret_ty: &'ll Type,
202-
fn_args: &[OperandRef<'tcx, &'ll Value>],
202+
fn_args: &[&'ll Value],
203203
attrs: AutoDiffAttrs,
204204
dest: PlaceRef<'tcx, &'ll Value>,
205205
) {
@@ -275,15 +275,13 @@ pub(crate) fn generate_enzyme_call<'ll, 'tcx>(
275275
args.push(cx.get_const_int(cx.type_i64(), attrs.width as u64));
276276
}
277277

278-
let outer_args: Vec<&llvm::Value> = fn_args.iter().map(|op| op.immediate()).collect();
279-
280278
match_args_from_caller_to_enzyme(
281279
&cx,
282280
builder,
283281
attrs.width,
284282
&mut args,
285283
&attrs.input_activity,
286-
&outer_args,
284+
fn_args,
287285
);
288286

289287
let call = builder.call(enzyme_ty, None, None, ad_fn, &args, None, None);

compiler/rustc_codegen_llvm/src/intrinsic.rs

Lines changed: 40 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ use std::cmp::Ordering;
33

44
use rustc_abi::{Align, BackendRepr, ExternAbi, Float, HasDataLayout, Primitive, Size};
55
use rustc_codegen_ssa::base::{compare_simd_types, wants_msvc_seh, wants_wasm_eh};
6+
use rustc_codegen_ssa::codegen_attrs::autodiff_attrs;
67
use rustc_codegen_ssa::common::{IntPredicate, TypeKind};
78
use rustc_codegen_ssa::errors::{ExpectedPointerMutability, InvalidMonomorphization};
89
use rustc_codegen_ssa::mir::operand::{OperandRef, OperandValue};
@@ -198,48 +199,60 @@ impl<'ll, 'tcx> IntrinsicCallBuilderMethods<'tcx> for Builder<'_, 'll, 'tcx> {
198199
&[ptr, args[1].immediate()],
199200
)
200201
}
201-
_ if tcx.has_attr(instance.def_id(), sym::rustc_autodiff) => {
202-
// NOTE(Sa4dUs): This is a hacky way to get the autodiff items
203-
// so we can focus on the lowering of the intrinsic call
204-
let mut source_id = None;
205-
let mut diff_attrs = None;
206-
let items: Vec<_> = tcx.hir_body_owners().map(|i| i.to_def_id()).collect();
207-
208-
// Hacky way of getting primal-diff pair, only works for code with 1 autodiff call
209-
for target_id in &items {
210-
let Some(target_attrs) = &tcx.codegen_fn_attrs(target_id).autodiff_item else {
211-
continue;
212-
};
202+
sym::enzyme_autodiff => {
203+
let val_arr: Vec<&'ll Value> = match args[2].val {
204+
crate::intrinsic::OperandValue::Ref(ref place_value) => {
205+
let mut ret_arr = vec![];
206+
let tuple_place = PlaceRef { val: *place_value, layout: args[2].layout };
213207

214-
if target_attrs.is_source() {
215-
source_id = Some(*target_id);
216-
} else {
217-
diff_attrs = Some(target_attrs);
218-
}
219-
}
208+
for i in 0..tuple_place.layout.layout.0.fields.count() {
209+
let field_place = tuple_place.project_field(self, i);
210+
let field_layout = tuple_place.layout.field(self, i);
211+
let llvm_ty = field_layout.llvm_type(self.cx);
220212

221-
if source_id.is_none() || diff_attrs.is_none() {
222-
bug!("could not find source_id={source_id:?} or diff_attrs={diff_attrs:?}");
223-
}
213+
let field_val =
214+
self.load(llvm_ty, field_place.val.llval, field_place.val.align);
215+
216+
ret_arr.push(field_val)
217+
}
224218

225-
let diff_attrs = diff_attrs.unwrap().clone();
219+
ret_arr
220+
}
221+
crate::intrinsic::OperandValue::Pair(v1, v2) => vec![v1, v2],
222+
OperandValue::Immediate(v) => vec![v],
223+
OperandValue::ZeroSized => bug!("unexpected `ZeroSized` arg"),
224+
};
226225

227-
// Get source fn
228-
let source_id = source_id.unwrap();
229-
let fn_source = Instance::mono(tcx, source_id);
226+
// Get source, diff, and attrs
227+
let source_id = match fn_args.into_type_list(tcx)[0].kind() {
228+
ty::FnDef(def_id, _) => def_id,
229+
_ => bug!("invalid args"),
230+
};
231+
let fn_source = Instance::mono(tcx, *source_id);
230232
let source_symbol =
231233
symbol_name_for_instance_in_crate(tcx, fn_source.clone(), LOCAL_CRATE);
232234
let fn_to_diff: Option<&'ll llvm::Value> = self.cx.get_function(&source_symbol);
233235
let Some(fn_to_diff) = fn_to_diff else { bug!("could not find source function") };
234236

237+
let diff_id = match fn_args.into_type_list(tcx)[1].kind() {
238+
ty::FnDef(def_id, _) => def_id,
239+
_ => bug!("invalid args"),
240+
};
241+
let fn_diff = Instance::mono(tcx, *diff_id);
242+
let diff_symbol =
243+
symbol_name_for_instance_in_crate(tcx, fn_diff.clone(), LOCAL_CRATE);
244+
245+
let diff_attrs = autodiff_attrs(tcx, *diff_id);
246+
let Some(diff_attrs) = diff_attrs else { bug!("could not find autodiff attrs") };
247+
235248
// Build body
236249
generate_enzyme_call(
237250
self,
238251
self.cx,
239252
fn_to_diff,
240-
name.as_str(),
253+
&diff_symbol,
241254
llret_ty,
242-
args,
255+
&val_arr,
243256
diff_attrs.clone(),
244257
result,
245258
);

compiler/rustc_hir_analysis/src/check/intrinsic.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,7 @@ fn intrinsic_operation_unsafety(tcx: TyCtxt<'_>, intrinsic_id: LocalDefId) -> hi
135135
| sym::round_ties_even_f32
136136
| sym::round_ties_even_f64
137137
| sym::round_ties_even_f128
138+
| sym::enzyme_autodiff
138139
| sym::const_eval_select => hir::Safety::Safe,
139140
_ => hir::Safety::Unsafe,
140141
};
@@ -216,6 +217,7 @@ pub(crate) fn check_intrinsic_type(
216217

217218
(n_tps, n_cts, inputs, output)
218219
}
220+
sym::enzyme_autodiff => (4, 0, vec![param(0), param(1), param(2)], param(3)),
219221
sym::abort => (0, 0, vec![], tcx.types.never),
220222
sym::unreachable => (0, 0, vec![], tcx.types.never),
221223
sym::breakpoint => (0, 0, vec![], tcx.types.unit),

compiler/rustc_span/src/symbol.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -917,6 +917,7 @@ symbols! {
917917
enumerate_method,
918918
env,
919919
env_CFG_RELEASE: env!("CFG_RELEASE"),
920+
enzyme_autodiff,
920921
eprint_macro,
921922
eprintln_macro,
922923
eq,

library/core/src/intrinsics/mod.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3157,6 +3157,10 @@ pub const unsafe fn copysignf64(x: f64, y: f64) -> f64;
31573157
#[rustc_intrinsic]
31583158
pub const unsafe fn copysignf128(x: f128, y: f128) -> f128;
31593159

3160+
#[rustc_nounwind]
3161+
#[rustc_intrinsic]
3162+
pub const fn enzyme_autodiff<F, G, T: crate::marker::Tuple, R>(f: F, df: G, args: T) -> R;
3163+
31603164
/// Inform Miri that a given pointer definitely has a certain alignment.
31613165
#[cfg(miri)]
31623166
#[rustc_allow_const_fn_unstable(const_eval_select)]

0 commit comments

Comments
 (0)