Skip to content

Commit 222886d

Browse files
committed
Move logic to a dedicated enzyme_autodiff intrinsic
1 parent 0b34fb9 commit 222886d

File tree

7 files changed

+181
-46
lines changed

7 files changed

+181
-46
lines changed

compiler/rustc_builtin_macros/src/autodiff.rs

Lines changed: 130 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -331,20 +331,23 @@ mod llvm_enzyme {
331331
.count() as u32;
332332
let (d_sig, new_args, idents, errored) = gen_enzyme_decl(ecx, &sig, &x, span);
333333

334-
// UNUSED
334+
// TODO(Sa4dUs): Remove this and all the related logic
335335
let _d_body = gen_enzyme_body(
336336
ecx, &x, n_active, &sig, &d_sig, primal, &new_args, span, sig_span, idents, errored,
337337
&generics,
338338
);
339339

340+
let d_body =
341+
call_enzyme_autodiff(ecx, primal, first_ident(&meta_item_vec[0]), span, &d_sig);
342+
340343
// The first element of it is the name of the function to be generated
341344
let asdf = Box::new(ast::Fn {
342345
defaultness: ast::Defaultness::Final,
343346
sig: d_sig,
344347
ident: first_ident(&meta_item_vec[0]),
345-
generics,
348+
generics: generics.clone(),
346349
contract: None,
347-
body: None, // This leads to an error when the ad function is inside a traits
350+
body: Some(d_body),
348351
define_opaque: None,
349352
});
350353
let mut rustc_ad_attr =
@@ -431,18 +434,15 @@ mod llvm_enzyme {
431434
tokens: ts,
432435
});
433436

434-
let rustc_intrinsic_attr =
435-
P(ast::NormalAttr::from_ident(Ident::with_dummy_span(sym::rustc_intrinsic)));
436-
let new_id = ecx.sess.psess.attr_id_generator.mk_attr_id();
437-
let intrinsic_attr = outer_normal_attr(&rustc_intrinsic_attr, new_id, span);
437+
let vis_clone = vis.clone();
438438

439439
let new_id = ecx.sess.psess.attr_id_generator.mk_attr_id();
440440
let d_attr = outer_normal_attr(&rustc_ad_attr, new_id, span);
441441
let d_annotatable = match &item {
442442
Annotatable::AssocItem(_, _) => {
443443
let assoc_item: AssocItemKind = ast::AssocItemKind::Fn(asdf);
444444
let d_fn = P(ast::AssocItem {
445-
attrs: thin_vec![d_attr, intrinsic_attr],
445+
attrs: thin_vec![d_attr],
446446
id: ast::DUMMY_NODE_ID,
447447
span,
448448
vis,
@@ -452,15 +452,13 @@ mod llvm_enzyme {
452452
Annotatable::AssocItem(d_fn, Impl { of_trait: false })
453453
}
454454
Annotatable::Item(_) => {
455-
let mut d_fn =
456-
ecx.item(span, thin_vec![d_attr, intrinsic_attr], ItemKind::Fn(asdf));
455+
let mut d_fn = ecx.item(span, thin_vec![d_attr], ItemKind::Fn(asdf));
457456
d_fn.vis = vis;
458457

459458
Annotatable::Item(d_fn)
460459
}
461460
Annotatable::Stmt(_) => {
462-
let mut d_fn =
463-
ecx.item(span, thin_vec![d_attr, intrinsic_attr], ItemKind::Fn(asdf));
461+
let mut d_fn = ecx.item(span, thin_vec![d_attr], ItemKind::Fn(asdf));
464462
d_fn.vis = vis;
465463

466464
Annotatable::Stmt(P(ast::Stmt {
@@ -474,7 +472,9 @@ mod llvm_enzyme {
474472
}
475473
};
476474

477-
return vec![orig_annotatable, d_annotatable];
475+
let dummy_const_annotatable = gen_dummy_const(ecx, span, primal, sig, generics, vis_clone);
476+
477+
return vec![orig_annotatable, dummy_const_annotatable, d_annotatable];
478478
}
479479

480480
// shadow arguments (the extra ones which were not in the original (primal) function), in reverse mode must be
@@ -495,6 +495,123 @@ mod llvm_enzyme {
495495
ty
496496
}
497497

498+
// Generate `enzyme_autodiff` intrinsic call
499+
// ```
500+
// std::intrinsics::enzyme_autodiff(source, diff, (args))
501+
// ```
502+
fn call_enzyme_autodiff(
503+
ecx: &ExtCtxt<'_>,
504+
primal: Ident,
505+
diff: Ident,
506+
span: Span,
507+
d_sig: &FnSig,
508+
) -> P<ast::Block> {
509+
let primal_path_expr = ecx.expr_path(ecx.path_ident(span, primal));
510+
let diff_path_expr = ecx.expr_path(ecx.path_ident(span, diff));
511+
512+
let tuple_expr = ecx.expr_tuple(
513+
span,
514+
d_sig
515+
.decl
516+
.inputs
517+
.iter()
518+
.map(|arg| match arg.pat.kind {
519+
PatKind::Ident(_, ident, _) => ecx.expr_path(ecx.path_ident(span, ident)),
520+
_ => todo!(),
521+
})
522+
.collect::<ThinVec<_>>()
523+
.into(),
524+
);
525+
526+
let enzyme_path = ecx.path(
527+
span,
528+
vec![
529+
Ident::from_str("std"),
530+
Ident::from_str("intrinsics"),
531+
Ident::from_str("enzyme_autodiff"),
532+
],
533+
);
534+
let call_expr = ecx.expr_call(
535+
span,
536+
ecx.expr_path(enzyme_path),
537+
vec![primal_path_expr, diff_path_expr, tuple_expr].into(),
538+
);
539+
540+
let block = ecx.block_expr(call_expr);
541+
542+
block
543+
}
544+
545+
// Generate dummy const to prevent primal function
546+
// from being optimized away before applying enzyme
547+
// ```
548+
// const _: () =
549+
// {
550+
// #[used]
551+
// pub static DUMMY_PTR: fn_type = primal_fn;
552+
// };
553+
// ```
554+
fn gen_dummy_const(
555+
ecx: &ExtCtxt<'_>,
556+
span: Span,
557+
primal: Ident,
558+
sig: FnSig,
559+
generics: Generics,
560+
vis: Visibility,
561+
) -> Annotatable {
562+
// #[used]
563+
let used_attr = P(ast::NormalAttr::from_ident(Ident::with_dummy_span(sym::used)));
564+
let new_id = ecx.sess.psess.attr_id_generator.mk_attr_id();
565+
let used_attr = outer_normal_attr(&used_attr, new_id, span);
566+
567+
// static DUMMY_PTR: <fn_type> = <primal_ident>
568+
let static_ident = Ident::from_str_and_span("DUMMY_PTR", span);
569+
let fn_ptr_ty = ast::TyKind::BareFn(Box::new(ast::BareFnTy {
570+
safety: sig.header.safety,
571+
ext: sig.header.ext,
572+
generic_params: generics.params,
573+
decl: sig.decl,
574+
decl_span: sig.span,
575+
}));
576+
let static_ty = ecx.ty(span, fn_ptr_ty);
577+
578+
let static_expr = ecx.expr_path(ecx.path(span, vec![primal]));
579+
let static_item_kind = ast::ItemKind::Static(Box::new(ast::StaticItem {
580+
ident: static_ident,
581+
ty: static_ty,
582+
safety: ast::Safety::Default,
583+
mutability: ast::Mutability::Not,
584+
expr: Some(static_expr),
585+
define_opaque: None,
586+
}));
587+
588+
let static_item = ast::Item {
589+
attrs: thin_vec![used_attr],
590+
id: ast::DUMMY_NODE_ID,
591+
span,
592+
vis,
593+
kind: static_item_kind,
594+
tokens: None,
595+
};
596+
597+
let block_expr = ecx.expr_block(Box::new(ast::Block {
598+
stmts: thin_vec![ecx.stmt_item(span, P(static_item))],
599+
id: ast::DUMMY_NODE_ID,
600+
rules: ast::BlockCheckMode::Default,
601+
span,
602+
tokens: None,
603+
}));
604+
605+
let const_item = ecx.item_const(
606+
span,
607+
Ident::from_str_and_span("_", span),
608+
ecx.ty(span, ast::TyKind::Tup(thin_vec![])),
609+
block_expr,
610+
);
611+
612+
Annotatable::Item(const_item)
613+
}
614+
498615
// Will generate a body of the type:
499616
// ```
500617
// {

compiler/rustc_codegen_llvm/src/builder/autodiff.rs

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ use rustc_middle::bug;
99
use tracing::{debug, trace};
1010

1111
use crate::back::write::llvm_err;
12-
use crate::builder::{Builder, OperandRef, PlaceRef, UNNAMED};
12+
use crate::builder::{Builder, PlaceRef, UNNAMED};
1313
use crate::context::SimpleCx;
1414
use crate::declare::declare_simple_fn;
1515
use crate::errors::{AutoDiffWithoutEnable, LlvmError};
@@ -199,7 +199,7 @@ pub(crate) fn generate_enzyme_call<'ll, 'tcx>(
199199
fn_to_diff: &'ll Value,
200200
outer_name: &str,
201201
ret_ty: &'ll Type,
202-
fn_args: &[OperandRef<'tcx, &'ll Value>],
202+
fn_args: &[&'ll Value],
203203
attrs: AutoDiffAttrs,
204204
dest: PlaceRef<'tcx, &'ll Value>,
205205
) {
@@ -275,15 +275,13 @@ pub(crate) fn generate_enzyme_call<'ll, 'tcx>(
275275
args.push(cx.get_const_int(cx.type_i64(), attrs.width as u64));
276276
}
277277

278-
let outer_args: Vec<&llvm::Value> = fn_args.iter().map(|op| op.immediate()).collect();
279-
280278
match_args_from_caller_to_enzyme(
281279
&cx,
282280
builder,
283281
attrs.width,
284282
&mut args,
285283
&attrs.input_activity,
286-
&outer_args,
284+
fn_args,
287285
);
288286

289287
let call = builder.call(enzyme_ty, None, None, ad_fn, &args, None, None);

compiler/rustc_codegen_llvm/src/intrinsic.rs

Lines changed: 40 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ use std::cmp::Ordering;
33

44
use rustc_abi::{Align, BackendRepr, ExternAbi, Float, HasDataLayout, Primitive, Size};
55
use rustc_codegen_ssa::base::{compare_simd_types, wants_msvc_seh, wants_wasm_eh};
6+
use rustc_codegen_ssa::codegen_attrs::autodiff_attrs;
67
use rustc_codegen_ssa::common::{IntPredicate, TypeKind};
78
use rustc_codegen_ssa::errors::{ExpectedPointerMutability, InvalidMonomorphization};
89
use rustc_codegen_ssa::mir::operand::{OperandRef, OperandValue};
@@ -198,48 +199,60 @@ impl<'ll, 'tcx> IntrinsicCallBuilderMethods<'tcx> for Builder<'_, 'll, 'tcx> {
198199
&[ptr, args[1].immediate()],
199200
)
200201
}
201-
_ if tcx.has_attr(instance.def_id(), sym::rustc_autodiff) => {
202-
// NOTE(Sa4dUs): This is a hacky way to get the autodiff items
203-
// so we can focus on the lowering of the intrinsic call
204-
let mut source_id = None;
205-
let mut diff_attrs = None;
206-
let items: Vec<_> = tcx.hir_body_owners().map(|i| i.to_def_id()).collect();
207-
208-
// Hacky way of getting primal-diff pair, only works for code with 1 autodiff call
209-
for target_id in &items {
210-
let Some(target_attrs) = &tcx.codegen_fn_attrs(target_id).autodiff_item else {
211-
continue;
212-
};
202+
sym::enzyme_autodiff => {
203+
let val_arr: Vec<&'ll Value> = match args[2].val {
204+
crate::intrinsic::OperandValue::Ref(ref place_value) => {
205+
let mut ret_arr = vec![];
206+
let tuple_place = PlaceRef { val: *place_value, layout: args[2].layout };
213207

214-
if target_attrs.is_source() {
215-
source_id = Some(*target_id);
216-
} else {
217-
diff_attrs = Some(target_attrs);
218-
}
219-
}
208+
for i in 0..tuple_place.layout.layout.0.fields.count() {
209+
let field_place = tuple_place.project_field(self, i);
210+
let field_layout = tuple_place.layout.field(self, i);
211+
let llvm_ty = field_layout.llvm_type(self.cx);
220212

221-
if source_id.is_none() || diff_attrs.is_none() {
222-
bug!("could not find source_id={source_id:?} or diff_attrs={diff_attrs:?}");
223-
}
213+
let field_val =
214+
self.load(llvm_ty, field_place.val.llval, field_place.val.align);
215+
216+
ret_arr.push(field_val)
217+
}
224218

225-
let diff_attrs = diff_attrs.unwrap().clone();
219+
ret_arr
220+
}
221+
crate::intrinsic::OperandValue::Pair(v1, v2) => vec![v1, v2],
222+
OperandValue::Immediate(v) => vec![v],
223+
OperandValue::ZeroSized => bug!("unexpected `ZeroSized` arg"),
224+
};
226225

227-
// Get source fn
228-
let source_id = source_id.unwrap();
229-
let fn_source = Instance::mono(tcx, source_id);
226+
// Get source, diff, and attrs
227+
let source_id = match fn_args.into_type_list(tcx)[0].kind() {
228+
ty::FnDef(def_id, _) => def_id,
229+
_ => bug!("invalid args"),
230+
};
231+
let fn_source = Instance::mono(tcx, *source_id);
230232
let source_symbol =
231233
symbol_name_for_instance_in_crate(tcx, fn_source.clone(), LOCAL_CRATE);
232234
let fn_to_diff: Option<&'ll llvm::Value> = self.cx.get_function(&source_symbol);
233235
let Some(fn_to_diff) = fn_to_diff else { bug!("could not find source function") };
234236

237+
let diff_id = match fn_args.into_type_list(tcx)[1].kind() {
238+
ty::FnDef(def_id, _) => def_id,
239+
_ => bug!("invalid args"),
240+
};
241+
let fn_diff = Instance::mono(tcx, *diff_id);
242+
let diff_symbol =
243+
symbol_name_for_instance_in_crate(tcx, fn_diff.clone(), LOCAL_CRATE);
244+
245+
let diff_attrs = autodiff_attrs(tcx, *diff_id);
246+
let Some(diff_attrs) = diff_attrs else { bug!("could not find autodiff attrs") };
247+
235248
// Build body
236249
generate_enzyme_call(
237250
self,
238251
self.cx,
239252
fn_to_diff,
240-
name.as_str(),
253+
&diff_symbol,
241254
llret_ty,
242-
args,
255+
&val_arr,
243256
diff_attrs.clone(),
244257
result,
245258
);

compiler/rustc_codegen_ssa/src/codegen_attrs.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -587,7 +587,7 @@ fn check_link_name_xor_ordinal(
587587
/// being differentiated). The other form is #[rustc_autodiff(Mode, ActivityList)] on top of the
588588
/// placeholder functions. We wrote the rustc_autodiff attributes ourself, so this should never
589589
/// panic, unless we introduced a bug when parsing the autodiff macro.
590-
fn autodiff_attrs(tcx: TyCtxt<'_>, id: DefId) -> Option<AutoDiffAttrs> {
590+
pub fn autodiff_attrs(tcx: TyCtxt<'_>, id: DefId) -> Option<AutoDiffAttrs> {
591591
let attrs = tcx.get_attrs(id, sym::rustc_autodiff);
592592

593593
let attrs = attrs.filter(|attr| attr.has_name(sym::rustc_autodiff)).collect::<Vec<_>>();

compiler/rustc_hir_analysis/src/check/intrinsic.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,7 @@ fn intrinsic_operation_unsafety(tcx: TyCtxt<'_>, intrinsic_id: LocalDefId) -> hi
135135
| sym::round_ties_even_f32
136136
| sym::round_ties_even_f64
137137
| sym::round_ties_even_f128
138+
| sym::enzyme_autodiff
138139
| sym::const_eval_select => hir::Safety::Safe,
139140
_ => hir::Safety::Unsafe,
140141
};
@@ -216,6 +217,7 @@ pub(crate) fn check_intrinsic_type(
216217

217218
(n_tps, n_cts, inputs, output)
218219
}
220+
sym::enzyme_autodiff => (4, 0, vec![param(0), param(1), param(2)], param(3)),
219221
sym::abort => (0, 0, vec![], tcx.types.never),
220222
sym::unreachable => (0, 0, vec![], tcx.types.never),
221223
sym::breakpoint => (0, 0, vec![], tcx.types.unit),

compiler/rustc_span/src/symbol.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -915,6 +915,7 @@ symbols! {
915915
enumerate_method,
916916
env,
917917
env_CFG_RELEASE: env!("CFG_RELEASE"),
918+
enzyme_autodiff,
918919
eprint_macro,
919920
eprintln_macro,
920921
eq,

library/core/src/intrinsics/mod.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3163,6 +3163,10 @@ pub const unsafe fn copysignf64(x: f64, y: f64) -> f64;
31633163
#[rustc_intrinsic]
31643164
pub const unsafe fn copysignf128(x: f128, y: f128) -> f128;
31653165

3166+
#[rustc_nounwind]
3167+
#[rustc_intrinsic]
3168+
pub const fn enzyme_autodiff<F, G, T: crate::marker::Tuple, R>(f: F, df: G, args: T) -> R;
3169+
31663170
/// Inform Miri that a given pointer definitely has a certain alignment.
31673171
#[cfg(miri)]
31683172
#[rustc_allow_const_fn_unstable(const_eval_select)]

0 commit comments

Comments
 (0)