Skip to content

Commit 7070455

Browse files
committed
autodiff: typetree recursive depth query from enzyme with fallback
Signed-off-by: Karan Janthe <[email protected]>
1 parent 4520926 commit 7070455

File tree

5 files changed

+21
-17
lines changed

5 files changed

+21
-17
lines changed

compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,6 @@ pub(crate) mod Enzyme_AD {
127127
);
128128
pub(crate) fn EnzymeTypeTreeToString(arg1: CTypeTreeRef) -> *const c_char;
129129
pub(crate) fn EnzymeTypeTreeToStringFree(arg1: *const c_char);
130-
pub(crate) fn EnzymeGetMaxTypeDepth() -> ::std::os::raw::c_uint;
131130
}
132131

133132
unsafe extern "C" {

compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1847,3 +1847,15 @@ extern "C" void LLVMRustSetNoSanitizeHWAddress(LLVMValueRef Global) {
18471847
MD.NoHWAddress = true;
18481848
GV.setSanitizerMetadata(MD);
18491849
}
1850+
1851+
#ifdef ENZYME
1852+
extern "C" {
1853+
extern llvm::cl::opt<unsigned> EnzymeMaxTypeDepth;
1854+
}
1855+
1856+
extern "C" size_t LLVMRustEnzymeGetMaxTypeDepth() { return EnzymeMaxTypeDepth; }
1857+
#else
1858+
extern "C" size_t LLVMRustEnzymeGetMaxTypeDepth() {
1859+
return 6; // Default fallback depth
1860+
}
1861+
#endif

compiler/rustc_middle/src/ty/mod.rs

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ pub use rustc_type_ir::solve::SizedTraitKind;
6363
pub use rustc_type_ir::*;
6464
#[allow(hidden_glob_reexports, unused_imports)]
6565
use rustc_type_ir::{InferCtxtLike, Interner};
66-
use tracing::{debug, instrument};
66+
use tracing::{debug, instrument, trace};
6767
pub use vtable::*;
6868
use {rustc_ast as ast, rustc_hir as hir};
6969

@@ -2256,26 +2256,19 @@ pub fn typetree_from_ty<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> TypeTree {
22562256
typetree_from_ty_inner(tcx, ty, 0, &mut visited)
22572257
}
22582258

2259+
/// Maximum recursion depth for TypeTree generation to prevent stack overflow
2260+
/// from pathological deeply nested types. Combined with cycle detection.
2261+
const MAX_TYPETREE_DEPTH: usize = 6;
2262+
22592263
/// Internal recursive function for TypeTree generation with cycle detection and depth limiting.
22602264
fn typetree_from_ty_inner<'tcx>(
22612265
tcx: TyCtxt<'tcx>,
22622266
ty: Ty<'tcx>,
22632267
depth: usize,
22642268
visited: &mut Vec<Ty<'tcx>>,
22652269
) -> TypeTree {
2266-
#[cfg(llvm_enzyme)]
2267-
{
2268-
unsafe extern "C" {
2269-
fn EnzymeGetMaxTypeDepth() -> ::std::os::raw::c_uint;
2270-
}
2271-
let max_depth = unsafe { EnzymeGetMaxTypeDepth() } as usize;
2272-
if depth > max_depth {
2273-
return TypeTree::new();
2274-
}
2275-
}
2276-
2277-
#[cfg(not(llvm_enzyme))]
2278-
if depth > 6 {
2270+
if depth >= MAX_TYPETREE_DEPTH {
2271+
trace!("typetree depth limit {} reached for type: {}", MAX_TYPETREE_DEPTH, ty);
22792272
return TypeTree::new();
22802273
}
22812274

src/llvm-project

0 commit comments

Comments
 (0)