Skip to content

Commit 0733fbd

Browse files
committed
wip: So yeah, tests pass, but still eval_always (not far from disk caching though I think)
1 parent 2507253 commit 0733fbd

File tree

13 files changed

+287
-33
lines changed

13 files changed

+287
-33
lines changed

compiler/rustc_ast/src/tokenstream.rs

Lines changed: 97 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,14 @@
1414
//! ownership of the original.
1515
1616
use std::borrow::Cow;
17+
use std::hash::Hash;
1718
use std::{cmp, fmt, iter};
1819

1920
use rustc_data_structures::stable_hasher::{HashStable, StableHasher};
2021
use rustc_data_structures::sync::{self, Lrc};
2122
use rustc_macros::{Decodable, Encodable, HashStable_Generic};
22-
use rustc_serialize::{Decodable, Encodable};
23+
use rustc_serialize::{Decodable, Encodable, Encoder};
24+
use rustc_span::def_id::{CrateNum, DefIndex};
2325
use rustc_span::{DUMMY_SP, Span, SpanDecoder, SpanEncoder, Symbol, sym};
2426

2527
use crate::ast::{AttrStyle, StmtKind};
@@ -138,8 +140,10 @@ impl fmt::Debug for LazyAttrTokenStream {
138140
}
139141

140142
impl<S: SpanEncoder> Encodable<S> for LazyAttrTokenStream {
141-
fn encode(&self, _s: &mut S) {
142-
panic!("Attempted to encode LazyAttrTokenStream");
143+
fn encode(&self, s: &mut S) {
144+
// TODO: welp
145+
// TODO: (also) `.flattened()` here?
146+
self.to_attr_token_stream().encode(s)
143147
}
144148
}
145149

@@ -295,6 +299,96 @@ pub struct AttrsTarget {
295299
#[derive(Clone, Debug, Default, Encodable, Decodable)]
296300
pub struct TokenStream(pub(crate) Lrc<Vec<TokenTree>>);
297301

302+
struct HashEncoder<H: std::hash::Hasher> {
303+
hasher: H,
304+
}
305+
306+
impl<H: std::hash::Hasher> Encoder for HashEncoder<H> {
307+
fn emit_usize(&mut self, v: usize) {
308+
self.hasher.write_usize(v)
309+
}
310+
311+
fn emit_u128(&mut self, v: u128) {
312+
self.hasher.write_u128(v)
313+
}
314+
315+
fn emit_u64(&mut self, v: u64) {
316+
self.hasher.write_u64(v)
317+
}
318+
319+
fn emit_u32(&mut self, v: u32) {
320+
self.hasher.write_u32(v)
321+
}
322+
323+
fn emit_u16(&mut self, v: u16) {
324+
self.hasher.write_u16(v)
325+
}
326+
327+
fn emit_u8(&mut self, v: u8) {
328+
self.hasher.write_u8(v)
329+
}
330+
331+
fn emit_isize(&mut self, v: isize) {
332+
self.hasher.write_isize(v)
333+
}
334+
335+
fn emit_i128(&mut self, v: i128) {
336+
self.hasher.write_i128(v)
337+
}
338+
339+
fn emit_i64(&mut self, v: i64) {
340+
self.hasher.write_i64(v)
341+
}
342+
343+
fn emit_i32(&mut self, v: i32) {
344+
self.hasher.write_i32(v)
345+
}
346+
347+
fn emit_i16(&mut self, v: i16) {
348+
self.hasher.write_i16(v)
349+
}
350+
351+
fn emit_raw_bytes(&mut self, s: &[u8]) {
352+
self.hasher.write(s)
353+
}
354+
}
355+
356+
impl<H: std::hash::Hasher> SpanEncoder for HashEncoder<H> {
357+
fn encode_span(&mut self, span: Span) {
358+
span.hash(&mut self.hasher)
359+
}
360+
361+
fn encode_symbol(&mut self, symbol: Symbol) {
362+
symbol.hash(&mut self.hasher)
363+
}
364+
365+
fn encode_expn_id(&mut self, expn_id: rustc_span::ExpnId) {
366+
expn_id.hash(&mut self.hasher)
367+
}
368+
369+
fn encode_syntax_context(&mut self, syntax_context: rustc_span::SyntaxContext) {
370+
syntax_context.hash(&mut self.hasher)
371+
}
372+
373+
fn encode_crate_num(&mut self, crate_num: CrateNum) {
374+
crate_num.hash(&mut self.hasher)
375+
}
376+
377+
fn encode_def_index(&mut self, def_index: DefIndex) {
378+
def_index.hash(&mut self.hasher)
379+
}
380+
381+
fn encode_def_id(&mut self, def_id: rustc_span::def_id::DefId) {
382+
def_id.hash(&mut self.hasher)
383+
}
384+
}
385+
386+
impl Hash for TokenStream {
387+
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
388+
Encodable::encode(self, &mut HashEncoder { hasher: state });
389+
}
390+
}
391+
298392
/// Indicates whether a token can join with the following token to form a
299393
/// compound token. Used for conversions to `proc_macro::Spacing`. Also used to
300394
/// guide pretty-printing, which is where the `JointHidden` value (which isn't

compiler/rustc_expand/src/base.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1074,6 +1074,9 @@ pub trait ResolverExpand {
10741074
trait_def_id: DefId,
10751075
impl_def_id: LocalDefId,
10761076
) -> Result<Vec<(Ident, Option<Ident>)>, Indeterminate>;
1077+
1078+
fn register_proc_macro_invoc(&mut self, invoc_id: LocalExpnId, ext: Lrc<SyntaxExtension>);
1079+
fn unregister_proc_macro_invoc(&mut self, invoc_id: LocalExpnId);
10771080
}
10781081

10791082
pub trait LintStoreExpand {
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
// TODO: remove
2+
#![allow(dead_code)]
3+
4+
use std::cell::Cell;
5+
use std::ptr;
6+
7+
use rustc_ast::tokenstream::TokenStream;
8+
use rustc_middle::ty::TyCtxt;
9+
use rustc_span::profiling::SpannedEventArgRecorder;
10+
use rustc_span::LocalExpnId;
11+
12+
use crate::base::ExtCtxt;
13+
use crate::errors;
14+
15+
pub(super) fn expand<'tcx>(
16+
tcx: TyCtxt<'tcx>,
17+
key: (LocalExpnId, &'tcx TokenStream),
18+
) -> Result<&'tcx TokenStream, ()> {
19+
let (invoc_id, input) = key;
20+
21+
let res = with_context(|(ecx, client)| {
22+
let span = invoc_id.expn_data().call_site;
23+
let _timer =
24+
ecx.sess.prof.generic_activity_with_arg_recorder("expand_proc_macro", |recorder| {
25+
recorder.record_arg_with_span(ecx.sess.source_map(), ecx.expansion_descr(), span);
26+
});
27+
let proc_macro_backtrace = ecx.ecfg.proc_macro_backtrace;
28+
let strategy = crate::proc_macro::exec_strategy(ecx);
29+
let server = crate::proc_macro_server::Rustc::new(ecx);
30+
let res = match client.run(&strategy, server, input.clone(), proc_macro_backtrace) {
31+
// TODO: without flattened some (weird) tests fail, but no idea if it's correct/enough
32+
Ok(stream) => Ok(tcx.arena.alloc(stream.flattened()) as &TokenStream),
33+
Err(e) => {
34+
ecx.dcx().emit_err({
35+
errors::ProcMacroDerivePanicked {
36+
span,
37+
message: e.as_str().map(|message| errors::ProcMacroDerivePanickedHelp {
38+
message: message.into(),
39+
}),
40+
}
41+
});
42+
Err(())
43+
}
44+
};
45+
res
46+
});
47+
48+
res
49+
}
50+
51+
type CLIENT = pm::bridge::client::Client<pm::TokenStream, pm::TokenStream>;
52+
53+
// based on rust/compiler/rustc_middle/src/ty/context/tls.rs
54+
// #[cfg(not(parallel_compiler))]
55+
thread_local! {
56+
/// A thread local variable that stores a pointer to the current `CONTEXT`.
57+
static TLV: Cell<(*mut (), Option<CLIENT>)> = const { Cell::new((ptr::null_mut(), None)) };
58+
}
59+
60+
#[inline]
61+
fn erase(context: &mut ExtCtxt<'_>) -> *mut () {
62+
context as *mut _ as *mut ()
63+
}
64+
65+
#[inline]
66+
unsafe fn downcast<'a>(context: *mut ()) -> &'a mut ExtCtxt<'a> {
67+
unsafe { &mut *(context as *mut ExtCtxt<'a>) }
68+
}
69+
70+
/// Sets `context` as the new current `CONTEXT` for the duration of the function `f`.
71+
#[inline]
72+
pub fn enter_context<'a, F, R>(context: (&mut ExtCtxt<'a>, CLIENT), f: F) -> R
73+
where
74+
F: FnOnce() -> R,
75+
{
76+
let (ectx, client) = context;
77+
let erased = (erase(ectx), Some(client));
78+
TLV.with(|tlv| {
79+
let old = tlv.replace(erased);
80+
let _reset = rustc_data_structures::defer(move || tlv.set(old));
81+
f()
82+
})
83+
}
84+
85+
/// Allows access to the current `CONTEXT` in a closure if one is available.
86+
#[inline]
87+
#[track_caller]
88+
pub fn with_context_opt<F, R>(f: F) -> R
89+
where
90+
F: for<'a, 'b> FnOnce(Option<&'b mut (&mut ExtCtxt<'a>, CLIENT)>) -> R,
91+
{
92+
let (ectx, client_opt) = TLV.get();
93+
if ectx.is_null() {
94+
f(None)
95+
} else {
96+
// We could get an `CONTEXT` pointer from another thread.
97+
// Ensure that `CONTEXT` is `DynSync`.
98+
// TODO: we should not be able to?
99+
// sync::assert_dyn_sync::<CONTEXT<'_>>();
100+
101+
unsafe { f(Some(&mut (downcast(ectx), client_opt.unwrap()))) }
102+
}
103+
}
104+
105+
/// Allows access to the current `CONTEXT`.
106+
/// Panics if there is no `CONTEXT` available.
107+
#[inline]
108+
pub fn with_context<F, R>(f: F) -> R
109+
where
110+
F: for<'a, 'b> FnOnce(&'b mut (&mut ExtCtxt<'a>, CLIENT)) -> R,
111+
{
112+
with_context_opt(|opt_context| f(opt_context.expect("no CONTEXT stored in tls")))
113+
}

compiler/rustc_expand/src/expand.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -796,9 +796,14 @@ impl<'a, 'b> MacroExpander<'a, 'b> {
796796
span,
797797
path,
798798
};
799+
self.cx
800+
.resolver
801+
.register_proc_macro_invoc(invoc.expansion_data.id, ext.clone());
802+
invoc.expansion_data.id.expn_data();
799803
let items = match expander.expand(self.cx, span, &meta, item, is_const) {
800804
ExpandResult::Ready(items) => items,
801805
ExpandResult::Retry(item) => {
806+
self.cx.resolver.unregister_proc_macro_invoc(invoc.expansion_data.id);
802807
// Reassemble the original invocation for retrying.
803808
return ExpandResult::Retry(Invocation {
804809
kind: InvocationKind::Derive { path: meta.path, item, is_const },

compiler/rustc_expand/src/lib.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,15 @@ mod proc_macro_server;
2929
pub use mbe::macro_rules::compile_declarative_macro;
3030
pub mod base;
3131
pub mod config;
32+
pub(crate) mod derive_macro_expansion;
3233
pub mod expand;
3334
pub mod module;
3435
// FIXME(Nilstrieb) Translate proc_macro diagnostics
3536
#[allow(rustc::untranslatable_diagnostic)]
3637
pub mod proc_macro;
3738

39+
pub fn provide(providers: &mut rustc_middle::util::Providers) {
40+
providers.derive_macro_expansion = derive_macro_expansion::expand;
41+
}
42+
3843
rustc_fluent_macro::fluent_messages! { "../messages.ftl" }

compiler/rustc_expand/src/proc_macro.rs

Lines changed: 20 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ use rustc_ast as ast;
22
use rustc_ast::ptr::P;
33
use rustc_ast::tokenstream::TokenStream;
44
use rustc_errors::ErrorGuaranteed;
5+
use rustc_middle::ty;
56
use rustc_parse::parser::{ForceCollect, Parser};
67
use rustc_session::config::ProcMacroExecutionStrategy;
78
use rustc_span::Span;
@@ -31,7 +32,7 @@ impl<T> pm::bridge::server::MessagePipe<T> for MessagePipe<T> {
3132
}
3233
}
3334

34-
fn exec_strategy(ecx: &ExtCtxt<'_>) -> impl pm::bridge::server::ExecutionStrategy {
35+
pub fn exec_strategy(ecx: &ExtCtxt<'_>) -> impl pm::bridge::server::ExecutionStrategy {
3536
pm::bridge::server::MaybeCrossThread::<MessagePipe<_>>::new(
3637
ecx.sess.opts.unstable_opts.proc_macro_execution_strategy
3738
== ProcMacroExecutionStrategy::CrossThread,
@@ -124,36 +125,27 @@ impl MultiItemModifier for DeriveProcMacro {
124125
// altogether. See #73345.
125126
crate::base::ann_pretty_printing_compatibility_hack(&item, &ecx.sess);
126127
let input = item.to_tokens();
127-
let stream = {
128-
let _timer =
129-
ecx.sess.prof.generic_activity_with_arg_recorder("expand_proc_macro", |recorder| {
130-
recorder.record_arg_with_span(
131-
ecx.sess.source_map(),
132-
ecx.expansion_descr(),
133-
span,
134-
);
135-
});
136-
let proc_macro_backtrace = ecx.ecfg.proc_macro_backtrace;
137-
let strategy = exec_strategy(ecx);
138-
let server = proc_macro_server::Rustc::new(ecx);
139-
match self.client.run(&strategy, server, input, proc_macro_backtrace) {
140-
Ok(stream) => stream,
141-
Err(e) => {
142-
ecx.dcx().emit_err({
143-
errors::ProcMacroDerivePanicked {
144-
span,
145-
message: e.as_str().map(|message| {
146-
errors::ProcMacroDerivePanickedHelp { message: message.into() }
147-
}),
148-
}
149-
});
150-
return ExpandResult::Ready(vec![]);
151-
}
152-
}
128+
let res = ty::tls::with(|tcx| {
129+
// TODO: without flattened some (weird) tests fail, but no idea if it's correct/enough
130+
let input = tcx.arena.alloc(input.flattened()) as &TokenStream;
131+
let invoc_id = ecx.current_expansion.id;
132+
133+
assert_eq!(invoc_id.expn_data().call_site, span);
134+
135+
let res = crate::derive_macro_expansion::enter_context((ecx, self.client), move || {
136+
let res = tcx.derive_macro_expansion((invoc_id, input)).cloned();
137+
res
138+
});
139+
140+
res
141+
});
142+
let Ok(output) = res else {
143+
// error will already have been emitted
144+
return ExpandResult::Ready(vec![]);
153145
};
154146

155147
let error_count_before = ecx.dcx().err_count();
156-
let mut parser = Parser::new(&ecx.sess.psess, stream, Some("proc-macro derive"));
148+
let mut parser = Parser::new(&ecx.sess.psess, output, Some("proc-macro derive"));
157149
let mut items = vec![];
158150

159151
loop {

compiler/rustc_interface/src/passes.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -686,6 +686,7 @@ pub static DEFAULT_QUERY_PROVIDERS: LazyLock<Providers> = LazyLock::new(|| {
686686
providers.resolutions = |tcx, ()| tcx.resolver_for_lowering_raw(()).1;
687687
providers.early_lint_checks = early_lint_checks;
688688
proc_macro_decls::provide(providers);
689+
rustc_expand::provide(providers);
689690
rustc_const_eval::provide(providers);
690691
rustc_middle::hir::provide(providers);
691692
rustc_borrowck::provide(providers);

compiler/rustc_middle/src/arena.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@ macro_rules! arena_types {
114114
[decode] specialization_graph: rustc_middle::traits::specialization_graph::Graph,
115115
[] crate_inherent_impls: rustc_middle::ty::CrateInherentImpls,
116116
[] hir_owner_nodes: rustc_hir::OwnerNodes<'tcx>,
117+
[] token_stream: rustc_ast::tokenstream::TokenStream,
117118
]);
118119
)
119120
}

compiler/rustc_middle/src/query/erase.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
use std::intrinsics::transmute_unchecked;
22
use std::mem::MaybeUninit;
33

4+
use rustc_ast::tokenstream::TokenStream;
45
use rustc_span::ErrorGuaranteed;
56

67
use crate::query::CyclePlaceholder;
@@ -172,6 +173,10 @@ impl EraseType for Result<ty::EarlyBinder<'_, Ty<'_>>, CyclePlaceholder> {
172173
type Result = [u8; size_of::<Result<ty::EarlyBinder<'static, Ty<'_>>, CyclePlaceholder>>()];
173174
}
174175

176+
impl EraseType for Result<&'_ TokenStream, ()> {
177+
type Result = [u8; size_of::<Result<&'static TokenStream, ()>>()];
178+
}
179+
175180
impl<T> EraseType for Option<&'_ T> {
176181
type Result = [u8; size_of::<Option<&'static ()>>()];
177182
}

0 commit comments

Comments
 (0)