diff --git a/pyk/src/pyk/kast/inner.py b/pyk/src/pyk/kast/inner.py index 957981fd67f..18c3360a6ae 100644 --- a/pyk/src/pyk/kast/inner.py +++ b/pyk/src/pyk/kast/inner.py @@ -877,6 +877,22 @@ def _var_occurence(_term: KInner) -> None: return _var_occurrences +def keep_vars_sorted(occurrences: dict[str, list[KVariable]]) -> dict[str, KVariable]: + """Keep the sort of variables from the occurrences dictionary.""" + occurrences_sorted: dict[str, KVariable] = {} + for k, vs in occurrences.items(): + sort = None + for v in vs: + if v.sort is not None: + if sort is None: + sort = v.sort + elif sort != v.sort: + sort = None + break + occurrences_sorted[k] = KVariable(k, sort) + return occurrences_sorted + + def collect(callback: Callable[[KInner], None], kinner: KInner) -> None: """Collect information about a given term traversing it top-down using a function with side effects. diff --git a/pyk/src/pyk/kast/manip.py b/pyk/src/pyk/kast/manip.py index 4f4056d703a..dd34a624497 100644 --- a/pyk/src/pyk/kast/manip.py +++ b/pyk/src/pyk/kast/manip.py @@ -21,6 +21,7 @@ bottom_up, collect, flatten_label, + keep_vars_sorted, top_down, var_occurrences, ) @@ -780,23 +781,24 @@ def build_rule( lhs_vars = free_vars(init_term) rhs_vars = free_vars(final_term) - var_occurrences = count_vars( + occurrences = var_occurrences( mlAnd( [push_down_rewrites(KRewrite(init_config, final_config))] + init_constraints + final_constraints, GENERATED_TOP_CELL, ) ) + sorted_vars = keep_vars_sorted(occurrences) v_subst: dict[str, KVariable] = {} vremap_subst: dict[str, KVariable] = {} - for v in var_occurrences: + for v in occurrences: new_v = v - if var_occurrences[v] == 1: + if len(occurrences[v]) == 1: new_v = '_' + new_v if v in rhs_vars and v not in lhs_vars: new_v = '?' + new_v if new_v != v: - v_subst[v] = KVariable(new_v) - vremap_subst[new_v] = KVariable(v) + v_subst[v] = KVariable(new_v, sorted_vars[v].sort) + vremap_subst[new_v] = sorted_vars[v] new_init_config = Subst(v_subst)(init_config) new_init_constraints = [Subst(v_subst)(c) for c in init_constraints] diff --git a/pyk/src/tests/integration/proof/test_refute_node.py b/pyk/src/tests/integration/proof/test_refute_node.py index 34f99d4977c..7c887322689 100644 --- a/pyk/src/tests/integration/proof/test_refute_node.py +++ b/pyk/src/tests/integration/proof/test_refute_node.py @@ -295,7 +295,7 @@ def test_apr_proof_refute_node_to_claim( expected = KClaim( body=KRewrite(KApply('_<=Int_', KVariable('N', 'Int'), KToken('0', 'Int')), KToken('false', 'Bool')), requires=KApply( - '_<=Int_', KApply('_+Int_', KVariable('_L', None), KVariable('N', 'Int')), KToken('0', 'Int') + '_<=Int_', KApply('_+Int_', KVariable('_L', 'Int'), KVariable('N', 'Int')), KToken('0', 'Int') ), att=KAtt(entries=[Atts.LABEL('refute-node-claim')]), ) diff --git a/pyk/src/tests/unit/kast/test_inner.py b/pyk/src/tests/unit/kast/test_inner.py index 35c84dbe8ae..d34bc5fd342 100644 --- a/pyk/src/tests/unit/kast/test_inner.py +++ b/pyk/src/tests/unit/kast/test_inner.py @@ -5,7 +5,7 @@ import pytest -from pyk.kast.inner import flatten_label +from pyk.kast.inner import KVariable, flatten_label, keep_vars_sorted from ..utils import a, f, g, x, y, z @@ -40,3 +40,32 @@ def test_flatten_label(label: str, kast: KInner, expected: list[KInner]) -> None # Then assert actual == expected + + +KEEP_VARS_SORTED_DATA: Final[tuple[tuple[dict[str, list[KVariable]], dict[str, KVariable]], ...]] = ( + ( + {'a': [KVariable('a'), KVariable('a')], 'b': [KVariable('b'), KVariable('b')]}, + {'a': KVariable('a'), 'b': KVariable('b')}, + ), + ( + {'a': [KVariable('a', 'K'), KVariable('a', 'X')], 'b': [KVariable('b', 'K'), KVariable('b', 'X')]}, + {'a': KVariable('a'), 'b': KVariable('b')}, + ), + ( + {'a': [KVariable('a', 'K'), KVariable('a')], 'b': [KVariable('b', 'K'), KVariable('b', 'K')]}, + {'a': KVariable('a', 'K'), 'b': KVariable('b', 'K')}, + ), + ( + {'a': [KVariable('a', 'A'), KVariable('a'), KVariable('a', 'B')]}, + {'a': KVariable('a')}, + ), +) + + +@pytest.mark.parametrize('occurrences,expected', KEEP_VARS_SORTED_DATA, ids=count()) +def test_keep_vars_sorted(occurrences: dict[str, list[KVariable]], expected: dict[str, KVariable]) -> None: + # When + actual = keep_vars_sorted(occurrences) + + # Then + assert actual == expected diff --git a/pyk/src/tests/unit/test_cterm.py b/pyk/src/tests/unit/test_cterm.py index 50490870d35..82674b677d4 100644 --- a/pyk/src/tests/unit/test_cterm.py +++ b/pyk/src/tests/unit/test_cterm.py @@ -103,14 +103,14 @@ def test_cterm_match_with_constraint(t1: CTerm, t2: CTerm) -> None: BUILD_RULE_TEST_DATA: Final = ( ( - T(k(KVariable('K_CELL')), mem(KVariable('MEM_CELL'))), + T(k(KVariable('K_CELL', 'K')), mem(KVariable('MEM_CELL'))), T( k(KVariable('K_CELL')), mem(KApply('_[_<-_]', [KVariable('MEM_CELL'), KVariable('KEY'), KVariable('VALUE')])), ), ['K_CELL'], T( - k(KVariable('_K_CELL')), + k(KVariable('_K_CELL', 'K')), mem( KRewrite( KVariable('MEM_CELL'),