Skip to content

Commit 116eb77

Browse files
committed
Remove sret logic
1 parent 8c856fd commit 116eb77

File tree

4 files changed

+74
-208
lines changed

4 files changed

+74
-208
lines changed

compiler/rustc_codegen_llvm/src/builder/autodiff.rs

Lines changed: 56 additions & 196 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ use crate::context::SimpleCx;
1414
use crate::declare::declare_simple_fn;
1515
use crate::errors::{AutoDiffWithoutEnable, LlvmError};
1616
use crate::llvm::AttributePlace::Function;
17-
use crate::llvm::{Metadata, True};
17+
use crate::llvm::{Metadata, True, Type};
1818
use crate::value::Value;
1919
use crate::{CodegenContext, LlvmCodegenBackend, ModuleLlvm, attributes, llvm};
2020

@@ -29,7 +29,7 @@ fn _get_params(fnc: &Value) -> Vec<&Value> {
2929
fnc_args
3030
}
3131

32-
fn has_sret(fnc: &Value) -> bool {
32+
fn _has_sret(fnc: &Value) -> bool {
3333
let num_args = llvm::LLVMCountParams(fnc) as usize;
3434
if num_args == 0 {
3535
false
@@ -55,7 +55,6 @@ fn match_args_from_caller_to_enzyme<'ll, 'tcx>(
5555
args: &mut Vec<&'ll llvm::Value>,
5656
inputs: &[DiffActivity],
5757
outer_args: &[&'ll llvm::Value],
58-
has_sret: bool,
5958
) {
6059
debug!("matching autodiff arguments");
6160
// We now handle the issue that Rust level arguments not always match the llvm-ir level
@@ -67,20 +66,12 @@ fn match_args_from_caller_to_enzyme<'ll, 'tcx>(
6766
let mut outer_pos: usize = 0;
6867
let mut activity_pos = 0;
6968

70-
if has_sret {
71-
// Then the first outer arg is the sret pointer. Enzyme doesn't know about sret, so the
72-
// inner function will still return something. We increase our outer_pos by one,
73-
// and once we're done with all other args we will take the return of the inner call and
74-
// update the sret pointer with it
75-
outer_pos = 1;
76-
}
77-
78-
let enzyme_const = cx.create_metadata(b"enzyme_const");
79-
let enzyme_out = cx.create_metadata(b"enzyme_out");
80-
let enzyme_dup = cx.create_metadata(b"enzyme_dup");
81-
let enzyme_dupv = cx.create_metadata(b"enzyme_dupv");
82-
let enzyme_dupnoneed = cx.create_metadata(b"enzyme_dupnoneed");
83-
let enzyme_dupnoneedv = cx.create_metadata(b"enzyme_dupnoneedv");
69+
let enzyme_const = cx.create_metadata("enzyme_const".to_string()).unwrap();
70+
let enzyme_out = cx.create_metadata("enzyme_out".to_string()).unwrap();
71+
let enzyme_dup = cx.create_metadata("enzyme_dup".to_string()).unwrap();
72+
let enzyme_dupv = cx.create_metadata("enzyme_dupv".to_string()).unwrap();
73+
let enzyme_dupnoneed = cx.create_metadata("enzyme_dupnoneed".to_string()).unwrap();
74+
let enzyme_dupnoneedv = cx.create_metadata("enzyme_dupnoneedv".to_string()).unwrap();
8475

8576
while activity_pos < inputs.len() {
8677
let diff_activity = inputs[activity_pos as usize];
@@ -193,92 +184,6 @@ fn match_args_from_caller_to_enzyme<'ll, 'tcx>(
193184
}
194185
}
195186

196-
// On LLVM-IR, we can luckily declare __enzyme_ functions without specifying the input
197-
// arguments. We do however need to declare them with their correct return type.
198-
// We already figured the correct return type out in our frontend, when generating the outer_fn,
199-
// so we can now just go ahead and use that. This is not always trivial, e.g. because sret.
200-
// Beyond sret, this article describes our challenges nicely:
201-
// <https://yorickpeterse.com/articles/the-mess-that-is-handling-structure-arguments-and-returns-in-llvm/>
202-
// I.e. (i32, f32) will get merged into i64, but we don't handle that yet.
203-
fn compute_enzyme_fn_ty<'ll>(
204-
cx: &SimpleCx<'ll>,
205-
attrs: &AutoDiffAttrs,
206-
fn_to_diff: &'ll Value,
207-
outer_fn: &'ll Value,
208-
) -> &'ll llvm::Type {
209-
let fn_ty = cx.get_type_of_global(outer_fn);
210-
let mut ret_ty = cx.get_return_type(fn_ty);
211-
212-
let has_sret = has_sret(outer_fn);
213-
214-
if has_sret {
215-
// Now we don't just forward the return type, so we have to figure it out based on the
216-
// primal return type, in combination with the autodiff settings.
217-
let fn_ty = cx.get_type_of_global(fn_to_diff);
218-
let inner_ret_ty = cx.get_return_type(fn_ty);
219-
220-
let void_ty = unsafe { llvm::LLVMVoidTypeInContext(cx.llcx) };
221-
if inner_ret_ty == void_ty {
222-
// This indicates that even the inner function has an sret.
223-
// Right now I only look for an sret in the outer function.
224-
// This *probably* needs some extra handling, but I never ran
225-
// into such a case. So I'll wait for user reports to have a test case.
226-
bug!("sret in inner function");
227-
}
228-
229-
if attrs.width == 1 {
230-
// Enzyme returns a struct of style:
231-
// `{ original_ret(if requested), float, float, ... }`
232-
let mut struct_elements = vec![];
233-
if attrs.has_primal_ret() {
234-
struct_elements.push(inner_ret_ty);
235-
}
236-
// Next, we push the list of active floats, since they will be lowered to `enzyme_out`,
237-
// and therefore part of the return struct.
238-
let param_tys = cx.func_params_types(fn_ty);
239-
for (act, param_ty) in attrs.input_activity.iter().zip(param_tys) {
240-
if matches!(act, DiffActivity::Active) {
241-
// Now find the float type at position i based on the fn_ty,
242-
// to know what (f16/f32/f64/...) to add to the struct.
243-
struct_elements.push(param_ty);
244-
}
245-
}
246-
ret_ty = cx.type_struct(&struct_elements, false);
247-
} else {
248-
// First we check if we also have to deal with the primal return.
249-
match attrs.mode {
250-
DiffMode::Forward => match attrs.ret_activity {
251-
DiffActivity::Dual => {
252-
let arr_ty =
253-
unsafe { llvm::LLVMArrayType2(inner_ret_ty, attrs.width as u64 + 1) };
254-
ret_ty = arr_ty;
255-
}
256-
DiffActivity::DualOnly => {
257-
let arr_ty =
258-
unsafe { llvm::LLVMArrayType2(inner_ret_ty, attrs.width as u64) };
259-
ret_ty = arr_ty;
260-
}
261-
DiffActivity::Const => {
262-
todo!("Not sure, do we need to do something here?");
263-
}
264-
_ => {
265-
bug!("unreachable");
266-
}
267-
},
268-
DiffMode::Reverse => {
269-
todo!("Handle sret for reverse mode");
270-
}
271-
_ => {
272-
bug!("unreachable");
273-
}
274-
}
275-
}
276-
}
277-
278-
// LLVM can figure out the input types on it's own, so we take a shortcut here.
279-
unsafe { llvm::LLVMFunctionType(ret_ty, ptr::null(), 0, True) }
280-
}
281-
282187
/// When differentiating `fn_to_diff`, take a `outer_fn` and generate another
283188
/// function with expected naming and calling conventions[^1] which will be
284189
/// discovered by the enzyme LLVM pass and its body populated with the differentiated
@@ -292,7 +197,8 @@ pub(crate) fn generate_enzyme_call<'ll, 'tcx>(
292197
builder: &mut Builder<'_, 'll, 'tcx>,
293198
cx: &SimpleCx<'ll>,
294199
fn_to_diff: &'ll Value,
295-
outer_fn: &'ll Value,
200+
outer_name: &str,
201+
ret_ty: &'ll Type,
296202
fn_args: &[OperandRef<'tcx, &'ll Value>],
297203
attrs: AutoDiffAttrs,
298204
dest: PlaceRef<'tcx, &'ll Value>,
@@ -305,11 +211,9 @@ pub(crate) fn generate_enzyme_call<'ll, 'tcx>(
305211
}
306212
.to_string();
307213

308-
// add outer_fn name to ad_name to make it unique, in case users apply autodiff to multiple
214+
// add outer_name to ad_name to make it unique, in case users apply autodiff to multiple
309215
// functions. Unwrap will only panic, if LLVM gave us an invalid string.
310-
let name = llvm::get_value_name(outer_fn);
311-
let outer_fn_name = std::str::from_utf8(&name).unwrap();
312-
ad_name.push_str(outer_fn_name);
216+
ad_name.push_str(outer_name);
313217

314218
// Let us assume the user wrote the following function square:
315219
//
@@ -320,13 +224,7 @@ pub(crate) fn generate_enzyme_call<'ll, 'tcx>(
320224
// ret double %0
321225
// }
322226
// ```
323-
//
324-
// The user now applies autodiff to the function square, in which case fn_to_diff will be `square`.
325-
// Our macro generates the following placeholder code (slightly simplified):
326-
//
327-
// ```llvm
328227
// define double @dsquare(double %x) {
329-
// ; placeholder code
330228
// return 0.0;
331229
// }
332230
// ```
@@ -343,92 +241,54 @@ pub(crate) fn generate_enzyme_call<'ll, 'tcx>(
343241
// ret double %0
344242
// }
345243
// ```
346-
unsafe {
347-
let enzyme_ty = compute_enzyme_fn_ty(cx, &attrs, fn_to_diff, outer_fn);
348-
349-
// FIXME(ZuseZ4): the CC/Addr/Vis values are best effort guesses, we should look at tests and
350-
// think a bit more about what should go here.
351-
let cc = llvm::LLVMGetFunctionCallConv(outer_fn);
352-
let ad_fn = declare_simple_fn(
353-
cx,
354-
&ad_name,
355-
llvm::CallConv::try_from(cc).expect("invalid callconv"),
356-
llvm::UnnamedAddr::No,
357-
llvm::Visibility::Default,
358-
enzyme_ty,
359-
);
360-
361-
// Otherwise LLVM might inline our temporary code before the enzyme pass has a chance to
362-
// do it's work.
363-
let attr = llvm::AttributeKind::NoInline.create_attr(cx.llcx);
364-
attributes::apply_to_llfn(ad_fn, Function, &[attr]);
365-
366-
// We add a made-up attribute just such that we can recognize it after AD to update
367-
// (no)-inline attributes. We'll then also remove this attribute.
368-
let enzyme_marker_attr = llvm::CreateAttrString(cx.llcx, "enzyme_marker");
369-
attributes::apply_to_llfn(outer_fn, Function, &[enzyme_marker_attr]);
370-
371-
let num_args = llvm::LLVMCountParams(&fn_to_diff);
372-
let mut args = Vec::with_capacity(num_args as usize + 1);
373-
args.push(fn_to_diff);
374-
375-
let enzyme_primal_ret = cx.create_metadata(b"enzyme_primal_return");
376-
if matches!(attrs.ret_activity, DiffActivity::Dual | DiffActivity::Active) {
377-
args.push(cx.get_metadata_value(enzyme_primal_ret));
378-
}
379-
if attrs.width > 1 {
380-
let enzyme_width = cx.create_metadata(b"enzyme_width");
381-
args.push(cx.get_metadata_value(enzyme_width));
382-
args.push(cx.get_const_int(cx.type_i64(), attrs.width as u64));
383-
}
384-
385-
let has_sret = has_sret(outer_fn);
386-
let outer_args: Vec<&llvm::Value> = fn_args.iter().map(|op| op.immediate()).collect();
387-
match_args_from_caller_to_enzyme(
388-
&cx,
389-
builder,
390-
attrs.width,
391-
&mut args,
392-
&attrs.input_activity,
393-
&outer_args,
394-
has_sret,
395-
);
396-
397-
let call = builder.call(enzyme_ty, None, None, ad_fn, &args, None, None);
398-
399-
builder.store_to_place(call, dest.val);
244+
let enzyme_ty = unsafe { llvm::LLVMFunctionType(ret_ty, ptr::null(), 0, True) };
245+
246+
// FIXME(ZuseZ4): the CC/Addr/Vis values are best effort guesses, we should look at tests and
247+
// think a bit more about what should go here.
248+
// FIXME(Sa4dUs): have to find a way to get the cc, using `FastCallConv` for now
249+
let cc = 8;
250+
let ad_fn = declare_simple_fn(
251+
cx,
252+
&ad_name,
253+
llvm::CallConv::try_from(cc).expect("invalid callconv"),
254+
llvm::UnnamedAddr::No,
255+
llvm::Visibility::Default,
256+
enzyme_ty,
257+
);
258+
259+
// Otherwise LLVM might inline our temporary code before the enzyme pass has a chance to
260+
// do it's work.
261+
let attr = llvm::AttributeKind::NoInline.create_attr(cx.llcx);
262+
attributes::apply_to_llfn(ad_fn, Function, &[attr]);
263+
264+
let num_args = llvm::LLVMCountParams(&fn_to_diff);
265+
let mut args = Vec::with_capacity(num_args as usize + 1);
266+
args.push(fn_to_diff);
267+
268+
let enzyme_primal_ret = cx.create_metadata("enzyme_primal_return".to_string()).unwrap();
269+
if matches!(attrs.ret_activity, DiffActivity::Dual | DiffActivity::Active) {
270+
args.push(cx.get_metadata_value(enzyme_primal_ret));
271+
}
272+
if attrs.width > 1 {
273+
let enzyme_width = cx.create_metadata("enzyme_width".to_string()).unwrap();
274+
args.push(cx.get_metadata_value(enzyme_width));
275+
args.push(cx.get_const_int(cx.type_i64(), attrs.width as u64));
276+
}
400277

401-
if cx.val_ty(call) == cx.type_void() || has_sret {
402-
if has_sret {
403-
// This is what we already have in our outer_fn (shortened):
404-
// define void @_foo(ptr <..> sret([32 x i8]) initializes((0, 32)) %0, <...>) {
405-
// %7 = call [4 x double] (...) @__enzyme_fwddiff_foo(ptr @square, metadata !"enzyme_width", i64 4, <...>)
406-
// <Here we are, we want to add the following two lines>
407-
// store [4 x double] %7, ptr %0, align 8
408-
// ret void
409-
// }
278+
let outer_args: Vec<&llvm::Value> = fn_args.iter().map(|op| op.immediate()).collect();
410279

411-
// now store the result of the enzyme call into the sret pointer.
412-
let sret_ptr = outer_args[0];
413-
let call_ty = cx.val_ty(call);
414-
if attrs.width == 1 {
415-
assert_eq!(cx.type_kind(call_ty), TypeKind::Struct);
416-
} else {
417-
assert_eq!(cx.type_kind(call_ty), TypeKind::Array);
418-
}
419-
llvm::LLVMBuildStore(&builder.llbuilder, call, sret_ptr);
420-
}
421-
builder.ret_void();
422-
}
280+
match_args_from_caller_to_enzyme(
281+
&cx,
282+
builder,
283+
attrs.width,
284+
&mut args,
285+
&attrs.input_activity,
286+
&outer_args,
287+
);
423288

424-
builder.store_to_place(call, dest.val);
289+
let call = builder.call(enzyme_ty, None, None, ad_fn, &args, None, None);
425290

426-
// Let's crash in case that we messed something up above and generated invalid IR.
427-
llvm::LLVMRustVerifyFunction(
428-
outer_fn,
429-
llvm::LLVMRustVerifierFailureAction::LLVMAbortProcessAction,
430-
);
431-
}
291+
builder.store_to_place(call, dest.val);
432292
}
433293

434294
pub(crate) fn differentiate<'ll>(

compiler/rustc_codegen_llvm/src/context.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -654,7 +654,7 @@ impl<'ll, 'tcx> CodegenCx<'ll, 'tcx> {
654654
}
655655
}
656656
impl<'ll> SimpleCx<'ll> {
657-
pub(crate) fn get_return_type(&self, ty: &'ll Type) -> &'ll Type {
657+
pub(crate) fn _get_return_type(&self, ty: &'ll Type) -> &'ll Type {
658658
assert_eq!(self.type_kind(ty), TypeKind::Function);
659659
unsafe { llvm::LLVMGetReturnType(ty) }
660660
}

compiler/rustc_codegen_llvm/src/intrinsic.rs

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -176,10 +176,17 @@ impl<'ll, 'tcx> IntrinsicCallBuilderMethods<'tcx> for Builder<'_, 'll, 'tcx> {
176176
span: Span,
177177
) -> Result<(), ty::Instance<'tcx>> {
178178
let tcx = self.tcx;
179+
let callee_ty = instance.ty(tcx, self.typing_env());
179180

180-
let name = tcx.item_name(instance.def_id());
181181
let fn_args = instance.args;
182182

183+
let sig = callee_ty.fn_sig(tcx);
184+
let sig = tcx.normalize_erasing_late_bound_regions(self.typing_env(), sig);
185+
let ret_ty = sig.output();
186+
let name = tcx.item_name(instance.def_id());
187+
188+
let llret_ty = self.layout_of(ret_ty).llvm_type(self);
189+
183190
let simple = call_simple_intrinsic(self, name, args);
184191
let llval = match name {
185192
_ if simple.is_some() => simple.unwrap(),
@@ -225,20 +232,14 @@ impl<'ll, 'tcx> IntrinsicCallBuilderMethods<'tcx> for Builder<'_, 'll, 'tcx> {
225232
let fn_to_diff: Option<&'ll llvm::Value> = self.cx.get_function(&source_symbol);
226233
let Some(fn_to_diff) = fn_to_diff else { bug!("could not find source function") };
227234

228-
// Declare target fn
229-
let target_symbol =
230-
symbol_name_for_instance_in_crate(tcx, instance.clone(), LOCAL_CRATE);
231-
let fn_abi = self.cx.fn_abi_of_instance(instance, ty::List::empty());
232-
let outer_fn: &'ll Value =
233-
self.cx.declare_fn(&target_symbol, fn_abi, Some(instance));
234-
235235
// Build body
236236
generate_enzyme_call(
237237
self,
238238
self.cx,
239239
fn_to_diff,
240-
outer_fn,
241-
args, // This argument was not in the original `generate_enzyme_call`, now it's included because `get_params` is not working anymore
240+
name.as_str(),
241+
llret_ty,
242+
args,
242243
diff_attrs.clone(),
243244
result,
244245
);

compiler/rustc_hir_analysis/src/check/intrinsic.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,12 @@ pub(crate) fn check_intrinsic_type(
197197
(Ty::new_ref(tcx, env_region, va_list_ty, mutbl), va_list_ty)
198198
};
199199

200-
let safety = intrinsic_operation_unsafety(tcx, intrinsic_id);
200+
// FIXME(Sa4dUs): Get the actual safety level of the diff function
201+
let safety = if has_autodiff {
202+
hir::Safety::Safe
203+
} else {
204+
intrinsic_operation_unsafety(tcx, intrinsic_id)
205+
};
201206
let n_lts = 0;
202207
let (n_tps, n_cts, inputs, output) = match intrinsic_name {
203208
_ if has_autodiff => {

0 commit comments

Comments
 (0)