4
4
5
5
use proc_macro2:: { Ident , Span , TokenStream } ;
6
6
use quote:: ToTokens ;
7
- use syn:: { spanned:: Spanned , visit_mut as visitor, Attribute , Expr , ExprCall , ReturnType } ;
7
+ use syn:: {
8
+ spanned:: Spanned ,
9
+ visit:: { visit_return_type, Visit } ,
10
+ visit_mut:: { self as visitor, visit_block_mut, visit_expr_mut, VisitMut } ,
11
+ Attribute , Expr , ExprCall , ReturnType , TypeImplTrait ,
12
+ } ;
8
13
9
14
use crate :: implementation:: { Contract , ContractMode , ContractType , FuncWithContracts } ;
10
15
@@ -98,7 +103,6 @@ pub(crate) fn extract_old_calls(contracts: &mut [Contract]) -> Vec<OldExpr> {
98
103
}
99
104
100
105
for assertion in & mut contract. assertions {
101
- use visitor:: VisitMut ;
102
106
extractor. visit_expr_mut ( assertion) ;
103
107
}
104
108
}
@@ -169,11 +173,11 @@ pub(crate) fn generate(
169
173
// creates an assertion appropriate for the current mode
170
174
let make_assertion = |mode : ContractMode ,
171
175
ctype : ContractType ,
172
- display : proc_macro2 :: TokenStream ,
176
+ display : TokenStream ,
173
177
exec_expr : & Expr ,
174
178
desc : & str | {
175
179
let span = display. span ( ) ;
176
- let mut result = proc_macro2 :: TokenStream :: new ( ) ;
180
+ let mut result = TokenStream :: new ( ) ;
177
181
178
182
let format_args = quote:: quote_spanned! { span=>
179
183
concat!( concat!( #desc, ": " ) , stringify!( #display) )
@@ -208,7 +212,7 @@ pub(crate) fn generate(
208
212
// generate assertion code for pre-conditions
209
213
//
210
214
211
- let pre: proc_macro2 :: TokenStream = func
215
+ let pre = func
212
216
. contracts
213
217
. iter ( )
214
218
. filter ( |c| c. ty == ContractType :: Requires || c. ty == ContractType :: Invariant )
@@ -222,7 +226,7 @@ pub(crate) fn generate(
222
226
let desc = if let Some ( desc) = c. desc . as_ref ( ) {
223
227
format ! ( "{} of {} violated: {}" , contract_type_name, func_name, desc)
224
228
} else {
225
- format ! ( "{} of {} violated" , c . ty . message_name ( ) , func_name)
229
+ format ! ( "{} of {} violated" , contract_type_name , func_name)
226
230
} ;
227
231
228
232
c. assertions
@@ -240,13 +244,13 @@ pub(crate) fn generate(
240
244
)
241
245
} )
242
246
} )
243
- . collect ( ) ;
247
+ . collect :: < TokenStream > ( ) ;
244
248
245
249
//
246
250
// generate assertion code for post-conditions
247
251
//
248
252
249
- let post: proc_macro2 :: TokenStream = func
253
+ let post = func
250
254
. contracts
251
255
. iter ( )
252
256
. filter ( |c| c. ty == ContractType :: Ensures || c. ty == ContractType :: Invariant )
@@ -260,7 +264,7 @@ pub(crate) fn generate(
260
264
let desc = if let Some ( desc) = c. desc . as_ref ( ) {
261
265
format ! ( "{} of {} violated: {}" , contract_type_name, func_name, desc)
262
266
} else {
263
- format ! ( "{} of {} violated" , c . ty . message_name ( ) , func_name)
267
+ format ! ( "{} of {} violated" , contract_type_name , func_name)
264
268
} ;
265
269
266
270
c. assertions
@@ -278,14 +282,14 @@ pub(crate) fn generate(
278
282
)
279
283
} )
280
284
} )
281
- . collect ( ) ;
285
+ . collect :: < TokenStream > ( ) ;
282
286
283
287
//
284
288
// bind "old()" expressions
285
289
//
286
290
287
291
let olds = {
288
- let mut toks = proc_macro2 :: TokenStream :: new ( ) ;
292
+ let mut toks = TokenStream :: new ( ) ;
289
293
290
294
for old in olds {
291
295
let span = old. expr . span ( ) ;
@@ -305,14 +309,27 @@ pub(crate) fn generate(
305
309
} ;
306
310
307
311
//
308
- // wrap the function body in a closure if we have any postconditions
312
+ // wrap the function body in a block so that we can use its return value
309
313
//
310
314
311
- let body = if post. is_empty ( ) {
312
- let block = & func. function . block ;
313
- quote:: quote! { let ret = #block; }
314
- } else {
315
- create_body_closure ( & func. function )
315
+ let body = ' blk: {
316
+ let mut block = func. function . block . clone ( ) ;
317
+ visit_block_mut ( & mut ReturnReplacer , & mut block) ;
318
+
319
+ let mut impl_detector = ImplDetector { found_impl : false } ;
320
+ visit_return_type ( & mut impl_detector, & func. function . sig . output ) ;
321
+
322
+ if !impl_detector. found_impl {
323
+ if let ReturnType :: Type ( .., ref return_type) = func. function . sig . output {
324
+ break ' blk quote:: quote! {
325
+ let ret: #return_type = ' run: #block;
326
+ } ;
327
+ }
328
+ }
329
+
330
+ quote:: quote! {
331
+ let ret = ' run: #block;
332
+ }
316
333
} ;
317
334
318
335
//
@@ -346,177 +363,25 @@ pub(crate) fn generate(
346
363
func. function . into_token_stream ( )
347
364
}
348
365
349
- struct SelfReplacer < ' a > {
350
- replace_with : & ' a syn:: Ident ,
351
- }
366
+ struct ReturnReplacer ;
352
367
353
- impl syn:: visit_mut:: VisitMut for SelfReplacer < ' _ > {
354
- fn visit_ident_mut ( & mut self , i : & mut Ident ) {
355
- if i == "self" {
356
- * i = self . replace_with . clone ( ) ;
368
+ impl VisitMut for ReturnReplacer {
369
+ fn visit_expr_mut ( & mut self , node : & mut Expr ) {
370
+ if let Expr :: Return ( ret_expr) = node {
371
+ let ret_expr_expr = ret_expr. expr . clone ( ) ;
372
+ * node = syn:: parse_quote!( break ' run #ret_expr_expr) ;
357
373
}
358
- }
359
- }
360
-
361
- fn ty_contains_impl_trait ( ty : & syn:: Type ) -> bool {
362
- use syn:: visit:: Visit ;
363
-
364
- struct TyContainsImplTrait {
365
- seen_impl_trait : bool ,
366
- }
367
374
368
- impl syn:: visit:: Visit < ' _ > for TyContainsImplTrait {
369
- fn visit_type_impl_trait ( & mut self , _: & syn:: TypeImplTrait ) {
370
- self . seen_impl_trait = true ;
371
- }
375
+ visit_expr_mut ( self , node) ;
372
376
}
373
-
374
- let mut vis = TyContainsImplTrait {
375
- seen_impl_trait : false ,
376
- } ;
377
- vis. visit_type ( ty) ;
378
- vis. seen_impl_trait
379
377
}
380
378
381
- fn create_body_closure ( func : & syn:: ItemFn ) -> TokenStream {
382
- let is_method = func. sig . receiver ( ) . is_some ( ) ;
383
-
384
- // If the function has a receiver (e.g. `self`, `&mut self`, or `self: Box<Self>`) rename it
385
- // to `this` within the closure
386
-
387
- let mut block = func. block . clone ( ) ;
388
- let mut closure_args = vec ! [ ] ;
389
- let mut arg_names = vec ! [ ] ;
390
-
391
- if is_method {
392
- // `mixed_site` makes the identifier hygienic so it won't collide with a local variable
393
- // named `this`.
394
- let this_ident = syn:: Ident :: new ( "this" , Span :: mixed_site ( ) ) ;
395
-
396
- let mut receiver = func. sig . inputs [ 0 ] . clone ( ) ;
397
- match receiver {
398
- // self, &self, &mut self
399
- syn:: FnArg :: Receiver ( rcv) => {
400
- // The `Self` type.
401
- let self_ty = Box :: new ( syn:: Type :: Path ( syn:: TypePath {
402
- qself : None ,
403
- path : syn:: Path :: from ( syn:: Ident :: new ( "Self" , rcv. span ( ) ) ) ,
404
- } ) ) ;
405
-
406
- let ty = if let Some ( ( and_token, ref lifetime) ) = rcv. reference {
407
- Box :: new ( syn:: Type :: Reference ( syn:: TypeReference {
408
- and_token,
409
- lifetime : lifetime. clone ( ) ,
410
- mutability : rcv. mutability ,
411
- elem : self_ty,
412
- } ) )
413
- } else {
414
- self_ty
415
- } ;
416
-
417
- let pat_mut = if rcv. reference . is_none ( ) {
418
- rcv. mutability
419
- } else {
420
- None
421
- } ;
422
-
423
- // this: [& [mut]] Self
424
- let new_rcv = syn:: PatType {
425
- attrs : rcv. attrs . clone ( ) ,
426
- pat : Box :: new ( syn:: Pat :: Ident ( syn:: PatIdent {
427
- attrs : vec ! [ ] ,
428
- by_ref : None ,
429
- mutability : pat_mut,
430
- ident : this_ident. clone ( ) ,
431
- subpat : None ,
432
- } ) ) ,
433
- colon_token : syn:: Token ![ : ] ( rcv. span ( ) ) ,
434
- ty,
435
- } ;
436
-
437
- receiver = syn:: FnArg :: Typed ( new_rcv) ;
438
- }
439
-
440
- // self: Box<Self>
441
- syn:: FnArg :: Typed ( ref mut pat) => {
442
- if let syn:: Pat :: Ident ( ref mut ident) = * pat. pat {
443
- if ident. ident == "self" {
444
- ident. ident = this_ident. clone ( ) ;
445
- }
446
- }
447
- }
448
- }
449
-
450
- closure_args. push ( receiver) ;
451
-
452
- match & func. sig . inputs [ 0 ] {
453
- syn:: FnArg :: Receiver ( receiver) => {
454
- arg_names. push ( syn:: Ident :: new ( "self" , receiver. self_token . span ( ) ) ) ;
455
- }
456
- syn:: FnArg :: Typed ( pat) => {
457
- if let syn:: Pat :: Ident ( ident) = & * pat. pat {
458
- arg_names. push ( ident. ident . clone ( ) ) ;
459
- } else {
460
- // Non-trivial receiver pattern => do not capture
461
- closure_args. pop ( ) ;
462
- }
463
- }
464
- } ;
465
-
466
- // Replace any references to `self` in the function body with references to `this`.
467
- syn:: visit_mut:: visit_block_mut (
468
- & mut SelfReplacer {
469
- replace_with : & this_ident,
470
- } ,
471
- & mut block,
472
- ) ;
473
- }
474
-
475
- // Any function arguments of the form `ident: ty` become closure arguments and get passed
476
- // explicitly. More complex ones, e.g. pattern matching like `(a, b): (u32, u32)`, are
477
- // captured instead.
478
- let args = func. sig . inputs . iter ( ) . skip ( if is_method { 1 } else { 0 } ) ;
479
- for arg in args {
480
- match arg {
481
- syn:: FnArg :: Receiver ( _) => unreachable ! ( "Multiple receivers?" ) ,
482
-
483
- syn:: FnArg :: Typed ( syn:: PatType { pat, ty, .. } ) => {
484
- if !ty_contains_impl_trait ( ty) {
485
- if let syn:: Pat :: Ident ( ident) = & * * pat {
486
- let ident_str = ident. ident . to_string ( ) ;
487
-
488
- // Any function argument identifier starting with '_' signals
489
- // that the binding will not be used.
490
- if !ident_str. starts_with ( '_' ) || ident_str. starts_with ( "__" ) {
491
- arg_names. push ( ident. ident . clone ( ) ) ;
492
- closure_args. push ( arg. clone ( ) ) ;
493
- }
494
- }
495
- }
496
- }
497
- }
498
- }
499
-
500
- let ret_ty = match & func. sig . output {
501
- ReturnType :: Type ( _, ty) => {
502
- let span = ty. span ( ) ;
503
- match ty. as_ref ( ) {
504
- syn:: Type :: ImplTrait ( _) => quote:: quote! { } ,
505
- ty => quote:: quote_spanned! { span=>
506
- -> #ty
507
- } ,
508
- }
509
- }
510
- ReturnType :: Default => quote:: quote! { } ,
511
- } ;
512
-
513
- let closure_args = closure_args. iter ( ) ;
514
- let arg_names = arg_names. iter ( ) ;
515
-
516
- quote:: quote! {
517
- #[ allow( unused_mut) ]
518
- let mut run = |#( #closure_args) , * | #ret_ty #block;
379
+ struct ImplDetector {
380
+ found_impl : bool ,
381
+ }
519
382
520
- let ret = run( #( #arg_names) , * ) ;
383
+ impl < ' a > Visit < ' a > for ImplDetector {
384
+ fn visit_type_impl_trait ( & mut self , _node : & ' a TypeImplTrait ) {
385
+ self . found_impl = true ;
521
386
}
522
387
}
0 commit comments