Skip to content

Commit d7f80fe

Browse files
committed
Handle slices when extracting args from tuple
1 parent f81fd98 commit d7f80fe

File tree

3 files changed

+39
-12
lines changed

3 files changed

+39
-12
lines changed

compiler/rustc_codegen_llvm/src/intrinsic.rs

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ use std::cmp::Ordering;
33

44
use rustc_abi::{Align, BackendRepr, ExternAbi, Float, HasDataLayout, Primitive, Size};
55
use rustc_codegen_ssa::base::{compare_simd_types, wants_msvc_seh, wants_wasm_eh};
6-
use rustc_codegen_ssa::codegen_attrs::autodiff_attrs;
76
use rustc_codegen_ssa::common::{IntPredicate, TypeKind};
87
use rustc_codegen_ssa::errors::{ExpectedPointerMutability, InvalidMonomorphization};
98
use rustc_codegen_ssa::mir::operand::{OperandRef, OperandValue};
@@ -1165,7 +1164,13 @@ fn codegen_enzyme_autodiff<'ll, 'tcx>(
11651164
Instance::try_resolve(tcx, bx.cx.typing_env(), *diff_id, diff_args).unwrap().unwrap();
11661165
let diff_symbol = symbol_name_for_instance_in_crate(tcx, fn_diff.clone(), LOCAL_CRATE);
11671166

1168-
let diff_attrs = autodiff_attrs(tcx, fn_diff.def_id());
1167+
// TODO(Sa4dUs): Store autodiff items in a single pass and just get them here
1168+
// in a O(1) step
1169+
let diff_attrs = tcx
1170+
.collect_and_partition_mono_items(())
1171+
.autodiff_items
1172+
.iter()
1173+
.find(|item| item.target == diff_symbol);
11691174
let Some(diff_attrs) = diff_attrs else { bug!("could not find autodiff attrs") };
11701175

11711176
// Build body
@@ -1176,7 +1181,7 @@ fn codegen_enzyme_autodiff<'ll, 'tcx>(
11761181
&diff_symbol,
11771182
llret_ty,
11781183
&val_arr,
1179-
diff_attrs.clone(),
1184+
diff_attrs.attrs.clone(),
11801185
result,
11811186
);
11821187
}
@@ -1193,11 +1198,22 @@ fn get_args_from_tuple<'ll, 'tcx>(
11931198
for i in 0..tuple_place.layout.layout.0.fields.count() {
11941199
let field_place = tuple_place.project_field(bx, i);
11951200
let field_layout = tuple_place.layout.field(bx, i);
1201+
let field_ty = field_layout.ty;
11961202
let llvm_ty = field_layout.llvm_type(bx.cx);
11971203

11981204
let field_val = bx.load(llvm_ty, field_place.val.llval, field_place.val.align);
11991205

1200-
ret_arr.push(field_val)
1206+
match field_ty.kind() {
1207+
ty::Ref(_, inner_ty, _) if matches!(inner_ty.kind(), ty::Slice(_)) => {
1208+
let ptr = bx.extract_value(field_val, 0);
1209+
let len = bx.extract_value(field_val, 1);
1210+
ret_arr.push(ptr);
1211+
ret_arr.push(len);
1212+
}
1213+
_ => {
1214+
ret_arr.push(field_val);
1215+
}
1216+
}
12011217
}
12021218

12031219
ret_arr

compiler/rustc_monomorphize/src/collector.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,7 @@ use rustc_span::{DUMMY_SP, Span};
240240
use tracing::{debug, instrument, trace};
241241

242242
#[cfg(llvm_enzyme)]
243-
use crate::collector::autodiff::collect_enzyme_autodiff_source_fn;
243+
use crate::collector::autodiff::collect_enzyme_autodiff_fn;
244244
use crate::errors::{self, EncounteredErrorWhileInstantiating, NoOptimizedMir, RecursionLimit};
245245

246246
#[derive(PartialEq)]
@@ -921,7 +921,7 @@ fn visit_instance_use<'tcx>(
921921
}
922922
if let Some(intrinsic) = tcx.intrinsic(instance.def_id()) {
923923
#[cfg(llvm_enzyme)]
924-
collect_enzyme_autodiff_source_fn(tcx, instance, intrinsic, output);
924+
collect_enzyme_autodiff_fn(tcx, instance, intrinsic, output);
925925

926926
if let Some(_requirement) = ValidityRequirement::from_intrinsic(intrinsic.name) {
927927
// The intrinsics assert_inhabited, assert_zero_valid, and assert_mem_uninitialized_valid will

compiler/rustc_monomorphize/src/collector/autodiff.rs

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
use rustc_middle::bug;
2-
use rustc_middle::ty::{self, IntrinsicDef, TyCtxt};
3-
use tracing::debug;
2+
use rustc_middle::ty::{self, GenericArg, IntrinsicDef, TyCtxt};
43

54
use crate::collector::{MonoItems, create_fn_mono_item};
65

7-
pub(crate) fn collect_enzyme_autodiff_source_fn<'tcx>(
6+
// Here, we force both primal and diff function to be collected in
7+
// mono so this does not interfere in `enzyme_autodiff` intrinsics
8+
// codegen process. If they are unused, they will be removed later and
9+
// won't be present at LLVM-IR.
10+
pub(crate) fn collect_enzyme_autodiff_fn<'tcx>(
811
tcx: TyCtxt<'tcx>,
912
instance: ty::Instance<'tcx>,
1013
intrinsic: IntrinsicDef,
@@ -14,8 +17,16 @@ pub(crate) fn collect_enzyme_autodiff_source_fn<'tcx>(
1417
return;
1518
};
1619

17-
debug!("enzyme_autodiff found");
18-
let (primal, span) = match instance.args[0].kind() {
20+
collect_autodiff_fn_from_arg(instance.args[0], tcx, output);
21+
collect_autodiff_fn_from_arg(instance.args[1], tcx, output);
22+
}
23+
24+
fn collect_autodiff_fn_from_arg<'tcx>(
25+
arg: GenericArg<'tcx>,
26+
tcx: TyCtxt<'tcx>,
27+
output: &mut MonoItems<'tcx>,
28+
) {
29+
let (instance, span) = match arg.kind() {
1930
rustc_middle::infer::canonical::ir::GenericArgKind::Type(ty) => match ty.kind() {
2031
ty::FnDef(def_id, substs) => {
2132
let span = tcx.def_span(def_id);
@@ -34,5 +45,5 @@ pub(crate) fn collect_enzyme_autodiff_source_fn<'tcx>(
3445
_ => bug!("expected type"),
3546
};
3647

37-
output.push(create_fn_mono_item(tcx, primal, span));
48+
output.push(create_fn_mono_item(tcx, instance, span));
3849
}

0 commit comments

Comments
 (0)