Skip to content

Commit 595473f

Browse files
committed
passing some enzyme tests
1 parent 59c665b commit 595473f

File tree

6 files changed

+58
-36
lines changed

6 files changed

+58
-36
lines changed

compiler/rustc_codegen_llvm/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ bitflags = "2.4.1"
1414
gimli = "0.31"
1515
itertools = "0.12"
1616
libc = "0.2"
17+
libloading = "0.8.8"
1718
measureme = "12.0.1"
1819
object = { version = "0.37.0", default-features = false, features = ["std", "read"] }
1920
rustc-demangle = "0.1.21"

compiler/rustc_codegen_llvm/src/back/write.rs

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -665,6 +665,22 @@ pub(crate) unsafe fn llvm_optimize(
665665

666666
let llvm_plugins = config.llvm_plugins.join(",");
667667

668+
fn call_dynamic() -> Result<*const c_void, Box<dyn std::error::Error>> {
669+
unsafe {
670+
let lib = libloading::Library::new("/home/manuel/prog/rust/build/x86_64-unknown-linux-gnu/enzyme/lib/libEnzyme-21.so")?;
671+
let func: libloading::Symbol<'_, c_void> = lib.get(b"registerEnzymeAndPassPipeline")?;
672+
let func = func.try_as_raw_ptr().unwrap();
673+
dbg!(func);
674+
Ok(func as *const c_void)
675+
}
676+
}
677+
let enzyme_fn = if cfg!(llvm_enzyme) && run_enzyme {
678+
call_dynamic().unwrap_or(std::ptr::null())
679+
} else {
680+
std::ptr::null()
681+
};
682+
683+
668684
let result = unsafe {
669685
llvm::LLVMRustOptimize(
670686
module.module_llvm.llmod(),
@@ -684,7 +700,8 @@ pub(crate) unsafe fn llvm_optimize(
684700
vectorize_loop,
685701
config.no_builtins,
686702
config.emit_lifetime_markers,
687-
run_enzyme,
703+
enzyme_fn,
704+
//run_enzyme,
688705
print_before_enzyme,
689706
print_after_enzyme,
690707
print_passes,

compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -59,10 +59,10 @@ pub(crate) enum LLVMRustVerifierFailureAction {
5959
LLVMReturnStatusAction = 2,
6060
}
6161

62-
#[cfg(llvm_enzyme)]
62+
#[cfg(not(llvm_enzyme))]
6363
pub(crate) use self::Enzyme_AD::*;
6464

65-
#[cfg(llvm_enzyme)]
65+
#[cfg(not(llvm_enzyme))]
6666
pub(crate) mod Enzyme_AD {
6767
use std::ffi::{CString, c_char};
6868

@@ -134,38 +134,38 @@ pub(crate) mod Enzyme_AD {
134134
}
135135
}
136136

137-
#[cfg(not(llvm_enzyme))]
137+
#[cfg(llvm_enzyme)]
138138
pub(crate) use self::Fallback_AD::*;
139139

140-
#[cfg(not(llvm_enzyme))]
140+
#[cfg(llvm_enzyme)]
141141
pub(crate) mod Fallback_AD {
142142
#![allow(unused_variables)]
143143

144144
pub(crate) fn set_inline(val: bool) {
145-
unimplemented!()
145+
//unimplemented!()
146146
}
147147
pub(crate) fn set_print_perf(print: bool) {
148-
unimplemented!()
148+
//unimplemented!()
149149
}
150150
pub(crate) fn set_print_activity(print: bool) {
151-
unimplemented!()
151+
//unimplemented!()
152152
}
153153
pub(crate) fn set_print_type(print: bool) {
154-
unimplemented!()
154+
//unimplemented!()
155155
}
156156
pub(crate) fn set_print_type_fun(fun_name: &str) {
157-
unimplemented!()
157+
//unimplemented!()
158158
}
159159
pub(crate) fn set_print(print: bool) {
160-
unimplemented!()
160+
//unimplemented!()
161161
}
162162
pub(crate) fn set_strict_aliasing(strict: bool) {
163-
unimplemented!()
163+
//unimplemented!()
164164
}
165165
pub(crate) fn set_loose_types(loose: bool) {
166-
unimplemented!()
166+
//unimplemented!()
167167
}
168168
pub(crate) fn set_rust_rules(val: bool) {
169-
unimplemented!()
169+
//unimplemented!()
170170
}
171171
}

compiler/rustc_codegen_llvm/src/llvm/ffi.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2518,7 +2518,8 @@ unsafe extern "C" {
25182518
LoopVectorize: bool,
25192519
DisableSimplifyLibCalls: bool,
25202520
EmitLifetimeMarkers: bool,
2521-
RunEnzyme: bool,
2521+
RunEnzyme: *const c_void,
2522+
//RunEnzyme: bool,
25222523
PrintBeforeEnzyme: bool,
25232524
PrintAfterEnzyme: bool,
25242525
PrintPasses: bool,

compiler/rustc_llvm/llvm-wrapper/PassWrapper.cpp

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -704,19 +704,21 @@ struct LLVMRustSanitizerOptions {
704704
extern "C" void registerEnzymeAndPassPipeline(llvm::PassBuilder &PB,
705705
/* augmentPassBuilder */ bool);
706706

707-
extern "C" {
708-
extern llvm::cl::opt<std::string> EnzymeFunctionToAnalyze;
709-
}
707+
//extern "C" {
708+
//extern llvm::cl::opt<std::string> EnzymeFunctionToAnalyze;
709+
//}
710710
#endif
711711

712+
extern "C" typedef void (*registerEnzymeAndPassPipelineFn)(llvm::PassBuilder &PB, bool augment);
713+
712714
extern "C" LLVMRustResult LLVMRustOptimize(
713715
LLVMModuleRef ModuleRef, LLVMTargetMachineRef TMRef,
714716
LLVMRustPassBuilderOptLevel OptLevelRust, LLVMRustOptStage OptStage,
715717
bool IsLinkerPluginLTO, bool NoPrepopulatePasses, bool VerifyIR,
716718
bool LintIR, LLVMRustThinLTOBuffer **ThinLTOBufferRef, bool EmitThinLTO,
717719
bool EmitThinLTOSummary, bool MergeFunctions, bool UnrollLoops,
718720
bool SLPVectorize, bool LoopVectorize, bool DisableSimplifyLibCalls,
719-
bool EmitLifetimeMarkers, bool RunEnzyme, bool PrintBeforeEnzyme,
721+
bool EmitLifetimeMarkers, registerEnzymeAndPassPipelineFn EnzymePtr, bool PrintBeforeEnzyme,
720722
bool PrintAfterEnzyme, bool PrintPasses,
721723
LLVMRustSanitizerOptions *SanitizerOptions, const char *PGOGenPath,
722724
const char *PGOUsePath, bool InstrumentCoverage,
@@ -1061,29 +1063,30 @@ extern "C" LLVMRustResult LLVMRustOptimize(
10611063

10621064
// now load "-enzyme" pass:
10631065
#ifdef ENZYME
1064-
if (RunEnzyme) {
1066+
if (EnzymePtr) {
10651067

10661068
if (PrintBeforeEnzyme) {
10671069
// Handle the Rust flag `-Zautodiff=PrintModBefore`.
10681070
std::string Banner = "Module before EnzymeNewPM";
10691071
MPM.addPass(PrintModulePass(outs(), Banner, true, false));
10701072
}
10711073

1072-
registerEnzymeAndPassPipeline(PB, false);
1074+
EnzymePtr(PB, false);
1075+
//registerEnzymeAndPassPipeline(PB, false);
10731076
if (auto Err = PB.parsePassPipeline(MPM, "enzyme")) {
10741077
std::string ErrMsg = toString(std::move(Err));
10751078
LLVMRustSetLastError(ErrMsg.c_str());
10761079
return LLVMRustResult::Failure;
10771080
}
10781081

10791082
// Check if PrintTAFn was used and add type analysis pass if needed
1080-
if (!EnzymeFunctionToAnalyze.empty()) {
1081-
if (auto Err = PB.parsePassPipeline(MPM, "print-type-analysis")) {
1082-
std::string ErrMsg = toString(std::move(Err));
1083-
LLVMRustSetLastError(ErrMsg.c_str());
1084-
return LLVMRustResult::Failure;
1085-
}
1086-
}
1083+
//if (!EnzymeFunctionToAnalyze.empty()) {
1084+
// if (auto Err = PB.parsePassPipeline(MPM, "print-type-analysis")) {
1085+
// std::string ErrMsg = toString(std::move(Err));
1086+
// LLVMRustSetLastError(ErrMsg.c_str());
1087+
// return LLVMRustResult::Failure;
1088+
// }
1089+
//}
10871090

10881091
if (PrintAfterEnzyme) {
10891092
// Handle the Rust flag `-Zautodiff=PrintModAfter`.

src/bootstrap/src/core/build_steps/compile.rs

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1205,14 +1205,14 @@ pub fn rustc_cargo(
12051205
// We want to link against registerEnzyme and in the future we want to use additional
12061206
// functionality from Enzyme core. For that we need to link against Enzyme.
12071207
if builder.config.llvm_enzyme {
1208-
let arch = builder.build.host_target;
1209-
let enzyme_dir = builder.build.out.join(arch).join("enzyme").join("lib");
1210-
cargo.rustflag("-L").rustflag(enzyme_dir.to_str().expect("Invalid path"));
1211-
1212-
if let Some(llvm_config) = builder.llvm_config(builder.config.host_target) {
1213-
let llvm_version_major = llvm::get_llvm_version_major(builder, &llvm_config);
1214-
cargo.rustflag("-l").rustflag(&format!("Enzyme-{llvm_version_major}"));
1215-
}
1208+
//let arch = builder.build.host_target;
1209+
//let enzyme_dir = builder.build.out.join(arch).join("enzyme").join("lib");
1210+
//cargo.rustflag("-L").rustflag(enzyme_dir.to_str().expect("Invalid path"));
1211+
1212+
//if let Some(llvm_config) = builder.llvm_config(builder.config.host_target) {
1213+
// let llvm_version_major = llvm::get_llvm_version_major(builder, &llvm_config);
1214+
// cargo.rustflag("-l").rustflag(&format!("Enzyme-{llvm_version_major}"));
1215+
//}
12161216
}
12171217

12181218
// Building with protected visibility reduces the number of dynamic relocations needed, giving

0 commit comments

Comments
 (0)