Skip to content

Commit 8c90dd0

Browse files
committed
remove perf impact of autodiff when not in use
1 parent 3f47395 commit 8c90dd0

File tree

6 files changed

+37
-53
lines changed

6 files changed

+37
-53
lines changed

compiler/rustc_ast/src/expand/autodiff_attrs.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@ use crate::{Ty, TyKind};
1717
/// functions. The proper solution is to recognize and resolve this DAG of autodiff invocations,
1818
/// as it's already done in the C++ and Julia frontend of Enzyme.
1919
///
20-
/// (FIXME) remove *First variants.
2120
/// Documentation for using [reverse](https://enzyme.mit.edu/rust/rev.html) and
2221
/// [forward](https://enzyme.mit.edu/rust/fwd.html) mode is available online.
2322
#[derive(Clone, Copy, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]

compiler/rustc_codegen_llvm/src/back/lto.rs

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -607,19 +607,20 @@ pub(crate) fn run_pass_manager(
607607
// If this rustc version was build with enzyme/autodiff enabled, and if users applied the
608608
// `#[autodiff]` macro at least once, then we will later call llvm_optimize a second time.
609609
debug!("running llvm pm opt pipeline");
610+
611+
// The PostAD behavior is the same that we would have if no autodiff was used.
612+
// It will run the default optimization pipeline. If AD is enabled we select
613+
// the DuringAD stage, which will disable vectorization and loop unrolling, and
614+
// schedule two autodiff optimization + differentiation passes.
615+
// We then run the llvm_optimize function a second time, to optimize the code which we generated
616+
// in the enzyme differentiation pass.
617+
let enable_ad = config.autodiff.contains(&config::AutoDiff::Enable);
618+
let stage =
619+
if enable_ad { write::AutodiffStage::DuringAD } else { write::AutodiffStage::PostAD };
610620
unsafe {
611-
write::llvm_optimize(
612-
cgcx,
613-
dcx,
614-
module,
615-
config,
616-
opt_level,
617-
opt_stage,
618-
write::AutodiffStage::DuringAD,
619-
)?;
621+
write::llvm_optimize(cgcx, dcx, module, config, opt_level, opt_stage, stage)?;
620622
}
621-
// FIXME(ZuseZ4): Make this more granular
622-
if cfg!(llvm_enzyme) && !thin {
623+
if cfg!(llvm_enzyme) && !thin && enable_ad {
623624
unsafe {
624625
write::llvm_optimize(
625626
cgcx,

compiler/rustc_codegen_llvm/src/back/write.rs

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -557,19 +557,16 @@ pub(crate) unsafe fn llvm_optimize(
557557
// FIXME(ZuseZ4): In a future update we could figure out how to only optimize individual functions getting
558558
// differentiated.
559559

560+
let consider_ad = cfg!(llvm_enzyme) && config.autodiff.contains(&config::AutoDiff::Enable);
561+
let run_enzyme = autodiff_stage == AutodiffStage::DuringAD;
560562
let unroll_loops;
561563
let vectorize_slp;
562564
let vectorize_loop;
563-
let run_enzyme = cfg!(llvm_enzyme) && autodiff_stage == AutodiffStage::DuringAD;
564565

565566
// When we build rustc with enzyme/autodiff support, we want to postpone size-increasing
566567
// optimizations until after differentiation. Our pipeline is thus: (opt + enzyme), (full opt).
567568
// We therefore have two calls to llvm_optimize, if autodiff is used.
568-
//
569-
// FIXME(ZuseZ4): Before shipping on nightly,
570-
// we should make this more granular, or at least check that the user has at least one autodiff
571-
// call in their code, to justify altering the compilation pipeline.
572-
if cfg!(llvm_enzyme) && autodiff_stage != AutodiffStage::PostAD {
569+
if consider_ad && autodiff_stage != AutodiffStage::PostAD {
573570
unroll_loops = false;
574571
vectorize_slp = false;
575572
vectorize_loop = false;
@@ -701,10 +698,8 @@ pub(crate) unsafe fn optimize(
701698

702699
// If we know that we will later run AD, then we disable vectorization and loop unrolling.
703700
// Otherwise we pretend AD is already done and run the normal opt pipeline (=PostAD).
704-
// FIXME(ZuseZ4): Make this more granular, only set PreAD if we actually have autodiff
705-
// usages, not just if we build rustc with autodiff support.
706-
let autodiff_stage =
707-
if cfg!(llvm_enzyme) { AutodiffStage::PreAD } else { AutodiffStage::PostAD };
701+
let consider_ad = cfg!(llvm_enzyme) && config.autodiff.contains(&config::AutoDiff::Enable);
702+
let autodiff_stage = if consider_ad { AutodiffStage::PreAD } else { AutodiffStage::PostAD };
708703
return unsafe {
709704
llvm_optimize(cgcx, dcx, module, config, opt_level, opt_stage, autodiff_stage)
710705
};

compiler/rustc_codegen_llvm/src/builder/autodiff.rs

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,6 @@ fn generate_enzyme_call<'ll>(
4646
let output = attrs.ret_activity;
4747

4848
// We have to pick the name depending on whether we want forward or reverse mode autodiff.
49-
// FIXME(ZuseZ4): The new pass based approach should not need the {Forward/Reverse}First method anymore, since
50-
// it will handle higher-order derivatives correctly automatically (in theory). Currently
51-
// higher-order derivatives fail, so we should debug that before adjusting this code.
5249
let mut ad_name: String = match attrs.mode {
5350
DiffMode::Forward => "__enzyme_fwddiff",
5451
DiffMode::Reverse => "__enzyme_autodiff",

compiler/rustc_session/src/config.rs

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -192,33 +192,28 @@ pub enum CoverageLevel {
192192
/// The different settings that the `-Z autodiff` flag can have.
193193
#[derive(Clone, Copy, PartialEq, Hash, Debug)]
194194
pub enum AutoDiff {
195+
/// Enable the autodiff opt pipeline
196+
Enable,
197+
195198
/// Print TypeAnalysis information
196199
PrintTA,
197200
/// Print ActivityAnalysis Information
198201
PrintAA,
199202
/// Print Performance Warnings from Enzyme
200203
PrintPerf,
201-
/// Combines the three print flags above.
202-
Print,
204+
/// Print intermediate IR generation steps
205+
PrintSteps,
203206
/// Print the whole module, before running opts.
204207
PrintModBefore,
205-
/// Print the whole module just before we pass it to Enzyme.
206-
/// For Debug purpose, prefer the OPT flag below
207-
PrintModAfterOpts,
208208
/// Print the module after Enzyme differentiated everything.
209-
PrintModAfterEnzyme,
209+
PrintModAfter,
210210

211-
/// Enzyme's loose type debug helper (can cause incorrect gradients)
211+
/// Enzyme's loose type debug helper (can cause incorrect gradients!!)
212+
/// Usable in cases where Enzyme errors with `can not deduce type of X`.
212213
LooseTypes,
213-
214-
/// More flags
215-
NoModOptAfter,
216-
/// Tell Enzyme to run LLVM Opts on each function it generated. By default off,
217-
/// since we already optimize the whole module after Enzyme is done.
218-
EnableFncOpt,
219-
NoVecUnroll,
214+
/// See Enzyme core documentation. FIXME(ZuseZ4): Clarify usages
220215
RuntimeActivity,
221-
/// Runs Enzyme specific Inlining
216+
/// Runs Enzyme's aggressive inlining
222217
Inline,
223218
}
224219

compiler/rustc_session/src/options.rs

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1348,17 +1348,15 @@ pub mod parse {
13481348
v.sort_unstable();
13491349
for &val in v.iter() {
13501350
let variant = match val {
1351+
"Enable" => AutoDiff::Enable,
13511352
"PrintTA" => AutoDiff::PrintTA,
13521353
"PrintAA" => AutoDiff::PrintAA,
13531354
"PrintPerf" => AutoDiff::PrintPerf,
1354-
"Print" => AutoDiff::Print,
1355+
"PrintSteps" => AutoDiff::PrintSteps,
13551356
"PrintModBefore" => AutoDiff::PrintModBefore,
1356-
"PrintModAfterOpts" => AutoDiff::PrintModAfterOpts,
1357-
"PrintModAfterEnzyme" => AutoDiff::PrintModAfterEnzyme,
1357+
"PrintModAfter" => AutoDiff::PrintModAfter,
13581358
"LooseTypes" => AutoDiff::LooseTypes,
1359-
"NoModOptAfter" => AutoDiff::NoModOptAfter,
1360-
"EnableFncOpt" => AutoDiff::EnableFncOpt,
1361-
"NoVecUnroll" => AutoDiff::NoVecUnroll,
1359+
"RuntimeActivity" => AutoDiff::RuntimeActivity,
13621360
"Inline" => AutoDiff::Inline,
13631361
_ => {
13641362
// FIXME(ZuseZ4): print an error saying which value is not recognized
@@ -2081,19 +2079,18 @@ options! {
20812079
assume_incomplete_release: bool = (false, parse_bool, [TRACKED],
20822080
"make cfg(version) treat the current version as incomplete (default: no)"),
20832081
autodiff: Vec<crate::config::AutoDiff> = (Vec::new(), parse_autodiff, [TRACKED],
2084-
"a list of optional autodiff flags to enable
2082+
"a list of autodiff flags to enable
2083+
Mandatory setting:
2084+
`=Enable`
20852085
Optional extra settings:
20862086
`=PrintTA`
20872087
`=PrintAA`
20882088
`=PrintPerf`
2089-
`=Print`
2089+
`=PrintSteps`
20902090
`=PrintModBefore`
2091-
`=PrintModAfterOpts`
2092-
`=PrintModAfterEnzyme`
2091+
`=PrintModAfter`
20932092
`=LooseTypes`
2094-
`=NoModOptAfter`
2095-
`=EnableFncOpt`
2096-
`=NoVecUnroll`
2093+
'=RuntimeActivity`
20972094
`=Inline`
20982095
Multiple options can be combined with commas."),
20992096
#[rustc_lint_opt_deny_field_access("use `Session::binary_dep_depinfo` instead of this field")]

0 commit comments

Comments
 (0)