Skip to content

Cache derive proc macro expansion with incremental query #145354

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Cargo.lock
Original file line number Diff line number Diff line change
Expand Up @@ -3833,11 +3833,13 @@ dependencies = [
"rustc_lexer",
"rustc_lint_defs",
"rustc_macros",
"rustc_middle",
"rustc_parse",
"rustc_proc_macro",
"rustc_serialize",
"rustc_session",
"rustc_span",
"scoped-tls",
"smallvec",
"thin-vec",
"tracing",
Expand Down
102 changes: 100 additions & 2 deletions compiler/rustc_ast/src/tokenstream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,17 @@
//! ownership of the original.

use std::borrow::Cow;
use std::hash::Hash;
use std::ops::Range;
use std::sync::Arc;
use std::{cmp, fmt, iter, mem};

use rustc_data_structures::stable_hasher::{HashStable, StableHasher};
use rustc_data_structures::sync;
use rustc_macros::{Decodable, Encodable, HashStable_Generic, Walkable};
use rustc_serialize::{Decodable, Encodable};
use rustc_span::{DUMMY_SP, Span, SpanDecoder, SpanEncoder, Symbol, sym};
use rustc_serialize::{Decodable, Encodable, Encoder};
use rustc_span::def_id::{CrateNum, DefIndex};
use rustc_span::{ByteSymbol, DUMMY_SP, Span, SpanDecoder, SpanEncoder, Symbol, sym};
use thin_vec::ThinVec;

use crate::ast::AttrStyle;
Expand Down Expand Up @@ -560,6 +562,102 @@ pub struct AttrsTarget {
#[derive(Clone, Debug, Default, Encodable, Decodable)]
pub struct TokenStream(pub(crate) Arc<Vec<TokenTree>>);

struct HashEncoder<H: std::hash::Hasher> {
hasher: H,
}

impl<H: std::hash::Hasher> Encoder for HashEncoder<H> {
fn emit_usize(&mut self, v: usize) {
self.hasher.write_usize(v)
}

fn emit_u128(&mut self, v: u128) {
self.hasher.write_u128(v)
}

fn emit_u64(&mut self, v: u64) {
self.hasher.write_u64(v)
}

fn emit_u32(&mut self, v: u32) {
self.hasher.write_u32(v)
}

fn emit_u16(&mut self, v: u16) {
self.hasher.write_u16(v)
}

fn emit_u8(&mut self, v: u8) {
self.hasher.write_u8(v)
}

fn emit_isize(&mut self, v: isize) {
self.hasher.write_isize(v)
}

fn emit_i128(&mut self, v: i128) {
self.hasher.write_i128(v)
}

fn emit_i64(&mut self, v: i64) {
self.hasher.write_i64(v)
}

fn emit_i32(&mut self, v: i32) {
self.hasher.write_i32(v)
}

fn emit_i16(&mut self, v: i16) {
self.hasher.write_i16(v)
}

fn emit_raw_bytes(&mut self, s: &[u8]) {
self.hasher.write(s)
}
}

impl<H: std::hash::Hasher> SpanEncoder for HashEncoder<H> {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this use StableHasher? The regular Hash impl is only valid for a single rustc session.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is something that wasn't answered in #129102 (comment).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, I need to read up on the difference of usage of Hash and StableHash for incremental. I examined a few other query key types and a lot of them just #[derive(Hash)]. I did the same for TokenStream, but would appreciate guidance on if it is the Right Thing to do :)

fn encode_span(&mut self, span: Span) {
span.hash(&mut self.hasher)
}

fn encode_symbol(&mut self, symbol: Symbol) {
symbol.hash(&mut self.hasher)
}

fn encode_byte_symbol(&mut self, byte_sym: ByteSymbol) {
byte_sym.hash(&mut self.hasher);
}

fn encode_expn_id(&mut self, expn_id: rustc_span::ExpnId) {
expn_id.hash(&mut self.hasher)
}

fn encode_syntax_context(&mut self, syntax_context: rustc_span::SyntaxContext) {
syntax_context.hash(&mut self.hasher)
}

fn encode_crate_num(&mut self, crate_num: CrateNum) {
crate_num.hash(&mut self.hasher)
}

fn encode_def_index(&mut self, def_index: DefIndex) {
def_index.hash(&mut self.hasher)
}

fn encode_def_id(&mut self, def_id: rustc_span::def_id::DefId) {
def_id.hash(&mut self.hasher)
}
}

/// TokenStream needs to be hashable because it is used as a query key for caching derive macro
/// expansions.
impl Hash for TokenStream {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
Encodable::encode(self, &mut HashEncoder { hasher: state });
}
}

/// Indicates whether a token can join with the following token to form a
/// compound token. Used for conversions to `proc_macro::Spacing`. Also used to
/// guide pretty-printing, which is where the `JointHidden` value (which isn't
Expand Down
2 changes: 2 additions & 0 deletions compiler/rustc_expand/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,15 @@ rustc_hir = { path = "../rustc_hir" }
rustc_lexer = { path = "../rustc_lexer" }
rustc_lint_defs = { path = "../rustc_lint_defs" }
rustc_macros = { path = "../rustc_macros" }
rustc_middle = { path = "../rustc_middle" }
rustc_parse = { path = "../rustc_parse" }
# We must use the proc_macro version that we will compile proc-macros against,
# not the one from our own sysroot.
rustc_proc_macro = { path = "../rustc_proc_macro" }
rustc_serialize = { path = "../rustc_serialize" }
rustc_session = { path = "../rustc_session" }
rustc_span = { path = "../rustc_span" }
scoped-tls = "1.0"
smallvec = { version = "1.8.1", features = ["union", "may_dangle"] }
thin-vec = "0.2.12"
tracing = "0.1"
Expand Down
4 changes: 4 additions & 0 deletions compiler/rustc_expand/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,4 +31,8 @@ pub mod module;
#[allow(rustc::untranslatable_diagnostic)]
pub mod proc_macro;

pub fn provide(providers: &mut rustc_middle::util::Providers) {
providers.derive_macro_expansion = proc_macro::provide_derive_macro_expansion;
}

rustc_fluent_macro::fluent_messages! { "../messages.ftl" }
175 changes: 142 additions & 33 deletions compiler/rustc_expand/src/proc_macro.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
use rustc_ast::tokenstream::TokenStream;
use rustc_data_structures::svh::Svh;
use rustc_errors::ErrorGuaranteed;
use rustc_middle::ty::{self, TyCtxt};
use rustc_parse::parser::{ForceCollect, Parser};
use rustc_session::Session;
use rustc_session::config::ProcMacroExecutionStrategy;
use rustc_span::Span;
use rustc_span::profiling::SpannedEventArgRecorder;
use rustc_span::{LocalExpnId, Span};
use {rustc_ast as ast, rustc_proc_macro as pm};

use crate::base::{self, *};
Expand All @@ -30,9 +33,9 @@ impl<T> pm::bridge::server::MessagePipe<T> for MessagePipe<T> {
}
}

fn exec_strategy(ecx: &ExtCtxt<'_>) -> impl pm::bridge::server::ExecutionStrategy + 'static {
fn exec_strategy(sess: &Session) -> impl pm::bridge::server::ExecutionStrategy + 'static {
pm::bridge::server::MaybeCrossThread::<MessagePipe<_>>::new(
ecx.sess.opts.unstable_opts.proc_macro_execution_strategy
sess.opts.unstable_opts.proc_macro_execution_strategy
== ProcMacroExecutionStrategy::CrossThread,
)
}
Expand All @@ -54,7 +57,7 @@ impl base::BangProcMacro for BangProcMacro {
});

let proc_macro_backtrace = ecx.ecfg.proc_macro_backtrace;
let strategy = exec_strategy(ecx);
let strategy = exec_strategy(ecx.sess);
let server = proc_macro_server::Rustc::new(ecx);
self.client.run(&strategy, server, input, proc_macro_backtrace).map_err(|e| {
ecx.dcx().emit_err(errors::ProcMacroPanicked {
Expand Down Expand Up @@ -85,7 +88,7 @@ impl base::AttrProcMacro for AttrProcMacro {
});

let proc_macro_backtrace = ecx.ecfg.proc_macro_backtrace;
let strategy = exec_strategy(ecx);
let strategy = exec_strategy(ecx.sess);
let server = proc_macro_server::Rustc::new(ecx);
self.client.run(&strategy, server, annotation, annotated, proc_macro_backtrace).map_err(
|e| {
Expand All @@ -101,7 +104,7 @@ impl base::AttrProcMacro for AttrProcMacro {
}

pub struct DeriveProcMacro {
pub client: pm::bridge::client::Client<pm::TokenStream, pm::TokenStream>,
pub client: DeriveClient,
}

impl MultiItemModifier for DeriveProcMacro {
Expand All @@ -113,6 +116,13 @@ impl MultiItemModifier for DeriveProcMacro {
item: Annotatable,
_is_derive_const: bool,
) -> ExpandResult<Vec<Annotatable>, Annotatable> {
let _timer = ecx.sess.prof.generic_activity_with_arg_recorder(
"expand_derive_proc_macro_outer",
|recorder| {
recorder.record_arg_with_span(ecx.sess.source_map(), ecx.expansion_descr(), span);
},
);

// We need special handling for statement items
// (e.g. `fn foo() { #[derive(Debug)] struct Bar; }`)
let is_stmt = matches!(item, Annotatable::Stmt(..));
Expand All @@ -123,36 +133,39 @@ impl MultiItemModifier for DeriveProcMacro {
// altogether. See #73345.
crate::base::ann_pretty_printing_compatibility_hack(&item, &ecx.sess.psess);
let input = item.to_tokens();
let stream = {
let _timer =
ecx.sess.prof.generic_activity_with_arg_recorder("expand_proc_macro", |recorder| {
recorder.record_arg_with_span(
ecx.sess.source_map(),
ecx.expansion_descr(),
span,
);
});
let proc_macro_backtrace = ecx.ecfg.proc_macro_backtrace;
let strategy = exec_strategy(ecx);
let server = proc_macro_server::Rustc::new(ecx);
match self.client.run(&strategy, server, input, proc_macro_backtrace) {
Ok(stream) => stream,
Err(e) => {
ecx.dcx().emit_err({
errors::ProcMacroDerivePanicked {
span,
message: e.as_str().map(|message| {
errors::ProcMacroDerivePanickedHelp { message: message.into() }
}),
}
});
return ExpandResult::Ready(vec![]);
}
}

let invoc_id = ecx.current_expansion.id;

let res = if ecx.sess.opts.incremental.is_some()
&& ecx.sess.opts.unstable_opts.cache_derive_macros
{
ty::tls::with(|tcx| {
// FIXME(pr-time): Just using the crate hash to notice when the proc-macro code has
// changed. How to *correctly* depend on exactly the macro definition?
// I.e., depending on the crate hash is just a HACK, and ideally the dependency would be
// more narrow.
let invoc_expn_data = invoc_id.expn_data();
let macro_def_id = invoc_expn_data.macro_def_id.unwrap();
let proc_macro_crate_hash = tcx.crate_hash(macro_def_id.krate);
Comment on lines +148 to +149
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be done from inside the query, to force a dependency on crate_hash, not as part of the key.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, my knowledge of the query infrastructure is quite limited 😅 How would that work? The key would be just LocalExpnId and &TokenStream, and then inside the query I would just call tcx.crate_hash and throw away the results? 🤔


let input = tcx.arena.alloc(input) as &TokenStream;
let key = (invoc_id, proc_macro_crate_hash, input);

QueryDeriveExpandCtx::enter(ecx, self.client, move || {
tcx.derive_macro_expansion(key).cloned()
})
})
} else {
expand_derive_macro(invoc_id, input, ecx, self.client)
};

let Ok(output) = res else {
// error will already have been emitted
return ExpandResult::Ready(vec![]);
};

let error_count_before = ecx.dcx().err_count();
let mut parser = Parser::new(&ecx.sess.psess, stream, Some("proc-macro derive"));
let mut parser = Parser::new(&ecx.sess.psess, output, Some("proc-macro derive"));
let mut items = vec![];

loop {
Expand Down Expand Up @@ -180,3 +193,99 @@ impl MultiItemModifier for DeriveProcMacro {
ExpandResult::Ready(items)
}
}

/// Provide a query for computing the output of a derive macro.
pub(super) fn provide_derive_macro_expansion<'tcx>(
tcx: TyCtxt<'tcx>,
key: (LocalExpnId, Svh, &'tcx TokenStream),
) -> Result<&'tcx TokenStream, ()> {
let (invoc_id, _macro_crate_hash, input) = key;

eprintln!("Expanding derive macro in a query");

QueryDeriveExpandCtx::with(|ecx, client| {
expand_derive_macro(invoc_id, input.clone(), ecx, client)
.map(|ts| tcx.arena.alloc(ts) as &TokenStream)
})
}

type DeriveClient = pm::bridge::client::Client<pm::TokenStream, pm::TokenStream>;

fn expand_derive_macro(
invoc_id: LocalExpnId,
input: TokenStream,
ecx: &mut ExtCtxt<'_>,
client: DeriveClient,
) -> Result<TokenStream, ()> {
let invoc_expn_data = invoc_id.expn_data();
let span = invoc_expn_data.call_site;
let event_arg = invoc_expn_data.kind.descr();
let _timer =
ecx.sess.prof.generic_activity_with_arg_recorder("expand_proc_macro", |recorder| {
recorder.record_arg_with_span(ecx.sess.source_map(), event_arg.clone(), span);
});

let proc_macro_backtrace = ecx.ecfg.proc_macro_backtrace;
let strategy = exec_strategy(ecx.sess);
let server = proc_macro_server::Rustc::new(ecx);

match client.run(&strategy, server, input, proc_macro_backtrace) {
Ok(stream) => Ok(stream),
Err(e) => {
ecx.dcx().emit_err({
errors::ProcMacroDerivePanicked {
span,
message: e.as_str().map(|message| errors::ProcMacroDerivePanickedHelp {
message: message.into(),
}),
}
});
Err(())
}
}
}

/// Stores the context necessary to expand a derive proc macro via a query.
struct QueryDeriveExpandCtx {
/// Type-erased version of `&mut ExtCtxt`
expansion_ctx: *mut (),
client: DeriveClient,
}

impl QueryDeriveExpandCtx {
/// Store the extension context and the client into the thread local value.
/// It will be accessible via the `with` method while `f` is active.
fn enter<F, R>(ecx: &mut ExtCtxt<'_>, client: DeriveClient, f: F) -> R
where
F: FnOnce() -> R,
{
// We need erasure to get rid of the lifetime
let ctx = Self { expansion_ctx: ecx as *mut _ as *mut (), client };
DERIVE_EXPAND_CTX.set(&ctx, || f())
}

/// Accesses the thread local value of the derive expansion context.
/// Must be called while the `enter` function is active.
fn with<F, R>(f: F) -> R
where
F: for<'a, 'b> FnOnce(&'b mut ExtCtxt<'a>, DeriveClient) -> R,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The ExtCtxt contains a bunch of things that would need to be tracked by the query system for sound caching. And the rest could either be retrieved from the tcx or be created from scratch to avoid having to use a thread local to bypass the query system.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was discussed a bit in #129102 (comment) and a few comments right above it.

{
DERIVE_EXPAND_CTX.with(|ctx| {
let ectx = {
let casted = ctx.expansion_ctx.cast::<ExtCtxt<'_>>();
// SAFETY: We can only get the value from `with` while the `enter` function
// is active (on the callstack), and that function's signature ensures that the
// lifetime is valid.
// If `with` is called at some other time, it will panic due to usage of
// `scoped_tls::with`.
unsafe { casted.as_mut().unwrap() }
};

f(ectx, ctx.client)
})
}
}

// When we invoke a query to expand a derive proc macro, we need to provide it with the expansion
// context and derive Client. We do that using a thread-local.
scoped_tls::scoped_thread_local!(static DERIVE_EXPAND_CTX: QueryDeriveExpandCtx);
Loading
Loading