Skip to content

Commit d3ed369

Browse files
committed
upstream rustc_codegen_ssa/rustc_middle changes for enzyme/autodiff
1 parent ebcf860 commit d3ed369

File tree

25 files changed

+440
-28
lines changed

25 files changed

+440
-28
lines changed

Cargo.lock

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4234,6 +4234,7 @@ name = "rustc_monomorphize"
42344234
version = "0.0.0"
42354235
dependencies = [
42364236
"rustc_abi",
4237+
"rustc_ast",
42374238
"rustc_attr_parsing",
42384239
"rustc_data_structures",
42394240
"rustc_errors",
@@ -4243,6 +4244,7 @@ dependencies = [
42434244
"rustc_middle",
42444245
"rustc_session",
42454246
"rustc_span",
4247+
"rustc_symbol_mangling",
42464248
"rustc_target",
42474249
"serde",
42484250
"serde_json",

compiler/rustc_codegen_llvm/src/builder/autodiff.rs

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,8 @@ fn generate_enzyme_call<'ll>(
6262
// add outer_fn name to ad_name to make it unique, in case users apply autodiff to multiple
6363
// functions. Unwrap will only panic, if LLVM gave us an invalid string.
6464
let name = llvm::get_value_name(outer_fn);
65-
let outer_fn_name = std::ffi::CStr::from_bytes_with_nul(name).unwrap().to_str().unwrap();
66-
ad_name.push_str(outer_fn_name.to_string().as_str());
65+
let outer_fn_name = std::str::from_utf8(name).unwrap();
66+
ad_name.push_str(outer_fn_name);
6767

6868
// Let us assume the user wrote the following function square:
6969
//
@@ -255,21 +255,25 @@ fn generate_enzyme_call<'ll>(
255255
// have no debug info to copy, which would then be ok.
256256
trace!("no dbg info");
257257
}
258+
258259
// Now that we copied the metadata, get rid of dummy code.
259-
llvm::LLVMRustEraseInstBefore(entry, last_inst);
260-
llvm::LLVMRustEraseInstFromParent(last_inst);
260+
llvm::LLVMRustEraseInstUntilInclusive(entry, last_inst);
261+
dbg!("cleaneup done");
262+
llvm::LLVMDumpValue(outer_fn);
261263

262264
if cx.val_ty(outer_fn) != cx.type_void() {
263265
builder.ret(call);
264266
} else {
265267
builder.ret_void();
266268
}
269+
dbg!("build ret done");
267270

268271
// Let's crash in case that we messed something up above and generated invalid IR.
269272
llvm::LLVMRustVerifyFunction(
270273
outer_fn,
271274
llvm::LLVMRustVerifierFailureAction::LLVMAbortProcessAction,
272275
);
276+
dbg!("verification done");
273277
}
274278
}
275279

@@ -308,6 +312,7 @@ pub(crate) fn differentiate<'ll>(
308312
};
309313

310314
generate_enzyme_call(&cx, fn_def, fn_target, item.attrs.clone());
315+
dbg!("generated enzyme call");
311316
}
312317

313318
// FIXME(ZuseZ4): support SanitizeHWAddress and prevent illegal/unsupported opts

compiler/rustc_codegen_llvm/src/coverageinfo/mapgen.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -298,7 +298,7 @@ struct UsageSets<'tcx> {
298298
/// Prepare sets of definitions that are relevant to deciding whether something
299299
/// is an "unused function" for coverage purposes.
300300
fn prepare_usage_sets<'tcx>(tcx: TyCtxt<'tcx>) -> UsageSets<'tcx> {
301-
let MonoItemPartitions { all_mono_items, codegen_units } =
301+
let MonoItemPartitions { all_mono_items, codegen_units, .. } =
302302
tcx.collect_and_partition_mono_items(());
303303

304304
// Obtain a MIR body for each function participating in codegen, via an

compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,13 @@ use crate::llvm::Bool;
77
extern "C" {
88
// Enzyme
99
pub fn LLVMRustHasMetadata(I: &Value, KindID: c_uint) -> bool;
10-
pub fn LLVMRustEraseInstBefore(BB: &BasicBlock, I: &Value);
10+
pub fn LLVMRustEraseInstUntilInclusive(BB: &BasicBlock, I: &Value);
1111
pub fn LLVMRustGetLastInstruction<'a>(BB: &BasicBlock) -> Option<&'a Value>;
1212
pub fn LLVMRustDIGetInstMetadata(I: &Value) -> Option<&Metadata>;
1313
pub fn LLVMRustEraseInstFromParent(V: &Value);
1414
pub fn LLVMRustGetTerminator<'a>(B: &BasicBlock) -> &'a Value;
15+
pub fn LLVMDumpModule(M: &Module);
16+
pub fn LLVMDumpValue(V: &Value);
1517
pub fn LLVMRustVerifyFunction(V: &Value, action: LLVMRustVerifierFailureAction) -> Bool;
1618

1719
pub fn LLVMGetFunctionCallConv(F: &Value) -> c_uint;

compiler/rustc_codegen_ssa/messages.ftl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -369,3 +369,6 @@ codegen_ssa_use_cargo_directive = use the `cargo:rustc-link-lib` directive to sp
369369
codegen_ssa_version_script_write_failure = failed to write version script: {$error}
370370
371371
codegen_ssa_visual_studio_not_installed = you may need to install Visual Studio build tools with the "C++ build tools" workload
372+
373+
codegen_ssa_autodiff_without_lto = using the autodiff feature requires using fat-lto
374+

compiler/rustc_codegen_ssa/src/back/write.rs

Lines changed: 33 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ use std::sync::mpsc::{Receiver, Sender, channel};
77
use std::{fs, io, mem, str, thread};
88

99
use rustc_ast::attr;
10+
use rustc_ast::expand::autodiff_attrs::AutoDiffItem;
1011
use rustc_data_structures::fx::{FxHashMap, FxIndexMap};
1112
use rustc_data_structures::jobserver::{self, Acquired};
1213
use rustc_data_structures::memmap::Mmap;
@@ -40,7 +41,7 @@ use tracing::debug;
4041
use super::link::{self, ensure_removed};
4142
use super::lto::{self, SerializedModule};
4243
use super::symbol_export::symbol_name_for_instance_in_crate;
43-
use crate::errors::ErrorCreatingRemarkDir;
44+
use crate::errors::{AutodiffWithoutLto, ErrorCreatingRemarkDir};
4445
use crate::traits::*;
4546
use crate::{
4647
CachedModuleCodegen, CodegenResults, CompiledModule, CrateInfo, ModuleCodegen, ModuleKind,
@@ -118,6 +119,7 @@ pub struct ModuleConfig {
118119
pub merge_functions: bool,
119120
pub emit_lifetime_markers: bool,
120121
pub llvm_plugins: Vec<String>,
122+
pub autodiff: Vec<config::AutoDiff>,
121123
}
122124

123125
impl ModuleConfig {
@@ -266,6 +268,7 @@ impl ModuleConfig {
266268

267269
emit_lifetime_markers: sess.emit_lifetime_markers(),
268270
llvm_plugins: if_regular!(sess.opts.unstable_opts.llvm_plugins.clone(), vec![]),
271+
autodiff: if_regular!(sess.opts.unstable_opts.autodiff.clone(), vec![]),
269272
}
270273
}
271274

@@ -389,6 +392,7 @@ impl<B: WriteBackendMethods> CodegenContext<B> {
389392

390393
fn generate_lto_work<B: ExtraBackendMethods>(
391394
cgcx: &CodegenContext<B>,
395+
autodiff: Vec<AutoDiffItem>,
392396
needs_fat_lto: Vec<FatLtoInput<B>>,
393397
needs_thin_lto: Vec<(String, B::ThinBuffer)>,
394398
import_only_modules: Vec<(SerializedModule<B::ModuleBuffer>, WorkProduct)>,
@@ -397,11 +401,19 @@ fn generate_lto_work<B: ExtraBackendMethods>(
397401

398402
if !needs_fat_lto.is_empty() {
399403
assert!(needs_thin_lto.is_empty());
400-
let module =
404+
let mut module =
401405
B::run_fat_lto(cgcx, needs_fat_lto, import_only_modules).unwrap_or_else(|e| e.raise());
406+
if cgcx.lto == Lto::Fat {
407+
let config = cgcx.config(ModuleKind::Regular);
408+
module = unsafe { module.autodiff(cgcx, autodiff, config).unwrap() };
409+
}
402410
// We are adding a single work item, so the cost doesn't matter.
403411
vec![(WorkItem::LTO(module), 0)]
404412
} else {
413+
if !autodiff.is_empty() {
414+
let dcx = cgcx.create_dcx();
415+
dcx.handle().emit_fatal(AutodiffWithoutLto {});
416+
}
405417
assert!(needs_fat_lto.is_empty());
406418
let (lto_modules, copy_jobs) = B::run_thin_lto(cgcx, needs_thin_lto, import_only_modules)
407419
.unwrap_or_else(|e| e.raise());
@@ -1021,6 +1033,9 @@ pub(crate) enum Message<B: WriteBackendMethods> {
10211033
/// Sent from a backend worker thread.
10221034
WorkItem { result: Result<WorkItemResult<B>, Option<WorkerFatalError>>, worker_id: usize },
10231035

1036+
/// A vector containing all the AutoDiff tasks that we have to pass to Enzyme.
1037+
AddAutoDiffItems(Vec<AutoDiffItem>),
1038+
10241039
/// The frontend has finished generating something (backend IR or a
10251040
/// post-LTO artifact) for a codegen unit, and it should be passed to the
10261041
/// backend. Sent from the main thread.
@@ -1348,6 +1363,7 @@ fn start_executing_work<B: ExtraBackendMethods>(
13481363

13491364
// This is where we collect codegen units that have gone all the way
13501365
// through codegen and LLVM.
1366+
let mut autodiff_items = Vec::new();
13511367
let mut compiled_modules = vec![];
13521368
let mut compiled_allocator_module = None;
13531369
let mut needs_link = Vec::new();
@@ -1459,9 +1475,13 @@ fn start_executing_work<B: ExtraBackendMethods>(
14591475
let needs_thin_lto = mem::take(&mut needs_thin_lto);
14601476
let import_only_modules = mem::take(&mut lto_import_only_modules);
14611477

1462-
for (work, cost) in
1463-
generate_lto_work(&cgcx, needs_fat_lto, needs_thin_lto, import_only_modules)
1464-
{
1478+
for (work, cost) in generate_lto_work(
1479+
&cgcx,
1480+
autodiff_items.clone(),
1481+
needs_fat_lto,
1482+
needs_thin_lto,
1483+
import_only_modules,
1484+
) {
14651485
let insertion_index = work_items
14661486
.binary_search_by_key(&cost, |&(_, cost)| cost)
14671487
.unwrap_or_else(|e| e);
@@ -1596,6 +1616,10 @@ fn start_executing_work<B: ExtraBackendMethods>(
15961616
main_thread_state = MainThreadState::Idle;
15971617
}
15981618

1619+
Message::AddAutoDiffItems(mut items) => {
1620+
autodiff_items.append(&mut items);
1621+
}
1622+
15991623
Message::CodegenComplete => {
16001624
if codegen_state != Aborted {
16011625
codegen_state = Completed;
@@ -2070,6 +2094,10 @@ impl<B: ExtraBackendMethods> OngoingCodegen<B> {
20702094
drop(self.coordinator.sender.send(Box::new(Message::CodegenComplete::<B>)));
20712095
}
20722096

2097+
pub(crate) fn submit_autodiff_items(&self, items: Vec<AutoDiffItem>) {
2098+
drop(self.coordinator.sender.send(Box::new(Message::<B>::AddAutoDiffItems(items))));
2099+
}
2100+
20732101
pub(crate) fn check_for_errors(&self, sess: &Session) {
20742102
self.shared_emitter_main.check(sess, false);
20752103
}

compiler/rustc_codegen_ssa/src/base.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -620,6 +620,8 @@ pub fn codegen_crate<B: ExtraBackendMethods>(
620620
// Run the monomorphization collector and partition the collected items into
621621
// codegen units.
622622
let codegen_units = tcx.collect_and_partition_mono_items(()).codegen_units;
623+
let autodiff_fncs = tcx.collect_and_partition_mono_items(()).autodiff_items;
624+
let autodiff_fncs = autodiff_fncs.to_vec();
623625

624626
// Force all codegen_unit queries so they are already either red or green
625627
// when compile_codegen_unit accesses them. We are not able to re-execute
@@ -690,6 +692,10 @@ pub fn codegen_crate<B: ExtraBackendMethods>(
690692
);
691693
}
692694

695+
if !autodiff_fncs.is_empty() {
696+
ongoing_codegen.submit_autodiff_items(autodiff_fncs);
697+
}
698+
693699
// For better throughput during parallel processing by LLVM, we used to sort
694700
// CGUs largest to smallest. This would lead to better thread utilization
695701
// by, for example, preventing a large CGU from being processed last and

compiler/rustc_codegen_ssa/src/codegen_attrs.rs

Lines changed: 112 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
1+
use std::str::FromStr;
2+
13
use rustc_ast::attr::list_contains_name;
2-
use rustc_ast::{MetaItemInner, attr};
4+
use rustc_ast::expand::autodiff_attrs::{
5+
AutoDiffAttrs, DiffActivity, DiffMode, valid_input_activity, valid_ret_activity,
6+
};
7+
use rustc_ast::{MetaItem, MetaItemInner, attr};
38
use rustc_attr_parsing::{InlineAttr, InstructionSetAttr, OptimizeAttr};
49
use rustc_data_structures::fx::FxHashMap;
510
use rustc_errors::codes::*;
@@ -13,6 +18,7 @@ use rustc_middle::middle::codegen_fn_attrs::{
1318
};
1419
use rustc_middle::mir::mono::Linkage;
1520
use rustc_middle::query::Providers;
21+
use rustc_middle::span_bug;
1622
use rustc_middle::ty::{self as ty, TyCtxt};
1723
use rustc_session::parse::feature_err;
1824
use rustc_session::{Session, lint};
@@ -856,6 +862,110 @@ impl<'a> MixedExportNameAndNoMangleState<'a> {
856862
}
857863
}
858864

865+
/// We now check the #[rustc_autodiff] attributes which we generated from the #[autodiff(...)]
866+
/// macros. There are two forms. The pure one without args to mark primal functions (the functions
867+
/// being differentiated). The other form is #[rustc_autodiff(Mode, ActivityList)] on top of the
868+
/// placeholder functions. We wrote the rustc_autodiff attributes ourself, so this should never
869+
/// panic, unless we introduced a bug when parsing the autodiff macro.
870+
fn autodiff_attrs(tcx: TyCtxt<'_>, id: DefId) -> AutoDiffAttrs {
871+
let attrs = tcx.get_attrs(id, sym::rustc_autodiff);
872+
873+
let attrs =
874+
attrs.filter(|attr| attr.name_or_empty() == sym::rustc_autodiff).collect::<Vec<_>>();
875+
876+
// check for exactly one autodiff attribute on placeholder functions.
877+
// There should only be one, since we generate a new placeholder per ad macro.
878+
// FIXME(ZuseZ4): re-enable this check. Currently we add multiple, which doesn't cause harm but
879+
// looks strange e.g. under cargo-expand.
880+
let attr = match &attrs[..] {
881+
[] => return AutoDiffAttrs::error(),
882+
[attr] => attr,
883+
// These two attributes are the same and unfortunately duplicated due to a previous bug.
884+
[attr, _attr2] => attr,
885+
_ => {
886+
//FIXME(ZuseZ4): Once we fixed our parser, we should also prohibit the two-attribute
887+
//branch above.
888+
span_bug!(attrs[1].span, "cg_ssa: rustc_autodiff should only exist once per source");
889+
}
890+
};
891+
892+
let list = attr.meta_item_list().unwrap_or_default();
893+
894+
// empty autodiff attribute macros (i.e. `#[autodiff]`) are used to mark source functions
895+
if list.is_empty() {
896+
return AutoDiffAttrs::source();
897+
}
898+
899+
let [mode, input_activities @ .., ret_activity] = &list[..] else {
900+
span_bug!(attr.span, "rustc_autodiff attribute must contain mode and activities");
901+
};
902+
let mode = if let MetaItemInner::MetaItem(MetaItem { path: ref p1, .. }) = mode {
903+
p1.segments.first().unwrap().ident
904+
} else {
905+
span_bug!(attr.span, "rustc_autodiff attribute must contain mode");
906+
};
907+
908+
// parse mode
909+
let mode = match mode.as_str() {
910+
"Forward" => DiffMode::Forward,
911+
"Reverse" => DiffMode::Reverse,
912+
"ForwardFirst" => DiffMode::ForwardFirst,
913+
"ReverseFirst" => DiffMode::ReverseFirst,
914+
_ => {
915+
span_bug!(mode.span, "rustc_autodiff attribute contains invalid mode");
916+
}
917+
};
918+
919+
// First read the ret symbol from the attribute
920+
let ret_symbol = if let MetaItemInner::MetaItem(MetaItem { path: ref p1, .. }) = ret_activity {
921+
p1.segments.first().unwrap().ident
922+
} else {
923+
span_bug!(attr.span, "rustc_autodiff attribute must contain the return activity");
924+
};
925+
926+
// Then parse it into an actual DiffActivity
927+
let Ok(ret_activity) = DiffActivity::from_str(ret_symbol.as_str()) else {
928+
span_bug!(ret_symbol.span, "invalid return activity");
929+
};
930+
931+
// Now parse all the intermediate (input) activities
932+
let mut arg_activities: Vec<DiffActivity> = vec![];
933+
for arg in input_activities {
934+
let arg_symbol = if let MetaItemInner::MetaItem(MetaItem { path: ref p2, .. }) = arg {
935+
match p2.segments.first() {
936+
Some(x) => x.ident,
937+
None => {
938+
span_bug!(
939+
arg.span(),
940+
"rustc_autodiff attribute must contain the input activity"
941+
);
942+
}
943+
}
944+
} else {
945+
span_bug!(arg.span(), "rustc_autodiff attribute must contain the input activity");
946+
};
947+
948+
match DiffActivity::from_str(arg_symbol.as_str()) {
949+
Ok(arg_activity) => arg_activities.push(arg_activity),
950+
Err(_) => {
951+
span_bug!(arg_symbol.span, "invalid input activity");
952+
}
953+
}
954+
}
955+
956+
for &input in &arg_activities {
957+
if !valid_input_activity(mode, input) {
958+
span_bug!(attr.span, "Invalid input activity {} for {} mode", input, mode);
959+
}
960+
}
961+
if !valid_ret_activity(mode, ret_activity) {
962+
span_bug!(attr.span, "Invalid return activity {} for {} mode", ret_activity, mode);
963+
}
964+
965+
AutoDiffAttrs { mode, ret_activity, input_activity: arg_activities }
966+
}
967+
859968
pub(crate) fn provide(providers: &mut Providers) {
860-
*providers = Providers { codegen_fn_attrs, should_inherit_track_caller, ..*providers };
969+
*providers =
970+
Providers { codegen_fn_attrs, should_inherit_track_caller, autodiff_attrs, ..*providers };
861971
}

compiler/rustc_codegen_ssa/src/errors.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,10 @@ pub(crate) struct CguNotRecorded<'a> {
3939
pub cgu_name: &'a str,
4040
}
4141

42+
#[derive(Diagnostic)]
43+
#[diag(codegen_ssa_autodiff_without_lto)]
44+
pub struct AutodiffWithoutLto;
45+
4246
#[derive(Diagnostic)]
4347
#[diag(codegen_ssa_unknown_reuse_kind)]
4448
pub(crate) struct UnknownReuseKind {

compiler/rustc_interface/src/tests.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -760,6 +760,7 @@ fn test_unstable_options_tracking_hash() {
760760
tracked!(allow_features, Some(vec![String::from("lang_items")]));
761761
tracked!(always_encode_mir, true);
762762
tracked!(assume_incomplete_release, true);
763+
tracked!(autodiff, vec![String::from("ad_flags")]);
763764
tracked!(binary_dep_depinfo, true);
764765
tracked!(box_noalias, false);
765766
tracked!(

0 commit comments

Comments
 (0)