Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions pyk/src/pyk/k2lean4/Prelude.lean
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,8 @@ abbrev SortId : Type := String
abbrev SortInt : Type := Int
abbrev SortString : Type := String
abbrev SortStringBuffer : Type := String

class Inj (From To : Type) : Type where
inj (x : From) : To

def inj {From To : Type} [inst : Inj From To] := inst.inj
67 changes: 65 additions & 2 deletions pyk/src/pyk/k2lean4/k2lean4.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,28 @@
from ..kore.internal import CollectionKind
from ..kore.syntax import SortApp
from ..utils import POSet
from .model import Ctor, ExplBinder, Inductive, Module, Mutual, Signature, Term
from .model import (
Alt,
AltsFieldVal,
Ctor,
ExplBinder,
Inductive,
Instance,
InstField,
Module,
Mutual,
Signature,
SimpleFieldVal,
StructVal,
Term,
)

if TYPE_CHECKING:
from typing import Final

from ..kore.internal import KoreDefn
from ..kore.syntax import SymbolDecl
from .model import Command, Declaration
from .model import Command, Declaration, FieldVal


_VALID_LEAN_IDENT: Final = re.compile(
Expand Down Expand Up @@ -113,6 +127,55 @@ def _collection(self, sort: str) -> Inductive:
ctor = Ctor('mk', Signature((ExplBinder(('coll',), val),), Term(sort)))
return Inductive(sort, Signature((), Term('Type')), ctors=(ctor,))

def inj_module(self) -> Module:
return Module(commands=self._inj_commands())

def _inj_commands(self) -> tuple[Command, ...]:
return tuple(
self._inj_instance(subsort, supersort)
for supersort, subsorts in self.defn.subsorts.items()
for subsort in subsorts
if not supersort.endswith(
'CellMap'
) # Strangely, cell collections can be injected from their value sort in KORE
)

def _inj_instance(self, subsort: str, supersort: str) -> Instance:
ty = Term(f'Inj {subsort} {supersort}')
field = self._inj_field(subsort, supersort)
return Instance(Signature((), ty), StructVal((field,)))

def _inj_field(self, subsort: str, supersort: str) -> InstField:
val = self._inj_val(subsort, supersort)
return InstField('inj', val)

def _inj_val(self, subsort: str, supersort: str) -> FieldVal:
subsubsorts: list[str]
if subsort.endswith('CellMap'):
subsubsorts = [] # Disregard injection from value sort to cell map sort
else:
subsubsorts = sorted(self.defn.subsorts.get(subsort, []))

if not subsubsorts:
return SimpleFieldVal(Term(f'{supersort}.inj_{subsort}'))
else:
return AltsFieldVal(self._inj_alts(subsort, supersort, subsubsorts))

def _inj_alts(self, subsort: str, supersort: str, subsubsorts: list[str]) -> list[Alt]:
def inj(subsort: str, supersort: str, x: str) -> Term:
return Term(f'{supersort}.inj_{subsort} {x}')

res = []
for subsubsort in subsubsorts:
res.append(Alt((inj(subsubsort, subsort, 'x'),), inj(subsubsort, supersort, 'x')))

if self.defn.constructors.get(subsort, []):
# Has actual constructors, not only subsorts
default = Alt((Term('x'),), inj(subsort, supersort, 'x'))
res.append(default)

return res


def _param_sorts(decl: SymbolDecl) -> list[str]:
from ..utils import check_type
Expand Down
143 changes: 143 additions & 0 deletions pyk/src/pyk/k2lean4/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,149 @@ def __str__(self) -> str:
return '\n'.join(lines)


@final
@dataclass(frozen=True)
class Instance(Declaration):
signature: Signature
val: DeclVal
attr_kind: AttrKind | None
priority: int | None
ident: DeclId | None
modifiers: Modifiers | None

def __init__(
self,
signature: Signature,
val: DeclVal,
attr_kind: AttrKind | None = None,
priority: int | None = None,
ident: str | DeclId | None = None,
modifiers: Modifiers | None = None,
):
if priority and priority < 0:
raise ValueError('Priority must be non-negative')
if not signature.ty:
# TODO refine type to avoid this check
raise ValueError('Missing type from signature')
ident = DeclId(ident) if isinstance(ident, str) else ident
object.__setattr__(self, 'signature', signature)
object.__setattr__(self, 'val', val)
object.__setattr__(self, 'attr_kind', attr_kind)
object.__setattr__(self, 'priority', priority)
object.__setattr__(self, 'ident', ident)
object.__setattr__(self, 'modifiers', modifiers)

def __str__(self) -> str:
modifiers = f'{self.modifiers} ' if self.modifiers else ''
attr_kind = f'{self.attr_kind.value} ' if self.attr_kind else ''
priority = f' (priority := {self.priority})' if self.priority is not None else ''
ident = f' {self.ident}' if self.ident else ''
signature = f' {self.signature}' if self.signature else ''

decl = f'{modifiers}{attr_kind}instance{priority}{ident}{signature}'

match self.val:
case SimpleVal():
return f'{decl} := {self.val}'
case StructVal(fields):
lines = []
lines.append(f'{decl} where')
lines.extend(indent(str(field), 2) for field in fields)
return '\n'.join(lines)
case _:
raise AssertionError()


class DeclVal(ABC): ...


@final
@dataclass(frozen=True)
class SimpleVal(DeclVal):
term: Term

def __str__(self) -> str:
return str(self.term)


@final
@dataclass(frozen=True)
class StructVal(DeclVal):
fields: tuple[InstField, ...]

def __init__(self, fields: Iterable[InstField]):
object.__setattr__(self, 'fields', tuple(fields))

def __str__(self) -> str:
return indent('\n'.join(str(field) for field in self.fields), 2)


@final
@dataclass(frozen=True)
class InstField:
lval: str
val: FieldVal
signature: Signature | None

def __init__(self, lval: str, val: FieldVal, signature: Signature | None = None):
object.__setattr__(self, 'lval', lval)
object.__setattr__(self, 'val', val)
object.__setattr__(self, 'signature', signature)

def __str__(self) -> str:
signature = f' {self.signature}' if self.signature else ''
decl = f'{self.lval}{signature}'
match self.val:
case SimpleFieldVal():
return f'{decl} := {self.val}'
case AltsFieldVal(alts):
lines = []
lines.append(f'{decl}')
lines.extend(indent(str(alt), 2) for alt in alts)
return '\n'.join(lines)
case _:
raise AssertionError()


class FieldVal(ABC): ...


@final
@dataclass(frozen=True)
class SimpleFieldVal(FieldVal):
term: Term

def __str__(self) -> str:
return str(self.term)


@final
@dataclass(frozen=True)
class AltsFieldVal(FieldVal):
alts: tuple[Alt, ...]

def __init__(self, alts: Iterable[Alt]):
object.__setattr__(self, 'alts', tuple(alts))

def __str__(self) -> str:
return indent('\n'.join(str(alt) for alt in self.alts), 2)


@final
@dataclass(frozen=True)
class Alt:
patterns: tuple[Term, ...]
rhs: Term

def __init__(self, patterns: Iterable[Term], rhs: Term):
object.__setattr__(self, 'patterns', tuple(patterns))
object.__setattr__(self, 'rhs', rhs)

def __str__(self) -> str:
patterns = ', '.join(str(pattern) for pattern in self.patterns)
return f'| {patterns} => {self.rhs}'


@final
@dataclass(frozen=True)
class DeclId:
Expand Down
Loading