diff --git a/tests/functional/codegen/features/test_constructor.py b/tests/functional/codegen/features/test_constructor.py index 182e2e2ff2..3a10680690 100644 --- a/tests/functional/codegen/features/test_constructor.py +++ b/tests/functional/codegen/features/test_constructor.py @@ -3,7 +3,6 @@ import pytest from tests.evm_backends.base_env import _compile -from vyper.exceptions import StackTooDeep from vyper.utils import method_id @@ -216,7 +215,6 @@ def get_foo() -> DynArray[DynArray[uint256, 3], 3]: assert c.get_foo() == [[37, 41, 73], [37041, 41073, 73037], [146, 123, 148]] -@pytest.mark.venom_xfail(raises=StackTooDeep, reason="stack scheduler regression") def test_initialise_nested_dynamic_array_2(env, get_contract): code = """ foo: DynArray[DynArray[DynArray[int128, 3], 3], 3] diff --git a/tests/functional/codegen/features/test_immutable.py b/tests/functional/codegen/features/test_immutable.py index 4707291662..0cca231182 100644 --- a/tests/functional/codegen/features/test_immutable.py +++ b/tests/functional/codegen/features/test_immutable.py @@ -1,7 +1,6 @@ import pytest from vyper.compiler.settings import OptimizationLevel -from vyper.exceptions import StackTooDeep @pytest.mark.parametrize( @@ -199,7 +198,6 @@ def get_idx_two() -> uint256: assert c.get_idx_two() == expected_values[2][2] -@pytest.mark.venom_xfail(raises=StackTooDeep, reason="stack scheduler regression") def test_nested_dynarray_immutable(get_contract): code = """ my_list: immutable(DynArray[DynArray[DynArray[int128, 3], 3], 3]) diff --git a/tests/functional/codegen/features/test_transient.py b/tests/functional/codegen/features/test_transient.py index 370e269cf9..2532def85b 100644 --- a/tests/functional/codegen/features/test_transient.py +++ b/tests/functional/codegen/features/test_transient.py @@ -2,7 +2,7 @@ from tests.utils import ZERO_ADDRESS from vyper.compiler import compile_code -from vyper.exceptions import EvmVersionException, StackTooDeep, VyperException +from vyper.exceptions import EvmVersionException, VyperException pytestmark = pytest.mark.requires_evm_version("cancun") @@ -343,7 +343,6 @@ def get_idx_two(_a: uint256, _b: uint256, _c: uint256) -> uint256: assert c.get_idx_two(*values) == expected_values[2][2] -@pytest.mark.venom_xfail(raises=StackTooDeep, reason="stack scheduler regression") def test_nested_dynarray_transient(get_contract, tx_failed, env): set_list = """ self.my_list = [ diff --git a/tests/functional/codegen/types/test_lists.py b/tests/functional/codegen/types/test_lists.py index a981987ce6..26cd16ed32 100644 --- a/tests/functional/codegen/types/test_lists.py +++ b/tests/functional/codegen/types/test_lists.py @@ -7,7 +7,7 @@ from tests.utils import check_precompile_asserts, decimal_to_int from vyper.compiler.settings import OptimizationLevel from vyper.evm.opcodes import version_check -from vyper.exceptions import ArrayIndexException, OverflowException, StackTooDeep, TypeMismatch +from vyper.exceptions import ArrayIndexException, OverflowException, TypeMismatch def _map_nested(f, xs): @@ -597,7 +597,6 @@ def bar(_baz: Foo[3]) -> String[96]: assert c.bar(c_input) == "Hello world!!!!" -@pytest.mark.venom_xfail(raises=StackTooDeep, reason="stack scheduler regression") def test_list_of_nested_struct_arrays(get_contract): code = """ struct Ded: diff --git a/tests/unit/compiler/venom/test_stack_spill.py b/tests/unit/compiler/venom/test_stack_spill.py new file mode 100644 index 0000000000..361ee2cc42 --- /dev/null +++ b/tests/unit/compiler/venom/test_stack_spill.py @@ -0,0 +1,261 @@ +import pytest + +from vyper.ir.compile_ir import Label +from vyper.venom.basicblock import IRLiteral, IRVariable +from vyper.venom.context import IRContext +from vyper.venom.parser import parse_venom +from vyper.venom.stack_model import StackModel +from vyper.venom.stack_spiller import StackSpiller +from vyper.venom.venom_to_assembly import VenomCompiler + + +def _build_stack(count: int) -> tuple[StackModel, list[IRLiteral]]: + stack = StackModel() + ops = [IRLiteral(i) for i in range(count)] + for op in ops: + stack.push(op) + return stack, ops + + +def _ops_only_strings(assembly) -> list[str]: + return [op for op in assembly if isinstance(op, str)] + + +def _dummy_dfg(): + class _DummyDFG: + def are_equivalent(self, a, b): + return False + + return _DummyDFG() + + +def test_swap_spills_deep_stack() -> None: + compiler = VenomCompiler(IRContext()) + stack, ops = _build_stack(40) + assembly: list = [] + + target = ops[-18] + before = stack._stack.copy() + + depth = stack.get_depth(target) + assert isinstance(depth, int) and depth < -16 + swap_idx = -depth + + compiler.spiller.swap(assembly, stack, depth) + + expected = before.copy() + top_index = len(expected) - 1 + target_index = expected.index(target) + expected[top_index], expected[target_index] = expected[target_index], expected[top_index] + assert stack._stack == expected + + ops_str = _ops_only_strings(assembly) + assert ops_str.count("MSTORE") == swap_idx + 1 + assert ops_str.count("MLOAD") == swap_idx + 1 + assert all(int(op[4:]) <= 16 for op in ops_str if op.startswith("SWAP")) + + +def test_dup_spills_deep_stack() -> None: + compiler = VenomCompiler(IRContext()) + stack, ops = _build_stack(40) + assembly: list = [] + + target = ops[-18] + before = stack._stack.copy() + + depth = stack.get_depth(target) + assert isinstance(depth, int) and depth < -16 + dup_idx = 1 - depth + + compiler.spiller.dup(assembly, stack, depth) + + expected = before.copy() + expected.append(target) + assert stack._stack == expected + + ops_str = _ops_only_strings(assembly) + assert ops_str.count("MSTORE") == dup_idx + assert ops_str.count("MLOAD") == dup_idx + 1 + assert all(int(op[3:]) <= 16 for op in ops_str if op.startswith("DUP")) + + +def test_stack_reorder_spills_before_swap() -> None: + ctx = IRContext() + compiler = VenomCompiler(ctx) + compiler.dfg = _dummy_dfg() + + compiler.spiller = StackSpiller(ctx, initial_offset=0x10000) + + stack = StackModel() + vars_on_stack = [IRVariable(f"%v{i}") for i in range(40)] + for var in vars_on_stack: + stack.push(var) + + spilled: dict = {} + assembly: list = [] + + target = vars_on_stack[21] # depth 18 from top for 40 items + + compiler._stack_reorder(assembly, stack, [target], spilled, dry_run=False) + + assert stack.get_depth(target) == 0 + assert len(spilled) == 2 # spilled top two values to reduce depth to <= 16 + + ops_str = _ops_only_strings(assembly) + assert ops_str.count("MSTORE") == 2 + assert all(int(op[4:]) <= 16 for op in ops_str if op.startswith("SWAP")) + + # restoring a spilled variable should reload it via MLOAD + restore_assembly: list = [] + spilled_var = next(iter(spilled)) + compiler.spiller.restore_spilled_operand(restore_assembly, stack, spilled, spilled_var) + restore_ops = _ops_only_strings(restore_assembly) + assert restore_ops.count("MLOAD") == 1 + assert spilled_var not in spilled + assert stack.get_depth(spilled_var) == 0 + + +def test_branch_spill_integration() -> None: + venom_src = """ + function spill_demo { + main: + %v0 = mload 0 + %v1 = mload 32 + %v2 = mload 64 + %v3 = mload 96 + %v4 = mload 128 + %v5 = mload 160 + %v6 = mload 192 + %v7 = mload 224 + %v8 = mload 256 + %v9 = mload 288 + %v10 = mload 320 + %v11 = mload 352 + %v12 = mload 384 + %v13 = mload 416 + %v14 = mload 448 + %v15 = mload 480 + %v16 = mload 512 + %v17 = mload 544 + %v18 = mload 576 + %v19 = mload 608 + %cond = mload 640 + jnz %cond, @then, @else + then: + %then_sum = add %v0, %v19 + %res_then = add %then_sum, %cond + jmp @join + else: + %else_sum = add %v1, %v19 + %res_else = add %else_sum, %cond + jmp @join + join: + %phi = phi @then, %res_then, @else, %res_else + %acc1 = add %phi, %v1 + %acc2 = add %acc1, %v2 + %acc3 = add %acc2, %v3 + %acc4 = add %acc3, %v4 + %acc5 = add %acc4, %v5 + %acc6 = add %acc5, %v6 + %acc7 = add %acc6, %v7 + %acc8 = add %acc7, %v8 + %acc9 = add %acc8, %v9 + %acc10 = add %acc9, %v10 + %acc11 = add %acc10, %v11 + %acc12 = add %acc11, %v12 + %acc13 = add %acc12, %v13 + %acc14 = add %acc13, %v14 + %acc15 = add %acc14, %v15 + %acc16 = add %acc15, %v16 + %acc17 = add %acc16, %v17 + %acc18 = add %acc17, %v18 + return %acc18 + } + """ + + ctx = parse_venom(venom_src) + compiler = VenomCompiler(ctx) + compiler.generate_evm_assembly() + + fn = next(iter(ctx.functions.values())) + assert any(inst.opcode == "alloca" for inst in fn.entry.instructions) + + asm = compiler.generate_evm_assembly() + opcodes = [op for op in asm if isinstance(op, str)] + + for op in opcodes: + if op.startswith("SWAP"): + assert int(op[4:]) <= 16 + if op.startswith("DUP"): + assert int(op[3:]) <= 16 + + def _find_spill_ops(kind: str) -> list[int]: + matches: list[int] = [] + idx = 0 + while idx < len(asm): + item = asm[idx] + if isinstance(item, str) and item.startswith("PUSH"): + try: + push_bytes = int(item[4:]) + except ValueError: + push_bytes = 0 + target_idx = idx + 1 + push_bytes + if target_idx < len(asm) and asm[target_idx] == kind: + matches.append(idx) + idx = target_idx + 1 + else: + idx += 1 + return matches + + store_indices = _find_spill_ops("MSTORE") + load_indices = _find_spill_ops("MLOAD") + assert store_indices + assert load_indices + + join_idx = next( + idx for idx, op in enumerate(asm) if isinstance(op, Label) and str(op) == "LABEL join" + ) + + assert any(idx < join_idx for idx in store_indices) + assert any(idx > join_idx for idx in store_indices) + assert any(idx < join_idx for idx in load_indices) + assert any(idx > join_idx for idx in load_indices) + + +def test_dup_op_operand_not_in_stack() -> None: + compiler = VenomCompiler(IRContext()) + stack = StackModel() + assembly: list = [] + + ops = [IRVariable(f"%{i}") for i in range(5)] + for op in ops: + stack.push(op) + + not_in_stack = IRVariable("%99") + + with pytest.raises(AssertionError): + compiler.dup_op(assembly, stack, not_in_stack) + + +def test_stack_reorder_operand_not_in_stack_but_spilled() -> None: + ctx = IRContext() + compiler = VenomCompiler(ctx) + compiler.dfg = _dummy_dfg() + + stack = StackModel() + for i in range(5): + stack.push(IRVariable(f"%{i}")) + + spilled_var = IRVariable("%spilled") + spilled: dict = {spilled_var: 0x10000} + + assembly: list = [] + + # Try to reorder with spilled_var as target (should restore it from memory) + compiler._stack_reorder(assembly, stack, [spilled_var], spilled, dry_run=False) + + # Should have restored the spilled variable + assert stack.get_depth(spilled_var) == 0 # Should be on top of stack + assert spilled_var not in spilled # Should have been removed from spilled dict + # Assembly should contain PUSH and MLOAD to restore + assert "MLOAD" in assembly diff --git a/vyper/utils.py b/vyper/utils.py index 5036af4d3b..d6c1a75d90 100644 --- a/vyper/utils.py +++ b/vyper/utils.py @@ -430,6 +430,7 @@ class MemoryPositions: FREE_VAR_SPACE = 0 FREE_VAR_SPACE2 = 32 RESERVED_MEMORY = 64 + STACK_SPILL_BASE = 0x10000 # scratch space used for spilling deep stacks # Sizes of different data types. Used to clamp types. diff --git a/vyper/venom/stack_spiller.py b/vyper/venom/stack_spiller.py new file mode 100644 index 0000000000..6170c9add1 --- /dev/null +++ b/vyper/venom/stack_spiller.py @@ -0,0 +1,242 @@ +from vyper.ir.compile_ir import PUSH +from vyper.utils import MemoryPositions, OrderedSet +from vyper.venom.basicblock import IRInstruction, IRLiteral, IROperand, IRVariable +from vyper.venom.context import IRContext +from vyper.venom.function import IRFunction +from vyper.venom.stack_model import StackModel + + +class StackSpiller: + """ + Manages stack spilling operations for deep stacks. + - Spilling operands to memory + - Restoring spilled operands from memory + - Managing spill slot allocation and deallocation + """ + + def __init__(self, ctx: IRContext, initial_offset: int | None = None): + self.ctx = ctx + self._spill_free_slots: list[int] = [] + self._spill_slot_offsets: dict[IRFunction, list[int]] = {} + self._spill_insert_index: dict[IRFunction, int] = {} + self._next_spill_offset = MemoryPositions.STACK_SPILL_BASE + if initial_offset is not None: + self._next_spill_offset = initial_offset + self._next_spill_alloca_id = 0 + self._current_function: IRFunction | None = None + + def set_current_function(self, fn: IRFunction | None) -> None: + """Set the current function being processed.""" + self._current_function = fn + if fn is not None: + self._prepare_spill_state(fn) + + def reset_spill_slots(self) -> None: + self._spill_free_slots = [] + + def _prepare_spill_state(self, fn: IRFunction) -> None: + if fn in self._spill_slot_offsets: + return + + entry = fn.entry + insert_idx = 0 + for inst in entry.instructions: + if inst.opcode == "param": + insert_idx += 1 + else: + break + + self._spill_slot_offsets[fn] = [] + self._spill_insert_index[fn] = insert_idx + + def spill_operand( + self, + assembly: list, + stack: StackModel, + spilled: dict[IROperand, int], + depth: int, + dry_run: bool = False, + ) -> None: + """Spill an operand from the stack to memory.""" + operand = stack.peek(depth) + assert isinstance(operand, IRVariable), operand + + if depth != 0: + self.swap(assembly, stack, depth, dry_run) + + offset = self._get_spill_slot(dry_run) + assembly.extend(PUSH(offset)) + assembly.append("MSTORE") + stack.pop() + spilled[operand] = offset + + def restore_spilled_operand( + self, + assembly: list, + stack: StackModel, + spilled: dict[IROperand, int], + op: IRVariable, + dry_run: bool = False, + ) -> None: + """Restore a spilled operand from memory to the stack.""" + offset = spilled.pop(op) + if not dry_run: + self._spill_free_slots.append(offset) + assembly.extend(PUSH(offset)) + assembly.append("MLOAD") + stack.push(op) + + def release_dead_spills( + self, spilled: dict[IROperand, int], live_set: OrderedSet[IRVariable] + ) -> None: + """Release memory slots for operands that are no longer live.""" + for op in list(spilled.keys()): + if isinstance(op, IRVariable) and op in live_set: + continue + offset = spilled.pop(op) + self._spill_free_slots.append(offset) + + def swap(self, assembly: list, stack: StackModel, depth: int, dry_run: bool = False) -> int: + """ + Swap operation that handles deep stacks via spilling. + + For stacks deeper than 16, spills the stack segment to memory, + then restores it in swapped order. + """ + # Swaps of the top is no op + if depth == 0: + return 0 + + swap_idx = -depth + if swap_idx < 1: + from vyper.exceptions import StackTooDeep + + raise StackTooDeep(f"Unsupported swap depth {swap_idx}") + + if swap_idx <= 16: + stack.swap(depth) + assembly.append(f"SWAP{swap_idx}") + return 1 + + # For deep stacks, use spill/restore technique + chunk_size = swap_idx + 1 + spill_ops, offsets, cost = self._spill_stack_segment(assembly, stack, chunk_size, dry_run) + + indices = list(range(chunk_size)) + if chunk_size == 1: + desired_indices = indices + else: + desired_indices = [indices[-1]] + indices[1:-1] + [indices[0]] + + cost += self._restore_spilled_segment( + assembly, stack, spill_ops, offsets, desired_indices, dry_run + ) + return cost + + def dup(self, assembly: list, stack: StackModel, depth: int, dry_run: bool = False) -> None: + """ + Dup operation that handles deep stacks via spilling. + + For stacks deeper than 16, spills the stack segment to memory, + then restores it with duplication. + """ + dup_idx = 1 - depth + if dup_idx < 1: + from vyper.exceptions import StackTooDeep + + raise StackTooDeep(f"Unsupported dup depth {dup_idx}") + + if dup_idx <= 16: + stack.dup(depth) + assembly.append(f"DUP{dup_idx}") + return + + # For deep stacks, use spill/restore technique + chunk_size = dup_idx + spill_ops, offsets, _ = self._spill_stack_segment(assembly, stack, chunk_size, dry_run) + + indices = list(range(chunk_size)) + desired_indices = [indices[-1]] + indices + + self._restore_spilled_segment(assembly, stack, spill_ops, offsets, desired_indices, dry_run) + + def _spill_stack_segment( + self, assembly: list, stack: StackModel, count: int, dry_run: bool + ) -> tuple[list[IROperand], list[int], int]: + """Spill a segment of the stack to memory.""" + spill_ops: list[IROperand] = [] + offsets: list[int] = [] + cost = 0 + + for _ in range(count): + op = stack.peek(0) + spill_ops.append(op) + + offset = self._acquire_spill_offset(dry_run) + offsets.append(offset) + + assembly.extend(PUSH(offset)) + assembly.append("MSTORE") + stack.pop() + cost += 2 + + return spill_ops, offsets, cost + + def _restore_spilled_segment( + self, + assembly: list, + stack: StackModel, + spill_ops: list[IROperand], + offsets: list[int], + desired_indices: list[int], + dry_run: bool, + ) -> int: + """Restore a spilled segment from memory to the stack.""" + cost = 0 + + for idx in reversed(desired_indices): + assembly.extend(PUSH(offsets[idx])) + assembly.append("MLOAD") + stack.push(spill_ops[idx]) + cost += 2 + + if not dry_run: + for offset in offsets: + self._spill_free_slots.append(offset) + + return cost + + def _get_spill_slot(self, dry_run: bool) -> int: + if dry_run: + return self._acquire_spill_offset(dry_run) + if self._current_function is None: + offset = self._next_spill_offset + self._next_spill_offset += 32 + return offset + return self._allocate_spill_slot(self._current_function) + + def _acquire_spill_offset(self, dry_run: bool) -> int: + if self._spill_free_slots: + return self._spill_free_slots.pop() if not dry_run else self._spill_free_slots[-1] + return self._get_spill_slot(dry_run) + + def _allocate_spill_slot(self, fn: IRFunction) -> int: + entry = fn.entry + insert_idx = self._spill_insert_index[fn] + + offset = self._next_spill_offset + self._next_spill_offset += 32 + + offset_lit = IRLiteral(offset) + size_lit = IRLiteral(32) + id_lit = IRLiteral(self._next_spill_alloca_id) + self._next_spill_alloca_id += 1 + + output_var = fn.get_next_variable() + + inst = IRInstruction("alloca", [offset_lit, size_lit, id_lit], [output_var]) + entry.instructions.insert(insert_idx, inst) + self._spill_insert_index[fn] += 1 + + self._spill_slot_offsets[fn].append(offset) + return offset diff --git a/vyper/venom/venom_to_assembly.py b/vyper/venom/venom_to_assembly.py index fd9c0e5678..c102b3f648 100644 --- a/vyper/venom/venom_to_assembly.py +++ b/vyper/venom/venom_to_assembly.py @@ -3,7 +3,7 @@ from typing import Any, Iterable from vyper.evm.assembler.instructions import DATA_ITEM, PUSH, DataHeader -from vyper.exceptions import CompilerPanic, StackTooDeep +from vyper.exceptions import CompilerPanic from vyper.ir.compile_ir import ( PUSH_OFST, PUSHLABEL, @@ -26,6 +26,7 @@ ) from vyper.venom.context import IRContext, IRFunction from vyper.venom.stack_model import StackModel +from vyper.venom.stack_spiller import StackSpiller DEBUG_SHOW_COST = False if DEBUG_SHOW_COST: @@ -153,6 +154,7 @@ def __init__(self, ctx: IRContext): self.ctx = ctx self.label_counter = 0 self.visited_basicblocks = OrderedSet() + self.spiller = StackSpiller(ctx) def mklabel(self, name: str) -> Label: self.label_counter += 1 @@ -173,7 +175,11 @@ def generate_evm_assembly(self, no_optimize: bool = False) -> list[AssemblyInstr assert self.cfg.is_normalized(), "Non-normalized CFG!" - self._generate_evm_for_basicblock_r(asm, fn.entry, StackModel()) + self.spiller.set_current_function(fn) + self.spiller.reset_spill_slots() + + self._generate_evm_for_basicblock_r(asm, fn.entry, StackModel(), {}) + self.spiller.set_current_function(None) asm.extend(_REVERT_POSTAMBLE) # Append data segment @@ -197,11 +203,18 @@ def generate_evm_assembly(self, no_optimize: bool = False) -> list[AssemblyInstr return asm def _stack_reorder( - self, assembly: list, stack: StackModel, stack_ops: list[IROperand], dry_run: bool = False + self, + assembly: list, + stack: StackModel, + stack_ops: list[IROperand], + spilled: dict[IROperand, int], + dry_run: bool = False, ) -> int: if dry_run: assert len(assembly) == 0, "Dry run should not work on assembly" stack = stack.copy() + spilled = spilled.copy() + spill_free_snapshot = self.spiller._spill_free_slots.copy() if len(stack_ops) == 0: return 0 @@ -213,10 +226,25 @@ def _stack_reorder( cost = 0 for i, op in enumerate(stack_ops): final_stack_depth = -(len(stack_ops) - i - 1) + depth = stack.get_depth(op) if depth == StackModel.NOT_IN_STACK: - raise CompilerPanic(f"Variable {op} not in stack") + if isinstance(op, IRVariable) and op in spilled: + self.spiller.restore_spilled_operand( + assembly, stack, spilled, op, dry_run=dry_run + ) + depth = stack.get_depth(op) + else: + raise CompilerPanic(f"Variable {op} not in stack") + + if depth < -16: + if not self._reduce_depth_via_spill( + assembly, stack, spilled, stack_ops, op, depth, dry_run + ): + depth = stack.get_depth(op) + else: + depth = stack.get_depth(op) if depth == final_stack_depth: continue @@ -228,13 +256,59 @@ def _stack_reorder( stack.poke(depth, to_swap) continue - cost += self.swap(assembly, stack, depth) - cost += self.swap(assembly, stack, final_stack_depth) + cost += self.spiller.swap(assembly, stack, depth, dry_run) + cost += self.spiller.swap(assembly, stack, final_stack_depth, dry_run) assert stack._stack[-len(stack_ops) :] == stack_ops, (stack, stack_ops) + if dry_run: + self.spiller._spill_free_slots = spill_free_snapshot + return cost + def _reduce_depth_via_spill( + self, + assembly: list, + stack: StackModel, + spilled: dict[IROperand, int], + stack_ops: list[IROperand], + target_op: IROperand, + depth: int, + dry_run: bool, + ) -> bool: + while depth < -16: + candidate_depth = self._select_spill_candidate(stack, stack_ops, depth) + if candidate_depth is None: + return False + self.spiller.spill_operand(assembly, stack, spilled, candidate_depth, dry_run) + depth = stack.get_depth(target_op) + if depth == StackModel.NOT_IN_STACK: + if isinstance(target_op, IRVariable) and target_op in spilled: + self.spiller.restore_spilled_operand( + assembly, stack, spilled, target_op, dry_run + ) + depth = stack.get_depth(target_op) + else: + return False + return True + + def _select_spill_candidate( + self, stack: StackModel, stack_ops: list[IROperand], target_depth: int + ) -> int | None: + forbidden = set(stack_ops) + max_offset = min(16, -target_depth - 1, stack.height - 1) + if max_offset < 0: + return None + for offset in range(0, max_offset + 1): + depth = -offset + candidate = stack.peek(depth) + if candidate in forbidden: + continue + if not isinstance(candidate, IRVariable): + continue + return depth + return None + def _emit_input_operands( self, assembly: list, @@ -242,6 +316,7 @@ def _emit_input_operands( ops: list[IROperand], stack: StackModel, next_liveness: OrderedSet[IRVariable], + spilled: dict[IROperand, int], ) -> None: # PRE: we already have all the items on the stack that have # been scheduled to be killed. now it's just a matter of emitting @@ -252,6 +327,9 @@ def _emit_input_operands( seen: set[IROperand] = set() for op in ops: + if isinstance(op, IRVariable) and op in spilled: + self.spiller.restore_spilled_operand(assembly, stack, spilled, op) + if isinstance(op, IRLabel): # invoke emits the actual instruction itself so we don't need # to emit it here but we need to add it to the stack map @@ -314,7 +392,7 @@ def popmany(self, asm, to_pop: Iterable[IRVariable], stack): deepest = min(depths) expected = list(range(deepest, 0)) if deepest < 0 and -deepest <= 16 and sorted(depths) == expected: - self.swap(asm, stack, deepest) + self.spiller.swap(asm, stack, deepest) self.pop(asm, stack, len(to_pop)) return @@ -327,11 +405,11 @@ def popmany(self, asm, to_pop: Iterable[IRVariable], stack): depth = stack.get_depth(var) if depth != 0: - self.swap(asm, stack, depth) + self.spiller.swap(asm, stack, depth) self.pop(asm, stack) def _generate_evm_for_basicblock_r( - self, asm: list, basicblock: IRBasicBlock, stack: StackModel + self, asm: list, basicblock: IRBasicBlock, stack: StackModel, spilled: dict[IROperand, int] ) -> None: if basicblock in self.visited_basicblocks: return @@ -361,7 +439,7 @@ def _generate_evm_for_basicblock_r( else: next_liveness = self.liveness.out_vars(basicblock) - asm.extend(self._generate_evm_for_instruction(inst, stack, next_liveness)) + asm.extend(self._generate_evm_for_instruction(inst, stack, next_liveness, spilled)) if DEBUG_SHOW_COST: print(" ".join(map(str, asm)), file=sys.stderr) @@ -370,7 +448,7 @@ def _generate_evm_for_basicblock_r( ref.extend(asm) for bb in self.cfg.cfg_out(basicblock): - self._generate_evm_for_basicblock_r(ref, bb, stack.copy()) + self._generate_evm_for_basicblock_r(ref, bb, stack.copy(), spilled.copy()) # pop values from stack at entry to bb # note this produces the same result(!) no matter which basic block @@ -395,7 +473,11 @@ def clean_stack_from_cfg_in( self.popmany(asm, to_pop, stack) def _generate_evm_for_instruction( - self, inst: IRInstruction, stack: StackModel, next_liveness: OrderedSet + self, + inst: IRInstruction, + stack: StackModel, + next_liveness: OrderedSet, + spilled: dict[IROperand, int], ) -> list[str]: assembly: list[AssemblyInstruction] = [] opcode = inst.opcode @@ -446,7 +528,7 @@ def _generate_evm_for_instruction( if to_be_replaced in next_liveness: # this branch seems unreachable (maybe due to make_ssa) # %13/%14 is still live(!), so we make a copy of it - self.dup(assembly, stack, depth) + self.spiller.dup(assembly, stack, depth) stack.poke(0, ret) else: stack.poke(depth, ret) @@ -460,7 +542,7 @@ def _generate_evm_for_instruction( return apply_line_numbers(inst, assembly) # Step 2: Emit instruction's input operands - self._emit_input_operands(assembly, inst, operands, stack, next_liveness) + self._emit_input_operands(assembly, inst, operands, stack, next_liveness, spilled) # Step 3: Reorder stack before join points if opcode == "jmp": @@ -475,16 +557,16 @@ def _generate_evm_for_instruction( assert len(self.cfg.cfg_in(next_bb)) > 1 target_stack = self.liveness.input_vars_from(inst.parent, next_bb) - self._stack_reorder(assembly, stack, list(target_stack)) + self._stack_reorder(assembly, stack, list(target_stack), spilled) if inst.is_commutative: - cost_no_swap = self._stack_reorder([], stack, operands, dry_run=True) + cost_no_swap = self._stack_reorder([], stack, operands, spilled, dry_run=True) operands[-1], operands[-2] = operands[-2], operands[-1] - cost_with_swap = self._stack_reorder([], stack, operands, dry_run=True) + cost_with_swap = self._stack_reorder([], stack, operands, spilled, dry_run=True) if cost_with_swap > cost_no_swap: operands[-1], operands[-2] = operands[-2], operands[-1] - cost = self._stack_reorder([], stack, operands, dry_run=True) + cost = self._stack_reorder([], stack, operands, spilled, dry_run=True) if DEBUG_SHOW_COST and cost: print("ENTER", inst, file=sys.stderr) print(" HAVE", stack, file=sys.stderr) @@ -493,7 +575,7 @@ def _generate_evm_for_instruction( # final step to get the inputs to this instruction ordered # correctly on the stack - self._stack_reorder(assembly, stack, operands) + self._stack_reorder(assembly, stack, operands, spilled) # some instructions (i.e. invoke) need to do stack manipulations # with the stack model containing the return value(s), so we fiddle @@ -616,6 +698,8 @@ def _generate_evm_for_instruction( # Use the top-most surviving output to schedule self._optimistic_swap(assembly, inst, next_liveness, stack) + self.spiller.release_dead_spills(spilled, next_liveness) + return apply_line_numbers(inst, assembly) def _optimistic_swap(self, assembly, inst, next_liveness, stack): @@ -640,7 +724,9 @@ def _optimistic_swap(self, assembly, inst, next_liveness, stack): if len(inst_outputs) > 0: current_top_out = inst_outputs[-1] if not self.dfg.are_equivalent(current_top_out, next_scheduled): - cost = self.swap_op(assembly, stack, next_scheduled) + depth = stack.get_depth(next_scheduled) + if depth is not StackModel.NOT_IN_STACK: + cost = self.spiller.swap(assembly, stack, depth) if DEBUG_SHOW_COST and cost != 0: print("ENTER", inst, file=sys.stderr) @@ -652,39 +738,12 @@ def pop(self, assembly, stack, num=1): stack.pop(num) assembly.extend(["POP"] * num) - def swap(self, assembly, stack, depth) -> int: - # Swaps of the top is no op - if depth == 0: - return 0 - - stack.swap(depth) - assembly.append(_evm_swap_for(depth)) - return 1 - - def dup(self, assembly, stack, depth): - stack.dup(depth) - assembly.append(_evm_dup_for(depth)) - def swap_op(self, assembly, stack, op): depth = stack.get_depth(op) assert depth is not StackModel.NOT_IN_STACK, f"Cannot swap non-existent operand {op}" - return self.swap(assembly, stack, depth) + return self.spiller.swap(assembly, stack, depth) def dup_op(self, assembly, stack, op): depth = stack.get_depth(op) assert depth is not StackModel.NOT_IN_STACK, f"Cannot dup non-existent operand {op}" - self.dup(assembly, stack, depth) - - -def _evm_swap_for(depth: int) -> str: - swap_idx = -depth - if not (1 <= swap_idx <= 16): - raise StackTooDeep(f"Unsupported swap depth {swap_idx}") - return f"SWAP{swap_idx}" - - -def _evm_dup_for(depth: int) -> str: - dup_idx = 1 - depth - if not (1 <= dup_idx <= 16): - raise StackTooDeep(f"Unsupported dup depth {dup_idx}") - return f"DUP{dup_idx}" + self.spiller.dup(assembly, stack, depth)