Skip to content

Commit 6339a2c

Browse files
committed
autodiff: typetree recursive depth query from enzyme with fallback
Signed-off-by: Karan Janthe <[email protected]>
1 parent 218c95d commit 6339a2c

File tree

3 files changed

+16
-5
lines changed

3 files changed

+16
-5
lines changed

compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,6 @@ pub(crate) mod Enzyme_AD {
127127
);
128128
pub(crate) fn EnzymeTypeTreeToString(arg1: CTypeTreeRef) -> *const c_char;
129129
pub(crate) fn EnzymeTypeTreeToStringFree(arg1: *const c_char);
130-
pub(crate) fn EnzymeGetMaxTypeDepth() -> ::std::os::raw::c_uint;
131130
}
132131

133132
unsafe extern "C" {

compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1947,3 +1947,15 @@ extern "C" void LLVMRustSetNoSanitizeHWAddress(LLVMValueRef Global) {
19471947
MD.NoHWAddress = true;
19481948
GV.setSanitizerMetadata(MD);
19491949
}
1950+
1951+
#ifdef ENZYME
1952+
extern "C" {
1953+
extern llvm::cl::opt<unsigned> EnzymeMaxTypeDepth;
1954+
}
1955+
1956+
extern "C" size_t LLVMRustEnzymeGetMaxTypeDepth() { return EnzymeMaxTypeDepth; }
1957+
#else
1958+
extern "C" size_t LLVMRustEnzymeGetMaxTypeDepth() {
1959+
return 6; // Default fallback depth
1960+
}
1961+
#endif

compiler/rustc_middle/src/ty/mod.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2272,16 +2272,16 @@ fn typetree_from_ty_inner<'tcx>(
22722272
#[cfg(llvm_enzyme)]
22732273
{
22742274
unsafe extern "C" {
2275-
fn EnzymeGetMaxTypeDepth() -> ::std::os::raw::c_uint;
2275+
fn LLVMRustEnzymeGetMaxTypeDepth() -> usize;
22762276
}
2277-
let max_depth = unsafe { EnzymeGetMaxTypeDepth() } as usize;
2278-
if depth > max_depth {
2277+
let max_depth = unsafe { LLVMRustEnzymeGetMaxTypeDepth() };
2278+
if depth >= max_depth {
22792279
return TypeTree::new();
22802280
}
22812281
}
22822282

22832283
#[cfg(not(llvm_enzyme))]
2284-
if depth > 6 {
2284+
if depth >= 6 {
22852285
return TypeTree::new();
22862286
}
22872287

0 commit comments

Comments
 (0)