Skip to content

Commit cc292c8

Browse files
futileKobzol
authored andcommitted
Implement incremental caching for derive macro expansions
1 parent a8537ab commit cc292c8

File tree

15 files changed

+262
-35
lines changed

15 files changed

+262
-35
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3748,6 +3748,7 @@ dependencies = [
37483748
"rustc_lexer",
37493749
"rustc_lint_defs",
37503750
"rustc_macros",
3751+
"rustc_middle",
37513752
"rustc_parse",
37523753
"rustc_proc_macro",
37533754
"rustc_serialize",

compiler/rustc_ast/src/tokenstream.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,17 @@
55
//! which are themselves a single [`Token`] or a `Delimited` subsequence of tokens.
66
77
use std::borrow::Cow;
8+
use std::hash::Hash;
89
use std::ops::Range;
910
use std::sync::Arc;
1011
use std::{cmp, fmt, iter, mem};
1112

1213
use rustc_data_structures::stable_hasher::{HashStable, StableHasher};
1314
use rustc_data_structures::sync;
1415
use rustc_macros::{Decodable, Encodable, HashStable_Generic, Walkable};
15-
use rustc_serialize::{Decodable, Encodable};
16-
use rustc_span::{DUMMY_SP, Span, SpanDecoder, SpanEncoder, Symbol, sym};
16+
use rustc_serialize::{Decodable, Encodable, Encoder};
17+
use rustc_span::def_id::{CrateNum, DefIndex};
18+
use rustc_span::{ByteSymbol, DUMMY_SP, Span, SpanDecoder, SpanEncoder, Symbol, sym};
1719
use thin_vec::ThinVec;
1820

1921
use crate::ast::AttrStyle;

compiler/rustc_expand/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ rustc_hir = { path = "../rustc_hir" }
2121
rustc_lexer = { path = "../rustc_lexer" }
2222
rustc_lint_defs = { path = "../rustc_lint_defs" }
2323
rustc_macros = { path = "../rustc_macros" }
24+
rustc_middle = { path = "../rustc_middle" }
2425
rustc_parse = { path = "../rustc_parse" }
2526
# We must use the proc_macro version that we will compile proc-macros against,
2627
# not the one from our own sysroot.

compiler/rustc_expand/src/lib.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,4 +31,8 @@ pub mod module;
3131
#[allow(rustc::untranslatable_diagnostic)]
3232
pub mod proc_macro;
3333

34+
pub fn provide(providers: &mut rustc_middle::util::Providers) {
35+
providers.derive_macro_expansion = proc_macro::provide_derive_macro_expansion;
36+
}
37+
3438
rustc_fluent_macro::fluent_messages! { "../messages.ftl" }

compiler/rustc_expand/src/proc_macro.rs

Lines changed: 146 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,15 @@
1+
use std::cell::Cell;
2+
use std::ptr::NonNull;
3+
14
use rustc_ast::tokenstream::TokenStream;
5+
use rustc_data_structures::svh::Svh;
26
use rustc_errors::ErrorGuaranteed;
7+
use rustc_middle::ty::{self, TyCtxt};
38
use rustc_parse::parser::{ForceCollect, Parser};
9+
use rustc_session::Session;
410
use rustc_session::config::ProcMacroExecutionStrategy;
5-
use rustc_span::Span;
611
use rustc_span::profiling::SpannedEventArgRecorder;
12+
use rustc_span::{LocalExpnId, Span};
713
use {rustc_ast as ast, rustc_proc_macro as pm};
814

915
use crate::base::{self, *};
@@ -30,9 +36,9 @@ impl<T> pm::bridge::server::MessagePipe<T> for MessagePipe<T> {
3036
}
3137
}
3238

33-
fn exec_strategy(ecx: &ExtCtxt<'_>) -> impl pm::bridge::server::ExecutionStrategy + 'static {
39+
pub fn exec_strategy(sess: &Session) -> impl pm::bridge::server::ExecutionStrategy + 'static {
3440
pm::bridge::server::MaybeCrossThread::<MessagePipe<_>>::new(
35-
ecx.sess.opts.unstable_opts.proc_macro_execution_strategy
41+
sess.opts.unstable_opts.proc_macro_execution_strategy
3642
== ProcMacroExecutionStrategy::CrossThread,
3743
)
3844
}
@@ -54,7 +60,7 @@ impl base::BangProcMacro for BangProcMacro {
5460
});
5561

5662
let proc_macro_backtrace = ecx.ecfg.proc_macro_backtrace;
57-
let strategy = exec_strategy(ecx);
63+
let strategy = exec_strategy(ecx.sess);
5864
let server = proc_macro_server::Rustc::new(ecx);
5965
self.client.run(&strategy, server, input, proc_macro_backtrace).map_err(|e| {
6066
ecx.dcx().emit_err(errors::ProcMacroPanicked {
@@ -85,7 +91,7 @@ impl base::AttrProcMacro for AttrProcMacro {
8591
});
8692

8793
let proc_macro_backtrace = ecx.ecfg.proc_macro_backtrace;
88-
let strategy = exec_strategy(ecx);
94+
let strategy = exec_strategy(ecx.sess);
8995
let server = proc_macro_server::Rustc::new(ecx);
9096
self.client.run(&strategy, server, annotation, annotated, proc_macro_backtrace).map_err(
9197
|e| {
@@ -113,6 +119,13 @@ impl MultiItemModifier for DeriveProcMacro {
113119
item: Annotatable,
114120
_is_derive_const: bool,
115121
) -> ExpandResult<Vec<Annotatable>, Annotatable> {
122+
let _timer = ecx.sess.prof.generic_activity_with_arg_recorder(
123+
"expand_derive_proc_macro_outer",
124+
|recorder| {
125+
recorder.record_arg_with_span(ecx.sess.source_map(), ecx.expansion_descr(), span);
126+
},
127+
);
128+
116129
// We need special handling for statement items
117130
// (e.g. `fn foo() { #[derive(Debug)] struct Bar; }`)
118131
let is_stmt = matches!(item, Annotatable::Stmt(..));
@@ -123,36 +136,39 @@ impl MultiItemModifier for DeriveProcMacro {
123136
// altogether. See #73345.
124137
crate::base::ann_pretty_printing_compatibility_hack(&item, &ecx.sess.psess);
125138
let input = item.to_tokens();
126-
let stream = {
127-
let _timer =
128-
ecx.sess.prof.generic_activity_with_arg_recorder("expand_proc_macro", |recorder| {
129-
recorder.record_arg_with_span(
130-
ecx.sess.source_map(),
131-
ecx.expansion_descr(),
132-
span,
133-
);
134-
});
135-
let proc_macro_backtrace = ecx.ecfg.proc_macro_backtrace;
136-
let strategy = exec_strategy(ecx);
137-
let server = proc_macro_server::Rustc::new(ecx);
138-
match self.client.run(&strategy, server, input, proc_macro_backtrace) {
139-
Ok(stream) => stream,
140-
Err(e) => {
141-
ecx.dcx().emit_err({
142-
errors::ProcMacroDerivePanicked {
143-
span,
144-
message: e.as_str().map(|message| {
145-
errors::ProcMacroDerivePanickedHelp { message: message.into() }
146-
}),
147-
}
148-
});
149-
return ExpandResult::Ready(vec![]);
150-
}
139+
let res = ty::tls::with(|tcx| {
140+
let input = tcx.arena.alloc(input) as &TokenStream;
141+
let invoc_id = ecx.current_expansion.id;
142+
let invoc_expn_data = invoc_id.expn_data();
143+
144+
assert_eq!(invoc_expn_data.call_site, span);
145+
146+
// FIXME(pr-time): Is this the correct way to check for incremental compilation (as
147+
// well as for `cache_proc_macros`)?
148+
if tcx.sess.opts.incremental.is_some() && tcx.sess.opts.unstable_opts.cache_proc_macros
149+
{
150+
// FIXME(pr-time): Just using the crate hash to notice when the proc-macro code has
151+
// changed. How to *correctly* depend on exactly the macro definition?
152+
// I.e., depending on the crate hash is just a HACK, and ideally the dependency would be
153+
// more narrow.
154+
let macro_def_id = invoc_expn_data.macro_def_id.unwrap();
155+
let proc_macro_crate_hash = tcx.crate_hash(macro_def_id.krate);
156+
157+
let key = (invoc_id, proc_macro_crate_hash, input);
158+
159+
enter_context((ecx, self.client), move || tcx.derive_macro_expansion(key).cloned())
160+
} else {
161+
expand_derive_macro(tcx, invoc_id, input, ecx, self.client).cloned()
151162
}
163+
});
164+
165+
let Ok(output) = res else {
166+
// error will already have been emitted
167+
return ExpandResult::Ready(vec![]);
152168
};
153169

154170
let error_count_before = ecx.dcx().err_count();
155-
let mut parser = Parser::new(&ecx.sess.psess, stream, Some("proc-macro derive"));
171+
let mut parser = Parser::new(&ecx.sess.psess, output, Some("proc-macro derive"));
156172
let mut items = vec![];
157173

158174
loop {
@@ -180,3 +196,102 @@ impl MultiItemModifier for DeriveProcMacro {
180196
ExpandResult::Ready(items)
181197
}
182198
}
199+
200+
pub(super) fn provide_derive_macro_expansion<'tcx>(
201+
tcx: TyCtxt<'tcx>,
202+
key: (LocalExpnId, Svh, &'tcx TokenStream),
203+
) -> Result<&'tcx TokenStream, ()> {
204+
let (invoc_id, _macro_crate_hash, input) = key;
205+
206+
with_context(|(ecx, client)| expand_derive_macro(tcx, invoc_id, input, ecx, *client))
207+
}
208+
209+
type CLIENT = pm::bridge::client::Client<pm::TokenStream, pm::TokenStream>;
210+
211+
fn expand_derive_macro<'tcx>(
212+
tcx: TyCtxt<'tcx>,
213+
invoc_id: LocalExpnId,
214+
input: &'tcx TokenStream,
215+
ecx: &mut ExtCtxt<'_>,
216+
client: CLIENT,
217+
) -> Result<&'tcx TokenStream, ()> {
218+
let invoc_expn_data = invoc_id.expn_data();
219+
let span = invoc_expn_data.call_site;
220+
let event_arg = invoc_expn_data.kind.descr();
221+
let _timer = tcx.sess.prof.generic_activity_with_arg_recorder(
222+
"expand_derive_proc_macro_inner",
223+
|recorder| {
224+
recorder.record_arg_with_span(tcx.sess.source_map(), event_arg.clone(), span);
225+
},
226+
);
227+
228+
let proc_macro_backtrace = ecx.ecfg.proc_macro_backtrace;
229+
let strategy = crate::proc_macro::exec_strategy(tcx.sess);
230+
let server = crate::proc_macro_server::Rustc::new(ecx);
231+
232+
match client.run(&strategy, server, input.clone(), proc_macro_backtrace) {
233+
Ok(stream) => Ok(tcx.arena.alloc(stream) as &TokenStream),
234+
Err(e) => {
235+
tcx.dcx().emit_err({
236+
errors::ProcMacroDerivePanicked {
237+
span,
238+
message: e.as_str().map(|message| errors::ProcMacroDerivePanickedHelp {
239+
message: message.into(),
240+
}),
241+
}
242+
});
243+
Err(())
244+
}
245+
}
246+
}
247+
248+
// based on rust/compiler/rustc_middle/src/ty/context/tls.rs
249+
thread_local! {
250+
/// A thread local variable that stores a pointer to the current `CONTEXT`.
251+
static TLV: Cell<(*mut (), Option<CLIENT>)> = const { Cell::new((std::ptr::null_mut(), None)) };
252+
}
253+
254+
/// Sets `context` as the new current `CONTEXT` for the duration of the function `f`.
255+
#[inline]
256+
pub(crate) fn enter_context<'a, F, R>(context: (&mut ExtCtxt<'a>, CLIENT), f: F) -> R
257+
where
258+
F: FnOnce() -> R,
259+
{
260+
let (ectx, client) = context;
261+
let erased = (ectx as *mut _ as *mut (), Some(client));
262+
TLV.with(|tlv| {
263+
let old = tlv.replace(erased);
264+
let _reset = rustc_data_structures::defer(move || tlv.set(old));
265+
f()
266+
})
267+
}
268+
269+
/// Allows access to the current `CONTEXT`.
270+
/// Panics if there is no `CONTEXT` available.
271+
#[inline]
272+
#[track_caller]
273+
fn with_context<F, R>(f: F) -> R
274+
where
275+
F: for<'a, 'b> FnOnce(&'b mut (&mut ExtCtxt<'a>, CLIENT)) -> R,
276+
{
277+
let (ectx, client_opt) = TLV.get();
278+
let ectx = NonNull::new(ectx).expect("no CONTEXT stored in tls");
279+
280+
// We could get an `CONTEXT` pointer from another thread.
281+
// Ensure that `CONTEXT` is `DynSync`.
282+
// FIXME(pr-time): we should not be able to?
283+
// sync::assert_dyn_sync::<CONTEXT<'_>>();
284+
285+
// prevent double entering, as that would allow creating two `&mut ExtCtxt`s
286+
// FIXME(pr-time): probably use a RefCell instead (which checks this properly)?
287+
TLV.with(|tlv| {
288+
let old = tlv.replace((std::ptr::null_mut(), None));
289+
let _reset = rustc_data_structures::defer(move || tlv.set(old));
290+
let ectx = {
291+
let mut casted = ectx.cast::<ExtCtxt<'_>>();
292+
unsafe { casted.as_mut() }
293+
};
294+
295+
f(&mut (ectx, client_opt.unwrap()))
296+
})
297+
}

compiler/rustc_interface/src/passes.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -881,6 +881,7 @@ pub static DEFAULT_QUERY_PROVIDERS: LazyLock<Providers> = LazyLock::new(|| {
881881
providers.env_var_os = env_var_os;
882882
limits::provide(providers);
883883
proc_macro_decls::provide(providers);
884+
rustc_expand::provide(providers);
884885
rustc_const_eval::provide(providers);
885886
rustc_middle::hir::provide(providers);
886887
rustc_borrowck::provide(providers);

compiler/rustc_middle/src/arena.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ macro_rules! arena_types {
116116
[decode] specialization_graph: rustc_middle::traits::specialization_graph::Graph,
117117
[] crate_inherent_impls: rustc_middle::ty::CrateInherentImpls,
118118
[] hir_owner_nodes: rustc_hir::OwnerNodes<'tcx>,
119+
[decode] token_stream: rustc_ast::tokenstream::TokenStream,
119120
]);
120121
)
121122
}

compiler/rustc_middle/src/query/erase.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ use std::ffi::OsStr;
22
use std::intrinsics::transmute_unchecked;
33
use std::mem::MaybeUninit;
44

5+
use rustc_ast::tokenstream::TokenStream;
56
use rustc_span::ErrorGuaranteed;
67

78
use crate::mir::interpret::EvalToValTreeResult;
@@ -170,6 +171,10 @@ impl EraseType for Result<ty::EarlyBinder<'_, Ty<'_>>, CyclePlaceholder> {
170171
type Result = [u8; size_of::<Result<ty::EarlyBinder<'static, Ty<'_>>, CyclePlaceholder>>()];
171172
}
172173

174+
impl EraseType for Result<&'_ TokenStream, ()> {
175+
type Result = [u8; size_of::<Result<&'static TokenStream, ()>>()];
176+
}
177+
173178
impl<T> EraseType for Option<&'_ T> {
174179
type Result = [u8; size_of::<Option<&'static ()>>()];
175180
}

compiler/rustc_middle/src/query/keys.rs

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,13 @@
22
33
use std::ffi::OsStr;
44

5+
use rustc_ast::tokenstream::TokenStream;
6+
use rustc_data_structures::svh::Svh;
57
use rustc_hir::def_id::{CrateNum, DefId, LOCAL_CRATE, LocalDefId, LocalModDefId, ModDefId};
68
use rustc_hir::hir_id::{HirId, OwnerId};
79
use rustc_query_system::dep_graph::DepNodeIndex;
810
use rustc_query_system::query::{DefIdCache, DefaultCache, SingleCache, VecCache};
9-
use rustc_span::{DUMMY_SP, Ident, Span, Symbol};
11+
use rustc_span::{DUMMY_SP, Ident, LocalExpnId, Span, Symbol};
1012

1113
use crate::infer::canonical::CanonicalQueryInput;
1214
use crate::mir::mono::CollectionMode;
@@ -616,6 +618,19 @@ impl Key for (LocalDefId, HirId) {
616618
}
617619
}
618620

621+
impl<'tcx> Key for (LocalExpnId, Svh, &'tcx TokenStream) {
622+
type Cache<V> = DefaultCache<Self, V>;
623+
624+
fn default_span(&self, _tcx: TyCtxt<'_>) -> Span {
625+
self.0.expn_data().call_site
626+
}
627+
628+
#[inline(always)]
629+
fn key_as_def_id(&self) -> Option<DefId> {
630+
None
631+
}
632+
}
633+
619634
impl<'tcx> Key for (ValidityRequirement, ty::PseudoCanonicalInput<'tcx, Ty<'tcx>>) {
620635
type Cache<V> = DefaultCache<Self, V>;
621636

compiler/rustc_middle/src/query/mod.rs

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ use std::sync::Arc;
7070
use rustc_abi::Align;
7171
use rustc_arena::TypedArena;
7272
use rustc_ast::expand::allocator::AllocatorKind;
73+
use rustc_ast::tokenstream::TokenStream;
7374
use rustc_data_structures::fingerprint::Fingerprint;
7475
use rustc_data_structures::fx::{FxIndexMap, FxIndexSet};
7576
use rustc_data_structures::sorted_map::SortedMap;
@@ -99,7 +100,7 @@ use rustc_session::cstore::{
99100
use rustc_session::lint::LintExpectationId;
100101
use rustc_span::def_id::LOCAL_CRATE;
101102
use rustc_span::source_map::Spanned;
102-
use rustc_span::{DUMMY_SP, Span, Symbol};
103+
use rustc_span::{DUMMY_SP, LocalExpnId, Span, Symbol};
103104
use rustc_target::spec::{PanicStrategy, SanitizerSet};
104105
use {rustc_abi as abi, rustc_ast as ast, rustc_hir as hir};
105106

@@ -164,6 +165,13 @@ pub use plumbing::{IntoQueryParam, TyCtxtAt, TyCtxtEnsureDone, TyCtxtEnsureOk};
164165
// Queries marked with `fatal_cycle` do not need the latter implementation,
165166
// as they will raise an fatal error on query cycles instead.
166167
rustc_queries! {
168+
query derive_macro_expansion(key: (LocalExpnId, Svh, &'tcx TokenStream)) -> Result<&'tcx TokenStream, ()> {
169+
// eval_always
170+
// no_hash
171+
desc { "expanding a derive (proc) macro" }
172+
cache_on_disk_if { true }
173+
}
174+
167175
/// This exists purely for testing the interactions between delayed bugs and incremental.
168176
query trigger_delayed_bug(key: DefId) {
169177
desc { "triggering a delayed bug for testing incremental" }

0 commit comments

Comments
 (0)