Skip to content

Commit c9c1c17

Browse files
committed
Remove inlining for autodiff handling
1 parent 250d77e commit c9c1c17

File tree

2 files changed

+21
-19
lines changed

2 files changed

+21
-19
lines changed

compiler/rustc_builtin_macros/src/autodiff.rs

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,6 @@ mod llvm_enzyme {
192192
/// which becomes expanded to:
193193
/// ```
194194
/// #[rustc_autodiff]
195-
/// #[inline(never)]
196195
/// fn sin(x: &Box<f32>) -> f32 {
197196
/// f32::sin(**x)
198197
/// }
@@ -371,7 +370,7 @@ mod llvm_enzyme {
371370
let new_id = ecx.sess.psess.attr_id_generator.mk_attr_id();
372371
let inline_never = outer_normal_attr(&inline_never_attr, new_id, span);
373372

374-
// We're avoid duplicating the attributes `#[rustc_autodiff]` and `#[inline(never)]`.
373+
// We're avoid duplicating the attribute `#[rustc_autodiff]`.
375374
fn same_attribute(attr: &ast::AttrKind, item: &ast::AttrKind) -> bool {
376375
match (attr, item) {
377376
(ast::AttrKind::Normal(a), ast::AttrKind::Normal(b)) => {
@@ -384,23 +383,25 @@ mod llvm_enzyme {
384383
}
385384
}
386385

386+
let mut has_inline_never = false;
387+
387388
// Don't add it multiple times:
388389
let orig_annotatable: Annotatable = match item {
389390
Annotatable::Item(ref mut iitem) => {
390391
if !iitem.attrs.iter().any(|a| same_attribute(&a.kind, &attr.kind)) {
391392
iitem.attrs.push(attr);
392393
}
393-
if !iitem.attrs.iter().any(|a| same_attribute(&a.kind, &inline_never.kind)) {
394-
iitem.attrs.push(inline_never.clone());
394+
if iitem.attrs.iter().any(|a| same_attribute(&a.kind, &inline_never.kind)) {
395+
has_inline_never = true;
395396
}
396397
Annotatable::Item(iitem.clone())
397398
}
398399
Annotatable::AssocItem(ref mut assoc_item, i @ Impl { .. }) => {
399400
if !assoc_item.attrs.iter().any(|a| same_attribute(&a.kind, &attr.kind)) {
400401
assoc_item.attrs.push(attr);
401402
}
402-
if !assoc_item.attrs.iter().any(|a| same_attribute(&a.kind, &inline_never.kind)) {
403-
assoc_item.attrs.push(inline_never.clone());
403+
if assoc_item.attrs.iter().any(|a| same_attribute(&a.kind, &inline_never.kind)) {
404+
has_inline_never = true;
404405
}
405406
Annotatable::AssocItem(assoc_item.clone(), i)
406407
}
@@ -410,9 +411,8 @@ mod llvm_enzyme {
410411
if !iitem.attrs.iter().any(|a| same_attribute(&a.kind, &attr.kind)) {
411412
iitem.attrs.push(attr);
412413
}
413-
if !iitem.attrs.iter().any(|a| same_attribute(&a.kind, &inline_never.kind))
414-
{
415-
iitem.attrs.push(inline_never.clone());
414+
if iitem.attrs.iter().any(|a| same_attribute(&a.kind, &inline_never.kind)) {
415+
has_inline_never = true;
416416
}
417417
}
418418
_ => unreachable!("stmt kind checked previously"),
@@ -433,11 +433,19 @@ mod llvm_enzyme {
433433

434434
let new_id = ecx.sess.psess.attr_id_generator.mk_attr_id();
435435
let d_attr = outer_normal_attr(&rustc_ad_attr, new_id, span);
436+
437+
// If the source function has the `#[inline(never)]` attribute, we'll also add it to the diff function
438+
let mut d_attrs = thin_vec![d_attr];
439+
440+
if has_inline_never {
441+
d_attrs.push(inline_never);
442+
}
443+
436444
let d_annotatable = match &item {
437445
Annotatable::AssocItem(_, _) => {
438446
let assoc_item: AssocItemKind = ast::AssocItemKind::Fn(d_fn);
439447
let d_fn = Box::new(ast::AssocItem {
440-
attrs: thin_vec![d_attr],
448+
attrs: d_attrs,
441449
id: ast::DUMMY_NODE_ID,
442450
span,
443451
vis,
@@ -447,13 +455,13 @@ mod llvm_enzyme {
447455
Annotatable::AssocItem(d_fn, Impl { of_trait: false })
448456
}
449457
Annotatable::Item(_) => {
450-
let mut d_fn = ecx.item(span, thin_vec![d_attr], ItemKind::Fn(d_fn));
458+
let mut d_fn = ecx.item(span, d_attrs, ItemKind::Fn(d_fn));
451459
d_fn.vis = vis;
452460

453461
Annotatable::Item(d_fn)
454462
}
455463
Annotatable::Stmt(_) => {
456-
let mut d_fn = ecx.item(span, thin_vec![d_attr], ItemKind::Fn(d_fn));
464+
let mut d_fn = ecx.item(span, d_attrs, ItemKind::Fn(d_fn));
457465
d_fn.vis = vis;
458466

459467
Annotatable::Stmt(Box::new(ast::Stmt {

compiler/rustc_codegen_llvm/src/builder/autodiff.rs

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,9 @@ use tracing::debug;
1010
use crate::builder::{Builder, PlaceRef, UNNAMED};
1111
use crate::context::SimpleCx;
1212
use crate::declare::declare_simple_fn;
13-
use crate::llvm::AttributePlace::Function;
13+
use crate::llvm;
1414
use crate::llvm::{Metadata, True, Type};
1515
use crate::value::Value;
16-
use crate::{attributes, llvm};
1716

1817
pub(crate) fn adjust_activity_to_abi<'tcx>(
1918
tcx: TyCtxt<'tcx>,
@@ -308,11 +307,6 @@ pub(crate) fn generate_enzyme_call<'ll, 'tcx>(
308307
enzyme_ty,
309308
);
310309

311-
// Otherwise LLVM might inline our temporary code before the enzyme pass has a chance to
312-
// do it's work.
313-
let attr = llvm::AttributeKind::NoInline.create_attr(cx.llcx);
314-
attributes::apply_to_llfn(ad_fn, Function, &[attr]);
315-
316310
let num_args = llvm::LLVMCountParams(&fn_to_diff);
317311
let mut args = Vec::with_capacity(num_args as usize + 1);
318312
args.push(fn_to_diff);

0 commit comments

Comments
 (0)