Skip to content

Commit 2372de9

Browse files
volodeykaVexobenzqy1018
authored
Refactoring Velvet macros (#20)
* fix macros * fix macros * fix macros * add comment on VelvetM.extract * add some examples containing continue and break * remove debug info * fix PBT --------- Co-authored-by: Vexoben <Vexoben@gmail.com> Co-authored-by: zqy1018 <zqy1018@hotmail.com>
1 parent 51dcb70 commit 2372de9

14 files changed

+207
-95
lines changed

CaseStudies/Cashmere/Syntax_Cashmere.lean

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,7 @@ elab_rules : command
297297
ret := ret
298298
pre := pre
299299
post := post
300+
modIds := #[]
300301
}
301302
return (defCmd, obligation)
302303
elabCommand defCmd

CaseStudies/Extension.lean

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ initialize
2626

2727
structure VelvetObligation where
2828
binderIdents : TSyntaxArray `Lean.Parser.Term.bracketedBinder
29+
modIds : Array Ident
2930
ids : Array Ident
3031
retId : Ident
3132
ret : Term
@@ -48,8 +49,9 @@ initialize velvetObligations :
4849

4950
/-- Storing slightly more information than `VelvetObligation`. -/
5051
structure VelvetTestingCtx extends VelvetObligation where
51-
newIds : Array Ident
52+
/-- The _original_ (i.e., without `Old` suffix) binders of mutable arguments. -/
5253
modBinders : Array (TSyntax `Lean.Parser.Term.bracketedBinder)
54+
retType : Term
5355
deriving Inhabited
5456

5557
abbrev VelvetTestingContextMap := Std.HashMap Name VelvetTestingCtx

CaseStudies/TestingUtil.lean

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@ def deriveDecidableNatUpperBound (tms : List <| TSyntax `term)
2323
constructor
2424
next => (intro $h:ident ; intros ; apply $h:ident <;> (try split_ands) <;> (solve
2525
| omega
26-
| aesop))
27-
next => aesop)
26+
| grind))
27+
next => grind)
2828

2929
macro "decidable_by_nat_upperbound" "[" tms:term,* "]" : term => do
3030
let res ← deriveDecidableNatUpperBound tms.getElems.toList

CaseStudies/Velvet/Syntax.lean

Lines changed: 51 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,8 @@ private def toBracketedBinderArrayLeafny (stx : Array (TSyntax `leafny_binder))
8888
let fb ← `(bracketedBinder| ($id : $tp:term))
8989
binders := binders.push fb
9090
| `(leafny_binder| (mut $id:ident : $tp:term)) => do
91-
let fb ← `(bracketedBinder| ($id : $tp:term))
91+
let idOld := mkIdent <| id.getId.appendAfter "Old"
92+
let fb ← `(bracketedBinder| ($idOld : $tp:term))
9293
binders := binders.push fb
9394
| _ => throwError "unexpected syntax in leafny binder: {b}"
9495
return binders
@@ -118,7 +119,8 @@ def getIds (stx : Array (TSyntax `leafny_binder)) : MetaM (Array Ident) := do
118119
for b in stx do
119120
match b with
120121
| `(leafny_binder| (mut $id:ident : $_:term)) => do
121-
ids := ids.push id
122+
let idOld := mkIdent <| id.getId.appendAfter "Old"
123+
ids := ids.push idOld
122124
| `(leafny_binder| ($id:ident : $_:term)) => do
123125
ids := ids.push id
124126
| _ => throwError "unexpected syntax in leafny binder: {b}"
@@ -139,15 +141,19 @@ partial def expandLeafnyDoSeqItem (modIds : Array Ident) (stx : doSeqItem) : Ter
139141
| `(Term.doSeqItem| $stx ;) => expandLeafnyDoSeqItem modIds $ <- `(Term.doSeqItem| $stx:doElem)
140142
| `(Term.doSeqItem| return) => expandLeafnyDoSeqItem modIds $ <- `(Term.doSeqItem| return ())
141143
| `(Term.doSeqItem| return $t) =>
142-
let mut ret <- `(term| ())
143-
for modId in modIds do
144-
ret <- `(term| ⟨$modId, $ret⟩)
145-
return #[<-`(Term.doSeqItem| return ⟨$t, $ret⟩)]
144+
let ret <-
145+
if modIds.size = 0 then
146+
`(term| $t)
147+
else
148+
`(term| ($t, $[$modIds:term],*))
149+
return #[<-`(Term.doSeqItem| return $ret)]
146150
| `(Term.doSeqItem| pure $t) =>
147-
let mut ret <- `(term| ())
148-
for modId in modIds do
149-
ret <- `(term| ⟨$modId, $ret⟩)
150-
return #[<-`(Term.doSeqItem| pure ⟨$t, $ret⟩)]
151+
let ret <-
152+
if modIds.size = 0 then
153+
`(term| $t)
154+
else
155+
`(term| ($t, $[$modIds:term],*))
156+
return #[<-`(Term.doSeqItem| pure $ret)]
151157
| `(Term.doSeqItem| if $h:ident : $t:term then $thn:doSeq else $els:doSeq) =>
152158
let thn <- expandLeafnyDoSeq modIds thn
153159
let els <- expandLeafnyDoSeq modIds els
@@ -227,6 +233,13 @@ private def Array.andList (ts : Array (TSyntax `term)) : TermElabM (TSyntax `ter
227233
t <- `(term| $t' ∧ $t)
228234
return t
229235

236+
private def addPreludeToPreCond (pre : Term) (modIds : Array Ident) : CoreM (TSyntax `term) := do
237+
let mut pre := pre
238+
for modId in modIds do
239+
let modIdOld := mkIdent <| modId.getId.appendAfter "Old"
240+
pre ← `(term| let $modId:ident := $modIdOld:ident; $pre)
241+
pure pre
242+
230243
elab_rules : command
231244
| `(command|
232245
method $name:ident $binders:leafny_binder* return ( $retId:ident : $type:term )
@@ -244,17 +257,22 @@ elab_rules : command
244257

245258
let mut mods := #[]
246259
for modId in modIds do
247-
-- let modIdOld := mkIdent <| modId.getId.appendAfter "Old"
260+
let modIdOld := mkIdent <| modId.getId.appendAfter "Old"
248261
-- let modOld <- `(Term.doSeqItem| let $modIdOld:ident := $modId:ident)
249-
let mod <- `(Term.doSeqItem| let mut $modId:ident := $modId:ident)
262+
let mod <- `(Term.doSeqItem| let mut $modId:ident := $modIdOld:ident)
250263
mods := mods.push mod
251264
let mutTypes ← getMutTypes binders
252-
let mut retType <- `(Unit)
253-
for mutType in mutTypes, modId in modIds do
254-
retType <- `(($modId:ident : $mutType) × $retType)
265+
let mut retType : Term <- `($type)
266+
if mutTypes.size != 0 then
267+
let lastMutType := mutTypes[mutTypes.size - 1]!
268+
let mutTypes := mutTypes.pop.reverse
269+
let mut mutTypeProd := lastMutType
270+
for mutType in mutTypes do
271+
mutTypeProd <- `($mutType × $mutTypeProd)
272+
retType <- `($retType × $mutTypeProd)
255273
let defCmd <- `(command|
256274
set_option linter.unusedVariables false in
257-
def $name $bindersIdents* : VelvetM (($retId:ident : $type) × $retType) := do $mods* $doSeq*
275+
def $name $bindersIdents* : VelvetM $retType:term := do $mods* $doSeq*
258276
$suf:suffix)
259277
-- let lemmaName := mkIdent <| name.getId.appendAfter "_correct"
260278

@@ -264,12 +282,13 @@ elab_rules : command
264282
let post <- ens.andListWithName ensName
265283

266284
let namelessPre <- req.andList
285+
let namelessPre <- addPreludeToPreCond namelessPre modIds
267286
let namelessPost <- ens.andList
268287

269-
let mut ret <- `(term| ())
270-
for modId in modIds do
271-
let modId := mkIdent <| modId.getId.appendAfter "New"
272-
ret <- `(term| ⟨$modId, $ret⟩)
288+
let ret <- if modIds.size = 0 then
289+
`(term| $retId)
290+
else
291+
`(term| ($retId, $[$modIds:term],*))
273292

274293
let ids ← getIds binders
275294
let obligation : VelvetObligation := {
@@ -279,17 +298,16 @@ elab_rules : command
279298
ret := ret
280299
pre := pre
281300
post := post
301+
modIds := modIds
282302
}
283-
let newIds := modIds.map (fun x => Lean.mkIdent <| x.getId.appendAfter "New")
284-
let modBinders ← newIds.zip mutTypes |>.mapM fun (newId, mutType) =>
285-
`(bracketedBinder| ($newId : $mutType))
286-
return (defCmd, obligation, { obligation with pre := namelessPre , post := namelessPost , modBinders , newIds })
303+
let modBinders ← modIds.zip mutTypes |>.mapM fun (mId, mutType) =>
304+
`(bracketedBinder| ($mId : $mutType))
305+
return (defCmd, obligation, { obligation with pre := namelessPre , post := namelessPost , modBinders , retType := type })
287306
elabCommand defCmd
288307
velvetObligations.modify (·.insert name.getId obligation)
289308
velvetTestingContextMap.modify (·.insert name.getId testingCtx)
290309

291310
notation "{" P "}" c "{" v "," Q "}" => triple P c (fun v => Q)
292-
293311
/-
294312
example:
295313
open TotalCorrectness DemonicChoice
@@ -305,22 +323,24 @@ elab_rules : command
305323
let .some obligation := ctx[name.getId]? | throwError "no obligation found"
306324
let bindersIdents := obligation.binderIdents
307325
let ids := obligation.ids
308-
let retId := obligation.retId
326+
-- let retId := obligation.retId
309327
let ret := obligation.ret
310-
let pre := obligation.pre
328+
let pre ← liftCoreM <| addPreludeToPreCond obligation.pre obligation.modIds
311329
let post := obligation.post
312330
let lemmaName := mkIdent <| name.getId.appendAfter "_correct"
313331
-- let proof <- withRef tkp ``()
314332
let proofSeq ← withRef tkp `(tacticSeq|
315333
unfold $name
316334
($proof))
335+
317336
let thmCmd <- withRef tkp `(command|
318337
@[loomSpec]
319338
lemma $lemmaName $bindersIdents* :
320339
triple
321340
$pre
322341
($name $ids*)
323-
(fun ⟨$retId, $ret⟩ => $post) := by $proofSeq $suf:suffix)
342+
(fun $ret => $post) := by $proofSeq $suf:suffix)
343+
trace[Loom] "{thmCmd}"
324344
Command.elabCommand thmCmd
325345
velvetObligations.modify (·.erase name.getId)
326346

@@ -358,7 +378,7 @@ def elabDefiningDecidableInstancesForVelvetSpec (nameRaw : Ident)
358378
let (target, suffix, binders) :=
359379
if pre?
360380
then (ctx.pre, "PreDecidable", bindersIdents)
361-
else (ctx.post, "PostDecidable", bindersIdents ++ ctx.modBinders)
381+
else (ctx.post, "PostDecidable", bindersIdents ++ ctx.modBinders |>.push ⟨mkExplicitBinder ctx.retId ctx.retType⟩)
362382
let decidableInstName := name.appendAfter suffix
363383
-- let proof := tac.getD (← `(term| (by infer_instance) ))
364384
let tac := tac.getD (← `(Lean.Parser.Tactic.tacticSeq| skip ))
@@ -394,7 +414,7 @@ elab_rules : command
394414
let bindersIdents := ctx.binderIdents
395415
let bundle (pre? : Bool) := if pre?
396416
then (ctx.pre, name.appendAfter "PreDecidable", ids)
397-
else (ctx.post, name.appendAfter "PostDecidable", ids ++ ctx.newIds)
417+
else (ctx.post, name.appendAfter "PostDecidable", ids ++ ctx.modIds |>.push retId)
398418
let decideTerm bundled : CommandElabM (TSyntax `term) := do
399419
let (target, instname, args) := bundled
400420
try
@@ -404,7 +424,7 @@ elab_rules : command
404424
`(term| ($(mkIdent ``decide) ($target)))
405425
let matcherTerm ← `(term|
406426
match ($(Syntax.mkApp (mkIdent execName) ids)) with
407-
| $(mkIdent ``DivM.res) ⟨$retId, $ret => $(← decideTerm <| bundle false)
427+
| $(mkIdent ``DivM.res) $ret => $(← decideTerm <| bundle false)
408428
| _ => false)
409429
let ifTerm ← `(term| if $(← decideTerm <| bundle true) then $matcherTerm else true)
410430
let testerName := name.appendAfter "Tester"

CaseStudies/Velvet/VelvetExamples/Examples.lean

Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -56,25 +56,23 @@ attribute [grind] Array.multiset_swap
5656
method insertionSort
5757
(mut arr: Array Int) return (u: Unit)
5858
require 1 ≤ arr.size
59-
ensures forall i j, 0 ≤ i ∧ i ≤ j ∧ j < arr.size → arrNew[i]! ≤ arrNew[j]!
60-
ensures arr.toMultiset = arrNew.toMultiset
59+
ensures forall i j, 0 ≤ i ∧ i ≤ j ∧ j < arr.size → arr[i]! ≤ arr[j]!
60+
ensures arr.toMultiset = arrOld.toMultiset
6161
do
62-
let arr₀ := arr
63-
let arr_size := arr.size
6462
let mut n := 1
6563
while n ≠ arr.size
66-
invariant arr.size = arr_size
64+
invariant arr.size = arrOld.size
6765
invariant 1 ≤ n ∧ n ≤ arr.size
6866
invariant forall i j, 0 ≤ i ∧ i < j ∧ j <= n - 1 → arr[i]! ≤ arr[j]!
69-
invariant arr.toMultiset = arr₀.toMultiset
67+
invariant arr.toMultiset = arrOld.toMultiset
7068
done_with n = arr.size
7169
do
7270
let mut mind := n
7371
while mind ≠ 0
74-
invariant arr.size = arr_size
72+
invariant arr.size = arrOld.size
7573
invariant mind ≤ n
7674
invariant forall i j, 0 ≤ i ∧ i < j ∧ j ≤ n ∧ j ≠ mind → arr[i]! ≤ arr[j]!
77-
invariant arr.toMultiset = arr₀.toMultiset
75+
invariant arr.toMultiset = arrOld.toMultiset
7876
done_with mind = 0
7977
do
8078
if arr[mind]! < arr[mind - 1]! then
@@ -119,14 +117,13 @@ method sqrt (x: ℕ) return (res: ℕ)
119117
do
120118
if x = 0 then
121119
return 0
122-
else
123-
let mut i := 0
124-
while i * i ≤ x
125-
invariant ∀ j, j < i → j * j ≤ x
126-
done_with x < i * i
127-
do
128-
i := i + 1
129-
return i - 1
120+
let mut i := 0
121+
while i * i ≤ x
122+
invariant ∀ j, j < i → j * j ≤ x
123+
done_with x < i * i
124+
do
125+
i := i + 1
126+
return i - 1
130127

131128
set_option auto.smt.trust true
132129
set_option auto.smt true

CaseStudies/Velvet/VelvetExamples/Examples_Total.lean

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,8 @@ method insertionSort(arr: array<int>)
5454
method insertionSort_total
5555
(mut arr: Array Int) return (u: Unit)
5656
require 1 ≤ arr.size
57-
ensures forall i j, 0 ≤ i ∧ i ≤ j ∧ j < arr.size → arrNew[i]! ≤ arrNew[j]!
58-
ensures arr.toMultiset = arrNew.toMultiset
57+
ensures forall i j, 0 ≤ i ∧ i ≤ j ∧ j < arr.size → arr[i]! ≤ arr[j]!
58+
ensures arrOld.toMultiset = arr.toMultiset
5959
do
6060
let arr₀ := arr
6161
let arr_size := arr.size

CaseStudies/Velvet/VelvetExamples/GCD.lean

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ method gcd (a : Nat) (b : Nat) return (res : Nat)
2222
else
2323
let remainder := a % b
2424
let result ← gcd b remainder
25-
return result.1
25+
return result
2626
termination_by b
2727
decreasing_by
2828
apply Nat.mod_lt

CaseStudies/Velvet/VelvetExamples/Recursion.lean

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ method simple_recursion (x : Nat) return (res: Nat)
2121
return 0
2222
else
2323
let pre_res ← simple_recursion (x - 1)
24-
return pre_res.1 + 1
24+
return pre_res + 1
2525

2626
prove_correct simple_recursion by
2727
loom_solve
@@ -45,8 +45,8 @@ method pickGreaterN (n: Nat) return (res: Nat)
4545
return 0
4646
else
4747
let pre_res ← pickGreaterN (n - 1)
48-
let pre_res_big ← pickGreater pre_res.1
49-
return pre_res_big.1
48+
let pre_res_big ← pickGreater pre_res
49+
return pre_res_big
5050

5151
prove_correct pickGreaterN by
5252
loom_solve
@@ -69,11 +69,11 @@ method calc_fact (n: Nat) return (res: Nat)
6969
let pre_res_n ← calc_fact (n - 1)
7070
while i < n
7171
invariant i <= n
72-
invariant i * pre_res_n.1 = ans
72+
invariant i * pre_res_n = ans
7373
decreasing n - i
7474
do
7575
let pre_res ← calc_fact (n - 1)
76-
ans := ans + pre_res.1
76+
ans := ans + pre_res
7777
i := i + 1
7878
return ans
7979

@@ -101,7 +101,7 @@ method SimpleList (li: List Nat) return (res: Nat)
101101
match li with
102102
| x :: xs =>
103103
let prev ← SimpleList xs
104-
return (prev.1 + x)
104+
return (prev + x)
105105
| [] =>
106106
return 1
107107

@@ -139,10 +139,10 @@ method insertTree (tree: mt1 Nat) (elem: Nat) return (res: mt1 Nat)
139139
else
140140
if el < elem then
141141
let right_res ← insertTree r elem
142-
pure (.Node l right_res.1 el)
142+
pure (.Node l right_res el)
143143
else
144144
let left_res ← insertTree l elem
145-
pure (.Node left_res.1 r el)
145+
pure (.Node left_res r el)
146146
| .Leaf el =>
147147
if el = elem then
148148
pure tree
@@ -170,10 +170,10 @@ method complex_measure_binsearch (l : Nat) (r: Nat) (x: Nat) return (res: Nat)
170170
let m := l + (r - l) / 2
171171
if m * m ≤ x then
172172
let pre_res_l ← complex_measure_binsearch m r x
173-
return pre_res_l.1
173+
return pre_res_l
174174
else
175175
let pre_res_r ← complex_measure_binsearch l m x
176-
return pre_res_r.1
176+
return pre_res_r
177177

178178
prove_correct complex_measure_binsearch by
179179
loom_solve
@@ -199,7 +199,7 @@ method pow2 (n: Nat) return (res: Nat)
199199
decreasing n - i
200200
do
201201
let pre_res ← pow2 i
202-
ans := ans + pre_res.1
202+
ans := ans + pre_res
203203
i := i + 1
204204
return ans
205205

0 commit comments

Comments
 (0)