Skip to content

Commit e6982b4

Browse files
committed
memcpy compiling
Signed-off-by: Karan Janthe <[email protected]>
1 parent f293c6d commit e6982b4

File tree

9 files changed

+290
-7
lines changed

9 files changed

+290
-7
lines changed

compiler/rustc_codegen_llvm/src/abi.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,7 @@ impl<'ll, 'tcx> ArgAbiExt<'ll, 'tcx> for ArgAbi<'tcx, Ty<'tcx>> {
238238
scratch_align,
239239
bx.const_usize(copy_bytes),
240240
MemFlags::empty(),
241+
None,
241242
);
242243
bx.lifetime_end(llscratch, scratch_size);
243244
}

compiler/rustc_codegen_llvm/src/builder.rs

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@ use std::borrow::{Borrow, Cow};
22
use std::ops::Deref;
33
use std::{iter, ptr};
44

5+
use rustc_ast::expand::typetree::FncTree;
56
pub(crate) mod autodiff;
6-
77
use libc::{c_char, c_uint, size_t};
88
use rustc_abi as abi;
99
use rustc_abi::{Align, Size, WrappingRange};
@@ -30,6 +30,7 @@ use tracing::{debug, instrument};
3030

3131
use crate::abi::FnAbiLlvmExt;
3232
use crate::attributes;
33+
use crate::builder::autodiff::add_tt;
3334
use crate::common::Funclet;
3435
use crate::context::{CodegenCx, FullCx, GenericCx, SCx};
3536
use crate::llvm::{
@@ -1036,11 +1037,12 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> {
10361037
src_align: Align,
10371038
size: &'ll Value,
10381039
flags: MemFlags,
1040+
tt: Option<FncTree>,
10391041
) {
10401042
assert!(!flags.contains(MemFlags::NONTEMPORAL), "non-temporal memcpy not supported");
10411043
let size = self.intcast(size, self.type_isize(), false);
10421044
let is_volatile = flags.contains(MemFlags::VOLATILE);
1043-
unsafe {
1045+
let memcpy = unsafe {
10441046
llvm::LLVMRustBuildMemCpy(
10451047
self.llbuilder,
10461048
dst,
@@ -1049,7 +1051,16 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> {
10491051
src_align.bytes() as c_uint,
10501052
size,
10511053
is_volatile,
1052-
);
1054+
)
1055+
};
1056+
1057+
// TypeTree metadata for memcpy is especially important: when Enzyme encounters
1058+
// a memcpy during autodiff, it needs to know the structure of the data being
1059+
// copied to properly track derivatives. For example, copying an array of floats
1060+
// vs. copying a struct with mixed types requires different derivative handling.
1061+
// The TypeTree tells Enzyme exactly what memory layout to expect.
1062+
if let Some(tt) = tt {
1063+
add_tt(self.cx().llmod, self.cx().llcx, memcpy, tt);
10531064
}
10541065
}
10551066

compiler/rustc_codegen_llvm/src/builder/autodiff.rs

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1+
use std::os::raw::{c_char, c_uint};
12
use std::ptr;
23

34
use rustc_ast::expand::autodiff_attrs::{AutoDiffAttrs, AutoDiffItem, DiffActivity, DiffMode};
5+
use rustc_ast::expand::typetree::{FncTree, TypeTree as RustTypeTree};
46
use rustc_codegen_ssa::ModuleCodegen;
57
use rustc_codegen_ssa::common::TypeKind;
68
use rustc_codegen_ssa::traits::BaseTypeCodegenMethods;
@@ -512,3 +514,128 @@ pub(crate) fn differentiate<'ll>(
512514

513515
Ok(())
514516
}
517+
518+
/// Converts a Rust TypeTree to Enzyme's internal TypeTree format
519+
///
520+
/// This function takes a Rust-side TypeTree (from rustc_ast::expand::typetree)
521+
/// and converts it to Enzyme's internal C++ TypeTree representation that
522+
/// Enzyme can understand during differentiation analysis.
523+
fn to_enzyme_typetree(
524+
rust_typetree: RustTypeTree,
525+
data_layout: &str,
526+
llcx: &llvm::Context,
527+
) -> llvm::TypeTree {
528+
// Start with an empty TypeTree
529+
let mut enzyme_tt = llvm::TypeTree::new();
530+
531+
// Convert each Type in the Rust TypeTree to Enzyme format
532+
for rust_type in rust_typetree.0 {
533+
let concrete_type = match rust_type.kind {
534+
rustc_ast::expand::typetree::Kind::Anything => llvm::CConcreteType::DT_Anything,
535+
rustc_ast::expand::typetree::Kind::Integer => llvm::CConcreteType::DT_Integer,
536+
rustc_ast::expand::typetree::Kind::Pointer => llvm::CConcreteType::DT_Pointer,
537+
rustc_ast::expand::typetree::Kind::Half => llvm::CConcreteType::DT_Half,
538+
rustc_ast::expand::typetree::Kind::Float => llvm::CConcreteType::DT_Float,
539+
rustc_ast::expand::typetree::Kind::Double => llvm::CConcreteType::DT_Double,
540+
rustc_ast::expand::typetree::Kind::Unknown => llvm::CConcreteType::DT_Unknown,
541+
};
542+
543+
// Create a TypeTree for this specific type
544+
let type_tt = llvm::TypeTree::from_type(concrete_type, llcx);
545+
546+
// Apply offset if specified
547+
let type_tt = if rust_type.offset == -1 {
548+
type_tt // -1 means everywhere/no specific offset
549+
} else {
550+
// Apply specific offset positioning
551+
type_tt.shift(data_layout, rust_type.offset, rust_type.size as isize, 0)
552+
};
553+
554+
// Merge this type into the main TypeTree
555+
enzyme_tt = enzyme_tt.merge(type_tt);
556+
}
557+
558+
enzyme_tt
559+
}
560+
561+
/// Attaches TypeTree information to LLVM function as enzyme_type attributes.
562+
///
563+
/// This function converts Rust TypeTrees to Enzyme format and attaches them as
564+
/// LLVM string attributes. Enzyme reads these attributes during autodiff analysis
565+
/// to understand the memory layout and generate correct derivative code.
566+
///
567+
/// # Arguments
568+
/// * `llmod` - LLVM module containing the function
569+
/// * `llcx` - LLVM context for creating attributes
570+
/// * `fn_def` - LLVM function to attach TypeTrees to
571+
/// * `tt` - Function TypeTree containing input and return type information
572+
pub(crate) fn add_tt<'ll>(
573+
llmod: &'ll llvm::Module,
574+
llcx: &'ll llvm::Context,
575+
fn_def: &'ll Value,
576+
tt: FncTree,
577+
) {
578+
let inputs = tt.args;
579+
let ret_tt: RustTypeTree = tt.ret;
580+
581+
// Get LLVM data layout string for TypeTree conversion
582+
let llvm_data_layout: *const c_char = unsafe { llvm::LLVMGetDataLayoutStr(&*llmod) };
583+
let llvm_data_layout =
584+
std::str::from_utf8(unsafe { std::ffi::CStr::from_ptr(llvm_data_layout) }.to_bytes())
585+
.expect("got a non-UTF8 data-layout from LLVM");
586+
587+
// Attribute name that Enzyme recognizes for TypeTree information
588+
let attr_name = "enzyme_type";
589+
let c_attr_name = std::ffi::CString::new(attr_name).unwrap();
590+
591+
// Attach TypeTree attributes to each input parameter
592+
// Enzyme uses these to understand parameter memory layouts during differentiation
593+
for (i, input) in inputs.iter().enumerate() {
594+
unsafe {
595+
// Convert Rust TypeTree to Enzyme's internal format
596+
let enzyme_tt = to_enzyme_typetree(input.clone(), llvm_data_layout, llcx);
597+
598+
// Serialize TypeTree to string format that Enzyme can parse
599+
let c_str = llvm::EnzymeTypeTreeToString(enzyme_tt.inner);
600+
let c_str = std::ffi::CStr::from_ptr(c_str);
601+
602+
// Create LLVM string attribute with TypeTree information
603+
let attr = llvm::LLVMCreateStringAttribute(
604+
llcx,
605+
c_attr_name.as_ptr(),
606+
c_attr_name.as_bytes().len() as c_uint,
607+
c_str.as_ptr(),
608+
c_str.to_bytes().len() as c_uint,
609+
);
610+
611+
// Attach attribute to the specific function parameter
612+
// Note: ArgumentPlace uses 0-based indexing, but LLVM uses 1-based for arguments
613+
attributes::apply_to_llfn(fn_def, llvm::AttributePlace::Argument(i as u32), &[attr]);
614+
615+
// Free the C string to prevent memory leaks
616+
llvm::EnzymeTypeTreeToStringFree(c_str.as_ptr());
617+
}
618+
}
619+
620+
// Attach TypeTree attribute to the return type
621+
// Enzyme needs this to understand how to handle return value derivatives
622+
unsafe {
623+
let enzyme_tt = to_enzyme_typetree(ret_tt, llvm_data_layout, llcx);
624+
let c_str = llvm::EnzymeTypeTreeToString(enzyme_tt.inner);
625+
let c_str = std::ffi::CStr::from_ptr(c_str);
626+
627+
let ret_attr = llvm::LLVMCreateStringAttribute(
628+
llcx,
629+
c_attr_name.as_ptr(),
630+
c_attr_name.as_bytes().len() as c_uint,
631+
c_str.as_ptr(),
632+
c_str.to_bytes().len() as c_uint,
633+
);
634+
635+
// Attach to function return type
636+
attributes::apply_to_llfn(fn_def, llvm::AttributePlace::ReturnValue, &[ret_attr]);
637+
638+
// Free the C string
639+
llvm::EnzymeTypeTreeToStringFree(c_str.as_ptr());
640+
}
641+
}

compiler/rustc_codegen_llvm/src/llvm/ffi.rs

Lines changed: 141 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2135,7 +2135,7 @@ unsafe extern "C" {
21352135
SPFlags: DISPFlags,
21362136
MaybeFn: Option<&'a Value>,
21372137
TParam: &'a DIArray,
2138-
Decl: Option<&'a DIDescriptor>,
2138+
Decl: Option<&'a DIDescriptor>,
21392139
) -> &'a DISubprogram;
21402140

21412141
pub(crate) fn LLVMRustDIBuilderCreateMethod<'a>(
@@ -2664,4 +2664,144 @@ unsafe extern "C" {
26642664

26652665
pub(crate) fn LLVMRustSetNoSanitizeAddress(Global: &Value);
26662666
pub(crate) fn LLVMRustSetNoSanitizeHWAddress(Global: &Value);
2667+
2668+
// ========== ENZYME AUTODIFF FFI FUNCTIONS ==========
2669+
2670+
// Enzyme Type Tree Functions (minimal set for TypeTree support)
2671+
pub(crate) fn EnzymeNewTypeTree() -> CTypeTreeRef;
2672+
pub(crate) fn EnzymeFreeTypeTree(CTT: CTypeTreeRef);
2673+
pub(crate) fn EnzymeNewTypeTreeCT(arg1: CConcreteType, ctx: &Context) -> CTypeTreeRef;
2674+
pub(crate) fn EnzymeNewTypeTreeTR(arg1: CTypeTreeRef) -> CTypeTreeRef;
2675+
pub(crate) fn EnzymeMergeTypeTree(arg1: CTypeTreeRef, arg2: CTypeTreeRef);
2676+
pub(crate) fn EnzymeTypeTreeShiftIndiciesEq(
2677+
arg1: CTypeTreeRef,
2678+
data_layout: *const c_char,
2679+
offset: i64,
2680+
max_size: i64,
2681+
add_offset: u64,
2682+
);
2683+
pub(crate) fn EnzymeTypeTreeToString(arg1: CTypeTreeRef) -> *const c_char;
2684+
pub(crate) fn EnzymeTypeTreeToStringFree(arg1: *const c_char);
26672685
}
2686+
2687+
// ========== ENZYME TYPES AND ENUMS ==========
2688+
2689+
// Type Tree Support for Autodiff
2690+
2691+
2692+
2693+
2694+
2695+
2696+
2697+
2698+
2699+
2700+
2701+
2702+
2703+
2704+
#[repr(u32)]
2705+
#[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)]
2706+
pub(crate) enum CConcreteType {
2707+
DT_Anything = 0,
2708+
DT_Integer = 1,
2709+
DT_Pointer = 2,
2710+
DT_Half = 3,
2711+
DT_Float = 4,
2712+
DT_Double = 5,
2713+
DT_Unknown = 6,
2714+
}
2715+
2716+
pub(crate) type CTypeTreeRef = *mut EnzymeTypeTree;
2717+
2718+
#[repr(C)]
2719+
#[derive(Debug, Copy, Clone)]
2720+
pub(crate) struct EnzymeTypeTree {
2721+
_unused: [u8; 0],
2722+
}
2723+
2724+
2725+
2726+
2727+
2728+
2729+
2730+
// TypeTree wrapper for Rust-side type safety and memory management
2731+
pub(crate) struct TypeTree {
2732+
pub(crate) inner: CTypeTreeRef,
2733+
}
2734+
2735+
impl TypeTree {
2736+
pub(crate) fn new() -> TypeTree {
2737+
let inner = unsafe { EnzymeNewTypeTree() };
2738+
TypeTree { inner }
2739+
}
2740+
2741+
pub(crate) fn from_type(t: CConcreteType, ctx: &Context) -> TypeTree {
2742+
let inner = unsafe { EnzymeNewTypeTreeCT(t, ctx) };
2743+
TypeTree { inner }
2744+
}
2745+
2746+
pub(crate) fn merge(self, other: Self) -> Self {
2747+
unsafe {
2748+
EnzymeMergeTypeTree(self.inner, other.inner);
2749+
}
2750+
drop(other);
2751+
self
2752+
}
2753+
2754+
#[must_use]
2755+
pub(crate) fn shift(self, layout: &str, offset: isize, max_size: isize, add_offset: usize) -> Self {
2756+
let layout = std::ffi::CString::new(layout).unwrap();
2757+
2758+
unsafe {
2759+
EnzymeTypeTreeShiftIndiciesEq(
2760+
self.inner,
2761+
layout.as_ptr(),
2762+
offset as i64,
2763+
max_size as i64,
2764+
add_offset as u64,
2765+
)
2766+
}
2767+
2768+
self
2769+
}
2770+
}
2771+
2772+
impl Clone for TypeTree {
2773+
fn clone(&self) -> Self {
2774+
let inner = unsafe { EnzymeNewTypeTreeTR(self.inner) };
2775+
TypeTree { inner }
2776+
}
2777+
}
2778+
2779+
impl std::fmt::Display for TypeTree {
2780+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2781+
let ptr = unsafe { EnzymeTypeTreeToString(self.inner) };
2782+
let cstr = unsafe { std::ffi::CStr::from_ptr(ptr) };
2783+
match cstr.to_str() {
2784+
Ok(x) => write!(f, "{}", x)?,
2785+
Err(err) => write!(f, "could not parse: {}", err)?,
2786+
}
2787+
2788+
// delete C string pointer
2789+
unsafe { EnzymeTypeTreeToStringFree(ptr) }
2790+
2791+
Ok(())
2792+
}
2793+
}
2794+
2795+
impl std::fmt::Debug for TypeTree {
2796+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2797+
<Self as std::fmt::Display>::fmt(self, f)
2798+
}
2799+
}
2800+
2801+
impl Drop for TypeTree {
2802+
fn drop(&mut self) {
2803+
unsafe { EnzymeFreeTypeTree(self.inner) }
2804+
}
2805+
}
2806+
2807+

compiler/rustc_codegen_llvm/src/va_arg.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -735,6 +735,7 @@ fn copy_to_temporary_if_more_aligned<'ll, 'tcx>(
735735
src_align,
736736
bx.const_u32(layout.layout.size().bytes() as u32),
737737
MemFlags::empty(),
738+
None,
738739
);
739740
tmp
740741
} else {

compiler/rustc_codegen_ssa/src/mir/block.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1550,6 +1550,7 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
15501550
align,
15511551
bx.const_usize(copy_bytes),
15521552
MemFlags::empty(),
1553+
None,
15531554
);
15541555
// ...and then load it with the ABI type.
15551556
llval = load_cast(bx, cast, llscratch, scratch_align);

compiler/rustc_codegen_ssa/src/mir/intrinsic.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ fn copy_intrinsic<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>>(
3030
if allow_overlap {
3131
bx.memmove(dst, align, src, align, size, flags);
3232
} else {
33-
bx.memcpy(dst, align, src, align, size, flags);
33+
bx.memcpy(dst, align, src, align, size, flags, None);
3434
}
3535
}
3636

compiler/rustc_codegen_ssa/src/mir/statement.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
9090
let align = pointee_layout.align;
9191
let dst = dst_val.immediate();
9292
let src = src_val.immediate();
93-
bx.memcpy(dst, align, src, align, bytes, crate::MemFlags::empty());
93+
bx.memcpy(dst, align, src, align, bytes, crate::MemFlags::empty(), None);
9494
}
9595
mir::StatementKind::FakeRead(..)
9696
| mir::StatementKind::Retag { .. }

0 commit comments

Comments
 (0)