Skip to content

Commit e98dc80

Browse files
committed
add sret recognition, helpers
1 parent 8f2fa11 commit e98dc80

File tree

6 files changed

+106
-7
lines changed

6 files changed

+106
-7
lines changed

compiler/rustc_codegen_llvm/messages.ftl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
codegen_llvm_autodiff_unused_args = implementation bug, failed to match all args on llvm level
12
codegen_llvm_autodiff_without_enable = using the autodiff feature requires -Z autodiff=Enable
23
codegen_llvm_autodiff_without_lto = using the autodiff feature requires using fat-lto
34

compiler/rustc_codegen_llvm/src/back/lto.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -655,6 +655,10 @@ pub(crate) fn run_pass_manager(
655655
unsafe {
656656
write::llvm_optimize(cgcx, dcx, module, None, config, opt_level, opt_stage, stage)?;
657657
}
658+
// This is the final IR, so people should be able to inspect the optimized autodiff output.
659+
if config.autodiff.contains(&config::AutoDiff::PrintModAfter) {
660+
unsafe { llvm::LLVMDumpModule(module.module_llvm.llmod()) };
661+
}
658662

659663
if cfg!(llvm_enzyme) && enable_ad {
660664
let opt_stage = llvm::OptStage::FatLTO;

compiler/rustc_codegen_llvm/src/builder/autodiff.rs

Lines changed: 89 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,14 @@ use std::ptr;
33
use rustc_ast::expand::autodiff_attrs::{AutoDiffAttrs, AutoDiffItem, DiffActivity, DiffMode};
44
use rustc_codegen_ssa::ModuleCodegen;
55
use rustc_codegen_ssa::back::write::ModuleConfig;
6-
use rustc_errors::FatalError;
6+
use rustc_errors::{DiagCtxt, FatalError};
77
use tracing::{debug, trace};
88

99
use crate::back::write::llvm_err;
1010
use crate::builder::SBuilder;
1111
use crate::context::SimpleCx;
1212
use crate::declare::declare_simple_fn;
13-
use crate::errors::{AutoDiffWithoutEnable, LlvmError};
13+
use crate::errors::{AutoDiffUnusedArgs, AutoDiffWithoutEnable, LlvmError};
1414
use crate::llvm::AttributePlace::Function;
1515
use crate::llvm::{Metadata, True};
1616
use crate::value::Value;
@@ -27,6 +27,69 @@ fn get_params(fnc: &Value) -> Vec<&Value> {
2727
}
2828
}
2929

30+
// A helper object to make sure, that we pass all of the input and output arguments of the outer
31+
// wrapper into the inner enzyme call.
32+
struct FunctionArgs {
33+
input_args: Vec<bool>,
34+
return_arg: bool,
35+
has_sret: bool,
36+
}
37+
38+
39+
impl FunctionArgs {
40+
fn fully_used(&self) -> bool {
41+
self.input_args.iter().all(|x| *x) && self.return_arg
42+
}
43+
44+
fn use_input(&mut self, idx: usize) {
45+
assert!(!self.input_args[idx]);
46+
self.input_args[idx] = true;
47+
}
48+
49+
fn use_output(&mut self) {
50+
assert!(!self.return_arg);
51+
self.return_arg = true;
52+
}
53+
54+
fn has_sret(&self) -> bool {
55+
self.has_sret
56+
}
57+
}
58+
59+
impl<'ll> From<(&'ll llvm::Context, &'ll Value)> for FunctionArgs {
60+
fn from(wrapper: (&'ll llvm::Context, &'ll Value)) -> Self {
61+
let (llcx, fnc) = wrapper;
62+
let num_args = unsafe { llvm::LLVMCountParams(fnc) as usize };
63+
let input_args = vec![false; num_args];
64+
65+
let fn_ty = unsafe { llvm::LLVMGlobalGetValueType(fnc) };
66+
let ret_ty = unsafe { llvm::LLVMGetReturnType(fn_ty) };
67+
let void_ty = unsafe { llvm::LLVMVoidTypeInContext(llcx) };
68+
let return_arg = ret_ty != void_ty;
69+
70+
71+
let has_sret = if num_args == 0 {
72+
false
73+
} else {
74+
unsafe { llvm::LLVMRustHasAttributeAtIndex(fnc, 0, llvm::AttributeKind::StructRet) }
75+
};
76+
if has_sret {
77+
dbg!("has sret");
78+
} else {
79+
dbg!("no sret");
80+
}
81+
82+
FunctionArgs { input_args, return_arg, has_sret }
83+
}
84+
}
85+
86+
// The drop implementation makes sure that when we're done, all input and output args are matched.
87+
//impl Drop for FunctionArgs {
88+
// fn drop(&mut self) {
89+
// assert!(self.fully_used());
90+
// }
91+
//}
92+
3093
/// When differentiating `fn_to_diff`, take a `outer_fn` and generate another
3194
/// function with expected naming and calling conventions[^1] which will be
3295
/// discovered by the enzyme LLVM pass and its body populated with the differentiated
@@ -37,14 +100,18 @@ fn get_params(fnc: &Value) -> Vec<&Value> {
37100
// FIXME(ZuseZ4): `outer_fn` should include upstream safety checks to
38101
// cover some assumptions of enzyme/autodiff, which could lead to UB otherwise.
39102
fn generate_enzyme_call<'ll>(
103+
_dcx: &DiagCtxt,
40104
cx: &SimpleCx<'ll>,
41105
fn_to_diff: &'ll Value,
42106
outer_fn: &'ll Value,
43107
attrs: AutoDiffAttrs,
44-
) {
108+
) -> Result<(), FatalError>
109+
{
45110
let inputs = attrs.input_activity;
46111
let output = attrs.ret_activity;
47112

113+
let fa = FunctionArgs::from((cx.llcx, outer_fn));
114+
48115
// We have to pick the name depending on whether we want forward or reverse mode autodiff.
49116
let mut ad_name: String = match attrs.mode {
50117
DiffMode::Forward => "__enzyme_fwddiff",
@@ -99,6 +166,9 @@ fn generate_enzyme_call<'ll>(
99166
let fn_ty = llvm::LLVMGlobalGetValueType(outer_fn);
100167
let ret_ty = llvm::LLVMGetReturnType(fn_ty);
101168

169+
dbg!(&outer_fn);
170+
dbg!(&fn_ty);
171+
102172
// LLVM can figure out the input types on it's own, so we take a shortcut here.
103173
let enzyme_ty = llvm::LLVMFunctionType(ret_ty, ptr::null(), 0, True);
104174

@@ -163,6 +233,15 @@ fn generate_enzyme_call<'ll>(
163233
// using iterators and peek()?
164234
let mut outer_pos: usize = 0;
165235
let mut activity_pos = 0;
236+
237+
if fa.has_sret() {
238+
// Then the first outer arg is the sret pointer. Enzyme doesn't know about sret, so the
239+
// inner function will still return something. We increase our outer_pos by one,
240+
// and once we're done with all other args we will take the return of the inner call and
241+
// update the sret pointer with it
242+
outer_pos = 1;
243+
}
244+
166245
let outer_args: Vec<&llvm::Value> = get_params(outer_fn);
167246
while activity_pos < inputs.len() {
168247
let diff_activity = inputs[activity_pos as usize];
@@ -293,6 +372,11 @@ fn generate_enzyme_call<'ll>(
293372
llvm::LLVMRustVerifierFailureAction::LLVMAbortProcessAction,
294373
);
295374
}
375+
376+
//if !fa.fully_used() {
377+
// return Err(dcx.handle().emit_almost_fatal(AutoDiffUnusedArgs));
378+
//}
379+
Ok(())
296380
}
297381

298382
pub(crate) fn differentiate<'ll>(
@@ -312,8 +396,7 @@ pub(crate) fn differentiate<'ll>(
312396
if !diff_items.is_empty()
313397
&& !cgcx.opts.unstable_opts.autodiff.contains(&rustc_session::config::AutoDiff::Enable)
314398
{
315-
let dcx = cgcx.create_dcx();
316-
return Err(dcx.handle().emit_almost_fatal(AutoDiffWithoutEnable));
399+
return Err(diag_handler.handle().emit_almost_fatal(AutoDiffWithoutEnable));
317400
}
318401

319402
// Before dumping the module, we want all the TypeTrees to become part of the module.
@@ -343,7 +426,7 @@ pub(crate) fn differentiate<'ll>(
343426
));
344427
};
345428

346-
generate_enzyme_call(&cx, fn_def, fn_target, item.attrs.clone());
429+
generate_enzyme_call(&diag_handler, &cx, fn_def, fn_target, item.attrs.clone())?;
347430
}
348431

349432
// FIXME(ZuseZ4): support SanitizeHWAddress and prevent illegal/unsupported opts

compiler/rustc_codegen_llvm/src/errors.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,10 @@ impl<G: EmissionGuarantee> Diagnostic<'_, G> for ParseTargetMachineConfig<'_> {
9494
#[diag(codegen_llvm_autodiff_without_lto)]
9595
pub(crate) struct AutoDiffWithoutLTO;
9696

97+
#[derive(Diagnostic)]
98+
#[diag(codegen_llvm_autodiff_unused_args)]
99+
pub(crate) struct AutoDiffUnusedArgs;
100+
97101
#[derive(Diagnostic)]
98102
#[diag(codegen_llvm_autodiff_without_enable)]
99103
pub(crate) struct AutoDiffWithoutEnable;

compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
use libc::{c_char, c_uint};
55

6-
use super::ffi::{BasicBlock, Metadata, Module, Type, Value};
6+
use super::ffi::{AttributeKind, BasicBlock, Metadata, Module, Type, Value};
77
use crate::llvm::Bool;
88

99
#[link(name = "llvm-wrapper", kind = "static")]
@@ -16,6 +16,7 @@ unsafe extern "C" {
1616
pub(crate) fn LLVMRustEraseInstFromParent(V: &Value);
1717
pub(crate) fn LLVMRustGetTerminator<'a>(B: &BasicBlock) -> &'a Value;
1818
pub(crate) fn LLVMRustVerifyFunction(V: &Value, action: LLVMRustVerifierFailureAction) -> Bool;
19+
pub(crate) fn LLVMRustHasAttributeAtIndex(V: &Value, i: c_uint, Kind: AttributeKind) -> bool;
1920
}
2021

2122
unsafe extern "C" {

compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -380,6 +380,12 @@ static inline void AddAttributes(T *t, unsigned Index, LLVMAttributeRef *Attrs,
380380
t->setAttributes(PALNew);
381381
}
382382

383+
extern "C" bool LLVMRustHasAttributeAtIndex(LLVMValueRef Fn, unsigned Index,
384+
LLVMRustAttributeKind RustAttr) {
385+
Function *F = unwrap<Function>(Fn);
386+
return F->hasParamAttribute(Index, fromRust(RustAttr));
387+
}
388+
383389
extern "C" void LLVMRustAddFunctionAttributes(LLVMValueRef Fn, unsigned Index,
384390
LLVMAttributeRef *Attrs,
385391
size_t AttrsLen) {

0 commit comments

Comments
 (0)