diff --git a/compiler/rustc_codegen_llvm/Cargo.toml b/compiler/rustc_codegen_llvm/Cargo.toml index 2d11628250cd2..a30bf95f6f1cb 100644 --- a/compiler/rustc_codegen_llvm/Cargo.toml +++ b/compiler/rustc_codegen_llvm/Cargo.toml @@ -14,6 +14,7 @@ bitflags = "2.4.1" gimli = "0.31" itertools = "0.12" libc = "0.2" +libloading = "0.8.8" measureme = "12.0.1" object = { version = "0.37.0", default-features = false, features = ["std", "read"] } rustc-demangle = "0.1.21" diff --git a/compiler/rustc_codegen_llvm/src/back/lto.rs b/compiler/rustc_codegen_llvm/src/back/lto.rs index 326b876e7e689..910467a907043 100644 --- a/compiler/rustc_codegen_llvm/src/back/lto.rs +++ b/compiler/rustc_codegen_llvm/src/back/lto.rs @@ -522,31 +522,33 @@ fn thin_lto( } } -fn enable_autodiff_settings(ad: &[config::AutoDiff]) { +fn enable_autodiff_settings(cgcx: &CodegenContext, ad: &[config::AutoDiff]) { + use rustc_codegen_ssa::back::write::EnzymeWrapper; + let enzyme = rustc_codegen_ssa::back::write::EnzymeWrapper::current(cgcx); for val in ad { // We intentionally don't use a wildcard, to not forget handling anything new. match val { config::AutoDiff::PrintPerf => { - llvm::set_print_perf(true); + enzyme.lock().unwrap().set_print_perf(true); } config::AutoDiff::PrintAA => { - llvm::set_print_activity(true); + enzyme.lock().unwrap().set_print_activity(true); } config::AutoDiff::PrintTA => { - llvm::set_print_type(true); + enzyme.lock().unwrap().set_print_type(true); } config::AutoDiff::PrintTAFn(fun) => { - llvm::set_print_type(true); // Enable general type printing - llvm::set_print_type_fun(&fun); // Set specific function to analyze + enzyme.lock().unwrap().set_print_type(true); // Enable general type printing + enzyme.lock().unwrap().set_print_type_fun(&fun); // Set specific function to analyze } config::AutoDiff::Inline => { - llvm::set_inline(true); + enzyme.lock().unwrap().set_inline(true); } config::AutoDiff::LooseTypes => { - llvm::set_loose_types(true); + enzyme.lock().unwrap().set_loose_types(true); } config::AutoDiff::PrintSteps => { - llvm::set_print(true); + enzyme.lock().unwrap().set_print(true); } // We handle this in the PassWrapper.cpp config::AutoDiff::PrintPasses => {} @@ -563,9 +565,9 @@ fn enable_autodiff_settings(ad: &[config::AutoDiff]) { } } // This helps with handling enums for now. - llvm::set_strict_aliasing(false); + enzyme.lock().unwrap().set_strict_aliasing(false); // FIXME(ZuseZ4): Test this, since it was added a long time ago. - llvm::set_rust_rules(true); + enzyme.lock().unwrap().set_rust_rules(true); } pub(crate) fn run_pass_manager( @@ -601,7 +603,7 @@ pub(crate) fn run_pass_manager( }; if enable_ad { - enable_autodiff_settings(&config.autodiff); + enable_autodiff_settings(&cgcx, &config.autodiff); } unsafe { diff --git a/compiler/rustc_codegen_llvm/src/back/write.rs b/compiler/rustc_codegen_llvm/src/back/write.rs index 7ea2ae6673b0f..cb70421e99c7e 100644 --- a/compiler/rustc_codegen_llvm/src/back/write.rs +++ b/compiler/rustc_codegen_llvm/src/back/write.rs @@ -665,6 +665,18 @@ pub(crate) unsafe fn llvm_optimize( let llvm_plugins = config.llvm_plugins.join(","); + let enzyme_fn = if consider_ad { + let wrapper = rustc_codegen_ssa::back::write::EnzymeWrapper::current(cgcx); + wrapper.lock().unwrap().registerEnzymeAndPassPipeline + } else { + //dbg!(run_enzyme); + //dbg!(consider_ad); + std::ptr::null() + }; + + dbg!(&enzyme_fn); + + let result = unsafe { llvm::LLVMRustOptimize( module.module_llvm.llmod(), @@ -684,7 +696,8 @@ pub(crate) unsafe fn llvm_optimize( vectorize_loop, config.no_builtins, config.emit_lifetime_markers, - run_enzyme, + enzyme_fn, + //run_enzyme, print_before_enzyme, print_after_enzyme, print_passes, diff --git a/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs b/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs index 56d756e52cce1..543e2cf4bf3f4 100644 --- a/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs +++ b/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs @@ -1,5 +1,10 @@ #![expect(dead_code)] + +use tracing::info; + +use std::path::Path; + use libc::{c_char, c_uint}; use super::MetadataKindId; @@ -51,6 +56,9 @@ unsafe extern "C" { pub(crate) fn LLVMGetNamedFunction(M: &Module, Name: *const c_char) -> Option<&Value>; } + + + #[repr(C)] #[derive(Copy, Clone, PartialEq)] pub(crate) enum LLVMRustVerifierFailureAction { @@ -62,76 +70,220 @@ pub(crate) enum LLVMRustVerifierFailureAction { #[cfg(llvm_enzyme)] pub(crate) use self::Enzyme_AD::*; -#[cfg(llvm_enzyme)] +//#[cfg(llvm_enzyme)] pub(crate) mod Enzyme_AD { - use std::ffi::{CString, c_char}; +// use std::ffi::CString; +// +// use libc::c_void; +// +//type SetFlag = unsafe extern "C" fn(*mut c_void, u8); +// +//#[derive(Debug)] +//pub(crate) struct EnzymeFns { +// pub set_cl: SetFlag, +//} +// +//#[derive(Debug)] +//pub(crate) struct EnzymeWrapper { +// EnzymePrintPerf: *mut c_void, +// EnzymePrintActivity: *mut c_void, +// EnzymePrintType: *mut c_void, +// EnzymeFunctionToAnalyze: *mut c_void, +// EnzymePrint: *mut c_void, +// EnzymeStrictAliasing: *mut c_void, +// looseTypeAnalysis: *mut c_void, +// EnzymeInline: *mut c_void, +// RustTypeRules: *mut c_void, +// +// EnzymeSetCLBool: EnzymeFns, +// EnzymeSetCLString: EnzymeFns, +// pub registerEnzymeAndPassPipeline: *const c_void, +// lib: libloading::Library, +//} +// fn call_dynamic() -> Result> { +// fn load_ptr(lib: &libloading::Library, bytes: &[u8]) -> Result<*mut c_void, Box> { +// // Safety: symbol lookup from a loaded shared object. +// unsafe { +// let s: libloading::Symbol<'_, *mut c_void> = lib.get(bytes)?; +// let s = s.try_as_raw_ptr().unwrap(); +// Ok(s as *mut c_void) +// } +// } +// dbg!("starting"); +// dbg!("Loading Enzyme"); +// use std::sync::OnceLock; +// static ENZYME_PATH: OnceLock = OnceLock::new(); +// assert!(ENZYME_PATH.get().is_some()); +// let mypath = ENZYME_PATH.get().unwrap(); // load Library from mypath +// let lib = unsafe {libloading::Library::new(mypath)?}; +// //let lib = unsafe {libloading::Library::new("/home/manuel/prog/rust/build/x86_64-unknown-linux-gnu/enzyme/lib/libEnzyme-21.so")?}; +// dbg!("second"); +// let EnzymeSetCLBool: libloading::Symbol<'_, SetFlag> = unsafe{lib.get(b"EnzymeSetCLBool")?}; +// dbg!("third"); +// let registerEnzymeAndPassPipeline = +// load_ptr(&lib, b"registerEnzymeAndPassPipeline").unwrap() as *const c_void; +// dbg!("fourth"); +// let EnzymeSetCLString: libloading::Symbol<'_, SetFlag> = unsafe{ lib.get(b"EnzymeSetCLString")?}; +// dbg!("done"); +// +// let EnzymePrintPerf = load_ptr(&lib, b"EnzymePrintPerf").unwrap(); +// let EnzymePrintActivity = load_ptr(&lib, b"EnzymePrintActivity").unwrap(); +// let EnzymePrintType = load_ptr(&lib, b"EnzymePrintType").unwrap(); +// let EnzymeFunctionToAnalyze = load_ptr(&lib, b"EnzymeFunctionToAnalyze").unwrap(); +// let EnzymePrint = load_ptr(&lib, b"EnzymePrint").unwrap(); +// +// let EnzymeStrictAliasing = load_ptr(&lib, b"EnzymeStrictAliasing").unwrap(); +// let looseTypeAnalysis = load_ptr(&lib, b"looseTypeAnalysis").unwrap(); +// let EnzymeInline = load_ptr(&lib, b"EnzymeInline").unwrap(); +// let RustTypeRules = load_ptr(&lib, b"RustTypeRules").unwrap(); +// +// let wrap = EnzymeWrapper { +// EnzymePrintPerf, +// EnzymePrintActivity, +// EnzymePrintType, +// EnzymeFunctionToAnalyze, +// EnzymePrint, +// EnzymeStrictAliasing, +// looseTypeAnalysis, +// EnzymeInline, +// RustTypeRules, +// EnzymeSetCLBool: EnzymeFns {set_cl: *EnzymeSetCLBool}, +// EnzymeSetCLString: EnzymeFns {set_cl: *EnzymeSetCLString}, +// registerEnzymeAndPassPipeline, +// lib +// }; +// dbg!(&wrap); +// Ok(wrap) +// } +use std::sync::Mutex; +use rustc_middle::bug; +use tracing::info; +use rustc_session::filesearch; +use rustc_session::Session; +use rustc_session::config::host_tuple; +//unsafe impl Sync for EnzymeWrapper {} +//unsafe impl Send for EnzymeWrapper {} +// impl EnzymeWrapper { +// pub(crate) fn current() -> &'static Mutex { +// use std::sync::OnceLock; +// static CELL: OnceLock> = OnceLock::new(); +// static ENZYME_PATH: OnceLock = OnceLock::new(); +// fn init_enzyme() -> Mutex { +// call_dynamic().unwrap().into() +// } +// //ENZYME_PATH.wait(); +// if ENZYME_PATH.get().is_none() { +// bug!("enzyme path is none!"); +// } +// CELL.get_or_init(|| init_enzyme()) +// } +// pub(crate) fn set_path(session: &Session) -> String { +// fn get_enzyme_path(session: &Session) -> String { +// dbg!("starting"); +// dbg!("Loading Enzyme"); +// let target = host_tuple(); +// let lib_ext = std::env::consts::DLL_EXTENSION; +// let sysroot = &session.opts.sysroot; +// //dbg!(sysroot); +// +// let sysroot = sysroot +// .all_paths() +// .map(|sysroot| { +// filesearch::make_target_lib_path(sysroot, target).join("lib").with_file_name("libEnzyme-21").with_extension(lib_ext) +// //filesearch::make_target_lib_path(sysroot, target).join("lib").with_file_name("lib") +// }) +// .find(|f| { +// info!("Enzyme candidate: {}", f.display()); +// f.exists() +// }) +// .unwrap_or_else(|| { +// let candidates = sysroot +// .all_paths() +// .map(|p| p.join("lib").display().to_string()) +// .collect::>() +// .join("\n* "); +// let err = format!( +// "failed to find a `libEnzyme` folder \ +// in the sysroot candidates:\n* {candidates}" +// ); +// dbg!(&err); +// bug!("asdf"); +// //early_dcx.early_fatal(err); +// }); +// +// info!("probing {} for a codegen backend", sysroot.display()); +// let enzyme_path = sysroot.to_str().unwrap().to_string(); +// //dbg!(&enzyme_path); +// enzyme_path +// } +// use std::sync::OnceLock; +// static ENZYME_PATH: OnceLock = OnceLock::new(); +// ENZYME_PATH.get_or_init(|| get_enzyme_path(session)).to_string() +// //ENZYME_PATH.get().unwrap().to_string() +// //ENZYME_PATH.get_or_init(|| get_enzyme_path(session)).clone() +// } +// pub(crate) fn set_print_perf(&mut self, print: bool) { +// unsafe { +// //(self.EnzymeSetCLBool.set_cl)(self.EnzymePrintPerf, print as u8); +// //(self.EnzymeSetCLBool)(std::ptr::addr_of_mut!(self.EnzymePrintPerf), print as u8); +// } +// } +// +// pub(crate) fn set_print_activity(&mut self, print: bool) { +// unsafe { +// //(self.EnzymeSetCLBool.set_cl)(self.EnzymePrintActivity, print as u8); +// //(self.EnzymeSetCLBool)(std::ptr::addr_of_mut!(self.EnzymePrintActivity), print as u8); +// } +// } +// +// pub(crate) fn set_print_type(&mut self, print: bool) { +// unsafe { +// // (self.EnzymeSetCLBool.set_cl)(self.EnzymePrintType, print as u8); +// } +// } +// +// pub(crate) fn set_print_type_fun(&mut self, fun_name: &str) { +// let _c_fun_name = CString::new(fun_name).unwrap(); +// //unsafe { +// // (self.EnzymeSetCLString.set_cl)( +// // self.EnzymeFunctionToAnalyze, +// // c_fun_name.as_ptr() as *const c_char, +// // ); +// //} +// } +// +// pub(crate) fn set_print(&mut self, print: bool) { +// unsafe { +// //(self.EnzymeSetCLBool.set_cl)(self.EnzymePrint, print as u8); +// } +// } +// +// pub(crate) fn set_strict_aliasing(&mut self, strict: bool) { +// unsafe { +// //(self.EnzymeSetCLBool.set_cl)(self.EnzymeStrictAliasing, strict as u8); +// } +// } +// +// pub(crate) fn set_loose_types(&mut self, loose: bool) { +// unsafe { +// //(self.EnzymeSetCLBool.set_cl)(self.looseTypeAnalysis, loose as u8); +// } +// } +// +// pub(crate) fn set_inline(&mut self, val: bool) { +// unsafe { +// //(self.EnzymeSetCLBool.set_cl)(self.EnzymeInline, val as u8); +// } +// } +// +// pub(crate) fn set_rust_rules(&mut self, val: bool) { +// unsafe { +// //(self.EnzymeSetCLBool.set_cl)(self.RustTypeRules, val as u8); +// } +// } +// } - use libc::c_void; - unsafe extern "C" { - pub(crate) fn EnzymeSetCLBool(arg1: *mut ::std::os::raw::c_void, arg2: u8); - pub(crate) fn EnzymeSetCLString(arg1: *mut ::std::os::raw::c_void, arg2: *const c_char); - } - unsafe extern "C" { - static mut EnzymePrintPerf: c_void; - static mut EnzymePrintActivity: c_void; - static mut EnzymePrintType: c_void; - static mut EnzymeFunctionToAnalyze: c_void; - static mut EnzymePrint: c_void; - static mut EnzymeStrictAliasing: c_void; - static mut looseTypeAnalysis: c_void; - static mut EnzymeInline: c_void; - static mut RustTypeRules: c_void; - } - pub(crate) fn set_print_perf(print: bool) { - unsafe { - EnzymeSetCLBool(std::ptr::addr_of_mut!(EnzymePrintPerf), print as u8); - } - } - pub(crate) fn set_print_activity(print: bool) { - unsafe { - EnzymeSetCLBool(std::ptr::addr_of_mut!(EnzymePrintActivity), print as u8); - } - } - pub(crate) fn set_print_type(print: bool) { - unsafe { - EnzymeSetCLBool(std::ptr::addr_of_mut!(EnzymePrintType), print as u8); - } - } - pub(crate) fn set_print_type_fun(fun_name: &str) { - let c_fun_name = CString::new(fun_name).unwrap(); - unsafe { - EnzymeSetCLString( - std::ptr::addr_of_mut!(EnzymeFunctionToAnalyze), - c_fun_name.as_ptr() as *const c_char, - ); - } - } - pub(crate) fn set_print(print: bool) { - unsafe { - EnzymeSetCLBool(std::ptr::addr_of_mut!(EnzymePrint), print as u8); - } - } - pub(crate) fn set_strict_aliasing(strict: bool) { - unsafe { - EnzymeSetCLBool(std::ptr::addr_of_mut!(EnzymeStrictAliasing), strict as u8); - } - } - pub(crate) fn set_loose_types(loose: bool) { - unsafe { - EnzymeSetCLBool(std::ptr::addr_of_mut!(looseTypeAnalysis), loose as u8); - } - } - pub(crate) fn set_inline(val: bool) { - unsafe { - EnzymeSetCLBool(std::ptr::addr_of_mut!(EnzymeInline), val as u8); - } - } - pub(crate) fn set_rust_rules(val: bool) { - unsafe { - EnzymeSetCLBool(std::ptr::addr_of_mut!(RustTypeRules), val as u8); - } - } } #[cfg(not(llvm_enzyme))] diff --git a/compiler/rustc_codegen_llvm/src/llvm/ffi.rs b/compiler/rustc_codegen_llvm/src/llvm/ffi.rs index b66fc157b3cb2..790101739cc1e 100644 --- a/compiler/rustc_codegen_llvm/src/llvm/ffi.rs +++ b/compiler/rustc_codegen_llvm/src/llvm/ffi.rs @@ -2518,7 +2518,8 @@ unsafe extern "C" { LoopVectorize: bool, DisableSimplifyLibCalls: bool, EmitLifetimeMarkers: bool, - RunEnzyme: bool, + RunEnzyme: *const c_void, + //RunEnzyme: bool, PrintBeforeEnzyme: bool, PrintAfterEnzyme: bool, PrintPasses: bool, diff --git a/compiler/rustc_codegen_ssa/Cargo.toml b/compiler/rustc_codegen_ssa/Cargo.toml index 2dfbc58164323..d39fbb46244de 100644 --- a/compiler/rustc_codegen_ssa/Cargo.toml +++ b/compiler/rustc_codegen_ssa/Cargo.toml @@ -12,6 +12,7 @@ bstr = "1.11.3" # per crate", so if you change this, you need to also change it in `rustc_llvm`. cc = "=1.2.16" itertools = "0.12" +libloading = "0.8.9" pathdiff = "0.2.0" regex = "1.4" rustc_abi = { path = "../rustc_abi" } diff --git a/compiler/rustc_codegen_ssa/src/back/write.rs b/compiler/rustc_codegen_ssa/src/back/write.rs index f637e7f58dbf7..61f7c711378e8 100644 --- a/compiler/rustc_codegen_ssa/src/back/write.rs +++ b/compiler/rustc_codegen_ssa/src/back/write.rs @@ -6,6 +6,11 @@ use std::sync::Arc; use std::sync::mpsc::{Receiver, Sender, channel}; use std::{fs, io, mem, str, thread}; +use tracing::info; +use rustc_session::filesearch; +use rustc_session::config::host_tuple; +use std::sync::Mutex; + use rustc_abi::Size; use rustc_ast::attr; use rustc_data_structures::fx::FxIndexMap; @@ -28,7 +33,7 @@ use rustc_middle::dep_graph::{WorkProduct, WorkProductId}; use rustc_middle::ty::TyCtxt; use rustc_session::Session; use rustc_session::config::{ - self, CrateType, Lto, OutFileName, OutputFilenames, OutputType, Passes, SwitchWithOptPath, + self, CrateType, Lto, OutFileName, OutputFilenames, OutputType, Passes, SwitchWithOptPath, Sysroot, }; use rustc_span::source_map::SourceMap; use rustc_span::{FileName, InnerSpan, Span, SpanData, sym}; @@ -344,6 +349,7 @@ pub struct CodegenContext { pub split_debuginfo: rustc_target::spec::SplitDebuginfo, pub split_dwarf_kind: rustc_session::config::SplitDwarfKind, pub pointer_size: Size, + pub sysroot: Sysroot, /// All commandline args used to invoke the compiler, with @file args fully expanded. /// This will only be used within debug info, e.g. in the pdb file on windows @@ -1072,6 +1078,7 @@ fn start_executing_work( ) -> thread::JoinHandle> { let coordinator_send = tx_to_llvm_workers; let sess = tcx.sess; + let sysroot = sess.opts.sysroot.clone(); let mut each_linked_rlib_for_lto = Vec::new(); let mut each_linked_rlib_file_for_lto = Vec::new(); @@ -1145,6 +1152,7 @@ fn start_executing_work( parallel: backend.supports_parallel() && !sess.opts.unstable_opts.no_parallel_backend, pointer_size: tcx.data_layout.pointer_size(), invocation_temp: sess.invocation_temp.clone(), + sysroot, }; let compiled_allocator_module = allocator_module.map(|mut allocator_module| { @@ -1679,6 +1687,183 @@ fn start_executing_work( #[must_use] pub(crate) struct WorkerFatalError; +use std::ffi::CString; +use libc::c_void; + +type SetFlag = unsafe extern "C" fn(*mut c_void, u8); + +#[derive(Debug)] +pub(crate) struct EnzymeFns { + pub set_cl: SetFlag, +} + +#[derive(Debug)] +#[allow(non_snake_case)] +pub struct EnzymeWrapper { + EnzymePrintPerf: *mut c_void, + EnzymePrintActivity: *mut c_void, + EnzymePrintType: *mut c_void, + EnzymeFunctionToAnalyze: *mut c_void, + EnzymePrint: *mut c_void, + EnzymeStrictAliasing: *mut c_void, + looseTypeAnalysis: *mut c_void, + EnzymeInline: *mut c_void, + RustTypeRules: *mut c_void, + + EnzymeSetCLBool: EnzymeFns, + EnzymeSetCLString: EnzymeFns, + pub registerEnzymeAndPassPipeline: *const c_void, + lib: libloading::Library, +} +unsafe impl Sync for EnzymeWrapper {} +unsafe impl Send for EnzymeWrapper {} + #[allow(non_snake_case)] + fn call_dynamic<'a, B: WriteBackendMethods>(cgcx: &'a CodegenContext) -> Result> { + fn load_ptr(lib: &libloading::Library, bytes: &[u8]) -> Result<*mut c_void, Box> { + // Safety: symbol lookup from a loaded shared object. + unsafe { + let s: libloading::Symbol<'_, *mut c_void> = lib.get(bytes)?; + let s = s.try_as_raw_ptr().unwrap(); + Ok(s as *mut c_void) + } + } + dbg!("Loading Enzyme"); + let mypath = EnzymeWrapper::get_enzyme_path(&cgcx.sysroot); + let lib = unsafe {libloading::Library::new(mypath)?}; + let EnzymeSetCLBool: libloading::Symbol<'_, SetFlag> = unsafe{lib.get(b"EnzymeSetCLBool")?}; + let registerEnzymeAndPassPipeline = + load_ptr(&lib, b"registerEnzymeAndPassPipeline").unwrap() as *const c_void; + let EnzymeSetCLString: libloading::Symbol<'_, SetFlag> = unsafe{ lib.get(b"EnzymeSetCLString")?}; + + let EnzymePrintPerf = load_ptr(&lib, b"EnzymePrintPerf").unwrap(); + let EnzymePrintActivity = load_ptr(&lib, b"EnzymePrintActivity").unwrap(); + let EnzymePrintType = load_ptr(&lib, b"EnzymePrintType").unwrap(); + let EnzymeFunctionToAnalyze = load_ptr(&lib, b"EnzymeFunctionToAnalyze").unwrap(); + let EnzymePrint = load_ptr(&lib, b"EnzymePrint").unwrap(); + + let EnzymeStrictAliasing = load_ptr(&lib, b"EnzymeStrictAliasing").unwrap(); + let looseTypeAnalysis = load_ptr(&lib, b"looseTypeAnalysis").unwrap(); + let EnzymeInline = load_ptr(&lib, b"EnzymeInline").unwrap(); + let RustTypeRules = load_ptr(&lib, b"RustTypeRules").unwrap(); + + let wrap = EnzymeWrapper { + EnzymePrintPerf, + EnzymePrintActivity, + EnzymePrintType, + EnzymeFunctionToAnalyze, + EnzymePrint, + EnzymeStrictAliasing, + looseTypeAnalysis, + EnzymeInline, + RustTypeRules, + EnzymeSetCLBool: EnzymeFns {set_cl: *EnzymeSetCLBool}, + EnzymeSetCLString: EnzymeFns {set_cl: *EnzymeSetCLString}, + registerEnzymeAndPassPipeline, + lib + }; + Ok(wrap) + } + impl EnzymeWrapper { + pub fn current<'a, B: WriteBackendMethods>(cgcx: &'a CodegenContext) -> &'static Mutex { + use std::sync::OnceLock; + static CELL: OnceLock> = OnceLock::new(); + fn init_enzyme<'a, B: WriteBackendMethods>(cgcx: &'a CodegenContext) -> Mutex { + call_dynamic(cgcx).unwrap().into() + } + CELL.get_or_init(|| init_enzyme(cgcx)) + } + fn get_enzyme_path(sysroot: &Sysroot) -> String { + let target = host_tuple(); + let lib_ext = std::env::consts::DLL_EXTENSION; + + let sysroot = sysroot + .all_paths() + .map(|sysroot| { + filesearch::make_target_lib_path(sysroot, target).join("lib").with_file_name("libEnzyme-21").with_extension(lib_ext) + }) + .find(|f| { + info!("Enzyme candidate: {}", f.display()); + f.exists() + }) + .unwrap_or_else(|| { + let candidates = sysroot + .all_paths() + .map(|p| p.join("lib").display().to_string()) + .collect::>() + .join("\n* "); + let err = format!( + "failed to find a `libEnzyme` folder \ + in the sysroot candidates:\n* {candidates}" + ); + dbg!(&err); + bug!("asdf"); + //early_dcx.early_fatal(err); + }); + + info!("probing {} for a codegen backend", sysroot.display()); + let enzyme_path = sysroot.to_str().unwrap().to_string(); + enzyme_path + } + pub fn set_print_perf(&mut self, print: bool) { + unsafe { + (self.EnzymeSetCLBool.set_cl)(self.EnzymePrintPerf, print as u8); + } + } + + pub fn set_print_activity(&mut self, print: bool) { + unsafe { + (self.EnzymeSetCLBool.set_cl)(self.EnzymePrintActivity, print as u8); + //(self.EnzymeSetCLBool)(std::ptr::addr_of_mut!(self.EnzymePrintActivity), print as u8); + } + } + + pub fn set_print_type(&mut self, print: bool) { + unsafe { + (self.EnzymeSetCLBool.set_cl)(self.EnzymePrintType, print as u8); + } + } + + pub fn set_print_type_fun(&mut self, fun_name: &str) { + let _c_fun_name = CString::new(fun_name).unwrap(); + //unsafe { + // (self.EnzymeSetCLString.set_cl)( + // self.EnzymeFunctionToAnalyze, + // c_fun_name.as_ptr() as *const c_char, + // ); + //} + } + + pub fn set_print(&mut self, print: bool) { + unsafe { + (self.EnzymeSetCLBool.set_cl)(self.EnzymePrint, print as u8); + } + } + + pub fn set_strict_aliasing(&mut self, strict: bool) { + unsafe { + (self.EnzymeSetCLBool.set_cl)(self.EnzymeStrictAliasing, strict as u8); + } + } + + pub fn set_loose_types(&mut self, loose: bool) { + unsafe { + (self.EnzymeSetCLBool.set_cl)(self.looseTypeAnalysis, loose as u8); + } + } + + pub fn set_inline(&mut self, val: bool) { + unsafe { + (self.EnzymeSetCLBool.set_cl)(self.EnzymeInline, val as u8); + } + } + + pub fn set_rust_rules(&mut self, val: bool) { + unsafe { + (self.EnzymeSetCLBool.set_cl)(self.RustTypeRules, val as u8); + } + } + } + fn spawn_work<'a, B: ExtraBackendMethods>( cgcx: &'a CodegenContext, coordinator_send: Sender>, diff --git a/compiler/rustc_llvm/llvm-wrapper/PassWrapper.cpp b/compiler/rustc_llvm/llvm-wrapper/PassWrapper.cpp index dd49232581497..bd54399ce1951 100644 --- a/compiler/rustc_llvm/llvm-wrapper/PassWrapper.cpp +++ b/compiler/rustc_llvm/llvm-wrapper/PassWrapper.cpp @@ -704,11 +704,13 @@ struct LLVMRustSanitizerOptions { extern "C" void registerEnzymeAndPassPipeline(llvm::PassBuilder &PB, /* augmentPassBuilder */ bool); -extern "C" { -extern llvm::cl::opt EnzymeFunctionToAnalyze; -} +//extern "C" { +//extern llvm::cl::opt EnzymeFunctionToAnalyze; +//} #endif +extern "C" typedef void (*registerEnzymeAndPassPipelineFn)(llvm::PassBuilder &PB, bool augment); + extern "C" LLVMRustResult LLVMRustOptimize( LLVMModuleRef ModuleRef, LLVMTargetMachineRef TMRef, LLVMRustPassBuilderOptLevel OptLevelRust, LLVMRustOptStage OptStage, @@ -716,7 +718,7 @@ extern "C" LLVMRustResult LLVMRustOptimize( bool LintIR, LLVMRustThinLTOBuffer **ThinLTOBufferRef, bool EmitThinLTO, bool EmitThinLTOSummary, bool MergeFunctions, bool UnrollLoops, bool SLPVectorize, bool LoopVectorize, bool DisableSimplifyLibCalls, - bool EmitLifetimeMarkers, bool RunEnzyme, bool PrintBeforeEnzyme, + bool EmitLifetimeMarkers, registerEnzymeAndPassPipelineFn EnzymePtr, bool PrintBeforeEnzyme, bool PrintAfterEnzyme, bool PrintPasses, LLVMRustSanitizerOptions *SanitizerOptions, const char *PGOGenPath, const char *PGOUsePath, bool InstrumentCoverage, @@ -1061,7 +1063,7 @@ extern "C" LLVMRustResult LLVMRustOptimize( // now load "-enzyme" pass: #ifdef ENZYME - if (RunEnzyme) { + if (EnzymePtr) { if (PrintBeforeEnzyme) { // Handle the Rust flag `-Zautodiff=PrintModBefore`. @@ -1069,7 +1071,8 @@ extern "C" LLVMRustResult LLVMRustOptimize( MPM.addPass(PrintModulePass(outs(), Banner, true, false)); } - registerEnzymeAndPassPipeline(PB, false); + EnzymePtr(PB, false); + //registerEnzymeAndPassPipeline(PB, false); if (auto Err = PB.parsePassPipeline(MPM, "enzyme")) { std::string ErrMsg = toString(std::move(Err)); LLVMRustSetLastError(ErrMsg.c_str()); @@ -1077,13 +1080,13 @@ extern "C" LLVMRustResult LLVMRustOptimize( } // Check if PrintTAFn was used and add type analysis pass if needed - if (!EnzymeFunctionToAnalyze.empty()) { - if (auto Err = PB.parsePassPipeline(MPM, "print-type-analysis")) { - std::string ErrMsg = toString(std::move(Err)); - LLVMRustSetLastError(ErrMsg.c_str()); - return LLVMRustResult::Failure; - } - } + //if (!EnzymeFunctionToAnalyze.empty()) { + // if (auto Err = PB.parsePassPipeline(MPM, "print-type-analysis")) { + // std::string ErrMsg = toString(std::move(Err)); + // LLVMRustSetLastError(ErrMsg.c_str()); + // return LLVMRustResult::Failure; + // } + //} if (PrintAfterEnzyme) { // Handle the Rust flag `-Zautodiff=PrintModAfter`. diff --git a/src/bootstrap/src/core/build_steps/compile.rs b/src/bootstrap/src/core/build_steps/compile.rs index 0b75e85772f86..99b4ca84420fe 100644 --- a/src/bootstrap/src/core/build_steps/compile.rs +++ b/src/bootstrap/src/core/build_steps/compile.rs @@ -1205,14 +1205,14 @@ pub fn rustc_cargo( // We want to link against registerEnzyme and in the future we want to use additional // functionality from Enzyme core. For that we need to link against Enzyme. if builder.config.llvm_enzyme { - let arch = builder.build.host_target; - let enzyme_dir = builder.build.out.join(arch).join("enzyme").join("lib"); - cargo.rustflag("-L").rustflag(enzyme_dir.to_str().expect("Invalid path")); - - if let Some(llvm_config) = builder.llvm_config(builder.config.host_target) { - let llvm_version_major = llvm::get_llvm_version_major(builder, &llvm_config); - cargo.rustflag("-l").rustflag(&format!("Enzyme-{llvm_version_major}")); - } + //let arch = builder.build.host_target; + //let enzyme_dir = builder.build.out.join(arch).join("enzyme").join("lib"); + //cargo.rustflag("-L").rustflag(enzyme_dir.to_str().expect("Invalid path")); + + //if let Some(llvm_config) = builder.llvm_config(builder.config.host_target) { + // let llvm_version_major = llvm::get_llvm_version_major(builder, &llvm_config); + // cargo.rustflag("-l").rustflag(&format!("Enzyme-{llvm_version_major}")); + //} } // Building with protected visibility reduces the number of dynamic relocations needed, giving diff --git a/src/tools/enzyme b/src/tools/enzyme index 58af4e9e6c047..942063499a4cf 160000 --- a/src/tools/enzyme +++ b/src/tools/enzyme @@ -1 +1 @@ -Subproject commit 58af4e9e6c047534ba059b12af17cecb8a2e9f9e +Subproject commit 942063499a4cf733dd0fd1bd2b30c28fd31334ee