Skip to content

Commit 664e83b

Browse files
committed
added typetree support for memcpy
1 parent 5d3ebc3 commit 664e83b

File tree

21 files changed

+135
-34
lines changed

21 files changed

+135
-34
lines changed

compiler/rustc_codegen_gcc/src/builder.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1383,6 +1383,7 @@ impl<'a, 'gcc, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'gcc, 'tcx> {
13831383
_src_align: Align,
13841384
size: RValue<'gcc>,
13851385
flags: MemFlags,
1386+
_tt: Option<rustc_ast::expand::typetree::FncTree>, // Autodiff TypeTrees are LLVM-only, ignored in GCC backend
13861387
) {
13871388
assert!(!flags.contains(MemFlags::NONTEMPORAL), "non-temporal memcpy not supported");
13881389
let size = self.intcast(size, self.type_size_t(), false);

compiler/rustc_codegen_gcc/src/intrinsic/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -771,6 +771,7 @@ impl<'gcc, 'tcx> ArgAbiExt<'gcc, 'tcx> for ArgAbi<'tcx, Ty<'tcx>> {
771771
scratch_align,
772772
bx.const_usize(self.layout.size.bytes()),
773773
MemFlags::empty(),
774+
None,
774775
);
775776

776777
bx.lifetime_end(scratch, scratch_size);

compiler/rustc_codegen_llvm/src/abi.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,7 @@ impl<'ll, 'tcx> ArgAbiExt<'ll, 'tcx> for ArgAbi<'tcx, Ty<'tcx>> {
246246
scratch_align,
247247
bx.const_usize(copy_bytes),
248248
MemFlags::empty(),
249+
None,
249250
);
250251
bx.lifetime_end(llscratch, scratch_size);
251252
}

compiler/rustc_codegen_llvm/src/builder.rs

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ use std::borrow::{Borrow, Cow};
22
use std::ops::Deref;
33
use std::{iter, ptr};
44

5+
use rustc_ast::expand::typetree::FncTree;
56
pub(crate) mod autodiff;
67
pub(crate) mod gpu_offload;
78

@@ -1107,11 +1108,12 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> {
11071108
src_align: Align,
11081109
size: &'ll Value,
11091110
flags: MemFlags,
1111+
tt: Option<FncTree>,
11101112
) {
11111113
assert!(!flags.contains(MemFlags::NONTEMPORAL), "non-temporal memcpy not supported");
11121114
let size = self.intcast(size, self.type_isize(), false);
11131115
let is_volatile = flags.contains(MemFlags::VOLATILE);
1114-
unsafe {
1116+
let memcpy = unsafe {
11151117
llvm::LLVMRustBuildMemCpy(
11161118
self.llbuilder,
11171119
dst,
@@ -1120,7 +1122,16 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> {
11201122
src_align.bytes() as c_uint,
11211123
size,
11221124
is_volatile,
1123-
);
1125+
)
1126+
};
1127+
1128+
// TypeTree metadata for memcpy is especially important: when Enzyme encounters
1129+
// a memcpy during autodiff, it needs to know the structure of the data being
1130+
// copied to properly track derivatives. For example, copying an array of floats
1131+
// vs. copying a struct with mixed types requires different derivative handling.
1132+
// The TypeTree tells Enzyme exactly what memory layout to expect.
1133+
if let Some(tt) = tt {
1134+
crate::typetree::add_tt(self.cx().llmod, self.cx().llcx, memcpy, tt);
11241135
}
11251136
}
11261137

compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ pub(crate) enum CConcreteType {
2525
DT_Half = 3,
2626
DT_Float = 4,
2727
DT_Double = 5,
28+
// FIXME(KMJ-007): handle f128 using long double here(https://github.com/EnzymeAD/Enzyme/issues/1600)
2829
DT_Unknown = 6,
2930
}
3031

compiler/rustc_codegen_llvm/src/typetree.rs

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
1-
use std::ffi::{CString, c_char, c_uint};
2-
3-
use rustc_ast::expand::typetree::{FncTree, TypeTree as RustTypeTree};
1+
use rustc_ast::expand::typetree::FncTree;
2+
#[cfg(llvm_enzyme)]
3+
use {
4+
crate::attributes,
5+
rustc_ast::expand::typetree::TypeTree as RustTypeTree,
6+
std::ffi::{CString, c_char, c_uint},
7+
};
48

5-
use crate::attributes;
69
use crate::llvm::{self, Value};
710

811
/// Converts a Rust TypeTree to Enzyme's internal TypeTree format
@@ -50,15 +53,6 @@ fn to_enzyme_typetree(
5053
enzyme_tt
5154
}
5255

53-
#[cfg(not(llvm_enzyme))]
54-
fn to_enzyme_typetree(
55-
_rust_typetree: RustTypeTree,
56-
_data_layout: &str,
57-
_llcx: &llvm::Context,
58-
) -> ! {
59-
unimplemented!("TypeTree conversion not available without llvm_enzyme support")
60-
}
61-
6256
// Attaches TypeTree information to LLVM function as enzyme_type attributes.
6357
#[cfg(llvm_enzyme)]
6458
pub(crate) fn add_tt<'ll>(

compiler/rustc_codegen_llvm/src/va_arg.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -738,6 +738,7 @@ fn copy_to_temporary_if_more_aligned<'ll, 'tcx>(
738738
src_align,
739739
bx.const_u32(layout.layout.size().bytes() as u32),
740740
MemFlags::empty(),
741+
None,
741742
);
742743
tmp
743744
} else {

compiler/rustc_codegen_ssa/src/mir/block.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1626,6 +1626,7 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
16261626
align,
16271627
bx.const_usize(copy_bytes),
16281628
MemFlags::empty(),
1629+
None,
16291630
);
16301631
// ...and then load it with the ABI type.
16311632
llval = load_cast(bx, cast, llscratch, scratch_align);

compiler/rustc_codegen_ssa/src/mir/intrinsic.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ fn copy_intrinsic<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>>(
3030
if allow_overlap {
3131
bx.memmove(dst, align, src, align, size, flags);
3232
} else {
33-
bx.memcpy(dst, align, src, align, size, flags);
33+
bx.memcpy(dst, align, src, align, size, flags, None);
3434
}
3535
}
3636

compiler/rustc_codegen_ssa/src/mir/statement.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
9090
let align = pointee_layout.align;
9191
let dst = dst_val.immediate();
9292
let src = src_val.immediate();
93-
bx.memcpy(dst, align, src, align, bytes, crate::MemFlags::empty());
93+
bx.memcpy(dst, align, src, align, bytes, crate::MemFlags::empty(), None);
9494
}
9595
mir::StatementKind::FakeRead(..)
9696
| mir::StatementKind::Retag { .. }

0 commit comments

Comments
 (0)