diff --git a/tests/functional/syntax/warnings/test_contract_size_limit_warning.py b/tests/functional/syntax/warnings/test_contract_size_limit_warning.py index 3e27304266..3a7b457b5d 100644 --- a/tests/functional/syntax/warnings/test_contract_size_limit_warning.py +++ b/tests/functional/syntax/warnings/test_contract_size_limit_warning.py @@ -15,9 +15,9 @@ def huge_bytestring(): def test_contract_size_exceeded(huge_bytestring): code = f""" @external -def a() -> bool: +def a() -> Bytes[24577]: q: Bytes[24577] = {huge_bytestring} - return True + return q """ with pytest.warns(vyper.warnings.ContractSizeLimit): vyper.compile_code(code, output_formats=["bytecode_runtime"]) diff --git a/tests/unit/compiler/venom/test_memory_copy_elision.py b/tests/unit/compiler/venom/test_memory_copy_elision.py new file mode 100644 index 0000000000..249f9ff5f8 --- /dev/null +++ b/tests/unit/compiler/venom/test_memory_copy_elision.py @@ -0,0 +1,704 @@ +import pytest + +from tests.venom_utils import PrePostChecker +from vyper.evm.opcodes import version_check +from vyper.venom.passes import MemoryCopyElisionPass + +_check_pre_post = PrePostChecker([MemoryCopyElisionPass], default_hevm=False) + + +def _check_no_change(pre): + _check_pre_post(pre, pre) + + +def test_load_store_no_elision(): + """ + Basic load-store test - single word copy is already optimal. + mload followed by mstore should NOT be changed. + """ + pre = """ + _global: + %1 = mload 100 + mstore %1, 200 + stop + """ + _check_no_change(pre) + + +def test_redundant_copy_elimination(): + """ + Test that copying to the same location is eliminated entirely. + """ + pre = """ + _global: + %1 = mload 100 + mstore 100, %1 + stop + """ + + post = """ + _global: + nop ; mstore 100, %1 [memory copy elision - redundant store] + nop ; %1 = mload 100 [memory copy elision - redundant load] + stop + """ + _check_pre_post(pre, post) + + +def test_mcopy_chain_optimization(): + """ + Test that mcopy chains are optimized. + A->B followed by B->C should become A->C. + """ + if not version_check(begin="cancun"): + return + + pre = """ + _global: + mcopy 200, 100, 32 + mcopy 300, 200, 32 + %1 = mload 300 + sink %1 + """ + + post = """ + _global: + nop ; mcopy 200, 100, 32 [memory copy elision - merged mcopy] + mcopy 300, 100, 32 + %1 = mload 300 + sink %1 + """ + _check_pre_post(pre, post) + + +def test_mcopy_redundant_elimination(): + """ + Test that mcopy with same src and dst is eliminated. + """ + if not version_check(begin="cancun"): + return + + pre = """ + _global: + mcopy 100, 100, 32 + stop + """ + + post = """ + _global: + nop ; mcopy 100, 100, 32 [memory copy elision - redundant mcopy] + stop + """ + _check_pre_post(pre, post) + + +def test_no_elision_with_intermediate_write(): + """ + Test that copy elision doesn't happen if there's an intermediate write + to the source location. + """ + pre = """ + _global: + %1 = mload 100 + mstore 100, 42 ; BARRIER - writes to source + mstore 200, %1 + %2 = mload 100 + %3 = mload 200 + sink %3, %2 + """ + _check_no_change(pre) + + +def test_no_elision_with_multiple_uses(): + """ + Test that copy elision doesn't happen if the loaded value has multiple uses. + """ + pre = """ + _global: + %1 = mload 100 + mstore 200, %1 + %2 = add %1, 1 ; Another use of %1 + %3 = mload 200 + sink %3 + """ + _check_no_change(pre) + + +def test_mcopy_chain_with_intermediate_read(): + """ + Test that mcopy chain optimization doesn't happen with intermediate reads + from the intermediate location. + """ + if not version_check(begin="cancun"): + return + + pre = """ + _global: + mcopy 200, 100, 32 + %1 = mload 200 ; BARRIER - read from intermediate location + mcopy 300, 200, 32 + mstore 400, %1 + %2 = mload 300 + %3 = mload 400 + sink %3, %2 + """ + _check_no_change(pre) + + +def test_mcopy_chain_with_size_mismatch(): + """ + Test that mcopy chains with different sizes are not merged. + """ + if not version_check(begin="cancun"): + return + + pre = """ + _global: + mcopy 200, 100, 32 + mcopy 300, 200, 64 ; Different size + %1 = mload 300 + sink %3 + """ + _check_no_change(pre) + + +def test_overlapping_memory_regions(): + """ + Test that overlapping memory regions prevent optimization. + """ + pre = """ + _global: + %1 = mload 100 + mstore 116, 42 ; BARRIER - overlaps with source [100-131] + mstore 200, %1 + %2 = mload 116 + %3 = mload 200 + sink %3, %2 + """ + _check_no_change(pre) + + +def test_call_instruction_clears_optimization(): + """ + Test that call instructions clear all tracked optimizations. + """ + pre = """ + _global: + %1 = mload 100 + %2 = call 0, 0, 0, 0, 0, 0, 0 ; BARRIER - can modify any memory + mstore 200, %1 + %3 = mload 200 + sink %3 + """ + _check_no_change(pre) + + +def test_multiple_load_store_pairs(): + """ + Test that multiple independent load-store pairs are not changed. + Single word copies are already optimal. + """ + pre = """ + _global: + %1 = mload 100 + %2 = mload 200 + mstore 300, %1 + mstore 400, %2 + %3 = mload 300 + %4 = mload 400 + sink %4, %3 + """ + _check_no_change(pre) + + +def test_mcopy_chain_longer(): + """ + Test longer mcopy chains: A->B->C->D should become A->D. + """ + if not version_check(begin="cancun"): + return + + pre = """ + _global: + mcopy 200, 100, 32 + mcopy 300, 200, 32 + mcopy 400, 300, 32 + %1 = mload 400 + sink %1 + """ + + post = """ + _global: + nop ; mcopy 200, 100, 32 [memory copy elision - merged mcopy] + nop ; mcopy 300, 200, 32 [memory copy elision - merged mcopy] + mcopy 400, 100, 32 + %1 = mload 400 + sink %1 + """ + _check_pre_post(pre, post) + + +def test_calldatacopy_barrier(): + """ + Test that calldatacopy acts as a barrier for optimizations. + """ + pre = """ + _global: + %1 = mload 100 + calldatacopy 200, 0, 32 ; BARRIER - writes to memory + mstore %1, 300 + %2 = mload 200 + sink %2 + """ + _check_no_change(pre) + + +def test_dloadbytes_barrier(): + """ + Test that dloadbytes acts as a barrier for optimizations. + """ + pre = """ + _global: + %1 = mload 100 + dloadbytes 200, 0, 32 ; BARRIER - writes to memory + mstore 300, %1 + %2 = mload 200 + %3 = mload 300 + sink %3, %2 + """ + _check_no_change(pre) + + +def test_calldatacopy_mcopy_chain(): + """ + Test that calldatacopy followed by mcopy can be optimized. + calldatacopy -> A, mcopy A -> B should become calldatacopy -> B. + """ + if not version_check(begin="cancun"): + return + + pre = """ + _global: + calldatacopy 100, 0, 32 ; Copy 32 bytes from calldata offset 0 to memory 100 + mcopy 200, 100, 32 ; Copy from 100 to 200 + %1 = mload 200 + sink %1 + """ + + post = """ + _global: + nop ; calldatacopy 100, 0, 32 [memory copy elision - merged calldatacopy] + calldatacopy 200, 0, 32 + %1 = mload 200 + sink %1 + """ + _check_pre_post(pre, post) + + +def test_codecopy_mcopy_chain(): + """ + Test that codecopy followed by mcopy can be optimized. + """ + if not version_check(begin="cancun"): + return + + pre = """ + _global: + codecopy 100, 10, 64 ; Copy 64 bytes from code offset 10 to memory 100 + mcopy 300, 100, 64 ; Copy from 100 to 300 + %1 = mload 300 + sink %1 + """ + + post = """ + _global: + nop ; codecopy 100, 10, 64 [memory copy elision - merged codecopy] + codecopy 300, 10, 64 + %1 = mload 300 + sink %1 + """ + _check_pre_post(pre, post) + + +def test_dloadbytes_mcopy_chain(): + """ + Test that dloadbytes followed by mcopy can be optimized. + """ + if not version_check(begin="cancun"): + return + + pre = """ + _global: + dloadbytes 100, 0, 32 ; Load 32 bytes from transient offset 0 to memory 100 + mcopy 200, 100, 32 ; Copy from 100 to 200 + %1 = mload 200 + sink %1 + """ + + post = """ + _global: + nop ; dloadbytes 100, 0, 32 [memory copy elision - merged dloadbytes] + dloadbytes 200, 0, 32 + %1 = mload 200 + sink %1 + """ + _check_pre_post(pre, post) + + +def test_special_copy_mcopy_chain_with_read(): + """ + Test that special copy + mcopy chain is NOT optimized if intermediate location is read. + """ + if not version_check(begin="cancun"): + return + + pre = """ + _global: + calldatacopy 100, 0, 32 ; Copy to intermediate location + %1 = mload 100 ; Read from intermediate location - BARRIER + mcopy 200, 100, 32 ; This cannot be merged + mstore 300, %1 + %1 = mload 300 + %2 = mload 200 + sink %2, %1 + """ + _check_no_change(pre) + + +def test_special_copy_mcopy_chain_size_mismatch(): + """ + Test that special copy + mcopy chain with different sizes are not merged. + """ + if not version_check(begin="cancun"): + return + + pre = """ + _global: + calldatacopy 100, 0, 32 ; Copy 32 bytes + mcopy 200, 100, 64 ; Try to copy 64 bytes - size mismatch + %1 = mload 200 + sink %1 + """ + _check_no_change(pre) + + +def test_special_copy_multiple_mcopy_chain(): + """ + Test that special copy followed by multiple mcopies can be optimized. + calldatacopy -> A, mcopy A -> B, mcopy B -> C should become calldatacopy -> C. + """ + if not version_check(begin="cancun"): + return + + pre = """ + _global: + calldatacopy 100, 0, 32 ; Copy from calldata to 100 + mcopy 200, 100, 32 ; Copy from 100 to 200 + mcopy 300, 200, 32 ; Copy from 200 to 300 + %1 = mload 300 + sink %1 + """ + + post = """ + _global: + nop ; calldatacopy 100, 0, 32 [memory copy elision - merged calldatacopy] + nop ; mcopy 200, 100, 32 [memory copy elision - merged mcopy] + calldatacopy 300, 0, 32 + %1 = mload 300 + sink %1 + """ + _check_pre_post(pre, post) + + +def test_inter_block_no_optimization(): + """ + Test that optimizations don't cross basic block boundaries. + Load and store in different blocks should not be optimized. + """ + pre = """ + _global: + %1 = mload 100 + jmp @label1 + + label1: + mstore 100, %1 ; Even though this is redundant, it's in a different block + %2 = mload 100 + sink %2 + """ + + # No optimization should happen - load and store are in different blocks + post = pre + + _check_pre_post(pre, post) + + +def test_mcopy_chain_across_blocks(): + """ + Test that mcopy chains don't merge across basic block boundaries. + """ + if not version_check(begin="cancun"): + return + + pre = """ + _global: + mcopy 200, 100, 32 + jmp @label1 + + label1: + mcopy 300, 200, 32 + %1 = mload 300 + sink %1 + """ + + # No optimization should happen - mcopies are in different blocks + post = pre + + _check_pre_post(pre, post) + + +def test_special_copy_chain_across_blocks(): + """ + Test that special copy + mcopy chains don't merge across basic block boundaries. + """ + if not version_check(begin="cancun"): + return + + pre = """ + _global: + calldatacopy 100, 0, 32 + jmp @label1 + + label1: + mcopy 200, 100, 32 + %1 = mload 200 + sink %1 + """ + + # No optimization should happen - copies are in different blocks + post = pre + + _check_pre_post(pre, post) + + +def test_conditional_branch_no_optimization(): + """ + Test that optimizations are conservative with conditional branches. + """ + pre = """ + _global: + %1 = mload 100 + %2 = iszero %1 + jnz %2, @label1, @label2 + + label1: + mstore 100, %1 ; Can't optimize - control flow dependent + %1 = mload 100 + sink %1 + + label2: + mstore 200, %1 ; Can't optimize - control flow dependent + %2 = mload 200 + sink %2 + """ + + # No optimization should happen + post = pre + + _check_pre_post(pre, post) + + +def test_special_copy_not_merged_with_hazard(): + """Test that special copy + mcopy chain is not merged when there's a hazard.""" + pre = """ + _global: + calldatacopy 100, 200, 32 + %1 = mload 100 + add %1, 1 + mstore 100, %1 + mcopy 200, 100, 32 + %2 = mload 200 + sink %2 + """ + + post = pre # No change - hazard prevents optimization + + _check_pre_post(pre, post) + + +def test_mem_elision_load_needed(): + pre = """ + main: + ; cannot remove this copy since + ; the mload uses this data + calldatacopy 100, 200, 64 + mcopy 300, 100, 64 + %1 = mload 100 + %2 = mload 300 + sink %2, %1 + """ + + post = """ + main: + calldatacopy 100, 200, 64 + calldatacopy 300, 200, 64 + %1 = mload 100 + %2 = mload 300 + sink %2, %1 + """ + + _check_pre_post(pre, post) + + +def test_mem_elision_load_needed_not_precise(): + pre = """ + main: + ; cannot remove this copy since + ; the mload uses this data + calldatacopy 100, 200, 64 + mcopy 300, 100, 64 + %1 = mload 132 + %2 = mload 332 + sink %2, %1 + """ + + post = """ + main: + calldatacopy 100, 200, 64 + calldatacopy 300, 200, 64 + %1 = mload 132 + %2 = mload 332 + sink %2, %1 + """ + + _check_pre_post(pre, post) + + +@pytest.mark.xfail +def test_mem_elision_msize(): + pre = """ + main: + ; you cannot nop both of + ; them since you need correct + ; msize (currently it does that) + %1 = mload 100 + mstore 100, %1 + %2 = msize + sink %2 + """ + + post = """ + main: + %1 = mload 100 + nop + %2 = msize + sink %2 + """ + + _check_pre_post(pre, post) + + +def test_remove_unused_writes(): + pre = """ + main: + %par = param + mstore 100, %par + mstore 300, %par + %cond = iszero %par + jnz %cond, @then, @else + then: + stop + ;%1 = mload 100 + ;sink %1 + else: + stop + ;%2 = mload 200 + ;sink %2 + """ + + post = """ + main: + %par = param + nop + nop + %cond = iszero %par + jnz %cond, @then, @else + then: + stop + else: + stop + """ + + _check_pre_post(pre, post) + + +def test_remove_unused_writes_with_read(): + pre = """ + main: + %par = param + mstore 100, %par + mstore 300, %par + %cond = iszero %par + jnz %cond, @then, @else + then: + %1 = mload 100 + sink %1 + else: + %2 = mload 100 + sink %2 + """ + + post = """ + main: + %par = param + mstore 100, %par + nop + %cond = iszero %par + jnz %cond, @then, @else + then: + %1 = mload 100 + sink %1 + else: + %2 = mload 100 + sink %2 + """ + + _check_pre_post(pre, post) + + +@pytest.mark.xfail +def test_remove_unused_writes_with_read_loop(): + pre = """ + main: + %par = param + mstore 100, %par + mstore 300, %par + jmp @cond + cond: + %cond = iszero %par + jnz %cond, @body, @after + body: + %1 = mload 100 + jmp @cond + after: + %2 = mload 100 + sink %2 + """ + + post = """ + main: + %par = param + mstore 100, %par + nop + jmp @cond + cond: + %cond = iszero %par + jnz %cond, @body, @after + body: + %1 = mload 100 + jmp @cond + after: + %2 = mload 100 + sink %2 + """ + + _check_pre_post(pre, post) diff --git a/vyper/venom/__init__.py b/vyper/venom/__init__.py index 0b13d8a4ec..a0020f7ecc 100644 --- a/vyper/venom/__init__.py +++ b/vyper/venom/__init__.py @@ -26,6 +26,7 @@ MakeSSA, Mem2Var, MemMergePass, + MemoryCopyElisionPass, PhiEliminationPass, ReduceLiteralsCodesize, RemoveUnusedVariablesPass, @@ -87,6 +88,7 @@ def _run_passes(fn: IRFunction, optimize: OptimizationLevel, ac: IRAnalysesCache SimplifyCFGPass(ac, fn).run_pass() MemMergePass(ac, fn).run_pass() + MemoryCopyElisionPass(ac, fn).run_pass() RemoveUnusedVariablesPass(ac, fn).run_pass() DeadStoreElimination(ac, fn).run_pass(addr_space=MEMORY) diff --git a/vyper/venom/analysis/__init__.py b/vyper/venom/analysis/__init__.py index 2b6b722d7b..6b0ddf4180 100644 --- a/vyper/venom/analysis/__init__.py +++ b/vyper/venom/analysis/__init__.py @@ -5,6 +5,7 @@ from .fcg import FCGAnalysis from .liveness import LivenessAnalysis from .mem_alias import MemoryAliasAnalysis +from .mem_overwrite_analysis import MemOverwriteAnalysis from .mem_ssa import MemSSA from .reachable import ReachableAnalysis from .var_definition import VarDefinition diff --git a/vyper/venom/analysis/mem_overwrite_analysis.py b/vyper/venom/analysis/mem_overwrite_analysis.py new file mode 100644 index 0000000000..c9b3ce6257 --- /dev/null +++ b/vyper/venom/analysis/mem_overwrite_analysis.py @@ -0,0 +1,107 @@ +from typing import Iterator + +from vyper.evm.address_space import MEMORY +from vyper.utils import OrderedSet +from vyper.venom.analysis import CFGAnalysis +from vyper.venom.analysis.analysis import IRAnalysis +from vyper.venom.basicblock import IRBasicBlock, IRInstruction +from vyper.venom.memory_location import MemoryLocation, get_read_location, get_write_location + +LatticeItem = OrderedSet[MemoryLocation] + + +def join(a: LatticeItem, b: LatticeItem) -> LatticeItem: + assert isinstance(a, OrderedSet) and isinstance(b, OrderedSet) + tmp = OrderedSet.intersection(a, b) + assert isinstance(tmp, OrderedSet) + return tmp + + +def carve_out(write: MemoryLocation, read: MemoryLocation) -> list[MemoryLocation]: + if not MemoryLocation.may_overlap(read, write): + return [write] + if not read.is_fixed or not write.is_fixed: + return [MemoryLocation.EMPTY] + + assert read.offset is not None + assert write.offset is not None + assert read.size is not None + assert write.size is not None + + a = (write.offset, read.offset) + b = (read.offset + read.size, write.offset + write.size) + res = [] + if a[0] < a[1]: + res.append(MemoryLocation(offset=a[0], size=a[1] - a[0])) + if b[0] < b[1]: + res.append(MemoryLocation(offset=b[0], size=b[1] - b[0])) + return res + + +class MemOverwriteAnalysis(IRAnalysis): + mem_rewriten: dict[IRBasicBlock, LatticeItem] + mem_start: dict[IRBasicBlock, LatticeItem] + + def analyze(self): + self.mem_rewriten = {bb: OrderedSet() for bb in self.function.get_basic_blocks()} + self.mem_start = {bb: OrderedSet() for bb in self.function.get_basic_blocks()} + self.cfg = self.analyses_cache.request_analysis(CFGAnalysis) + + order = self.cfg.dfs_post_walk + + while True: + change = False + for bb in order: + res = self._handle_bb(bb) + if self.mem_rewriten[bb] != res: + change = True + self.mem_rewriten[bb] = res + + if not change: + break + + def _handle_bb(self, bb: IRBasicBlock) -> LatticeItem: + succs = self.cfg.cfg_out(bb) + if len(succs) > 0: + lattice_item: LatticeItem = self.mem_rewriten[succs.first()].copy() + for succ in self.cfg.cfg_out(bb): + lattice_item = join(lattice_item, self.mem_rewriten[succ]) + elif bb.instructions[-1].opcode in ("stop", "sink"): + lattice_item: LatticeItem = OrderedSet([MemoryLocation.ALL]) + else: + lattice_item: LatticeItem = OrderedSet([]) + + + self.mem_start[bb] = lattice_item + + for inst in reversed(bb.instructions): + read_loc = get_read_location(inst, MEMORY) + write_loc = get_write_location(inst, MEMORY) + if write_loc != MemoryLocation.EMPTY and write_loc.is_fixed: + lattice_item.add(write_loc) + if not read_loc.is_fixed: + lattice_item = OrderedSet() + if read_loc.is_fixed: + tmp: LatticeItem = OrderedSet() + for loc in lattice_item: + tmp.addmany(carve_out(write=loc, read=read_loc)) + lattice_item = tmp + + return lattice_item + + def bb_iterator(self, bb: IRBasicBlock) -> Iterator[tuple[IRInstruction, LatticeItem]]: + lattice_item = self.mem_start[bb] + print(bb.label, lattice_item) + for inst in reversed(bb.instructions): + yield (inst, lattice_item) + read_loc = get_read_location(inst, MEMORY) + write_loc = get_write_location(inst, MEMORY) + if write_loc != MemoryLocation.EMPTY and write_loc.is_fixed: + lattice_item.add(write_loc) + if not read_loc.is_fixed: + lattice_item = OrderedSet() + if read_loc.is_fixed: + tmp: LatticeItem = OrderedSet() + for loc in lattice_item: + tmp.addmany(carve_out(write=loc, read=read_loc)) + lattice_item = tmp diff --git a/vyper/venom/memory_location.py b/vyper/venom/memory_location.py index 977c8a1c76..7558d7d3ec 100644 --- a/vyper/venom/memory_location.py +++ b/vyper/venom/memory_location.py @@ -21,6 +21,7 @@ class MemoryLocation: # Initialize after class definition EMPTY: ClassVar[MemoryLocation] UNDEFINED: ClassVar[MemoryLocation] + ALL: ClassVar[MemoryLocation] @property def is_offset_fixed(self) -> bool: @@ -126,6 +127,7 @@ def may_overlap(loc1: MemoryLocation, loc2: MemoryLocation) -> bool: MemoryLocation.EMPTY = MemoryLocation(offset=0, size=0) MemoryLocation.UNDEFINED = MemoryLocation(offset=None, size=None) +MemoryLocation.ALL = MemoryLocation(offset=0, size=2**256) def get_write_location(inst, addr_space: AddrSpace) -> MemoryLocation: diff --git a/vyper/venom/passes/__init__.py b/vyper/venom/passes/__init__.py index fe56bb823e..f0794519a0 100644 --- a/vyper/venom/passes/__init__.py +++ b/vyper/venom/passes/__init__.py @@ -12,6 +12,7 @@ from .make_ssa import MakeSSA from .mem2var import Mem2Var from .memmerging import MemMergePass +from .memory_copy_elision import MemoryCopyElisionPass from .normalization import NormalizationPass from .phi_elimination import PhiEliminationPass from .remove_unused_variables import RemoveUnusedVariablesPass diff --git a/vyper/venom/passes/memory_copy_elision.py b/vyper/venom/passes/memory_copy_elision.py new file mode 100644 index 0000000000..8391a29d72 --- /dev/null +++ b/vyper/venom/passes/memory_copy_elision.py @@ -0,0 +1,461 @@ +from vyper.evm.address_space import MEMORY +from vyper.venom.analysis import CFGAnalysis, DFGAnalysis, LivenessAnalysis, MemOverwriteAnalysis +from vyper.venom.basicblock import IRBasicBlock, IRInstruction, IRLiteral, IRVariable +from vyper.venom.effects import Effects +from vyper.venom.memory_location import MemoryLocation, get_read_location, get_write_location +from vyper.venom.passes.base_pass import InstUpdater, IRPass + + +class MemoryCopyElisionPass(IRPass): + """ + This pass elides useless memory copies. It identifies patterns where: + 1. A value is loaded from memory and immediately stored to another location + 2. The source memory is not modified between the load and store + 3. The value loaded is not used elsewhere + 4. Intermediate mcopy operations that can be combined or eliminated + + Common patterns optimized: + - %1 = mload src; mstore %1, dst -> mcopy 32, src, dst (or direct copy) + - Redundant copies where src and dst are the same + - Load-store pairs that can be eliminated entirely + - mcopy chains: mcopy A->B followed by mcopy B->C -> mcopy A->C + """ + + def run_pass(self): + self.dfg = self.analyses_cache.request_analysis(DFGAnalysis) + self.cfg = self.analyses_cache.request_analysis(CFGAnalysis) + self.updater = InstUpdater(self.dfg) + + for bb in self.function.get_basic_blocks(): + self._process_basic_block(bb) + + self._remove_unnecessary_effects() + + self.analyses_cache.invalidate_analysis(LivenessAnalysis) + + def _remove_unnecessary_effects(self): + self.mem_overwrite = self.analyses_cache.request_analysis(MemOverwriteAnalysis) + for bb in self.function.get_basic_blocks(): + self._remove_unnecessary_effects_bb(bb) + + def _remove_unnecessary_effects_bb(self, bb: IRBasicBlock): + for inst, state in self.mem_overwrite.bb_iterator(bb): + if inst.output is not None: + continue + write_loc = get_write_location(inst, MEMORY) + if write_loc == MemoryLocation.EMPTY: + continue + if not write_loc.is_fixed: + continue + overlap = [loc for loc in state if loc.completely_contains(write_loc)] + if len(overlap) > 0: + self.updater.nop(inst, annotation="remove unnecessery effects") + + def _process_basic_block(self, bb: IRBasicBlock): + """Process a basic block to find and elide memory copies.""" + # Track loads that could potentially be elided + # Maps variable -> (load_inst, src_location) + available_loads: dict[IRVariable, tuple[IRInstruction, MemoryLocation]] = {} + + # Track mcopy operations for chain optimization + # Maps destination location -> (mcopy_inst, src_location) + mcopy_chain: dict[int, tuple[IRInstruction, MemoryLocation]] = {} + + # Track memory writes to invalidate loads + for inst in bb.instructions.copy(): + if inst.opcode == "mload": + assert inst.output is not None + # Track the load if it has a literal source + if isinstance(inst.operands[0], IRLiteral): + src_loc = MemoryLocation.from_operands(inst.operands[0], 32) + available_loads[inst.output] = (inst, src_loc) + + elif inst.opcode == "mstore": + var, dst = inst.operands + + # Check if this is a load-store pair we can optimize + if isinstance(var, IRVariable) and isinstance(dst, IRLiteral): + if var in available_loads: + load_inst, src_loc = available_loads[var] + dst_loc = MemoryLocation.from_operands(dst, 32) + + # Check if we can elide this copy + if self._can_elide_copy(inst, src_loc, dst_loc, var): + self._elide_copy(load_inst, inst, src_loc, dst_loc) + # Remove from available loads since we've processed it + del available_loads[var] + continue + + # This store invalidates any loads that may alias with the destination + self._invalidate_aliasing_loads(available_loads, inst) + self._invalidate_mcopy_chain(mcopy_chain, inst) + + elif inst.opcode == "mcopy": + # Handle mcopy operations + src_loc = get_read_location(inst, MEMORY) + dst_loc = get_write_location(inst, MEMORY) + + # Only process if we have fixed locations + if src_loc.is_fixed and dst_loc.is_fixed: + assert src_loc.offset is not None # help mypy + assert dst_loc.offset is not None # help mypy + # Check for redundant copy (src == dst) + if src_loc.offset == dst_loc.offset and src_loc.size == dst_loc.size: + self.updater.nop(inst, annotation="[memory copy elision - redundant mcopy]") + continue + + # Check if this forms a chain with a previous copy + if src_loc.offset in mcopy_chain: + prev_inst, prev_src_loc = mcopy_chain[src_loc.offset] + + # Check if previous instruction is a special copy (calldatacopy, etc) + if prev_inst.opcode in ( + "calldatacopy", + "codecopy", + "returndatacopy", + "dloadbytes", + ): + # Can merge if sizes match and no hazards + if ( + prev_src_loc.size == src_loc.size == dst_loc.size + and self._can_merge_special_copy_chain(bb, prev_inst, inst, src_loc) + ): + # Replace mcopy with the special copy directly to final destination + # Need to update the destination operand + new_operands = list(prev_inst.operands) + # For these instructions, dst is the last operand + new_operands[-1] = inst.operands[2] # Use mcopy's destination + self.updater.update( + inst, + prev_inst.opcode, + new_operands, + annotation="[memory copy elision - merged special copy]", + ) + # Update chain tracking + del mcopy_chain[src_loc.offset] + # Track the new special copy in the chain for + # potential future merging + mcopy_chain[dst_loc.offset] = (inst, prev_src_loc) + continue + else: + # Regular mcopy chain + # Check if we can merge: A->B followed by B->C becomes A->C + if ( + prev_src_loc.size == src_loc.size == dst_loc.size + and self._can_merge_mcopy_chain( + bb, prev_inst, inst, prev_src_loc, src_loc, dst_loc + ) + ): + # Update current mcopy to copy from original source + # Internal order is [size, src, dst] + size_op = inst.operands[0] + dst_op = inst.operands[2] + assert prev_src_loc.offset is not None # help mypy + self.updater.update( + inst, "mcopy", [size_op, IRLiteral(prev_src_loc.offset), dst_op] + ) + # Update chain tracking + del mcopy_chain[src_loc.offset] + mcopy_chain[dst_loc.offset] = (inst, prev_src_loc) + continue + + # Track this mcopy for potential future chaining + mcopy_chain[dst_loc.offset] = (inst, src_loc) + + # mcopy invalidates overlapping loads but not mcopy chains + # (we handle mcopy chain invalidation separately) + self._invalidate_aliasing_loads_by_inst(available_loads, inst) + + elif inst.opcode in ("calldatacopy", "codecopy", "returndatacopy", "dloadbytes"): + # These also perform memory copies and can start chains + dst_loc = get_write_location(inst, MEMORY) + + # Only process if we have fixed destination + if dst_loc.is_fixed: + assert dst_loc.offset is not None # help mypy + # Track this copy for potential future chaining with mcopy + # For these instructions, src_loc represents the source data location + # which is not a memory location but rather calldata/code/returndata + # We'll use a special marker to indicate the source + src_marker = MemoryLocation(offset=-1, size=dst_loc.size) # Special marker + mcopy_chain[dst_loc.offset] = (inst, src_marker) + + self._invalidate_aliasing_loads_by_inst(available_loads, inst) + self._invalidate_mcopy_chain(mcopy_chain, inst, exclude_current=True) + + elif self._modifies_memory(inst): + # Conservative: clear all available loads if memory is modified + available_loads.clear() + mcopy_chain.clear() + + def _can_elide_copy( + self, + store_inst: IRInstruction, + src_loc: MemoryLocation, + dst_loc: MemoryLocation, + var: IRVariable, + ) -> bool: + """ + Check if a load-store pair can be elided. + + Conditions: + 1. The loaded value is only used by the store (no other uses) + 2. No memory writes between load and store that could alias with src + 3. The source and destination don't overlap (unless they're identical) + """ + # Check if the loaded value is only used by the store + uses = self.dfg.get_uses(var) + if len(uses) != 1 or store_inst not in uses: + return False + + # Check if src and dst are the same (redundant copy) + if src_loc.is_fixed and dst_loc.is_fixed: + if src_loc.offset == dst_loc.offset and src_loc.size == dst_loc.size: + # Redundant copy - can be eliminated entirely + return True + + return False + + def _can_merge_special_copy_chain( + self, + bb: IRBasicBlock, + special_copy: IRInstruction, + mcopy: IRInstruction, + intermediate_loc: MemoryLocation, + ) -> bool: + """ + Check if a special copy (calldatacopy, etc) followed by mcopy can be merged. + + Conditions: + 1. No memory writes between the two copies that alias with intermediate location + 2. The intermediate location is not read between the copies + """ + first_idx = bb.instructions.index(special_copy) + second_idx = bb.instructions.index(mcopy) + + # Check for operations between the two copies + for i in range(first_idx + 1, second_idx): + inst = bb.instructions[i] + + # Check if intermediate location is modified + if self._modifies_memory_at(inst, intermediate_loc): + return False + + # Check if intermediate location is read + if self._reads_memory_at(inst, intermediate_loc): + return False + + return True + + def _can_merge_mcopy_chain( + self, + bb: IRBasicBlock, + first_mcopy: IRInstruction, + second_mcopy: IRInstruction, + orig_src_loc: MemoryLocation, + intermediate_loc: MemoryLocation, + final_dst_loc: MemoryLocation, + ) -> bool: + """ + Check if two mcopy operations can be merged into one. + + Conditions: + 1. No memory writes between the two mcopies that alias with intermediate location + 2. The intermediate location is not read between the mcopies + 3. No overlap issues that would change semantics + """ + first_idx = bb.instructions.index(first_mcopy) + second_idx = bb.instructions.index(second_mcopy) + + # Check for operations between the two mcopies + for i in range(first_idx + 1, second_idx): + inst = bb.instructions[i] + + # Check if intermediate location is modified + if self._modifies_memory_at(inst, intermediate_loc): + return False + + # Check if intermediate location is read + if self._reads_memory_at(inst, intermediate_loc): + return False + + # Check if original source is modified + if self._modifies_memory_at(inst, orig_src_loc): + return False + + # Check for overlap issues + # If final destination overlaps with original source, merging could change semantics + if MemoryLocation.may_overlap(final_dst_loc, orig_src_loc): + return False + + return True + + def _elide_copy( + self, + load_inst: IRInstruction, + store_inst: IRInstruction, + src_loc: MemoryLocation, + dst_loc: MemoryLocation, + ): + """Elide a load-store pair by converting to a more efficient form.""" + # Check if this is a redundant copy (src == dst) + assert src_loc.offset == dst_loc.offset and src_loc.size == dst_loc.size + # Completely redundant - remove both instructions + # Must nop store first since it uses the load's output + self.updater.nop(store_inst, annotation="[memory copy elision - redundant store]") + self.updater.nop(load_inst, annotation="[memory copy elision - redundant load]") + + def _modifies_memory(self, inst: IRInstruction) -> bool: + """Check if an instruction modifies memory.""" + write_effects = inst.get_write_effects() + return Effects.MEMORY in write_effects or Effects.MSIZE in write_effects + + def _reads_memory(self, inst: IRInstruction) -> bool: + """Check if an instruction reads memory.""" + read_effects = inst.get_read_effects() + return Effects.MEMORY in read_effects + + def _modifies_memory_at(self, inst: IRInstruction, loc: MemoryLocation) -> bool: + """Check if an instruction modifies memory at a specific location.""" + if not self._modifies_memory(inst): + return False + + # For stores, check if they write to an aliasing location + if inst.opcode == "mstore": + _, dst = inst.operands + if isinstance(dst, IRLiteral): + write_loc = MemoryLocation.from_operands(dst, 32) + return MemoryLocation.may_overlap(write_loc, loc) + + elif inst.opcode in ("mcopy", "calldatacopy", "codecopy", "returndatacopy", "dloadbytes"): + assert len(inst.operands) == 3 + size_op = inst.operands[0] + dst_op = inst.operands[2] + if isinstance(size_op, IRLiteral) and isinstance(dst_op, IRLiteral): + write_loc = MemoryLocation.from_operands(dst_op, size_op) + return MemoryLocation.may_overlap(write_loc, loc) + + # Conservative: assume any other memory write could alias + return True + + def _reads_memory_at(self, inst: IRInstruction, loc: MemoryLocation) -> bool: + """Check if an instruction reads memory at a specific location.""" + if not self._reads_memory(inst): + return False + + if inst.opcode == "mload": + src = inst.operands[0] + if isinstance(src, IRLiteral): + read_loc = MemoryLocation.from_operands(src, 32) + return MemoryLocation.may_overlap(read_loc, loc) + + elif inst.opcode == "mcopy": + if len(inst.operands) >= 3: + size_op = inst.operands[0] + src_op = inst.operands[1] + if isinstance(size_op, IRLiteral) and isinstance(src_op, IRLiteral): + read_loc = MemoryLocation.from_operands(src_op, size_op) + return MemoryLocation.may_overlap(read_loc, loc) + + # Conservative: assume any other memory read could alias + return True + + def _invalidate_aliasing_loads( + self, + available_loads: dict[IRVariable, tuple[IRInstruction, MemoryLocation]], + store_inst: IRInstruction, + ): + """Remove any tracked loads that may alias with a store.""" + if store_inst.opcode != "mstore": + return + + _, dst = store_inst.operands + if not isinstance(dst, IRLiteral): + # Conservative: clear all if we can't determine the destination + available_loads.clear() + return + + store_loc = MemoryLocation.from_operands(dst, 32) + + # Remove any loads that may alias with this store + to_remove = [] + for var, (_, src_loc) in available_loads.items(): + if MemoryLocation.may_overlap(src_loc, store_loc): + to_remove.append(var) + + for var in to_remove: + del available_loads[var] + + def _invalidate_aliasing_loads_by_inst( + self, + available_loads: dict[IRVariable, tuple[IRInstruction, MemoryLocation]], + inst: IRInstruction, + ): + """Remove any tracked loads that may alias with a memory-writing instruction.""" + if inst.opcode not in ("mcopy", "calldatacopy", "codecopy", "returndatacopy", "dloadbytes"): + # Conservative: clear all for unknown memory writes + available_loads.clear() + return + + assert len(inst.operands) == 3 + size_op = inst.operands[0] + dst_op = inst.operands[2] + write_loc = MemoryLocation.from_operands(dst_op, size_op) + + if not write_loc.is_fixed: + # Conservative: clear all if we can't determine the destination + available_loads.clear() + return + + to_remove = [] + for var, (_, src_loc) in available_loads.items(): + if MemoryLocation.may_overlap(src_loc, write_loc): + to_remove.append(var) + + for var in to_remove: + del available_loads[var] + + def _invalidate_mcopy_chain( + self, + mcopy_chain: dict[int, tuple[IRInstruction, MemoryLocation]], + inst: IRInstruction, + exclude_current: bool = False, + ): + if inst.opcode not in ( + "mstore", + "mcopy", + "calldatacopy", + "codecopy", + "returndatacopy", + "dloadbytes", + ): + # Conservative: clear all + mcopy_chain.clear() + return + + write_loc = get_write_location(inst, MEMORY) + if not write_loc.is_fixed: + # Conservative: clear all + mcopy_chain.clear() + return + + to_remove = [] + for dst_offset, (tracked_inst, src_loc) in mcopy_chain.items(): + # Skip if this is the current instruction and exclude_current is True + if exclude_current and tracked_inst is inst: + continue + + dst_loc = MemoryLocation(offset=dst_offset, size=src_loc.size) + # Invalidate if the write aliases with either source or destination + # For special copies, src_loc.offset == -1, so we only check destination + if src_loc.offset == -1: # Special marker for non-memory sources + if MemoryLocation.may_overlap(write_loc, dst_loc): + to_remove.append(dst_offset) + else: + if MemoryLocation.may_overlap(write_loc, src_loc) or MemoryLocation.may_overlap( + write_loc, dst_loc + ): + to_remove.append(dst_offset) + + for offset in to_remove: + del mcopy_chain[offset]