Skip to content

Commit 4520926

Browse files
committed
autodiff: recurion added for typetree
1 parent 4f3f0f4 commit 4520926

File tree

12 files changed

+191
-20
lines changed

12 files changed

+191
-20
lines changed

compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,7 @@ pub(crate) mod Enzyme_AD {
127127
);
128128
pub(crate) fn EnzymeTypeTreeToString(arg1: CTypeTreeRef) -> *const c_char;
129129
pub(crate) fn EnzymeTypeTreeToStringFree(arg1: *const c_char);
130+
pub(crate) fn EnzymeGetMaxTypeDepth() -> ::std::os::raw::c_uint;
130131
}
131132

132133
unsafe extern "C" {

compiler/rustc_codegen_llvm/src/typetree.rs

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,11 +39,7 @@ fn process_typetree_recursive(
3939

4040
let mut indices = parent_indices.to_vec();
4141
if !parent_indices.is_empty() {
42-
if rust_type.offset == -1 {
43-
indices.push(-1);
44-
} else {
45-
indices.push(rust_type.offset as i64);
46-
}
42+
indices.push(rust_type.offset as i64);
4743
} else if rust_type.offset == -1 {
4844
indices.push(-1);
4945
} else {
@@ -52,7 +48,9 @@ fn process_typetree_recursive(
5248

5349
enzyme_tt.insert(&indices, concrete_type, llcx);
5450

55-
if rust_type.kind == rustc_ast::expand::typetree::Kind::Pointer && !rust_type.child.0.is_empty() {
51+
if rust_type.kind == rustc_ast::expand::typetree::Kind::Pointer
52+
&& !rust_type.child.0.is_empty()
53+
{
5654
process_typetree_recursive(enzyme_tt, &rust_type.child, &indices, llcx);
5755
}
5856
}

compiler/rustc_middle/src/ty/mod.rs

Lines changed: 68 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2252,6 +2252,61 @@ pub fn fnc_typetrees<'tcx>(tcx: TyCtxt<'tcx>, fn_ty: Ty<'tcx>) -> FncTree {
22522252
/// Generate TypeTree for a specific type.
22532253
/// This function analyzes a Rust type and creates appropriate TypeTree metadata.
22542254
pub fn typetree_from_ty<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> TypeTree {
2255+
let mut visited = Vec::new();
2256+
typetree_from_ty_inner(tcx, ty, 0, &mut visited)
2257+
}
2258+
2259+
/// Internal recursive function for TypeTree generation with cycle detection and depth limiting.
2260+
fn typetree_from_ty_inner<'tcx>(
2261+
tcx: TyCtxt<'tcx>,
2262+
ty: Ty<'tcx>,
2263+
depth: usize,
2264+
visited: &mut Vec<Ty<'tcx>>,
2265+
) -> TypeTree {
2266+
#[cfg(llvm_enzyme)]
2267+
{
2268+
unsafe extern "C" {
2269+
fn EnzymeGetMaxTypeDepth() -> ::std::os::raw::c_uint;
2270+
}
2271+
let max_depth = unsafe { EnzymeGetMaxTypeDepth() } as usize;
2272+
if depth > max_depth {
2273+
return TypeTree::new();
2274+
}
2275+
}
2276+
2277+
#[cfg(not(llvm_enzyme))]
2278+
if depth > 6 {
2279+
return TypeTree::new();
2280+
}
2281+
2282+
if visited.contains(&ty) {
2283+
return TypeTree::new();
2284+
}
2285+
2286+
visited.push(ty);
2287+
let result = typetree_from_ty_impl(tcx, ty, depth, visited);
2288+
visited.pop();
2289+
result
2290+
}
2291+
2292+
/// Implementation of TypeTree generation logic.
2293+
fn typetree_from_ty_impl<'tcx>(
2294+
tcx: TyCtxt<'tcx>,
2295+
ty: Ty<'tcx>,
2296+
depth: usize,
2297+
visited: &mut Vec<Ty<'tcx>>,
2298+
) -> TypeTree {
2299+
typetree_from_ty_impl_inner(tcx, ty, depth, visited, false)
2300+
}
2301+
2302+
/// Internal implementation with context about whether this is for a reference target.
2303+
fn typetree_from_ty_impl_inner<'tcx>(
2304+
tcx: TyCtxt<'tcx>,
2305+
ty: Ty<'tcx>,
2306+
depth: usize,
2307+
visited: &mut Vec<Ty<'tcx>>,
2308+
is_reference_target: bool,
2309+
) -> TypeTree {
22552310
if ty.is_scalar() {
22562311
let (kind, size) = if ty.is_integral() || ty.is_char() || ty.is_bool() {
22572312
(Kind::Integer, ty.primitive_size(tcx).bytes_usize())
@@ -2267,7 +2322,10 @@ pub fn typetree_from_ty<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> TypeTree {
22672322
(Kind::Integer, 0)
22682323
};
22692324

2270-
return TypeTree(vec![Type { offset: -1, size, kind, child: TypeTree::new() }]);
2325+
// Use offset 0 for scalars that are direct targets of references (like &f64)
2326+
// Use offset -1 for scalars used directly (like function return types)
2327+
let offset = if is_reference_target && !ty.is_array() { 0 } else { -1 };
2328+
return TypeTree(vec![Type { offset, size, kind, child: TypeTree::new() }]);
22712329
}
22722330

22732331
if ty.is_ref() || ty.is_raw_ptr() || ty.is_box() {
@@ -2277,7 +2335,7 @@ pub fn typetree_from_ty<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> TypeTree {
22772335
return TypeTree::new();
22782336
};
22792337

2280-
let child = typetree_from_ty(tcx, inner_ty);
2338+
let child = typetree_from_ty_impl_inner(tcx, inner_ty, depth + 1, visited, true);
22812339
return TypeTree(vec![Type {
22822340
offset: -1,
22832341
size: tcx.data_layout.pointer_size().bytes_usize(),
@@ -2292,9 +2350,8 @@ pub fn typetree_from_ty<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> TypeTree {
22922350
if len == 0 {
22932351
return TypeTree::new();
22942352
}
2295-
2296-
let element_tree = typetree_from_ty(tcx, *element_ty);
2297-
2353+
let element_tree =
2354+
typetree_from_ty_impl_inner(tcx, *element_ty, depth + 1, visited, false);
22982355
let mut types = Vec::new();
22992356
for elem_type in &element_tree.0 {
23002357
types.push(Type {
@@ -2311,7 +2368,8 @@ pub fn typetree_from_ty<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> TypeTree {
23112368

23122369
if ty.is_slice() {
23132370
if let ty::Slice(element_ty) = ty.kind() {
2314-
let element_tree = typetree_from_ty(tcx, *element_ty);
2371+
let element_tree =
2372+
typetree_from_ty_impl_inner(tcx, *element_ty, depth + 1, visited, false);
23152373
return element_tree;
23162374
}
23172375
}
@@ -2325,7 +2383,8 @@ pub fn typetree_from_ty<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> TypeTree {
23252383
let mut current_offset = 0;
23262384

23272385
for tuple_ty in tuple_types.iter() {
2328-
let element_tree = typetree_from_ty(tcx, tuple_ty);
2386+
let element_tree =
2387+
typetree_from_ty_impl_inner(tcx, tuple_ty, depth + 1, visited, false);
23292388

23302389
let element_layout = tcx
23312390
.layout_of(ty::TypingEnv::fully_monomorphized().as_query_input(tuple_ty))
@@ -2361,7 +2420,8 @@ pub fn typetree_from_ty<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> TypeTree {
23612420

23622421
for (field_idx, field_def) in adt_def.all_fields().enumerate() {
23632422
let field_ty = field_def.ty(tcx, args);
2364-
let field_tree = typetree_from_ty(tcx, field_ty);
2423+
let field_tree =
2424+
typetree_from_ty_impl_inner(tcx, field_ty, depth + 1, visited, false);
23652425

23662426
let field_offset = layout.fields.offset(field_idx).bytes_usize();
23672427

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
// Check that enzyme_type attributes are present when TypeTree is enabled
22
// This verifies our TypeTree metadata attachment is working
33

4-
CHECK: define{{.*}}"enzyme_type"="{[-1]:Float@double}"{{.*}}@square{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,-1]:Float@double}"
4+
CHECK: define{{.*}}"enzyme_type"="{[-1]:Float@double}"{{.*}}@square{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,0]:Float@double}"
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
CHECK: define{{.*}}"enzyme_type"="{[-1]:Float@double}"{{.*}}@test_deep{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,0]:Float@double}"
2+
CHECK: define{{.*}}"enzyme_type"="{[-1]:Float@double}"{{.*}}@test_graph{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,0]:Integer, [-1,8]:Integer, [-1,16]:Integer, [-1,24]:Float@double}"
3+
CHECK: define{{.*}}"enzyme_type"="{[-1]:Float@double}"{{.*}}@test_node{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,0]:Float@double}"
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
//@ needs-enzyme
2+
//@ ignore-cross-compile
3+
4+
use run_make_support::{llvm_filecheck, rfs, rustc};
5+
6+
fn main() {
7+
rustc().input("test.rs").arg("-Zautodiff=Enable").emit("llvm-ir").run();
8+
llvm_filecheck().patterns("recursion.check").stdin_buf(rfs::read("test.ll")).run();
9+
}
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
#![feature(autodiff)]
2+
3+
use std::autodiff::autodiff_reverse;
4+
5+
// Self-referential struct to test recursion detection
6+
#[derive(Clone)]
7+
struct Node {
8+
value: f64,
9+
next: Option<Box<Node>>,
10+
}
11+
12+
// Mutually recursive structs to test cycle detection
13+
#[derive(Clone)]
14+
struct GraphNodeA {
15+
value: f64,
16+
connections: Vec<GraphNodeB>,
17+
}
18+
19+
#[derive(Clone)]
20+
struct GraphNodeB {
21+
weight: f64,
22+
target: Option<Box<GraphNodeA>>,
23+
}
24+
25+
#[autodiff_reverse(d_test_node, Duplicated, Active)]
26+
#[no_mangle]
27+
fn test_node(node: &Node) -> f64 {
28+
node.value * 2.0
29+
}
30+
31+
#[autodiff_reverse(d_test_graph, Duplicated, Active)]
32+
#[no_mangle]
33+
fn test_graph(a: &GraphNodeA) -> f64 {
34+
a.value * 3.0
35+
}
36+
37+
// Simple depth test - deeply nested but not circular
38+
#[derive(Clone)]
39+
struct Level1 {
40+
val: f64,
41+
next: Option<Box<Level2>>,
42+
}
43+
#[derive(Clone)]
44+
struct Level2 {
45+
val: f64,
46+
next: Option<Box<Level3>>,
47+
}
48+
#[derive(Clone)]
49+
struct Level3 {
50+
val: f64,
51+
next: Option<Box<Level4>>,
52+
}
53+
#[derive(Clone)]
54+
struct Level4 {
55+
val: f64,
56+
next: Option<Box<Level5>>,
57+
}
58+
#[derive(Clone)]
59+
struct Level5 {
60+
val: f64,
61+
next: Option<Box<Level6>>,
62+
}
63+
#[derive(Clone)]
64+
struct Level6 {
65+
val: f64,
66+
next: Option<Box<Level7>>,
67+
}
68+
#[derive(Clone)]
69+
struct Level7 {
70+
val: f64,
71+
next: Option<Box<Level8>>,
72+
}
73+
#[derive(Clone)]
74+
struct Level8 {
75+
val: f64,
76+
}
77+
78+
#[autodiff_reverse(d_test_deep, Duplicated, Active)]
79+
#[no_mangle]
80+
fn test_deep(deep: &Level1) -> f64 {
81+
deep.val * 4.0
82+
}
83+
84+
fn main() {
85+
let node = Node { value: 1.0, next: None };
86+
87+
let graph = GraphNodeA { value: 2.0, connections: vec![] };
88+
89+
let deep = Level1 { val: 5.0, next: None };
90+
91+
let mut d_node = Node { value: 0.0, next: None };
92+
93+
let mut d_graph = GraphNodeA { value: 0.0, connections: vec![] };
94+
95+
let mut d_deep = Level1 { val: 0.0, next: None };
96+
97+
let _result1 = d_test_node(&node, &mut d_node, 1.0);
98+
let _result2 = d_test_graph(&graph, &mut d_graph, 1.0);
99+
let _result3 = d_test_deep(&deep, &mut d_deep, 1.0);
100+
}
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
; Check that f128 TypeTree metadata is correctly generated
22
; Should show Float@fp128 for f128 values and Pointer for references
33

4-
CHECK: define{{.*}}"enzyme_type"="{[-1]:Float@fp128}"{{.*}}@test_f128{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,-1]:Float@fp128}"
4+
CHECK: define{{.*}}"enzyme_type"="{[-1]:Float@fp128}"{{.*}}@test_f128{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,0]:Float@fp128}"
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
; Check that f16 TypeTree metadata is correctly generated
22
; Should show Float@half for f16 values and Pointer for references
33

4-
CHECK: define{{.*}}"enzyme_type"="{[-1]:Float@half}"{{.*}}@test_f16{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,-1]:Float@half}"
4+
CHECK: define{{.*}}"enzyme_type"="{[-1]:Float@half}"{{.*}}@test_f16{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,0]:Float@half}"
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
; Check that f32 TypeTree metadata is correctly generated
22
; Should show Float@float for f32 values and Pointer for references
33

4-
CHECK: define{{.*}}"enzyme_type"="{[-1]:Float@float}"{{.*}}@test_f32{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,-1]:Float@float}"
4+
CHECK: define{{.*}}"enzyme_type"="{[-1]:Float@float}"{{.*}}@test_f32{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,0]:Float@float}"

0 commit comments

Comments
 (0)