Skip to content

Commit e949d8b

Browse files
committed
Use libloading to open Enzyme at runtime
1 parent 59c665b commit e949d8b

File tree

9 files changed

+466
-106
lines changed

9 files changed

+466
-106
lines changed

compiler/rustc_codegen_llvm/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ bitflags = "2.4.1"
1414
gimli = "0.31"
1515
itertools = "0.12"
1616
libc = "0.2"
17+
libloading = "0.8.8"
1718
measureme = "12.0.1"
1819
object = { version = "0.37.0", default-features = false, features = ["std", "read"] }
1920
rustc-demangle = "0.1.21"

compiler/rustc_codegen_llvm/src/back/lto.rs

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -522,31 +522,33 @@ fn thin_lto(
522522
}
523523
}
524524

525-
fn enable_autodiff_settings(ad: &[config::AutoDiff]) {
525+
fn enable_autodiff_settings(cgcx: &CodegenContext<LlvmCodegenBackend>, ad: &[config::AutoDiff]) {
526+
use rustc_codegen_ssa::back::write::EnzymeWrapper;
527+
let enzyme = rustc_codegen_ssa::back::write::EnzymeWrapper::current(cgcx);
526528
for val in ad {
527529
// We intentionally don't use a wildcard, to not forget handling anything new.
528530
match val {
529531
config::AutoDiff::PrintPerf => {
530-
llvm::set_print_perf(true);
532+
enzyme.lock().unwrap().set_print_perf(true);
531533
}
532534
config::AutoDiff::PrintAA => {
533-
llvm::set_print_activity(true);
535+
enzyme.lock().unwrap().set_print_activity(true);
534536
}
535537
config::AutoDiff::PrintTA => {
536-
llvm::set_print_type(true);
538+
enzyme.lock().unwrap().set_print_type(true);
537539
}
538540
config::AutoDiff::PrintTAFn(fun) => {
539-
llvm::set_print_type(true); // Enable general type printing
540-
llvm::set_print_type_fun(&fun); // Set specific function to analyze
541+
enzyme.lock().unwrap().set_print_type(true); // Enable general type printing
542+
enzyme.lock().unwrap().set_print_type_fun(&fun); // Set specific function to analyze
541543
}
542544
config::AutoDiff::Inline => {
543-
llvm::set_inline(true);
545+
enzyme.lock().unwrap().set_inline(true);
544546
}
545547
config::AutoDiff::LooseTypes => {
546-
llvm::set_loose_types(true);
548+
enzyme.lock().unwrap().set_loose_types(true);
547549
}
548550
config::AutoDiff::PrintSteps => {
549-
llvm::set_print(true);
551+
enzyme.lock().unwrap().set_print(true);
550552
}
551553
// We handle this in the PassWrapper.cpp
552554
config::AutoDiff::PrintPasses => {}
@@ -563,9 +565,9 @@ fn enable_autodiff_settings(ad: &[config::AutoDiff]) {
563565
}
564566
}
565567
// This helps with handling enums for now.
566-
llvm::set_strict_aliasing(false);
568+
enzyme.lock().unwrap().set_strict_aliasing(false);
567569
// FIXME(ZuseZ4): Test this, since it was added a long time ago.
568-
llvm::set_rust_rules(true);
570+
enzyme.lock().unwrap().set_rust_rules(true);
569571
}
570572

571573
pub(crate) fn run_pass_manager(
@@ -601,7 +603,7 @@ pub(crate) fn run_pass_manager(
601603
};
602604

603605
if enable_ad {
604-
enable_autodiff_settings(&config.autodiff);
606+
enable_autodiff_settings(&cgcx, &config.autodiff);
605607
}
606608

607609
unsafe {

compiler/rustc_codegen_llvm/src/back/write.rs

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -665,6 +665,18 @@ pub(crate) unsafe fn llvm_optimize(
665665

666666
let llvm_plugins = config.llvm_plugins.join(",");
667667

668+
let enzyme_fn = if consider_ad {
669+
let wrapper = rustc_codegen_ssa::back::write::EnzymeWrapper::current(cgcx);
670+
wrapper.lock().unwrap().registerEnzymeAndPassPipeline
671+
} else {
672+
//dbg!(run_enzyme);
673+
//dbg!(consider_ad);
674+
std::ptr::null()
675+
};
676+
677+
dbg!(&enzyme_fn);
678+
679+
668680
let result = unsafe {
669681
llvm::LLVMRustOptimize(
670682
module.module_llvm.llmod(),
@@ -684,7 +696,8 @@ pub(crate) unsafe fn llvm_optimize(
684696
vectorize_loop,
685697
config.no_builtins,
686698
config.emit_lifetime_markers,
687-
run_enzyme,
699+
enzyme_fn,
700+
//run_enzyme,
688701
print_before_enzyme,
689702
print_after_enzyme,
690703
print_passes,

0 commit comments

Comments
 (0)