Skip to content

Commit a03bfe5

Browse files
committed
memcpy compiling
Signed-off-by: Karan Janthe <[email protected]>
1 parent 0a56a89 commit a03bfe5

File tree

11 files changed

+292
-8
lines changed

11 files changed

+292
-8
lines changed

compiler/rustc_codegen_llvm/src/abi.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,7 @@ impl<'ll, 'tcx> ArgAbiExt<'ll, 'tcx> for ArgAbi<'tcx, Ty<'tcx>> {
238238
scratch_align,
239239
bx.const_usize(copy_bytes),
240240
MemFlags::empty(),
241+
None,
241242
);
242243
bx.lifetime_end(llscratch, scratch_size);
243244
}

compiler/rustc_codegen_llvm/src/builder.rs

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ use std::borrow::{Borrow, Cow};
22
use std::ops::Deref;
33
use std::{iter, ptr};
44

5+
use rustc_ast::expand::typetree::FncTree;
56
pub(crate) mod autodiff;
67
pub(crate) mod gpu_offload;
78

@@ -31,6 +32,7 @@ use tracing::{debug, instrument};
3132

3233
use crate::abi::FnAbiLlvmExt;
3334
use crate::attributes;
35+
use crate::builder::autodiff::add_tt;
3436
use crate::common::Funclet;
3537
use crate::context::{CodegenCx, FullCx, GenericCx, SCx};
3638
use crate::llvm::{
@@ -1105,11 +1107,12 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> {
11051107
src_align: Align,
11061108
size: &'ll Value,
11071109
flags: MemFlags,
1110+
tt: Option<FncTree>,
11081111
) {
11091112
assert!(!flags.contains(MemFlags::NONTEMPORAL), "non-temporal memcpy not supported");
11101113
let size = self.intcast(size, self.type_isize(), false);
11111114
let is_volatile = flags.contains(MemFlags::VOLATILE);
1112-
unsafe {
1115+
let memcpy = unsafe {
11131116
llvm::LLVMRustBuildMemCpy(
11141117
self.llbuilder,
11151118
dst,
@@ -1118,7 +1121,16 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> {
11181121
src_align.bytes() as c_uint,
11191122
size,
11201123
is_volatile,
1121-
);
1124+
)
1125+
};
1126+
1127+
// TypeTree metadata for memcpy is especially important: when Enzyme encounters
1128+
// a memcpy during autodiff, it needs to know the structure of the data being
1129+
// copied to properly track derivatives. For example, copying an array of floats
1130+
// vs. copying a struct with mixed types requires different derivative handling.
1131+
// The TypeTree tells Enzyme exactly what memory layout to expect.
1132+
if let Some(tt) = tt {
1133+
add_tt(self.cx().llmod, self.cx().llcx, memcpy, tt);
11221134
}
11231135
}
11241136

compiler/rustc_codegen_llvm/src/builder/autodiff.rs

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1+
use std::os::raw::{c_char, c_uint};
12
use std::ptr;
23

34
use rustc_ast::expand::autodiff_attrs::{AutoDiffAttrs, AutoDiffItem, DiffActivity, DiffMode};
5+
use rustc_ast::expand::typetree::{FncTree, TypeTree as RustTypeTree};
46
use rustc_codegen_ssa::ModuleCodegen;
57
use rustc_codegen_ssa::common::TypeKind;
68
use rustc_codegen_ssa::traits::BaseTypeCodegenMethods;
@@ -512,3 +514,128 @@ pub(crate) fn differentiate<'ll>(
512514

513515
Ok(())
514516
}
517+
518+
/// Converts a Rust TypeTree to Enzyme's internal TypeTree format
519+
///
520+
/// This function takes a Rust-side TypeTree (from rustc_ast::expand::typetree)
521+
/// and converts it to Enzyme's internal C++ TypeTree representation that
522+
/// Enzyme can understand during differentiation analysis.
523+
fn to_enzyme_typetree(
524+
rust_typetree: RustTypeTree,
525+
data_layout: &str,
526+
llcx: &llvm::Context,
527+
) -> llvm::TypeTree {
528+
// Start with an empty TypeTree
529+
let mut enzyme_tt = llvm::TypeTree::new();
530+
531+
// Convert each Type in the Rust TypeTree to Enzyme format
532+
for rust_type in rust_typetree.0 {
533+
let concrete_type = match rust_type.kind {
534+
rustc_ast::expand::typetree::Kind::Anything => llvm::CConcreteType::DT_Anything,
535+
rustc_ast::expand::typetree::Kind::Integer => llvm::CConcreteType::DT_Integer,
536+
rustc_ast::expand::typetree::Kind::Pointer => llvm::CConcreteType::DT_Pointer,
537+
rustc_ast::expand::typetree::Kind::Half => llvm::CConcreteType::DT_Half,
538+
rustc_ast::expand::typetree::Kind::Float => llvm::CConcreteType::DT_Float,
539+
rustc_ast::expand::typetree::Kind::Double => llvm::CConcreteType::DT_Double,
540+
rustc_ast::expand::typetree::Kind::Unknown => llvm::CConcreteType::DT_Unknown,
541+
};
542+
543+
// Create a TypeTree for this specific type
544+
let type_tt = llvm::TypeTree::from_type(concrete_type, llcx);
545+
546+
// Apply offset if specified
547+
let type_tt = if rust_type.offset == -1 {
548+
type_tt // -1 means everywhere/no specific offset
549+
} else {
550+
// Apply specific offset positioning
551+
type_tt.shift(data_layout, rust_type.offset, rust_type.size as isize, 0)
552+
};
553+
554+
// Merge this type into the main TypeTree
555+
enzyme_tt = enzyme_tt.merge(type_tt);
556+
}
557+
558+
enzyme_tt
559+
}
560+
561+
/// Attaches TypeTree information to LLVM function as enzyme_type attributes.
562+
///
563+
/// This function converts Rust TypeTrees to Enzyme format and attaches them as
564+
/// LLVM string attributes. Enzyme reads these attributes during autodiff analysis
565+
/// to understand the memory layout and generate correct derivative code.
566+
///
567+
/// # Arguments
568+
/// * `llmod` - LLVM module containing the function
569+
/// * `llcx` - LLVM context for creating attributes
570+
/// * `fn_def` - LLVM function to attach TypeTrees to
571+
/// * `tt` - Function TypeTree containing input and return type information
572+
pub(crate) fn add_tt<'ll>(
573+
llmod: &'ll llvm::Module,
574+
llcx: &'ll llvm::Context,
575+
fn_def: &'ll Value,
576+
tt: FncTree,
577+
) {
578+
let inputs = tt.args;
579+
let ret_tt: RustTypeTree = tt.ret;
580+
581+
// Get LLVM data layout string for TypeTree conversion
582+
let llvm_data_layout: *const c_char = unsafe { llvm::LLVMGetDataLayoutStr(&*llmod) };
583+
let llvm_data_layout =
584+
std::str::from_utf8(unsafe { std::ffi::CStr::from_ptr(llvm_data_layout) }.to_bytes())
585+
.expect("got a non-UTF8 data-layout from LLVM");
586+
587+
// Attribute name that Enzyme recognizes for TypeTree information
588+
let attr_name = "enzyme_type";
589+
let c_attr_name = std::ffi::CString::new(attr_name).unwrap();
590+
591+
// Attach TypeTree attributes to each input parameter
592+
// Enzyme uses these to understand parameter memory layouts during differentiation
593+
for (i, input) in inputs.iter().enumerate() {
594+
unsafe {
595+
// Convert Rust TypeTree to Enzyme's internal format
596+
let enzyme_tt = to_enzyme_typetree(input.clone(), llvm_data_layout, llcx);
597+
598+
// Serialize TypeTree to string format that Enzyme can parse
599+
let c_str = llvm::EnzymeTypeTreeToString(enzyme_tt.inner);
600+
let c_str = std::ffi::CStr::from_ptr(c_str);
601+
602+
// Create LLVM string attribute with TypeTree information
603+
let attr = llvm::LLVMCreateStringAttribute(
604+
llcx,
605+
c_attr_name.as_ptr(),
606+
c_attr_name.as_bytes().len() as c_uint,
607+
c_str.as_ptr(),
608+
c_str.to_bytes().len() as c_uint,
609+
);
610+
611+
// Attach attribute to the specific function parameter
612+
// Note: ArgumentPlace uses 0-based indexing, but LLVM uses 1-based for arguments
613+
attributes::apply_to_llfn(fn_def, llvm::AttributePlace::Argument(i as u32), &[attr]);
614+
615+
// Free the C string to prevent memory leaks
616+
llvm::EnzymeTypeTreeToStringFree(c_str.as_ptr());
617+
}
618+
}
619+
620+
// Attach TypeTree attribute to the return type
621+
// Enzyme needs this to understand how to handle return value derivatives
622+
unsafe {
623+
let enzyme_tt = to_enzyme_typetree(ret_tt, llvm_data_layout, llcx);
624+
let c_str = llvm::EnzymeTypeTreeToString(enzyme_tt.inner);
625+
let c_str = std::ffi::CStr::from_ptr(c_str);
626+
627+
let ret_attr = llvm::LLVMCreateStringAttribute(
628+
llcx,
629+
c_attr_name.as_ptr(),
630+
c_attr_name.as_bytes().len() as c_uint,
631+
c_str.as_ptr(),
632+
c_str.to_bytes().len() as c_uint,
633+
);
634+
635+
// Attach to function return type
636+
attributes::apply_to_llfn(fn_def, llvm::AttributePlace::ReturnValue, &[ret_attr]);
637+
638+
// Free the C string
639+
llvm::EnzymeTypeTreeToStringFree(c_str.as_ptr());
640+
}
641+
}

compiler/rustc_codegen_llvm/src/llvm/ffi.rs

Lines changed: 141 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2142,7 +2142,7 @@ unsafe extern "C" {
21422142
SPFlags: DISPFlags,
21432143
MaybeFn: Option<&'a Value>,
21442144
TParam: &'a DIArray,
2145-
Decl: Option<&'a DIDescriptor>,
2145+
Decl: Option<&'a DIDescriptor>,
21462146
) -> &'a DISubprogram;
21472147

21482148
pub(crate) fn LLVMRustDIBuilderCreateMethod<'a>(
@@ -2673,4 +2673,144 @@ unsafe extern "C" {
26732673

26742674
pub(crate) fn LLVMRustSetNoSanitizeAddress(Global: &Value);
26752675
pub(crate) fn LLVMRustSetNoSanitizeHWAddress(Global: &Value);
2676+
2677+
// ========== ENZYME AUTODIFF FFI FUNCTIONS ==========
2678+
2679+
// Enzyme Type Tree Functions (minimal set for TypeTree support)
2680+
pub(crate) fn EnzymeNewTypeTree() -> CTypeTreeRef;
2681+
pub(crate) fn EnzymeFreeTypeTree(CTT: CTypeTreeRef);
2682+
pub(crate) fn EnzymeNewTypeTreeCT(arg1: CConcreteType, ctx: &Context) -> CTypeTreeRef;
2683+
pub(crate) fn EnzymeNewTypeTreeTR(arg1: CTypeTreeRef) -> CTypeTreeRef;
2684+
pub(crate) fn EnzymeMergeTypeTree(arg1: CTypeTreeRef, arg2: CTypeTreeRef);
2685+
pub(crate) fn EnzymeTypeTreeShiftIndiciesEq(
2686+
arg1: CTypeTreeRef,
2687+
data_layout: *const c_char,
2688+
offset: i64,
2689+
max_size: i64,
2690+
add_offset: u64,
2691+
);
2692+
pub(crate) fn EnzymeTypeTreeToString(arg1: CTypeTreeRef) -> *const c_char;
2693+
pub(crate) fn EnzymeTypeTreeToStringFree(arg1: *const c_char);
26762694
}
2695+
2696+
// ========== ENZYME TYPES AND ENUMS ==========
2697+
2698+
// Type Tree Support for Autodiff
2699+
2700+
2701+
2702+
2703+
2704+
2705+
2706+
2707+
2708+
2709+
2710+
2711+
2712+
2713+
#[repr(u32)]
2714+
#[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)]
2715+
pub(crate) enum CConcreteType {
2716+
DT_Anything = 0,
2717+
DT_Integer = 1,
2718+
DT_Pointer = 2,
2719+
DT_Half = 3,
2720+
DT_Float = 4,
2721+
DT_Double = 5,
2722+
DT_Unknown = 6,
2723+
}
2724+
2725+
pub(crate) type CTypeTreeRef = *mut EnzymeTypeTree;
2726+
2727+
#[repr(C)]
2728+
#[derive(Debug, Copy, Clone)]
2729+
pub(crate) struct EnzymeTypeTree {
2730+
_unused: [u8; 0],
2731+
}
2732+
2733+
2734+
2735+
2736+
2737+
2738+
2739+
// TypeTree wrapper for Rust-side type safety and memory management
2740+
pub(crate) struct TypeTree {
2741+
pub(crate) inner: CTypeTreeRef,
2742+
}
2743+
2744+
impl TypeTree {
2745+
pub(crate) fn new() -> TypeTree {
2746+
let inner = unsafe { EnzymeNewTypeTree() };
2747+
TypeTree { inner }
2748+
}
2749+
2750+
pub(crate) fn from_type(t: CConcreteType, ctx: &Context) -> TypeTree {
2751+
let inner = unsafe { EnzymeNewTypeTreeCT(t, ctx) };
2752+
TypeTree { inner }
2753+
}
2754+
2755+
pub(crate) fn merge(self, other: Self) -> Self {
2756+
unsafe {
2757+
EnzymeMergeTypeTree(self.inner, other.inner);
2758+
}
2759+
drop(other);
2760+
self
2761+
}
2762+
2763+
#[must_use]
2764+
pub(crate) fn shift(self, layout: &str, offset: isize, max_size: isize, add_offset: usize) -> Self {
2765+
let layout = std::ffi::CString::new(layout).unwrap();
2766+
2767+
unsafe {
2768+
EnzymeTypeTreeShiftIndiciesEq(
2769+
self.inner,
2770+
layout.as_ptr(),
2771+
offset as i64,
2772+
max_size as i64,
2773+
add_offset as u64,
2774+
)
2775+
}
2776+
2777+
self
2778+
}
2779+
}
2780+
2781+
impl Clone for TypeTree {
2782+
fn clone(&self) -> Self {
2783+
let inner = unsafe { EnzymeNewTypeTreeTR(self.inner) };
2784+
TypeTree { inner }
2785+
}
2786+
}
2787+
2788+
impl std::fmt::Display for TypeTree {
2789+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2790+
let ptr = unsafe { EnzymeTypeTreeToString(self.inner) };
2791+
let cstr = unsafe { std::ffi::CStr::from_ptr(ptr) };
2792+
match cstr.to_str() {
2793+
Ok(x) => write!(f, "{}", x)?,
2794+
Err(err) => write!(f, "could not parse: {}", err)?,
2795+
}
2796+
2797+
// delete C string pointer
2798+
unsafe { EnzymeTypeTreeToStringFree(ptr) }
2799+
2800+
Ok(())
2801+
}
2802+
}
2803+
2804+
impl std::fmt::Debug for TypeTree {
2805+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2806+
<Self as std::fmt::Display>::fmt(self, f)
2807+
}
2808+
}
2809+
2810+
impl Drop for TypeTree {
2811+
fn drop(&mut self) {
2812+
unsafe { EnzymeFreeTypeTree(self.inner) }
2813+
}
2814+
}
2815+
2816+

compiler/rustc_codegen_llvm/src/va_arg.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -735,6 +735,7 @@ fn copy_to_temporary_if_more_aligned<'ll, 'tcx>(
735735
src_align,
736736
bx.const_u32(layout.layout.size().bytes() as u32),
737737
MemFlags::empty(),
738+
None,
738739
);
739740
tmp
740741
} else {

compiler/rustc_codegen_ssa/src/mir/block.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1550,6 +1550,7 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
15501550
align,
15511551
bx.const_usize(copy_bytes),
15521552
MemFlags::empty(),
1553+
None,
15531554
);
15541555
// ...and then load it with the ABI type.
15551556
llval = load_cast(bx, cast, llscratch, scratch_align);

compiler/rustc_codegen_ssa/src/mir/intrinsic.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ fn copy_intrinsic<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>>(
3030
if allow_overlap {
3131
bx.memmove(dst, align, src, align, size, flags);
3232
} else {
33-
bx.memcpy(dst, align, src, align, size, flags);
33+
bx.memcpy(dst, align, src, align, size, flags, None);
3434
}
3535
}
3636

compiler/rustc_codegen_ssa/src/mir/statement.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
9090
let align = pointee_layout.align;
9191
let dst = dst_val.immediate();
9292
let src = src_val.immediate();
93-
bx.memcpy(dst, align, src, align, bytes, crate::MemFlags::empty());
93+
bx.memcpy(dst, align, src, align, bytes, crate::MemFlags::empty(), None);
9494
}
9595
mir::StatementKind::FakeRead(..)
9696
| mir::StatementKind::Retag { .. }

0 commit comments

Comments
 (0)