Skip to content

Commit 91ae169

Browse files
author
karroffel
committed
Merge branch 'arg-forwarding' into 'master'
Pass simple args to the closure and preserve their types See merge request karroffel/contracts!5
2 parents 7936c39 + ad9a75c commit 91ae169

File tree

4 files changed

+212
-23
lines changed

4 files changed

+212
-23
lines changed

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ override_log = []
3535
mirai_assertions = []
3636

3737
[dependencies]
38-
syn = { version = "1.0", features = ["extra-traits", "full", "visit-mut"] }
38+
syn = { version = "1.0", features = ["extra-traits", "full", "visit", "visit-mut"] }
3939
quote = "1.0"
4040
proc-macro2 = "1.0"
4141

src/implementation/codegen.rs

Lines changed: 160 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -316,29 +316,14 @@ pub(crate) fn generate(
316316
};
317317

318318
//
319-
// wrap the function body in a closure
319+
// wrap the function body in a closure if we have any postconditions
320320
//
321321

322-
let block = func.function.block.clone();
323-
324-
let ret_ty = match &func.function.sig.output {
325-
ReturnType::Type(_, ty) => {
326-
let span = ty.span();
327-
match ty.as_ref() {
328-
syn::Type::ImplTrait(_) => quote::quote! {},
329-
ty => quote::quote_spanned! { span=>
330-
-> #ty
331-
},
332-
}
333-
}
334-
ReturnType::Default => quote::quote! {},
335-
};
336-
337-
let body = quote::quote! {
338-
#[allow(unused_mut)]
339-
let mut run = || #ret_ty #block;
340-
341-
let ret = run();
322+
let body = if post.is_empty() {
323+
let block = &func.function.block;
324+
quote::quote! { let ret = #block; }
325+
} else {
326+
create_closure_body_and_adjust_signature(&func.function)
342327
};
343328

344329
//
@@ -371,3 +356,157 @@ pub(crate) fn generate(
371356

372357
func.function.into_token_stream()
373358
}
359+
360+
struct SelfReplacer<'a> {
361+
replace_with: &'a syn::Ident,
362+
}
363+
364+
impl syn::visit_mut::VisitMut for SelfReplacer<'_> {
365+
fn visit_ident_mut(&mut self, i: &mut Ident) {
366+
if i == "self" {
367+
*i = self.replace_with.clone();
368+
}
369+
}
370+
}
371+
372+
fn ty_contains_impl_trait(ty: &syn::Type) -> bool {
373+
use syn::visit::Visit;
374+
375+
struct TyContainsImplTrait {
376+
seen_impl_trait: bool,
377+
}
378+
379+
impl syn::visit::Visit<'_> for TyContainsImplTrait {
380+
fn visit_type_impl_trait(&mut self, _: &syn::TypeImplTrait) {
381+
self.seen_impl_trait = true;
382+
}
383+
}
384+
385+
let mut vis = TyContainsImplTrait {
386+
seen_impl_trait: false,
387+
};
388+
vis.visit_type(ty);
389+
vis.seen_impl_trait
390+
}
391+
392+
fn create_closure_body_and_adjust_signature(func: &syn::ItemFn) -> TokenStream {
393+
let is_method = func.sig.receiver().is_some();
394+
395+
// If the function has a receiver (e.g. `self`, `&mut self`, or `self: Box<Self>`) rename it
396+
// to `this__` within the closure
397+
398+
let mut block = func.block.clone();
399+
let mut closure_args = vec![];
400+
let mut arg_names = vec![];
401+
402+
if is_method {
403+
let this_ident = syn::Ident::new("this__", Span::call_site());
404+
405+
let mut receiver = func.sig.inputs[0].clone();
406+
match receiver {
407+
// self, &self, &mut self
408+
syn::FnArg::Receiver(rcv) => {
409+
// The `Self` type.
410+
let self_ty = Box::new(syn::Type::Path(syn::TypePath {
411+
qself: None,
412+
path: syn::Path::from(syn::Ident::new("Self", rcv.span())),
413+
}));
414+
415+
let ty = if let Some((and_token, ref lifetime)) = rcv.reference
416+
{
417+
Box::new(syn::Type::Reference(syn::TypeReference {
418+
and_token,
419+
lifetime: lifetime.clone(),
420+
mutability: rcv.mutability,
421+
elem: self_ty,
422+
}))
423+
} else {
424+
self_ty
425+
};
426+
427+
let pat_mut = if rcv.reference.is_none() {
428+
rcv.mutability
429+
} else {
430+
None
431+
};
432+
433+
// this__: [& [mut]] Self
434+
let new_rcv = syn::PatType {
435+
attrs: rcv.attrs.clone(),
436+
pat: Box::new(syn::Pat::Ident(syn::PatIdent {
437+
attrs: vec![],
438+
by_ref: None,
439+
mutability: pat_mut,
440+
ident: this_ident.clone(),
441+
subpat: None,
442+
})),
443+
colon_token: syn::Token![:](rcv.span()),
444+
ty,
445+
};
446+
447+
receiver = syn::FnArg::Typed(new_rcv);
448+
}
449+
450+
// self: Box<Self>
451+
syn::FnArg::Typed(ref mut pat) => {
452+
if let syn::Pat::Ident(ref mut ident) = *pat.pat {
453+
if ident.ident == "self" {
454+
ident.ident = this_ident.clone();
455+
}
456+
}
457+
}
458+
}
459+
460+
closure_args.push(receiver);
461+
arg_names.push(syn::Ident::new("self", Span::call_site()));
462+
463+
// Replace any references to `self` in the function body with references to `this__`.
464+
syn::visit_mut::visit_block_mut(
465+
&mut SelfReplacer {
466+
replace_with: &this_ident,
467+
},
468+
&mut block,
469+
);
470+
}
471+
472+
// Replace any pattern matching in the function signature with placeholder identifiers.
473+
// Pattern matching gets done in the closure instead.
474+
let args = func.sig.inputs.iter().skip(if is_method { 1 } else { 0 });
475+
for arg in args {
476+
match arg {
477+
syn::FnArg::Receiver(_) => unreachable!("Multiple receivers?"),
478+
479+
syn::FnArg::Typed(syn::PatType { pat, ty, .. }) => {
480+
if !ty_contains_impl_trait(ty) {
481+
if let syn::Pat::Ident(ident) = &**pat {
482+
arg_names.push(ident.ident.clone());
483+
closure_args.push(arg.clone());
484+
}
485+
}
486+
}
487+
}
488+
}
489+
490+
let ret_ty = match &func.sig.output {
491+
ReturnType::Type(_, ty) => {
492+
let span = ty.span();
493+
match ty.as_ref() {
494+
syn::Type::ImplTrait(_) => quote::quote! {},
495+
ty => quote::quote_spanned! { span=>
496+
-> #ty
497+
},
498+
}
499+
}
500+
ReturnType::Default => quote::quote! {},
501+
};
502+
503+
let closure_args = closure_args.iter();
504+
let arg_names = arg_names.iter();
505+
506+
quote::quote! {
507+
#[allow(unused_mut)]
508+
let mut run = |#(#closure_args),*| #ret_ty #block;
509+
510+
let ret = run(#(#arg_names),*);
511+
}
512+
}

tests/functions.rs

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,13 +89,52 @@ fn test_early_return() {
8989
abs(-4);
9090
}
9191

92+
#[test]
93+
fn test_mut_ref_and_lifetimes() {
94+
#[requires(i < s.len())]
95+
#[ensures(*ret == 0)]
96+
fn insert_zero<'a>(s: &'a mut [u8], i: usize) -> &'a mut u8 {
97+
s[i] = 0;
98+
&mut s[i]
99+
}
100+
101+
insert_zero(&mut [4, 2], 1);
102+
}
103+
104+
#[test]
105+
fn test_pattern_match() {
106+
#[ensures(ret > a && ret > b)]
107+
fn add((a, b): (u8, u8)) -> u8 {
108+
a.saturating_add(b)
109+
}
110+
111+
assert_eq!(add((4, 2)), 6);
112+
}
113+
92114
#[test]
93115
fn test_impl_trait_return() {
94116
// make sure that compiling functions that return existentially
95117
// qualified types works properly.
96118

97119
#[requires(x >= 10)]
98-
fn impl_test(x: isize) -> impl Clone + std::fmt::Debug {
120+
#[ensures(ret.clone() == ret)]
121+
fn impl_test(x: isize) -> impl Clone + PartialEq + std::fmt::Debug {
122+
"it worked"
123+
}
124+
125+
let x = impl_test(200);
126+
let y = x.clone();
127+
assert_eq!(
128+
format!("{:?} and {:?}", x, y),
129+
r#""it worked" and "it worked""#
130+
);
131+
}
132+
133+
#[test]
134+
fn test_impl_trait_arg() {
135+
#[requires(x.clone() == x)]
136+
#[ensures(ret.clone() == ret)]
137+
fn impl_test(x: impl Clone + PartialEq + std::fmt::Debug) -> &'static str {
99138
"it worked"
100139
}
101140

tests/methods.rs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,15 @@ fn methods() {
2626
self.count += 2;
2727
}
2828

29+
// Manually express the invariant in terms of `ret` since `self.count` is mutably borrowed.
30+
#[requires(is_even(self.count))]
31+
#[ensures(is_even(*ret))]
32+
#[ensures(*ret == old(self.count) + 2)]
33+
fn next_even_and_get<'a>(&'a mut self) -> &'a mut usize {
34+
self.count += 2;
35+
&mut self.count
36+
}
37+
2938
#[invariant(is_even(self.count))]
3039
#[requires(self.count >= 2)]
3140
#[ensures(self.count == old(self.count) - 2)]
@@ -41,6 +50,8 @@ fn methods() {
4150

4251
adder.prev_even();
4352
adder.prev_even();
53+
54+
assert_eq!(*adder.next_even_and_get(), 2);
4455
}
4556

4657
#[test]

0 commit comments

Comments
 (0)