Skip to content

Commit 4cd65a5

Browse files
committed
added PrintTAFn flag
1 parent c6a9554 commit 4cd65a5

File tree

6 files changed

+44
-5
lines changed

6 files changed

+44
-5
lines changed

compiler/rustc_codegen_llvm/src/back/lto.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -587,7 +587,7 @@ fn thin_lto(
587587
}
588588

589589
fn enable_autodiff_settings(ad: &[config::AutoDiff]) {
590-
for &val in ad {
590+
for val in ad {
591591
// We intentionally don't use a wildcard, to not forget handling anything new.
592592
match val {
593593
config::AutoDiff::PrintPerf => {
@@ -599,6 +599,9 @@ fn enable_autodiff_settings(ad: &[config::AutoDiff]) {
599599
config::AutoDiff::PrintTA => {
600600
llvm::set_print_type(true);
601601
}
602+
config::AutoDiff::PrintTAFn(fun) => {
603+
llvm::set_print_type_fun(&fun);
604+
}
602605
config::AutoDiff::Inline => {
603606
llvm::set_inline(true);
604607
}

compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,13 +58,17 @@ pub(crate) use self::Enzyme_AD::*;
5858
#[cfg(llvm_enzyme)]
5959
pub(crate) mod Enzyme_AD {
6060
use libc::c_void;
61+
use std::ffi::{CString, c_char};
62+
6163
unsafe extern "C" {
6264
pub(crate) fn EnzymeSetCLBool(arg1: *mut ::std::os::raw::c_void, arg2: u8);
65+
pub(crate) fn EnzymeSetCLString(arg1: *mut ::std::os::raw::c_void, arg2: *const c_char);
6366
}
6467
unsafe extern "C" {
6568
static mut EnzymePrintPerf: c_void;
6669
static mut EnzymePrintActivity: c_void;
6770
static mut EnzymePrintType: c_void;
71+
static mut FunctionToAnalyze: c_void;
6872
static mut EnzymePrint: c_void;
6973
static mut EnzymeStrictAliasing: c_void;
7074
static mut looseTypeAnalysis: c_void;
@@ -86,6 +90,15 @@ pub(crate) mod Enzyme_AD {
8690
EnzymeSetCLBool(std::ptr::addr_of_mut!(EnzymePrintType), print as u8);
8791
}
8892
}
93+
pub(crate) fn set_print_type_fun(fun_name: &str) {
94+
let c_fun_name = CString::new(fun_name).unwrap();
95+
unsafe {
96+
EnzymeSetCLString(
97+
std::ptr::addr_of_mut!(FunctionToAnalyze),
98+
c_fun_name.as_ptr() as *const c_char,
99+
);
100+
}
101+
}
89102
pub(crate) fn set_print(print: bool) {
90103
unsafe {
91104
EnzymeSetCLBool(std::ptr::addr_of_mut!(EnzymePrint), print as u8);
@@ -132,6 +145,9 @@ pub(crate) mod Fallback_AD {
132145
pub(crate) fn set_print_type(print: bool) {
133146
unimplemented!()
134147
}
148+
pub(crate) fn set_print_type_fun(fun_name: &str) {
149+
unimplemented!()
150+
}
135151
pub(crate) fn set_print(print: bool) {
136152
unimplemented!()
137153
}

compiler/rustc_session/src/config.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -227,13 +227,15 @@ pub enum CoverageLevel {
227227
}
228228

229229
/// The different settings that the `-Z autodiff` flag can have.
230-
#[derive(Clone, Copy, PartialEq, Hash, Debug)]
230+
#[derive(Clone, PartialEq, Hash, Debug)]
231231
pub enum AutoDiff {
232232
/// Enable the autodiff opt pipeline
233233
Enable,
234234

235235
/// Print TypeAnalysis information
236236
PrintTA,
237+
/// Print TypeAnalysis information for a specific function
238+
PrintTAFn(String),
237239
/// Print ActivityAnalysis Information
238240
PrintAA,
239241
/// Print Performance Warnings from Enzyme

compiler/rustc_session/src/options.rs

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -711,7 +711,7 @@ mod desc {
711711
pub(crate) const parse_list: &str = "a space-separated list of strings";
712712
pub(crate) const parse_list_with_polarity: &str =
713713
"a comma-separated list of strings, with elements beginning with + or -";
714-
pub(crate) const parse_autodiff: &str = "a comma separated list of settings: `Enable`, `PrintSteps`, `PrintTA`, `PrintAA`, `PrintPerf`, `PrintModBefore`, `PrintModAfter`, `PrintModFinal`, `PrintPasses`, `NoPostopt`, `LooseTypes`, `Inline`";
714+
pub(crate) const parse_autodiff: &str = "a comma separated list of settings: `Enable`, `PrintSteps`, `PrintTA`, `PrintTAFn`, `PrintAA`, `PrintPerf`, `PrintModBefore`, `PrintModAfter`, `PrintModFinal`, `PrintPasses`, `NoPostopt`, `LooseTypes`, `Inline`";
715715
pub(crate) const parse_comma_list: &str = "a comma-separated list of strings";
716716
pub(crate) const parse_opt_comma_list: &str = parse_comma_list;
717717
pub(crate) const parse_number: &str = "a number";
@@ -1351,9 +1351,25 @@ pub mod parse {
13511351
let mut v: Vec<&str> = v.split(",").collect();
13521352
v.sort_unstable();
13531353
for &val in v.iter() {
1354-
let variant = match val {
1354+
// Split each entry on '=' if it has an argument
1355+
let (key, arg) = match val.split_once('=') {
1356+
Some((k, a)) => (k, Some(a)),
1357+
None => (val, None),
1358+
};
1359+
1360+
let variant = match key {
13551361
"Enable" => AutoDiff::Enable,
13561362
"PrintTA" => AutoDiff::PrintTA,
1363+
"PrintTAFn" => {
1364+
if let Some(fun) = arg {
1365+
AutoDiff::PrintTAFn(fun.to_string())
1366+
} else {
1367+
eprintln!(
1368+
"Missing argument for PrintTAFn (expected PrintTAFn=<function_name>)"
1369+
);
1370+
return false;
1371+
}
1372+
}
13571373
"PrintAA" => AutoDiff::PrintAA,
13581374
"PrintPerf" => AutoDiff::PrintPerf,
13591375
"PrintSteps" => AutoDiff::PrintSteps,
@@ -1365,7 +1381,7 @@ pub mod parse {
13651381
"LooseTypes" => AutoDiff::LooseTypes,
13661382
"Inline" => AutoDiff::Inline,
13671383
_ => {
1368-
// FIXME(ZuseZ4): print an error saying which value is not recognized
1384+
eprintln!("Unknown autodiff option: {key}");
13691385
return false;
13701386
}
13711387
};

src/doc/rustc-dev-guide/src/autodiff/flags.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ To support you while debugging or profiling, we have added support for an experi
66

77
```text
88
PrintTA // Print TypeAnalysis information
9+
PrintTAFn // Print TypeAnalysis information for a specific function
910
PrintAA // Print ActivityAnalysis information
1011
Print // Print differentiated functions while they are being generated and optimized
1112
PrintPerf // Print AD related Performance warnings

src/doc/unstable-book/src/compiler-flags/autodiff.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ Multiple options can be separated with a comma. Valid options are:
1010

1111
`Enable` - Required flag to enable autodiff
1212
`PrintTA` - print Type Analysis Information
13+
`PrintTAFn` - print Type Analysis Information for a specific function
1314
`PrintAA` - print Activity Analysis Information
1415
`PrintPerf` - print Performance Warnings from Enzyme
1516
`PrintSteps` - prints all intermediate transformations

0 commit comments

Comments
 (0)