Skip to content

Commit 4f3f0f4

Browse files
committed
autodiff: fixed test to be more precise for type tree checking
1 parent 574f0b9 commit 4f3f0f4

File tree

19 files changed

+120
-87
lines changed

19 files changed

+120
-87
lines changed

compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,13 @@ pub(crate) mod Enzyme_AD {
118118
max_size: i64,
119119
add_offset: u64,
120120
);
121+
pub(crate) fn EnzymeTypeTreeInsertEq(
122+
CTT: CTypeTreeRef,
123+
indices: *const i64,
124+
len: usize,
125+
ct: CConcreteType,
126+
ctx: &Context,
127+
);
121128
pub(crate) fn EnzymeTypeTreeToString(arg1: CTypeTreeRef) -> *const c_char;
122129
pub(crate) fn EnzymeTypeTreeToStringFree(arg1: *const c_char);
123130
}
@@ -234,6 +241,16 @@ pub(crate) mod Fallback_AD {
234241
unimplemented!()
235242
}
236243

244+
pub(crate) unsafe fn EnzymeTypeTreeInsertEq(
245+
CTT: CTypeTreeRef,
246+
indices: *const i64,
247+
len: usize,
248+
ct: CConcreteType,
249+
ctx: &Context,
250+
) {
251+
unimplemented!()
252+
}
253+
237254
pub(crate) unsafe fn EnzymeTypeTreeToString(arg1: CTypeTreeRef) -> *const c_char {
238255
unimplemented!()
239256
}
@@ -312,6 +329,12 @@ impl TypeTree {
312329

313330
self
314331
}
332+
333+
pub(crate) fn insert(&mut self, indices: &[i64], ct: CConcreteType, ctx: &Context) {
334+
unsafe {
335+
EnzymeTypeTreeInsertEq(self.inner, indices.as_ptr(), indices.len(), ct, ctx);
336+
}
337+
}
315338
}
316339

317340
impl Clone for TypeTree {

compiler/rustc_codegen_llvm/src/typetree.rs

Lines changed: 28 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -8,22 +8,24 @@ use {
88

99
use crate::llvm::{self, Value};
1010

11-
/// Converts a Rust TypeTree to Enzyme's internal TypeTree format
12-
///
13-
/// This function takes a Rust-side TypeTree (from rustc_ast::expand::typetree)
14-
/// and converts it to Enzyme's internal C++ TypeTree representation that
15-
/// Enzyme can understand during differentiation analysis.
1611
#[cfg(llvm_enzyme)]
1712
fn to_enzyme_typetree(
1813
rust_typetree: RustTypeTree,
19-
data_layout: &str,
14+
_data_layout: &str,
2015
llcx: &llvm::Context,
2116
) -> llvm::TypeTree {
22-
// Start with an empty TypeTree
2317
let mut enzyme_tt = llvm::TypeTree::new();
24-
25-
// Convert each Type in the Rust TypeTree to Enzyme format
26-
for rust_type in rust_typetree.0 {
18+
process_typetree_recursive(&mut enzyme_tt, &rust_typetree, &[], llcx);
19+
enzyme_tt
20+
}
21+
#[cfg(llvm_enzyme)]
22+
fn process_typetree_recursive(
23+
enzyme_tt: &mut llvm::TypeTree,
24+
rust_typetree: &RustTypeTree,
25+
parent_indices: &[i64],
26+
llcx: &llvm::Context,
27+
) {
28+
for rust_type in &rust_typetree.0 {
2729
let concrete_type = match rust_type.kind {
2830
rustc_ast::expand::typetree::Kind::Anything => llvm::CConcreteType::DT_Anything,
2931
rustc_ast::expand::typetree::Kind::Integer => llvm::CConcreteType::DT_Integer,
@@ -35,25 +37,27 @@ fn to_enzyme_typetree(
3537
rustc_ast::expand::typetree::Kind::Unknown => llvm::CConcreteType::DT_Unknown,
3638
};
3739

38-
// Create a TypeTree for this specific type
39-
let type_tt = llvm::TypeTree::from_type(concrete_type, llcx);
40-
41-
// Apply offset if specified
42-
let type_tt = if rust_type.offset == -1 {
43-
type_tt // -1 means everywhere/no specific offset
40+
let mut indices = parent_indices.to_vec();
41+
if !parent_indices.is_empty() {
42+
if rust_type.offset == -1 {
43+
indices.push(-1);
44+
} else {
45+
indices.push(rust_type.offset as i64);
46+
}
47+
} else if rust_type.offset == -1 {
48+
indices.push(-1);
4449
} else {
45-
// Apply specific offset positioning
46-
type_tt.shift(data_layout, rust_type.offset, rust_type.size as isize, 0)
47-
};
50+
indices.push(rust_type.offset as i64);
51+
}
4852

49-
// Merge this type into the main TypeTree
50-
enzyme_tt = enzyme_tt.merge(type_tt);
51-
}
53+
enzyme_tt.insert(&indices, concrete_type, llcx);
5254

53-
enzyme_tt
55+
if rust_type.kind == rustc_ast::expand::typetree::Kind::Pointer && !rust_type.child.0.is_empty() {
56+
process_typetree_recursive(enzyme_tt, &rust_type.child, &indices, llcx);
57+
}
58+
}
5459
}
5560

56-
// Attaches TypeTree information to LLVM function as enzyme_type attributes.
5761
#[cfg(llvm_enzyme)]
5862
pub(crate) fn add_tt<'ll>(
5963
llmod: &'ll llvm::Module,
@@ -64,28 +68,20 @@ pub(crate) fn add_tt<'ll>(
6468
let inputs = tt.args;
6569
let ret_tt: RustTypeTree = tt.ret;
6670

67-
// Get LLVM data layout string for TypeTree conversion
6871
let llvm_data_layout: *const c_char = unsafe { llvm::LLVMGetDataLayoutStr(&*llmod) };
6972
let llvm_data_layout =
7073
std::str::from_utf8(unsafe { std::ffi::CStr::from_ptr(llvm_data_layout) }.to_bytes())
7174
.expect("got a non-UTF8 data-layout from LLVM");
7275

73-
// Attribute name that Enzyme recognizes for TypeTree information
7476
let attr_name = "enzyme_type";
7577
let c_attr_name = CString::new(attr_name).unwrap();
7678

77-
// Attach TypeTree attributes to each input parameter
78-
// Enzyme uses these to understand parameter memory layouts during differentiation
7979
for (i, input) in inputs.iter().enumerate() {
8080
unsafe {
81-
// Convert Rust TypeTree to Enzyme's internal format
8281
let enzyme_tt = to_enzyme_typetree(input.clone(), llvm_data_layout, llcx);
83-
84-
// Serialize TypeTree to string format that Enzyme can parse
8582
let c_str = llvm::EnzymeTypeTreeToString(enzyme_tt.inner);
8683
let c_str = std::ffi::CStr::from_ptr(c_str);
8784

88-
// Create LLVM string attribute with TypeTree information
8985
let attr = llvm::LLVMCreateStringAttribute(
9086
llcx,
9187
c_attr_name.as_ptr(),
@@ -94,17 +90,11 @@ pub(crate) fn add_tt<'ll>(
9490
c_str.to_bytes().len() as c_uint,
9591
);
9692

97-
// Attach attribute to the specific function parameter
98-
// Note: ArgumentPlace uses 0-based indexing, but LLVM uses 1-based for arguments
9993
attributes::apply_to_llfn(fn_def, llvm::AttributePlace::Argument(i as u32), &[attr]);
100-
101-
// Free the C string to prevent memory leaks
10294
llvm::EnzymeTypeTreeToStringFree(c_str.as_ptr());
10395
}
10496
}
10597

106-
// Attach TypeTree attribute to the return type
107-
// Enzyme needs this to understand how to handle return value derivatives
10898
unsafe {
10999
let enzyme_tt = to_enzyme_typetree(ret_tt, llvm_data_layout, llcx);
110100
let c_str = llvm::EnzymeTypeTreeToString(enzyme_tt.inner);
@@ -118,15 +108,11 @@ pub(crate) fn add_tt<'ll>(
118108
c_str.to_bytes().len() as c_uint,
119109
);
120110

121-
// Attach to function return type
122111
attributes::apply_to_llfn(fn_def, llvm::AttributePlace::ReturnValue, &[ret_attr]);
123-
124-
// Free the C string
125112
llvm::EnzymeTypeTreeToStringFree(c_str.as_ptr());
126113
}
127114
}
128115

129-
// Fallback implementation when Enzyme is not available
130116
#[cfg(not(llvm_enzyme))]
131117
pub(crate) fn add_tt<'ll>(
132118
_llmod: &'ll llvm::Module,

compiler/rustc_middle/src/ty/mod.rs

Lines changed: 9 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -2261,10 +2261,10 @@ pub fn typetree_from_ty<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> TypeTree {
22612261
x if x == tcx.types.f32 => (Kind::Float, 4),
22622262
x if x == tcx.types.f64 => (Kind::Double, 8),
22632263
x if x == tcx.types.f128 => (Kind::F128, 16),
2264-
_ => return TypeTree::new(),
2264+
_ => (Kind::Integer, 0),
22652265
}
22662266
} else {
2267-
return TypeTree::new();
2267+
(Kind::Integer, 0)
22682268
};
22692269

22702270
return TypeTree(vec![Type { offset: -1, size, kind, child: TypeTree::new() }]);
@@ -2295,32 +2295,14 @@ pub fn typetree_from_ty<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> TypeTree {
22952295

22962296
let element_tree = typetree_from_ty(tcx, *element_ty);
22972297

2298-
let element_layout = tcx
2299-
.layout_of(ty::TypingEnv::fully_monomorphized().as_query_input(*element_ty))
2300-
.ok()
2301-
.map(|layout| layout.size.bytes_usize())
2302-
.unwrap_or(0);
2303-
2304-
if element_layout == 0 {
2305-
return TypeTree::new();
2306-
}
2307-
23082298
let mut types = Vec::new();
2309-
for i in 0..len {
2310-
let base_offset = (i as usize * element_layout) as isize;
2311-
2312-
for elem_type in &element_tree.0 {
2313-
types.push(Type {
2314-
offset: if elem_type.offset == -1 {
2315-
base_offset
2316-
} else {
2317-
base_offset + elem_type.offset
2318-
},
2319-
size: elem_type.size,
2320-
kind: elem_type.kind,
2321-
child: elem_type.child.clone(),
2322-
});
2323-
}
2299+
for elem_type in &element_tree.0 {
2300+
types.push(Type {
2301+
offset: -1,
2302+
size: elem_type.size,
2303+
kind: elem_type.kind,
2304+
child: elem_type.child.clone(),
2305+
});
23242306
}
23252307

23262308
return TypeTree(types);
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
; Check that array TypeTree metadata is correctly generated
22
; Should show Float@double at each array element offset (0, 8, 16, 24, 32 bytes)
33

4-
CHECK: define{{.*}}"enzyme_type"="{[]:Float@double}"{{.*}}@test_array{{.*}}"enzyme_type"="{[]:Pointer}"
4+
CHECK: define{{.*}}"enzyme_type"="{[-1]:Float@double}"{{.*}}@test_array{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,-1]:Float@double}"

tests/run-make/autodiff/type-trees/memcpy-typetree/memcpy-ir.check

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
; Check that enzyme_type attributes are present in the LLVM IR function definition
22
; This verifies our TypeTree system correctly attaches metadata for Enzyme
33

4-
CHECK: define{{.*}}"enzyme_type"="{[]:Float@double}"{{.*}}@test_memcpy({{.*}}"enzyme_type"="{[]:Pointer}"
4+
CHECK: define{{.*}}"enzyme_type"="{[-1]:Float@double}"{{.*}}@test_memcpy({{.*}}"enzyme_type"="{[-1]:Pointer, [-1,-1]:Float@double}"
55

66
; Check that llvm.memcpy exists (either call or declare)
77
CHECK: {{(call|declare).*}}@llvm.memcpy
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
; Check that mixed struct with large array generates correct detailed type tree
2+
CHECK: define{{.*}}"enzyme_type"="{[-1]:Float@float}"{{.*}}@test_mixed_struct{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,0]:Integer, [-1,8]:Float@float}"
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
//@ needs-enzyme
2+
//@ ignore-cross-compile
3+
4+
use run_make_support::{llvm_filecheck, rfs, rustc};
5+
6+
fn main() {
7+
rustc()
8+
.input("test.rs")
9+
.arg("-Zautodiff=Enable")
10+
.arg("-Zautodiff=NoPostopt")
11+
.opt_level("0")
12+
.emit("llvm-ir")
13+
.run();
14+
15+
llvm_filecheck().patterns("mixed.check").stdin_buf(rfs::read("test.ll")).run();
16+
}
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
#![feature(autodiff)]
2+
3+
use std::autodiff::autodiff_reverse;
4+
5+
#[repr(C)]
6+
struct Container {
7+
header: i64,
8+
data: [f32; 1000],
9+
}
10+
11+
#[autodiff_reverse(d_test, Duplicated, Active)]
12+
#[no_mangle]
13+
#[inline(never)]
14+
fn test_mixed_struct(container: &Container) -> f32 {
15+
container.data[0] + container.data[999]
16+
}
17+
18+
fn main() {
19+
let container = Container { header: 42, data: [1.0; 1000] };
20+
let mut d_container = Container { header: 0, data: [0.0; 1000] };
21+
let result = d_test(&container, &mut d_container, 1.0);
22+
std::hint::black_box(result);
23+
}
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
// Check that enzyme_type attributes are present when TypeTree is enabled
22
// This verifies our TypeTree metadata attachment is working
33

4-
CHECK: define{{.*}}"enzyme_type"="{[]:Float@double}"{{.*}}@square{{.*}}"enzyme_type"="{[]:Pointer}"
4+
CHECK: define{{.*}}"enzyme_type"="{[-1]:Float@double}"{{.*}}@square{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,-1]:Float@double}"
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
; Check that f128 TypeTree metadata is correctly generated
2-
; f128 maps to Unknown in our current implementation since CConcreteType doesn't have DT_F128
2+
; Should show Float@fp128 for f128 values and Pointer for references
33

4-
CHECK: define{{.*}}"enzyme_type"={{.*}}@test_f128{{.*}}"enzyme_type"="{[]:Pointer}"
4+
CHECK: define{{.*}}"enzyme_type"="{[-1]:Float@fp128}"{{.*}}@test_f128{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,-1]:Float@fp128}"

0 commit comments

Comments
 (0)