Skip to content

Commit c30ff77

Browse files
committed
use guarded bindings for optional where-clauses
Woohoo!
1 parent 5136925 commit c30ff77

File tree

62 files changed

+368
-271
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

62 files changed

+368
-271
lines changed

crates/formality-core/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ pub mod language;
3232
pub mod parse;
3333
pub mod substitution;
3434
pub mod term;
35+
pub mod util;
3536
pub mod variable;
3637
pub mod visit;
3738

crates/formality-core/src/util.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
/// Returns true if `t` is the default value for `t`.
2+
/// Used by the "derive" code for `Debug`.
3+
pub fn is_default<T>(t: &T) -> bool
4+
where
5+
T: Default + Eq,
6+
{
7+
let default_value: T = Default::default();
8+
default_value == *t
9+
}

crates/formality-macros/src/debug.rs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,19 @@ fn debug_field_with_mode(name: &Ident, mode: &FieldMode) -> TokenStream {
256256
}
257257
}
258258
}
259+
260+
FieldMode::Guarded { guard, mode } => {
261+
let guard = as_literal(guard);
262+
let base = debug_field_with_mode(name, mode);
263+
264+
quote_spanned! { name.span() =>
265+
if !::formality_core::util::is_default(#name) {
266+
write!(fmt, "{}{}", sep, #guard)?;
267+
sep = " ";
268+
#base
269+
}
270+
}
271+
}
259272
}
260273
}
261274

crates/formality-macros/src/parse.rs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,18 @@ fn parse_field_mode(span: Span, mode: &FieldMode) -> TokenStream {
230230
__p.comma_nonterminal()?
231231
}
232232
}
233+
234+
FieldMode::Guarded { guard, mode } => {
235+
let guard_keyword = as_literal(guard);
236+
let initializer = parse_field_mode(span, mode);
237+
quote_spanned! {
238+
span =>
239+
match __p.expect_keyword(#guard_keyword) {
240+
Ok(()) => #initializer,
241+
Err(_) => Default::default(),
242+
}
243+
}
244+
}
233245
}
234246
}
235247

crates/formality-macros/src/spec.rs

Lines changed: 80 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use std::iter::Peekable;
1+
use std::{iter::Peekable, sync::Arc};
22

33
use proc_macro2::{Group, Ident, Punct, TokenStream, TokenTree};
44
use syn::spanned::Spanned;
@@ -36,6 +36,11 @@ pub enum FieldMode {
3636
/// $x -- just parse `x`
3737
Single,
3838

39+
Guarded {
40+
guard: Ident,
41+
mode: Arc<FieldMode>,
42+
},
43+
3944
/// $<x> -- `x` is a `Vec<E>`, parse `<E0,...,En>`
4045
/// $[x] -- `x` is a `Vec<E>`, parse `[E0,...,En]`
4146
/// $(x) -- `x` is a `Vec<E>`, parse `(E0,...,En)`
@@ -100,7 +105,7 @@ fn token_stream_to_symbols(
100105
}
101106
proc_macro2::TokenTree::Ident(ident) => symbols.push(Keyword { ident }),
102107
proc_macro2::TokenTree::Punct(punct) => match punct.as_char() {
103-
'$' => symbols.push(parse_variable_binding(punct, &mut tokens)?),
108+
'$' => symbols.push(parse_variable_binding(&punct, &mut tokens)?),
104109
_ => symbols.push(Char { punct }),
105110
},
106111
proc_macro2::TokenTree::Literal(_) => {
@@ -120,14 +125,16 @@ fn token_stream_to_symbols(
120125
/// or we could also see a `$`, in which case user wrote `$$`, and we treat that as a single
121126
/// `$` sign.
122127
fn parse_variable_binding(
123-
dollar_token: Punct,
124-
tokens: &mut impl Iterator<Item = TokenTree>,
128+
dollar_token: &Punct,
129+
tokens: &mut dyn Iterator<Item = TokenTree>,
125130
) -> syn::Result<FormalitySpecSymbol> {
126-
let dollar_token = &dollar_token;
127131
let mut tokens = tokens.peekable();
128132

129133
let Some(token) = tokens.peek() else {
130-
return error(dollar_token);
134+
return error(
135+
dollar_token,
136+
"incomplete field reference; use `$$` if you just want a dollar sign",
137+
);
131138
};
132139

133140
return match token {
@@ -162,6 +169,12 @@ fn parse_variable_binding(
162169
parse_variable_binding_name(dollar_token, FieldMode::Optional, &mut tokens)
163170
}
164171

172+
// $:guard $x
173+
TokenTree::Punct(punct) if punct.as_char() == ':' => {
174+
let guard_token = tokens.next().unwrap();
175+
parse_guarded_variable_binding(dollar_token, guard_token, &mut tokens)
176+
}
177+
165178
// $<x> or $<?x>
166179
TokenTree::Punct(punct) if punct.as_char() == '<' => {
167180
tokens.next();
@@ -172,7 +185,7 @@ fn parse_variable_binding(
172185
// we should see a `>` next
173186
match tokens.next() {
174187
Some(TokenTree::Punct(punct)) if punct.as_char() == '>' => Ok(result),
175-
_ => error(dollar_token),
188+
_ => error(dollar_token, "expected a `>` to end this field reference"),
176189
}
177190
}
178191

@@ -184,7 +197,7 @@ fn parse_variable_binding(
184197
unreachable!()
185198
};
186199
let Some((open, close)) = open_close(&group) else {
187-
return error(&group);
200+
return error(&group, "did not expect a spliced macro_rules group here");
188201
};
189202

190203
// consume `x` or `?x`
@@ -193,14 +206,65 @@ fn parse_variable_binding(
193206

194207
// there shouldn't be anything else in the token tree
195208
if let Some(t) = group_tokens.next() {
196-
return error(&t);
209+
return error(
210+
&t,
211+
"extra characters in delimited field reference after field name",
212+
);
197213
}
198214
Ok(result)
199215
}
200216

201-
_ => error(dollar_token),
217+
_ => error(dollar_token, "invalid field reference"),
202218
};
203219

220+
fn parse_guarded_variable_binding(
221+
dollar_token: &Punct,
222+
guard_token: TokenTree,
223+
tokens: &mut Peekable<impl Iterator<Item = TokenTree>>,
224+
) -> syn::Result<FormalitySpecSymbol> {
225+
// The next token should be an identifier
226+
let Some(TokenTree::Ident(guard_ident)) = tokens.next() else {
227+
return error(
228+
&guard_token,
229+
"expected an identifier after a `:` in a field reference",
230+
);
231+
};
232+
233+
// The next token should be a `$`, beginning another variable binding
234+
let next_dollar_token = match tokens.next() {
235+
Some(TokenTree::Punct(next_dollar_token)) if next_dollar_token.as_char() == '$' => {
236+
next_dollar_token
237+
}
238+
239+
_ => {
240+
return error(
241+
&dollar_token,
242+
"expected another `$` field reference to follow the `:` guard",
243+
);
244+
}
245+
};
246+
247+
// Then should come another field reference.
248+
let FormalitySpecSymbol::Field { name, mode } =
249+
parse_variable_binding(&next_dollar_token, tokens)?
250+
else {
251+
return error(
252+
&next_dollar_token,
253+
"`$:` must be followed by another field reference, not a `$$` literal",
254+
);
255+
};
256+
257+
let guard_mode = FieldMode::Guarded {
258+
guard: guard_ident,
259+
mode: Arc::new(mode),
260+
};
261+
262+
Ok(FormalitySpecSymbol::Field {
263+
name: name,
264+
mode: guard_mode,
265+
})
266+
}
267+
204268
fn parse_delimited(
205269
dollar_token: &Punct,
206270
open: char,
@@ -236,14 +300,17 @@ fn parse_variable_binding(
236300
// Extract the name of the field.
237301
let name = match tokens.next() {
238302
Some(TokenTree::Ident(name)) => name,
239-
_ => return error(dollar_token),
303+
_ => return error(dollar_token, "expected field name"),
240304
};
241305

242306
Ok(FormalitySpecSymbol::Field { name, mode })
243307
}
244308

245-
fn error(at_token: &impl Spanned) -> syn::Result<FormalitySpecSymbol> {
246-
let message = "invalid field reference in grammar";
309+
fn error(at_token: &impl Spanned, message: impl ToString) -> syn::Result<FormalitySpecSymbol> {
310+
let mut message = message.to_string();
311+
if message.is_empty() {
312+
message = "invalid field reference in grammar".into();
313+
}
247314
Err(syn::Error::new(at_token.span(), message))
248315
}
249316
}

crates/formality-prove/src/decls.rs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ pub struct ImplDecl {
111111
}
112112

113113
/// Data bound under the generics from [`ImplDecl`][]
114-
#[term($trait_ref where $where_clause)]
114+
#[term($trait_ref $:where $where_clause)]
115115
pub struct ImplDeclBoundData {
116116
/// The trait ref that is implemented
117117
pub trait_ref: TraitRef,
@@ -129,7 +129,7 @@ pub struct NegImplDecl {
129129
}
130130

131131
/// Data bound under the impl generics for a negative impl
132-
#[term(!$trait_ref where $where_clause)]
132+
#[term(!$trait_ref $:where $where_clause)]
133133
pub struct NegImplDeclBoundData {
134134
pub trait_ref: TraitRef,
135135
pub where_clause: Wcs,
@@ -209,7 +209,7 @@ pub struct TraitInvariantBoundData {
209209
}
210210

211211
/// The "bound data" for a [`TraitDecl`][] -- i.e., what is covered by the forall.
212-
#[term(where $where_clause)]
212+
#[term($:where $where_clause)]
213213
pub struct TraitDeclBoundData {
214214
/// The where-clauses declared on the trait
215215
pub where_clause: Wcs,
@@ -231,7 +231,7 @@ impl AliasEqDecl {
231231
}
232232

233233
/// Data bound under the impl generics for a [`AliasEqDecl`][]
234-
#[term($alias = $ty where $where_clause)]
234+
#[term($alias = $ty $:where $where_clause)]
235235
pub struct AliasEqDeclBoundData {
236236
/// The alias that is equal
237237
pub alias: AliasTy,
@@ -258,7 +258,7 @@ impl AliasBoundDecl {
258258
}
259259
}
260260

261-
#[term($alias : $ensures where $where_clause)]
261+
#[term($alias : $ensures $:where $where_clause)]
262262
pub struct AliasBoundDeclBoundData {
263263
pub alias: AliasTy,
264264
// FIXME: this is currently encoded as something like `<T> [T: Foo]` where
@@ -281,7 +281,7 @@ pub struct AdtDecl {
281281
}
282282

283283
/// The "bound data" for a [`AdtDecl`][].
284-
#[term(where $where_clause)]
284+
#[term($:where $where_clause)]
285285
pub struct AdtDeclBoundData {
286286
/// The where-clauses declared on the ADT,
287287
pub where_clause: Wcs,

crates/formality-rust/src/grammar.rs

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ impl Struct {
9696
}
9797
}
9898

99-
#[term(where $where_clauses { $,fields })]
99+
#[term($:where $,where_clauses { $,fields })]
100100
pub struct StructBoundData {
101101
pub where_clauses: Vec<WhereClause>,
102102
pub fields: Vec<Field>,
@@ -148,7 +148,7 @@ pub struct Adt {
148148
pub binder: Binder<AdtBoundData>,
149149
}
150150

151-
#[term(where $where_clauses { $,variants })]
151+
#[term($:where $,where_clauses { $,variants })]
152152
pub struct AdtBoundData {
153153
pub where_clauses: Vec<WhereClause>,
154154
pub variants: Vec<Variant>,
@@ -179,7 +179,7 @@ impl<T: Term> TraitBinder<T> {
179179
}
180180
}
181181

182-
#[term(where $where_clauses { $*trait_items })]
182+
#[term($:where $,where_clauses { $*trait_items })]
183183
pub struct TraitBoundData {
184184
pub where_clauses: Vec<WhereClause>,
185185
pub trait_items: Vec<TraitItem>,
@@ -199,7 +199,7 @@ pub struct Fn {
199199
pub binder: Binder<FnBoundData>,
200200
}
201201

202-
#[term($(input_tys) -> $output_ty where $where_clauses $body)]
202+
#[term($(input_tys) -> $output_ty $:where $,where_clauses $body)]
203203
pub struct FnBoundData {
204204
pub input_tys: Vec<Ty>,
205205
pub output_ty: Ty,
@@ -232,7 +232,7 @@ pub struct AssociatedTy {
232232
pub binder: Binder<AssociatedTyBoundData>,
233233
}
234234

235-
#[term(: $ensures where $where_clauses)]
235+
#[term(: $ensures $:where $,where_clauses)]
236236
pub struct AssociatedTyBoundData {
237237
/// So e.g. `type Item : [Sized]` would be encoded as `<type I> (I: Sized)`.
238238
pub ensures: Vec<WhereBound>,
@@ -252,7 +252,7 @@ impl TraitImpl {
252252
}
253253
}
254254

255-
#[term($trait_id $<?trait_parameters> for $self_ty where $where_clauses { $*impl_items })]
255+
#[term($trait_id $<?trait_parameters> for $self_ty $:where $,where_clauses { $*impl_items })]
256256
pub struct TraitImplBoundData {
257257
pub trait_id: TraitId,
258258
pub self_ty: Ty,
@@ -272,7 +272,7 @@ pub struct NegTraitImpl {
272272
pub binder: Binder<NegTraitImplBoundData>,
273273
}
274274

275-
#[term(!$trait_id $<?trait_parameters> for $self_ty where $where_clauses { })]
275+
#[term(!$trait_id $<?trait_parameters> for $self_ty $:where $,where_clauses { })]
276276
pub struct NegTraitImplBoundData {
277277
pub trait_id: TraitId,
278278
pub self_ty: Ty,
@@ -300,7 +300,7 @@ pub struct AssociatedTyValue {
300300
pub binder: Binder<AssociatedTyValueBoundData>,
301301
}
302302

303-
#[term(= $ty where $where_clauses)]
303+
#[term(= $ty $:where $,where_clauses)]
304304
pub struct AssociatedTyValueBoundData {
305305
pub where_clauses: Vec<WhereClause>,
306306
pub ty: Ty,

0 commit comments

Comments
 (0)