Skip to content

Commit 21e746d

Browse files
Auto merge of #145354 - Kobzol:cache-proc-derive-macros, r=<try>
Cache derive proc macro expansion with incremental query
2 parents 1c9952f + 2f37ce2 commit 21e746d

File tree

15 files changed

+343
-37
lines changed

15 files changed

+343
-37
lines changed

Cargo.lock

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3833,11 +3833,13 @@ dependencies = [
38333833
"rustc_lexer",
38343834
"rustc_lint_defs",
38353835
"rustc_macros",
3836+
"rustc_middle",
38363837
"rustc_parse",
38373838
"rustc_proc_macro",
38383839
"rustc_serialize",
38393840
"rustc_session",
38403841
"rustc_span",
3842+
"scoped-tls",
38413843
"smallvec",
38423844
"thin-vec",
38433845
"tracing",

compiler/rustc_ast/src/tokenstream.rs

Lines changed: 100 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,17 @@
1414
//! ownership of the original.
1515
1616
use std::borrow::Cow;
17+
use std::hash::Hash;
1718
use std::ops::Range;
1819
use std::sync::Arc;
1920
use std::{cmp, fmt, iter, mem};
2021

2122
use rustc_data_structures::stable_hasher::{HashStable, StableHasher};
2223
use rustc_data_structures::sync;
2324
use rustc_macros::{Decodable, Encodable, HashStable_Generic, Walkable};
24-
use rustc_serialize::{Decodable, Encodable};
25-
use rustc_span::{DUMMY_SP, Span, SpanDecoder, SpanEncoder, Symbol, sym};
25+
use rustc_serialize::{Decodable, Encodable, Encoder};
26+
use rustc_span::def_id::{CrateNum, DefIndex};
27+
use rustc_span::{ByteSymbol, DUMMY_SP, Span, SpanDecoder, SpanEncoder, Symbol, sym};
2628
use thin_vec::ThinVec;
2729

2830
use crate::ast::AttrStyle;
@@ -560,6 +562,102 @@ pub struct AttrsTarget {
560562
#[derive(Clone, Debug, Default, Encodable, Decodable)]
561563
pub struct TokenStream(pub(crate) Arc<Vec<TokenTree>>);
562564

565+
struct HashEncoder<H: std::hash::Hasher> {
566+
hasher: H,
567+
}
568+
569+
impl<H: std::hash::Hasher> Encoder for HashEncoder<H> {
570+
fn emit_usize(&mut self, v: usize) {
571+
self.hasher.write_usize(v)
572+
}
573+
574+
fn emit_u128(&mut self, v: u128) {
575+
self.hasher.write_u128(v)
576+
}
577+
578+
fn emit_u64(&mut self, v: u64) {
579+
self.hasher.write_u64(v)
580+
}
581+
582+
fn emit_u32(&mut self, v: u32) {
583+
self.hasher.write_u32(v)
584+
}
585+
586+
fn emit_u16(&mut self, v: u16) {
587+
self.hasher.write_u16(v)
588+
}
589+
590+
fn emit_u8(&mut self, v: u8) {
591+
self.hasher.write_u8(v)
592+
}
593+
594+
fn emit_isize(&mut self, v: isize) {
595+
self.hasher.write_isize(v)
596+
}
597+
598+
fn emit_i128(&mut self, v: i128) {
599+
self.hasher.write_i128(v)
600+
}
601+
602+
fn emit_i64(&mut self, v: i64) {
603+
self.hasher.write_i64(v)
604+
}
605+
606+
fn emit_i32(&mut self, v: i32) {
607+
self.hasher.write_i32(v)
608+
}
609+
610+
fn emit_i16(&mut self, v: i16) {
611+
self.hasher.write_i16(v)
612+
}
613+
614+
fn emit_raw_bytes(&mut self, s: &[u8]) {
615+
self.hasher.write(s)
616+
}
617+
}
618+
619+
impl<H: std::hash::Hasher> SpanEncoder for HashEncoder<H> {
620+
fn encode_span(&mut self, span: Span) {
621+
span.hash(&mut self.hasher)
622+
}
623+
624+
fn encode_symbol(&mut self, symbol: Symbol) {
625+
symbol.hash(&mut self.hasher)
626+
}
627+
628+
fn encode_byte_symbol(&mut self, byte_sym: ByteSymbol) {
629+
byte_sym.hash(&mut self.hasher);
630+
}
631+
632+
fn encode_expn_id(&mut self, expn_id: rustc_span::ExpnId) {
633+
expn_id.hash(&mut self.hasher)
634+
}
635+
636+
fn encode_syntax_context(&mut self, syntax_context: rustc_span::SyntaxContext) {
637+
syntax_context.hash(&mut self.hasher)
638+
}
639+
640+
fn encode_crate_num(&mut self, crate_num: CrateNum) {
641+
crate_num.hash(&mut self.hasher)
642+
}
643+
644+
fn encode_def_index(&mut self, def_index: DefIndex) {
645+
def_index.hash(&mut self.hasher)
646+
}
647+
648+
fn encode_def_id(&mut self, def_id: rustc_span::def_id::DefId) {
649+
def_id.hash(&mut self.hasher)
650+
}
651+
}
652+
653+
/// TokenStream needs to be hashable because it is used as a query key for caching derive macro
654+
/// expansions.
655+
impl Hash for TokenStream {
656+
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
657+
Encodable::encode(self, &mut HashEncoder { hasher: state });
658+
}
659+
}
660+
563661
/// Indicates whether a token can join with the following token to form a
564662
/// compound token. Used for conversions to `proc_macro::Spacing`. Also used to
565663
/// guide pretty-printing, which is where the `JointHidden` value (which isn't

compiler/rustc_expand/Cargo.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,15 @@ rustc_hir = { path = "../rustc_hir" }
2121
rustc_lexer = { path = "../rustc_lexer" }
2222
rustc_lint_defs = { path = "../rustc_lint_defs" }
2323
rustc_macros = { path = "../rustc_macros" }
24+
rustc_middle = { path = "../rustc_middle" }
2425
rustc_parse = { path = "../rustc_parse" }
2526
# We must use the proc_macro version that we will compile proc-macros against,
2627
# not the one from our own sysroot.
2728
rustc_proc_macro = { path = "../rustc_proc_macro" }
2829
rustc_serialize = { path = "../rustc_serialize" }
2930
rustc_session = { path = "../rustc_session" }
3031
rustc_span = { path = "../rustc_span" }
32+
scoped-tls = "1.0"
3133
smallvec = { version = "1.8.1", features = ["union", "may_dangle"] }
3234
thin-vec = "0.2.12"
3335
tracing = "0.1"

compiler/rustc_expand/src/lib.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,4 +31,8 @@ pub mod module;
3131
#[allow(rustc::untranslatable_diagnostic)]
3232
pub mod proc_macro;
3333

34+
pub fn provide(providers: &mut rustc_middle::util::Providers) {
35+
providers.derive_macro_expansion = proc_macro::provide_derive_macro_expansion;
36+
}
37+
3438
rustc_fluent_macro::fluent_messages! { "../messages.ftl" }

compiler/rustc_expand/src/proc_macro.rs

Lines changed: 142 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
use rustc_ast::tokenstream::TokenStream;
2+
use rustc_data_structures::svh::Svh;
23
use rustc_errors::ErrorGuaranteed;
4+
use rustc_middle::ty::{self, TyCtxt};
35
use rustc_parse::parser::{ForceCollect, Parser};
6+
use rustc_session::Session;
47
use rustc_session::config::ProcMacroExecutionStrategy;
5-
use rustc_span::Span;
68
use rustc_span::profiling::SpannedEventArgRecorder;
9+
use rustc_span::{LocalExpnId, Span};
710
use {rustc_ast as ast, rustc_proc_macro as pm};
811

912
use crate::base::{self, *};
@@ -30,9 +33,9 @@ impl<T> pm::bridge::server::MessagePipe<T> for MessagePipe<T> {
3033
}
3134
}
3235

33-
fn exec_strategy(ecx: &ExtCtxt<'_>) -> impl pm::bridge::server::ExecutionStrategy + 'static {
36+
fn exec_strategy(sess: &Session) -> impl pm::bridge::server::ExecutionStrategy + 'static {
3437
pm::bridge::server::MaybeCrossThread::<MessagePipe<_>>::new(
35-
ecx.sess.opts.unstable_opts.proc_macro_execution_strategy
38+
sess.opts.unstable_opts.proc_macro_execution_strategy
3639
== ProcMacroExecutionStrategy::CrossThread,
3740
)
3841
}
@@ -54,7 +57,7 @@ impl base::BangProcMacro for BangProcMacro {
5457
});
5558

5659
let proc_macro_backtrace = ecx.ecfg.proc_macro_backtrace;
57-
let strategy = exec_strategy(ecx);
60+
let strategy = exec_strategy(ecx.sess);
5861
let server = proc_macro_server::Rustc::new(ecx);
5962
self.client.run(&strategy, server, input, proc_macro_backtrace).map_err(|e| {
6063
ecx.dcx().emit_err(errors::ProcMacroPanicked {
@@ -85,7 +88,7 @@ impl base::AttrProcMacro for AttrProcMacro {
8588
});
8689

8790
let proc_macro_backtrace = ecx.ecfg.proc_macro_backtrace;
88-
let strategy = exec_strategy(ecx);
91+
let strategy = exec_strategy(ecx.sess);
8992
let server = proc_macro_server::Rustc::new(ecx);
9093
self.client.run(&strategy, server, annotation, annotated, proc_macro_backtrace).map_err(
9194
|e| {
@@ -101,7 +104,7 @@ impl base::AttrProcMacro for AttrProcMacro {
101104
}
102105

103106
pub struct DeriveProcMacro {
104-
pub client: pm::bridge::client::Client<pm::TokenStream, pm::TokenStream>,
107+
pub client: DeriveClient,
105108
}
106109

107110
impl MultiItemModifier for DeriveProcMacro {
@@ -113,6 +116,13 @@ impl MultiItemModifier for DeriveProcMacro {
113116
item: Annotatable,
114117
_is_derive_const: bool,
115118
) -> ExpandResult<Vec<Annotatable>, Annotatable> {
119+
let _timer = ecx.sess.prof.generic_activity_with_arg_recorder(
120+
"expand_derive_proc_macro_outer",
121+
|recorder| {
122+
recorder.record_arg_with_span(ecx.sess.source_map(), ecx.expansion_descr(), span);
123+
},
124+
);
125+
116126
// We need special handling for statement items
117127
// (e.g. `fn foo() { #[derive(Debug)] struct Bar; }`)
118128
let is_stmt = matches!(item, Annotatable::Stmt(..));
@@ -123,36 +133,39 @@ impl MultiItemModifier for DeriveProcMacro {
123133
// altogether. See #73345.
124134
crate::base::ann_pretty_printing_compatibility_hack(&item, &ecx.sess.psess);
125135
let input = item.to_tokens();
126-
let stream = {
127-
let _timer =
128-
ecx.sess.prof.generic_activity_with_arg_recorder("expand_proc_macro", |recorder| {
129-
recorder.record_arg_with_span(
130-
ecx.sess.source_map(),
131-
ecx.expansion_descr(),
132-
span,
133-
);
134-
});
135-
let proc_macro_backtrace = ecx.ecfg.proc_macro_backtrace;
136-
let strategy = exec_strategy(ecx);
137-
let server = proc_macro_server::Rustc::new(ecx);
138-
match self.client.run(&strategy, server, input, proc_macro_backtrace) {
139-
Ok(stream) => stream,
140-
Err(e) => {
141-
ecx.dcx().emit_err({
142-
errors::ProcMacroDerivePanicked {
143-
span,
144-
message: e.as_str().map(|message| {
145-
errors::ProcMacroDerivePanickedHelp { message: message.into() }
146-
}),
147-
}
148-
});
149-
return ExpandResult::Ready(vec![]);
150-
}
151-
}
136+
137+
let invoc_id = ecx.current_expansion.id;
138+
139+
let res = if ecx.sess.opts.incremental.is_some()
140+
&& ecx.sess.opts.unstable_opts.cache_derive_macros
141+
{
142+
ty::tls::with(|tcx| {
143+
// FIXME(pr-time): Just using the crate hash to notice when the proc-macro code has
144+
// changed. How to *correctly* depend on exactly the macro definition?
145+
// I.e., depending on the crate hash is just a HACK, and ideally the dependency would be
146+
// more narrow.
147+
let invoc_expn_data = invoc_id.expn_data();
148+
let macro_def_id = invoc_expn_data.macro_def_id.unwrap();
149+
let proc_macro_crate_hash = tcx.crate_hash(macro_def_id.krate);
150+
151+
let input = tcx.arena.alloc(input) as &TokenStream;
152+
let key = (invoc_id, proc_macro_crate_hash, input);
153+
154+
QueryDeriveExpandCtx::enter(ecx, self.client, move || {
155+
tcx.derive_macro_expansion(key).cloned()
156+
})
157+
})
158+
} else {
159+
expand_derive_macro(invoc_id, input, ecx, self.client)
160+
};
161+
162+
let Ok(output) = res else {
163+
// error will already have been emitted
164+
return ExpandResult::Ready(vec![]);
152165
};
153166

154167
let error_count_before = ecx.dcx().err_count();
155-
let mut parser = Parser::new(&ecx.sess.psess, stream, Some("proc-macro derive"));
168+
let mut parser = Parser::new(&ecx.sess.psess, output, Some("proc-macro derive"));
156169
let mut items = vec![];
157170

158171
loop {
@@ -180,3 +193,99 @@ impl MultiItemModifier for DeriveProcMacro {
180193
ExpandResult::Ready(items)
181194
}
182195
}
196+
197+
/// Provide a query for computing the output of a derive macro.
198+
pub(super) fn provide_derive_macro_expansion<'tcx>(
199+
tcx: TyCtxt<'tcx>,
200+
key: (LocalExpnId, Svh, &'tcx TokenStream),
201+
) -> Result<&'tcx TokenStream, ()> {
202+
let (invoc_id, _macro_crate_hash, input) = key;
203+
204+
eprintln!("Expanding derive macro in a query");
205+
206+
QueryDeriveExpandCtx::with(|ecx, client| {
207+
expand_derive_macro(invoc_id, input.clone(), ecx, client)
208+
.map(|ts| tcx.arena.alloc(ts) as &TokenStream)
209+
})
210+
}
211+
212+
type DeriveClient = pm::bridge::client::Client<pm::TokenStream, pm::TokenStream>;
213+
214+
fn expand_derive_macro(
215+
invoc_id: LocalExpnId,
216+
input: TokenStream,
217+
ecx: &mut ExtCtxt<'_>,
218+
client: DeriveClient,
219+
) -> Result<TokenStream, ()> {
220+
let invoc_expn_data = invoc_id.expn_data();
221+
let span = invoc_expn_data.call_site;
222+
let event_arg = invoc_expn_data.kind.descr();
223+
let _timer =
224+
ecx.sess.prof.generic_activity_with_arg_recorder("expand_proc_macro", |recorder| {
225+
recorder.record_arg_with_span(ecx.sess.source_map(), event_arg.clone(), span);
226+
});
227+
228+
let proc_macro_backtrace = ecx.ecfg.proc_macro_backtrace;
229+
let strategy = exec_strategy(ecx.sess);
230+
let server = proc_macro_server::Rustc::new(ecx);
231+
232+
match client.run(&strategy, server, input, proc_macro_backtrace) {
233+
Ok(stream) => Ok(stream),
234+
Err(e) => {
235+
ecx.dcx().emit_err({
236+
errors::ProcMacroDerivePanicked {
237+
span,
238+
message: e.as_str().map(|message| errors::ProcMacroDerivePanickedHelp {
239+
message: message.into(),
240+
}),
241+
}
242+
});
243+
Err(())
244+
}
245+
}
246+
}
247+
248+
/// Stores the context necessary to expand a derive proc macro via a query.
249+
struct QueryDeriveExpandCtx {
250+
/// Type-erased version of `&mut ExtCtxt`
251+
expansion_ctx: *mut (),
252+
client: DeriveClient,
253+
}
254+
255+
impl QueryDeriveExpandCtx {
256+
/// Store the extension context and the client into the thread local value.
257+
/// It will be accessible via the `with` method while `f` is active.
258+
fn enter<F, R>(ecx: &mut ExtCtxt<'_>, client: DeriveClient, f: F) -> R
259+
where
260+
F: FnOnce() -> R,
261+
{
262+
// We need erasure to get rid of the lifetime
263+
let ctx = Self { expansion_ctx: ecx as *mut _ as *mut (), client };
264+
DERIVE_EXPAND_CTX.set(&ctx, || f())
265+
}
266+
267+
/// Accesses the thread local value of the derive expansion context.
268+
/// Must be called while the `enter` function is active.
269+
fn with<F, R>(f: F) -> R
270+
where
271+
F: for<'a, 'b> FnOnce(&'b mut ExtCtxt<'a>, DeriveClient) -> R,
272+
{
273+
DERIVE_EXPAND_CTX.with(|ctx| {
274+
let ectx = {
275+
let casted = ctx.expansion_ctx.cast::<ExtCtxt<'_>>();
276+
// SAFETY: We can only get the value from `with` while the `enter` function
277+
// is active (on the callstack), and that function's signature ensures that the
278+
// lifetime is valid.
279+
// If `with` is called at some other time, it will panic due to usage of
280+
// `scoped_tls::with`.
281+
unsafe { casted.as_mut().unwrap() }
282+
};
283+
284+
f(ectx, ctx.client)
285+
})
286+
}
287+
}
288+
289+
// When we invoke a query to expand a derive proc macro, we need to provide it with the expansion
290+
// context and derive Client. We do that using a thread-local.
291+
scoped_tls::scoped_thread_local!(static DERIVE_EXPAND_CTX: QueryDeriveExpandCtx);

0 commit comments

Comments
 (0)