diff --git a/tests/functional/builtins/codegen/test_create_functions.py b/tests/functional/builtins/codegen/test_create_functions.py index 085d012759..3b7200f218 100644 --- a/tests/functional/builtins/codegen/test_create_functions.py +++ b/tests/functional/builtins/codegen/test_create_functions.py @@ -5,9 +5,8 @@ import vyper.ir.compile_ir as compile_ir from tests.utils import ZERO_ADDRESS -from vyper.codegen.ir_node import IRnode from vyper.compiler import compile_code -from vyper.compiler.settings import OptimizationLevel +from vyper.ir.compile_ir import DATA_ITEM, PUSH, PUSHLABEL, DataHeader, Label from vyper.utils import EIP_170_LIMIT, ERC5202_PREFIX, checksum_encode, keccak256 @@ -295,10 +294,19 @@ def test(code_ofst: uint256) -> address: # deploy a blueprint contract whose contained initcode contains only # zeroes (so no matter which offset, create_from_blueprint will # return empty code) - ir = IRnode.from_list(["deploy", 0, ["seq"] + ["stop"] * initcode_len, 0]) - bytecode, _ = compile_ir.assembly_to_evm( - compile_ir.compile_to_assembly(ir, optimize=OptimizationLevel.NONE) - ) + asm = [ + *PUSH(initcode_len), + PUSHLABEL(Label("end")), + *PUSH(0), + "CODECOPY", + *PUSH(initcode_len), + *PUSH(0), + "RETURN", + DataHeader(Label("end")), + DATA_ITEM(b"\x00" * initcode_len), + ] + bytecode, _ = compile_ir.assembly_to_evm(asm) + # manually deploy the bytecode c = env.deploy(abi=[], bytecode=bytecode) blueprint_address = c.address diff --git a/tests/functional/codegen/features/test_init.py b/tests/functional/codegen/features/test_init.py index 84d224f632..1e37b89701 100644 --- a/tests/functional/codegen/features/test_init.py +++ b/tests/functional/codegen/features/test_init.py @@ -15,13 +15,13 @@ def __init__(a: uint256): assert c.val() == 123 # Make sure the init code does not access calldata - assembly = vyper.compile_code(code, output_formats=["asm"])["asm"].split(" ") - ir_return_idx_start = assembly.index("{") - ir_return_idx_end = assembly.index("}") + compiler_output = vyper.compile_code(code, output_formats=["asm", "asm_runtime"]) + asm_deploytime = compiler_output["asm"] + asm_runtime = compiler_output["asm_runtime"] - assert "CALLDATALOAD" in assembly - assert "CALLDATACOPY" not in assembly[:ir_return_idx_start] + assembly[ir_return_idx_end:] - assert "CALLDATALOAD" not in assembly[:ir_return_idx_start] + assembly[ir_return_idx_end:] + assert "CALLDATALOAD" in asm_runtime + assert "CALLDATACOPY" not in asm_deploytime + assert "CALLDATALOAD" not in asm_deploytime def test_init_calls_internal(get_contract, assert_compile_failed, tx_failed): diff --git a/tests/functional/codegen/test_selector_table_stability.py b/tests/functional/codegen/test_selector_table_stability.py index 4416b5f5ea..a1f58d6852 100644 --- a/tests/functional/codegen/test_selector_table_stability.py +++ b/tests/functional/codegen/test_selector_table_stability.py @@ -9,13 +9,14 @@ def test_dense_jumptable_stability(): code = "\n".join(f"@external\ndef {name}():\n pass" for name in function_names) output = compile_code( - code, output_formats=["asm"], settings=Settings(optimize=OptimizationLevel.CODESIZE) + code, output_formats=["asm_runtime"], settings=Settings(optimize=OptimizationLevel.CODESIZE) ) # test that the selector table data is stable across different runs # (xdist should provide different PYTHONHASHSEEDs). - expected_asm = """{ DATA _sym_BUCKET_HEADERS b\'\\x0bB\' _sym_bucket_0 b\'\\n\' b\'+\\x8d\' _sym_bucket_1 b\'\\x0c\' b\'\\x00\\x85\' _sym_bucket_2 b\'\\x08\' } { DATA _sym_bucket_1 b\'\\xd8\\xee\\xa1\\xe8\' _sym_external 6 foo6()3639517672 b\'\\x05\' b\'\\xd2\\x9e\\xe0\\xf9\' _sym_external 0 foo0()3533627641 b\'\\x05\' b\'\\x05\\xf1\\xe0_\' _sym_external 2 foo2()99737695 b\'\\x05\' b\'\\x91\\t\\xb4{\' _sym_external 23 foo23()2433332347 b\'\\x05\' b\'np3\\x7f\' _sym_external 11 foo11()1852846975 b\'\\x05\' b\'&\\xf5\\x96\\xf9\' _sym_external 13 foo13()653629177 b\'\\x05\' b\'\\x04ga\\xeb\' _sym_external 14 foo14()73884139 b\'\\x05\' b\'\\x89\\x06\\xad\\xc6\' _sym_external 17 foo17()2298916294 b\'\\x05\' b\'\\xe4%\\xac\\xd1\' _sym_external 4 foo4()3827674321 b\'\\x05\' b\'yj\\x01\\xac\' _sym_external 7 foo7()2036990380 b\'\\x05\' b\'\\xf1\\xe6K\\xe5\' _sym_external 29 foo29()4058401765 b\'\\x05\' b\'\\xd2\\x89X\\xb8\' _sym_external 3 foo3()3532216504 b\'\\x05\' } { DATA _sym_bucket_2 b\'\\x06p\\xffj\' _sym_external 25 foo25()108068714 b\'\\x05\' b\'\\x964\\x99I\' _sym_external 24 foo24()2520029513 b\'\\x05\' b\'s\\x81\\xe7\\xc1\' _sym_external 10 foo10()1937893313 b\'\\x05\' b\'\\x85\\xad\\xc11\' _sym_external 28 foo28()2242756913 b\'\\x05\' b\'\\xfa"\\xb1\\xed\' _sym_external 5 foo5()4196577773 b\'\\x05\' b\'A\\xe7[\\x05\' _sym_external 22 foo22()1105681157 b\'\\x05\' b\'\\xd3\\x89U\\xe8\' _sym_external 1 foo1()3548993000 b\'\\x05\' b\'hL\\xf8\\xf3\' _sym_external 20 foo20()1749874931 b\'\\x05\' } { DATA _sym_bucket_0 b\'\\xee\\xd9\\x1d\\xe3\' _sym_external 9 foo9()4007206371 b\'\\x05\' b\'a\\xbc\\x1ch\' _sym_external 16 foo16()1639717992 b\'\\x05\' b\'\\xd3*\\xa7\\x0c\' _sym_external 21 foo21()3542787852 b\'\\x05\' b\'\\x18iG\\xd9\' _sym_external 19 foo19()409552857 b\'\\x05\' b\'\\n\\xf1\\xf9\\x7f\' _sym_external 18 foo18()183630207 b\'\\x05\' b\')\\xda\\xd7`\' _sym_external 27 foo27()702207840 b\'\\x05\' b\'2\\xf6\\xaa\\xda\' _sym_external 12 foo12()855026394 b\'\\x05\' b\'\\xbe\\xb5\\x05\\xf5\' _sym_external 15 foo15()3199534581 b\'\\x05\' b\'\\xfc\\xa7_\\xe6\' _sym_external 8 foo8()4238827494 b\'\\x05\' b\'\\x1b\\x12C8\' _sym_external 26 foo26()454181688 b\'\\x05\' } }""" # noqa: E501, FS003 - assert expected_asm in output["asm"] + expected_asm = """DATA BUCKET_HEADERS:\n DATABYTES 0b42\n DATALABEL bucket_0\n DATABYTES 0a\n DATABYTES 2b8d\n DATALABEL bucket_1\n DATABYTES 0c\n DATABYTES 0085\n DATALABEL bucket_2\n DATABYTES 08\n\nDATA bucket_1:\n DATABYTES d8eea1e8\n DATALABEL external 6 foo6()3639517672\n DATABYTES 05\n DATABYTES d29ee0f9\n DATALABEL external 0 foo0()3533627641\n DATABYTES 05\n DATABYTES 05f1e05f\n DATALABEL external 2 foo2()99737695\n DATABYTES 05\n DATABYTES 9109b47b\n DATALABEL external 23 foo23()2433332347\n DATABYTES 05\n DATABYTES 6e70337f\n DATALABEL external 11 foo11()1852846975\n DATABYTES 05\n DATABYTES 26f596f9\n DATALABEL external 13 foo13()653629177\n DATABYTES 05\n DATABYTES 046761eb\n DATALABEL external 14 foo14()73884139\n DATABYTES 05\n DATABYTES 8906adc6\n DATALABEL external 17 foo17()2298916294\n DATABYTES 05\n DATABYTES e425acd1\n DATALABEL external 4 foo4()3827674321\n DATABYTES 05\n DATABYTES 796a01ac\n DATALABEL external 7 foo7()2036990380\n DATABYTES 05\n DATABYTES f1e64be5\n DATALABEL external 29 foo29()4058401765\n DATABYTES 05\n DATABYTES d28958b8\n DATALABEL external 3 foo3()3532216504\n DATABYTES 05\n\nDATA bucket_2:\n DATABYTES 0670ff6a\n DATALABEL external 25 foo25()108068714\n DATABYTES 05\n DATABYTES 96349949\n DATALABEL external 24 foo24()2520029513\n DATABYTES 05\n DATABYTES 7381e7c1\n DATALABEL external 10 foo10()1937893313\n DATABYTES 05\n DATABYTES 85adc131\n DATALABEL external 28 foo28()2242756913\n DATABYTES 05\n DATABYTES fa22b1ed\n DATALABEL external 5 foo5()4196577773\n DATABYTES 05\n DATABYTES 41e75b05\n DATALABEL external 22 foo22()1105681157\n DATABYTES 05\n DATABYTES d38955e8\n DATALABEL external 1 foo1()3548993000\n DATABYTES 05\n DATABYTES 684cf8f3\n DATALABEL external 20 foo20()1749874931\n DATABYTES 05\n\nDATA bucket_0:\n DATABYTES eed91de3\n DATALABEL external 9 foo9()4007206371\n DATABYTES 05\n DATABYTES 61bc1c68\n DATALABEL external 16 foo16()1639717992\n DATABYTES 05\n DATABYTES d32aa70c\n DATALABEL external 21 foo21()3542787852\n DATABYTES 05\n DATABYTES 186947d9\n DATALABEL external 19 foo19()409552857\n DATABYTES 05\n DATABYTES 0af1f97f\n DATALABEL external 18 foo18()183630207\n DATABYTES 05\n DATABYTES 29dad760\n DATALABEL external 27 foo27()702207840\n DATABYTES 05\n DATABYTES 32f6aada\n DATALABEL external 12 foo12()855026394\n DATABYTES 05\n DATABYTES beb505f5\n DATALABEL external 15 foo15()3199534581\n DATABYTES 05\n DATABYTES fca75fe6\n DATALABEL external 8 foo8()4238827494\n DATABYTES 05\n DATABYTES 1b124338\n DATALABEL external 26 foo26()454181688\n DATABYTES 05""" # noqa: E501 + + assert expected_asm in output["asm_runtime"] def test_sparse_jumptable_stability(): diff --git a/tests/functional/venom/test_venom_label_variables.py b/tests/functional/venom/test_venom_label_variables.py index ac101d7039..0f34f073d6 100644 --- a/tests/functional/venom/test_venom_label_variables.py +++ b/tests/functional/venom/test_venom_label_variables.py @@ -82,4 +82,4 @@ def test_labels_as_variables(): run_passes_on(ctx, OptimizationLevel.default()) asm = generate_assembly_experimental(ctx) - generate_bytecode(asm, compiler_metadata=None) + generate_bytecode(asm) diff --git a/tests/functional/venom/test_venom_repr.py b/tests/functional/venom/test_venom_repr.py index d08f71c2b9..c8bfc16229 100644 --- a/tests/functional/venom/test_venom_repr.py +++ b/tests/functional/venom/test_venom_repr.py @@ -104,7 +104,7 @@ def _helper1(vyper_source, optimize): # test we can generate assembly+bytecode asm = generate_assembly_experimental(ctx) - generate_bytecode(asm, compiler_metadata=None) + generate_bytecode(asm) def _helper2(vyper_source, optimize, compiler_settings): @@ -126,7 +126,7 @@ def _helper2(vyper_source, optimize, compiler_settings): # test we can generate assembly+bytecode asm = generate_assembly_experimental(ctx, optimize=optimize) - bytecode = generate_bytecode(asm, compiler_metadata=None) + bytecode, _ = generate_bytecode(asm) out = compile_code(vyper_source, settings=settings, output_formats=["bytecode_runtime"]) assert "0x" + bytecode.hex() == out["bytecode_runtime"] diff --git a/tests/hevm.py b/tests/hevm.py index da104db60a..f8d7f58a80 100644 --- a/tests/hevm.py +++ b/tests/hevm.py @@ -66,8 +66,8 @@ def _prep_hevm_venom_ctx(ctx, verbose=False): LowerDloadPass(ac, fn).run_pass() SingleUseExpansion(ac, fn).run_pass() - compiler = VenomCompiler([ctx]) - asm = compiler.generate_evm(no_optimize=False) + compiler = VenomCompiler(ctx) + asm = compiler.generate_evm_assembly(no_optimize=False) return assembly_to_evm(asm)[0].hex() diff --git a/tests/unit/compiler/asm/test_asm_optimizer.py b/tests/unit/compiler/asm/test_asm_optimizer.py index dfbb53ad5a..fff8a961bf 100644 --- a/tests/unit/compiler/asm/test_asm_optimizer.py +++ b/tests/unit/compiler/asm/test_asm_optimizer.py @@ -3,7 +3,7 @@ from vyper.compiler import compile_code from vyper.compiler.phases import CompilerData from vyper.compiler.settings import OptimizationLevel, Settings -from vyper.ir.compile_ir import _merge_jumpdests +from vyper.ir.compile_ir import PUSHLABEL, Label, _merge_jumpdests codes = [ """ @@ -82,18 +82,18 @@ def test_dead_code_eliminator(code): c = CompilerData(code, settings=Settings(optimize=OptimizationLevel.NONE)) # get the labels - initcode_asm = [i for i in c.assembly if isinstance(i, str)] - runtime_asm = [i for i in c.assembly_runtime if isinstance(i, str)] + initcode_labels = [i for i in c.assembly if isinstance(i, Label)] + runtime_labels = [i for i in c.assembly_runtime if isinstance(i, Label)] ctor_only = "ctor_only()" runtime_only = "runtime_only()" # qux reachable from unoptimized initcode, foo not reachable. - assert any(ctor_only in instr for instr in initcode_asm) - assert all(runtime_only not in instr for instr in initcode_asm) + assert any(ctor_only in label.label for label in initcode_labels) + assert all(runtime_only not in label.label for label in initcode_labels) - assert any(runtime_only in instr for instr in runtime_asm) - assert all(ctor_only not in instr for instr in runtime_asm) + assert any(runtime_only in label.label for label in runtime_labels) + assert all(ctor_only not in label.label for label in runtime_labels) def test_library_code_eliminator(make_input_bundle, experimental_codegen): @@ -118,8 +118,8 @@ def foo(): library.some_function() """ input_bundle = make_input_bundle({"library.vy": library}) - res = compile_code(code, input_bundle=input_bundle, output_formats=["asm"]) - asm = res["asm"] + res = compile_code(code, input_bundle=input_bundle, output_formats=["asm_runtime"]) + asm = res["asm_runtime"] if not experimental_codegen: assert "some_function()" in asm # Venom function inliner will remove this @@ -129,6 +129,6 @@ def foo(): def test_merge_jumpdests(): - asm = ["_sym_label_0", "JUMP", "PUSH0", "_sym_label_0", "JUMPDEST", "_sym_label_0", "JUMPDEST"] + asm = [PUSHLABEL(Label("label_0")), "JUMP", "PUSH0", Label("label_0"), Label("_label_0")] assert _merge_jumpdests(asm) is False, "should not return True as no changes were made" diff --git a/tests/unit/compiler/ir/test_repeat.py b/tests/unit/compiler/ir/test_repeat.py index e134be087d..f7fe869dea 100644 --- a/tests/unit/compiler/ir/test_repeat.py +++ b/tests/unit/compiler/ir/test_repeat.py @@ -1,5 +1,5 @@ def test_repeat(get_contract_from_ir, assert_compile_failed): - good_ir = ["repeat", 0, 0, 1, 1, ["seq"]] + good_ir = ["repeat", "i", 0, 1, 1, ["seq"]] bad_ir_1 = ["repeat", 0, 0, 0, 0, ["seq"]] bad_ir_2 = ["repeat", 0, 0, -1, -1, ["seq"]] get_contract_from_ir(good_ir) diff --git a/tests/unit/compiler/test_sha3_32.py b/tests/unit/compiler/test_sha3_32.py deleted file mode 100644 index e1cbf9c843..0000000000 --- a/tests/unit/compiler/test_sha3_32.py +++ /dev/null @@ -1,12 +0,0 @@ -from vyper.codegen.ir_node import IRnode -from vyper.evm.opcodes import version_check -from vyper.ir import compile_ir, optimizer - - -def test_sha3_32(): - ir = ["sha3_32", 0] - evm = ["PUSH1", 0, "PUSH1", 0, "MSTORE", "PUSH1", 32, "PUSH1", 0, "SHA3"] - if version_check(begin="shanghai"): - evm = ["PUSH0", "PUSH0", "MSTORE", "PUSH1", 32, "PUSH0", "SHA3"] - assert compile_ir.compile_to_assembly(IRnode.from_list(ir)) == evm - assert compile_ir.compile_to_assembly(optimizer.optimize(IRnode.from_list(ir))) == evm diff --git a/tests/unit/compiler/test_source_map.py b/tests/unit/compiler/test_source_map.py index bd5d75a447..0272ea9044 100644 --- a/tests/unit/compiler/test_source_map.py +++ b/tests/unit/compiler/test_source_map.py @@ -33,7 +33,8 @@ def foo(a: uint256) -> int128: def test_jump_map(optimize, experimental_codegen): - source_map = compile_code(TEST_CODE, output_formats=["source_map"])["source_map"] + compiler_output = compile_code(TEST_CODE, output_formats=["source_map_runtime"]) + source_map = compiler_output["source_map_runtime"] pos_map = source_map["pc_pos_map"] jump_map = source_map["pc_jump_map"] @@ -75,7 +76,8 @@ def test_jump_map(optimize, experimental_codegen): def test_pos_map_offsets(): - source_map = compile_code(TEST_CODE, output_formats=["source_map"])["source_map"] + compiler_output = compile_code(TEST_CODE, output_formats=["source_map_runtime"]) + source_map = compiler_output["source_map_runtime"] expanded = expand_source_map(source_map["pc_pos_map_compressed"]) pc_iter = iter(source_map["pc_pos_map"][i] for i in sorted(source_map["pc_pos_map"])) @@ -105,7 +107,9 @@ def test_error_map(experimental_codegen): def update_foo(): self.foo += 1 """ - error_map = compile_code(code, output_formats=["source_map"])["source_map"]["error_map"] + compiler_output = compile_code(code, output_formats=["source_map_runtime"]) + error_map = compiler_output["source_map_runtime"]["error_map"] + assert "safeadd" in error_map.values() if experimental_codegen: @@ -121,7 +125,8 @@ def test_error_map_with_user_error(): def foo(): raise "some error" """ - error_map = compile_code(code, output_formats=["source_map"])["source_map"]["error_map"] + compiler_output = compile_code(code, output_formats=["source_map_runtime"]) + error_map = compiler_output["source_map_runtime"]["error_map"] assert "user revert with reason" in error_map.values() @@ -132,7 +137,8 @@ def foo(i: uint256): a: DynArray[uint256, 10] = [1] a[i % 10] = 2 """ - error_map = compile_code(code, output_formats=["source_map"])["source_map"]["error_map"] + compiler_output = compile_code(code, output_formats=["source_map_runtime"]) + error_map = compiler_output["source_map_runtime"]["error_map"] assert "safemod" in error_map.values() @@ -147,7 +153,8 @@ def bar(i: uint256) -> String[85]: # ensure the mod doesn't get erased return concat("foo foo", uint2str(i)) """ - error_map = compile_code(code, output_formats=["source_map"])["source_map"]["error_map"] + compiler_output = compile_code(code, output_formats=["source_map_runtime"]) + error_map = compiler_output["source_map_runtime"]["error_map"] assert "user revert with reason" in error_map.values() assert "safemod" in error_map.values() @@ -196,10 +203,11 @@ def _construct_node_id_map(ast_struct): def test_node_id_map(): code = TEST_CODE - out = compile_code(code, output_formats=["annotated_ast_dict", "source_map", "ir"]) - assert out["source_map"]["pc_ast_map_item_keys"] == ("source_id", "node_id") + out = compile_code(code, output_formats=["annotated_ast_dict", "source_map_runtime", "ir"]) + source_map = out["source_map_runtime"] + assert source_map["pc_ast_map_item_keys"] == ("source_id", "node_id") - pc_ast_map = out["source_map"]["pc_ast_map"] + pc_ast_map = source_map["pc_ast_map"] ast_node_map = _construct_node_id_map(out["annotated_ast_dict"]) diff --git a/tests/unit/compiler/venom/test_venom_to_assembly.py b/tests/unit/compiler/venom/test_venom_to_assembly.py index ba520c06d1..73162485b2 100644 --- a/tests/unit/compiler/venom/test_venom_to_assembly.py +++ b/tests/unit/compiler/venom/test_venom_to_assembly.py @@ -13,7 +13,7 @@ def test_dead_params(): """ ctx = parse_venom(code) - asm = VenomCompiler([ctx]).generate_evm() + asm = VenomCompiler(ctx).generate_evm_assembly() assert asm == ["SWAP1", "POP", "JUMP"] @@ -32,5 +32,5 @@ def test_optimistic_swap_params(): """ ctx = parse_venom(code) - asm = VenomCompiler([ctx]).generate_evm() + asm = VenomCompiler(ctx).generate_evm_assembly() assert asm == ["SWAP2", "PUSH1", 117, "POP", "MSTORE", "MSTORE", "JUMP"] diff --git a/vyper/cli/venom_main.py b/vyper/cli/venom_main.py index d6b7bcec50..0ceeae73f6 100755 --- a/vyper/cli/venom_main.py +++ b/vyper/cli/venom_main.py @@ -61,7 +61,7 @@ def _parse_args(argv: list[str]): run_passes_on(ctx, OptimizationLevel.default()) asm = generate_assembly_experimental(ctx) - bytecode = generate_bytecode(asm, compiler_metadata=None) + bytecode, _ = generate_bytecode(asm) print(f"0x{bytecode.hex()}") diff --git a/vyper/codegen/core.py b/vyper/codegen/core.py index 98da9bc45c..14e7b08195 100644 --- a/vyper/codegen/core.py +++ b/vyper/codegen/core.py @@ -933,7 +933,7 @@ def FAIL(): # pragma: no cover _label = 0 -# TODO might want to coalesce with Context.fresh_varname and compile_ir.mksymbol +# TODO might want to coalesce with Context.fresh_varname def _freshname(name): global _label _label += 1 diff --git a/vyper/compiler/__init__.py b/vyper/compiler/__init__.py index cd79fee663..e83017e9d3 100644 --- a/vyper/compiler/__init__.py +++ b/vyper/compiler/__init__.py @@ -37,6 +37,7 @@ # requires assembly "abi": output.build_abi_output, "asm": output.build_asm_output, + "asm_runtime": output.build_asm_runtime_output, "source_map": output.build_source_map_output, "source_map_runtime": output.build_source_map_runtime_output, # requires bytecode diff --git a/vyper/compiler/output.py b/vyper/compiler/output.py index 668ca55294..e31327a2c2 100644 --- a/vyper/compiler/output.py +++ b/vyper/compiler/output.py @@ -166,19 +166,19 @@ def build_interface_output(compiler_data: CompilerData) -> str: def build_bb_output(compiler_data: CompilerData) -> IRnode: - return compiler_data.venom_functions[0] + return compiler_data.venom_deploytime def build_bb_runtime_output(compiler_data: CompilerData) -> IRnode: - return compiler_data.venom_functions[1] + return compiler_data.venom_runtime def build_cfg_output(compiler_data: CompilerData) -> str: - return compiler_data.venom_functions[0].as_graph() + return compiler_data.venom_deploytime.as_graph() def build_cfg_runtime_output(compiler_data: CompilerData) -> str: - return compiler_data.venom_functions[1].as_graph() + return compiler_data.venom_runtime.as_graph() def build_ir_output(compiler_data: CompilerData) -> IRnode: @@ -320,6 +320,10 @@ def build_asm_output(compiler_data: CompilerData) -> str: return _build_asm(compiler_data.assembly) +def build_asm_runtime_output(compiler_data: CompilerData) -> str: + return _build_asm(compiler_data.assembly_runtime) + + def build_layout_output(compiler_data: CompilerData) -> StorageLayout: # in the future this might return (non-storage) layout, # for now only storage layout is returned. @@ -327,26 +331,24 @@ def build_layout_output(compiler_data: CompilerData) -> StorageLayout: def _build_asm(asm_list): - output_string = "" + output_string = "__entry__:" in_push = 0 - for node in asm_list: - if isinstance(node, list): - output_string += "{ " + _build_asm(node) + "} " + for item in asm_list: + if isinstance(item, (compile_ir.Label, compile_ir.DataHeader)): + output_string += f"\n\n{item}:" continue if in_push > 0: - assert isinstance(node, int), node - output_string += hex(node)[2:].rjust(2, "0") - if in_push == 1: - output_string += " " + assert isinstance(item, int), item + output_string += hex(item)[2:].rjust(2, "0") in_push -= 1 else: - output_string += str(node) + " " + output_string += f"\n {item}" - if isinstance(node, str) and node.startswith("PUSH") and node != "PUSH0": + if isinstance(item, str) and item.startswith("PUSH") and item != "PUSH0": assert in_push == 0 - in_push = int(node[4:]) - output_string += "0x" + in_push = int(item[4:]) + output_string += " 0x" return output_string @@ -356,6 +358,10 @@ def _build_node_identifier(ast_node): return (ast_node.module_node.source_id, ast_node.node_id) +def _getpos(node): + return (node.lineno, node.col_offset, node.end_lineno, node.end_col_offset) + + def _build_source_map_output(compiler_data, bytecode, pc_maps): """ Generate source map output in various formats. Note that integrations @@ -376,7 +382,7 @@ def _build_source_map_output(compiler_data, bytecode, pc_maps): # tag it with source id ast_map[0] = compiler_data.annotated_vyper_module - pc_pos_map = {k: compile_ir.getpos(v) for (k, v) in ast_map.items()} + pc_pos_map = {k: _getpos(v) for (k, v) in ast_map.items()} node_id_map = {k: _build_node_identifier(v) for (k, v) in ast_map.items()} compressed_map = _compress_source_map(ast_map, out["pc_jump_map"], len(bytecode)) out["pc_pos_map_compressed"] = compressed_map @@ -388,15 +394,15 @@ def _build_source_map_output(compiler_data, bytecode, pc_maps): def build_source_map_output(compiler_data: CompilerData) -> dict: - bytecode, pc_maps = compile_ir.assembly_to_evm(compiler_data.assembly, compiler_metadata=None) - return _build_source_map_output(compiler_data, bytecode, pc_maps) + bytecode = compiler_data.bytecode + source_map = compiler_data.source_map + return _build_source_map_output(compiler_data, bytecode, source_map) def build_source_map_runtime_output(compiler_data: CompilerData) -> dict: - bytecode, pc_maps = compile_ir.assembly_to_evm( - compiler_data.assembly_runtime, compiler_metadata=None - ) - return _build_source_map_output(compiler_data, bytecode, pc_maps) + bytecode = compiler_data.bytecode_runtime + source_map = compiler_data.source_map_runtime + return _build_source_map_output(compiler_data, bytecode, source_map) # generate a solidity-style source map. this functionality is deprecated diff --git a/vyper/compiler/phases.py b/vyper/compiler/phases.py index 1ff9ed1844..790d35b02d 100644 --- a/vyper/compiler/phases.py +++ b/vyper/compiler/phases.py @@ -18,7 +18,6 @@ should_run_legacy_optimizer, ) from vyper.ir import compile_ir, optimizer -from vyper.ir.compile_ir import reset_symbols from vyper.semantics import analyze_module, set_data_positions, validate_compilation_target from vyper.semantics.analysis.data_positions import generate_layout_export from vyper.semantics.analysis.imports import resolve_imports @@ -26,7 +25,7 @@ from vyper.semantics.types.module import ModuleT from vyper.typing import StorageLayout from vyper.utils import ERC5202_PREFIX, sha256sum -from vyper.venom import generate_assembly_experimental, generate_ir +from vyper.venom import generate_assembly_experimental, generate_venom from vyper.warnings import VyperWarning, vyper_warn DEFAULT_CONTRACT_PATH = PurePath("VyperContract.vy") @@ -256,42 +255,91 @@ def function_signatures(self) -> dict[str, ContractFunctionT]: return {f.name: f._metadata["func_type"] for f in fs} @cached_property - def venom_functions(self): - deploy_ir, runtime_ir = self._ir_output - deploy_venom = generate_ir(deploy_ir, self.settings) - runtime_venom = generate_ir(runtime_ir, self.settings) - return deploy_venom, runtime_venom + def venom_runtime(self): + runtime_venom = generate_venom(self.ir_runtime, self.settings) + return runtime_venom + + @cached_property + def venom_deploytime(self): + data_sections = {"runtime_begin": self.bytecode_runtime} + if self.bytecode_metadata is not None: + data_sections["cbor_metadata"] = self.bytecode_metadata + + constants = { + "runtime_codesize": len(self.bytecode_runtime), + "immutables_len": self.compilation_target._metadata["type"].immutable_section_bytes, + } + + venom_ctx = generate_venom( + self.ir_nodes, self.settings, constants=constants, data_sections=data_sections + ) + return venom_ctx @cached_property def assembly(self) -> list: + metadata = None + if not self.no_bytecode_metadata: + metadata = bytes.fromhex(self.integrity_sum) + if self.settings.experimental_codegen: - deploy_code, runtime_code = self.venom_functions assert self.settings.optimize is not None # mypy hint return generate_assembly_experimental( - runtime_code, deploy_code=deploy_code, optimize=self.settings.optimize + self.venom_deploytime, optimize=self.settings.optimize ) else: - return generate_assembly(self.ir_nodes, self.settings.optimize) + return generate_assembly( + self.ir_nodes, self.settings.optimize, compiler_metadata=metadata + ) + + @cached_property + def bytecode_metadata(self) -> Optional[bytes]: + if self.no_bytecode_metadata: + return None + + runtime_asm = self.assembly_runtime + runtime_data_segment_lengths = compile_ir.get_data_segment_lengths(runtime_asm) + + immutables_len = self.compilation_target._metadata["type"].immutable_section_bytes + runtime_codesize = len(self.bytecode_runtime) + + metadata = bytes.fromhex(self.integrity_sum) + return compile_ir.generate_cbor_metadata( + metadata, runtime_codesize, runtime_data_segment_lengths, immutables_len + ) @cached_property def assembly_runtime(self) -> list: if self.settings.experimental_codegen: - _, runtime_code = self.venom_functions assert self.settings.optimize is not None # mypy hint - return generate_assembly_experimental(runtime_code, optimize=self.settings.optimize) + return generate_assembly_experimental( + self.venom_runtime, optimize=self.settings.optimize + ) else: return generate_assembly(self.ir_runtime, self.settings.optimize) @cached_property + def _bytecode(self) -> tuple[bytes, dict[str, Any]]: + return generate_bytecode(self.assembly) + + @property def bytecode(self) -> bytes: - metadata = None - if not self.no_bytecode_metadata: - metadata = bytes.fromhex(self.integrity_sum) - return generate_bytecode(self.assembly, compiler_metadata=metadata) + return self._bytecode[0] + + @property + def source_map(self) -> dict[str, Any]: + return self._bytecode[1] @cached_property + def _bytecode_runtime(self) -> tuple[bytes, dict[str, Any]]: + return generate_bytecode(self.assembly_runtime) + + @property def bytecode_runtime(self) -> bytes: - return generate_bytecode(self.assembly_runtime, compiler_metadata=None) + return self._bytecode_runtime[0] + + @property + def source_map_runtime(self) -> dict[str, Any]: + return self._bytecode_runtime[1] @cached_property def blueprint_bytecode(self) -> bytes: @@ -326,7 +374,6 @@ def generate_ir_nodes(global_ctx: ModuleT, settings: Settings) -> tuple[IRnode, """ # make IR output the same between runs codegen.reset_names() - reset_symbols() with anchor_settings(settings): ir_nodes, ir_runtime = module.generate_ir_for_module(global_ctx) @@ -338,7 +385,11 @@ def generate_ir_nodes(global_ctx: ModuleT, settings: Settings) -> tuple[IRnode, return ir_nodes, ir_runtime -def generate_assembly(ir_nodes: IRnode, optimize: Optional[OptimizationLevel] = None) -> list: +def generate_assembly( + ir_nodes: IRnode, + optimize: Optional[OptimizationLevel] = None, + compiler_metadata: Optional[Any] = None, +) -> list: """ Generate assembly instructions from IR. @@ -353,9 +404,11 @@ def generate_assembly(ir_nodes: IRnode, optimize: Optional[OptimizationLevel] = List of assembly instructions. """ optimize = optimize or OptimizationLevel.default() - assembly = compile_ir.compile_to_assembly(ir_nodes, optimize=optimize) + assembly = compile_ir.compile_to_assembly( + ir_nodes, optimize=optimize, compiler_metadata=compiler_metadata + ) - if _find_nested_opcode(assembly, "DEBUG"): + if "DEBUG" in assembly: vyper_warn( VyperWarning( "This code contains DEBUG opcodes! The DEBUG opcode will only work in " @@ -365,15 +418,7 @@ def generate_assembly(ir_nodes: IRnode, optimize: Optional[OptimizationLevel] = return assembly -def _find_nested_opcode(assembly, key): - if key in assembly: - return True - else: - sublists = [sub for sub in assembly if isinstance(sub, list)] - return any(_find_nested_opcode(x, key) for x in sublists) - - -def generate_bytecode(assembly: list, compiler_metadata: Optional[Any]) -> bytes: +def generate_bytecode(assembly: list) -> tuple[bytes, dict[str, Any]]: """ Generate bytecode from assembly instructions. @@ -386,5 +431,7 @@ def generate_bytecode(assembly: list, compiler_metadata: Optional[Any]) -> bytes ------- bytes Final compiled bytecode. + dict + Source map """ - return compile_ir.assembly_to_evm(assembly, compiler_metadata=compiler_metadata)[0] + return compile_ir.assembly_to_evm(assembly) diff --git a/vyper/evm/opcodes.py b/vyper/evm/opcodes.py index 3049d7f911..3c6a80d33e 100644 --- a/vyper/evm/opcodes.py +++ b/vyper/evm/opcodes.py @@ -220,9 +220,12 @@ def _gas(value: OpcodeValue, idx: int) -> Optional[OpcodeRulesetValue]: def _mk_version_opcodes(opcodes: OpcodeMap, idx: int) -> OpcodeRulesetMap: - return dict( - (k, _gas(v, idx)) for k, v in opcodes.items() if _gas(v, idx) is not None # type: ignore - ) + ret = {} + for k, v in opcodes.items(): + gas = _gas(v, idx) + if gas is not None: + ret[k] = gas + return ret _evm_opcodes: Dict[int, OpcodeRulesetMap] = { diff --git a/vyper/ir/compile_ir.py b/vyper/ir/compile_ir.py index 936e6d5d72..0b6a21a8fe 100644 --- a/vyper/ir/compile_ir.py +++ b/vyper/ir/compile_ir.py @@ -1,7 +1,9 @@ +from __future__ import annotations + +import contextlib import copy -import functools -import math from dataclasses import dataclass +from typing import Any, Optional, TypeVar import cbor2 @@ -10,7 +12,7 @@ from vyper.evm.opcodes import get_opcodes, version_check from vyper.exceptions import CodegenPanic, CompilerPanic from vyper.ir.optimizer import COMMUTATIVE_OPS -from vyper.utils import MemoryPositions +from vyper.utils import MemoryPositions, OrderedSet from vyper.version import version_tuple PUSH_OFFSET = 0x5F @@ -44,40 +46,166 @@ def PUSH_N(x, n): return [f"PUSH{len(o)}"] + o -_next_symbol = 0 +##################################### +# assembly data structures and utils +##################################### + + +class Label: + def __init__(self, label: str): + assert isinstance(label, str) + self.label = label + + def __repr__(self): + return f"LABEL {self.label}" + + def __eq__(self, other): + if not isinstance(other, Label): + return False + return self.label == other.label + + def __hash__(self): + return hash(self.label) + + +@dataclass +class DataHeader: + label: Label + + def __repr__(self): + return f"DATA {self.label.label}" + + +# this could be fused with Label, the only difference is if +# it gets looked up from const_map or symbol_map. +class CONSTREF: + def __init__(self, label: str): + assert isinstance(label, str) + self.label = label + + def __repr__(self): + return f"CONSTREF {self.label}" + + def __eq__(self, other): + if not isinstance(other, CONSTREF): + return False + return self.label == other.label + + def __hash__(self): + return hash(self.label) + + +class CONST: + def __init__(self, name: str, value: int): + assert isinstance(name, str) + assert isinstance(value, int) + self.name = name + self.value = value + + def __repr__(self): + return f"CONST {self.name} {self.value}" + + def __eq__(self, other): + if not isinstance(other, CONST): + return False + return self.name == other.name and self.value == other.value + +class PUSHLABEL: + def __init__(self, label: Label): + assert isinstance(label, Label), label + self.label = label -def mksymbol(name=""): - global _next_symbol - _next_symbol += 1 + def __repr__(self): + return f"PUSHLABEL {self.label.label}" + + def __eq__(self, other): + if not isinstance(other, PUSHLABEL): + return False + return self.label == other.label + + def __hash__(self): + return hash(self.label) + + +# push the result of an addition (which might be resolvable at compile-time) +class PUSH_OFST: + def __init__(self, label: Label | CONSTREF, ofst: int): + # label can be Label or CONSTREF + assert isinstance(label, (Label, CONSTREF)) + self.label = label + self.ofst = ofst + + def __repr__(self): + label = self.label + if isinstance(label, Label): + label = label.label # str + return f"PUSH_OFST({label}, {self.ofst})" + + def __eq__(self, other): + if not isinstance(other, PUSH_OFST): + return False + return self.label == other.label and self.ofst == other.ofst + + def __hash__(self): + return hash((self.label, self.ofst)) + + +class DATA_ITEM: + def __init__(self, item: bytes | Label): + self.data = item + + def __repr__(self): + if isinstance(self.data, bytes): + return f"DATABYTES {self.data.hex()}" + elif isinstance(self.data, Label): + return f"DATALABEL {self.data.label}" - return f"_sym_{name}{_next_symbol}" +def JUMP(label: Label): + return [PUSHLABEL(label), "JUMP"] -def reset_symbols(): - global _next_symbol - _next_symbol = 0 + +def JUMPI(label: Label): + return [PUSHLABEL(label), "JUMPI"] def mkdebug(pc_debugger, ast_source): - i = Instruction("DEBUG", ast_source) + # compile debug instructions + # (this is dead code -- CMC 2025-05-08) + i = TaggedInstruction("DEBUG", ast_source) i.pc_debugger = pc_debugger return [i] def is_symbol(i): - return isinstance(i, str) and i.startswith("_sym_") + return isinstance(i, Label) + +def is_ofst(assembly_item): + return isinstance(assembly_item, PUSH_OFST) -# basically something like a symbol which gets resolved -# during assembly, but requires 4 bytes of space. -# (should only happen in deploy code) -def is_mem_sym(i): - return isinstance(i, str) and i.startswith("_mem_") +def generate_cbor_metadata( + compiler_metadata: Any, + runtime_codesize: int, + runtime_data_segment_lengths: list[int], + immutables_len: int, +) -> bytes: + metadata = ( + compiler_metadata, + runtime_codesize, + runtime_data_segment_lengths, + immutables_len, + {"vyper": version_tuple}, + ) + ret = cbor2.dumps(metadata) + # append the length of the footer, *including* the length + # of the length bytes themselves. + suffix_len = len(ret) + 2 + ret += suffix_len.to_bytes(2, "big") -def is_ofst(sym): - return isinstance(sym, str) and sym == "_OFST" + return ret def _runtime_code_offsets(ctor_mem_size, runtime_codelen): @@ -88,10 +216,10 @@ def _runtime_code_offsets(ctor_mem_size, runtime_codelen): # of the runtime code. # after the ctor has run but before copying runtime code to # memory, the layout is - # ... | data section + # | | # and after copying runtime code to memory (immediately before # returning the runtime code): - # ... | data section + # | | # since the ctor memory variables and runtime code overlap, # we start allocating the data section from # `max(ctor_mem_size, runtime_code_size)` @@ -102,11 +230,17 @@ def _runtime_code_offsets(ctor_mem_size, runtime_codelen): return runtime_code_start, runtime_code_end -# Calculate the size of PUSH instruction we need to handle all -# mem offsets in the code. For instance, if we only see mem symbols -# up to size 256, we can use PUSH1. -def calc_mem_ofst_size(ctor_mem_size): - return math.ceil(math.log(ctor_mem_size + 1, 256)) +# Calculate the size of PUSH instruction +def calc_push_size(val: int): + # stupid implementation. this is "slow", but its correctness is + # obvious verify, as opposed to + # ``` + # (val.bit_length() + 7) // 8 + # + (1 + # if (val > 0 or version_check(begin="shanghai")) + # else 0) + # ``` + return len(PUSH(val)) # temporary optimization to handle stack items for return sequences @@ -147,643 +281,694 @@ def _rewrite_return_sequences(ir_node, label_params=None): _rewrite_return_sequences(t, label_params) -def _assert_false(): - global _revert_label - # use a shared failure block for common case of assert(x). - # in the future we might want to change the code - # at _sym_revert0 to: INVALID - return [_revert_label, "JUMPI"] +# a string (assembly instruction) but with additional metadata from the source code +class TaggedInstruction(str): + def __new__(cls, sstr, *args, **kwargs): + return super().__new__(cls, sstr) + + def __init__(self, sstr, ast_source=None, error_msg=None): + self.error_msg = error_msg + self.pc_debugger = False + + self.ast_source = ast_source + +############################## +# IRnode to assembly +############################## -def _add_postambles(asm_ops): - to_append = [] - global _revert_label +# external entry point to `IRnode.compile_to_assembly()` +def compile_to_assembly( + code: IRnode, + optimize: OptimizationLevel = OptimizationLevel.GAS, + compiler_metadata: Optional[Any] = None, +): + """ + Parameters: + code: IRnode to compile + optimize: Optimization level + compiler_metadata: + any compiler metadata to add as the final data segment. pass + `None` to indicate no metadata to be added (should always + be `None` for runtime code). the value is opaque, and will be + passed directly to `cbor2.dumps()`. + """ - _revert_string = [_revert_label, "JUMPDEST", *PUSH(0), "DUP1", "REVERT"] + # don't mutate the ir since the original might need to be output, e.g. `-f ir,asm` + code = copy.deepcopy(code) + _rewrite_return_sequences(code) - if _revert_label in asm_ops: - # shared failure block - to_append.extend(_revert_string) + res = _IRnodeLowerer(optimize, compiler_metadata).compile_to_assembly(code) - if len(to_append) > 0: - # insert the postambles *before* runtime code - # so the data section of the runtime code can't bork the postambles. - runtime = None - if isinstance(asm_ops[-1], list) and isinstance(asm_ops[-1][0], RuntimeHeader): - runtime = asm_ops.pop() + if optimize != OptimizationLevel.NONE: + optimize_assembly(res) + return res - # for some reason there might not be a STOP at the end of asm_ops. - # (generally vyper programs will have it but raw IR might not). - asm_ops.append("STOP") - asm_ops.extend(to_append) - if runtime: - asm_ops.append(runtime) +# TODO: move all these assembly data structures to own module, like +# vyper.evm.assembly +AssemblyInstruction = ( + str | TaggedInstruction | int | PUSHLABEL | Label | PUSH_OFST | DATA_ITEM | DataHeader | CONST +) - # need to do this recursively since every sublist is basically - # treated as its own program (there are no global labels.) - for t in asm_ops: - if isinstance(t, list): - _add_postambles(t) +class _IRnodeLowerer: + # map from variable names to height in stack + withargs: dict[str, int] -class Instruction(str): - def __new__(cls, sstr, *args, **kwargs): - return super().__new__(cls, sstr) + # set of all existing labels in the IRnodes + existing_labels: set[str] - def __init__(self, sstr, ast_source=None, error_msg=None): - self.error_msg = error_msg - self.pc_debugger = False + # break destination when inside loops + # continue_dest, break_dest, height + break_dest: tuple[Label, Label, int] - self.ast_source = ast_source + # current height in stack + height: int + code_instructions: list[AssemblyInstruction] + data_segments: list[list[AssemblyInstruction]] -def apply_line_numbers(func): - @functools.wraps(func) - def apply_line_no_wrapper(*args, **kwargs): - code = args[0] - ret = func(*args, **kwargs) + optimize: OptimizationLevel - new_ret = [ - Instruction(i, code.ast_source, code.error_msg) - if isinstance(i, str) and not isinstance(i, Instruction) - else i - for i in ret - ] - return new_ret + symbol_counter: int = 0 - return apply_line_no_wrapper + def __init__(self, optimize: OptimizationLevel = OptimizationLevel.GAS, compiler_metadata=None): + self.optimize = optimize + self.compiler_metadata = compiler_metadata + def compile_to_assembly(self, code): + self.withargs = {} + self.existing_labels = set() + self.break_dest = None + self.height = 0 -@apply_line_numbers -def compile_to_assembly(code, optimize=OptimizationLevel.GAS): - global _revert_label - _revert_label = mksymbol("revert") + self.global_revert_label = None - # don't overwrite ir since the original might need to be output, e.g. `-f ir,asm` - code = copy.deepcopy(code) - _rewrite_return_sequences(code) + self.data_segments = [] + self.freeze_data_segments = False - res = _compile_to_assembly(code) + ret = self._compile_r(code, height=0) - _add_postambles(res) + # append postambles before data segments + ret.extend(self._create_postambles()) - _relocate_segments(res) + for data in self.data_segments: + ret.extend(self._compile_data_segment(data)) - if optimize != OptimizationLevel.NONE: - optimize_assembly(res) - return res + return ret + @contextlib.contextmanager + def modify_breakdest(self, continue_dest: Label, exit_dest: Label, height: int): + tmp = self.break_dest + try: + self.break_dest = continue_dest, exit_dest, height + yield + finally: + self.break_dest = tmp -# Compiles IR to assembly -@apply_line_numbers -def _compile_to_assembly(code, withargs=None, existing_labels=None, break_dest=None, height=0): - if withargs is None: - withargs = {} - if not isinstance(withargs, dict): - raise CompilerPanic(f"Incorrect type for withargs: {type(withargs)}") + def mksymbol(self, name: str) -> Label: + self.symbol_counter += 1 + + return Label(f"{name}_{self.symbol_counter}") + + def _data_ofst_of( + self, symbol: Label | CONSTREF, ofst: IRnode, height: int + ) -> list[AssemblyInstruction]: + # e.g. PUSHOFST foo 32 + assert isinstance(symbol, (Label, CONSTREF)), symbol - def _data_ofst_of(sym, ofst, height_): - # e.g. _OFST _sym_foo 32 - assert is_symbol(sym) or is_mem_sym(sym) if isinstance(ofst.value, int): - # resolve at compile time using magic _OFST op - return ["_OFST", sym, ofst.value] + # resolve at compile time using magic PUSH_OFST op + return [PUSH_OFST(symbol, ofst.value)] + + # if we can't resolve at compile time, resolve at runtime + pushsym: PUSHLABEL | PUSH_OFST + if isinstance(symbol, Label): + pushsym = PUSHLABEL(symbol) else: - # if we can't resolve at compile time, resolve at runtime - ofst = _compile_to_assembly(ofst, withargs, existing_labels, break_dest, height_) - return ofst + [sym, "ADD"] - - def _height_of(witharg): - ret = height - withargs[witharg] - if ret > 16: - raise Exception("With statement too deep") - return ret + # magic for mem syms + assert isinstance(symbol, CONSTREF) # clarity + # we don't have a PUSHCONST instruction, use PUSH_OFST with ofst of 0 + pushsym = PUSH_OFST(symbol, 0) + + ofst_asm = self._compile_r(ofst, height) + return ofst_asm + [pushsym, "ADD"] + + def _compile_r(self, code: IRnode, height: int) -> list[AssemblyInstruction]: + asm = self._step_r(code, height) + for i, item in enumerate(asm): + if isinstance(item, str) and not isinstance(item, TaggedInstruction): + # CMC 2025-05-08 this is O(n^2).. :'( + asm[i] = TaggedInstruction(item, code.ast_source, code.error_msg) + + return asm + + def _step_r(self, code: IRnode, height: int) -> list[AssemblyInstruction]: + def _height_of(varname): + ret = height - self.withargs[varname] + if ret > 16: # pragma: nocover + raise Exception("With statement too deep") + return ret + + if isinstance(code.value, str) and code.value.upper() in get_opcodes(): + o = [] + for i, c in enumerate(reversed(code.args)): + o.extend(self._compile_r(c, height + i)) + o.append(code.value.upper()) + return o + + # Numbers + if isinstance(code.value, int): + if code.value < -(2**255): # pragma: nocover + raise Exception(f"Value too low: {code.value}") + elif code.value >= 2**256: # pragma: nocover + raise Exception(f"Value too high: {code.value}") + + return PUSH(code.value % 2**256) + + # Variables connected to with statements + if isinstance(code.value, str) and code.value in self.withargs: + return ["DUP" + str(_height_of(code.value))] + + # Setting variables connected to with statements + if code.value == "set": + varname = code.args[0].value + assert isinstance(varname, str) + if len(code.args) != 2 or varname not in self.withargs: + raise Exception("Set expects two arguments, the first being a stack variable") + # TODO: use _height_of + if height - self.withargs[varname] > 16: + raise Exception("With statement too deep") + swap_instr = "SWAP" + str(height - self.withargs[varname]) + return self._compile_r(code.args[1], height) + [swap_instr, "POP"] + + # Pass statements + # TODO remove "dummy"; no longer needed + if code.value in ("pass", "dummy"): + return [] + + # "mload" from data section of the currently executing code + if code.value == "dload": + loc = code.args[0] + + o = [] + # codecopy 32 bytes to FREE_VAR_SPACE, then mload from FREE_VAR_SPACE + o.extend(PUSH(32)) + + o.extend(self._data_ofst_of(Label("code_end"), loc, height + 1)) + + o.extend(PUSH(MemoryPositions.FREE_VAR_SPACE) + ["CODECOPY"]) + o.extend(PUSH(MemoryPositions.FREE_VAR_SPACE) + ["MLOAD"]) + return o + + # batch copy from data section of the currently executing code to memory + # (probably should have named this dcopy but oh well) + if code.value == "dloadbytes": + dst = code.args[0] + src = code.args[1] + len_ = code.args[2] + + o = [] + o.extend(self._compile_r(len_, height)) + o.extend(self._data_ofst_of(Label("code_end"), src, height + 1)) + o.extend(self._compile_r(dst, height + 2)) + o.extend(["CODECOPY"]) + return o + + # "mload" from the data section of (to-be-deployed) runtime code + if code.value == "iload": + loc = code.args[0] + + o = [] + o.extend(self._data_ofst_of(CONSTREF("mem_deploy_end"), loc, height)) + o.append("MLOAD") + + return o + + # "mstore" to the data section of (to-be-deployed) runtime code + if code.value == "istore": + loc = code.args[0] + val = code.args[1] + + o = [] + o.extend(self._compile_r(val, height)) + o.extend(self._data_ofst_of(CONSTREF("mem_deploy_end"), loc, height + 1)) + o.append("MSTORE") + + return o + + # batch copy from memory to the data section of runtime code + if code.value == "istorebytes": + raise Exception("unimplemented") + + # If statements (2 arguments, ie. if x: y) + if code.value == "if" and len(code.args) == 2: + o = [] + o.extend(self._compile_r(code.args[0], height)) + end_symbol = self.mksymbol("join") + o.extend(["ISZERO", *JUMPI(end_symbol)]) + o.extend(self._compile_r(code.args[1], height)) + o.extend([end_symbol]) + return o + + # If statements (3 arguments, ie. if x: y, else: z) + if code.value == "if" and len(code.args) == 3: + o = [] + o.extend(self._compile_r(code.args[0], height)) + mid_symbol = self.mksymbol("else") + end_symbol = self.mksymbol("join") + o.extend(["ISZERO", *JUMPI(mid_symbol)]) + o.extend(self._compile_r(code.args[1], height)) + o.extend([*JUMP(end_symbol), mid_symbol]) + o.extend(self._compile_r(code.args[2], height)) + o.extend([end_symbol]) + return o + + # repeat(counter_location, start, rounds, rounds_bound, body) + # basically a do-while loop: + # + # assert(rounds <= rounds_bound) + # if (rounds > 0) { + # do { + # body; + # } while (++i != start + rounds) + # } + if code.value == "repeat": + o = [] + if len(code.args) != 5: # pragma: nocover + raise CompilerPanic("bad number of repeat args") + + i_name = code.args[0] + start = code.args[1] + rounds = code.args[2] + rounds_bound = code.args[3] + body = code.args[4] + + assert isinstance(i_name.value, str) # help mypy + + entry_dest = self.mksymbol("loop_start") + continue_dest = self.mksymbol("loop_continue") + exit_dest = self.mksymbol("loop_exit") + + # stack: [] + o.extend(self._compile_r(start, height)) + + o.extend(self._compile_r(rounds, height + 1)) + + # stack: i + + # assert rounds <= round_bound + if rounds != rounds_bound: + # stack: i, rounds + o.extend(self._compile_r(rounds_bound, height + 2)) + # stack: i, rounds, rounds_bound + # assert 0 <= rounds <= rounds_bound (for rounds_bound < 2**255) + # TODO this runtime assertion shouldn't fail for + # internally generated repeats. + o.extend(["DUP2", "GT"] + self._assert_false()) + + # stack: i, rounds + # if (0 == rounds) { goto end_dest; } + o.extend(["DUP1", "ISZERO", *JUMPI(exit_dest)]) + + # stack: start, rounds + if start.value != 0: + o.extend(["DUP2", "ADD"]) + + # stack: i, exit_i + o.extend(["SWAP1"]) + + if i_name.value in self.withargs: # pragma: nocover + raise CompilerPanic(f"shadowed loop variable {i_name}") + self.withargs[i_name.value] = height + 1 + + # stack: exit_i, i + o.extend([entry_dest]) + + with self.modify_breakdest(exit_dest, continue_dest, height + 2): + o.extend(self._compile_r(body, height + 2)) + + del self.withargs[i_name.value] + + # clean up any stack items left by body + o.extend(["POP"] * body.valency) + + # stack: exit_i, i + # increment i: + o.extend([continue_dest, "PUSH1", 1, "ADD"]) + + # stack: exit_i, i+1 (new_i) + # if (exit_i != new_i) { goto entry_dest } + o.extend(["DUP2", "DUP2", "XOR", *JUMPI(entry_dest)]) + o.extend([exit_dest, "POP", "POP"]) + + return o + + # Continue to the next iteration of the for loop + if code.value == "continue": + if not self.break_dest: # pragma: nocover + raise CompilerPanic("Invalid break") + _dest, continue_dest, _break_height = self.break_dest + return [*JUMP(continue_dest)] + + # Break from inside a for loop + if code.value == "break": + if not self.break_dest: # pragma: nocover + raise CompilerPanic("Invalid break") + dest, _continue_dest, break_height = self.break_dest + + n_local_vars = height - break_height + # clean up any stack items declared in the loop body + cleanup_local_vars = ["POP"] * n_local_vars + return cleanup_local_vars + [*JUMP(dest)] + + # Break from inside one or more for loops prior to a return statement inside the loop + if code.value == "cleanup_repeat": + if not self.break_dest: # pragma: nocover + raise CompilerPanic("Invalid break") + # clean up local vars and internal loop vars + _, _, break_height = self.break_dest + # except don't pop label params + if "return_buffer" in self.withargs: + break_height -= 1 + if "return_pc" in self.withargs: + break_height -= 1 + return ["POP"] * break_height + + # With statements + if code.value == "with": + varname = code.args[0].value + assert isinstance(varname, str) + + o = [] + o.extend(self._compile_r(code.args[1], height)) + old = self.withargs.get(varname, None) + self.withargs[varname] = height + o.extend(self._compile_r(code.args[2], height + 1)) + if code.args[2].valency: + o.extend(["SWAP1", "POP"]) + else: + o.extend(["POP"]) + if old is not None: + self.withargs[varname] = old + else: + del self.withargs[varname] + return o + + # runtime statement (used to deploy runtime code) + if code.value == "deploy": + # used to calculate where to copy the runtime code to memory + memsize = code.args[0].value + ir = code.args[1] + immutables_len = code.args[2].value + assert isinstance(memsize, int), "non-int memsize" + assert isinstance(immutables_len, int), "non-int immutables_len" + + runtime_assembly = _IRnodeLowerer( + self.optimize, self.compiler_metadata + ).compile_to_assembly(ir) + + if self.optimize != OptimizationLevel.NONE: + optimize_assembly(runtime_assembly) + + runtime_data_segment_lengths = get_data_segment_lengths(runtime_assembly) - if existing_labels is None: - existing_labels = set() - if not isinstance(existing_labels, set): - raise CompilerPanic(f"must be set(), but got {type(existing_labels)}") - - # Opcodes - if isinstance(code.value, str) and code.value.upper() in get_opcodes(): - o = [] - for i, c in enumerate(code.args[::-1]): - o.extend(_compile_to_assembly(c, withargs, existing_labels, break_dest, height + i)) - o.append(code.value.upper()) - return o - - # Numbers - elif isinstance(code.value, int): - if code.value < -(2**255): - raise Exception(f"Value too low: {code.value}") - elif code.value >= 2**256: - raise Exception(f"Value too high: {code.value}") - return PUSH(code.value % 2**256) - - # Variables connected to with statements - elif isinstance(code.value, str) and code.value in withargs: - return ["DUP" + str(_height_of(code.value))] - - # Setting variables connected to with statements - elif code.value == "set": - if len(code.args) != 2 or code.args[0].value not in withargs: - raise Exception("Set expects two arguments, the first being a stack variable") - if height - withargs[code.args[0].value] > 16: - raise Exception("With statement too deep") - return _compile_to_assembly(code.args[1], withargs, existing_labels, break_dest, height) + [ - "SWAP" + str(height - withargs[code.args[0].value]), - "POP", - ] - - # Pass statements - # TODO remove "dummy"; no longer needed - elif code.value in ("pass", "dummy"): - return [] - - # "mload" from data section of the currently executing code - elif code.value == "dload": - loc = code.args[0] - - o = [] - # codecopy 32 bytes to FREE_VAR_SPACE, then mload from FREE_VAR_SPACE - o.extend(PUSH(32)) - o.extend(_data_ofst_of("_sym_code_end", loc, height + 1)) - o.extend(PUSH(MemoryPositions.FREE_VAR_SPACE) + ["CODECOPY"]) - o.extend(PUSH(MemoryPositions.FREE_VAR_SPACE) + ["MLOAD"]) - return o - - # batch copy from data section of the currently executing code to memory - # (probably should have named this dcopy but oh well) - elif code.value == "dloadbytes": - dst = code.args[0] - src = code.args[1] - len_ = code.args[2] - - o = [] - o.extend(_compile_to_assembly(len_, withargs, existing_labels, break_dest, height)) - o.extend(_data_ofst_of("_sym_code_end", src, height + 1)) - o.extend(_compile_to_assembly(dst, withargs, existing_labels, break_dest, height + 2)) - o.extend(["CODECOPY"]) - return o - - # "mload" from the data section of (to-be-deployed) runtime code - elif code.value == "iload": - loc = code.args[0] - - o = [] - o.extend(_data_ofst_of("_mem_deploy_end", loc, height)) - o.append("MLOAD") - - return o - - # "mstore" to the data section of (to-be-deployed) runtime code - elif code.value == "istore": - loc = code.args[0] - val = code.args[1] - - o = [] - o.extend(_compile_to_assembly(val, withargs, existing_labels, break_dest, height)) - o.extend(_data_ofst_of("_mem_deploy_end", loc, height + 1)) - o.append("MSTORE") - - return o - - # batch copy from memory to the data section of runtime code - elif code.value == "istorebytes": - raise Exception("unimplemented") - - # If statements (2 arguments, ie. if x: y) - elif code.value == "if" and len(code.args) == 2: - o = [] - o.extend(_compile_to_assembly(code.args[0], withargs, existing_labels, break_dest, height)) - end_symbol = mksymbol("join") - o.extend(["ISZERO", end_symbol, "JUMPI"]) - o.extend(_compile_to_assembly(code.args[1], withargs, existing_labels, break_dest, height)) - o.extend([end_symbol, "JUMPDEST"]) - return o - # If statements (3 arguments, ie. if x: y, else: z) - elif code.value == "if" and len(code.args) == 3: - o = [] - o.extend(_compile_to_assembly(code.args[0], withargs, existing_labels, break_dest, height)) - mid_symbol = mksymbol("else") - end_symbol = mksymbol("join") - o.extend(["ISZERO", mid_symbol, "JUMPI"]) - o.extend(_compile_to_assembly(code.args[1], withargs, existing_labels, break_dest, height)) - o.extend([end_symbol, "JUMP", mid_symbol, "JUMPDEST"]) - o.extend(_compile_to_assembly(code.args[2], withargs, existing_labels, break_dest, height)) - o.extend([end_symbol, "JUMPDEST"]) - return o - - # repeat(counter_location, start, rounds, rounds_bound, body) - # basically a do-while loop: - # - # assert(rounds <= rounds_bound) - # if (rounds > 0) { - # do { - # body; - # } while (++i != start + rounds) - # } - elif code.value == "repeat": - o = [] - if len(code.args) != 5: # pragma: nocover - raise CompilerPanic("bad number of repeat args") - - i_name = code.args[0] - start = code.args[1] - rounds = code.args[2] - rounds_bound = code.args[3] - body = code.args[4] - - entry_dest, continue_dest, exit_dest = ( - mksymbol("loop_start"), - mksymbol("loop_continue"), - mksymbol("loop_exit"), - ) - - # stack: [] - o.extend(_compile_to_assembly(start, withargs, existing_labels, break_dest, height)) - - o.extend(_compile_to_assembly(rounds, withargs, existing_labels, break_dest, height + 1)) - - # stack: i - - # assert rounds <= round_bound - if rounds != rounds_bound: - # stack: i, rounds + runtime_bytecode, _ = assembly_to_evm(runtime_assembly) + + runtime_begin = Label("runtime_begin") + o = [] + + runtime_codesize = len(runtime_bytecode) + + mem_deploy_start, mem_deploy_end = _runtime_code_offsets(memsize, runtime_codesize) + + # COPY the code to memory for deploy o.extend( - _compile_to_assembly( - rounds_bound, withargs, existing_labels, break_dest, height + 2 - ) - ) - # stack: i, rounds, rounds_bound - # assert 0 <= rounds <= rounds_bound (for rounds_bound < 2**255) - # TODO this runtime assertion shouldn't fail for - # internally generated repeats. - o.extend(["DUP2", "GT"] + _assert_false()) - - # stack: i, rounds - # if (0 == rounds) { goto end_dest; } - o.extend(["DUP1", "ISZERO", exit_dest, "JUMPI"]) - - # stack: start, rounds - if start.value != 0: - o.extend(["DUP2", "ADD"]) - - # stack: i, exit_i - o.extend(["SWAP1"]) - - if i_name.value in withargs: - raise CompilerPanic(f"shadowed loop variable {i_name}") - withargs[i_name.value] = height + 1 - - # stack: exit_i, i - o.extend([entry_dest, "JUMPDEST"]) - o.extend( - _compile_to_assembly( - body, withargs, existing_labels, (exit_dest, continue_dest, height + 2), height + 2 + [ + *PUSH(runtime_codesize), + PUSHLABEL(runtime_begin), + *PUSH(mem_deploy_start), + "CODECOPY", + ] ) - ) - - del withargs[i_name.value] - - # clean up any stack items left by body - o.extend(["POP"] * body.valency) - - # stack: exit_i, i - # increment i: - o.extend([continue_dest, "JUMPDEST", "PUSH1", 1, "ADD"]) - - # stack: exit_i, i+1 (new_i) - # if (exit_i != new_i) { goto entry_dest } - o.extend(["DUP2", "DUP2", "XOR", entry_dest, "JUMPI"]) - o.extend([exit_dest, "JUMPDEST", "POP", "POP"]) - - return o - - # Continue to the next iteration of the for loop - elif code.value == "continue": - if not break_dest: - raise CompilerPanic("Invalid break") - dest, continue_dest, break_height = break_dest - return [continue_dest, "JUMP"] - # Break from inside a for loop - elif code.value == "break": - if not break_dest: - raise CompilerPanic("Invalid break") - dest, continue_dest, break_height = break_dest - - n_local_vars = height - break_height - # clean up any stack items declared in the loop body - cleanup_local_vars = ["POP"] * n_local_vars - return cleanup_local_vars + [dest, "JUMP"] - # Break from inside one or more for loops prior to a return statement inside the loop - elif code.value == "cleanup_repeat": - if not break_dest: - raise CompilerPanic("Invalid break") - # clean up local vars and internal loop vars - _, _, break_height = break_dest - # except don't pop label params - if "return_buffer" in withargs: - break_height -= 1 - if "return_pc" in withargs: - break_height -= 1 - return ["POP"] * break_height - # With statements - elif code.value == "with": - o = [] - o.extend(_compile_to_assembly(code.args[1], withargs, existing_labels, break_dest, height)) - old = withargs.get(code.args[0].value, None) - withargs[code.args[0].value] = height - o.extend( - _compile_to_assembly(code.args[2], withargs, existing_labels, break_dest, height + 1) - ) - if code.args[2].valency: - o.extend(["SWAP1", "POP"]) - else: - o.extend(["POP"]) - if old is not None: - withargs[code.args[0].value] = old - else: - del withargs[code.args[0].value] - return o - - # runtime statement (used to deploy runtime code) - elif code.value == "deploy": - memsize = code.args[0].value # used later to calculate _mem_deploy_start - ir = code.args[1] - immutables_len = code.args[2].value - assert isinstance(memsize, int), "non-int memsize" - assert isinstance(immutables_len, int), "non-int immutables_len" - - runtime_begin = mksymbol("runtime_begin") - - subcode = _compile_to_assembly(ir) - - o = [] - - # COPY the code to memory for deploy - o.extend(["_sym_subcode_size", runtime_begin, "_mem_deploy_start", "CODECOPY"]) - - # calculate the len of runtime code - o.extend(["_OFST", "_sym_subcode_size", immutables_len]) # stack: len - o.extend(["_mem_deploy_start"]) # stack: len mem_ofst - o.extend(["RETURN"]) - - # since the asm data structures are very primitive, to make sure - # assembly_to_evm is able to calculate data offsets correctly, - # we pass the memsize via magic opcodes to the subcode - subcode = [RuntimeHeader(runtime_begin, memsize, immutables_len)] + subcode - - # append the runtime code after the ctor code - # `append(...)` call here is intentional. - # each sublist is essentially its own program with its - # own symbols. - # in the later step when the "ir" block compiled to EVM, - # symbols in subcode are resolved to position from start of - # runtime-code (instead of position from start of bytecode). - o.append(subcode) - - return o - - # Seq (used to piece together multiple statements) - elif code.value == "seq": - o = [] - for arg in code.args: - o.extend(_compile_to_assembly(arg, withargs, existing_labels, break_dest, height)) - if arg.valency == 1 and arg != code.args[-1]: - o.append("POP") - return o - # Seq without popping. - # unreachable keyword produces INVALID opcode - elif code.value == "assert_unreachable": - o = _compile_to_assembly(code.args[0], withargs, existing_labels, break_dest, height) - end_symbol = mksymbol("reachable") - o.extend([end_symbol, "JUMPI", "INVALID", end_symbol, "JUMPDEST"]) - return o - # Assert (if false, exit) - elif code.value == "assert": - o = _compile_to_assembly(code.args[0], withargs, existing_labels, break_dest, height) - o.extend(["ISZERO"]) - o.extend(_assert_false()) - return o - - # SHA3 a single value - elif code.value == "sha3_32": - o = _compile_to_assembly(code.args[0], withargs, existing_labels, break_dest, height) - o.extend( - [ - *PUSH(MemoryPositions.FREE_VAR_SPACE), - "MSTORE", - *PUSH(32), - *PUSH(MemoryPositions.FREE_VAR_SPACE), - "SHA3", - ] - ) - return o - # SHA3 a 64 byte value - elif code.value == "sha3_64": - o = _compile_to_assembly(code.args[0], withargs, existing_labels, break_dest, height) - o.extend( - _compile_to_assembly(code.args[1], withargs, existing_labels, break_dest, height + 1) - ) - o.extend( - [ - *PUSH(MemoryPositions.FREE_VAR_SPACE2), - "MSTORE", - *PUSH(MemoryPositions.FREE_VAR_SPACE), - "MSTORE", - *PUSH(64), - *PUSH(MemoryPositions.FREE_VAR_SPACE), - "SHA3", - ] - ) - return o - elif code.value == "select": - # b ^ ((a ^ b) * cond) where cond is 1 or 0 - # let t = a ^ b - cond = code.args[0] - a = code.args[1] - b = code.args[2] - - o = [] - o.extend(_compile_to_assembly(b, withargs, existing_labels, break_dest, height)) - o.extend(_compile_to_assembly(a, withargs, existing_labels, break_dest, height + 1)) - # stack: b a - o.extend(["DUP2", "XOR"]) - # stack: b t - o.extend(_compile_to_assembly(cond, withargs, existing_labels, break_dest, height + 2)) - # stack: b t cond - o.extend(["MUL", "XOR"]) - - # stack: b ^ (t * cond) - return o - - # <= operator - elif code.value == "le": - return _compile_to_assembly( - IRnode.from_list(["iszero", ["gt", code.args[0], code.args[1]]]), - withargs, - existing_labels, - break_dest, - height, - ) - # >= operator - elif code.value == "ge": - return _compile_to_assembly( - IRnode.from_list(["iszero", ["lt", code.args[0], code.args[1]]]), - withargs, - existing_labels, - break_dest, - height, - ) - # <= operator - elif code.value == "sle": - return _compile_to_assembly( - IRnode.from_list(["iszero", ["sgt", code.args[0], code.args[1]]]), - withargs, - existing_labels, - break_dest, - height, - ) - # >= operator - elif code.value == "sge": - return _compile_to_assembly( - IRnode.from_list(["iszero", ["slt", code.args[0], code.args[1]]]), - withargs, - existing_labels, - break_dest, - height, - ) - # != operator - elif code.value == "ne": - return _compile_to_assembly( - IRnode.from_list(["iszero", ["eq", code.args[0], code.args[1]]]), - withargs, - existing_labels, - break_dest, - height, - ) - - # e.g. 95 -> 96, 96 -> 96, 97 -> 128 - elif code.value == "ceil32": - # floor32(x) = x - x % 32 == x & 0b11..100000 == x & (~31) - # ceil32(x) = floor32(x + 31) == (x + 31) & (~31) - x = code.args[0] - return _compile_to_assembly( - IRnode.from_list(["and", ["add", x, 31], ["not", 31]]), - withargs, - existing_labels, - break_dest, - height, - ) - - elif code.value == "data": - data_node = [DataHeader("_sym_" + code.args[0].value)] - - for c in code.args[1:]: - if isinstance(c.value, int): - assert 0 <= c < 256, f"invalid data byte {c}" - data_node.append(c.value) - elif isinstance(c.value, bytes): - data_node.append(c.value) - elif isinstance(c, IRnode): - assert c.value == "symbol" - data_node.extend( - _compile_to_assembly(c, withargs, existing_labels, break_dest, height) + + o.append(CONST("mem_deploy_end", mem_deploy_end)) + + # calculate the len of runtime code + immutables size + amount_to_return = runtime_codesize + immutables_len + o.extend([*PUSH(amount_to_return)]) # stack: len + o.extend([*PUSH(mem_deploy_start)]) # stack: len mem_ofst + + o.extend(["RETURN"]) + + self.data_segments.append([DataHeader(runtime_begin), DATA_ITEM(runtime_bytecode)]) + + if self.compiler_metadata is not None: + # we should issue the cbor-encoded metadata. + bytecode_suffix = generate_cbor_metadata( + self.compiler_metadata, + runtime_codesize, + runtime_data_segment_lengths, + immutables_len, ) + + segment: list[AssemblyInstruction] = [DataHeader(Label("cbor_metadata"))] + segment.append(DATA_ITEM(bytecode_suffix)) + self.data_segments.append(segment) + + return o + + # Seq (used to piece together multiple statements) + if code.value == "seq": + o = [] + for arg in code.args: + o.extend(self._compile_r(arg, height)) + if arg.valency == 1 and arg != code.args[-1]: + o.append("POP") + return o + + # Seq without popping. + # unreachable keyword produces INVALID opcode + if code.value == "assert_unreachable": + o = self._compile_r(code.args[0], height) + end_symbol = self.mksymbol("reachable") + o.extend([*JUMPI(end_symbol), "INVALID", end_symbol]) + return o + + # Assert (if false, exit) + if code.value == "assert": + o = self._compile_r(code.args[0], height) + o.extend(["ISZERO"]) + o.extend(self._assert_false()) + return o + + # SHA3 a 64 byte value + if code.value == "sha3_64": + o = self._compile_r(code.args[0], height) + o.extend(self._compile_r(code.args[1], height + 1)) + o.extend( + [ + *PUSH(MemoryPositions.FREE_VAR_SPACE2), + "MSTORE", + *PUSH(MemoryPositions.FREE_VAR_SPACE), + "MSTORE", + *PUSH(64), + *PUSH(MemoryPositions.FREE_VAR_SPACE), + "SHA3", + ] + ) + return o + + if code.value == "select": + # b ^ ((a ^ b) * cond) where cond is 1 or 0 + # let t = a ^ b + cond = code.args[0] + a = code.args[1] + b = code.args[2] + + o = [] + o.extend(self._compile_r(b, height)) + o.extend(self._compile_r(a, height + 1)) + # stack: b a + o.extend(["DUP2", "XOR"]) + # stack: b t + o.extend(self._compile_r(cond, height + 2)) + # stack: b t cond + o.extend(["MUL", "XOR"]) + + # stack: b ^ (t * cond) + return o + + # <= operator + if code.value == "le": + expanded_ir = IRnode.from_list(["iszero", ["gt", code.args[0], code.args[1]]]) + return self._compile_r(expanded_ir, height) + + # >= operator + if code.value == "ge": + expanded_ir = IRnode.from_list(["iszero", ["lt", code.args[0], code.args[1]]]) + return self._compile_r(expanded_ir, height) + # <= operator + if code.value == "sle": + expanded_ir = IRnode.from_list(["iszero", ["sgt", code.args[0], code.args[1]]]) + return self._compile_r(expanded_ir, height) + # >= operator + if code.value == "sge": + expanded_ir = IRnode.from_list(["iszero", ["slt", code.args[0], code.args[1]]]) + return self._compile_r(expanded_ir, height) + + # != operator + if code.value == "ne": + expanded_ir = IRnode.from_list(["iszero", ["eq", code.args[0], code.args[1]]]) + return self._compile_r(expanded_ir, height) + + # e.g. 95 -> 96, 96 -> 96, 97 -> 128 + if code.value == "ceil32": + # floor32(x) = x - x % 32 == x & 0b11..100000 == x & (~31) + # ceil32(x) = floor32(x + 31) == (x + 31) & (~31) + x = code.args[0] + expanded_ir = IRnode.from_list(["and", ["add", x, 31], ["not", 31]]) + return self._compile_r(expanded_ir, height) + + if code.value == "data": + assert isinstance(code.args[0].value, str) # help mypy + + data_header = DataHeader(Label(code.args[0].value)) + data_items = [] + + for c in code.args[1:]: + if isinstance(c.value, bytes): + data_items.append(DATA_ITEM(c.value)) + elif isinstance(c, IRnode): + assert c.value == "symbol" + assert len(c.args) == 1 + assert isinstance(c.args[0].value, str), (type(c.args[0].value), c) + data_items.append(DATA_ITEM(Label(c.args[0].value))) + else: # pragma: nocover + raise ValueError(f"Invalid data: {type(c)} {c}") + + self.data_segments.append([data_header, *data_items]) + return [] + + # jump to a symbol, and push variable # of arguments onto stack + if code.value == "goto": + o = [] + for i, c in enumerate(reversed(code.args[1:])): + o.extend(self._compile_r(c, height + i)) + target = code.args[0].value + assert isinstance(target, str) # help mypy + o.extend([*JUMP(Label(target))]) + return o + + if code.value == "djump": + o = [] + # "djump" compiles to a raw EVM jump instruction + jump_target = code.args[0] + o.extend(self._compile_r(jump_target, height)) + o.append("JUMP") + return o + # push a literal symbol + if code.value == "symbol": + label = code.args[0].value + assert isinstance(label, str) + return [PUSHLABEL(Label(label))] + + # set a symbol as a location. + if code.value == "label": + label_name = code.args[0].value + assert isinstance(label_name, str) + + if label_name in self.existing_labels: # pragma: nocover + raise Exception(f"Label with name {label_name} already exists!") else: - raise ValueError(f"Invalid data: {type(c)} {c}") - - # intentionally return a sublist. - return [data_node] - - # jump to a symbol, and push variable # of arguments onto stack - elif code.value == "goto": - o = [] - for i, c in enumerate(reversed(code.args[1:])): - o.extend(_compile_to_assembly(c, withargs, existing_labels, break_dest, height + i)) - o.extend(["_sym_" + code.args[0].value, "JUMP"]) - return o - elif code.value == "djump": - o = [] - # "djump" compiles to a raw EVM jump instruction - jump_target = code.args[0] - o.extend(_compile_to_assembly(jump_target, withargs, existing_labels, break_dest, height)) - o.append("JUMP") - return o - # push a literal symbol - elif code.value == "symbol": - return ["_sym_" + code.args[0].value] - # set a symbol as a location. - elif code.value == "label": - label_name = code.args[0].value - assert isinstance(label_name, str) - - if label_name in existing_labels: - raise Exception(f"Label with name {label_name} already exists!") - else: - existing_labels.add(label_name) + self.existing_labels.add(label_name) - if code.args[1].value != "var_list": - raise CodegenPanic("2nd arg to label must be var_list") - var_args = code.args[1].args + if code.args[1].value != "var_list": # pragma: nocover + raise CodegenPanic("2nd arg to label must be var_list") + var_args = code.args[1].args - body = code.args[2] + body = code.args[2] - # new scope - height = 0 - withargs = {} + # new scope + height = 0 + old_withargs = self.withargs - for arg in reversed(var_args): - assert isinstance( - arg.value, str - ) # already checked for higher up but only the paranoid survive - withargs[arg.value] = height - height += 1 + self.withargs = {} - body_asm = _compile_to_assembly( - body, withargs=withargs, existing_labels=existing_labels, height=height - ) - # pop_scoped_vars = ["POP"] * height - # for now, _rewrite_return_sequences forces - # label params to be consumed implicitly - pop_scoped_vars = [] + for arg in reversed(var_args): + assert isinstance(arg.value, str) # sanity + self.withargs[arg.value] = height + height += 1 - return ["_sym_" + label_name, "JUMPDEST"] + body_asm + pop_scoped_vars + body_asm = self._compile_r(body, height) + # pop_scoped_vars = ["POP"] * height + # for now, _rewrite_return_sequences forces + # label params to be consumed implicitly + pop_scoped_vars: list = [] - elif code.value == "unique_symbol": - symbol = code.args[0].value - assert isinstance(symbol, str) + self.withargs = old_withargs - if symbol in existing_labels: - raise Exception(f"symbol {symbol} already exists!") - else: - existing_labels.add(symbol) + return [Label(label_name)] + body_asm + pop_scoped_vars - return [] + if code.value == "unique_symbol": + symbol = code.args[0].value + assert isinstance(symbol, str) - elif code.value == "exit_to": - raise CodegenPanic("exit_to not implemented yet!") + if symbol in self.existing_labels: # pragma: nocover + raise Exception(f"symbol {symbol} already exists!") + else: + self.existing_labels.add(symbol) + + return [] + + if code.value == "exit_to": + # currently removed by _rewrite_return_sequences + raise CodegenPanic("exit_to not implemented yet!") + + # inject debug opcode. + if code.value == "debugger": + return mkdebug(pc_debugger=False, ast_source=code.ast_source) + + # inject debug opcode. + if code.value == "pc_debugger": + return mkdebug(pc_debugger=True, ast_source=code.ast_source) + + raise CompilerPanic(f"invalid IRnode: {type(code)} {code}") # pragma: no cover + + def _create_postambles(self): + ret = [] + # for some reason there might not be a STOP at the end of asm_ops. + # (generally vyper programs will have it but raw IR might not). + ret.append("STOP") + + # common revert block + if self.global_revert_label is not None: + ret.extend([self.global_revert_label, *PUSH(0), "DUP1", "REVERT"]) + + return ret + + def _compile_data_segment( + self, segment: list[AssemblyInstruction] + ) -> list[AssemblyInstruction]: + return segment - # inject debug opcode. - elif code.value == "debugger": - return mkdebug(pc_debugger=False, ast_source=code.ast_source) - # inject debug opcode. - elif code.value == "pc_debugger": - return mkdebug(pc_debugger=True, ast_source=code.ast_source) - else: # pragma: no cover - raise ValueError(f"Weird code element: {type(code)} {code}") + def _assert_false(self): + if self.global_revert_label is None: + self.global_revert_label = self.mksymbol("revert") + # use a shared failure block for common case of assert(x). + return JUMPI(self.global_revert_label) -def getpos(node): - return (node.lineno, node.col_offset, node.end_lineno, node.end_col_offset) +############################## +# assembly to evm utils +############################## def note_line_num(line_number_map, pc, item): # Record AST attached to pc - if isinstance(item, Instruction): + if isinstance(item, TaggedInstruction): if (ast_node := item.ast_source) is not None: ast_node = ast_node.get_original_node() if hasattr(ast_node, "node_id"): @@ -806,6 +991,10 @@ def note_breakpoint(line_number_map, pc, item): line_number_map["breakpoints"].add(item.lineno + 1) +############################## +# assembly optimizer +############################## + _TERMINAL_OPS = ("JUMP", "RETURN", "REVERT", "STOP", "INVALID") @@ -816,15 +1005,10 @@ def _prune_unreachable_code(assembly): i = 0 while i < len(assembly) - 1: if assembly[i] in _TERMINAL_OPS: - # find the next jumpdest or sublist + # find the next jumpdest or data section for j in range(i + 1, len(assembly)): - next_is_jumpdest = ( - j < len(assembly) - 1 - and is_symbol(assembly[j]) - and assembly[j + 1] == "JUMPDEST" - ) - next_is_list = isinstance(assembly[j], list) - if next_is_jumpdest or next_is_list: + next_is_reachable = isinstance(assembly[j], (Label, DataHeader)) + if next_is_reachable: break else: # fixup an off-by-one if we made it to the end of the assembly @@ -839,17 +1023,17 @@ def _prune_unreachable_code(assembly): def _prune_inefficient_jumps(assembly): - # prune sequences `_sym_x JUMP _sym_x JUMPDEST` to `_sym_x JUMPDEST` + # prune sequences `PUSHLABEL x JUMP LABEL x` to `LABEL x` changed = False i = 0 - while i < len(assembly) - 4: + while i < len(assembly) - 2: if ( - is_symbol(assembly[i]) + isinstance(assembly[i], PUSHLABEL) and assembly[i + 1] == "JUMP" - and assembly[i] == assembly[i + 2] - and assembly[i + 3] == "JUMPDEST" + and is_symbol(assembly[i + 2]) + and assembly[i + 2] == assembly[i].label ): - # delete _sym_x JUMP + # delete PUSHLABEL x JUMP changed = True del assembly[i : i + 2] else: @@ -859,18 +1043,19 @@ def _prune_inefficient_jumps(assembly): def _optimize_inefficient_jumps(assembly): - # optimize sequences `_sym_common JUMPI _sym_x JUMP _sym_common JUMPDEST` - # to `ISZERO _sym_x JUMPI _sym_common JUMPDEST` + # optimize sequences + # `PUSHLABEL common JUMPI PUSHLABEL x JUMP LABEL common` + # to `ISZERO PUSHLABEL x JUMPI LABEL common` changed = False i = 0 - while i < len(assembly) - 6: + while i < len(assembly) - 4: if ( - is_symbol(assembly[i]) + isinstance(assembly[i], PUSHLABEL) and assembly[i + 1] == "JUMPI" - and is_symbol(assembly[i + 2]) + and isinstance(assembly[i + 2], PUSHLABEL) and assembly[i + 3] == "JUMP" - and assembly[i] == assembly[i + 4] - and assembly[i + 5] == "JUMPDEST" + and isinstance(assembly[i + 4], Label) + and assembly[i].label == assembly[i + 4] ): changed = True assembly[i] = "ISZERO" @@ -891,27 +1076,29 @@ def _merge_jumpdests(assembly): # or some nested if statements.) changed = False i = 0 - while i < len(assembly) - 3: - if is_symbol(assembly[i]) and assembly[i + 1] == "JUMPDEST": + while i < len(assembly) - 2: + # if is_symbol(assembly[i]) and assembly[i + 1] == "JUMPDEST": + if is_symbol(assembly[i]): current_symbol = assembly[i] - if is_symbol(assembly[i + 2]) and assembly[i + 3] == "JUMPDEST": - # _sym_x JUMPDEST _sym_y JUMPDEST - # replace all instances of _sym_x with _sym_y - # (except for _sym_x JUMPDEST - don't want duplicate labels) - new_symbol = assembly[i + 2] + if is_symbol(assembly[i + 1]): + # LABEL x LABEL y + # replace all instances of PUSHLABEL x with PUSHLABEL y + new_symbol = assembly[i + 1] if new_symbol != current_symbol: for j in range(len(assembly)): - if assembly[j] == current_symbol and i != j: - assembly[j] = new_symbol + if ( + isinstance(assembly[j], PUSHLABEL) + and assembly[j].label == current_symbol + ): + assembly[j].label = new_symbol changed = True - elif is_symbol(assembly[i + 2]) and assembly[i + 3] == "JUMP": - # _sym_x JUMPDEST _sym_y JUMP - # replace all instances of _sym_x with _sym_y - # (except for _sym_x JUMPDEST - don't want duplicate labels) - new_symbol = assembly[i + 2] + elif isinstance(assembly[i + 1], PUSHLABEL) and assembly[i + 2] == "JUMP": + # LABEL x PUSHLABEL y JUMP + # replace all instances of PUSHLABEL x with PUSHLABEL y + new_symbol = assembly[i + 1].label for j in range(len(assembly)): - if assembly[j] == current_symbol and i != j: - assembly[j] = new_symbol + if isinstance(assembly[j], PUSHLABEL) and assembly[j].label == current_symbol: + assembly[j].label = new_symbol changed = True i += 1 @@ -955,7 +1142,7 @@ def _merge_iszero(assembly): # but it could also just be a no-op before JUMPI. if ( assembly[i : i + 2] == ["ISZERO", "ISZERO"] - and is_symbol(assembly[i + 2]) + and isinstance(assembly[i + 2], PUSHLABEL) and assembly[i + 3] == "JUMPI" ): changed = True @@ -966,38 +1153,27 @@ def _merge_iszero(assembly): return changed -# a symbol _sym_x in assembly can either mean to push _sym_x to the stack, -# or it can precede a location in code which we want to add to symbol map. -# this helper function tells us if we want to add the previous instruction -# to the symbol map. -def is_symbol_map_indicator(asm_node): - return asm_node == "JUMPDEST" - - def _prune_unused_jumpdests(assembly): changed = False - used_jumpdests = set() + used_jumpdests: set[Label] = set() # find all used jumpdests - for i in range(len(assembly) - 1): - if is_symbol(assembly[i]) and not is_symbol_map_indicator(assembly[i + 1]): - used_jumpdests.add(assembly[i]) - for item in assembly: - if isinstance(item, list) and isinstance(item[0], DataHeader): + if isinstance(item, PUSHLABEL): + used_jumpdests.add(item.label) + + if isinstance(item, DATA_ITEM) and isinstance(item.data, Label): # add symbols used in data sections as they are likely # used for a jumptable. - for t in item: - if is_symbol(t): - used_jumpdests.add(t) + used_jumpdests.add(item.data) # delete jumpdests that aren't used i = 0 - while i < len(assembly) - 2: + while i < len(assembly): if is_symbol(assembly[i]) and assembly[i] not in used_jumpdests: changed = True - del assembly[i : i + 2] + del assembly[i] else: i += 1 @@ -1035,7 +1211,7 @@ def _stack_peephole_opts(assembly): ): changed = True del assembly[i : i + 2] - if assembly[i] == "SWAP1" and assembly[i + 1].lower() in COMMUTATIVE_OPS: + if assembly[i] == "SWAP1" and str(assembly[i + 1]).lower() in COMMUTATIVE_OPS: changed = True del assembly[i] if assembly[i] == "DUP1" and assembly[i + 1] == "SWAP1": @@ -1048,10 +1224,6 @@ def _stack_peephole_opts(assembly): # optimize assembly, in place def optimize_assembly(assembly): - for x in assembly: - if isinstance(x, list) and isinstance(x[0], RuntimeHeader): - optimize_assembly(x) - for _ in range(1024): changed = False @@ -1069,302 +1241,255 @@ def optimize_assembly(assembly): raise CompilerPanic("infinite loop detected during assembly reduction") # pragma: nocover -def adjust_pc_maps(pc_maps, ofst): - assert ofst >= 0 - - ret = {} - # source breakpoints, don't need to modify - ret["breakpoints"] = pc_maps["breakpoints"].copy() - ret["pc_breakpoints"] = {pc + ofst for pc in pc_maps["pc_breakpoints"]} - ret["pc_jump_map"] = {k + ofst: v for (k, v) in pc_maps["pc_jump_map"].items()} - ret["pc_raw_ast_map"] = {k + ofst: v for (k, v) in pc_maps["pc_raw_ast_map"].items()} - ret["error_map"] = {k + ofst: v for (k, v) in pc_maps["error_map"].items()} - - return ret - - SYMBOL_SIZE = 2 # size of a PUSH instruction for a code symbol -def _data_to_evm(assembly, symbol_map): - ret = bytearray() - assert isinstance(assembly[0], DataHeader) - for item in assembly[1:]: - if is_symbol(item): - symbol = symbol_map[item].to_bytes(SYMBOL_SIZE, "big") - ret.extend(symbol) - elif isinstance(item, int): - ret.append(item) - elif isinstance(item, bytes): - ret.extend(item) - else: +# predict what length of an assembly [data] node will be in bytecode +def get_data_segment_lengths(assembly: list[AssemblyInstruction]) -> list[int]: + ret = [] + for item in assembly: + if isinstance(item, DataHeader): + ret.append(0) + continue + if len(ret) == 0: + # haven't yet seen a data header + continue + assert isinstance(item, DATA_ITEM) + if is_symbol(item.data): + ret[-1] += SYMBOL_SIZE + elif isinstance(item.data, bytes): + ret[-1] += len(item.data) + else: # pragma: nocover raise ValueError(f"invalid data {type(item)} {item}") return ret -# predict what length of an assembly [data] node will be in bytecode -def _length_of_data(assembly): - ret = 0 - assert isinstance(assembly[0], DataHeader) - for item in assembly[1:]: - if is_symbol(item): - ret += SYMBOL_SIZE - elif isinstance(item, int): - assert 0 <= item < 256, f"invalid data byte {item}" - ret += 1 - elif isinstance(item, bytes): - ret += len(item) - else: - raise ValueError(f"invalid data {type(item)} {item}") +############################## +# assembly to evm bytecode +############################## - return ret +def _compile_data_item(item: DATA_ITEM, symbol_map: dict[Label, int]) -> bytes: + if isinstance(item.data, bytes): + return item.data + if isinstance(item.data, Label): + symbolbytes = symbol_map[item.data].to_bytes(SYMBOL_SIZE, "big") + return symbolbytes -@dataclass -class RuntimeHeader: - label: str - ctor_mem_size: int - immutables_len: int + raise CompilerPanic(f"Invalid data {type(item.data)}, {item.data}") # pragma: nocover - def __repr__(self): - return f"" +T = TypeVar("T") -@dataclass -class DataHeader: - label: str - def __repr__(self): - return f"DATA {self.label}" - - -def _relocate_segments(assembly): - # relocate all data segments to the end, otherwise data could be - # interpreted as PUSH instructions and mangle otherwise valid jumpdests - # relocate all runtime segments to the end as well - data_segments = [] - non_data_segments = [] - code_segments = [] - for t in assembly: - if isinstance(t, list): - if isinstance(t[0], DataHeader): - data_segments.append(t) - else: - _relocate_segments(t) # recurse - assert isinstance(t[0], RuntimeHeader) - code_segments.append(t) - else: - non_data_segments.append(t) - assembly.clear() - assembly.extend(non_data_segments) - assembly.extend(code_segments) - assembly.extend(data_segments) +def _add_to_symbol_map(symbol_map: dict[T, int], item: T, value: int): + if item in symbol_map: # pragma: nocover + raise CompilerPanic(f"duplicate label: {item}") + symbol_map[item] = value -# TODO: change API to split assembly_to_evm and assembly_to_source/symbol_maps -def assembly_to_evm(assembly, pc_ofst=0, compiler_metadata=None): - bytecode, source_maps, _ = assembly_to_evm_with_symbol_map( - assembly, pc_ofst=pc_ofst, compiler_metadata=compiler_metadata - ) - return bytecode, source_maps +def assembly_to_evm(assembly: list[AssemblyInstruction]) -> tuple[bytes, dict[str, Any]]: + """ + Generate bytecode and source map from assembly + + Returns: + bytecode: bytestring of the EVM bytecode + source_map: source map dict that gets output for the user + """ + # This API might seem a bit strange, but it's backwards compatible + symbol_map, const_map, source_map = resolve_symbols(assembly) + bytecode = _assembly_to_evm(assembly, symbol_map, const_map) + return bytecode, source_map -def assembly_to_evm_with_symbol_map(assembly, pc_ofst=0, compiler_metadata=None): +# resolve symbols in assembly +def resolve_symbols( + assembly: list[AssemblyInstruction], +) -> tuple[dict[Label, int], dict[CONSTREF, int], dict[str, Any]]: """ - Assembles assembly into EVM - - assembly: list of asm instructions - pc_ofst: when constructing the source map, the amount to offset all - pcs by (no effect until we add deploy code source map) - compiler_metadata: any compiler metadata to add. pass `None` to indicate - no metadata to be added (should always be `None` for - runtime code). the value is opaque, and will be passed - directly to `cbor2.dumps()`. + Construct symbol map from assembly list + + Returns: + symbol_map: dict from labels to values + const_map: dict from CONSTREFs to values + source_map: source map dict that gets output for the user """ - line_number_map = { - "breakpoints": set(), - "pc_breakpoints": set(), + source_map: dict[str, Any] = { + "breakpoints": OrderedSet(), + "pc_breakpoints": OrderedSet(), "pc_jump_map": {0: "-"}, "pc_raw_ast_map": {}, "error_map": {}, } - pc = 0 - symbol_map = {} + symbol_map: dict[Label, int] = {} + const_map: dict[CONSTREF, int] = {} - runtime_code, runtime_code_start, runtime_code_end = None, None, None + pc: int = 0 - # to optimize the size of deploy code - we want to use the smallest - # PUSH instruction possible which can support all memory symbols - # (and also works with linear pass symbol resolution) - # to do this, we first do a single pass to compile any runtime code - # and use that to calculate mem_ofst_size. - mem_ofst_size, ctor_mem_size = None, None - max_mem_ofst = 0 - for i, item in enumerate(assembly): - if isinstance(item, list) and isinstance(item[0], RuntimeHeader): - assert runtime_code is None, "Multiple subcodes" - - assert ctor_mem_size is None - ctor_mem_size = item[0].ctor_mem_size - - runtime_code, runtime_map = assembly_to_evm(item[1:]) - - runtime_code_start, runtime_code_end = _runtime_code_offsets( - ctor_mem_size, len(runtime_code) - ) - assert runtime_code_end - runtime_code_start == len(runtime_code) - - if is_ofst(item) and is_mem_sym(assembly[i + 1]): - max_mem_ofst = max(assembly[i + 2], max_mem_ofst) - - if runtime_code_end is not None: - mem_ofst_size = calc_mem_ofst_size(runtime_code_end + max_mem_ofst) - - data_section_lengths = [] - immutables_len = None + # resolve constants + for item in assembly: + if isinstance(item, CONST): + # should this be merged into the symbol map? + _add_to_symbol_map(const_map, CONSTREF(item.name), item.value) - # go through the code, resolving symbolic locations - # (i.e. JUMPDEST locations) to actual code locations + # resolve labels (i.e. JUMPDEST locations) to actual code locations, + # and simultaneously build the source map. for i, item in enumerate(assembly): - note_line_num(line_number_map, pc, item) - if item == "DEBUG": - continue # skip debug + # add it to the source map + note_line_num(source_map, pc, item) # update pc_jump_map if item == "JUMP": last = assembly[i - 1] - if is_symbol(last) and last.startswith("_sym_internal"): - if last.endswith("cleanup"): + if isinstance(last, PUSHLABEL) and last.label.label.startswith("internal"): + if last.label.label.endswith("cleanup"): # exit an internal function - line_number_map["pc_jump_map"][pc] = "o" + source_map["pc_jump_map"][pc] = "o" else: # enter an internal function - line_number_map["pc_jump_map"][pc] = "i" + source_map["pc_jump_map"][pc] = "i" else: # everything else - line_number_map["pc_jump_map"][pc] = "-" + source_map["pc_jump_map"][pc] = "-" elif item in ("JUMPI", "JUMPDEST"): - line_number_map["pc_jump_map"][pc] = "-" + source_map["pc_jump_map"][pc] = "-" - # update pc - if is_symbol(item): - if is_symbol_map_indicator(assembly[i + 1]): - # Don't increment pc as the symbol itself doesn't go into code - if item in symbol_map: - raise CompilerPanic(f"duplicate jumpdest {item}") + if item == "DEBUG": + continue # "debug" opcode does not go into bytecode - symbol_map[item] = pc - else: + if isinstance(item, CONST): + continue # CONST declarations do not go into bytecode + + # update pc + if isinstance(item, Label): + _add_to_symbol_map(symbol_map, item, pc) + pc += 1 # jumpdest + + elif isinstance(item, DataHeader): + # Don't increment pc as the symbol itself doesn't go into code + _add_to_symbol_map(symbol_map, item.label, pc) + + elif isinstance(item, PUSHLABEL): + pc += SYMBOL_SIZE + 1 # PUSH2 highbits lowbits + + elif isinstance(item, PUSH_OFST): + assert isinstance(item.ofst, int), item + # [PUSH_OFST, (Label foo), bar] -> PUSH2 (foo+bar) + # [PUSH_OFST, _mem_foo, bar] -> PUSHN (foo+bar) + if isinstance(item.label, Label): pc += SYMBOL_SIZE + 1 # PUSH2 highbits lowbits - elif is_mem_sym(item): - # PUSH item - pc += mem_ofst_size + 1 - elif is_ofst(item): - assert is_symbol(assembly[i + 1]) or is_mem_sym(assembly[i + 1]) - assert isinstance(assembly[i + 2], int) - # [_OFST, _sym_foo, bar] -> PUSH2 (foo+bar) - # [_OFST, _mem_foo, bar] -> PUSHN (foo+bar) - pc -= 1 - elif isinstance(item, list) and isinstance(item[0], RuntimeHeader): - # we are in initcode - symbol_map[item[0].label] = pc - # add source map for all items in the runtime map - t = adjust_pc_maps(runtime_map, pc) - for key in line_number_map: - line_number_map[key].update(t[key]) - immutables_len = item[0].immutables_len - pc += len(runtime_code) - # grab lengths of data sections from the runtime - for t in item: - if isinstance(t, list) and isinstance(t[0], DataHeader): - data_section_lengths.append(_length_of_data(t)) - - elif isinstance(item, list) and isinstance(item[0], DataHeader): - symbol_map[item[0].label] = pc - pc += _length_of_data(item) + elif isinstance(item.label, CONSTREF): + const = const_map[item.label] + val = const + item.ofst + pc += calc_push_size(val) + else: # pragma: nocover + raise CompilerPanic(f"invalid ofst {item.label}") + + elif isinstance(item, DATA_ITEM): + if isinstance(item.data, Label): + pc += SYMBOL_SIZE + else: + assert isinstance(item.data, bytes) + pc += len(item.data) + elif isinstance(item, int): + assert 0 <= item < 256 + pc += 1 else: + assert isinstance(item, str) and item in get_opcodes(), item pc += 1 - bytecode_suffix = b"" - if compiler_metadata is not None: - # this will hold true when we are in initcode - assert immutables_len is not None - metadata = ( - compiler_metadata, - len(runtime_code), - data_section_lengths, - immutables_len, - {"vyper": version_tuple}, - ) - bytecode_suffix += cbor2.dumps(metadata) - # append the length of the footer, *including* the length - # of the length bytes themselves. - suffix_len = len(bytecode_suffix) + 2 - bytecode_suffix += suffix_len.to_bytes(2, "big") - - pc += len(bytecode_suffix) - - symbol_map["_sym_code_end"] = pc - symbol_map["_mem_deploy_start"] = runtime_code_start - symbol_map["_mem_deploy_end"] = runtime_code_end - if runtime_code is not None: - symbol_map["_sym_subcode_size"] = len(runtime_code) - - # TODO refactor into two functions, create symbol_map and assemble + source_map["breakpoints"] = list(source_map["breakpoints"]) + source_map["pc_breakpoints"] = list(source_map["pc_breakpoints"]) + + # magic -- probably the assembler should actually add this label + _add_to_symbol_map(symbol_map, Label("code_end"), pc) + + return symbol_map, const_map, source_map + + +# helper function +def _compile_push_instruction(assembly: list[AssemblyInstruction]) -> bytes: + push_mnemonic = assembly[0] + assert isinstance(push_mnemonic, str) and push_mnemonic.startswith("PUSH") + push_instr = PUSH_OFFSET + int(push_mnemonic[4:]) + ret = [push_instr] + + for item in assembly[1:]: + assert isinstance(item, int) + ret.append(item) + return bytes(ret) + +def _assembly_to_evm( + assembly: list[AssemblyInstruction], + symbol_map: dict[Label, int], + const_map: dict[CONSTREF, int], +) -> bytes: + """ + Assembles assembly into EVM bytecode + + Parameters: + assembly: list of asm instructions + symbol_map: dict from labels to resolved locations in the code + const_map: dict from constrefs to their values + + Returns: bytes representing the bytecode + """ ret = bytearray() # now that all symbols have been resolved, generate bytecode # using the symbol map - to_skip = 0 - for i, item in enumerate(assembly): - if to_skip > 0: - to_skip -= 1 - continue - + for item in assembly: if item in ("DEBUG",): continue # skippable opcodes + elif isinstance(item, CONST): + continue # CONST things do not show up in bytecode + elif isinstance(item, DataHeader): + continue # DataHeader does not show up in bytecode - elif is_symbol(item): + elif isinstance(item, PUSHLABEL): # push a symbol to stack - if not is_symbol_map_indicator(assembly[i + 1]): - bytecode, _ = assembly_to_evm(PUSH_N(symbol_map[item], n=SYMBOL_SIZE)) - ret.extend(bytecode) - - elif is_mem_sym(item): - bytecode, _ = assembly_to_evm(PUSH_N(symbol_map[item], n=mem_ofst_size)) + label = item.label + bytecode = _compile_push_instruction(PUSH_N(symbol_map[label], n=SYMBOL_SIZE)) ret.extend(bytecode) - elif is_ofst(item): - # _OFST _sym_foo 32 - ofst = symbol_map[assembly[i + 1]] + assembly[i + 2] - n = mem_ofst_size if is_mem_sym(assembly[i + 1]) else SYMBOL_SIZE - bytecode, _ = assembly_to_evm(PUSH_N(ofst, n)) + elif isinstance(item, Label): + jumpdest_opcode = get_opcodes()["JUMPDEST"][0] + assert jumpdest_opcode is not None # help mypy + ret.append(jumpdest_opcode) + + elif isinstance(item, PUSH_OFST): + # PUSH_OFST (LABEL foo) 32 + # PUSH_OFST (const foo) 32 + if isinstance(item.label, Label): + ofst = symbol_map[item.label] + item.ofst + bytecode = _compile_push_instruction(PUSH_N(ofst, SYMBOL_SIZE)) + else: + assert isinstance(item.label, CONSTREF) + ofst = const_map[item.label] + item.ofst + bytecode = _compile_push_instruction(PUSH(ofst)) + ret.extend(bytecode) - to_skip = 2 elif isinstance(item, int): ret.append(item) elif isinstance(item, str) and item.upper() in get_opcodes(): - ret.append(get_opcodes()[item.upper()][0]) + opcode = get_opcodes()[item.upper()][0] + # TODO: fix signature of get_opcodes() + assert opcode is not None # help mypy + ret.append(opcode) + elif isinstance(item, DATA_ITEM): + ret.extend(_compile_data_item(item, symbol_map)) elif item[:4] == "PUSH": ret.append(PUSH_OFFSET + int(item[4:])) elif item[:3] == "DUP": ret.append(DUP_OFFSET + int(item[3:])) elif item[:4] == "SWAP": ret.append(SWAP_OFFSET + int(item[4:])) - elif isinstance(item, list) and isinstance(item[0], RuntimeHeader): - ret.extend(runtime_code) - elif isinstance(item, list) and isinstance(item[0], DataHeader): - ret.extend(_data_to_evm(item, symbol_map)) else: # pragma: no cover # unreachable raise ValueError(f"Weird symbol in assembly: {type(item)} {item}") - ret.extend(bytecode_suffix) - - line_number_map["breakpoints"] = list(line_number_map["breakpoints"]) - line_number_map["pc_breakpoints"] = list(line_number_map["pc_breakpoints"]) - return bytes(ret), line_number_map, symbol_map + return bytes(ret) diff --git a/vyper/venom/__init__.py b/vyper/venom/__init__.py index 4b17dfcecb..5858570cae 100644 --- a/vyper/venom/__init__.py +++ b/vyper/venom/__init__.py @@ -7,8 +7,10 @@ from vyper.compiler.settings import OptimizationLevel, Settings from vyper.evm.address_space import MEMORY, STORAGE, TRANSIENT from vyper.exceptions import CompilerPanic +from vyper.ir.compile_ir import AssemblyInstruction from vyper.venom.analysis import MemSSA from vyper.venom.analysis.analysis import IRAnalysesCache +from vyper.venom.basicblock import IRLabel, IRLiteral from vyper.venom.context import IRContext from vyper.venom.function import IRFunction from vyper.venom.ir_node_to_venom import ir_node_to_venom @@ -40,18 +42,10 @@ def generate_assembly_experimental( - runtime_code: IRContext, - deploy_code: Optional[IRContext] = None, - optimize: OptimizationLevel = DEFAULT_OPT_LEVEL, -) -> list[str]: - # note: VenomCompiler is sensitive to the order of these! - if deploy_code is not None: - functions = [deploy_code, runtime_code] - else: - functions = [runtime_code] - - compiler = VenomCompiler(functions) - return compiler.generate_evm(optimize == OptimizationLevel.NONE) + venom_ctx: IRContext, optimize: OptimizationLevel = DEFAULT_OPT_LEVEL +) -> list[AssemblyInstruction]: + compiler = VenomCompiler(venom_ctx) + return compiler.generate_evm_assembly(optimize == OptimizationLevel.NONE) def _run_passes(fn: IRFunction, optimize: OptimizationLevel, ac: IRAnalysesCache) -> None: @@ -129,9 +123,23 @@ def run_passes_on(ctx: IRContext, optimize: OptimizationLevel) -> None: _run_passes(fn, optimize, ir_analyses[fn]) -def generate_ir(ir: IRnode, settings: Settings) -> IRContext: +def generate_venom( + ir: IRnode, + settings: Settings, + constants: dict[str, int] = None, + data_sections: dict[str, bytes] = None, +) -> IRContext: # Convert "old" IR to "new" IR - ctx = ir_node_to_venom(ir) + constants = constants or {} + ctx = ir_node_to_venom(ir, constants) + + data_sections = data_sections or {} + for section_name, data in data_sections.items(): + ctx.append_data_section(IRLabel(section_name)) + ctx.append_data_item(data) + + for constname, value in constants.items(): + ctx.add_constant(constname, value) optimize = settings.optimize assert optimize is not None # help mypy diff --git a/vyper/venom/analysis/dfg.py b/vyper/venom/analysis/dfg.py index bef9c5f7be..16b8c254a9 100644 --- a/vyper/venom/analysis/dfg.py +++ b/vyper/venom/analysis/dfg.py @@ -85,7 +85,7 @@ def analyze(self): inputs.add(inst) for op in res: # type: ignore - assert isinstance(op, IRVariable) + assert isinstance(op, IRVariable), op self._dfg_outputs[op] = inst def as_graph(self) -> str: diff --git a/vyper/venom/basicblock.py b/vyper/venom/basicblock.py index 27d1e2c7fd..a58a8ef54b 100644 --- a/vyper/venom/basicblock.py +++ b/vyper/venom/basicblock.py @@ -14,9 +14,7 @@ from vyper.venom.function import IRFunction # instructions which can terminate a basic block -BB_TERMINATORS = frozenset( - ["jmp", "djmp", "jnz", "ret", "return", "revert", "stop", "exit", "sink"] -) +BB_TERMINATORS = frozenset(["jmp", "djmp", "jnz", "ret", "return", "revert", "stop", "sink"]) VOLATILE_INSTRUCTIONS = frozenset( [ @@ -50,7 +48,6 @@ "assert", "assert_unreachable", "stop", - "exit", ] ) @@ -79,7 +76,6 @@ "djmp", "jnz", "log", - "exit", "nop", ] ) @@ -247,6 +243,8 @@ def __init__( ): assert isinstance(opcode, str), "opcode must be an str" assert isinstance(operands, list | Iterator), "operands must be a list" + if output is not None: + assert isinstance(output, IRVariable), output self.opcode = opcode self.operands = list(operands) # in case we get an iterator self.output = output diff --git a/vyper/venom/context.py b/vyper/venom/context.py index 30fac4875d..f50dc1220f 100644 --- a/vyper/venom/context.py +++ b/vyper/venom/context.py @@ -33,8 +33,7 @@ def __str__(self): class IRContext: functions: dict[IRLabel, IRFunction] entry_function: Optional[IRFunction] - ctor_mem_size: Optional[int] - immutables_len: Optional[int] + constants: dict[str, int] # globally defined constants data_segment: list[DataSection] last_label: int last_variable: int @@ -42,9 +41,9 @@ class IRContext: def __init__(self) -> None: self.functions = {} self.entry_function = None - self.ctor_mem_size = None - self.immutables_len = None self.data_segment = [] + self.constants = {} + self.last_label = 0 self.last_variable = 0 @@ -99,6 +98,10 @@ def append_data_item(self, data: IRLabel | bytes) -> None: data_section = self.data_segment[-1] data_section.data_items.append(DataItem(data)) + def add_constant(self, name: str, value: int) -> None: + assert name not in self.constants + self.constants[name] = value + def as_graph(self) -> str: s = ["digraph G {"] for fn in self.functions.values(): diff --git a/vyper/venom/function.py b/vyper/venom/function.py index 3ad63b207a..efb9235caa 100644 --- a/vyper/venom/function.py +++ b/vyper/venom/function.py @@ -126,15 +126,13 @@ def freshen_varnames(self) -> None: continue inst.operands[i] = varmap[op] + # TODO: move these to the IR builder class def push_source(self, ir): - if isinstance(ir, IRnode): - self._ast_source_stack.append(ir.ast_source) - self._error_msg_stack.append(ir.error_msg) + self._ast_source_stack.append(ir.ast_source) + self._error_msg_stack.append(ir.error_msg) def pop_source(self): - assert len(self._ast_source_stack) > 0, "Empty source stack" self._ast_source_stack.pop() - assert len(self._error_msg_stack) > 0, "Empty error stack" self._error_msg_stack.pop() def get_param_by_id(self, id_: int) -> Optional[IRParameter]: diff --git a/vyper/venom/ir_node_to_venom.py b/vyper/venom/ir_node_to_venom.py index 65793ea5c0..006559c58e 100644 --- a/vyper/venom/ir_node_to_venom.py +++ b/vyper/venom/ir_node_to_venom.py @@ -1,13 +1,14 @@ from __future__ import annotations +import contextlib import functools import re from collections import defaultdict -from typing import Optional +from typing import Iterable, Optional from vyper.codegen.context import Alloca from vyper.codegen.ir_node import IRnode -from vyper.evm.opcodes import get_opcodes +from vyper.ir.compile_ir import _runtime_code_offsets from vyper.venom.basicblock import ( IRBasicBlock, IRInstruction, @@ -69,6 +70,7 @@ "gas", "gasprice", "gaslimit", + "return", "returndatasize", "iload", "istore", @@ -95,7 +97,6 @@ "selfdestruct", "assert", "assert_unreachable", - "exit", "calldatacopy", "mcopy", "extcodecopy", @@ -114,498 +115,433 @@ NOOP_INSTRUCTIONS = frozenset(["pass", "cleanup_repeat", "var_list", "unique_symbol"]) -SymbolTable = dict[str, IROperand] -_alloca_table: dict[int, IROperand] -_callsites: dict[str, list[Alloca]] +SymbolTable = dict[str, IRVariable] MAIN_ENTRY_LABEL_NAME = "__main_entry" -_scratch_alloca_id = 2**32 +class IRnodeToVenom: + _alloca_table: dict[int, IROperand] + _callsites: dict[str, list[Alloca]] -def get_scratch_alloca_id() -> int: - global _scratch_alloca_id - _scratch_alloca_id += 1 - return _scratch_alloca_id + _scratch_alloca_id: int = 2**32 + # current vyper function + _current_func_t = None -# convert IRnode directly to venom -def ir_node_to_venom(ir: IRnode) -> IRContext: - _ = ir.unique_symbols # run unique symbols check + _break_target: Optional[IRBasicBlock] = None + _continue_target: Optional[IRBasicBlock] = None - global _alloca_table, _callsites - _alloca_table = {} - _callsites = defaultdict(list) + constants: dict[str, int] - ctx = IRContext() - fn = ctx.create_function(MAIN_ENTRY_LABEL_NAME) - ctx.entry_function = fn + variables: dict[str, IRVariable] - _convert_ir_bb(fn, ir, {}) + def __init__(self, constants: dict[str, int]): + self._alloca_table = {} + self._callsites = defaultdict(list) - for fn in ctx.functions.values(): - for bb in fn.get_basic_blocks(): - bb.ensure_well_formed() + self.constants = constants - return ctx + def convert(self, ir: IRnode) -> IRContext: + ctx = IRContext() + fn = ctx.create_function(MAIN_ENTRY_LABEL_NAME) + ctx.entry_function = fn + self.variables = {} -def _append_jmp(fn: IRFunction, label: IRLabel) -> None: - bb = fn.get_basic_block() - if bb.is_terminated: - bb = IRBasicBlock(fn.ctx.get_next_label("jmp_target"), fn) - fn.append_basic_block(bb) - - bb.append_instruction("jmp", label) - - -def _new_block(fn: IRFunction) -> None: - bb = IRBasicBlock(fn.ctx.get_next_label(), fn) - fn.append_basic_block(bb) - - -def _append_return_args(fn: IRFunction, ofst: int = 0, size: int = 0): - bb = fn.get_basic_block() - if bb.is_terminated: - bb = IRBasicBlock(fn.ctx.get_next_label("exit_to"), fn) - fn.append_basic_block(bb) - ret_ofst = IRVariable("ret_ofst") - ret_size = IRVariable("ret_size") - bb.append_instruction("store", ofst, ret=ret_ofst) - bb.append_instruction("store", size, ret=ret_size) - - -# func_t: ContractFunctionT -@functools.lru_cache(maxsize=1024) -def _pass_via_stack(func_t) -> dict[str, bool]: - # returns a dict which returns True if a given argument (referered to - # by name) should be passed via the stack - if not ENABLE_NEW_CALL_CONV: - return {arg.name: False for arg in func_t.arguments} - - arguments = {arg.name: arg for arg in func_t.arguments} - - stack_items = 0 - returns_word = _returns_word(func_t) - if returns_word: - stack_items += 1 - - ret = {} - - for arg in arguments.values(): - if not _is_word_type(arg.typ) or stack_items > MAX_STACK_ARGS: - ret[arg.name] = False - else: - ret[arg.name] = True - stack_items += 1 - - return ret - - -def _handle_self_call(fn: IRFunction, ir: IRnode, symbols: SymbolTable) -> Optional[IROperand]: - global _callsites - setup_ir = ir.args[1] - goto_ir = [ir for ir in ir.args if ir.value == "goto"][0] - target_label = goto_ir.args[0].value # goto - - func_t = ir.passthrough_metadata["func_t"] - assert func_t is not None, "func_t not found in passthrough metadata" - - returns_word = _returns_word(func_t) - - if setup_ir != goto_ir: - _convert_ir_bb(fn, setup_ir, symbols) - - converted_args = _convert_ir_bb_list(fn, goto_ir.args[1:], symbols) - - callsite_op = converted_args[-1] - assert isinstance(callsite_op, IRLabel), converted_args - callsite = callsite_op.value - - bb = fn.get_basic_block() - return_buf = None - - if len(converted_args) > 1: - return_buf = converted_args[0] - - stack_args: list[IROperand] = [IRLabel(str(target_label))] - - if return_buf is not None: - if not ENABLE_NEW_CALL_CONV or not returns_word: - stack_args.append(return_buf) # type: ignore - - callsite_args = _callsites[callsite] - if ENABLE_NEW_CALL_CONV: - for alloca in callsite_args: - if not _pass_via_stack(func_t)[alloca.name]: - continue - ptr = _alloca_table[alloca._id] - stack_arg = bb.append_instruction("mload", ptr) - assert stack_arg is not None - stack_args.append(stack_arg) - - if returns_word: - ret_value = bb.append_invoke_instruction(stack_args, returns=True) # type: ignore - assert ret_value is not None - assert isinstance(return_buf, IROperand) - bb.append_instruction("mstore", ret_value, return_buf) - return return_buf - - bb.append_invoke_instruction(stack_args, returns=False) # type: ignore - - return return_buf - - -_current_func_t = None - - -def _is_word_type(typ): - # we can pass it on the stack. - return typ.memory_bytes_required == 32 - - -# func_t: ContractFunctionT -def _returns_word(func_t) -> bool: - return_t = func_t.return_type - return return_t is not None and _is_word_type(return_t) + self.fn = fn + self.convert_ir(ir) + self.finish(ctx) -def _handle_internal_func( - # TODO: remove does_return_data, replace with `func_t.return_type is not None` - fn: IRFunction, - ir: IRnode, - does_return_data: bool, - symbols: SymbolTable, -) -> IRFunction: - global _alloca_table, _current_func_t + return ctx - func_t = ir.passthrough_metadata["func_t"] - context = ir.passthrough_metadata["context"] - assert func_t is not None, "func_t not found in passthrough metadata" - assert context is not None, func_t.name + def finish(self, ctx: IRContext): + for fn in ctx.functions.values(): + for bb in fn.get_basic_blocks(): + bb.ensure_well_formed() - _current_func_t = func_t + def get_scratch_alloca_id(self): + self._scratch_alloca_id += 1 + return self._scratch_alloca_id - funcname = ir.args[0].args[0].value - assert isinstance(funcname, str) - fn = fn.ctx.create_function(funcname) + def convert_ir(self, ir: IRnode): + _ = ir.unique_symbols # run unique symbols check - bb = fn.get_basic_block() + self.fn.push_source(ir) + ret = self._convert_ir(ir) + self.fn.pop_source() - _saved_alloca_table = _alloca_table - _alloca_table = {} + return ret - returns_word = _returns_word(func_t) + @contextlib.contextmanager + def anchor_fn(self, new_fn: IRFunction): + tmp = self.fn + try: + self.fn = new_fn + yield + finally: + self.fn = tmp + + @contextlib.contextmanager + def anchor_variables(self, new_variables: Optional[SymbolTable] = None): + if new_variables is None: + new_variables = self.variables.copy() + + tmp = self.variables + try: + self.variables = new_variables + yield + finally: + self.variables = tmp + + # globally agreed upon ret_ofst/ret_size variables + RET_OFST = IRVariable("ret_ofst") + RET_SIZE = IRVariable("ret_size") + + def _append_return_args(self, ofst: int, size: int): + fn = self.fn + bb = fn.get_basic_block() + if bb.is_terminated: + # NOTE: this generates dead code + bb = IRBasicBlock(fn.ctx.get_next_label("exit_to"), fn) + fn.append_basic_block(bb) - # return buffer - if does_return_data: - if ENABLE_NEW_CALL_CONV and returns_word: - # TODO: remove this once we have proper memory allocator - # functionality in venom. Currently, we hardcode the scratch - # buffer size of 32 bytes. - # TODO: we don't need to use scratch space once the legacy optimizer - # is disabled. - buf = bb.append_instruction("alloca", 0, 32, get_scratch_alloca_id()) - else: - buf = bb.append_instruction("param") - bb.instructions[-1].annotation = "return_buffer" - - assert buf is not None # help mypy - symbols["return_buffer"] = buf - - if ENABLE_NEW_CALL_CONV: - stack_index = 0 - if func_t.return_type is not None and not _returns_word(func_t): - stack_index += 1 - for arg in func_t.arguments: - if not _pass_via_stack(func_t)[arg.name]: - continue - - param = bb.append_instruction("param") - bb.instructions[-1].annotation = arg.name - assert param is not None # help mypy - - var = context.lookup_var(arg.name) - - venom_arg = IRParameter( - name=var.name, - index=stack_index, - offset=var.alloca.offset, - size=var.alloca.size, - id_=var.alloca._id, - call_site_var=None, - func_var=param, - addr_var=None, + bb.append_instruction("store", ofst, ret=self.RET_OFST) + bb.append_instruction("store", size, ret=self.RET_SIZE) + + def _convert_ir_simple(self, ir: IRnode) -> Optional[IRVariable]: + # execute in order + args = self._convert_ir_list(ir.args) + # reverse output variables for stack + args.reverse() + assert isinstance(ir.value, str) # help mypy + return self.fn.get_basic_block().append_instruction(ir.value, *args) + + def _convert_ir_list(self, ir_list: Iterable[IRnode]): + return [self.convert_ir(ir_node) for ir_node in ir_list] + + def _convert_ir(self, ir: IRnode): + fn = self.fn + ctx = fn.ctx + + if isinstance(ir.value, int): + return IRLiteral(ir.value) + + elif ir.value in INVERSE_MAPPED_IR_INSTRUCTIONS: + orig_value = ir.value + ir.value = INVERSE_MAPPED_IR_INSTRUCTIONS[ir.value] + new_var = self._convert_ir_simple(ir) + assert new_var is not None # help mypy + ir.value = orig_value + return fn.get_basic_block().append_instruction("iszero", new_var) + + elif ir.value in PASS_THROUGH_INSTRUCTIONS: + return self._convert_ir_simple(ir) + + elif ir.value == "deploy": + ctor_mem_size = ir.args[0].value + immutables_len = ir.args[2].value + runtime_codesize = self.constants["runtime_codesize"] + assert immutables_len == self.constants["immutables_len"] # sanity + assert isinstance(immutables_len, int) # help mypy + + mem_deploy_start, mem_deploy_end = _runtime_code_offsets( + ctor_mem_size, runtime_codesize ) - fn.args.append(venom_arg) - stack_index += 1 - - # return address - return_pc = bb.append_instruction("param") - assert return_pc is not None # help mypy - symbols["return_pc"] = return_pc - bb.instructions[-1].annotation = "return_pc" - # convert the body of the function - _convert_ir_bb(fn, ir.args[0].args[2], symbols) + fn.ctx.add_constant("mem_deploy_end", mem_deploy_end) - _alloca_table = _saved_alloca_table + bb = fn.get_basic_block() - return fn + bb.append_instruction( + "codecopy", runtime_codesize, IRLabel("runtime_begin"), mem_deploy_start + ) + amount_to_return = bb.append_instruction("add", runtime_codesize, immutables_len) + assert amount_to_return is not None # help mypy + bb.append_instruction("return", amount_to_return, mem_deploy_start) + return None + elif ir.value == "seq": + if len(ir.args) == 0: + return None + if ir.is_self_call: + return self._handle_self_call(ir) + elif ir.args[0].value == "label": + labelvalue = ir.args[0].args[0].value + assert isinstance(labelvalue, str) # mypy + is_external = labelvalue.startswith("external") + is_internal = labelvalue.startswith("internal") + if is_internal or len(re.findall(r"external.*__init__\(.*_deploy", labelvalue)) > 0: + # Internal definition + var_list = ir.args[0].args[1] + assert var_list.value == "var_list" + + does_return_data = IRnode.from_list(["return_buffer"]) in var_list.args + + new_variables: SymbolTable = {} + with self.anchor_variables(new_variables): + new_fn = self._handle_internal_func(ir, does_return_data) + with self.anchor_fn(new_fn): + for ir_node in ir.args[1:]: + ret = self.convert_ir(ir_node) + + return None + + assert is_external + + # "parameters" to the exit sequence block + self.variables["ret_ofst"] = self.RET_OFST + self.variables["ret_len"] = self.RET_SIZE + ret = self.convert_ir(ir.args[0]) -def _convert_ir_simple_node( - fn: IRFunction, ir: IRnode, symbols: SymbolTable -) -> Optional[IRVariable]: - # execute in order - args = _convert_ir_bb_list(fn, ir.args, symbols) - # reverse output variables for stack - args.reverse() - return fn.get_basic_block().append_instruction(ir.value, *args) # type: ignore + else: + bb = fn.get_basic_block() + if bb.is_terminated: + bb = IRBasicBlock(ctx.get_next_label("seq"), fn) + fn.append_basic_block(bb) + ret = self.convert_ir(ir.args[0]) + for ir_node in ir.args[1:]: + # seq returns the last item in the list + ret = self.convert_ir(ir_node) -_break_target: Optional[IRBasicBlock] = None -_continue_target: Optional[IRBasicBlock] = None + return ret + elif ir.value == "if": + return self._handle_if_stmt(ir) -def _convert_ir_bb_list(fn, ir, symbols): - ret = [] - for ir_node in ir: - venom = _convert_ir_bb(fn, ir_node, symbols) - ret.append(venom) - return ret + elif ir.value == "with": + varname = ir.args[0].value + # compute the initial value for the variable + ret = self.convert_ir(ir.args[1]) + # ensure it is stored in a variable + ret = fn.get_basic_block().append_instruction("store", ret) -def pop_source_on_return(func): - @functools.wraps(func) - def pop_source(*args, **kwargs): - fn = args[0] - ret = func(*args, **kwargs) - fn.pop_source() - return ret + body_ir = ir.args[2] + with self.anchor_variables(): + assert isinstance(varname, str) + # `with` allows shadowing + self.variables[varname] = ret + return self.convert_ir(body_ir) - return pop_source + elif ir.value == "goto": + bb = fn.get_basic_block() + if bb.is_terminated: + # TODO: this branch seems dead, investigate. + bb = IRBasicBlock(fn.ctx.get_next_label("jmp_target"), fn) + fn.append_basic_block(bb) -@pop_source_on_return -def _convert_ir_bb(fn, ir, symbols): - assert isinstance(ir, IRnode), ir - # TODO: refactor these to not be globals - global _break_target, _continue_target, _alloca_table + assert isinstance(ir.args[0].value, str) # mypy + bb.append_instruction("jmp", IRLabel(ir.args[0].value)) + + elif ir.value == "djump": + args = [self.convert_ir(ir.args[0])] + for target in ir.args[1:]: + assert isinstance(target.value, str) # mypy + args.append(IRLabel(target.value)) + fn.get_basic_block().append_instruction("djmp", *args) + self._append_new_bb() + return + + elif ir.value == "set": + varname = ir.args[0].value + assert isinstance(varname, str) + arg_1 = self.convert_ir(ir.args[1]) + venom_var = self.variables[varname] + fn.get_basic_block().append_instruction("store", arg_1, ret=venom_var) + return + + elif ir.value == "symbol": + assert isinstance(ir.args[0].value, str) # mypy + return IRLabel(ir.args[0].value, True) + + elif ir.value == "data": + assert isinstance(ir.args[0].value, str) # mypy + label = IRLabel(ir.args[0].value, True) + ctx.append_data_section(label) + for c in ir.args[1:]: + if isinstance(c.value, bytes): + ctx.append_data_item(c.value) + elif isinstance(c, IRnode): + data = self.convert_ir(c) + assert isinstance(data, IRLabel) # help mypy + ctx.append_data_item(data) + + elif ir.value == "label": + assert isinstance(ir.args[0].value, str) # mypy + label = IRLabel(ir.args[0].value, True) + bb = fn.get_basic_block() + if not bb.is_terminated: + bb.append_instruction("jmp", label) + bb = IRBasicBlock(label, fn) + fn.append_basic_block(bb) + code = ir.args[2] + self.convert_ir(code) + + elif ir.value == "exit_to": + return self._handle_exit_to(ir) + elif ir.value == "mstore": + # some upstream code depends on reversed order of evaluation -- + # to fix upstream. + val, ptr = self._convert_ir_list(reversed(ir.args)) + return fn.get_basic_block().append_instruction("mstore", val, ptr) + + elif ir.value == "ceil32": + x = ir.args[0] + expanded = IRnode.from_list(["and", ["add", x, 31], ["not", 31]]) + return self.convert_ir(expanded) + + elif ir.value == "select": + cond, a, b = ir.args + expanded = IRnode.from_list( + [ + "with", + "cond", + cond, + [ + "with", + "a", + a, + ["with", "b", b, ["xor", "b", ["mul", "cond", ["xor", "a", "b"]]]], + ], + ] + ) + return self.convert_ir(expanded) + + elif ir.value == "repeat": + return self._handle_repeat(ir) + + elif ir.value == "break": + assert self._break_target is not None + fn.get_basic_block().append_instruction("jmp", self._break_target.label) + self._append_new_bb() + + elif ir.value == "continue": + assert self._continue_target is not None + fn.get_basic_block().append_instruction("jmp", self._continue_target.label) + self._append_new_bb() + + elif ir.value in NOOP_INSTRUCTIONS: + pass + + elif isinstance(ir.value, str) and ir.value.startswith("log"): + log_args = reversed(self._convert_ir_list(ir.args)) + topic_count = int(ir.value[3:]) + assert topic_count >= 0 and topic_count <= 4, "invalid topic count" + fn.get_basic_block().append_instruction("log", topic_count, *log_args) + + elif isinstance(ir.value, str): + if ir.value.startswith("$alloca"): + alloca = ir.passthrough_metadata["alloca"] + if alloca._id not in self._alloca_table: + ptr = fn.get_basic_block().append_instruction( + "alloca", alloca.offset, alloca.size, alloca._id + ) + self._alloca_table[alloca._id] = ptr + return self._alloca_table[alloca._id] + + elif ir.value.startswith("$palloca"): + alloca = ir.passthrough_metadata["alloca"] + if alloca._id not in self._alloca_table: + bb = fn.get_basic_block() + ptr = bb.append_instruction("palloca", alloca.offset, alloca.size, alloca._id) + bb.instructions[-1].annotation = f"{alloca.name} (memory)" + if ENABLE_NEW_CALL_CONV and _pass_via_stack(self._current_func_t)[alloca.name]: + param = fn.get_param_by_id(alloca._id) + assert param is not None + bb.append_instruction("mstore", param.func_var, ptr) + self._alloca_table[alloca._id] = ptr + return self._alloca_table[alloca._id] + elif ir.value.startswith("$calloca"): + alloca = ir.passthrough_metadata["alloca"] + assert alloca._callsite is not None + if alloca._id not in self._alloca_table: + bb = fn.get_basic_block() + + callsite_func = ir.passthrough_metadata["callsite_func"] + if ENABLE_NEW_CALL_CONV and _pass_via_stack(callsite_func)[alloca.name]: + ptr = bb.append_instruction( + "alloca", alloca.offset, alloca.size, alloca._id + ) + else: + # if we use alloca, mstores might get removed. convert + # to calloca until memory analysis is more sound. + ptr = bb.append_instruction( + "calloca", alloca.offset, alloca.size, alloca._id + ) + + self._alloca_table[alloca._id] = ptr + ret = self._alloca_table[alloca._id] + # assumption: callocas appear in the same order as the + # order of arguments to the function. + self._callsites[alloca._callsite].append(alloca) + return ret - # keep a map from external functions to all possible entry points + return self.variables[ir.value] - ctx = fn.ctx - fn.push_source(ir) + else: + raise Exception(f"Unknown IR node: {ir}") - if ir.value in INVERSE_MAPPED_IR_INSTRUCTIONS: - org_value = ir.value - ir.value = INVERSE_MAPPED_IR_INSTRUCTIONS[ir.value] - new_var = _convert_ir_simple_node(fn, ir, symbols) - ir.value = org_value - return fn.get_basic_block().append_instruction("iszero", new_var) - elif ir.value in PASS_THROUGH_INSTRUCTIONS: - return _convert_ir_simple_node(fn, ir, symbols) - elif ir.value == "return": - fn.get_basic_block().append_instruction( - "return", IRVariable("ret_size"), IRVariable("ret_ofst") - ) - elif ir.value == "deploy": - ctx.ctor_mem_size = ir.args[0].value - ctx.immutables_len = ir.args[2].value - fn.get_basic_block().append_instruction("exit") return None - elif ir.value == "seq": - if len(ir.args) == 0: - return None - if ir.is_self_call: - return _handle_self_call(fn, ir, symbols) - elif ir.args[0].value == "label": - current_func = ir.args[0].args[0].value - is_external = current_func.startswith("external") - is_internal = current_func.startswith("internal") - if is_internal or len(re.findall(r"external.*__init__\(.*_deploy", current_func)) > 0: - # Internal definition - var_list = ir.args[0].args[1] - assert var_list.value == "var_list" - does_return_data = IRnode.from_list(["return_buffer"]) in var_list.args - symbols = {} - new_fn = _handle_internal_func(fn, ir, does_return_data, symbols) - for ir_node in ir.args[1:]: - ret = _convert_ir_bb(new_fn, ir_node, symbols) - return ret - elif is_external: - ret = _convert_ir_bb(fn, ir.args[0], symbols) - _append_return_args(fn) - else: - bb = fn.get_basic_block() - if bb.is_terminated: - bb = IRBasicBlock(ctx.get_next_label("seq"), fn) - fn.append_basic_block(bb) - ret = _convert_ir_bb(fn, ir.args[0], symbols) + def _handle_if_stmt(self, ir: IRnode) -> Optional[IRVariable]: + fn = self.fn + ctx = fn.ctx - for ir_node in ir.args[1:]: - ret = _convert_ir_bb(fn, ir_node, symbols) + cond_ir = ir.args[0] - return ret - elif ir.value == "if": - cond = ir.args[0] - - # convert the condition - cont_ret = _convert_ir_bb(fn, cond, symbols) + cond = self.convert_ir(cond_ir) cond_block = fn.get_basic_block() then_block = IRBasicBlock(ctx.get_next_label("then"), fn) else_block = IRBasicBlock(ctx.get_next_label("else"), fn) # convert "then" - cond_symbols = symbols.copy() fn.append_basic_block(then_block) - then_ret_val = _convert_ir_bb(fn, ir.args[1], cond_symbols) - if isinstance(then_ret_val, IRLiteral): - then_ret_val = fn.get_basic_block().append_instruction("store", then_ret_val) + with self.anchor_variables(): + then_ret_val = self.convert_ir(ir.args[1]) then_block_finish = fn.get_basic_block() # convert "else" - cond_symbols = symbols.copy() fn.append_basic_block(else_block) else_ret_val = None if len(ir.args) == 3: - else_ret_val = _convert_ir_bb(fn, ir.args[2], cond_symbols) - if isinstance(else_ret_val, IRLiteral): - assert isinstance(else_ret_val.value, int) # help mypy - else_ret_val = fn.get_basic_block().append_instruction("store", else_ret_val) + with self.anchor_variables(): + else_ret_val = self.convert_ir(ir.args[2]) else_block_finish = fn.get_basic_block() # finish the condition block - cond_block.append_instruction("jnz", cont_ret, then_block.label, else_block.label) + cond_block.append_instruction("jnz", cond, then_block.label, else_block.label) # exit bb - exit_bb = IRBasicBlock(ctx.get_next_label("if_exit"), fn) - fn.append_basic_block(exit_bb) + join_bb = IRBasicBlock(ctx.get_next_label("if_exit"), fn) + fn.append_basic_block(join_bb) if_ret = fn.get_next_variable() + # will get converted to phi by make_ssa if then_ret_val is not None and else_ret_val is not None: then_block_finish.append_instruction("store", then_ret_val, ret=if_ret) else_block_finish.append_instruction("store", else_ret_val, ret=if_ret) if not else_block_finish.is_terminated: - else_block_finish.append_instruction("jmp", exit_bb.label) + else_block_finish.append_instruction("jmp", join_bb.label) if not then_block_finish.is_terminated: - then_block_finish.append_instruction("jmp", exit_bb.label) + then_block_finish.append_instruction("jmp", join_bb.label) return if_ret - elif ir.value == "with": - ret = _convert_ir_bb(fn, ir.args[1], symbols) # initialization - - ret = fn.get_basic_block().append_instruction("store", ret) - - sym = ir.args[0] - with_symbols = symbols.copy() - with_symbols[sym.value] = ret - - return _convert_ir_bb(fn, ir.args[2], with_symbols) # body - - elif ir.value == "goto": - _append_jmp(fn, IRLabel(ir.args[0].value)) - elif ir.value == "djump": - args = [_convert_ir_bb(fn, ir.args[0], symbols)] - for target in ir.args[1:]: - args.append(IRLabel(target.value)) - fn.get_basic_block().append_instruction("djmp", *args) - _new_block(fn) - elif ir.value == "set": - sym = ir.args[0] - arg_1 = _convert_ir_bb(fn, ir.args[1], symbols) - fn.get_basic_block().append_instruction("store", arg_1, ret=symbols[sym.value]) - elif ir.value == "symbol": - return IRLabel(ir.args[0].value, True) - elif ir.value == "data": - label = IRLabel(ir.args[0].value, True) - ctx.append_data_section(label) - for c in ir.args[1:]: - if isinstance(c.value, bytes): - ctx.append_data_item(c.value) - elif isinstance(c, IRnode): - data = _convert_ir_bb(fn, c, symbols) - assert isinstance(data, IRLabel) # help mypy - ctx.append_data_item(data) - elif ir.value == "label": - label = IRLabel(ir.args[0].value, True) - bb = fn.get_basic_block() - if not bb.is_terminated: - bb.append_instruction("jmp", label) - bb = IRBasicBlock(label, fn) - fn.append_basic_block(bb) - code = ir.args[2] - _convert_ir_bb(fn, code, symbols) - elif ir.value == "exit_to": - bb = fn.get_basic_block() - if bb.is_terminated: - bb = IRBasicBlock(ctx.get_next_label("exit_to"), fn) - fn.append_basic_block(bb) - - args = _convert_ir_bb_list(fn, ir.args[1:], symbols) - var_list = args - # TODO: only append return args if the function is external - _append_return_args(fn, *var_list) - bb = fn.get_basic_block() - - label = IRLabel(ir.args[0].value) - if label.value == "return_pc": - label = symbols.get("return_pc") - # return label should be top of stack - if _returns_word(_current_func_t) and ENABLE_NEW_CALL_CONV: - buf = symbols["return_buffer"] - val = bb.append_instruction("mload", buf) - bb.append_instruction("ret", val, label) - else: - bb.append_instruction("ret", label) - - else: - bb.append_instruction("jmp", label) - - elif ir.value == "mstore": - # some upstream code depends on reversed order of evaluation -- - # to fix upstream. - val, ptr = _convert_ir_bb_list(fn, reversed(ir.args), symbols) - return fn.get_basic_block().append_instruction("mstore", val, ptr) - - elif ir.value == "ceil32": - x = ir.args[0] - expanded = IRnode.from_list(["and", ["add", x, 31], ["not", 31]]) - return _convert_ir_bb(fn, expanded, symbols) - elif ir.value == "select": - cond, a, b = ir.args - expanded = IRnode.from_list( - [ - "with", - "cond", - cond, - [ - "with", - "a", - a, - ["with", "b", b, ["xor", "b", ["mul", "cond", ["xor", "a", "b"]]]], - ], - ] - ) - return _convert_ir_bb(fn, expanded, symbols) - elif ir.value == "repeat": - - def emit_body_blocks(): - global _break_target, _continue_target - old_targets = _break_target, _continue_target - _break_target, _continue_target = exit_block, incr_block - _convert_ir_bb(fn, body, symbols.copy()) - _break_target, _continue_target = old_targets + def _handle_repeat(self, ir): + fn = self.fn + ctx = fn.ctx + # loop variable name sym = ir.args[0] - start, end, _ = _convert_ir_bb_list(fn, ir.args[1:4], symbols) + start, end, _ = self._convert_ir_list(ir.args[1:4]) assert ir.args[3].is_literal, "repeat bound expected to be literal" bound = ir.args[3].value @@ -623,7 +559,6 @@ def emit_body_blocks(): fn.append_basic_block(entry_block) counter_var = entry_block.append_instruction("store", start) - symbols[sym.value] = counter_var if bound is not None: # assert le end bound @@ -636,12 +571,20 @@ def emit_body_blocks(): entry_block.append_instruction("jmp", cond_block.label) xor_ret = cond_block.append_instruction("xor", counter_var, end) - cont_ret = cond_block.append_instruction("iszero", xor_ret) + cond = cond_block.append_instruction("iszero", xor_ret) + fn.append_basic_block(cond_block) + # convert body fn.append_basic_block(body_block) + backup = self._break_target, self._continue_target + self._break_target = exit_block + self._continue_target = incr_block + with self.anchor_variables(): + self.variables[sym.value] = counter_var + self.convert_ir(body) + self._break_target, self._continue_target = backup - emit_body_blocks() body_end = fn.get_basic_block() if body_end.is_terminated is False: body_end.append_instruction("jmp", incr_block.label) @@ -654,81 +597,223 @@ def emit_body_blocks(): fn.append_basic_block(exit_block) - cond_block.append_instruction("jnz", cont_ret, exit_block.label, body_block.label) - elif ir.value == "break": - assert _break_target is not None, "Break with no break target" - fn.get_basic_block().append_instruction("jmp", _break_target.label) - fn.append_basic_block(IRBasicBlock(ctx.get_next_label(), fn)) - elif ir.value == "continue": - assert _continue_target is not None, "Continue with no contrinue target" - fn.get_basic_block().append_instruction("jmp", _continue_target.label) - fn.append_basic_block(IRBasicBlock(ctx.get_next_label(), fn)) - elif ir.value in NOOP_INSTRUCTIONS: - pass - elif isinstance(ir.value, str) and ir.value.startswith("log"): - args = reversed(_convert_ir_bb_list(fn, ir.args, symbols)) - topic_count = int(ir.value[3:]) - assert topic_count >= 0 and topic_count <= 4, "invalid topic count" - fn.get_basic_block().append_instruction("log", topic_count, *args) - elif isinstance(ir.value, str) and ir.value.upper() in get_opcodes(): - _convert_ir_opcode(fn, ir, symbols) - elif isinstance(ir.value, str): - if ir.value.startswith("$alloca"): - alloca = ir.passthrough_metadata["alloca"] - if alloca._id not in _alloca_table: - ptr = fn.get_basic_block().append_instruction( - "alloca", alloca.offset, alloca.size, alloca._id + cond_block.append_instruction("jnz", cond, exit_block.label, body_block.label) + + def _handle_exit_to(self, ir): + fn = self.fn + ctx = fn.ctx + + bb = fn.get_basic_block() + if bb.is_terminated: + bb = IRBasicBlock(ctx.get_next_label("exit_to"), fn) + fn.append_basic_block(bb) + + args = self._convert_ir_list(ir.args[1:]) + bb = fn.get_basic_block() + + label = IRLabel(ir.args[0].value) + if label.value == "return_pc": + # return from internal function + + label = self.variables["return_pc"] + # return label should be top of stack + if _returns_word(self._current_func_t) and ENABLE_NEW_CALL_CONV: + buf = self.variables["return_buffer"] + val = bb.append_instruction("mload", buf) + bb.append_instruction("ret", val, label) + else: + bb.append_instruction("ret", label) + + elif len(ir.args) > 1 and ir.args[1].value == "return_pc": + # cleanup routine for internal function + bb.append_instruction("jmp", label) + else: + # cleanup routine for external function + if len(args) > 0: + ofst, size = args + self._append_return_args(ofst, size) + bb = fn.get_basic_block() + bb.append_instruction("jmp", label) + + def _handle_self_call(self, ir: IRnode) -> Optional[IROperand]: + fn = self.fn + + setup_ir = ir.args[1] + goto_ir = [ir for ir in ir.args if ir.value == "goto"][0] + target_label = goto_ir.args[0].value # goto + + func_t = ir.passthrough_metadata["func_t"] + assert func_t is not None, "func_t not found in passthrough metadata" + + returns_word = _returns_word(func_t) + + if setup_ir != goto_ir: + self.convert_ir(setup_ir) + + converted_args = self._convert_ir_list(goto_ir.args[1:]) + + callsite_op = converted_args[-1] + assert isinstance(callsite_op, IRLabel), converted_args + callsite = callsite_op.value + + bb = fn.get_basic_block() + return_buf = None + + if len(converted_args) > 1: + return_buf = converted_args[0] + + # should be list[IROperand], but mypy is stupid + stack_args: list[IROperand | int] + stack_args = [IRLabel(str(target_label))] + + if return_buf is not None: + if not ENABLE_NEW_CALL_CONV or not returns_word: + stack_args.append(return_buf) + + callsite_args = self._callsites[callsite] + if ENABLE_NEW_CALL_CONV: + for alloca in callsite_args: + if not _pass_via_stack(func_t)[alloca.name]: + continue + ptr = self._alloca_table[alloca._id] + stack_arg = bb.append_instruction("mload", ptr) + assert stack_arg is not None + stack_args.append(stack_arg) + + if returns_word: + ret_value = bb.append_invoke_instruction(stack_args, returns=True) + assert ret_value is not None # help mypy + assert return_buf is not None # help mypy + bb.append_instruction("mstore", ret_value, return_buf) + return return_buf + + bb.append_invoke_instruction(stack_args, returns=False) + + return return_buf + + # TODO: remove does_return_data, replace with `func_t.return_type is not None` + def _handle_internal_func(self, ir: IRnode, does_return_data: bool) -> IRFunction: + fn = self.fn + + func_t = ir.passthrough_metadata["func_t"] + context = ir.passthrough_metadata["context"] + assert func_t is not None, "func_t not found in passthrough metadata" + assert context is not None, func_t.name + + self._current_func_t = func_t + + funcname = ir.args[0].args[0].value + assert isinstance(funcname, str) + fn = fn.ctx.create_function(funcname) + + bb = fn.get_basic_block() + + _saved_alloca_table = self._alloca_table + self._alloca_table = {} + + returns_word = _returns_word(func_t) + + # return buffer + if does_return_data: + if ENABLE_NEW_CALL_CONV and returns_word: + # TODO: remove this once we have proper memory allocator + # functionality in venom. Currently, we hardcode the scratch + # buffer size of 32 bytes. + # TODO: we don't need to use scratch space once the legacy optimizer + # is disabled. + buf = bb.append_instruction("alloca", 0, 32, self.get_scratch_alloca_id()) + else: + buf = bb.append_instruction("param") + bb.instructions[-1].annotation = "return_buffer" + + assert buf is not None # help mypy + self.variables["return_buffer"] = buf + + if ENABLE_NEW_CALL_CONV: + stack_index = 0 + if func_t.return_type is not None and not _returns_word(func_t): + stack_index += 1 + for arg in func_t.arguments: + if not _pass_via_stack(func_t)[arg.name]: + continue + + param = bb.append_instruction("param") + bb.instructions[-1].annotation = arg.name + assert param is not None # help mypy + + var = context.lookup_var(arg.name) + + venom_arg = IRParameter( + name=var.name, + index=stack_index, + offset=var.alloca.offset, + size=var.alloca.size, + id_=var.alloca._id, + call_site_var=None, + func_var=param, + addr_var=None, ) - _alloca_table[alloca._id] = ptr - return _alloca_table[alloca._id] + fn.args.append(venom_arg) + stack_index += 1 - elif ir.value.startswith("$palloca"): - alloca = ir.passthrough_metadata["alloca"] - if alloca._id not in _alloca_table: - bb = fn.get_basic_block() - ptr = bb.append_instruction("palloca", alloca.offset, alloca.size, alloca._id) - bb.instructions[-1].annotation = f"{alloca.name} (memory)" - if ENABLE_NEW_CALL_CONV and _pass_via_stack(_current_func_t)[alloca.name]: - param = fn.get_param_by_id(alloca._id) - assert param is not None - bb.append_instruction("mstore", param.func_var, ptr) - _alloca_table[alloca._id] = ptr - return _alloca_table[alloca._id] - elif ir.value.startswith("$calloca"): - global _callsites - alloca = ir.passthrough_metadata["alloca"] - assert alloca._callsite is not None - if alloca._id not in _alloca_table: - bb = fn.get_basic_block() + # return address + return_pc = bb.append_instruction("param") + assert return_pc is not None # help mypy + self.variables["return_pc"] = return_pc + bb.instructions[-1].annotation = "return_pc" - callsite_func = ir.passthrough_metadata["callsite_func"] - if ENABLE_NEW_CALL_CONV and _pass_via_stack(callsite_func)[alloca.name]: - ptr = bb.append_instruction("alloca", alloca.offset, alloca.size, alloca._id) - else: - # if we use alloca, mstores might get removed. convert - # to calloca until memory analysis is more sound. - ptr = bb.append_instruction("calloca", alloca.offset, alloca.size, alloca._id) - - _alloca_table[alloca._id] = ptr - ret = _alloca_table[alloca._id] - # assumption: callocas appear in the same order as the - # order of arguments to the function. - _callsites[alloca._callsite].append(alloca) - return ret + with self.anchor_fn(fn): + # convert the body of the function + self.convert_ir(ir.args[0].args[2]) + + self._alloca_table = _saved_alloca_table - return symbols.get(ir.value) - elif ir.is_literal: - return IRLiteral(ir.value) - else: - raise Exception(f"Unknown IR node: {ir}") + return fn + + def _append_new_bb(self) -> None: + fn = self.fn + bb = IRBasicBlock(fn.ctx.get_next_label(), fn) + fn.append_basic_block(bb) + + +# func_t: ContractFunctionT +@functools.lru_cache(maxsize=1024) +def _pass_via_stack(func_t) -> dict[str, bool]: + # returns a dict which returns True if a given argument (referered to + # by name) should be passed via the stack + if not ENABLE_NEW_CALL_CONV: + return {arg.name: False for arg in func_t.arguments} - return None + arguments = {arg.name: arg for arg in func_t.arguments} + + stack_items = 0 + returns_word = _returns_word(func_t) + if returns_word: + stack_items += 1 + + ret = {} + + for arg in arguments.values(): + if not _is_word_type(arg.typ) or stack_items > MAX_STACK_ARGS: + ret[arg.name] = False + else: + ret[arg.name] = True + stack_items += 1 + + return ret + + +def _is_word_type(typ): + # we can pass it on the stack. + return typ.memory_bytes_required == 32 + + +# func_t: ContractFunctionT +def _returns_word(func_t) -> bool: + return_t = func_t.return_type + return return_t is not None and _is_word_type(return_t) -def _convert_ir_opcode(fn: IRFunction, ir: IRnode, symbols: SymbolTable) -> None: - opcode = ir.value.upper() # type: ignore - inst_args = [] - for arg in ir.args: - if isinstance(arg, IRnode): - inst_args.append(_convert_ir_bb(fn, arg, symbols)) - fn.get_basic_block().append_instruction(opcode, *inst_args) +def ir_node_to_venom(ir: IRnode, constants: Optional[dict[str, int]]) -> IRContext: + constants = constants or {} + return IRnodeToVenom(constants).convert(ir) diff --git a/vyper/venom/memory_location.py b/vyper/venom/memory_location.py index 977c8a1c76..ec2a2f9da8 100644 --- a/vyper/venom/memory_location.py +++ b/vyper/venom/memory_location.py @@ -208,8 +208,6 @@ def _get_memory_read_location(inst) -> MemoryLocation: elif opcode == "sha3": size, offset = inst.operands return MemoryLocation.from_operands(offset, size) - elif opcode == "sha3_32": - raise CompilerPanic("invalid opcode") # should be unused elif opcode == "sha3_64": return MemoryLocation(offset=0, size=64) elif opcode == "log": @@ -251,7 +249,7 @@ def _get_storage_read_location(inst, addr_space: AddrSpace) -> MemoryLocation: return MemoryLocation.UNDEFINED elif opcode in ("create", "create2"): return MemoryLocation.UNDEFINED - elif opcode in ("return", "stop", "exit", "sink"): + elif opcode in ("return", "stop", "sink"): # these opcodes terminate execution and commit to (persistent) # storage, resulting in storage writes escaping our control. # returning `MemoryLocation.UNDEFINED` represents "future" reads diff --git a/vyper/venom/venom_to_assembly.py b/vyper/venom/venom_to_assembly.py index 4c5a2bfcda..823f346ce1 100644 --- a/vyper/venom/venom_to_assembly.py +++ b/vyper/venom/venom_to_assembly.py @@ -1,12 +1,17 @@ +from __future__ import annotations + from typing import Any, Iterable from vyper.exceptions import CompilerPanic, StackTooDeep from vyper.ir.compile_ir import ( + DATA_ITEM, PUSH, + PUSH_OFST, + PUSHLABEL, + AssemblyInstruction, DataHeader, - Instruction, - RuntimeHeader, - mksymbol, + Label, + TaggedInstruction, optimize_assembly, ) from vyper.utils import MemoryPositions, OrderedSet, wrap256 @@ -107,22 +112,27 @@ ] ) -_REVERT_POSTAMBLE = ["_sym___revert", "JUMPDEST", *PUSH(0), "DUP1", "REVERT"] +_REVERT_POSTAMBLE = [Label("revert"), *PUSH(0), "DUP1", "REVERT"] def apply_line_numbers(inst: IRInstruction, asm) -> list[str]: ret = [] for op in asm: - if isinstance(op, str) and not isinstance(op, Instruction): - ret.append(Instruction(op, inst.ast_source, inst.error_msg)) + if isinstance(op, str) and not isinstance(op, TaggedInstruction): + ret.append(TaggedInstruction(op, inst.ast_source, inst.error_msg)) else: ret.append(op) return ret # type: ignore -def _as_asm_symbol(label: IRLabel) -> str: +def _as_asm_symbol(label: IRLabel) -> Label: # Lower an IRLabel to an assembly symbol - return f"_sym_{label.value}" + return Label(label.value) + + +def _ofst(label: Label, value: int) -> list[Any]: + # resolve at compile time using magic PUSH_OFST op + return [PUSH_OFST(label, value)] # TODO: "assembly" gets into the recursion due to how the original @@ -141,68 +151,55 @@ class VenomCompiler: dfg: DFGAnalysis cfg: CFGAnalysis - def __init__(self, ctxs: list[IRContext]): - self.ctxs = ctxs + def __init__(self, ctx: IRContext): + # TODO: maybe just accept a single IRContext + self.ctx = ctx self.label_counter = 0 self.visited_basicblocks = OrderedSet() - def generate_evm(self, no_optimize: bool = False) -> list[str]: + def mklabel(self, name: str) -> Label: + self.label_counter += 1 + return Label(f"{name}_{self.label_counter}") + + def generate_evm_assembly(self, no_optimize: bool = False) -> list[AssemblyInstruction]: self.visited_basicblocks = OrderedSet() self.label_counter = 0 - asm: list[Any] = [] - top_asm = asm - - for ctx in self.ctxs: - for fn in ctx.functions.values(): - ac = IRAnalysesCache(fn) - - NormalizationPass(ac, fn).run_pass() - self.liveness = ac.request_analysis(LivenessAnalysis) - self.dfg = ac.request_analysis(DFGAnalysis) - self.cfg = ac.request_analysis(CFGAnalysis) - - assert self.cfg.is_normalized(), "Non-normalized CFG!" - - self._generate_evm_for_basicblock_r(asm, fn.entry, StackModel()) - - # TODO make this property on IRFunction - asm.extend(["_sym__ctor_exit", "JUMPDEST"]) - if ctx.immutables_len is not None and ctx.ctor_mem_size is not None: - asm.extend( - ["_sym_subcode_size", "_sym_runtime_begin", "_mem_deploy_start", "CODECOPY"] - ) - asm.extend(["_OFST", "_sym_subcode_size", ctx.immutables_len]) # stack: len - asm.extend(["_mem_deploy_start"]) # stack: len mem_ofst - asm.extend(["RETURN"]) - asm.extend(_REVERT_POSTAMBLE) - runtime_asm = [ - RuntimeHeader("_sym_runtime_begin", ctx.ctor_mem_size, ctx.immutables_len) - ] - asm.append(runtime_asm) - asm = runtime_asm - else: - asm.extend(_REVERT_POSTAMBLE) - - # Append data segment - for data_section in ctx.data_segment: - label = data_section.label - asm_data_section: list[Any] = [] - asm_data_section.append(DataHeader(_as_asm_symbol(label))) - for item in data_section.data_items: - data = item.data - if isinstance(data, IRLabel): - asm_data_section.append(_as_asm_symbol(data)) - else: - assert isinstance(data, bytes) - asm_data_section.append(data) - - asm.append(asm_data_section) + asm: list[AssemblyInstruction] = [] + + for fn in self.ctx.functions.values(): + ac = IRAnalysesCache(fn) + + NormalizationPass(ac, fn).run_pass() + self.liveness = ac.request_analysis(LivenessAnalysis) + self.dfg = ac.request_analysis(DFGAnalysis) + self.cfg = ac.request_analysis(CFGAnalysis) + + assert self.cfg.is_normalized(), "Non-normalized CFG!" + + self._generate_evm_for_basicblock_r(asm, fn.entry, StackModel()) + + asm.extend(_REVERT_POSTAMBLE) + + # Append data segment + for data_section in self.ctx.data_segment: + label = data_section.label + asm_data_section: list[AssemblyInstruction] = [] + asm_data_section.append(DataHeader(_as_asm_symbol(label))) + for item in data_section.data_items: + data = item.data + if isinstance(data, IRLabel): + asm_data_section.append(DATA_ITEM(_as_asm_symbol(data))) + else: + assert isinstance(data, bytes) + asm_data_section.append(DATA_ITEM(data)) + + asm.extend(asm_data_section) if no_optimize is False: - optimize_assembly(top_asm) + optimize_assembly(asm) - return top_asm + return asm def _stack_reorder( self, assembly: list, stack: StackModel, stack_ops: list[IROperand], dry_run: bool = False @@ -262,7 +259,7 @@ def _emit_input_operands( # invoke emits the actual instruction itself so we don't need # to emit it here but we need to add it to the stack map if inst.opcode != "invoke": - assembly.append(_as_asm_symbol(op)) + assembly.append(PUSHLABEL(_as_asm_symbol(op))) stack.push(op) continue @@ -338,7 +335,6 @@ def _generate_evm_for_basicblock_r( # assembly entry point into the block asm.append(_as_asm_symbol(basicblock.label)) - asm.append("JUMPDEST") fn = basicblock.parent if basicblock == fn.entry: @@ -391,7 +387,7 @@ def clean_stack_from_cfg_in( def _generate_evm_for_instruction( self, inst: IRInstruction, stack: StackModel, next_liveness: OrderedSet ) -> list[str]: - assembly: list[str | int] = [] + assembly: list[AssemblyInstruction] = [] opcode = inst.opcode # @@ -450,7 +446,7 @@ def _generate_evm_for_instruction( if opcode == "offset": ofst, label = inst.operands assert isinstance(label, IRLabel) # help mypy - assembly.extend(["_OFST", _as_asm_symbol(label), ofst.value]) + assembly.extend(_ofst(_as_asm_symbol(label), ofst.value)) assert isinstance(inst.output, IROperand), "Offset must have output" stack.push(inst.output) return apply_line_numbers(inst, assembly) @@ -516,19 +512,19 @@ def _generate_evm_for_instruction( elif opcode == "jnz": # jump if not zero if_nonzero_label, if_zero_label = inst.get_label_operands() - assembly.append(_as_asm_symbol(if_nonzero_label)) + assembly.append(PUSHLABEL(_as_asm_symbol(if_nonzero_label))) assembly.append("JUMPI") # make sure the if_zero_label will be optimized out # assert if_zero_label == next(iter(inst.parent.cfg_out)).label - assembly.append(_as_asm_symbol(if_zero_label)) + assembly.append(PUSHLABEL(_as_asm_symbol(if_zero_label))) assembly.append("JUMP") elif opcode == "jmp": (target,) = inst.operands assert isinstance(target, IRLabel) - assembly.append(_as_asm_symbol(target)) + assembly.append(PUSHLABEL(_as_asm_symbol(target))) assembly.append("JUMP") elif opcode == "djmp": assert isinstance( @@ -540,22 +536,14 @@ def _generate_evm_for_instruction( assert isinstance( target, IRLabel ), f"invoke target must be a label (is ${type(target)} ${target})" + return_label = self.mklabel("return_label") assembly.extend( - [ - f"_sym_label_ret_{self.label_counter}", - _as_asm_symbol(target), - "JUMP", - f"_sym_label_ret_{self.label_counter}", - "JUMPDEST", - ] + [PUSHLABEL(return_label), PUSHLABEL(_as_asm_symbol(target)), "JUMP", return_label] ) - self.label_counter += 1 elif opcode == "ret": assembly.append("JUMP") elif opcode == "return": assembly.append("RETURN") - elif opcode == "exit": - assembly.extend(["_sym__ctor_exit", "JUMP"]) elif opcode == "phi": pass elif opcode == "sha3": @@ -573,23 +561,27 @@ def _generate_evm_for_instruction( ] ) elif opcode == "assert": - assembly.extend(["ISZERO", "_sym___revert", "JUMPI"]) + assembly.extend(["ISZERO", PUSHLABEL(Label("revert")), "JUMPI"]) elif opcode == "assert_unreachable": - end_symbol = mksymbol("reachable") - assembly.extend([end_symbol, "JUMPI", "INVALID", end_symbol, "JUMPDEST"]) + end_symbol = self.mklabel("reachable") + assembly.extend([PUSHLABEL(end_symbol), "JUMPI", "INVALID", end_symbol]) elif opcode == "iload": addr = inst.operands[0] + mem_deploy_end = self.ctx.constants["mem_deploy_end"] if isinstance(addr, IRLiteral): - assembly.extend(["_OFST", "_mem_deploy_end", addr.value]) + ptr = mem_deploy_end + addr.value + assembly.extend(PUSH(ptr)) else: - assembly.extend(["_mem_deploy_end", "ADD"]) + assembly.extend([*PUSH(mem_deploy_end), "ADD"]) assembly.append("MLOAD") elif opcode == "istore": addr = inst.operands[1] + mem_deploy_end = self.ctx.constants["mem_deploy_end"] if isinstance(addr, IRLiteral): - assembly.extend(["_OFST", "_mem_deploy_end", addr.value]) + ptr = mem_deploy_end + addr.value + assembly.extend(PUSH(ptr)) else: - assembly.extend(["_mem_deploy_end", "ADD"]) + assembly.extend([*PUSH(mem_deploy_end), "ADD"]) assembly.append("MSTORE") elif opcode == "log": assembly.extend([f"LOG{log_topic_count}"])