Skip to content

Commit 4c144cc

Browse files
committed
Add class Instance to model
1 parent 10bdfe0 commit 4c144cc

File tree

1 file changed

+141
-0
lines changed

1 file changed

+141
-0
lines changed

pyk/src/pyk/k2lean4/model.py

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,147 @@ def __str__(self) -> str:
117117
return '\n'.join(lines)
118118

119119

120+
@final
121+
@dataclass(frozen=True)
122+
class Instance(Declaration):
123+
signature: Signature
124+
val: DeclVal
125+
attr_kind: AttrKind | None
126+
priority: int | None
127+
ident: DeclId | None
128+
modifiers: Modifiers | None
129+
130+
def __init__(
131+
self,
132+
signature: Signature,
133+
val: DeclVal,
134+
attr_kind: AttrKind | None = None,
135+
priority: int | None = None,
136+
ident: str | DeclId | None = None,
137+
modifiers: Modifiers | None = None,
138+
):
139+
if not signature.ty:
140+
# TODO refine type to avoid this check
141+
raise ValueError('Missing type from signature')
142+
ident = DeclId(ident) if isinstance(ident, str) else ident
143+
object.__setattr__(self, 'signature', signature)
144+
object.__setattr__(self, 'val', val)
145+
object.__setattr__(self, 'attr_kind', attr_kind)
146+
object.__setattr__(self, 'priority', priority)
147+
object.__setattr__(self, 'ident', ident)
148+
object.__setattr__(self, 'modifiers', modifiers)
149+
150+
def __str__(self) -> str:
151+
modifiers = f'{self.modifiers} ' if self.modifiers else ''
152+
attr_kind = f'{self.attr_kind.value} ' if self.attr_kind else ''
153+
priority = f' priority := {self.priority}' if self.priority is not None else ''
154+
ident = f' {self.ident}' if self.ident else ''
155+
signature = f' {self.signature}' if self.signature else ''
156+
157+
decl = f'{modifiers}{attr_kind}instance{priority}{ident}{signature}'
158+
159+
match self.val:
160+
case SimpleVal():
161+
return f'{decl} := {self.val}'
162+
case StructVal(fields):
163+
lines = []
164+
lines.append(f'{decl} where')
165+
lines.extend(indent(str(field), 2) for field in fields)
166+
return '\n'.join(lines)
167+
case _:
168+
raise AssertionError()
169+
170+
171+
class DeclVal(ABC): ...
172+
173+
174+
@final
175+
@dataclass(frozen=True)
176+
class SimpleVal(DeclVal):
177+
term: Term
178+
179+
def __str__(self) -> str:
180+
return str(self.term)
181+
182+
183+
@final
184+
@dataclass(frozen=True)
185+
class StructVal(DeclVal):
186+
fields: tuple[InstField, ...]
187+
188+
def __init__(self, fields: Iterable[InstField]):
189+
object.__setattr__(self, 'fields', tuple(fields))
190+
191+
def __str__(self) -> str:
192+
return indent('\n'.join(str(field) for field in self.fields), 2)
193+
194+
195+
@final
196+
@dataclass(frozen=True)
197+
class InstField:
198+
lval: str
199+
val: FieldVal
200+
signature: Signature | None
201+
202+
def __init__(self, lval: str, val: FieldVal, signature: Signature | None = None):
203+
object.__setattr__(self, 'lval', lval)
204+
object.__setattr__(self, 'val', val)
205+
object.__setattr__(self, 'signature', signature)
206+
207+
def __str__(self) -> str:
208+
signature = f' {self.signature}' if self.signature else ''
209+
decl = f'{self.lval}{signature}'
210+
match self.val:
211+
case SimpleFieldVal():
212+
return f'{decl} := {self.val}'
213+
case AltsFieldVal(alts):
214+
lines = []
215+
lines.append(f'{decl}')
216+
lines.extend(indent(str(alt), 2) for alt in alts)
217+
return '\n'.join(lines)
218+
case _:
219+
raise AssertionError()
220+
221+
222+
class FieldVal(ABC): ...
223+
224+
225+
@final
226+
@dataclass(frozen=True)
227+
class SimpleFieldVal(FieldVal):
228+
term: Term
229+
230+
def __str__(self) -> str:
231+
return str(self.term)
232+
233+
234+
@final
235+
@dataclass(frozen=True)
236+
class AltsFieldVal(FieldVal):
237+
alts: tuple[Alt, ...]
238+
239+
def __init__(self, alts: Iterable[Alt]):
240+
object.__setattr__(self, 'alts', tuple(alts))
241+
242+
def __str__(self) -> str:
243+
return indent('\n'.join(str(alt) for alt in self.alts), 2)
244+
245+
246+
@final
247+
@dataclass(frozen=True)
248+
class Alt:
249+
patterns: tuple[Term, ...]
250+
rhs: Term
251+
252+
def __init__(self, patterns: Iterable[Term], rhs: Term):
253+
object.__setattr__(self, 'patterns', tuple(patterns))
254+
object.__setattr__(self, 'rhs', rhs)
255+
256+
def __str__(self) -> str:
257+
patterns = ', '.join(str(pattern) for pattern in self.patterns)
258+
return f'| {patterns} => {self.rhs}'
259+
260+
120261
@final
121262
@dataclass(frozen=True)
122263
class DeclId:

0 commit comments

Comments
 (0)