Skip to content

Commit f26f394

Browse files
committed
Refactor TLS access
1 parent 5220403 commit f26f394

File tree

3 files changed

+80
-86
lines changed

3 files changed

+80
-86
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3839,6 +3839,7 @@ dependencies = [
38393839
"rustc_serialize",
38403840
"rustc_session",
38413841
"rustc_span",
3842+
"scoped-tls",
38423843
"smallvec",
38433844
"thin-vec",
38443845
"tracing",

compiler/rustc_expand/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ rustc_proc_macro = { path = "../rustc_proc_macro" }
2929
rustc_serialize = { path = "../rustc_serialize" }
3030
rustc_session = { path = "../rustc_session" }
3131
rustc_span = { path = "../rustc_span" }
32+
scoped-tls = "1.0"
3233
smallvec = { version = "1.8.1", features = ["union", "may_dangle"] }
3334
thin-vec = "0.2.12"
3435
tracing = "0.1"

compiler/rustc_expand/src/proc_macro.rs

Lines changed: 78 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,3 @@
1-
use std::cell::Cell;
2-
use std::ptr::NonNull;
3-
41
use rustc_ast::tokenstream::TokenStream;
52
use rustc_data_structures::svh::Svh;
63
use rustc_errors::ErrorGuaranteed;
@@ -36,7 +33,7 @@ impl<T> pm::bridge::server::MessagePipe<T> for MessagePipe<T> {
3633
}
3734
}
3835

39-
pub fn exec_strategy(sess: &Session) -> impl pm::bridge::server::ExecutionStrategy + 'static {
36+
fn exec_strategy(sess: &Session) -> impl pm::bridge::server::ExecutionStrategy + 'static {
4037
pm::bridge::server::MaybeCrossThread::<MessagePipe<_>>::new(
4138
sess.opts.unstable_opts.proc_macro_execution_strategy
4239
== ProcMacroExecutionStrategy::CrossThread,
@@ -107,7 +104,7 @@ impl base::AttrProcMacro for AttrProcMacro {
107104
}
108105

109106
pub struct DeriveProcMacro {
110-
pub client: pm::bridge::client::Client<pm::TokenStream, pm::TokenStream>,
107+
pub client: DeriveClient,
111108
}
112109

113110
impl MultiItemModifier for DeriveProcMacro {
@@ -136,32 +133,31 @@ impl MultiItemModifier for DeriveProcMacro {
136133
// altogether. See #73345.
137134
crate::base::ann_pretty_printing_compatibility_hack(&item, &ecx.sess.psess);
138135
let input = item.to_tokens();
139-
let res = ty::tls::with(|tcx| {
140-
let input = tcx.arena.alloc(input) as &TokenStream;
141-
let invoc_id = ecx.current_expansion.id;
142-
let invoc_expn_data = invoc_id.expn_data();
143-
144-
assert_eq!(invoc_expn_data.call_site, span);
145-
146-
// FIXME(pr-time): Is this the correct way to check for incremental compilation (as
147-
// well as for `cache_proc_macros`)?
148-
if tcx.sess.opts.incremental.is_some()
149-
&& tcx.sess.opts.unstable_opts.cache_derive_macros
150-
{
136+
137+
let invoc_id = ecx.current_expansion.id;
138+
139+
let res = if ecx.sess.opts.incremental.is_some()
140+
&& ecx.sess.opts.unstable_opts.cache_derive_macros
141+
{
142+
ty::tls::with(|tcx| {
151143
// FIXME(pr-time): Just using the crate hash to notice when the proc-macro code has
152144
// changed. How to *correctly* depend on exactly the macro definition?
153145
// I.e., depending on the crate hash is just a HACK, and ideally the dependency would be
154146
// more narrow.
147+
let invoc_expn_data = invoc_id.expn_data();
155148
let macro_def_id = invoc_expn_data.macro_def_id.unwrap();
156149
let proc_macro_crate_hash = tcx.crate_hash(macro_def_id.krate);
157150

151+
let input = tcx.arena.alloc(input) as &TokenStream;
158152
let key = (invoc_id, proc_macro_crate_hash, input);
159153

160-
enter_context((ecx, self.client), move || tcx.derive_macro_expansion(key).cloned())
161-
} else {
162-
expand_derive_macro(tcx, invoc_id, input, ecx, self.client).cloned()
163-
}
164-
});
154+
QueryDeriveExpandCtx::enter(ecx, self.client, move || {
155+
tcx.derive_macro_expansion(key).cloned()
156+
})
157+
})
158+
} else {
159+
expand_derive_macro(invoc_id, input, ecx, self.client)
160+
};
165161

166162
let Ok(output) = res else {
167163
// error will already have been emitted
@@ -205,36 +201,38 @@ pub(super) fn provide_derive_macro_expansion<'tcx>(
205201
) -> Result<&'tcx TokenStream, ()> {
206202
let (invoc_id, _macro_crate_hash, input) = key;
207203

208-
with_context(|(ecx, client)| expand_derive_macro(tcx, invoc_id, input, ecx, *client))
204+
eprintln!("Expanding derive macro in a query");
205+
206+
QueryDeriveExpandCtx::with(|ecx, client| {
207+
expand_derive_macro(invoc_id, input.clone(), ecx, client)
208+
.map(|ts| tcx.arena.alloc(ts) as &TokenStream)
209+
})
209210
}
210211

211-
type CLIENT = pm::bridge::client::Client<pm::TokenStream, pm::TokenStream>;
212+
type DeriveClient = pm::bridge::client::Client<pm::TokenStream, pm::TokenStream>;
212213

213-
fn expand_derive_macro<'tcx>(
214-
tcx: TyCtxt<'tcx>,
214+
fn expand_derive_macro(
215215
invoc_id: LocalExpnId,
216-
input: &'tcx TokenStream,
216+
input: TokenStream,
217217
ecx: &mut ExtCtxt<'_>,
218-
client: CLIENT,
219-
) -> Result<&'tcx TokenStream, ()> {
218+
client: DeriveClient,
219+
) -> Result<TokenStream, ()> {
220220
let invoc_expn_data = invoc_id.expn_data();
221221
let span = invoc_expn_data.call_site;
222222
let event_arg = invoc_expn_data.kind.descr();
223-
let _timer = tcx.sess.prof.generic_activity_with_arg_recorder(
224-
"expand_derive_proc_macro_inner",
225-
|recorder| {
226-
recorder.record_arg_with_span(tcx.sess.source_map(), event_arg.clone(), span);
227-
},
228-
);
223+
let _timer =
224+
ecx.sess.prof.generic_activity_with_arg_recorder("expand_proc_macro", |recorder| {
225+
recorder.record_arg_with_span(ecx.sess.source_map(), event_arg.clone(), span);
226+
});
229227

230228
let proc_macro_backtrace = ecx.ecfg.proc_macro_backtrace;
231-
let strategy = crate::proc_macro::exec_strategy(tcx.sess);
232-
let server = crate::proc_macro_server::Rustc::new(ecx);
229+
let strategy = exec_strategy(ecx.sess);
230+
let server = proc_macro_server::Rustc::new(ecx);
233231

234-
match client.run(&strategy, server, input.clone(), proc_macro_backtrace) {
235-
Ok(stream) => Ok(tcx.arena.alloc(stream) as &TokenStream),
232+
match client.run(&strategy, server, input, proc_macro_backtrace) {
233+
Ok(stream) => Ok(stream),
236234
Err(e) => {
237-
tcx.dcx().emit_err({
235+
ecx.dcx().emit_err({
238236
errors::ProcMacroDerivePanicked {
239237
span,
240238
message: e.as_str().map(|message| errors::ProcMacroDerivePanickedHelp {
@@ -247,53 +245,47 @@ fn expand_derive_macro<'tcx>(
247245
}
248246
}
249247

250-
// based on rust/compiler/rustc_middle/src/ty/context/tls.rs
251-
thread_local! {
252-
/// A thread local variable that stores a pointer to the current `CONTEXT`.
253-
static TLV: Cell<(*mut (), Option<CLIENT>)> = const { Cell::new((std::ptr::null_mut(), None)) };
254-
}
255-
256-
/// Sets `context` as the new current `CONTEXT` for the duration of the function `f`.
257-
#[inline]
258-
pub(crate) fn enter_context<'a, F, R>(context: (&mut ExtCtxt<'a>, CLIENT), f: F) -> R
259-
where
260-
F: FnOnce() -> R,
261-
{
262-
let (ectx, client) = context;
263-
let erased = (ectx as *mut _ as *mut (), Some(client));
264-
TLV.with(|tlv| {
265-
let old = tlv.replace(erased);
266-
let _reset = rustc_data_structures::defer(move || tlv.set(old));
267-
f()
268-
})
248+
/// Stores the context necessary to expand a derive proc macro via a query.
249+
struct QueryDeriveExpandCtx {
250+
/// Type-erased version of `&mut ExtCtxt`
251+
expansion_ctx: *mut (),
252+
client: DeriveClient,
269253
}
270254

271-
/// Allows access to the current `CONTEXT`.
272-
/// Panics if there is no `CONTEXT` available.
273-
#[inline]
274-
#[track_caller]
275-
fn with_context<F, R>(f: F) -> R
276-
where
277-
F: for<'a, 'b> FnOnce(&'b mut (&mut ExtCtxt<'a>, CLIENT)) -> R,
278-
{
279-
let (ectx, client_opt) = TLV.get();
280-
let ectx = NonNull::new(ectx).expect("no CONTEXT stored in tls");
281-
282-
// We could get an `CONTEXT` pointer from another thread.
283-
// Ensure that `CONTEXT` is `DynSync`.
284-
// FIXME(pr-time): we should not be able to?
285-
// sync::assert_dyn_sync::<CONTEXT<'_>>();
286-
287-
// prevent double entering, as that would allow creating two `&mut ExtCtxt`s
288-
// FIXME(pr-time): probably use a RefCell instead (which checks this properly)?
289-
TLV.with(|tlv| {
290-
let old = tlv.replace((std::ptr::null_mut(), None));
291-
let _reset = rustc_data_structures::defer(move || tlv.set(old));
292-
let ectx = {
293-
let mut casted = ectx.cast::<ExtCtxt<'_>>();
294-
unsafe { casted.as_mut() }
295-
};
255+
impl QueryDeriveExpandCtx {
256+
/// Store the extension context and the client into the thread local value.
257+
/// It will be accessible via the `with` method while `f` is active.
258+
fn enter<F, R>(ecx: &mut ExtCtxt<'_>, client: DeriveClient, f: F) -> R
259+
where
260+
F: FnOnce() -> R,
261+
{
262+
// We need erasure to get rid of the lifetime
263+
let ctx = Self { expansion_ctx: ecx as *mut _ as *mut (), client };
264+
DERIVE_EXPAND_CTX.set(&ctx, || f())
265+
}
296266

297-
f(&mut (ectx, client_opt.unwrap()))
298-
})
267+
/// Accesses the thread local value of the derive expansion context.
268+
/// Must be called while the `enter` function is active.
269+
fn with<F, R>(f: F) -> R
270+
where
271+
F: for<'a, 'b> FnOnce(&'b mut ExtCtxt<'a>, DeriveClient) -> R,
272+
{
273+
DERIVE_EXPAND_CTX.with(|ctx| {
274+
let ectx = {
275+
let casted = ctx.expansion_ctx.cast::<ExtCtxt<'_>>();
276+
// SAFETY: We can only get the value from `with` while the `enter` function
277+
// is active (on the callstack), and that function's signature ensures that the
278+
// lifetime is valid.
279+
// If `with` is called at some other time, it will panic due to usage of
280+
// `scoped_tls::with`.
281+
unsafe { casted.as_mut().unwrap() }
282+
};
283+
284+
f(ectx, ctx.client)
285+
})
286+
}
299287
}
288+
289+
// When we invoke a query to expand a derive proc macro, we need to provide it with the expansion
290+
// context and derive Client. We do that using a thread-local.
291+
scoped_tls::scoped_thread_local!(static DERIVE_EXPAND_CTX: QueryDeriveExpandCtx);

0 commit comments

Comments
 (0)