diff --git a/pyk/src/pyk/k2lean4/Prelude.lean b/pyk/src/pyk/k2lean4/Prelude.lean index 0c7ad32a778..ef5d0a3c8bb 100644 --- a/pyk/src/pyk/k2lean4/Prelude.lean +++ b/pyk/src/pyk/k2lean4/Prelude.lean @@ -4,3 +4,8 @@ abbrev SortId : Type := String abbrev SortInt : Type := Int abbrev SortString : Type := String abbrev SortStringBuffer : Type := String + +class Inj (From To : Type) : Type where + inj (x : From) : To + +def inj {From To : Type} [inst : Inj From To] := inst.inj diff --git a/pyk/src/pyk/k2lean4/k2lean4.py b/pyk/src/pyk/k2lean4/k2lean4.py index 1dd7b76bfde..3992d39e1ec 100644 --- a/pyk/src/pyk/k2lean4/k2lean4.py +++ b/pyk/src/pyk/k2lean4/k2lean4.py @@ -10,14 +10,28 @@ from ..kore.internal import CollectionKind from ..kore.syntax import SortApp from ..utils import POSet -from .model import Ctor, ExplBinder, Inductive, Module, Mutual, Signature, Term +from .model import ( + Alt, + AltsFieldVal, + Ctor, + ExplBinder, + Inductive, + Instance, + InstField, + Module, + Mutual, + Signature, + SimpleFieldVal, + StructVal, + Term, +) if TYPE_CHECKING: from typing import Final from ..kore.internal import KoreDefn from ..kore.syntax import SymbolDecl - from .model import Command, Declaration + from .model import Command, Declaration, FieldVal _VALID_LEAN_IDENT: Final = re.compile( @@ -113,6 +127,55 @@ def _collection(self, sort: str) -> Inductive: ctor = Ctor('mk', Signature((ExplBinder(('coll',), val),), Term(sort))) return Inductive(sort, Signature((), Term('Type')), ctors=(ctor,)) + def inj_module(self) -> Module: + return Module(commands=self._inj_commands()) + + def _inj_commands(self) -> tuple[Command, ...]: + return tuple( + self._inj_instance(subsort, supersort) + for supersort, subsorts in self.defn.subsorts.items() + for subsort in subsorts + if not supersort.endswith( + 'CellMap' + ) # Strangely, cell collections can be injected from their value sort in KORE + ) + + def _inj_instance(self, subsort: str, supersort: str) -> Instance: + ty = Term(f'Inj {subsort} {supersort}') + field = self._inj_field(subsort, supersort) + return Instance(Signature((), ty), StructVal((field,))) + + def _inj_field(self, subsort: str, supersort: str) -> InstField: + val = self._inj_val(subsort, supersort) + return InstField('inj', val) + + def _inj_val(self, subsort: str, supersort: str) -> FieldVal: + subsubsorts: list[str] + if subsort.endswith('CellMap'): + subsubsorts = [] # Disregard injection from value sort to cell map sort + else: + subsubsorts = sorted(self.defn.subsorts.get(subsort, [])) + + if not subsubsorts: + return SimpleFieldVal(Term(f'{supersort}.inj_{subsort}')) + else: + return AltsFieldVal(self._inj_alts(subsort, supersort, subsubsorts)) + + def _inj_alts(self, subsort: str, supersort: str, subsubsorts: list[str]) -> list[Alt]: + def inj(subsort: str, supersort: str, x: str) -> Term: + return Term(f'{supersort}.inj_{subsort} {x}') + + res = [] + for subsubsort in subsubsorts: + res.append(Alt((inj(subsubsort, subsort, 'x'),), inj(subsubsort, supersort, 'x'))) + + if self.defn.constructors.get(subsort, []): + # Has actual constructors, not only subsorts + default = Alt((Term('x'),), inj(subsort, supersort, 'x')) + res.append(default) + + return res + def _param_sorts(decl: SymbolDecl) -> list[str]: from ..utils import check_type diff --git a/pyk/src/pyk/k2lean4/model.py b/pyk/src/pyk/k2lean4/model.py index fdbc6300145..39e763e8011 100644 --- a/pyk/src/pyk/k2lean4/model.py +++ b/pyk/src/pyk/k2lean4/model.py @@ -117,6 +117,149 @@ def __str__(self) -> str: return '\n'.join(lines) +@final +@dataclass(frozen=True) +class Instance(Declaration): + signature: Signature + val: DeclVal + attr_kind: AttrKind | None + priority: int | None + ident: DeclId | None + modifiers: Modifiers | None + + def __init__( + self, + signature: Signature, + val: DeclVal, + attr_kind: AttrKind | None = None, + priority: int | None = None, + ident: str | DeclId | None = None, + modifiers: Modifiers | None = None, + ): + if priority and priority < 0: + raise ValueError('Priority must be non-negative') + if not signature.ty: + # TODO refine type to avoid this check + raise ValueError('Missing type from signature') + ident = DeclId(ident) if isinstance(ident, str) else ident + object.__setattr__(self, 'signature', signature) + object.__setattr__(self, 'val', val) + object.__setattr__(self, 'attr_kind', attr_kind) + object.__setattr__(self, 'priority', priority) + object.__setattr__(self, 'ident', ident) + object.__setattr__(self, 'modifiers', modifiers) + + def __str__(self) -> str: + modifiers = f'{self.modifiers} ' if self.modifiers else '' + attr_kind = f'{self.attr_kind.value} ' if self.attr_kind else '' + priority = f' (priority := {self.priority})' if self.priority is not None else '' + ident = f' {self.ident}' if self.ident else '' + signature = f' {self.signature}' if self.signature else '' + + decl = f'{modifiers}{attr_kind}instance{priority}{ident}{signature}' + + match self.val: + case SimpleVal(): + return f'{decl} := {self.val}' + case StructVal(fields): + lines = [] + lines.append(f'{decl} where') + lines.extend(indent(str(field), 2) for field in fields) + return '\n'.join(lines) + case _: + raise AssertionError() + + +class DeclVal(ABC): ... + + +@final +@dataclass(frozen=True) +class SimpleVal(DeclVal): + term: Term + + def __str__(self) -> str: + return str(self.term) + + +@final +@dataclass(frozen=True) +class StructVal(DeclVal): + fields: tuple[InstField, ...] + + def __init__(self, fields: Iterable[InstField]): + object.__setattr__(self, 'fields', tuple(fields)) + + def __str__(self) -> str: + return indent('\n'.join(str(field) for field in self.fields), 2) + + +@final +@dataclass(frozen=True) +class InstField: + lval: str + val: FieldVal + signature: Signature | None + + def __init__(self, lval: str, val: FieldVal, signature: Signature | None = None): + object.__setattr__(self, 'lval', lval) + object.__setattr__(self, 'val', val) + object.__setattr__(self, 'signature', signature) + + def __str__(self) -> str: + signature = f' {self.signature}' if self.signature else '' + decl = f'{self.lval}{signature}' + match self.val: + case SimpleFieldVal(): + return f'{decl} := {self.val}' + case AltsFieldVal(alts): + lines = [] + lines.append(f'{decl}') + lines.extend(indent(str(alt), 2) for alt in alts) + return '\n'.join(lines) + case _: + raise AssertionError() + + +class FieldVal(ABC): ... + + +@final +@dataclass(frozen=True) +class SimpleFieldVal(FieldVal): + term: Term + + def __str__(self) -> str: + return str(self.term) + + +@final +@dataclass(frozen=True) +class AltsFieldVal(FieldVal): + alts: tuple[Alt, ...] + + def __init__(self, alts: Iterable[Alt]): + object.__setattr__(self, 'alts', tuple(alts)) + + def __str__(self) -> str: + return indent('\n'.join(str(alt) for alt in self.alts), 2) + + +@final +@dataclass(frozen=True) +class Alt: + patterns: tuple[Term, ...] + rhs: Term + + def __init__(self, patterns: Iterable[Term], rhs: Term): + object.__setattr__(self, 'patterns', tuple(patterns)) + object.__setattr__(self, 'rhs', rhs) + + def __str__(self) -> str: + patterns = ', '.join(str(pattern) for pattern in self.patterns) + return f'| {patterns} => {self.rhs}' + + @final @dataclass(frozen=True) class DeclId: