Skip to content

Commit 3ba5f19

Browse files
committed
autodiff: typetree recursive depth query from enzyme with fallback
Signed-off-by: Karan Janthe <[email protected]>
1 parent 4520926 commit 3ba5f19

File tree

6 files changed

+26
-22
lines changed

6 files changed

+26
-22
lines changed

compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,6 @@ pub(crate) mod Enzyme_AD {
127127
);
128128
pub(crate) fn EnzymeTypeTreeToString(arg1: CTypeTreeRef) -> *const c_char;
129129
pub(crate) fn EnzymeTypeTreeToStringFree(arg1: *const c_char);
130-
pub(crate) fn EnzymeGetMaxTypeDepth() -> ::std::os::raw::c_uint;
131130
}
132131

133132
unsafe extern "C" {

compiler/rustc_codegen_llvm/src/typetree.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
use rustc_ast::expand::typetree::FncTree;
2-
#[cfg(llvm_enzyme)]
2+
#[cfg(feature = "llvm_enzyme")]
33
use {
44
crate::attributes,
55
rustc_ast::expand::typetree::TypeTree as RustTypeTree,
@@ -8,7 +8,7 @@ use {
88

99
use crate::llvm::{self, Value};
1010

11-
#[cfg(llvm_enzyme)]
11+
#[cfg(feature = "llvm_enzyme")]
1212
fn to_enzyme_typetree(
1313
rust_typetree: RustTypeTree,
1414
_data_layout: &str,
@@ -18,7 +18,7 @@ fn to_enzyme_typetree(
1818
process_typetree_recursive(&mut enzyme_tt, &rust_typetree, &[], llcx);
1919
enzyme_tt
2020
}
21-
#[cfg(llvm_enzyme)]
21+
#[cfg(feature = "llvm_enzyme")]
2222
fn process_typetree_recursive(
2323
enzyme_tt: &mut llvm::TypeTree,
2424
rust_typetree: &RustTypeTree,
@@ -56,7 +56,7 @@ fn process_typetree_recursive(
5656
}
5757
}
5858

59-
#[cfg(llvm_enzyme)]
59+
#[cfg(feature = "llvm_enzyme")]
6060
pub(crate) fn add_tt<'ll>(
6161
llmod: &'ll llvm::Module,
6262
llcx: &'ll llvm::Context,
@@ -111,7 +111,7 @@ pub(crate) fn add_tt<'ll>(
111111
}
112112
}
113113

114-
#[cfg(not(llvm_enzyme))]
114+
#[cfg(not(feature = "llvm_enzyme"))]
115115
pub(crate) fn add_tt<'ll>(
116116
_llmod: &'ll llvm::Module,
117117
_llcx: &'ll llvm::Context,

compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1847,3 +1847,15 @@ extern "C" void LLVMRustSetNoSanitizeHWAddress(LLVMValueRef Global) {
18471847
MD.NoHWAddress = true;
18481848
GV.setSanitizerMetadata(MD);
18491849
}
1850+
1851+
#ifdef ENZYME
1852+
extern "C" {
1853+
extern llvm::cl::opt<unsigned> EnzymeMaxTypeDepth;
1854+
}
1855+
1856+
extern "C" size_t LLVMRustEnzymeGetMaxTypeDepth() { return EnzymeMaxTypeDepth; }
1857+
#else
1858+
extern "C" size_t LLVMRustEnzymeGetMaxTypeDepth() {
1859+
return 6; // Default fallback depth
1860+
}
1861+
#endif

compiler/rustc_middle/src/ty/mod.rs

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ pub use rustc_type_ir::solve::SizedTraitKind;
6363
pub use rustc_type_ir::*;
6464
#[allow(hidden_glob_reexports, unused_imports)]
6565
use rustc_type_ir::{InferCtxtLike, Interner};
66-
use tracing::{debug, instrument};
66+
use tracing::{debug, instrument, trace};
6767
pub use vtable::*;
6868
use {rustc_ast as ast, rustc_hir as hir};
6969

@@ -2256,26 +2256,19 @@ pub fn typetree_from_ty<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> TypeTree {
22562256
typetree_from_ty_inner(tcx, ty, 0, &mut visited)
22572257
}
22582258

2259+
/// Maximum recursion depth for TypeTree generation to prevent stack overflow
2260+
/// from pathological deeply nested types. Combined with cycle detection.
2261+
const MAX_TYPETREE_DEPTH: usize = 6;
2262+
22592263
/// Internal recursive function for TypeTree generation with cycle detection and depth limiting.
22602264
fn typetree_from_ty_inner<'tcx>(
22612265
tcx: TyCtxt<'tcx>,
22622266
ty: Ty<'tcx>,
22632267
depth: usize,
22642268
visited: &mut Vec<Ty<'tcx>>,
22652269
) -> TypeTree {
2266-
#[cfg(llvm_enzyme)]
2267-
{
2268-
unsafe extern "C" {
2269-
fn EnzymeGetMaxTypeDepth() -> ::std::os::raw::c_uint;
2270-
}
2271-
let max_depth = unsafe { EnzymeGetMaxTypeDepth() } as usize;
2272-
if depth > max_depth {
2273-
return TypeTree::new();
2274-
}
2275-
}
2276-
2277-
#[cfg(not(llvm_enzyme))]
2278-
if depth > 6 {
2270+
if depth >= MAX_TYPETREE_DEPTH {
2271+
trace!("typetree depth limit {} reached for type: {}", MAX_TYPETREE_DEPTH, ty);
22792272
return TypeTree::new();
22802273
}
22812274

src/llvm-project

0 commit comments

Comments
 (0)