Skip to content

Commit cc2a752

Browse files
committed
Do not depend on mono anymore
1 parent 5df63cb commit cc2a752

File tree

2 files changed

+90
-12
lines changed

2 files changed

+90
-12
lines changed

compiler/rustc_codegen_llvm/src/builder/autodiff.rs

Lines changed: 78 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@ use std::ptr;
33
use rustc_ast::expand::autodiff_attrs::{AutoDiffAttrs, DiffActivity, DiffMode};
44
use rustc_codegen_ssa::common::TypeKind;
55
use rustc_codegen_ssa::traits::{BaseTypeCodegenMethods, BuilderMethods};
6-
use rustc_middle::bug;
6+
use rustc_middle::{bug, ty};
7+
use rustc_middle::ty::{PseudoCanonicalInput, Ty, TyCtxt, TypingEnv};
78
use tracing::debug;
89

910
use crate::builder::{Builder, PlaceRef, UNNAMED};
@@ -14,6 +15,82 @@ use crate::llvm::{Metadata, True, Type};
1415
use crate::value::Value;
1516
use crate::{attributes, llvm};
1617

18+
pub(crate) fn adjust_activity_to_abi<'tcx>(
19+
tcx: TyCtxt<'tcx>,
20+
fn_ty: Ty<'tcx>,
21+
da: &mut Vec<DiffActivity>,
22+
) {
23+
if !matches!(fn_ty.kind(), ty::FnDef(..)) {
24+
bug!("expected fn def for autodiff, got {:?}", fn_ty);
25+
}
26+
27+
// We don't actually pass the types back into the type system.
28+
// All we do is decide how to handle the arguments.
29+
let sig = fn_ty.fn_sig(tcx).skip_binder();
30+
31+
let mut new_activities = vec![];
32+
let mut new_positions = vec![];
33+
for (i, ty) in sig.inputs().iter().enumerate() {
34+
if let Some(inner_ty) = ty.builtin_deref(true) {
35+
if inner_ty.is_slice() {
36+
// Now we need to figure out the size of each slice element in memory to allow
37+
// safety checks and usability improvements in the backend.
38+
let sty = match inner_ty.builtin_index() {
39+
Some(sty) => sty,
40+
None => {
41+
panic!("slice element type unknown");
42+
}
43+
};
44+
let pci = PseudoCanonicalInput {
45+
typing_env: TypingEnv::fully_monomorphized(),
46+
value: sty,
47+
};
48+
49+
let layout = tcx.layout_of(pci);
50+
let elem_size = match layout {
51+
Ok(layout) => layout.size,
52+
Err(_) => {
53+
bug!("autodiff failed to compute slice element size");
54+
}
55+
};
56+
let elem_size: u32 = elem_size.bytes() as u32;
57+
58+
// We know that the length will be passed as extra arg.
59+
if !da.is_empty() {
60+
// We are looking at a slice. The length of that slice will become an
61+
// extra integer on llvm level. Integers are always const.
62+
// However, if the slice get's duplicated, we want to know to later check the
63+
// size. So we mark the new size argument as FakeActivitySize.
64+
// There is one FakeActivitySize per slice, so for convenience we store the
65+
// slice element size in bytes in it. We will use the size in the backend.
66+
let activity = match da[i] {
67+
DiffActivity::DualOnly
68+
| DiffActivity::Dual
69+
| DiffActivity::Dualv
70+
| DiffActivity::DuplicatedOnly
71+
| DiffActivity::Duplicated => {
72+
DiffActivity::FakeActivitySize(Some(elem_size))
73+
}
74+
DiffActivity::Const => DiffActivity::Const,
75+
_ => bug!("unexpected activity for ptr/ref"),
76+
};
77+
new_activities.push(activity);
78+
new_positions.push(i + 1);
79+
}
80+
81+
continue;
82+
}
83+
}
84+
}
85+
// now add the extra activities coming from slices
86+
// Reverse order to not invalidate the indices
87+
for _ in 0..new_activities.len() {
88+
let pos = new_positions.pop().unwrap();
89+
let activity = new_activities.pop().unwrap();
90+
da.insert(pos, activity);
91+
}
92+
}
93+
1794
// When we call the `__enzyme_autodiff` or `__enzyme_fwddiff` function, we need to pass all the
1895
// original inputs, as well as metadata and the additional shadow arguments.
1996
// This function matches the arguments from the outer function to the inner enzyme call.

compiler/rustc_codegen_llvm/src/intrinsic.rs

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ use std::cmp::Ordering;
33

44
use rustc_abi::{Align, BackendRepr, ExternAbi, Float, HasDataLayout, Primitive, Size};
55
use rustc_codegen_ssa::base::{compare_simd_types, wants_msvc_seh, wants_wasm_eh};
6+
use rustc_codegen_ssa::codegen_attrs::autodiff_attrs;
67
use rustc_codegen_ssa::common::{IntPredicate, TypeKind};
78
use rustc_codegen_ssa::errors::{ExpectedPointerMutability, InvalidMonomorphization};
89
use rustc_codegen_ssa::mir::operand::{OperandRef, OperandValue};
@@ -12,7 +13,7 @@ use rustc_hir::def_id::LOCAL_CRATE;
1213
use rustc_hir::{self as hir};
1314
use rustc_middle::mir::BinOp;
1415
use rustc_middle::ty::layout::{FnAbiOf, HasTyCtxt, HasTypingEnv, LayoutOf};
15-
use rustc_middle::ty::{self, GenericArgsRef, Instance, Ty, TyCtxt};
16+
use rustc_middle::ty::{self, GenericArgsRef, Instance, Ty, TyCtxt, TypingEnv};
1617
use rustc_middle::{bug, span_bug};
1718
use rustc_span::{Span, Symbol, sym};
1819
use rustc_symbol_mangling::{mangle_internal_symbol, symbol_name_for_instance_in_crate};
@@ -21,7 +22,7 @@ use tracing::debug;
2122

2223
use crate::abi::FnAbiLlvmExt;
2324
use crate::builder::Builder;
24-
use crate::builder::autodiff::generate_enzyme_call;
25+
use crate::builder::autodiff::{adjust_activity_to_abi, generate_enzyme_call};
2526
use crate::context::CodegenCx;
2627
use crate::llvm::{self, Metadata};
2728
use crate::type_::Type;
@@ -1166,14 +1167,14 @@ fn codegen_enzyme_autodiff<'ll, 'tcx>(
11661167
Instance::try_resolve(tcx, bx.cx.typing_env(), *diff_id, diff_args).unwrap().unwrap();
11671168
let diff_symbol = symbol_name_for_instance_in_crate(tcx, fn_diff.clone(), LOCAL_CRATE);
11681169

1169-
// TODO(Sa4dUs): Store autodiff items in a single pass and just get them here
1170-
// in a O(1) step
1171-
let diff_attrs = tcx
1172-
.collect_and_partition_mono_items(())
1173-
.autodiff_items
1174-
.iter()
1175-
.find(|item| item.target == diff_symbol);
1176-
let Some(diff_attrs) = diff_attrs else { bug!("could not find autodiff attrs") };
1170+
let diff_attrs = autodiff_attrs(tcx, fn_diff.def_id());
1171+
let Some(mut diff_attrs) = diff_attrs else { bug!("could not find autodiff attrs") };
1172+
1173+
adjust_activity_to_abi(
1174+
tcx,
1175+
fn_source.ty(tcx, TypingEnv::fully_monomorphized()),
1176+
&mut diff_attrs.input_activity,
1177+
);
11771178

11781179
// Build body
11791180
generate_enzyme_call(
@@ -1183,7 +1184,7 @@ fn codegen_enzyme_autodiff<'ll, 'tcx>(
11831184
&diff_symbol,
11841185
llret_ty,
11851186
&val_arr,
1186-
diff_attrs.attrs.clone(),
1187+
diff_attrs.clone(),
11871188
result,
11881189
);
11891190
}

0 commit comments

Comments
 (0)