Skip to content

Commit 2b96629

Browse files
committed
wip
1 parent 51634d1 commit 2b96629

File tree

6 files changed

+46
-28
lines changed

6 files changed

+46
-28
lines changed

compiler/rustc_ast/src/expand/autodiff_attrs.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ pub struct AutoDiffItem {
7979
pub target: String,
8080
pub attrs: AutoDiffAttrs,
8181
}
82+
8283
#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
8384
pub struct AutoDiffAttrs {
8485
/// Conceptually either forward or reverse mode AD, as described in various autodiff papers and
@@ -231,7 +232,7 @@ impl AutoDiffAttrs {
231232
self.ret_activity == DiffActivity::ActiveOnly
232233
}
233234

234-
pub fn error() -> Self {
235+
pub const fn error() -> Self {
235236
AutoDiffAttrs {
236237
mode: DiffMode::Error,
237238
ret_activity: DiffActivity::None,

compiler/rustc_codegen_ssa/src/base.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ use std::cmp;
22
use std::collections::BTreeSet;
33
use std::time::{Duration, Instant};
44

5-
use rustc_middle::mir::mono::MonoItemPartitions;
65
use itertools::Itertools;
76
use rustc_abi::FIRST_VARIANT;
87
use rustc_ast::expand::allocator::{ALLOCATOR_METHODS, AllocatorKind, global_fn_name};
@@ -19,7 +18,7 @@ use rustc_middle::middle::debugger_visualizer::{DebuggerVisualizerFile, Debugger
1918
use rustc_middle::middle::exported_symbols::SymbolExportKind;
2019
use rustc_middle::middle::{exported_symbols, lang_items};
2120
use rustc_middle::mir::BinOp;
22-
use rustc_middle::mir::mono::{CodegenUnit, CodegenUnitNameBuilder, MonoItem};
21+
use rustc_middle::mir::mono::{CodegenUnit, CodegenUnitNameBuilder, MonoItem, MonoItemPartitions};
2322
use rustc_middle::query::Providers;
2423
use rustc_middle::ty::layout::{HasTyCtxt, HasTypingEnv, LayoutOf, TyAndLayout};
2524
use rustc_middle::ty::{self, Instance, Ty, TyCtxt};
@@ -620,7 +619,8 @@ pub fn codegen_crate<B: ExtraBackendMethods>(
620619

621620
// Run the monomorphization collector and partition the collected items into
622621
// codegen units.
623-
let MonoItemPartitions {codegen_units, autodiff_items, ..} = tcx.collect_and_partition_mono_items(());
622+
let MonoItemPartitions { codegen_units, autodiff_items, .. } =
623+
tcx.collect_and_partition_mono_items(());
624624
let autodiff_fncs = autodiff_items.to_vec();
625625

626626
// Force all codegen_unit queries so they are already either red or green

compiler/rustc_codegen_ssa/src/codegen_attrs.rs

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,13 @@ fn codegen_fn_attrs(tcx: TyCtxt<'_>, did: LocalDefId) -> CodegenFnAttrs {
7171
codegen_fn_attrs.flags |= CodegenFnAttrFlags::TRACK_CALLER;
7272
}
7373

74+
// If our rustc version supports autodiff/enzyme, then we call our handler
75+
// to check for any `#[rustc_autodiff(...)]` attributes.
76+
//if cfg!(llvm_enzyme) {
77+
// let ad = autodiff_attrs(tcx, did.into());
78+
// codegen_fn_attrs.autodiff_item = ad;
79+
//}
80+
7481
// When `no_builtins` is applied at the crate level, we should add the
7582
// `no-builtins` attribute to each function to ensure it takes effect in LTO.
7683
let crate_attrs = tcx.hir().attrs(rustc_hir::CRATE_HIR_ID);
@@ -867,7 +874,7 @@ impl<'a> MixedExportNameAndNoMangleState<'a> {
867874
/// being differentiated). The other form is #[rustc_autodiff(Mode, ActivityList)] on top of the
868875
/// placeholder functions. We wrote the rustc_autodiff attributes ourself, so this should never
869876
/// panic, unless we introduced a bug when parsing the autodiff macro.
870-
fn autodiff_attrs(tcx: TyCtxt<'_>, id: DefId) -> AutoDiffAttrs {
877+
fn autodiff_attrs(tcx: TyCtxt<'_>, id: DefId) -> Option<AutoDiffAttrs> {
871878
let attrs = tcx.get_attrs(id, sym::rustc_autodiff);
872879

873880
let attrs =
@@ -878,7 +885,7 @@ fn autodiff_attrs(tcx: TyCtxt<'_>, id: DefId) -> AutoDiffAttrs {
878885
// FIXME(ZuseZ4): re-enable this check. Currently we add multiple, which doesn't cause harm but
879886
// looks strange e.g. under cargo-expand.
880887
let attr = match &attrs[..] {
881-
[] => return AutoDiffAttrs::error(),
888+
[] => return None,
882889
[attr] => attr,
883890
// These two attributes are the same and unfortunately duplicated due to a previous bug.
884891
[attr, _attr2] => attr,
@@ -893,7 +900,7 @@ fn autodiff_attrs(tcx: TyCtxt<'_>, id: DefId) -> AutoDiffAttrs {
893900

894901
// empty autodiff attribute macros (i.e. `#[autodiff]`) are used to mark source functions
895902
if list.is_empty() {
896-
return AutoDiffAttrs::source();
903+
return Some(AutoDiffAttrs::source());
897904
}
898905

899906
let [mode, input_activities @ .., ret_activity] = &list[..] else {
@@ -962,10 +969,10 @@ fn autodiff_attrs(tcx: TyCtxt<'_>, id: DefId) -> AutoDiffAttrs {
962969
span_bug!(attr.span, "Invalid return activity {} for {} mode", ret_activity, mode);
963970
}
964971

965-
AutoDiffAttrs { mode, ret_activity, input_activity: arg_activities }
972+
Some(AutoDiffAttrs { mode, ret_activity, input_activity: arg_activities })
966973
}
967974

968975
pub(crate) fn provide(providers: &mut Providers) {
969976
*providers =
970-
Providers { codegen_fn_attrs, should_inherit_track_caller, autodiff_attrs, ..*providers };
977+
Providers { codegen_fn_attrs, should_inherit_track_caller, ..*providers };
971978
}

compiler/rustc_middle/src/middle/codegen_fn_attrs.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
use rustc_abi::Align;
2+
use rustc_ast::expand::autodiff_attrs::AutoDiffAttrs;
23
use rustc_attr_parsing::{InlineAttr, InstructionSetAttr, OptimizeAttr};
34
use rustc_macros::{HashStable, TyDecodable, TyEncodable};
45
use rustc_span::Symbol;
@@ -52,6 +53,8 @@ pub struct CodegenFnAttrs {
5253
/// The `#[patchable_function_entry(...)]` attribute. Indicates how many nops should be around
5354
/// the function entry.
5455
pub patchable_function_entry: Option<PatchableFunctionEntry>,
56+
/// For the `#[autodiff]` macros.
57+
pub autodiff_item: Option<AutoDiffAttrs>,
5558
}
5659

5760
#[derive(Copy, Clone, Debug, TyEncodable, TyDecodable, HashStable)]
@@ -160,6 +163,7 @@ impl CodegenFnAttrs {
160163
instruction_set: None,
161164
alignment: None,
162165
patchable_function_entry: None,
166+
autodiff_item: None,
163167
}
164168
}
165169

compiler/rustc_middle/src/query/mod.rs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ use std::sync::Arc;
1313
use rustc_arena::TypedArena;
1414
use rustc_ast::expand::StrippedCfgItem;
1515
use rustc_ast::expand::allocator::AllocatorKind;
16-
use rustc_ast::expand::autodiff_attrs::AutoDiffAttrs;
16+
//use rustc_ast::expand::autodiff_attrs::AutoDiffAttrs;
1717
use rustc_data_structures::fingerprint::Fingerprint;
1818
use rustc_data_structures::fx::{FxIndexMap, FxIndexSet};
1919
use rustc_data_structures::sorted_map::SortedMap;
@@ -1394,11 +1394,11 @@ rustc_queries! {
13941394
}
13951395

13961396
/// List of autodiff extern functions in the current crate.
1397-
query autodiff_attrs(def_id: DefId) -> &'tcx AutoDiffAttrs {
1398-
desc { |tcx| "computing autodiff attributes of `{}`", tcx.def_path_str(def_id) }
1399-
arena_cache
1400-
cache_on_disk_if { def_id.is_local() }
1401-
}
1397+
//query autodiff_attrs(def_id: DefId) -> &'tcx AutoDiffAttrs {
1398+
// desc { |tcx| "computing autodiff attributes of `{}`", tcx.def_path_str(def_id) }
1399+
// arena_cache
1400+
// cache_on_disk_if { def_id.is_local() }
1401+
//}
14021402

14031403
query asm_target_features(def_id: DefId) -> &'tcx FxIndexSet<Symbol> {
14041404
desc { |tcx| "computing target features for inline asm of `{}`", tcx.def_path_str(def_id) }

compiler/rustc_monomorphize/src/partitioning.rs

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ use std::fs::{self, File};
9898
use std::io::Write;
9999
use std::path::{Path, PathBuf};
100100

101-
use rustc_ast::expand::autodiff_attrs::{AutoDiffAttrs, AutoDiffItem, DiffActivity};
101+
use rustc_ast::expand::autodiff_attrs::{AutoDiffItem, DiffActivity};
102102
use rustc_data_structures::fx::{FxIndexMap, FxIndexSet};
103103
use rustc_data_structures::sync;
104104
use rustc_data_structures::unord::{UnordMap, UnordSet};
@@ -254,19 +254,20 @@ where
254254
always_export_generics,
255255
);
256256

257+
// TODO: This currently crashes compilation:
258+
// thread 'rustc' panicked at compiler/rustc_metadata/src/rmeta/decoder/cstore_impl.rs:242:1:
259+
// DefId(2:45681 ~ core[1460]::iter::adapters::copied::Copied) does not have a "codegen_fn_attrs"
260+
//
257261
// We can't differentiate something that got inlined.
258-
let autodiff_active = if cfg!(llvm_enzyme) {
259-
match characteristic_def_id {
260-
Some(def_id) => cx.tcx.autodiff_attrs(def_id).is_active(),
261-
None => false,
262-
}
263-
} else {
264-
false
262+
let autodiff_active = cfg!(llvm_enzyme) && {
263+
characteristic_def_id
264+
.and_then(|def_id| cx.tcx.codegen_fn_attrs(def_id).autodiff_item.as_ref())
265+
.is_some_and(|ad| ad.is_active())
265266
};
266267

267-
if !autodiff_active && visibility == Visibility::Hidden && can_be_internalized {
268-
internalization_candidates.insert(mono_item);
269-
}
268+
//if !autodiff_active && visibility == Visibility::Hidden && can_be_internalized {
269+
// internalization_candidates.insert(mono_item);
270+
//}
270271
let size_estimate = mono_item.size_estimate(cx.tcx);
271272

272273
cgu.items_mut().insert(mono_item, MonoItemData {
@@ -1254,7 +1255,11 @@ fn collect_and_partition_mono_items(tcx: TyCtxt<'_>, (): ()) -> MonoItemPartitio
12541255

12551256
for (item, instance) in autodiff_mono_items {
12561257
let target_id = instance.def_id();
1257-
let target_attrs: &AutoDiffAttrs = tcx.autodiff_attrs(target_id);
1258+
let cg_fn_attr = tcx.codegen_fn_attrs(target_id).autodiff_item.clone();
1259+
let Some(target_attrs) = cg_fn_attr else {
1260+
continue;
1261+
};
1262+
//let target_attrs: &AutoDiffAttrs = tcx.autodiff_attrs(target_id);
12581263
let mut input_activities: Vec<DiffActivity> = target_attrs.input_activity.clone();
12591264
if target_attrs.is_source() {
12601265
trace!("source found: {:?}", target_id);
@@ -1269,7 +1274,8 @@ fn collect_and_partition_mono_items(tcx: TyCtxt<'_>, (): ()) -> MonoItemPartitio
12691274
usage_map.used_map.get(&item).unwrap().into_iter().find_map(|item| match *item {
12701275
MonoItem::Fn(ref instance_s) => {
12711276
let source_id = instance_s.def_id();
1272-
if tcx.autodiff_attrs(source_id).is_active() {
1277+
//if tcx.autodiff_attrs(source_id).is_active() {
1278+
if let Some(ad) = &tcx.codegen_fn_attrs(source_id).autodiff_item && ad.is_active() {
12731279
return Some(instance_s);
12741280
}
12751281
None

0 commit comments

Comments
 (0)