Skip to content

Commit 8761799

Browse files
committed
remove perf impact of autodiff when not in use
1 parent 104cd33 commit 8761799

File tree

8 files changed

+50
-69
lines changed

8 files changed

+50
-69
lines changed

compiler/rustc_ast/src/expand/autodiff_attrs.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@ use crate::{Ty, TyKind};
1717
/// functions. The proper solution is to recognize and resolve this DAG of autodiff invocations,
1818
/// as it's already done in the C++ and Julia frontend of Enzyme.
1919
///
20-
/// (FIXME) remove *First variants.
2120
/// Documentation for using [reverse](https://enzyme.mit.edu/rust/rev.html) and
2221
/// [forward](https://enzyme.mit.edu/rust/fwd.html) mode is available online.
2322
#[derive(Clone, Copy, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]

compiler/rustc_codegen_llvm/src/back/lto.rs

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -607,19 +607,20 @@ pub(crate) fn run_pass_manager(
607607
// If this rustc version was build with enzyme/autodiff enabled, and if users applied the
608608
// `#[autodiff]` macro at least once, then we will later call llvm_optimize a second time.
609609
debug!("running llvm pm opt pipeline");
610+
611+
// The PostAD behavior is the same that we would have if no autodiff was used.
612+
// It will run the default optimization pipeline. If AD is enabled we select
613+
// the DuringAD stage, which will disable vectorization and loop unrolling, and
614+
// schedule two autodiff optimization + differentiation passes.
615+
// We then run the llvm_optimize function a second time, to optimize the code which we generated
616+
// in the enzyme differentiation pass.
617+
let enable_ad = config.autodiff.contains(&config::AutoDiff::Enable);
618+
let stage =
619+
if enable_ad { write::AutodiffStage::DuringAD } else { write::AutodiffStage::PostAD };
610620
unsafe {
611-
write::llvm_optimize(
612-
cgcx,
613-
dcx,
614-
module,
615-
config,
616-
opt_level,
617-
opt_stage,
618-
write::AutodiffStage::DuringAD,
619-
)?;
621+
write::llvm_optimize(cgcx, dcx, module, config, opt_level, opt_stage, stage)?;
620622
}
621-
// FIXME(ZuseZ4): Make this more granular
622-
if cfg!(llvm_enzyme) && !thin {
623+
if cfg!(llvm_enzyme) && !thin && enable_ad {
623624
unsafe {
624625
write::llvm_optimize(
625626
cgcx,

compiler/rustc_codegen_llvm/src/back/write.rs

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -564,19 +564,16 @@ pub(crate) unsafe fn llvm_optimize(
564564
// FIXME(ZuseZ4): In a future update we could figure out how to only optimize individual functions getting
565565
// differentiated.
566566

567+
let consider_ad = cfg!(llvm_enzyme) && config.autodiff.contains(&config::AutoDiff::Enable);
568+
let run_enzyme = autodiff_stage == AutodiffStage::DuringAD;
567569
let unroll_loops;
568570
let vectorize_slp;
569571
let vectorize_loop;
570-
let run_enzyme = cfg!(llvm_enzyme) && autodiff_stage == AutodiffStage::DuringAD;
571572

572573
// When we build rustc with enzyme/autodiff support, we want to postpone size-increasing
573574
// optimizations until after differentiation. Our pipeline is thus: (opt + enzyme), (full opt).
574575
// We therefore have two calls to llvm_optimize, if autodiff is used.
575-
//
576-
// FIXME(ZuseZ4): Before shipping on nightly,
577-
// we should make this more granular, or at least check that the user has at least one autodiff
578-
// call in their code, to justify altering the compilation pipeline.
579-
if cfg!(llvm_enzyme) && autodiff_stage != AutodiffStage::PostAD {
576+
if consider_ad && autodiff_stage != AutodiffStage::PostAD {
580577
unroll_loops = false;
581578
vectorize_slp = false;
582579
vectorize_loop = false;
@@ -706,10 +703,8 @@ pub(crate) unsafe fn optimize(
706703

707704
// If we know that we will later run AD, then we disable vectorization and loop unrolling.
708705
// Otherwise we pretend AD is already done and run the normal opt pipeline (=PostAD).
709-
// FIXME(ZuseZ4): Make this more granular, only set PreAD if we actually have autodiff
710-
// usages, not just if we build rustc with autodiff support.
711-
let autodiff_stage =
712-
if cfg!(llvm_enzyme) { AutodiffStage::PreAD } else { AutodiffStage::PostAD };
706+
let consider_ad = cfg!(llvm_enzyme) && config.autodiff.contains(&config::AutoDiff::Enable);
707+
let autodiff_stage = if consider_ad { AutodiffStage::PreAD } else { AutodiffStage::PostAD };
713708
return unsafe {
714709
llvm_optimize(cgcx, dcx, module, config, opt_level, opt_stage, autodiff_stage)
715710
};

compiler/rustc_codegen_llvm/src/builder/autodiff.rs

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,6 @@ fn generate_enzyme_call<'ll>(
4646
let output = attrs.ret_activity;
4747

4848
// We have to pick the name depending on whether we want forward or reverse mode autodiff.
49-
// FIXME(ZuseZ4): The new pass based approach should not need the {Forward/Reverse}First method anymore, since
50-
// it will handle higher-order derivatives correctly automatically (in theory). Currently
51-
// higher-order derivatives fail, so we should debug that before adjusting this code.
5249
let mut ad_name: String = match attrs.mode {
5350
DiffMode::Forward => "__enzyme_fwddiff",
5451
DiffMode::Reverse => "__enzyme_autodiff",

compiler/rustc_interface/src/tests.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -759,7 +759,7 @@ fn test_unstable_options_tracking_hash() {
759759
tracked!(allow_features, Some(vec![String::from("lang_items")]));
760760
tracked!(always_encode_mir, true);
761761
tracked!(assume_incomplete_release, true);
762-
tracked!(autodiff, vec![AutoDiff::Print]);
762+
tracked!(autodiff, vec![AutoDiff::Enable]);
763763
tracked!(binary_dep_depinfo, true);
764764
tracked!(box_noalias, false);
765765
tracked!(

compiler/rustc_session/src/config.rs

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -198,33 +198,28 @@ pub enum CoverageLevel {
198198
/// The different settings that the `-Z autodiff` flag can have.
199199
#[derive(Clone, Copy, PartialEq, Hash, Debug)]
200200
pub enum AutoDiff {
201+
/// Enable the autodiff opt pipeline
202+
Enable,
203+
201204
/// Print TypeAnalysis information
202205
PrintTA,
203206
/// Print ActivityAnalysis Information
204207
PrintAA,
205208
/// Print Performance Warnings from Enzyme
206209
PrintPerf,
207-
/// Combines the three print flags above.
208-
Print,
210+
/// Print intermediate IR generation steps
211+
PrintSteps,
209212
/// Print the whole module, before running opts.
210213
PrintModBefore,
211-
/// Print the whole module just before we pass it to Enzyme.
212-
/// For Debug purpose, prefer the OPT flag below
213-
PrintModAfterOpts,
214214
/// Print the module after Enzyme differentiated everything.
215-
PrintModAfterEnzyme,
215+
PrintModAfter,
216216

217-
/// Enzyme's loose type debug helper (can cause incorrect gradients)
217+
/// Enzyme's loose type debug helper (can cause incorrect gradients!!)
218+
/// Usable in cases where Enzyme errors with `can not deduce type of X`.
218219
LooseTypes,
219-
220-
/// More flags
221-
NoModOptAfter,
222-
/// Tell Enzyme to run LLVM Opts on each function it generated. By default off,
223-
/// since we already optimize the whole module after Enzyme is done.
224-
EnableFncOpt,
225-
NoVecUnroll,
220+
/// See Enzyme core documentation. FIXME(ZuseZ4): Clarify usages
226221
RuntimeActivity,
227-
/// Runs Enzyme specific Inlining
222+
/// Runs Enzyme's aggressive inlining
228223
Inline,
229224
}
230225

compiler/rustc_session/src/options.rs

Lines changed: 19 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -707,7 +707,7 @@ mod desc {
707707
pub(crate) const parse_list: &str = "a space-separated list of strings";
708708
pub(crate) const parse_list_with_polarity: &str =
709709
"a comma-separated list of strings, with elements beginning with + or -";
710-
pub(crate) const parse_autodiff: &str = "a comma separated list of settings: `Print`, `PrintTA`, `PrintAA`, `PrintPerf`, `PrintModBefore`, `PrintModAfterOpts`, `PrintModAfterEnzyme`, `LooseTypes`, `NoModOptAfter`, `EnableFncOpt`, `NoVecUnroll`, `Inline`";
710+
pub(crate) const parse_autodiff: &str = "a comma separated list of settings: `Enable`, `PrintSteps`, `PrintTA`, `PrintAA`, `PrintPerf`, `PrintModBefore`, `PrintModAfter`, `LooseTypes`, `RuntimeActivity`, `Inline`";
711711
pub(crate) const parse_comma_list: &str = "a comma-separated list of strings";
712712
pub(crate) const parse_opt_comma_list: &str = parse_comma_list;
713713
pub(crate) const parse_number: &str = "a number";
@@ -1348,17 +1348,15 @@ pub mod parse {
13481348
v.sort_unstable();
13491349
for &val in v.iter() {
13501350
let variant = match val {
1351+
"Enable" => AutoDiff::Enable,
13511352
"PrintTA" => AutoDiff::PrintTA,
13521353
"PrintAA" => AutoDiff::PrintAA,
13531354
"PrintPerf" => AutoDiff::PrintPerf,
1354-
"Print" => AutoDiff::Print,
1355+
"PrintSteps" => AutoDiff::PrintSteps,
13551356
"PrintModBefore" => AutoDiff::PrintModBefore,
1356-
"PrintModAfterOpts" => AutoDiff::PrintModAfterOpts,
1357-
"PrintModAfterEnzyme" => AutoDiff::PrintModAfterEnzyme,
1357+
"PrintModAfter" => AutoDiff::PrintModAfter,
13581358
"LooseTypes" => AutoDiff::LooseTypes,
1359-
"NoModOptAfter" => AutoDiff::NoModOptAfter,
1360-
"EnableFncOpt" => AutoDiff::EnableFncOpt,
1361-
"NoVecUnroll" => AutoDiff::NoVecUnroll,
1359+
"RuntimeActivity" => AutoDiff::RuntimeActivity,
13621360
"Inline" => AutoDiff::Inline,
13631361
_ => {
13641362
// FIXME(ZuseZ4): print an error saying which value is not recognized
@@ -2081,21 +2079,20 @@ options! {
20812079
assume_incomplete_release: bool = (false, parse_bool, [TRACKED],
20822080
"make cfg(version) treat the current version as incomplete (default: no)"),
20832081
autodiff: Vec<crate::config::AutoDiff> = (Vec::new(), parse_autodiff, [TRACKED],
2084-
"a list of optional autodiff flags to enable
2085-
Optional extra settings:
2086-
`=PrintTA`
2087-
`=PrintAA`
2088-
`=PrintPerf`
2089-
`=Print`
2090-
`=PrintModBefore`
2091-
`=PrintModAfterOpts`
2092-
`=PrintModAfterEnzyme`
2093-
`=LooseTypes`
2094-
`=NoModOptAfter`
2095-
`=EnableFncOpt`
2096-
`=NoVecUnroll`
2097-
`=Inline`
2098-
Multiple options can be combined with commas."),
2082+
"a list of autodiff flags to enable
2083+
Mandatory setting:
2084+
`=Enable`
2085+
Optional extra settings:
2086+
`=PrintTA`
2087+
`=PrintAA`
2088+
`=PrintPerf`
2089+
`=PrintSteps`
2090+
`=PrintModBefore`
2091+
`=PrintModAfter`
2092+
`=LooseTypes`
2093+
'=RuntimeActivity`
2094+
`=Inline`
2095+
Multiple options can be combined with commas."),
20992096
#[rustc_lint_opt_deny_field_access("use `Session::binary_dep_depinfo` instead of this field")]
21002097
binary_dep_depinfo: bool = (false, parse_bool, [TRACKED],
21012098
"include artifacts (sysroot, crate dependencies) used during compilation in dep-info \

src/doc/unstable-book/src/compiler-flags/autodiff.md

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,13 @@ This feature allows you to differentiate functions using automatic differentiati
88
Set the `-Zautodiff=<options>` compiler flag to adjust the behaviour of the autodiff feature.
99
Multiple options can be separated with a comma. Valid options are:
1010

11+
`Enable` - Required flag to enable autodiff
1112
`PrintTA` - print Type Analysis Information
1213
`PrintAA` - print Activity Analysis Information
1314
`PrintPerf` - print Performance Warnings from Enzyme
14-
`Print` - prints all intermediate transformations
15+
`PrintSteps` - prints all intermediate transformations
1516
`PrintModBefore` - print the whole module, before running opts
16-
`PrintModAfterOpts` - print the whole module just before we pass it to Enzyme
17-
`PrintModAfterEnzyme` - print the module after Enzyme differentiated everything
17+
`PrintModAfter` - print the module after Enzyme differentiated everything
1818
`LooseTypes` - Enzyme's loose type debug helper (can cause incorrect gradients)
1919
`Inline` - runs Enzyme specific Inlining
20-
`NoModOptAfter` - do not optimize the module after Enzyme is done
21-
`EnableFncOpt` - tell Enzyme to run LLVM Opts on each function it generated
22-
`NoVecUnroll` - do not unroll vectorized loops
2320
`RuntimeActivity` - allow specifying activity at runtime

0 commit comments

Comments
 (0)