Skip to content

Commit e8f2cfd

Browse files
committed
Basic implementation of autodiff intrinsic
1 parent 53af067 commit e8f2cfd

File tree

7 files changed

+284
-245
lines changed

7 files changed

+284
-245
lines changed

compiler/rustc_builtin_macros/src/autodiff.rs

Lines changed: 133 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -329,17 +329,22 @@ mod llvm_enzyme {
329329
.filter(|a| **a == DiffActivity::Active || **a == DiffActivity::ActiveOnly)
330330
.count() as u32;
331331
let (d_sig, new_args, idents, errored) = gen_enzyme_decl(ecx, &sig, &x, span);
332-
let d_body = gen_enzyme_body(
332+
333+
// TODO(Sa4dUs): Remove this and all the related logic
334+
let _d_body = gen_enzyme_body(
333335
ecx, &x, n_active, &sig, &d_sig, primal, &new_args, span, sig_span, idents, errored,
334336
&generics,
335337
);
336338

339+
let d_body =
340+
call_autodiff(ecx, primal, first_ident(&meta_item_vec[0]), span, &d_sig);
341+
337342
// The first element of it is the name of the function to be generated
338343
let asdf = Box::new(ast::Fn {
339344
defaultness: ast::Defaultness::Final,
340345
sig: d_sig,
341346
ident: first_ident(&meta_item_vec[0]),
342-
generics,
347+
generics: generics.clone(),
343348
contract: None,
344349
body: Some(d_body),
345350
define_opaque: None,
@@ -428,12 +433,15 @@ mod llvm_enzyme {
428433
tokens: ts,
429434
});
430435

436+
let vis_clone = vis.clone();
437+
438+
let new_id = ecx.sess.psess.attr_id_generator.mk_attr_id();
431439
let d_attr = outer_normal_attr(&rustc_ad_attr, new_id, span);
432440
let d_annotatable = match &item {
433441
Annotatable::AssocItem(_, _) => {
434442
let assoc_item: AssocItemKind = ast::AssocItemKind::Fn(asdf);
435443
let d_fn = Box::new(ast::AssocItem {
436-
attrs: thin_vec![d_attr, inline_never],
444+
attrs: thin_vec![d_attr],
437445
id: ast::DUMMY_NODE_ID,
438446
span,
439447
vis,
@@ -443,13 +451,13 @@ mod llvm_enzyme {
443451
Annotatable::AssocItem(d_fn, Impl { of_trait: false })
444452
}
445453
Annotatable::Item(_) => {
446-
let mut d_fn = ecx.item(span, thin_vec![d_attr, inline_never], ItemKind::Fn(asdf));
454+
let mut d_fn = ecx.item(span, thin_vec![d_attr], ItemKind::Fn(asdf));
447455
d_fn.vis = vis;
448456

449457
Annotatable::Item(d_fn)
450458
}
451459
Annotatable::Stmt(_) => {
452-
let mut d_fn = ecx.item(span, thin_vec![d_attr, inline_never], ItemKind::Fn(asdf));
460+
let mut d_fn = ecx.item(span, thin_vec![d_attr], ItemKind::Fn(asdf));
453461
d_fn.vis = vis;
454462

455463
Annotatable::Stmt(Box::new(ast::Stmt {
@@ -463,7 +471,9 @@ mod llvm_enzyme {
463471
}
464472
};
465473

466-
return vec![orig_annotatable, d_annotatable];
474+
let dummy_const_annotatable = gen_dummy_const(ecx, span, primal, sig, generics, vis_clone);
475+
476+
return vec![orig_annotatable, dummy_const_annotatable, d_annotatable];
467477
}
468478

469479
// shadow arguments (the extra ones which were not in the original (primal) function), in reverse mode must be
@@ -484,6 +494,123 @@ mod llvm_enzyme {
484494
ty
485495
}
486496

497+
// Generate `autodiff` intrinsic call
498+
// ```
499+
// std::intrinsics::autodiff(source, diff, (args))
500+
// ```
501+
fn call_autodiff(
502+
ecx: &ExtCtxt<'_>,
503+
primal: Ident,
504+
diff: Ident,
505+
span: Span,
506+
d_sig: &FnSig,
507+
) -> P<ast::Block> {
508+
let primal_path_expr = ecx.expr_path(ecx.path_ident(span, primal));
509+
let diff_path_expr = ecx.expr_path(ecx.path_ident(span, diff));
510+
511+
let tuple_expr = ecx.expr_tuple(
512+
span,
513+
d_sig
514+
.decl
515+
.inputs
516+
.iter()
517+
.map(|arg| match arg.pat.kind {
518+
PatKind::Ident(_, ident, _) => ecx.expr_path(ecx.path_ident(span, ident)),
519+
_ => todo!(),
520+
})
521+
.collect::<ThinVec<_>>()
522+
.into(),
523+
);
524+
525+
let enzyme_path = ecx.path(
526+
span,
527+
vec![
528+
Ident::from_str("std"),
529+
Ident::from_str("intrinsics"),
530+
Ident::from_str("autodiff"),
531+
],
532+
);
533+
let call_expr = ecx.expr_call(
534+
span,
535+
ecx.expr_path(enzyme_path),
536+
vec![primal_path_expr, diff_path_expr, tuple_expr].into(),
537+
);
538+
539+
let block = ecx.block_expr(call_expr);
540+
541+
block
542+
}
543+
544+
// Generate dummy const to prevent primal function
545+
// from being optimized away before applying enzyme
546+
// ```
547+
// const _: () =
548+
// {
549+
// #[used]
550+
// pub static DUMMY_PTR: fn_type = primal_fn;
551+
// };
552+
// ```
553+
fn gen_dummy_const(
554+
ecx: &ExtCtxt<'_>,
555+
span: Span,
556+
primal: Ident,
557+
sig: FnSig,
558+
generics: Generics,
559+
vis: Visibility,
560+
) -> Annotatable {
561+
// #[used]
562+
let used_attr = P(ast::NormalAttr::from_ident(Ident::with_dummy_span(sym::used)));
563+
let new_id = ecx.sess.psess.attr_id_generator.mk_attr_id();
564+
let used_attr = outer_normal_attr(&used_attr, new_id, span);
565+
566+
// static DUMMY_PTR: <fn_type> = <primal_ident>
567+
let static_ident = Ident::from_str_and_span("DUMMY_PTR", span);
568+
let fn_ptr_ty = ast::TyKind::BareFn(Box::new(ast::BareFnTy {
569+
safety: sig.header.safety,
570+
ext: sig.header.ext,
571+
generic_params: generics.params,
572+
decl: sig.decl,
573+
decl_span: sig.span,
574+
}));
575+
let static_ty = ecx.ty(span, fn_ptr_ty);
576+
577+
let static_expr = ecx.expr_path(ecx.path(span, vec![primal]));
578+
let static_item_kind = ast::ItemKind::Static(Box::new(ast::StaticItem {
579+
ident: static_ident,
580+
ty: static_ty,
581+
safety: ast::Safety::Default,
582+
mutability: ast::Mutability::Not,
583+
expr: Some(static_expr),
584+
define_opaque: None,
585+
}));
586+
587+
let static_item = ast::Item {
588+
attrs: thin_vec![used_attr],
589+
id: ast::DUMMY_NODE_ID,
590+
span,
591+
vis,
592+
kind: static_item_kind,
593+
tokens: None,
594+
};
595+
596+
let block_expr = ecx.expr_block(Box::new(ast::Block {
597+
stmts: thin_vec![ecx.stmt_item(span, P(static_item))],
598+
id: ast::DUMMY_NODE_ID,
599+
rules: ast::BlockCheckMode::Default,
600+
span,
601+
tokens: None,
602+
}));
603+
604+
let const_item = ecx.item_const(
605+
span,
606+
Ident::from_str_and_span("_", span),
607+
ecx.ty(span, ast::TyKind::Tup(thin_vec![])),
608+
block_expr,
609+
);
610+
611+
Annotatable::Item(const_item)
612+
}
613+
487614
// Will generate a body of the type:
488615
// ```
489616
// {

0 commit comments

Comments
 (0)