Skip to content

Commit cc0b6c4

Browse files
committed
autodiff: typetree recursive depth query from enzyme with fallback
Signed-off-by: Karan Janthe <[email protected]>
1 parent 4520926 commit cc0b6c4

File tree

5 files changed

+18
-7
lines changed

5 files changed

+18
-7
lines changed

compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,6 @@ pub(crate) mod Enzyme_AD {
127127
);
128128
pub(crate) fn EnzymeTypeTreeToString(arg1: CTypeTreeRef) -> *const c_char;
129129
pub(crate) fn EnzymeTypeTreeToStringFree(arg1: *const c_char);
130-
pub(crate) fn EnzymeGetMaxTypeDepth() -> ::std::os::raw::c_uint;
131130
}
132131

133132
unsafe extern "C" {

compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1847,3 +1847,15 @@ extern "C" void LLVMRustSetNoSanitizeHWAddress(LLVMValueRef Global) {
18471847
MD.NoHWAddress = true;
18481848
GV.setSanitizerMetadata(MD);
18491849
}
1850+
1851+
#ifdef ENZYME
1852+
extern "C" {
1853+
extern llvm::cl::opt<unsigned> EnzymeMaxTypeDepth;
1854+
}
1855+
1856+
extern "C" size_t LLVMRustEnzymeGetMaxTypeDepth() { return EnzymeMaxTypeDepth; }
1857+
#else
1858+
extern "C" size_t LLVMRustEnzymeGetMaxTypeDepth() {
1859+
return 6; // Default fallback depth
1860+
}
1861+
#endif

compiler/rustc_middle/src/ty/mod.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2266,16 +2266,16 @@ fn typetree_from_ty_inner<'tcx>(
22662266
#[cfg(llvm_enzyme)]
22672267
{
22682268
unsafe extern "C" {
2269-
fn EnzymeGetMaxTypeDepth() -> ::std::os::raw::c_uint;
2269+
fn LLVMRustEnzymeGetMaxTypeDepth() -> usize;
22702270
}
2271-
let max_depth = unsafe { EnzymeGetMaxTypeDepth() } as usize;
2272-
if depth > max_depth {
2271+
let max_depth = unsafe { LLVMRustEnzymeGetMaxTypeDepth() };
2272+
if depth >= max_depth {
22732273
return TypeTree::new();
22742274
}
22752275
}
22762276

22772277
#[cfg(not(llvm_enzyme))]
2278-
if depth > 6 {
2278+
if depth >= 6 {
22792279
return TypeTree::new();
22802280
}
22812281

src/llvm-project

src/tools/cargo

Submodule cargo updated 582 files

0 commit comments

Comments
 (0)