Skip to content

Commit 8c40191

Browse files
authored
Generate Lean 4 type definitions from a KORE definition (#4717)
* Add a prelude for basic primitive sorts * Generate and `inductive` for each constructed sort * Generate an `abbrev` for each collection sort
1 parent 06ffc88 commit 8c40191

File tree

5 files changed

+453
-4
lines changed

5 files changed

+453
-4
lines changed

pyk/src/pyk/k2lean4/Prelude.lean

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
abbrev SortBool : Type := Int
2+
abbrev SortBytes: Type := ByteArray
3+
abbrev SortId : Type := String
4+
abbrev SortInt : Type := Int
5+
abbrev SortString : Type := String
6+
abbrev SortStringBuffer : Type := String
7+
8+
abbrev ListHook (E : Type) : Type := List E
9+
abbrev MapHook (K : Type) (V : Type) : Type := List (K × V)
10+
abbrev SetHook (E : Type) : Type := List E

pyk/src/pyk/k2lean4/__init__.py

Whitespace-only changes.

pyk/src/pyk/k2lean4/k2lean4.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
from __future__ import annotations
2+
3+
from dataclasses import dataclass
4+
from typing import TYPE_CHECKING
5+
6+
from ..kore.internal import CollectionKind
7+
from ..kore.syntax import SortApp
8+
from ..utils import check_type
9+
from .model import Abbrev, Ctor, ExplBinder, Inductive, Module, Signature, Term
10+
11+
if TYPE_CHECKING:
12+
from ..kore.internal import KoreDefn
13+
from .model import Command
14+
15+
16+
@dataclass(frozen=True)
17+
class K2Lean4:
18+
defn: KoreDefn
19+
20+
def sort_module(self) -> Module:
21+
commands = []
22+
commands += self._inductives()
23+
commands += self._collections()
24+
return Module(commands=commands)
25+
26+
def _inductives(self) -> list[Command]:
27+
def is_inductive(sort: str) -> bool:
28+
decl = self.defn.sorts[sort]
29+
return not decl.hooked and 'hasDomainValues' not in decl.attrs_by_key
30+
31+
sorts = sorted(sort for sort in self.defn.sorts if is_inductive(sort))
32+
return [self._inductive(sort) for sort in sorts]
33+
34+
def _inductive(self, sort: str) -> Inductive:
35+
subsorts = sorted(self.defn.subsorts.get(sort, ()))
36+
symbols = sorted(self.defn.constructors.get(sort, ()))
37+
ctors: list[Ctor] = []
38+
ctors.extend(self._inj_ctor(sort, subsort) for subsort in subsorts)
39+
ctors.extend(self._symbol_ctor(sort, symbol) for symbol in symbols)
40+
return Inductive(sort, Signature((), Term('Type')), ctors=ctors)
41+
42+
def _inj_ctor(self, sort: str, subsort: str) -> Ctor:
43+
return Ctor(f'inj_{subsort}', Signature((ExplBinder(('x',), Term(subsort)),), Term(sort)))
44+
45+
def _symbol_ctor(self, sort: str, symbol: str) -> Ctor:
46+
param_sorts = (
47+
check_type(sort, SortApp).name for sort in self.defn.symbols[symbol].param_sorts
48+
) # TODO eliminate check_type
49+
binders = tuple(ExplBinder((f'x{i}',), Term(sort)) for i, sort in enumerate(param_sorts))
50+
symbol = symbol.replace('-', '_')
51+
return Ctor(symbol, Signature(binders, Term(sort)))
52+
53+
def _collections(self) -> list[Command]:
54+
return [self._collection(sort) for sort in sorted(self.defn.collections)]
55+
56+
def _collection(self, sort: str) -> Abbrev:
57+
coll = self.defn.collections[sort]
58+
elem = self.defn.symbols[coll.element]
59+
sorts = ' '.join(check_type(sort, SortApp).name for sort in elem.param_sorts) # TODO eliminate check_type
60+
assert sorts
61+
match coll.kind:
62+
case CollectionKind.LIST:
63+
val = Term(f'ListHook {sorts}')
64+
case CollectionKind.MAP:
65+
val = Term(f'MapHook {sorts}')
66+
case CollectionKind.SET:
67+
val = Term(f'SetHook {sorts}')
68+
return Abbrev(sort, val, Signature((), Term('Type')))

pyk/src/pyk/k2lean4/model.py

Lines changed: 306 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,306 @@
1+
from __future__ import annotations
2+
3+
from abc import ABC
4+
from dataclasses import dataclass
5+
from enum import Enum
6+
from typing import TYPE_CHECKING, final
7+
8+
if TYPE_CHECKING:
9+
from collections.abc import Iterable
10+
11+
12+
def indent(text: str, n: int) -> str:
13+
indent = n * ' '
14+
res = []
15+
for line in text.splitlines():
16+
res.append(f'{indent}{line}' if line else '')
17+
return '\n'.join(res)
18+
19+
20+
@final
21+
@dataclass(frozen=True)
22+
class Module:
23+
commands: tuple[Command, ...]
24+
25+
def __init__(self, commands: Iterable[Command] | None = None):
26+
commands = tuple(commands) if commands is not None else ()
27+
object.__setattr__(self, 'commands', commands)
28+
29+
def __str__(self) -> str:
30+
return '\n\n'.join(str(command) for command in self.commands)
31+
32+
33+
class Command(ABC): ...
34+
35+
36+
class Mutual(Command):
37+
commands: tuple[Command, ...]
38+
39+
def __init__(self, commands: Iterable[Command] | None = None):
40+
commands = tuple(commands) if commands is not None else ()
41+
object.__setattr__(self, 'commands', commands)
42+
43+
def __str__(self) -> str:
44+
commands = '\n\n'.join(indent(str(command), 2) for command in self.commands)
45+
return f'mutual\n{commands}\nend'
46+
47+
48+
class Declaration(Command, ABC):
49+
modifiers: Modifiers | None
50+
51+
52+
@final
53+
@dataclass
54+
class Abbrev(Declaration):
55+
ident: DeclId
56+
val: Term # declVal
57+
signature: Signature | None
58+
modifiers: Modifiers | None
59+
60+
def __init__(
61+
self,
62+
ident: str | DeclId,
63+
val: Term,
64+
signature: Signature | None = None,
65+
modifiers: Modifiers | None = None,
66+
):
67+
ident = DeclId(ident) if isinstance(ident, str) else ident
68+
object.__setattr__(self, 'ident', ident)
69+
object.__setattr__(self, 'val', val)
70+
object.__setattr__(self, 'signature', signature)
71+
object.__setattr__(self, 'modifiers', modifiers)
72+
73+
def __str__(self) -> str:
74+
modifiers = f'{self.modifiers} ' if self.modifiers else ''
75+
signature = f' {self.signature}' if self.signature else ''
76+
return f'{modifiers} abbrev {self.ident}{signature} := {self.val}'
77+
78+
79+
@final
80+
@dataclass(frozen=True)
81+
class Inductive(Declaration):
82+
ident: DeclId
83+
signature: Signature | None
84+
ctors: tuple[Ctor, ...]
85+
deriving: tuple[str, ...]
86+
modifiers: Modifiers | None
87+
88+
def __init__(
89+
self,
90+
ident: str | DeclId,
91+
signature: Signature | None = None,
92+
ctors: Iterable[Ctor] | None = None,
93+
deriving: Iterable[str] | None = None,
94+
modifiers: Modifiers | None = None,
95+
):
96+
ident = DeclId(ident) if isinstance(ident, str) else ident
97+
ctors = tuple(ctors) if ctors is not None else ()
98+
deriving = tuple(deriving) if deriving is not None else ()
99+
object.__setattr__(self, 'ident', ident)
100+
object.__setattr__(self, 'signature', signature)
101+
object.__setattr__(self, 'ctors', ctors)
102+
object.__setattr__(self, 'deriving', deriving)
103+
object.__setattr__(self, 'modifiers', modifiers)
104+
105+
def __str__(self) -> str:
106+
modifiers = f'{self.modifiers} ' if self.modifiers else ''
107+
signature = f' {self.signature}' if self.signature else ''
108+
where = ' where' if self.ctors else ''
109+
deriving = ', '.join(self.deriving)
110+
111+
lines = []
112+
lines.append(f'{modifiers}inductive {self.ident}{signature}{where}')
113+
for ctor in self.ctors:
114+
lines.append(f' | {ctor}')
115+
if deriving:
116+
lines.append(f' deriving {deriving}')
117+
return '\n'.join(lines)
118+
119+
120+
@final
121+
@dataclass(frozen=True)
122+
class DeclId:
123+
val: str
124+
uvars: tuple[str, ...]
125+
126+
def __init__(self, val: str, uvars: Iterable[str] | None = None):
127+
uvars = tuple(uvars) if uvars is not None else ()
128+
object.__setattr__(self, 'val', val)
129+
object.__setattr__(self, 'uvars', uvars)
130+
131+
def __str__(self) -> str:
132+
uvars = ', '.join(self.uvars)
133+
uvars = '.{' + uvars + '}' if uvars else ''
134+
return f'{self.val}{uvars}'
135+
136+
137+
@final
138+
@dataclass(frozen=True)
139+
class Ctor:
140+
ident: str
141+
signature: Signature | None = None
142+
modifiers: Modifiers | None = None
143+
144+
def __str__(self) -> str:
145+
modifiers = f'{self.modifiers} ' if self.modifiers else ''
146+
signature = f' {self.signature}' if self.signature else ''
147+
return f'{modifiers}{self.ident}{signature}'
148+
149+
150+
@final
151+
@dataclass(frozen=True)
152+
class Signature:
153+
binders: tuple[Binder, ...]
154+
ty: Term | None
155+
156+
def __init__(self, binders: Iterable[Binder] | None = None, ty: Term | None = None):
157+
binders = tuple(binders) if binders is not None else ()
158+
object.__setattr__(self, 'binders', binders)
159+
object.__setattr__(self, 'ty', ty)
160+
161+
def __str__(self) -> str:
162+
binders = ' '.join(str(binder) for binder in self.binders)
163+
sep = ' ' if self.binders else ''
164+
ty = f'{sep}: {self.ty}' if self.ty else ''
165+
return f'{binders}{ty}'
166+
167+
168+
class Binder(ABC): ...
169+
170+
171+
class BracketBinder(Binder, ABC): ...
172+
173+
174+
@final
175+
@dataclass(frozen=True)
176+
class ExplBinder(BracketBinder):
177+
idents: tuple[str, ...]
178+
ty: Term | None
179+
180+
def __init__(self, idents: Iterable[str], ty: Term | None = None):
181+
object.__setattr__(self, 'idents', tuple(idents))
182+
object.__setattr__(self, 'ty', ty)
183+
184+
def __str__(self) -> str:
185+
idents = ' '.join(self.idents)
186+
ty = '' if self.ty is None else f' : {self.ty}'
187+
return f'({idents}{ty})'
188+
189+
190+
@final
191+
@dataclass(frozen=True)
192+
class ImplBinder(BracketBinder):
193+
idents: tuple[str, ...]
194+
ty: Term | None
195+
strict: bool
196+
197+
def __init__(self, idents: Iterable[str], ty: Term | None = None, *, strict: bool | None = None):
198+
object.__setattr__(self, 'idents', tuple(idents))
199+
object.__setattr__(self, 'ty', ty)
200+
object.__setattr__(self, 'strict', bool(strict))
201+
202+
def __str__(self) -> str:
203+
ldelim, rdelim = ['⦃', '⦄'] if self.strict else ['{', '}']
204+
idents = ' '.join(self.idents)
205+
ty = '' if self.ty is None else f' : {self.ty}'
206+
return f'{ldelim}{idents}{ty}{rdelim}'
207+
208+
209+
@final
210+
@dataclass(frozen=True)
211+
class InstBinder(BracketBinder):
212+
ty: Term
213+
ident: str | None
214+
215+
def __init__(self, ty: Term, ident: str | None = None):
216+
object.__setattr__(self, 'ty', ty)
217+
object.__setattr__(self, 'ident', ident)
218+
219+
def __str__(self) -> str:
220+
ident = f'{self.ident} : ' if self.ident else ''
221+
return f'[{ident}{self.ty}]'
222+
223+
224+
@final
225+
@dataclass(frozen=True)
226+
class Term:
227+
term: str # TODO: refine
228+
229+
def __str__(self) -> str:
230+
return self.term
231+
232+
233+
@final
234+
@dataclass(frozen=True)
235+
class Modifiers:
236+
attrs: tuple[Attr, ...]
237+
visibility: Visibility | None
238+
noncomputable: bool
239+
unsafe: bool
240+
totality: Totality | None
241+
242+
def __init__(
243+
self,
244+
*,
245+
attrs: Iterable[str | Attr] | None = None,
246+
visibility: str | Visibility | None = None,
247+
noncomputable: bool | None = None,
248+
unsafe: bool | None = None,
249+
totality: str | Totality | None = None,
250+
):
251+
attrs = tuple(Attr(attr) if isinstance(attr, str) else attr for attr in attrs) if attrs is not None else ()
252+
visibility = Visibility(visibility) if isinstance(visibility, str) else visibility
253+
noncomputable = bool(noncomputable)
254+
unsafe = bool(unsafe)
255+
totality = Totality(totality) if isinstance(totality, str) else totality
256+
object.__setattr__(self, 'attrs', attrs)
257+
object.__setattr__(self, 'visibility', visibility)
258+
object.__setattr__(self, 'noncomputable', noncomputable)
259+
object.__setattr__(self, 'unsafe', unsafe)
260+
object.__setattr__(self, 'totality', totality)
261+
262+
def __str__(self) -> str:
263+
chunks = []
264+
if self.attrs:
265+
attrs = ', '.join(str(attr) for attr in self.attrs)
266+
chunks.append(f'@[{attrs}]')
267+
if self.visibility:
268+
chunks.append(self.visibility.value)
269+
if self.noncomputable:
270+
chunks.append('noncomputable')
271+
if self.unsafe:
272+
chunks.append('unsafe')
273+
if self.totality:
274+
chunks.append(self.totality.value)
275+
return ' '.join(chunks)
276+
277+
278+
@final
279+
@dataclass(frozen=True)
280+
class Attr:
281+
attr: str
282+
kind: AttrKind | None
283+
284+
def __init__(self, attr: str, kind: AttrKind | None = None):
285+
object.__setattr__(self, 'attr', attr)
286+
object.__setattr__(self, 'kind', kind)
287+
288+
def __str__(self) -> str:
289+
if self.kind:
290+
return f'{self.kind.value} {self.attr}'
291+
return self.attr
292+
293+
294+
class AttrKind(Enum):
295+
SCOPED = 'scoped'
296+
LOCAL = 'local'
297+
298+
299+
class Visibility(Enum):
300+
PRIVATE = 'private'
301+
PROTECTED = 'protected'
302+
303+
304+
class Totality(Enum):
305+
PARTIAL = 'partial'
306+
NONREC = 'nonrec'

0 commit comments

Comments
 (0)