Skip to content

Commit c8905ea

Browse files
committed
Auto merge of #147128 - matthiaskrgr:rollup-mqey4c4, r=matthiaskrgr
Rollup of 6 pull requests Successful merges: - #140482 (std::net: update tcp deferaccept delay type to Duration.) - #141469 (Allow `&raw [mut | const]` for union field in safe code) - #144197 (TypeTree support in autodiff) - #146675 (Allow shared access to `Exclusive<T>` when `T: Sync`) - #147113 (Reland "Add LSX accelerated implementation for source file analysis") - #147120 (Fix --extra-checks=spellcheck to prevent cargo install every time) r? `@ghost` `@rustbot` modify labels: rollup
2 parents 8d72d3e + 4eb6b8f commit c8905ea

File tree

82 files changed

+1631
-62
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

82 files changed

+1631
-62
lines changed

compiler/rustc_ast/src/expand/autodiff_attrs.rs

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
use std::fmt::{self, Display, Formatter};
77
use std::str::FromStr;
88

9+
use crate::expand::typetree::TypeTree;
910
use crate::expand::{Decodable, Encodable, HashStable_Generic};
1011
use crate::{Ty, TyKind};
1112

@@ -84,6 +85,8 @@ pub struct AutoDiffItem {
8485
/// The name of the function being generated
8586
pub target: String,
8687
pub attrs: AutoDiffAttrs,
88+
pub inputs: Vec<TypeTree>,
89+
pub output: TypeTree,
8790
}
8891

8992
#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
@@ -275,14 +278,22 @@ impl AutoDiffAttrs {
275278
!matches!(self.mode, DiffMode::Error | DiffMode::Source)
276279
}
277280

278-
pub fn into_item(self, source: String, target: String) -> AutoDiffItem {
279-
AutoDiffItem { source, target, attrs: self }
281+
pub fn into_item(
282+
self,
283+
source: String,
284+
target: String,
285+
inputs: Vec<TypeTree>,
286+
output: TypeTree,
287+
) -> AutoDiffItem {
288+
AutoDiffItem { source, target, inputs, output, attrs: self }
280289
}
281290
}
282291

283292
impl fmt::Display for AutoDiffItem {
284293
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
285294
write!(f, "Differentiating {} -> {}", self.source, self.target)?;
286-
write!(f, " with attributes: {:?}", self.attrs)
295+
write!(f, " with attributes: {:?}", self.attrs)?;
296+
write!(f, " with inputs: {:?}", self.inputs)?;
297+
write!(f, " with output: {:?}", self.output)
287298
}
288299
}

compiler/rustc_ast/src/expand/typetree.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ pub enum Kind {
3131
Half,
3232
Float,
3333
Double,
34+
F128,
3435
Unknown,
3536
}
3637

compiler/rustc_codegen_gcc/src/builder.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1383,6 +1383,7 @@ impl<'a, 'gcc, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'gcc, 'tcx> {
13831383
_src_align: Align,
13841384
size: RValue<'gcc>,
13851385
flags: MemFlags,
1386+
_tt: Option<rustc_ast::expand::typetree::FncTree>, // Autodiff TypeTrees are LLVM-only, ignored in GCC backend
13861387
) {
13871388
assert!(!flags.contains(MemFlags::NONTEMPORAL), "non-temporal memcpy not supported");
13881389
let size = self.intcast(size, self.type_size_t(), false);

compiler/rustc_codegen_gcc/src/intrinsic/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -770,6 +770,7 @@ impl<'gcc, 'tcx> ArgAbiExt<'gcc, 'tcx> for ArgAbi<'tcx, Ty<'tcx>> {
770770
scratch_align,
771771
bx.const_usize(self.layout.size.bytes()),
772772
MemFlags::empty(),
773+
None,
773774
);
774775

775776
bx.lifetime_end(scratch, scratch_size);

compiler/rustc_codegen_llvm/src/abi.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,7 @@ impl<'ll, 'tcx> ArgAbiExt<'ll, 'tcx> for ArgAbi<'tcx, Ty<'tcx>> {
246246
scratch_align,
247247
bx.const_usize(copy_bytes),
248248
MemFlags::empty(),
249+
None,
249250
);
250251
bx.lifetime_end(llscratch, scratch_size);
251252
}

compiler/rustc_codegen_llvm/src/back/lto.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -563,6 +563,8 @@ fn enable_autodiff_settings(ad: &[config::AutoDiff]) {
563563
config::AutoDiff::Enable => {}
564564
// We handle this below
565565
config::AutoDiff::NoPostopt => {}
566+
// Disables TypeTree generation
567+
config::AutoDiff::NoTT => {}
566568
}
567569
}
568570
// This helps with handling enums for now.

compiler/rustc_codegen_llvm/src/builder.rs

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ use std::borrow::{Borrow, Cow};
22
use std::ops::Deref;
33
use std::{iter, ptr};
44

5+
use rustc_ast::expand::typetree::FncTree;
56
pub(crate) mod autodiff;
67
pub(crate) mod gpu_offload;
78

@@ -1107,11 +1108,12 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> {
11071108
src_align: Align,
11081109
size: &'ll Value,
11091110
flags: MemFlags,
1111+
tt: Option<FncTree>,
11101112
) {
11111113
assert!(!flags.contains(MemFlags::NONTEMPORAL), "non-temporal memcpy not supported");
11121114
let size = self.intcast(size, self.type_isize(), false);
11131115
let is_volatile = flags.contains(MemFlags::VOLATILE);
1114-
unsafe {
1116+
let memcpy = unsafe {
11151117
llvm::LLVMRustBuildMemCpy(
11161118
self.llbuilder,
11171119
dst,
@@ -1120,7 +1122,16 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> {
11201122
src_align.bytes() as c_uint,
11211123
size,
11221124
is_volatile,
1123-
);
1125+
)
1126+
};
1127+
1128+
// TypeTree metadata for memcpy is especially important: when Enzyme encounters
1129+
// a memcpy during autodiff, it needs to know the structure of the data being
1130+
// copied to properly track derivatives. For example, copying an array of floats
1131+
// vs. copying a struct with mixed types requires different derivative handling.
1132+
// The TypeTree tells Enzyme exactly what memory layout to expect.
1133+
if let Some(tt) = tt {
1134+
crate::typetree::add_tt(self.cx().llmod, self.cx().llcx, memcpy, tt);
11241135
}
11251136
}
11261137

compiler/rustc_codegen_llvm/src/builder/autodiff.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
use std::ptr;
22

33
use rustc_ast::expand::autodiff_attrs::{AutoDiffAttrs, DiffActivity, DiffMode};
4+
use rustc_ast::expand::typetree::FncTree;
45
use rustc_codegen_ssa::common::TypeKind;
56
use rustc_codegen_ssa::traits::{BaseTypeCodegenMethods, BuilderMethods};
67
use rustc_middle::ty::{Instance, PseudoCanonicalInput, TyCtxt, TypingEnv};
@@ -294,6 +295,7 @@ pub(crate) fn generate_enzyme_call<'ll, 'tcx>(
294295
fn_args: &[&'ll Value],
295296
attrs: AutoDiffAttrs,
296297
dest: PlaceRef<'tcx, &'ll Value>,
298+
fnc_tree: FncTree,
297299
) {
298300
// We have to pick the name depending on whether we want forward or reverse mode autodiff.
299301
let mut ad_name: String = match attrs.mode {
@@ -370,6 +372,10 @@ pub(crate) fn generate_enzyme_call<'ll, 'tcx>(
370372
fn_args,
371373
);
372374

375+
if !fnc_tree.args.is_empty() || !fnc_tree.ret.0.is_empty() {
376+
crate::typetree::add_tt(cx.llmod, cx.llcx, fn_to_diff, fnc_tree);
377+
}
378+
373379
let call = builder.call(enzyme_ty, None, None, ad_fn, &args, None, None);
374380

375381
builder.store_to_place(call, dest.val);

compiler/rustc_codegen_llvm/src/intrinsic.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1212,6 +1212,9 @@ fn codegen_autodiff<'ll, 'tcx>(
12121212
&mut diff_attrs.input_activity,
12131213
);
12141214

1215+
let fnc_tree =
1216+
rustc_middle::ty::fnc_typetrees(tcx, fn_source.ty(tcx, TypingEnv::fully_monomorphized()));
1217+
12151218
// Build body
12161219
generate_enzyme_call(
12171220
bx,
@@ -1222,6 +1225,7 @@ fn codegen_autodiff<'ll, 'tcx>(
12221225
&val_arr,
12231226
diff_attrs.clone(),
12241227
result,
1228+
fnc_tree,
12251229
);
12261230
}
12271231

compiler/rustc_codegen_llvm/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ mod llvm_util;
6868
mod mono_item;
6969
mod type_;
7070
mod type_of;
71+
mod typetree;
7172
mod va_arg;
7273
mod value;
7374

0 commit comments

Comments
 (0)