Skip to content

Commit 83a5a89

Browse files
committed
move to OnceLock, compiles but fails all tests and SIGSEGV
1 parent 595473f commit 83a5a89

File tree

3 files changed

+178
-89
lines changed

3 files changed

+178
-89
lines changed

compiler/rustc_codegen_llvm/src/back/lto.rs

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -523,30 +523,32 @@ fn thin_lto(
523523
}
524524

525525
fn enable_autodiff_settings(ad: &[config::AutoDiff]) {
526+
use crate::llvm::EnzymeWrapper;
527+
let enzyme = EnzymeWrapper::current();
526528
for val in ad {
527529
// We intentionally don't use a wildcard, to not forget handling anything new.
528530
match val {
529531
config::AutoDiff::PrintPerf => {
530-
llvm::set_print_perf(true);
532+
enzyme.lock().unwrap().set_print_perf(true);
531533
}
532534
config::AutoDiff::PrintAA => {
533-
llvm::set_print_activity(true);
535+
enzyme.lock().unwrap().set_print_activity(true);
534536
}
535537
config::AutoDiff::PrintTA => {
536-
llvm::set_print_type(true);
538+
enzyme.lock().unwrap().set_print_type(true);
537539
}
538540
config::AutoDiff::PrintTAFn(fun) => {
539-
llvm::set_print_type(true); // Enable general type printing
540-
llvm::set_print_type_fun(&fun); // Set specific function to analyze
541+
enzyme.lock().unwrap().set_print_type(true); // Enable general type printing
542+
enzyme.lock().unwrap().set_print_type_fun(&fun); // Set specific function to analyze
541543
}
542544
config::AutoDiff::Inline => {
543-
llvm::set_inline(true);
545+
enzyme.lock().unwrap().set_inline(true);
544546
}
545547
config::AutoDiff::LooseTypes => {
546-
llvm::set_loose_types(true);
548+
enzyme.lock().unwrap().set_loose_types(true);
547549
}
548550
config::AutoDiff::PrintSteps => {
549-
llvm::set_print(true);
551+
enzyme.lock().unwrap().set_print(true);
550552
}
551553
// We handle this in the PassWrapper.cpp
552554
config::AutoDiff::PrintPasses => {}
@@ -563,9 +565,9 @@ fn enable_autodiff_settings(ad: &[config::AutoDiff]) {
563565
}
564566
}
565567
// This helps with handling enums for now.
566-
llvm::set_strict_aliasing(false);
568+
enzyme.lock().unwrap().set_strict_aliasing(false);
567569
// FIXME(ZuseZ4): Test this, since it was added a long time ago.
568-
llvm::set_rust_rules(true);
570+
enzyme.lock().unwrap().set_rust_rules(true);
569571
}
570572

571573
pub(crate) fn run_pass_manager(

compiler/rustc_codegen_llvm/src/back/write.rs

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -665,21 +665,17 @@ pub(crate) unsafe fn llvm_optimize(
665665

666666
let llvm_plugins = config.llvm_plugins.join(",");
667667

668-
fn call_dynamic() -> Result<*const c_void, Box<dyn std::error::Error>> {
669-
unsafe {
670-
let lib = libloading::Library::new("/home/manuel/prog/rust/build/x86_64-unknown-linux-gnu/enzyme/lib/libEnzyme-21.so")?;
671-
let func: libloading::Symbol<'_, c_void> = lib.get(b"registerEnzymeAndPassPipeline")?;
672-
let func = func.try_as_raw_ptr().unwrap();
673-
dbg!(func);
674-
Ok(func as *const c_void)
675-
}
676-
}
677-
let enzyme_fn = if cfg!(llvm_enzyme) && run_enzyme {
678-
call_dynamic().unwrap_or(std::ptr::null())
668+
let enzyme_fn = if run_enzyme {
669+
let wrapper = crate::llvm::EnzymeWrapper::current();
670+
wrapper.lock().unwrap().registerEnzymeAndPassPipeline
679671
} else {
680-
std::ptr::null()
672+
//dbg!(run_enzyme);
673+
//dbg!(consider_ad);
674+
std::ptr::null()
681675
};
682676

677+
dbg!(&enzyme_fn);
678+
683679

684680
let result = unsafe {
685681
llvm::LLVMRustOptimize(

compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs

Lines changed: 158 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,9 @@ unsafe extern "C" {
5151
pub(crate) fn LLVMGetNamedFunction(M: &Module, Name: *const c_char) -> Option<&Value>;
5252
}
5353

54+
55+
56+
5457
#[repr(C)]
5558
#[derive(Copy, Clone, PartialEq)]
5659
pub(crate) enum LLVMRustVerifierFailureAction {
@@ -59,113 +62,201 @@ pub(crate) enum LLVMRustVerifierFailureAction {
5962
LLVMReturnStatusAction = 2,
6063
}
6164

62-
#[cfg(not(llvm_enzyme))]
65+
#[cfg(llvm_enzyme)]
6366
pub(crate) use self::Enzyme_AD::*;
6467

65-
#[cfg(not(llvm_enzyme))]
68+
//#[cfg(llvm_enzyme)]
6669
pub(crate) mod Enzyme_AD {
67-
use std::ffi::{CString, c_char};
70+
use std::ffi::CString;
71+
//use std::ffi::{CString, c_char};
6872

6973
use libc::c_void;
7074

71-
unsafe extern "C" {
72-
pub(crate) fn EnzymeSetCLBool(arg1: *mut ::std::os::raw::c_void, arg2: u8);
73-
pub(crate) fn EnzymeSetCLString(arg1: *mut ::std::os::raw::c_void, arg2: *const c_char);
74-
}
75-
unsafe extern "C" {
76-
static mut EnzymePrintPerf: c_void;
77-
static mut EnzymePrintActivity: c_void;
78-
static mut EnzymePrintType: c_void;
79-
static mut EnzymeFunctionToAnalyze: c_void;
80-
static mut EnzymePrint: c_void;
81-
static mut EnzymeStrictAliasing: c_void;
82-
static mut looseTypeAnalysis: c_void;
83-
static mut EnzymeInline: c_void;
84-
static mut RustTypeRules: c_void;
75+
type SetFlag = unsafe extern "C" fn(*mut c_void, u8);
76+
77+
#[derive(Debug)]
78+
pub(crate) struct EnzymeFns {
79+
pub set_cl: SetFlag,
80+
}
81+
82+
#[derive(Debug)]
83+
pub(crate) struct EnzymeWrapper {
84+
EnzymePrintPerf: *mut c_void,
85+
EnzymePrintActivity: *mut c_void,
86+
EnzymePrintType: *mut c_void,
87+
EnzymeFunctionToAnalyze: *mut c_void,
88+
EnzymePrint: *mut c_void,
89+
EnzymeStrictAliasing: *mut c_void,
90+
looseTypeAnalysis: *mut c_void,
91+
EnzymeInline: *mut c_void,
92+
RustTypeRules: *mut c_void,
93+
94+
EnzymeSetCLBool: EnzymeFns,
95+
EnzymeSetCLString: EnzymeFns,
96+
pub registerEnzymeAndPassPipeline: *const c_void,
8597
}
86-
pub(crate) fn set_print_perf(print: bool) {
87-
unsafe {
88-
EnzymeSetCLBool(std::ptr::addr_of_mut!(EnzymePrintPerf), print as u8);
98+
fn call_dynamic() -> Result<EnzymeWrapper, Box<dyn std::error::Error>> {
99+
fn load_ptr(lib: &libloading::Library, bytes: &[u8]) -> Result<*mut c_void, Box<dyn std::error::Error>> {
100+
// Safety: symbol lookup from a loaded shared object.
101+
unsafe {
102+
let s: libloading::Symbol<'_, *mut c_void> = lib.get(bytes)?;
103+
let s = s.try_as_raw_ptr().unwrap();
104+
Ok(s as *mut c_void)
105+
}
106+
}
107+
dbg!("starting");
108+
dbg!("Loading Enzyme");
109+
let lib = unsafe {libloading::Library::new("/home/manuel/prog/rust/build/x86_64-unknown-linux-gnu/enzyme/lib/libEnzyme-21.so")?};
110+
dbg!("second");
111+
let EnzymeSetCLBool: libloading::Symbol<'_, SetFlag> = unsafe{lib.get(b"EnzymeSetCLBool")?};
112+
dbg!("third");
113+
let registerEnzymeAndPassPipeline =
114+
load_ptr(&lib, b"registerEnzymeAndPassPipeline").unwrap() as *const c_void;
115+
dbg!("fourth");
116+
//let EnzymeSetCLBool: libloading::Symbol<'_, unsafe extern "C" fn(&mut c_void, u8) -> ()> = unsafe{lib.get(b"registerEnzymeAndPassPipeline")?};
117+
//let EnzymeSetCLBool = unsafe {EnzymeSetCLBool.try_as_raw_ptr().unwrap()};
118+
let EnzymeSetCLString: libloading::Symbol<'_, SetFlag> = unsafe{ lib.get(b"EnzymeSetCLString")?};
119+
dbg!("done");
120+
//let EnzymeSetCLString = unsafe {EnzymeSetCLString.try_as_raw_ptr().unwrap()};
121+
122+
let EnzymePrintPerf = load_ptr(&lib, b"EnzymePrintPerf").unwrap();
123+
let EnzymePrintActivity = load_ptr(&lib, b"EnzymePrintActivity").unwrap();
124+
let EnzymePrintType = load_ptr(&lib, b"EnzymePrintType").unwrap();
125+
let EnzymeFunctionToAnalyze = load_ptr(&lib, b"EnzymeFunctionToAnalyze").unwrap();
126+
let EnzymePrint = load_ptr(&lib, b"EnzymePrint").unwrap();
127+
128+
let EnzymeStrictAliasing = load_ptr(&lib, b"EnzymeStrictAliasing").unwrap();
129+
let looseTypeAnalysis = load_ptr(&lib, b"looseTypeAnalysis").unwrap();
130+
let EnzymeInline = load_ptr(&lib, b"EnzymeInline").unwrap();
131+
let RustTypeRules = load_ptr(&lib, b"RustTypeRules").unwrap();
132+
133+
let wrap = EnzymeWrapper {
134+
EnzymePrintPerf,
135+
EnzymePrintActivity,
136+
EnzymePrintType,
137+
EnzymeFunctionToAnalyze,
138+
EnzymePrint,
139+
EnzymeStrictAliasing,
140+
looseTypeAnalysis,
141+
EnzymeInline,
142+
RustTypeRules,
143+
//EnzymeSetCLBool: EnzymeFns {set_cl: unsafe{*EnzymeSetCLBool}},
144+
//EnzymeSetCLString: EnzymeFns {set_cl: unsafe{*EnzymeSetCLString}},
145+
EnzymeSetCLBool: EnzymeFns {set_cl: *EnzymeSetCLBool},
146+
EnzymeSetCLString: EnzymeFns {set_cl: *EnzymeSetCLString},
147+
registerEnzymeAndPassPipeline,
148+
};
149+
dbg!(&wrap);
150+
Ok(wrap)
89151
}
90-
}
91-
pub(crate) fn set_print_activity(print: bool) {
92-
unsafe {
93-
EnzymeSetCLBool(std::ptr::addr_of_mut!(EnzymePrintActivity), print as u8);
152+
use std::sync::Mutex;
153+
unsafe impl Sync for EnzymeWrapper {}
154+
unsafe impl Send for EnzymeWrapper {}
155+
impl EnzymeWrapper {
156+
pub(crate) fn current() -> &'static Mutex<EnzymeWrapper> {
157+
use std::sync::OnceLock;
158+
static CELL: OnceLock<Mutex<EnzymeWrapper>> = OnceLock::new();
159+
fn init_enzyme() -> Mutex<EnzymeWrapper> {
160+
call_dynamic().unwrap().into()
161+
}
162+
CELL.get_or_init(|| init_enzyme())
94163
}
95-
}
96-
pub(crate) fn set_print_type(print: bool) {
97-
unsafe {
98-
EnzymeSetCLBool(std::ptr::addr_of_mut!(EnzymePrintType), print as u8);
164+
pub(crate) fn set_print_perf(&mut self, print: bool) {
165+
unsafe {
166+
//(self.EnzymeSetCLBool.set_cl)(self.EnzymePrintPerf, print as u8);
167+
//(self.EnzymeSetCLBool)(std::ptr::addr_of_mut!(self.EnzymePrintPerf), print as u8);
168+
}
99169
}
100-
}
101-
pub(crate) fn set_print_type_fun(fun_name: &str) {
102-
let c_fun_name = CString::new(fun_name).unwrap();
103-
unsafe {
104-
EnzymeSetCLString(
105-
std::ptr::addr_of_mut!(EnzymeFunctionToAnalyze),
106-
c_fun_name.as_ptr() as *const c_char,
107-
);
170+
171+
pub(crate) fn set_print_activity(&mut self, print: bool) {
172+
unsafe {
173+
//(self.EnzymeSetCLBool.set_cl)(self.EnzymePrintActivity, print as u8);
174+
//(self.EnzymeSetCLBool)(std::ptr::addr_of_mut!(self.EnzymePrintActivity), print as u8);
175+
}
108176
}
109-
}
110-
pub(crate) fn set_print(print: bool) {
111-
unsafe {
112-
EnzymeSetCLBool(std::ptr::addr_of_mut!(EnzymePrint), print as u8);
177+
178+
pub(crate) fn set_print_type(&mut self, print: bool) {
179+
unsafe {
180+
// (self.EnzymeSetCLBool.set_cl)(self.EnzymePrintType, print as u8);
181+
}
113182
}
114-
}
115-
pub(crate) fn set_strict_aliasing(strict: bool) {
116-
unsafe {
117-
EnzymeSetCLBool(std::ptr::addr_of_mut!(EnzymeStrictAliasing), strict as u8);
183+
184+
pub(crate) fn set_print_type_fun(&mut self, fun_name: &str) {
185+
let _c_fun_name = CString::new(fun_name).unwrap();
186+
//unsafe {
187+
// (self.EnzymeSetCLString.set_cl)(
188+
// self.EnzymeFunctionToAnalyze,
189+
// c_fun_name.as_ptr() as *const c_char,
190+
// );
191+
//}
118192
}
119-
}
120-
pub(crate) fn set_loose_types(loose: bool) {
121-
unsafe {
122-
EnzymeSetCLBool(std::ptr::addr_of_mut!(looseTypeAnalysis), loose as u8);
193+
194+
pub(crate) fn set_print(&mut self, print: bool) {
195+
unsafe {
196+
//(self.EnzymeSetCLBool.set_cl)(self.EnzymePrint, print as u8);
197+
}
123198
}
124-
}
125-
pub(crate) fn set_inline(val: bool) {
126-
unsafe {
127-
EnzymeSetCLBool(std::ptr::addr_of_mut!(EnzymeInline), val as u8);
199+
200+
pub(crate) fn set_strict_aliasing(&mut self, strict: bool) {
201+
unsafe {
202+
//(self.EnzymeSetCLBool.set_cl)(self.EnzymeStrictAliasing, strict as u8);
203+
}
128204
}
129-
}
130-
pub(crate) fn set_rust_rules(val: bool) {
131-
unsafe {
132-
EnzymeSetCLBool(std::ptr::addr_of_mut!(RustTypeRules), val as u8);
205+
206+
pub(crate) fn set_loose_types(&mut self, loose: bool) {
207+
unsafe {
208+
//(self.EnzymeSetCLBool.set_cl)(self.looseTypeAnalysis, loose as u8);
209+
}
210+
}
211+
212+
pub(crate) fn set_inline(&mut self, val: bool) {
213+
unsafe {
214+
//(self.EnzymeSetCLBool.set_cl)(self.EnzymeInline, val as u8);
215+
}
216+
}
217+
218+
pub(crate) fn set_rust_rules(&mut self, val: bool) {
219+
unsafe {
220+
//(self.EnzymeSetCLBool.set_cl)(self.RustTypeRules, val as u8);
221+
}
133222
}
134223
}
224+
225+
135226
}
136227

137-
#[cfg(llvm_enzyme)]
228+
#[cfg(not(llvm_enzyme))]
138229
pub(crate) use self::Fallback_AD::*;
139230

140-
#[cfg(llvm_enzyme)]
231+
#[cfg(not(llvm_enzyme))]
141232
pub(crate) mod Fallback_AD {
142233
#![allow(unused_variables)]
143234

144235
pub(crate) fn set_inline(val: bool) {
145-
//unimplemented!()
236+
unimplemented!()
146237
}
147238
pub(crate) fn set_print_perf(print: bool) {
148-
//unimplemented!()
239+
unimplemented!()
149240
}
150241
pub(crate) fn set_print_activity(print: bool) {
151-
//unimplemented!()
242+
unimplemented!()
152243
}
153244
pub(crate) fn set_print_type(print: bool) {
154-
//unimplemented!()
245+
unimplemented!()
155246
}
156247
pub(crate) fn set_print_type_fun(fun_name: &str) {
157-
//unimplemented!()
248+
unimplemented!()
158249
}
159250
pub(crate) fn set_print(print: bool) {
160-
//unimplemented!()
251+
unimplemented!()
161252
}
162253
pub(crate) fn set_strict_aliasing(strict: bool) {
163-
//unimplemented!()
254+
unimplemented!()
164255
}
165256
pub(crate) fn set_loose_types(loose: bool) {
166-
//unimplemented!()
257+
unimplemented!()
167258
}
168259
pub(crate) fn set_rust_rules(val: bool) {
169-
//unimplemented!()
260+
unimplemented!()
170261
}
171262
}

0 commit comments

Comments
 (0)