Skip to content

Commit 218c95d

Browse files
committed
autodiff: recurion added for typetree
1 parent 20974ec commit 218c95d

File tree

12 files changed

+191
-20
lines changed

12 files changed

+191
-20
lines changed

compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,7 @@ pub(crate) mod Enzyme_AD {
127127
);
128128
pub(crate) fn EnzymeTypeTreeToString(arg1: CTypeTreeRef) -> *const c_char;
129129
pub(crate) fn EnzymeTypeTreeToStringFree(arg1: *const c_char);
130+
pub(crate) fn EnzymeGetMaxTypeDepth() -> ::std::os::raw::c_uint;
130131
}
131132

132133
unsafe extern "C" {

compiler/rustc_codegen_llvm/src/typetree.rs

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,11 +39,7 @@ fn process_typetree_recursive(
3939

4040
let mut indices = parent_indices.to_vec();
4141
if !parent_indices.is_empty() {
42-
if rust_type.offset == -1 {
43-
indices.push(-1);
44-
} else {
45-
indices.push(rust_type.offset as i64);
46-
}
42+
indices.push(rust_type.offset as i64);
4743
} else if rust_type.offset == -1 {
4844
indices.push(-1);
4945
} else {
@@ -52,7 +48,9 @@ fn process_typetree_recursive(
5248

5349
enzyme_tt.insert(&indices, concrete_type, llcx);
5450

55-
if rust_type.kind == rustc_ast::expand::typetree::Kind::Pointer && !rust_type.child.0.is_empty() {
51+
if rust_type.kind == rustc_ast::expand::typetree::Kind::Pointer
52+
&& !rust_type.child.0.is_empty()
53+
{
5654
process_typetree_recursive(enzyme_tt, &rust_type.child, &indices, llcx);
5755
}
5856
}

compiler/rustc_middle/src/ty/mod.rs

Lines changed: 68 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2258,6 +2258,61 @@ pub fn fnc_typetrees<'tcx>(tcx: TyCtxt<'tcx>, fn_ty: Ty<'tcx>) -> FncTree {
22582258
/// Generate TypeTree for a specific type.
22592259
/// This function analyzes a Rust type and creates appropriate TypeTree metadata.
22602260
pub fn typetree_from_ty<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> TypeTree {
2261+
let mut visited = Vec::new();
2262+
typetree_from_ty_inner(tcx, ty, 0, &mut visited)
2263+
}
2264+
2265+
/// Internal recursive function for TypeTree generation with cycle detection and depth limiting.
2266+
fn typetree_from_ty_inner<'tcx>(
2267+
tcx: TyCtxt<'tcx>,
2268+
ty: Ty<'tcx>,
2269+
depth: usize,
2270+
visited: &mut Vec<Ty<'tcx>>,
2271+
) -> TypeTree {
2272+
#[cfg(llvm_enzyme)]
2273+
{
2274+
unsafe extern "C" {
2275+
fn EnzymeGetMaxTypeDepth() -> ::std::os::raw::c_uint;
2276+
}
2277+
let max_depth = unsafe { EnzymeGetMaxTypeDepth() } as usize;
2278+
if depth > max_depth {
2279+
return TypeTree::new();
2280+
}
2281+
}
2282+
2283+
#[cfg(not(llvm_enzyme))]
2284+
if depth > 6 {
2285+
return TypeTree::new();
2286+
}
2287+
2288+
if visited.contains(&ty) {
2289+
return TypeTree::new();
2290+
}
2291+
2292+
visited.push(ty);
2293+
let result = typetree_from_ty_impl(tcx, ty, depth, visited);
2294+
visited.pop();
2295+
result
2296+
}
2297+
2298+
/// Implementation of TypeTree generation logic.
2299+
fn typetree_from_ty_impl<'tcx>(
2300+
tcx: TyCtxt<'tcx>,
2301+
ty: Ty<'tcx>,
2302+
depth: usize,
2303+
visited: &mut Vec<Ty<'tcx>>,
2304+
) -> TypeTree {
2305+
typetree_from_ty_impl_inner(tcx, ty, depth, visited, false)
2306+
}
2307+
2308+
/// Internal implementation with context about whether this is for a reference target.
2309+
fn typetree_from_ty_impl_inner<'tcx>(
2310+
tcx: TyCtxt<'tcx>,
2311+
ty: Ty<'tcx>,
2312+
depth: usize,
2313+
visited: &mut Vec<Ty<'tcx>>,
2314+
is_reference_target: bool,
2315+
) -> TypeTree {
22612316
if ty.is_scalar() {
22622317
let (kind, size) = if ty.is_integral() || ty.is_char() || ty.is_bool() {
22632318
(Kind::Integer, ty.primitive_size(tcx).bytes_usize())
@@ -2273,7 +2328,10 @@ pub fn typetree_from_ty<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> TypeTree {
22732328
(Kind::Integer, 0)
22742329
};
22752330

2276-
return TypeTree(vec![Type { offset: -1, size, kind, child: TypeTree::new() }]);
2331+
// Use offset 0 for scalars that are direct targets of references (like &f64)
2332+
// Use offset -1 for scalars used directly (like function return types)
2333+
let offset = if is_reference_target && !ty.is_array() { 0 } else { -1 };
2334+
return TypeTree(vec![Type { offset, size, kind, child: TypeTree::new() }]);
22772335
}
22782336

22792337
if ty.is_ref() || ty.is_raw_ptr() || ty.is_box() {
@@ -2283,7 +2341,7 @@ pub fn typetree_from_ty<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> TypeTree {
22832341
return TypeTree::new();
22842342
};
22852343

2286-
let child = typetree_from_ty(tcx, inner_ty);
2344+
let child = typetree_from_ty_impl_inner(tcx, inner_ty, depth + 1, visited, true);
22872345
return TypeTree(vec![Type {
22882346
offset: -1,
22892347
size: tcx.data_layout.pointer_size().bytes_usize(),
@@ -2298,9 +2356,8 @@ pub fn typetree_from_ty<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> TypeTree {
22982356
if len == 0 {
22992357
return TypeTree::new();
23002358
}
2301-
2302-
let element_tree = typetree_from_ty(tcx, *element_ty);
2303-
2359+
let element_tree =
2360+
typetree_from_ty_impl_inner(tcx, *element_ty, depth + 1, visited, false);
23042361
let mut types = Vec::new();
23052362
for elem_type in &element_tree.0 {
23062363
types.push(Type {
@@ -2317,7 +2374,8 @@ pub fn typetree_from_ty<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> TypeTree {
23172374

23182375
if ty.is_slice() {
23192376
if let ty::Slice(element_ty) = ty.kind() {
2320-
let element_tree = typetree_from_ty(tcx, *element_ty);
2377+
let element_tree =
2378+
typetree_from_ty_impl_inner(tcx, *element_ty, depth + 1, visited, false);
23212379
return element_tree;
23222380
}
23232381
}
@@ -2331,7 +2389,8 @@ pub fn typetree_from_ty<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> TypeTree {
23312389
let mut current_offset = 0;
23322390

23332391
for tuple_ty in tuple_types.iter() {
2334-
let element_tree = typetree_from_ty(tcx, tuple_ty);
2392+
let element_tree =
2393+
typetree_from_ty_impl_inner(tcx, tuple_ty, depth + 1, visited, false);
23352394

23362395
let element_layout = tcx
23372396
.layout_of(ty::TypingEnv::fully_monomorphized().as_query_input(tuple_ty))
@@ -2367,7 +2426,8 @@ pub fn typetree_from_ty<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> TypeTree {
23672426

23682427
for (field_idx, field_def) in adt_def.all_fields().enumerate() {
23692428
let field_ty = field_def.ty(tcx, args);
2370-
let field_tree = typetree_from_ty(tcx, field_ty);
2429+
let field_tree =
2430+
typetree_from_ty_impl_inner(tcx, field_ty, depth + 1, visited, false);
23712431

23722432
let field_offset = layout.fields.offset(field_idx).bytes_usize();
23732433

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
// Check that enzyme_type attributes are present when TypeTree is enabled
22
// This verifies our TypeTree metadata attachment is working
33

4-
CHECK: define{{.*}}"enzyme_type"="{[-1]:Float@double}"{{.*}}@square{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,-1]:Float@double}"
4+
CHECK: define{{.*}}"enzyme_type"="{[-1]:Float@double}"{{.*}}@square{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,0]:Float@double}"
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
CHECK: define{{.*}}"enzyme_type"="{[-1]:Float@double}"{{.*}}@test_deep{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,0]:Float@double}"
2+
CHECK: define{{.*}}"enzyme_type"="{[-1]:Float@double}"{{.*}}@test_graph{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,0]:Integer, [-1,8]:Integer, [-1,16]:Integer, [-1,24]:Float@double}"
3+
CHECK: define{{.*}}"enzyme_type"="{[-1]:Float@double}"{{.*}}@test_node{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,0]:Float@double}"
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
//@ needs-enzyme
2+
//@ ignore-cross-compile
3+
4+
use run_make_support::{llvm_filecheck, rfs, rustc};
5+
6+
fn main() {
7+
rustc().input("test.rs").arg("-Zautodiff=Enable").emit("llvm-ir").run();
8+
llvm_filecheck().patterns("recursion.check").stdin_buf(rfs::read("test.ll")).run();
9+
}
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
#![feature(autodiff)]
2+
3+
use std::autodiff::autodiff_reverse;
4+
5+
// Self-referential struct to test recursion detection
6+
#[derive(Clone)]
7+
struct Node {
8+
value: f64,
9+
next: Option<Box<Node>>,
10+
}
11+
12+
// Mutually recursive structs to test cycle detection
13+
#[derive(Clone)]
14+
struct GraphNodeA {
15+
value: f64,
16+
connections: Vec<GraphNodeB>,
17+
}
18+
19+
#[derive(Clone)]
20+
struct GraphNodeB {
21+
weight: f64,
22+
target: Option<Box<GraphNodeA>>,
23+
}
24+
25+
#[autodiff_reverse(d_test_node, Duplicated, Active)]
26+
#[no_mangle]
27+
fn test_node(node: &Node) -> f64 {
28+
node.value * 2.0
29+
}
30+
31+
#[autodiff_reverse(d_test_graph, Duplicated, Active)]
32+
#[no_mangle]
33+
fn test_graph(a: &GraphNodeA) -> f64 {
34+
a.value * 3.0
35+
}
36+
37+
// Simple depth test - deeply nested but not circular
38+
#[derive(Clone)]
39+
struct Level1 {
40+
val: f64,
41+
next: Option<Box<Level2>>,
42+
}
43+
#[derive(Clone)]
44+
struct Level2 {
45+
val: f64,
46+
next: Option<Box<Level3>>,
47+
}
48+
#[derive(Clone)]
49+
struct Level3 {
50+
val: f64,
51+
next: Option<Box<Level4>>,
52+
}
53+
#[derive(Clone)]
54+
struct Level4 {
55+
val: f64,
56+
next: Option<Box<Level5>>,
57+
}
58+
#[derive(Clone)]
59+
struct Level5 {
60+
val: f64,
61+
next: Option<Box<Level6>>,
62+
}
63+
#[derive(Clone)]
64+
struct Level6 {
65+
val: f64,
66+
next: Option<Box<Level7>>,
67+
}
68+
#[derive(Clone)]
69+
struct Level7 {
70+
val: f64,
71+
next: Option<Box<Level8>>,
72+
}
73+
#[derive(Clone)]
74+
struct Level8 {
75+
val: f64,
76+
}
77+
78+
#[autodiff_reverse(d_test_deep, Duplicated, Active)]
79+
#[no_mangle]
80+
fn test_deep(deep: &Level1) -> f64 {
81+
deep.val * 4.0
82+
}
83+
84+
fn main() {
85+
let node = Node { value: 1.0, next: None };
86+
87+
let graph = GraphNodeA { value: 2.0, connections: vec![] };
88+
89+
let deep = Level1 { val: 5.0, next: None };
90+
91+
let mut d_node = Node { value: 0.0, next: None };
92+
93+
let mut d_graph = GraphNodeA { value: 0.0, connections: vec![] };
94+
95+
let mut d_deep = Level1 { val: 0.0, next: None };
96+
97+
let _result1 = d_test_node(&node, &mut d_node, 1.0);
98+
let _result2 = d_test_graph(&graph, &mut d_graph, 1.0);
99+
let _result3 = d_test_deep(&deep, &mut d_deep, 1.0);
100+
}
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
; Check that f128 TypeTree metadata is correctly generated
22
; Should show Float@fp128 for f128 values and Pointer for references
33

4-
CHECK: define{{.*}}"enzyme_type"="{[-1]:Float@fp128}"{{.*}}@test_f128{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,-1]:Float@fp128}"
4+
CHECK: define{{.*}}"enzyme_type"="{[-1]:Float@fp128}"{{.*}}@test_f128{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,0]:Float@fp128}"
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
; Check that f16 TypeTree metadata is correctly generated
22
; Should show Float@half for f16 values and Pointer for references
33

4-
CHECK: define{{.*}}"enzyme_type"="{[-1]:Float@half}"{{.*}}@test_f16{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,-1]:Float@half}"
4+
CHECK: define{{.*}}"enzyme_type"="{[-1]:Float@half}"{{.*}}@test_f16{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,0]:Float@half}"
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
; Check that f32 TypeTree metadata is correctly generated
22
; Should show Float@float for f32 values and Pointer for references
33

4-
CHECK: define{{.*}}"enzyme_type"="{[-1]:Float@float}"{{.*}}@test_f32{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,-1]:Float@float}"
4+
CHECK: define{{.*}}"enzyme_type"="{[-1]:Float@float}"{{.*}}@test_f32{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,0]:Float@float}"

0 commit comments

Comments
 (0)