Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 61 additions & 1 deletion kmir/src/kmir/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,35 @@ def _kmir_view(opts: ViewOpts) -> None:
viewer.run()


def _write_to_module(kmir: KMIR, proof: APRProof, to_module_path: Path) -> None:
"""Write proof KCFG as a K module to the specified path."""
import json

from pyk.kast.manip import remove_generated_cells
from pyk.kast.outer import KRule

# Generate K module using KCFG.to_module with defunc_with for proper function inlining
module_name = proof.id.upper().replace('.', '-').replace('_', '-') + '-SUMMARY'
k_module = proof.kcfg.to_module(module_name=module_name, defunc_with=kmir.definition)

if to_module_path.suffix == '.json':
# JSON format for --add-module: keep <generatedTop> for Kore conversion
# Note: We don't use minimize_rule_like here because it creates partial configs
# with dots that cannot be converted back to Kore
to_module_path.write_text(json.dumps(k_module.to_dict(), indent=2))
else:
# K text format for human readability: remove <generatedTop> and <generatedCounter>
def _process_sentence(sent): # type: ignore[no-untyped-def]
if isinstance(sent, KRule):
sent = sent.let(body=remove_generated_cells(sent.body))
return sent

k_module_readable = k_module.let(sentences=[_process_sentence(sent) for sent in k_module.sentences])
k_module_text = kmir.pretty_print(k_module_readable)
to_module_path.write_text(k_module_text)
_LOGGER.info(f'Module written to: {to_module_path}')


def _kmir_show(opts: ShowOpts) -> None:
from pyk.kast.pretty import PrettyPrinter

Expand All @@ -92,6 +121,13 @@ def _kmir_show(opts: ShowOpts) -> None:
kmir = KMIR(HASKELL_DEF_DIR, LLVM_LIB_DIR)
proof = APRProof.read_proof_data(opts.proof_dir, opts.id)

# Minimize proof KCFG if requested
if opts.minimize_proof:
_LOGGER.info('Minimizing proof KCFG...')
proof.minimize_kcfg()
proof.write_proof_data()
_LOGGER.info('Proof KCFG minimized and saved')

# Use custom KMIR printer by default, switch to standard printer if requested
if opts.use_default_printer:
printer = PrettyPrinter(kmir.definition)
Expand Down Expand Up @@ -119,6 +155,7 @@ def _kmir_show(opts: ShowOpts) -> None:
nodes=opts.nodes or (),
node_deltas=effective_node_deltas,
omit_cells=tuple(all_omit_cells),
to_module=opts.to_module is not None,
)
if opts.statistics:
if lines and lines[-1] != '':
Expand All @@ -132,7 +169,12 @@ def _kmir_show(opts: ShowOpts) -> None:
lines.append('')
lines.extend(render_leaf_k_cells(proof, node_printer.cterm_show))

print('\n'.join(lines))
# Handle --to-module output
if opts.to_module:
_write_to_module(kmir, proof, opts.to_module)
print(f'Module written to: {opts.to_module}')
else:
print('\n'.join(lines))


def _kmir_prune(opts: PruneOpts) -> None:
Expand Down Expand Up @@ -410,6 +452,17 @@ def _arg_parser() -> ArgumentParser:
)

show_parser.add_argument('--rules', metavar='EDGES', help='Comma separated list of edges in format "source:target"')
show_parser.add_argument(
'--to-module',
type=Path,
metavar='FILE',
help='Output path for K module file (.k for readable, .json for --add-module)',
)
show_parser.add_argument(
'--minimize-proof',
action='store_true',
help='Minimize the proof KCFG before exporting to module',
)

command_parser.add_parser(
'view', help='View proof information', parents=[kcli_args.logging_args, proof_args, display_args]
Expand Down Expand Up @@ -443,6 +496,12 @@ def _arg_parser() -> ArgumentParser:
prove_rs_parser.add_argument(
'--start-symbol', type=str, metavar='SYMBOL', default='main', help='Symbol name to begin execution from'
)
prove_rs_parser.add_argument(
'--add-module',
type=Path,
metavar='FILE',
help='K module file to include (.json format from --to-module)',
)

link_parser = command_parser.add_parser(
'link', help='Link together 2 or more SMIR JSON files', parents=[kcli_args.logging_args]
Expand Down Expand Up @@ -530,6 +589,7 @@ def _parse_args(ns: Namespace) -> KMirOpts:
break_every_terminator=ns.break_every_terminator,
break_every_step=ns.break_every_step,
terminate_on_thunk=ns.terminate_on_thunk,
add_module=ns.add_module,
)
case 'link':
return LinkOpts(
Expand Down
25 changes: 21 additions & 4 deletions kmir/src/kmir/kmir.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,11 @@ def cut_point_rules(

@staticmethod
def from_kompiled_kore(
smir_info: SMIRInfo, target_dir: Path, bug_report: Path | None = None, symbolic: bool = True
smir_info: SMIRInfo,
target_dir: Path,
bug_report: Path | None = None,
symbolic: bool = True,
extra_module: Path | None = None,
) -> KMIR:
from .kompile import kompile_smir

Expand All @@ -128,6 +132,7 @@ def from_kompiled_kore(
target_dir=target_dir,
bug_report=bug_report,
symbolic=symbolic,
extra_module=extra_module,
)
return kompiled_smir.create_kmir(bug_report_file=bug_report)

Expand Down Expand Up @@ -213,7 +218,11 @@ def prove_rs(opts: ProveRSOpts) -> APRProof:

smir_info = SMIRInfo.from_file(target_path / 'smir.json')
kmir = KMIR.from_kompiled_kore(
smir_info, symbolic=True, bug_report=opts.bug_report, target_dir=target_path
smir_info,
symbolic=True,
bug_report=opts.bug_report,
target_dir=target_path,
extra_module=opts.add_module,
)
else:
_LOGGER.info(f'Constructing initial proof: {label}')
Expand All @@ -237,7 +246,11 @@ def prove_rs(opts: ProveRSOpts) -> APRProof:
_LOGGER.debug(f'Missing-body function symbols (first 5): {missing_body_syms[:5]}')

kmir = KMIR.from_kompiled_kore(
smir_info, symbolic=True, bug_report=opts.bug_report, target_dir=target_path
smir_info,
symbolic=True,
bug_report=opts.bug_report,
target_dir=target_path,
extra_module=opts.add_module,
)

apr_proof = kmir.apr_proof_from_smir(
Expand Down Expand Up @@ -267,7 +280,11 @@ def prove_rs(opts: ProveRSOpts) -> APRProof:
)

with kmir.kcfg_explore(label, terminate_on_thunk=opts.terminate_on_thunk) as kcfg_explore:
prover = APRProver(kcfg_explore, execute_depth=opts.max_depth, cut_point_rules=cut_point_rules)
prover = APRProver(
kcfg_explore,
execute_depth=opts.max_depth,
cut_point_rules=cut_point_rules,
)
prover.advance_proof(
apr_proof,
max_iterations=opts.max_iterations,
Expand Down
62 changes: 54 additions & 8 deletions kmir/src/kmir/kompile.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,11 +91,46 @@ def _digest_file(target_dir: Path) -> Path:
return target_dir / 'smir-digest.json'


def _load_extra_module_rules(kmir: KMIR, module_path: Path) -> list[Sentence]:
"""Load a K module from JSON and convert rules to Kore axioms.

Args:
kmir: KMIR instance with the definition
module_path: Path to JSON module file (from --to-module output.json)

Returns:
List of Kore axioms converted from the module rules
"""
from pyk.kast.outer import KFlatModule, KRule
from pyk.konvert import krule_to_kore

_LOGGER.info(f'Loading extra module rules: {module_path}')

if module_path.suffix != '.json':
_LOGGER.warning(f'Only JSON format is supported for --add-module: {module_path}')
return []

module_dict = json.loads(module_path.read_text())
k_module = KFlatModule.from_dict(module_dict)

axioms: list[Sentence] = []
for sentence in k_module.sentences:
if isinstance(sentence, KRule):
try:
axiom = krule_to_kore(kmir.definition, sentence)
axioms.append(axiom)
except Exception:
_LOGGER.warning(f'Failed to convert rule to Kore: {sentence}', exc_info=True)

return axioms


def kompile_smir(
smir_info: SMIRInfo,
target_dir: Path,
bug_report: Path | None = None,
symbolic: bool = True,
extra_module: Path | None = None,
) -> KompiledSMIR:
kompile_digest: KompileDigest | None = None
try:
Expand All @@ -120,8 +155,18 @@ def kompile_smir(
target_dir.mkdir(parents=True, exist_ok=True)

kmir = KMIR(HASKELL_DEF_DIR)
rules = make_kore_rules(kmir, smir_info)
_LOGGER.info(f'Generated {len(rules)} function equations to add to `definition.kore')
smir_rules: list[Sentence] = list(make_kore_rules(kmir, smir_info))
_LOGGER.info(f'Generated {len(smir_rules)} function equations to add to `definition.kore')

# Load and convert extra module rules if provided
# These are kept separate because LLVM backend doesn't support configuration rewrites
extra_rules: list[Sentence] = []
if extra_module is not None:
extra_rules = _load_extra_module_rules(kmir, extra_module)
_LOGGER.info(f'Added {len(extra_rules)} rules from extra module: {extra_module}')

# Combined rules for Haskell backend (supports both function equations and rewrites)
all_rules = smir_rules + extra_rules

if symbolic:
# Create output directories
Expand All @@ -131,11 +176,12 @@ def kompile_smir(
target_llvmdt_path.mkdir(parents=True, exist_ok=True)
target_hs_path.mkdir(parents=True, exist_ok=True)

# Process LLVM definition
# Process LLVM definition (only SMIR rules, not extra module rules)
# Extra module rules are configuration rewrites that LLVM backend doesn't support
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is getting a bit too complicated IMO. We should extract two helpers for the two cases of if symbolic (not in this PR though).

_LOGGER.info('Writing LLVM definition file')
llvm_def_file = LLVM_LIB_DIR / 'definition.kore'
llvm_def_output = target_llvm_lib_path / 'definition.kore'
_insert_rules_and_write(llvm_def_file, rules, llvm_def_output)
_insert_rules_and_write(llvm_def_file, smir_rules, llvm_def_output)

# Run llvm-kompile-matching and llvm-kompile for LLVM
# TODO use pyk to do this if possible (subprocess wrapper, maybe llvm-kompile itself?)
Expand All @@ -161,10 +207,10 @@ def kompile_smir(
check=True,
)

# Process Haskell definition
# Process Haskell definition (includes both SMIR rules and extra module rules)
_LOGGER.info('Writing Haskell definition file')
hs_def_file = HASKELL_DEF_DIR / 'definition.kore'
_insert_rules_and_write(hs_def_file, rules, target_hs_path / 'definition.kore')
_insert_rules_and_write(hs_def_file, all_rules, target_hs_path / 'definition.kore')

# Copy all files except definition.kore and binary from HASKELL_DEF_DIR to out/hs
_LOGGER.info('Copying other artefacts into HS output directory')
Expand All @@ -183,11 +229,11 @@ def kompile_smir(
_LOGGER.info(f'Creating directory {target_llvmdt_path}')
target_llvmdt_path.mkdir(parents=True, exist_ok=True)

# Process LLVM definition
# Process LLVM definition (only SMIR rules for concrete execution)
_LOGGER.info('Writing LLVM definition file')
llvm_def_file = LLVM_LIB_DIR / 'definition.kore'
llvm_def_output = target_llvm_path / 'definition.kore'
_insert_rules_and_write(llvm_def_file, rules, llvm_def_output)
_insert_rules_and_write(llvm_def_file, smir_rules, llvm_def_output)

import subprocess

Expand Down
9 changes: 9 additions & 0 deletions kmir/src/kmir/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ class ProveRSOpts(ProveOpts):
save_smir: bool
smir: bool
start_symbol: str
add_module: Path | None

def __init__(
self,
Expand Down Expand Up @@ -143,6 +144,7 @@ def __init__(
break_every_terminator: bool = False,
break_every_step: bool = False,
terminate_on_thunk: bool = False,
add_module: Path | None = None,
) -> None:
self.rs_file = rs_file
self.proof_dir = Path(proof_dir).resolve() if proof_dir is not None else None
Expand Down Expand Up @@ -170,6 +172,7 @@ def __init__(
self.break_every_terminator = break_every_terminator
self.break_every_step = break_every_step
self.terminate_on_thunk = terminate_on_thunk
self.add_module = add_module


@dataclass
Expand Down Expand Up @@ -204,6 +207,8 @@ class ShowOpts(DisplayOpts):
use_default_printer: bool
statistics: bool
leaves: bool
to_module: Path | None
minimize_proof: bool

def __init__(
self,
Expand All @@ -221,12 +226,16 @@ def __init__(
use_default_printer: bool = False,
statistics: bool = False,
leaves: bool = False,
to_module: Path | None = None,
minimize_proof: bool = False,
) -> None:
super().__init__(proof_dir, id, full_printer, smir_info, omit_current_body)
self.omit_static_info = omit_static_info
self.use_default_printer = use_default_printer
self.statistics = statistics
self.leaves = leaves
self.to_module = to_module
self.minimize_proof = minimize_proof
self.nodes = tuple(int(n.strip()) for n in nodes.split(',')) if nodes is not None else None

def _parse_pairs(text: str | None) -> tuple[tuple[int, int], ...] | None:
Expand Down
Loading