Skip to content

Commit cc7c7b2

Browse files
committed
Use libloading to open Enzyme at runtime
1 parent 59c665b commit cc7c7b2

File tree

7 files changed

+201
-88
lines changed

7 files changed

+201
-88
lines changed

compiler/rustc_codegen_llvm/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ bitflags = "2.4.1"
1414
gimli = "0.31"
1515
itertools = "0.12"
1616
libc = "0.2"
17+
libloading = "0.8.8"
1718
measureme = "12.0.1"
1819
object = { version = "0.37.0", default-features = false, features = ["std", "read"] }
1920
rustc-demangle = "0.1.21"

compiler/rustc_codegen_llvm/src/back/lto.rs

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -523,30 +523,32 @@ fn thin_lto(
523523
}
524524

525525
fn enable_autodiff_settings(ad: &[config::AutoDiff]) {
526+
use crate::llvm::EnzymeWrapper;
527+
let enzyme = EnzymeWrapper::current();
526528
for val in ad {
527529
// We intentionally don't use a wildcard, to not forget handling anything new.
528530
match val {
529531
config::AutoDiff::PrintPerf => {
530-
llvm::set_print_perf(true);
532+
enzyme.lock().unwrap().set_print_perf(true);
531533
}
532534
config::AutoDiff::PrintAA => {
533-
llvm::set_print_activity(true);
535+
enzyme.lock().unwrap().set_print_activity(true);
534536
}
535537
config::AutoDiff::PrintTA => {
536-
llvm::set_print_type(true);
538+
enzyme.lock().unwrap().set_print_type(true);
537539
}
538540
config::AutoDiff::PrintTAFn(fun) => {
539-
llvm::set_print_type(true); // Enable general type printing
540-
llvm::set_print_type_fun(&fun); // Set specific function to analyze
541+
enzyme.lock().unwrap().set_print_type(true); // Enable general type printing
542+
enzyme.lock().unwrap().set_print_type_fun(&fun); // Set specific function to analyze
541543
}
542544
config::AutoDiff::Inline => {
543-
llvm::set_inline(true);
545+
enzyme.lock().unwrap().set_inline(true);
544546
}
545547
config::AutoDiff::LooseTypes => {
546-
llvm::set_loose_types(true);
548+
enzyme.lock().unwrap().set_loose_types(true);
547549
}
548550
config::AutoDiff::PrintSteps => {
549-
llvm::set_print(true);
551+
enzyme.lock().unwrap().set_print(true);
550552
}
551553
// We handle this in the PassWrapper.cpp
552554
config::AutoDiff::PrintPasses => {}
@@ -563,9 +565,9 @@ fn enable_autodiff_settings(ad: &[config::AutoDiff]) {
563565
}
564566
}
565567
// This helps with handling enums for now.
566-
llvm::set_strict_aliasing(false);
568+
enzyme.lock().unwrap().set_strict_aliasing(false);
567569
// FIXME(ZuseZ4): Test this, since it was added a long time ago.
568-
llvm::set_rust_rules(true);
570+
enzyme.lock().unwrap().set_rust_rules(true);
569571
}
570572

571573
pub(crate) fn run_pass_manager(

compiler/rustc_codegen_llvm/src/back/write.rs

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -665,6 +665,18 @@ pub(crate) unsafe fn llvm_optimize(
665665

666666
let llvm_plugins = config.llvm_plugins.join(",");
667667

668+
let enzyme_fn = if run_enzyme {
669+
let wrapper = crate::llvm::EnzymeWrapper::current();
670+
wrapper.lock().unwrap().registerEnzymeAndPassPipeline
671+
} else {
672+
//dbg!(run_enzyme);
673+
//dbg!(consider_ad);
674+
std::ptr::null()
675+
};
676+
677+
dbg!(&enzyme_fn);
678+
679+
668680
let result = unsafe {
669681
llvm::LLVMRustOptimize(
670682
module.module_llvm.llmod(),
@@ -684,7 +696,8 @@ pub(crate) unsafe fn llvm_optimize(
684696
vectorize_loop,
685697
config.no_builtins,
686698
config.emit_lifetime_markers,
687-
run_enzyme,
699+
enzyme_fn,
700+
//run_enzyme,
688701
print_before_enzyme,
689702
print_after_enzyme,
690703
print_passes,

compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs

Lines changed: 148 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,9 @@ unsafe extern "C" {
5151
pub(crate) fn LLVMGetNamedFunction(M: &Module, Name: *const c_char) -> Option<&Value>;
5252
}
5353

54+
55+
56+
5457
#[repr(C)]
5558
#[derive(Copy, Clone, PartialEq)]
5659
pub(crate) enum LLVMRustVerifierFailureAction {
@@ -62,76 +65,166 @@ pub(crate) enum LLVMRustVerifierFailureAction {
6265
#[cfg(llvm_enzyme)]
6366
pub(crate) use self::Enzyme_AD::*;
6467

65-
#[cfg(llvm_enzyme)]
68+
//#[cfg(llvm_enzyme)]
6669
pub(crate) mod Enzyme_AD {
67-
use std::ffi::{CString, c_char};
70+
use std::ffi::CString;
71+
//use std::ffi::{CString, c_char};
6872

6973
use libc::c_void;
7074

71-
unsafe extern "C" {
72-
pub(crate) fn EnzymeSetCLBool(arg1: *mut ::std::os::raw::c_void, arg2: u8);
73-
pub(crate) fn EnzymeSetCLString(arg1: *mut ::std::os::raw::c_void, arg2: *const c_char);
74-
}
75-
unsafe extern "C" {
76-
static mut EnzymePrintPerf: c_void;
77-
static mut EnzymePrintActivity: c_void;
78-
static mut EnzymePrintType: c_void;
79-
static mut EnzymeFunctionToAnalyze: c_void;
80-
static mut EnzymePrint: c_void;
81-
static mut EnzymeStrictAliasing: c_void;
82-
static mut looseTypeAnalysis: c_void;
83-
static mut EnzymeInline: c_void;
84-
static mut RustTypeRules: c_void;
75+
type SetFlag = unsafe extern "C" fn(*mut c_void, u8);
76+
77+
#[derive(Debug)]
78+
pub(crate) struct EnzymeFns {
79+
pub set_cl: SetFlag,
80+
}
81+
82+
#[derive(Debug)]
83+
pub(crate) struct EnzymeWrapper {
84+
EnzymePrintPerf: *mut c_void,
85+
EnzymePrintActivity: *mut c_void,
86+
EnzymePrintType: *mut c_void,
87+
EnzymeFunctionToAnalyze: *mut c_void,
88+
EnzymePrint: *mut c_void,
89+
EnzymeStrictAliasing: *mut c_void,
90+
looseTypeAnalysis: *mut c_void,
91+
EnzymeInline: *mut c_void,
92+
RustTypeRules: *mut c_void,
93+
94+
EnzymeSetCLBool: EnzymeFns,
95+
EnzymeSetCLString: EnzymeFns,
96+
pub registerEnzymeAndPassPipeline: *const c_void,
97+
lib: libloading::Library,
8598
}
86-
pub(crate) fn set_print_perf(print: bool) {
87-
unsafe {
88-
EnzymeSetCLBool(std::ptr::addr_of_mut!(EnzymePrintPerf), print as u8);
99+
fn call_dynamic() -> Result<EnzymeWrapper, Box<dyn std::error::Error>> {
100+
fn load_ptr(lib: &libloading::Library, bytes: &[u8]) -> Result<*mut c_void, Box<dyn std::error::Error>> {
101+
// Safety: symbol lookup from a loaded shared object.
102+
unsafe {
103+
let s: libloading::Symbol<'_, *mut c_void> = lib.get(bytes)?;
104+
let s = s.try_as_raw_ptr().unwrap();
105+
Ok(s as *mut c_void)
106+
}
107+
}
108+
dbg!("starting");
109+
dbg!("Loading Enzyme");
110+
let lib = unsafe {libloading::Library::new("/home/manuel/prog/rust/build/x86_64-unknown-linux-gnu/enzyme/lib/libEnzyme-21.so")?};
111+
dbg!("second");
112+
let EnzymeSetCLBool: libloading::Symbol<'_, SetFlag> = unsafe{lib.get(b"EnzymeSetCLBool")?};
113+
dbg!("third");
114+
let registerEnzymeAndPassPipeline =
115+
load_ptr(&lib, b"registerEnzymeAndPassPipeline").unwrap() as *const c_void;
116+
dbg!("fourth");
117+
//let EnzymeSetCLBool: libloading::Symbol<'_, unsafe extern "C" fn(&mut c_void, u8) -> ()> = unsafe{lib.get(b"registerEnzymeAndPassPipeline")?};
118+
//let EnzymeSetCLBool = unsafe {EnzymeSetCLBool.try_as_raw_ptr().unwrap()};
119+
let EnzymeSetCLString: libloading::Symbol<'_, SetFlag> = unsafe{ lib.get(b"EnzymeSetCLString")?};
120+
dbg!("done");
121+
//let EnzymeSetCLString = unsafe {EnzymeSetCLString.try_as_raw_ptr().unwrap()};
122+
123+
let EnzymePrintPerf = load_ptr(&lib, b"EnzymePrintPerf").unwrap();
124+
let EnzymePrintActivity = load_ptr(&lib, b"EnzymePrintActivity").unwrap();
125+
let EnzymePrintType = load_ptr(&lib, b"EnzymePrintType").unwrap();
126+
let EnzymeFunctionToAnalyze = load_ptr(&lib, b"EnzymeFunctionToAnalyze").unwrap();
127+
let EnzymePrint = load_ptr(&lib, b"EnzymePrint").unwrap();
128+
129+
let EnzymeStrictAliasing = load_ptr(&lib, b"EnzymeStrictAliasing").unwrap();
130+
let looseTypeAnalysis = load_ptr(&lib, b"looseTypeAnalysis").unwrap();
131+
let EnzymeInline = load_ptr(&lib, b"EnzymeInline").unwrap();
132+
let RustTypeRules = load_ptr(&lib, b"RustTypeRules").unwrap();
133+
134+
let wrap = EnzymeWrapper {
135+
EnzymePrintPerf,
136+
EnzymePrintActivity,
137+
EnzymePrintType,
138+
EnzymeFunctionToAnalyze,
139+
EnzymePrint,
140+
EnzymeStrictAliasing,
141+
looseTypeAnalysis,
142+
EnzymeInline,
143+
RustTypeRules,
144+
//EnzymeSetCLBool: EnzymeFns {set_cl: unsafe{*EnzymeSetCLBool}},
145+
//EnzymeSetCLString: EnzymeFns {set_cl: unsafe{*EnzymeSetCLString}},
146+
EnzymeSetCLBool: EnzymeFns {set_cl: *EnzymeSetCLBool},
147+
EnzymeSetCLString: EnzymeFns {set_cl: *EnzymeSetCLString},
148+
registerEnzymeAndPassPipeline,
149+
lib
150+
};
151+
dbg!(&wrap);
152+
Ok(wrap)
89153
}
90-
}
91-
pub(crate) fn set_print_activity(print: bool) {
92-
unsafe {
93-
EnzymeSetCLBool(std::ptr::addr_of_mut!(EnzymePrintActivity), print as u8);
154+
use std::sync::Mutex;
155+
unsafe impl Sync for EnzymeWrapper {}
156+
unsafe impl Send for EnzymeWrapper {}
157+
impl EnzymeWrapper {
158+
pub(crate) fn current() -> &'static Mutex<EnzymeWrapper> {
159+
use std::sync::OnceLock;
160+
static CELL: OnceLock<Mutex<EnzymeWrapper>> = OnceLock::new();
161+
fn init_enzyme() -> Mutex<EnzymeWrapper> {
162+
call_dynamic().unwrap().into()
163+
}
164+
CELL.get_or_init(|| init_enzyme())
94165
}
95-
}
96-
pub(crate) fn set_print_type(print: bool) {
97-
unsafe {
98-
EnzymeSetCLBool(std::ptr::addr_of_mut!(EnzymePrintType), print as u8);
166+
pub(crate) fn set_print_perf(&mut self, print: bool) {
167+
unsafe {
168+
//(self.EnzymeSetCLBool.set_cl)(self.EnzymePrintPerf, print as u8);
169+
//(self.EnzymeSetCLBool)(std::ptr::addr_of_mut!(self.EnzymePrintPerf), print as u8);
170+
}
99171
}
100-
}
101-
pub(crate) fn set_print_type_fun(fun_name: &str) {
102-
let c_fun_name = CString::new(fun_name).unwrap();
103-
unsafe {
104-
EnzymeSetCLString(
105-
std::ptr::addr_of_mut!(EnzymeFunctionToAnalyze),
106-
c_fun_name.as_ptr() as *const c_char,
107-
);
172+
173+
pub(crate) fn set_print_activity(&mut self, print: bool) {
174+
unsafe {
175+
//(self.EnzymeSetCLBool.set_cl)(self.EnzymePrintActivity, print as u8);
176+
//(self.EnzymeSetCLBool)(std::ptr::addr_of_mut!(self.EnzymePrintActivity), print as u8);
177+
}
108178
}
109-
}
110-
pub(crate) fn set_print(print: bool) {
111-
unsafe {
112-
EnzymeSetCLBool(std::ptr::addr_of_mut!(EnzymePrint), print as u8);
179+
180+
pub(crate) fn set_print_type(&mut self, print: bool) {
181+
unsafe {
182+
// (self.EnzymeSetCLBool.set_cl)(self.EnzymePrintType, print as u8);
183+
}
113184
}
114-
}
115-
pub(crate) fn set_strict_aliasing(strict: bool) {
116-
unsafe {
117-
EnzymeSetCLBool(std::ptr::addr_of_mut!(EnzymeStrictAliasing), strict as u8);
185+
186+
pub(crate) fn set_print_type_fun(&mut self, fun_name: &str) {
187+
let _c_fun_name = CString::new(fun_name).unwrap();
188+
//unsafe {
189+
// (self.EnzymeSetCLString.set_cl)(
190+
// self.EnzymeFunctionToAnalyze,
191+
// c_fun_name.as_ptr() as *const c_char,
192+
// );
193+
//}
118194
}
119-
}
120-
pub(crate) fn set_loose_types(loose: bool) {
121-
unsafe {
122-
EnzymeSetCLBool(std::ptr::addr_of_mut!(looseTypeAnalysis), loose as u8);
195+
196+
pub(crate) fn set_print(&mut self, print: bool) {
197+
unsafe {
198+
//(self.EnzymeSetCLBool.set_cl)(self.EnzymePrint, print as u8);
199+
}
123200
}
124-
}
125-
pub(crate) fn set_inline(val: bool) {
126-
unsafe {
127-
EnzymeSetCLBool(std::ptr::addr_of_mut!(EnzymeInline), val as u8);
201+
202+
pub(crate) fn set_strict_aliasing(&mut self, strict: bool) {
203+
unsafe {
204+
//(self.EnzymeSetCLBool.set_cl)(self.EnzymeStrictAliasing, strict as u8);
205+
}
128206
}
129-
}
130-
pub(crate) fn set_rust_rules(val: bool) {
131-
unsafe {
132-
EnzymeSetCLBool(std::ptr::addr_of_mut!(RustTypeRules), val as u8);
207+
208+
pub(crate) fn set_loose_types(&mut self, loose: bool) {
209+
unsafe {
210+
//(self.EnzymeSetCLBool.set_cl)(self.looseTypeAnalysis, loose as u8);
211+
}
212+
}
213+
214+
pub(crate) fn set_inline(&mut self, val: bool) {
215+
unsafe {
216+
//(self.EnzymeSetCLBool.set_cl)(self.EnzymeInline, val as u8);
217+
}
218+
}
219+
220+
pub(crate) fn set_rust_rules(&mut self, val: bool) {
221+
unsafe {
222+
//(self.EnzymeSetCLBool.set_cl)(self.RustTypeRules, val as u8);
223+
}
133224
}
134225
}
226+
227+
135228
}
136229

137230
#[cfg(not(llvm_enzyme))]

compiler/rustc_codegen_llvm/src/llvm/ffi.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2518,7 +2518,8 @@ unsafe extern "C" {
25182518
LoopVectorize: bool,
25192519
DisableSimplifyLibCalls: bool,
25202520
EmitLifetimeMarkers: bool,
2521-
RunEnzyme: bool,
2521+
RunEnzyme: *const c_void,
2522+
//RunEnzyme: bool,
25222523
PrintBeforeEnzyme: bool,
25232524
PrintAfterEnzyme: bool,
25242525
PrintPasses: bool,

compiler/rustc_llvm/llvm-wrapper/PassWrapper.cpp

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -704,19 +704,21 @@ struct LLVMRustSanitizerOptions {
704704
extern "C" void registerEnzymeAndPassPipeline(llvm::PassBuilder &PB,
705705
/* augmentPassBuilder */ bool);
706706

707-
extern "C" {
708-
extern llvm::cl::opt<std::string> EnzymeFunctionToAnalyze;
709-
}
707+
//extern "C" {
708+
//extern llvm::cl::opt<std::string> EnzymeFunctionToAnalyze;
709+
//}
710710
#endif
711711

712+
extern "C" typedef void (*registerEnzymeAndPassPipelineFn)(llvm::PassBuilder &PB, bool augment);
713+
712714
extern "C" LLVMRustResult LLVMRustOptimize(
713715
LLVMModuleRef ModuleRef, LLVMTargetMachineRef TMRef,
714716
LLVMRustPassBuilderOptLevel OptLevelRust, LLVMRustOptStage OptStage,
715717
bool IsLinkerPluginLTO, bool NoPrepopulatePasses, bool VerifyIR,
716718
bool LintIR, LLVMRustThinLTOBuffer **ThinLTOBufferRef, bool EmitThinLTO,
717719
bool EmitThinLTOSummary, bool MergeFunctions, bool UnrollLoops,
718720
bool SLPVectorize, bool LoopVectorize, bool DisableSimplifyLibCalls,
719-
bool EmitLifetimeMarkers, bool RunEnzyme, bool PrintBeforeEnzyme,
721+
bool EmitLifetimeMarkers, registerEnzymeAndPassPipelineFn EnzymePtr, bool PrintBeforeEnzyme,
720722
bool PrintAfterEnzyme, bool PrintPasses,
721723
LLVMRustSanitizerOptions *SanitizerOptions, const char *PGOGenPath,
722724
const char *PGOUsePath, bool InstrumentCoverage,
@@ -1061,29 +1063,30 @@ extern "C" LLVMRustResult LLVMRustOptimize(
10611063

10621064
// now load "-enzyme" pass:
10631065
#ifdef ENZYME
1064-
if (RunEnzyme) {
1066+
if (EnzymePtr) {
10651067

10661068
if (PrintBeforeEnzyme) {
10671069
// Handle the Rust flag `-Zautodiff=PrintModBefore`.
10681070
std::string Banner = "Module before EnzymeNewPM";
10691071
MPM.addPass(PrintModulePass(outs(), Banner, true, false));
10701072
}
10711073

1072-
registerEnzymeAndPassPipeline(PB, false);
1074+
EnzymePtr(PB, false);
1075+
//registerEnzymeAndPassPipeline(PB, false);
10731076
if (auto Err = PB.parsePassPipeline(MPM, "enzyme")) {
10741077
std::string ErrMsg = toString(std::move(Err));
10751078
LLVMRustSetLastError(ErrMsg.c_str());
10761079
return LLVMRustResult::Failure;
10771080
}
10781081

10791082
// Check if PrintTAFn was used and add type analysis pass if needed
1080-
if (!EnzymeFunctionToAnalyze.empty()) {
1081-
if (auto Err = PB.parsePassPipeline(MPM, "print-type-analysis")) {
1082-
std::string ErrMsg = toString(std::move(Err));
1083-
LLVMRustSetLastError(ErrMsg.c_str());
1084-
return LLVMRustResult::Failure;
1085-
}
1086-
}
1083+
//if (!EnzymeFunctionToAnalyze.empty()) {
1084+
// if (auto Err = PB.parsePassPipeline(MPM, "print-type-analysis")) {
1085+
// std::string ErrMsg = toString(std::move(Err));
1086+
// LLVMRustSetLastError(ErrMsg.c_str());
1087+
// return LLVMRustResult::Failure;
1088+
// }
1089+
//}
10871090

10881091
if (PrintAfterEnzyme) {
10891092
// Handle the Rust flag `-Zautodiff=PrintModAfter`.

0 commit comments

Comments
 (0)