Skip to content

Commit 0a0e48a

Browse files
committed
Naive impl of intrinsic codegen
Note(Sa4dUs): Most tests are still broken due to `sret` and how funcs are searched in the current logic
1 parent 6cbf9cc commit 0a0e48a

File tree

3 files changed

+70
-55
lines changed

3 files changed

+70
-55
lines changed

compiler/rustc_codegen_llvm/src/builder/autodiff.rs

Lines changed: 19 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,13 @@ use std::ptr;
33
use rustc_ast::expand::autodiff_attrs::{AutoDiffAttrs, AutoDiffItem, DiffActivity, DiffMode};
44
use rustc_codegen_ssa::ModuleCodegen;
55
use rustc_codegen_ssa::common::TypeKind;
6-
use rustc_codegen_ssa::traits::BaseTypeCodegenMethods;
6+
use rustc_codegen_ssa::traits::{BaseTypeCodegenMethods, BuilderMethods};
77
use rustc_errors::FatalError;
88
use rustc_middle::bug;
99
use tracing::{debug, trace};
1010

1111
use crate::back::write::llvm_err;
12-
use crate::builder::{SBuilder, UNNAMED};
12+
use crate::builder::{Builder, OperandRef, PlaceRef, UNNAMED};
1313
use crate::context::SimpleCx;
1414
use crate::declare::declare_simple_fn;
1515
use crate::errors::{AutoDiffWithoutEnable, LlvmError};
@@ -18,7 +18,7 @@ use crate::llvm::{Metadata, True};
1818
use crate::value::Value;
1919
use crate::{CodegenContext, LlvmCodegenBackend, ModuleLlvm, attributes, llvm};
2020

21-
fn get_params(fnc: &Value) -> Vec<&Value> {
21+
fn _get_params(fnc: &Value) -> Vec<&Value> {
2222
let param_num = llvm::LLVMCountParams(fnc) as usize;
2323
let mut fnc_args: Vec<&Value> = vec![];
2424
fnc_args.reserve(param_num);
@@ -48,9 +48,9 @@ fn has_sret(fnc: &Value) -> bool {
4848
// need to match those.
4949
// FIXME(ZuseZ4): This logic is a bit more complicated than it should be, can we simplify it
5050
// using iterators and peek()?
51-
fn match_args_from_caller_to_enzyme<'ll>(
51+
fn match_args_from_caller_to_enzyme<'ll, 'tcx>(
5252
cx: &SimpleCx<'ll>,
53-
builder: &SBuilder<'ll, 'll>,
53+
builder: &mut Builder<'_, 'll, 'tcx>,
5454
width: u32,
5555
args: &mut Vec<&'ll llvm::Value>,
5656
inputs: &[DiffActivity],
@@ -288,11 +288,14 @@ fn compute_enzyme_fn_ty<'ll>(
288288
/// [^1]: <https://enzyme.mit.edu/getting_started/CallingConvention/>
289289
// FIXME(ZuseZ4): `outer_fn` should include upstream safety checks to
290290
// cover some assumptions of enzyme/autodiff, which could lead to UB otherwise.
291-
fn generate_enzyme_call<'ll>(
291+
pub(crate) fn generate_enzyme_call<'ll, 'tcx>(
292+
builder: &mut Builder<'_, 'll, 'tcx>,
292293
cx: &SimpleCx<'ll>,
293294
fn_to_diff: &'ll Value,
294295
outer_fn: &'ll Value,
296+
fn_args: &[OperandRef<'tcx, &'ll Value>],
295297
attrs: AutoDiffAttrs,
298+
dest: PlaceRef<'tcx, &'ll Value>,
296299
) {
297300
// We have to pick the name depending on whether we want forward or reverse mode autodiff.
298301
let mut ad_name: String = match attrs.mode {
@@ -365,14 +368,6 @@ fn generate_enzyme_call<'ll>(
365368
let enzyme_marker_attr = llvm::CreateAttrString(cx.llcx, "enzyme_marker");
366369
attributes::apply_to_llfn(outer_fn, Function, &[enzyme_marker_attr]);
367370

368-
// first, remove all calls from fnc
369-
let entry = llvm::LLVMGetFirstBasicBlock(outer_fn);
370-
let br = llvm::LLVMRustGetTerminator(entry);
371-
llvm::LLVMRustEraseInstFromParent(br);
372-
373-
let last_inst = llvm::LLVMRustGetLastInstruction(entry).unwrap();
374-
let mut builder = SBuilder::build(cx, entry);
375-
376371
let num_args = llvm::LLVMCountParams(&fn_to_diff);
377372
let mut args = Vec::with_capacity(num_args as usize + 1);
378373
args.push(fn_to_diff);
@@ -388,40 +383,20 @@ fn generate_enzyme_call<'ll>(
388383
}
389384

390385
let has_sret = has_sret(outer_fn);
391-
let outer_args: Vec<&llvm::Value> = get_params(outer_fn);
386+
let outer_args: Vec<&llvm::Value> = fn_args.iter().map(|op| op.immediate()).collect();
392387
match_args_from_caller_to_enzyme(
393388
&cx,
394-
&builder,
389+
builder,
395390
attrs.width,
396391
&mut args,
397392
&attrs.input_activity,
398393
&outer_args,
399394
has_sret,
400395
);
401396

402-
let call = builder.call(enzyme_ty, ad_fn, &args, None);
403-
404-
// This part is a bit iffy. LLVM requires that a call to an inlineable function has some
405-
// metadata attached to it, but we just created this code oota. Given that the
406-
// differentiated function already has partly confusing metadata, and given that this
407-
// affects nothing but the auttodiff IR, we take a shortcut and just steal metadata from the
408-
// dummy code which we inserted at a higher level.
409-
// FIXME(ZuseZ4): Work with Enzyme core devs to clarify what debug metadata issues we have,
410-
// and how to best improve it for enzyme core and rust-enzyme.
411-
let md_ty = cx.get_md_kind_id("dbg");
412-
if llvm::LLVMRustHasMetadata(last_inst, md_ty) {
413-
let md = llvm::LLVMRustDIGetInstMetadata(last_inst)
414-
.expect("failed to get instruction metadata");
415-
let md_todiff = cx.get_metadata_value(md);
416-
llvm::LLVMSetMetadata(call, md_ty, md_todiff);
417-
} else {
418-
// We don't panic, since depending on whether we are in debug or release mode, we might
419-
// have no debug info to copy, which would then be ok.
420-
trace!("no dbg info");
421-
}
397+
let call = builder.call(enzyme_ty, None, None, ad_fn, &args, None, None);
422398

423-
// Now that we copied the metadata, get rid of dummy code.
424-
llvm::LLVMRustEraseInstUntilInclusive(entry, last_inst);
399+
builder.store_to_place(call, dest.val);
425400

426401
if cx.val_ty(call) == cx.type_void() || has_sret {
427402
if has_sret {
@@ -444,10 +419,10 @@ fn generate_enzyme_call<'ll>(
444419
llvm::LLVMBuildStore(&builder.llbuilder, call, sret_ptr);
445420
}
446421
builder.ret_void();
447-
} else {
448-
builder.ret(call);
449422
}
450423

424+
builder.store_to_place(call, dest.val);
425+
451426
// Let's crash in case that we messed something up above and generated invalid IR.
452427
llvm::LLVMRustVerifyFunction(
453428
outer_fn,
@@ -461,6 +436,7 @@ pub(crate) fn differentiate<'ll>(
461436
cgcx: &CodegenContext<LlvmCodegenBackend>,
462437
diff_items: Vec<AutoDiffItem>,
463438
) -> Result<(), FatalError> {
439+
// TODO(Sa4dUs): delete all this logic
464440
for item in &diff_items {
465441
trace!("{}", item);
466442
}
@@ -480,7 +456,7 @@ pub(crate) fn differentiate<'ll>(
480456
for item in diff_items.iter() {
481457
let name = item.source.clone();
482458
let fn_def: Option<&llvm::Value> = cx.get_function(&name);
483-
let Some(fn_def) = fn_def else {
459+
let Some(_fn_def) = fn_def else {
484460
return Err(llvm_err(
485461
diag_handler.handle(),
486462
LlvmError::PrepareAutoDiff {
@@ -492,7 +468,7 @@ pub(crate) fn differentiate<'ll>(
492468
};
493469
debug!(?item.target);
494470
let fn_target: Option<&llvm::Value> = cx.get_function(&item.target);
495-
let Some(fn_target) = fn_target else {
471+
let Some(_fn_target) = fn_target else {
496472
return Err(llvm_err(
497473
diag_handler.handle(),
498474
LlvmError::PrepareAutoDiff {
@@ -503,7 +479,7 @@ pub(crate) fn differentiate<'ll>(
503479
));
504480
};
505481

506-
generate_enzyme_call(&cx, fn_def, fn_target, item.attrs.clone());
482+
// generate_enzyme_call(&cx, fn_def, fn_target, item.attrs.clone());
507483
}
508484

509485
// FIXME(ZuseZ4): support SanitizeHWAddress and prevent illegal/unsupported opts

compiler/rustc_codegen_llvm/src/intrinsic.rs

Lines changed: 50 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,17 +9,19 @@ use rustc_codegen_ssa::mir::operand::{OperandRef, OperandValue};
99
use rustc_codegen_ssa::mir::place::{PlaceRef, PlaceValue};
1010
use rustc_codegen_ssa::traits::*;
1111
use rustc_hir as hir;
12+
use rustc_hir::def_id::LOCAL_CRATE;
1213
use rustc_middle::mir::BinOp;
1314
use rustc_middle::ty::layout::{FnAbiOf, HasTyCtxt, HasTypingEnv, LayoutOf};
14-
use rustc_middle::ty::{self, GenericArgsRef, Ty};
15+
use rustc_middle::ty::{self, GenericArgsRef, Instance, Ty};
1516
use rustc_middle::{bug, span_bug};
1617
use rustc_span::{Span, Symbol, sym};
17-
use rustc_symbol_mangling::mangle_internal_symbol;
18+
use rustc_symbol_mangling::{mangle_internal_symbol, symbol_name_for_instance_in_crate};
1819
use rustc_target::spec::PanicStrategy;
1920
use tracing::debug;
2021

2122
use crate::abi::FnAbiLlvmExt;
2223
use crate::builder::Builder;
24+
use crate::builder::autodiff::generate_enzyme_call;
2325
use crate::context::CodegenCx;
2426
use crate::llvm::{self, Metadata};
2527
use crate::type_::Type;
@@ -189,23 +191,59 @@ impl<'ll, 'tcx> IntrinsicCallBuilderMethods<'tcx> for Builder<'_, 'll, 'tcx> {
189191
&[ptr, args[1].immediate()],
190192
)
191193
}
192-
_ if tcx.has_attr(def_id, sym::rustc_autodiff) => {
194+
_ if tcx.has_attr(instance.def_id(), sym::rustc_autodiff) => {
193195
// NOTE(Sa4dUs): This is a hacky way to get the autodiff items
194196
// so we can focus on the lowering of the intrinsic call
197+
let mut source_id = None;
198+
let mut diff_attrs = None;
199+
let items: Vec<_> = tcx.hir_body_owners().map(|i| i.to_def_id()).collect();
200+
201+
// Hacky way of getting primal-diff pair, only works for code with 1 autodiff call
202+
for target_id in &items {
203+
let Some(target_attrs) = &tcx.codegen_fn_attrs(target_id).autodiff_item else {
204+
continue;
205+
};
195206

196-
// `diff_items` is empty even when autodiff is enabled, and if we're here,
197-
// it's because some function was marked as intrinsic and had the `rustc_autodiff` attr
198-
let diff_items = tcx.collect_and_partition_mono_items(()).autodiff_items;
207+
if target_attrs.is_source() {
208+
source_id = Some(*target_id);
209+
} else {
210+
diff_attrs = Some(target_attrs);
211+
}
212+
}
199213

200-
// this shouldn't happen?
201-
if diff_items.is_empty() {
202-
bug!("no autodiff items found for {def_id:?}");
214+
if source_id.is_none() || diff_attrs.is_none() {
215+
bug!("could not find source_id={source_id:?} or diff_attrs={diff_attrs:?}");
203216
}
204217

205-
// TODO(Sa4dUs): generate the enzyme call itself, based on the logic in `builder.rs`
218+
let diff_attrs = diff_attrs.unwrap().clone();
219+
220+
// Get source fn
221+
let source_id = source_id.unwrap();
222+
let fn_source = Instance::mono(tcx, source_id);
223+
let source_symbol =
224+
symbol_name_for_instance_in_crate(tcx, fn_source.clone(), LOCAL_CRATE);
225+
let fn_to_diff: Option<&'ll llvm::Value> = self.cx.get_function(&source_symbol);
226+
let Some(fn_to_diff) = fn_to_diff else { bug!("could not find source function") };
227+
228+
// Declare target fn
229+
let target_symbol =
230+
symbol_name_for_instance_in_crate(tcx, instance.clone(), LOCAL_CRATE);
231+
let fn_abi = self.cx.fn_abi_of_instance(instance, ty::List::empty());
232+
let outer_fn: &'ll Value =
233+
self.cx.declare_fn(&target_symbol, fn_abi, Some(instance));
234+
235+
// Build body
236+
generate_enzyme_call(
237+
self,
238+
self.cx,
239+
fn_to_diff,
240+
outer_fn,
241+
args, // This argument was not in the original `generate_enzyme_call`, now it's included because `get_params` is not working anymore
242+
diff_attrs.clone(),
243+
result,
244+
);
206245

207-
// Just gen the fallback body for now
208-
return Err(ty::Instance::new_raw(def_id, instance.args));
246+
return Ok(());
209247
}
210248
sym::is_val_statically_known => {
211249
if let OperandValue::Immediate(imm) = args[0].val {

tests/codegen-llvm/autodiff/scalar.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
//@ no-prefer-dynamic
33
//@ needs-enzyme
44
#![feature(autodiff)]
5+
#![feature(intrinsics)]
56

67
use std::autodiff::autodiff_reverse;
78

0 commit comments

Comments
 (0)