Skip to content

Commit bba3bff

Browse files
author
rocky
committed
Add program to check bytecode writing.
Fix some errors in writing bytecode. Respecting Bytecode code type in marshal had bugs in it. (It may still have bugs.)
1 parent d43b2a1 commit bba3bff

File tree

4 files changed

+394
-40
lines changed

4 files changed

+394
-40
lines changed

test/test_writing_pyc.py

Lines changed: 362 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,362 @@
1+
#!/usr/bin/env python
2+
# (C) Copyright 2025 by Rocky Bernstein
3+
#
4+
# This program is free software; you can redistribute it and/or
5+
# modify it under the terms of the GNU General Public License
6+
# as published by the Free Software Foundation; either version 2
7+
# of the License, or (at your option) any later version.
8+
#
9+
# This program is distributed in the hope that it will be useful,
10+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
11+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12+
# GNU General Public License for more details.
13+
#
14+
# You should have received a copy of the GNU General Public License
15+
# along with this program; if not, write to the Free Software
16+
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
17+
"""
18+
test_writing_pyc.py
19+
20+
Command-line tool that:
21+
- takes a Python bytecode file (.pyc) as an argument
22+
- loads it with xdis.load.load_module_from_file_object()
23+
- writes it back out using xdis.load.write_bytecode_file() to a temporary file
24+
- compares the original and written files by raw bytes and by the
25+
module/code-object structure they contain
26+
27+
Usage:
28+
python test_writing_pyc.py path/to/bytecode_file.pyc
29+
"""
30+
31+
import filecmp
32+
import os
33+
import os.path as osp
34+
import sys
35+
import tempfile
36+
import types
37+
from typing import List, Union
38+
39+
from xdis.codetype.base import CodeBase
40+
from xdis.load import load_module_from_file_object, write_bytecode_file
41+
from xdis.version_info import version_tuple_to_str
42+
43+
# For later:
44+
# from xdis.unmarshal import (
45+
# _VersionIndependentUnmarshaller,
46+
# FLAG_REF,
47+
# UNMARSHAL_DISPATCH_TABLE,
48+
# )
49+
50+
51+
def compare_consts(c1: tuple, c2: tuple) -> bool:
52+
"""Compare tuples of constants, recursing into code objects when found."""
53+
if len(c1) != len(c2):
54+
return False
55+
for a, b in zip(c1, c2):
56+
# If code object-like, compare as code objects
57+
if hasattr(a, "co_code") and hasattr(b, "co_code"):
58+
if not compare_code_objects(a, b):
59+
return False
60+
else:
61+
if a != b:
62+
return False
63+
return True
64+
65+
def compare_showing_error(orig_path: str, new_path: str):
66+
"""
67+
Compare file contents sizes if those mismatch and the first hex offset that the differ.
68+
"""
69+
orig_bytes = open(orig_path, "rb").read()
70+
new_bytes = open(new_path, "rb").read()
71+
if (orig_n := len(orig_bytes)) != (new_n := len(new_bytes)):
72+
print(f"MISMATCH: original has {orig_n} bytes; new has {new_n} bytes", file=sys.stderr)
73+
74+
for i, (old_byte, new_byte) in enumerate(zip(orig_bytes, new_bytes)):
75+
if (old_byte != new_byte):
76+
print(f"MISMATCH at {hex(i)}: old {hex(old_byte)}; new: {hex(new_byte)}", file=sys.stderr)
77+
return
78+
79+
80+
def get_code_attrs(co: Union[CodeBase, types.CodeType]) -> dict:
81+
"""
82+
Extract a set of attributes from a code object that are stable for comparison.
83+
This is defensive: only include an attribute if it exists on the object.
84+
"""
85+
attrs = {}
86+
for name in (
87+
"co_argcount",
88+
"co_posonlyargcount",
89+
"co_kwonlyargcount",
90+
"co_nlocals",
91+
"co_stacksize",
92+
"co_flags",
93+
"co_code",
94+
"co_consts",
95+
"co_names",
96+
"co_varnames",
97+
"co_freevars",
98+
"co_cellvars",
99+
"co_filename",
100+
"co_name",
101+
):
102+
if hasattr(co, name):
103+
attrs[name] = getattr(co, name)
104+
return attrs
105+
106+
107+
def compare_code_objects(
108+
a: Union[CodeBase, types.CodeType], b: Union[CodeBase, types.CodeType]
109+
) -> bool:
110+
"""
111+
Compare two code objects by a selection of attributes.
112+
This function attempts to be permissive across Python implementations
113+
by only using attributes when present and by comparing constants recursively.
114+
"""
115+
attrs_a = get_code_attrs(a)
116+
attrs_b = get_code_attrs(b)
117+
118+
if set(attrs_a.keys()) != set(attrs_b.keys()):
119+
# If the sets of attributes differ, still try to compare the
120+
# intersection of attributes to be less strict across implementations.
121+
common_keys = sorted(set(attrs_a.keys()).intersection(attrs_b.keys()))
122+
else:
123+
common_keys = sorted(attrs_a.keys())
124+
125+
for key in common_keys:
126+
va = attrs_a[key]
127+
vb = attrs_b[key]
128+
if key == "co_consts":
129+
if not compare_consts(tuple(va), tuple(vb)):
130+
return False
131+
else:
132+
if va != vb:
133+
return False
134+
return True
135+
136+
137+
# For later
138+
# def r_object(self, bytes_for_s: bool = False):
139+
# """
140+
# Replacement r_object for classification
141+
# """
142+
# byte1 = ord(self.fp.read(1))
143+
144+
# # FLAG_REF indicates whether we "intern" or
145+
# # save a reference to the object.
146+
# # byte1 without that reference is the
147+
# # marshal type code, an ASCII character.
148+
# save_ref = False
149+
# if byte1 & FLAG_REF:
150+
# # Since 3.4, "flag" is the marshal.c name
151+
# save_ref = True
152+
# byte1 = byte1 & (FLAG_REF - 1)
153+
# marshal_type = chr(byte1)
154+
155+
# # print(marshal_type) # debug
156+
157+
# if marshal_type in UNMARSHAL_DISPATCH_TABLE:
158+
# func_suffix = UNMARSHAL_DISPATCH_TABLE[marshal_type]
159+
# unmarshal_func = getattr(self, "t_" + func_suffix)
160+
# return unmarshal_func(save_ref, bytes_for_s)
161+
# else:
162+
# mess = ("Unknown type %i (hex %x) %c\n"
163+
# % (ord(marshal_type), ord(marshal_type), marshal_type))
164+
# raise TypeError(mess)
165+
# return
166+
167+
def load_meta_and_code_from_filename(path: str):
168+
"""
169+
Open path and use load_module_from_file_object to get:
170+
(version_tuple, timestamp, magic_int, co, is_pypy, source_size, sip_hash, file_offsets)
171+
"""
172+
with open(path, "rb") as fp:
173+
return load_module_from_file_object(fp, filename=path, get_code=True)
174+
175+
176+
def main(argv: List[str]) -> int:
177+
# parser = argparse.ArgumentParser(
178+
# description="Load a .pyc with xdis, rewrite it to a temporary file, and compare."
179+
# )
180+
# parser.add_argument("pycfile", help="Path to the .pyc (or other bytecode) file")
181+
# args = parser.parse_args(argv)
182+
if len(argv) < 2:
183+
print("ERROR: you need to pass a pyc file", file=sys.stderr)
184+
return 1
185+
186+
orig_path = argv[1]
187+
if not osp.exists(orig_path):
188+
print(f"ERROR: file does not exist: {orig_path}", file=sys.stderr)
189+
return 2
190+
if not osp.isfile(orig_path):
191+
print(f"ERROR: not a file: {orig_path}", file=sys.stderr)
192+
return 2
193+
194+
# Load original using the file-object loader (it will close the file for us)
195+
try:
196+
(
197+
orig_version,
198+
orig_timestamp,
199+
orig_magic_int,
200+
orig_co,
201+
orig_is_pypy,
202+
orig_source_size,
203+
orig_sip_hash,
204+
orig_file_offsets,
205+
) = load_meta_and_code_from_filename(orig_path)
206+
except Exception as e:
207+
print(f"ERROR: failed to load original bytecode file: {e}", file=sys.stderr)
208+
return 3
209+
210+
tf_name_base = osp.basename(orig_path)
211+
if tf_name_base.endswith(".pyc"):
212+
tf_name_base = tf_name_base[: -len(".pyc")]
213+
elif tf_name_base.endswith(".pyc"):
214+
tf_name_base = tf_name_base[: -len(".pyo")]
215+
216+
version_str = version_tuple_to_str(orig_version)
217+
tf_name_base += f"-{version_str}-"
218+
219+
# Write to a temporary file using write_bytecode_file
220+
tf = tempfile.NamedTemporaryFile(prefix=tf_name_base, suffix=".pyc", delete=False)
221+
tf_name = tf.name
222+
tf.close() # write_bytecode_file will open/write the file itself
223+
224+
try:
225+
write_bytecode_file(
226+
tf_name,
227+
orig_co,
228+
orig_magic_int,
229+
compilation_ts=orig_timestamp,
230+
filesize=orig_source_size or 0,
231+
)
232+
except TypeError:
233+
# Older/newer signatures might name the timestamp param differently; try without names
234+
try:
235+
write_bytecode_file(
236+
tf_name, orig_co, orig_magic_int, orig_timestamp, orig_source_size or 0
237+
)
238+
except Exception as e:
239+
print(f"ERROR: failed to write bytecode file: {e}", file=sys.stderr)
240+
# Cleanup
241+
try:
242+
os.unlink(tf_name)
243+
except Exception:
244+
pass
245+
return 4
246+
except Exception as e:
247+
print(f"ERROR: failed to write bytecode file: {e}", file=sys.stderr)
248+
try:
249+
os.unlink(tf_name)
250+
except Exception:
251+
pass
252+
return 4
253+
254+
# Compare raw bytes first
255+
same_bytes = False
256+
try:
257+
same_bytes = filecmp.cmp(orig_path, tf_name, shallow=False)
258+
except Exception as e:
259+
print(f"WARNING: could not do raw byte comparison: {e}", file=sys.stderr)
260+
261+
print("Original file:", orig_path)
262+
print("Rewritten file:", tf_name)
263+
print("Raw-bytes identical:", same_bytes)
264+
if same_bytes:
265+
return 0
266+
267+
compare_showing_error(orig_path, tf_name)
268+
269+
# # Now compare by loading both and comparing metadata and code-object structure
270+
# try:
271+
# (
272+
# new_version,
273+
# new_timestamp,
274+
# new_magic_int,
275+
# new_co,
276+
# new_is_pypy,
277+
# new_source_size,
278+
# new_sip_hash,
279+
# new_file_offsets,
280+
# ) = load_meta_and_code_from_filename(tf_name)
281+
# except Exception as e:
282+
# print(
283+
# f"ERROR: failed to load rewritten bytecode file {tf_name}:\n\t{e}",
284+
# file=sys.stderr,
285+
# )
286+
# return 5
287+
288+
# meta_equal = (
289+
# orig_version == new_version
290+
# and orig_magic_int == new_magic_int
291+
# and (orig_timestamp == new_timestamp)
292+
# and (orig_source_size == new_source_size)
293+
# and (orig_sip_hash == new_sip_hash)
294+
# )
295+
296+
# print(
297+
# "Metadata equal (version, magic, timestamp, source_size, sip_hash):", meta_equal
298+
# )
299+
# print("Original is PyPy:", orig_is_pypy, "Rewritten is PyPy:", new_is_pypy)
300+
301+
# # Compare code objects
302+
# codes_equal = compare_code_objects(orig_co, new_co)
303+
# print("Code objects structurally equal:", codes_equal)
304+
305+
# # Compare file_offsets if present
306+
# offsets_equal = orig_file_offsets == new_file_offsets
307+
# print("File offsets equal:", offsets_equal)
308+
309+
# all_identical = same_bytes and meta_equal and codes_equal and offsets_equal
310+
# if all_identical:
311+
# print(
312+
# "RESULT: files are identical both as raw bytes and in contained module/code objects."
313+
# )
314+
# ret = 0
315+
# else:
316+
# print("RESULT: files differ.")
317+
# # Helpful diagnostic: show which pieces disagree
318+
# if not same_bytes:
319+
# print("- raw-bytes differ")
320+
# if not meta_equal:
321+
# print("- metadata differs")
322+
# print(
323+
# " original:",
324+
# (
325+
# orig_version,
326+
# orig_magic_int,
327+
# orig_timestamp,
328+
# orig_source_size,
329+
# orig_sip_hash,
330+
# ),
331+
# )
332+
# print(
333+
# " rewritten:",
334+
# (
335+
# new_version,
336+
# new_magic_int,
337+
# new_timestamp,
338+
# new_source_size,
339+
# new_sip_hash,
340+
# ),
341+
# )
342+
# if not codes_equal:
343+
# print("- code objects differ (some attributes compared):")
344+
# # attempt to print a short diff of co_code lengths and names
345+
# try:
346+
# print(" orig co_code len:", len(getattr(orig_co, "co_code", b"")))
347+
# print(" new co_code len:", len(getattr(new_co, "co_code", b"")))
348+
# print(" orig co_names:", getattr(orig_co, "co_names", None))
349+
# print(" new co_names:", getattr(new_co, "co_names", None))
350+
# except Exception:
351+
# pass
352+
# if not offsets_equal:
353+
# print("- file offsets differ")
354+
# ret = 1
355+
356+
# Do not delete the temporary file automatically; leave it for inspection.
357+
print("Temporary rewritten file left at:", tf_name)
358+
return 1
359+
360+
361+
if __name__ == "__main__":
362+
sys.exit(main(sys.argv))

xdis/load.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -386,10 +386,10 @@ def write_bytecode_file(
386386
magic_int (i.e. bytecode associated with some version of Python)
387387
"""
388388
fp = open(bytecode_path, "wb")
389-
version = py_str2tuple(magicint2version[magic_int])
390-
if version >= (3, 0):
389+
version_tuple = py_str2tuple(magicint2version[magic_int])
390+
if version_tuple >= (3, 0):
391391
fp.write(pack("<Hcc", magic_int, b"\r", b"\n"))
392-
if version >= (3, 7): # pep552 bytes
392+
if version_tuple >= (3, 7): # pep552 bytes
393393
fp.write(pack("<I", 0)) # pep552 bytes
394394
else:
395395
fp.write(pack("<Hcc", magic_int, b"\r", b"\n"))
@@ -404,13 +404,13 @@ def write_bytecode_file(
404404
else:
405405
fp.write(pack("<I", int(datetime.now().timestamp())))
406406

407-
if version >= (3, 3):
407+
if version_tuple >= (3, 3):
408408
# In Python 3.3+, these 4 bytes are the size of the source code_obj file (mod 2^32)
409409
fp.write(pack("<I", filesize))
410410
if isinstance(code_obj, types.CodeType):
411411
fp.write(marshal.dumps(code_obj))
412412
else:
413-
fp.write(xdis.marsh.dumps(code_obj))
413+
fp.write(xdis.marsh.dumps(code_obj, python_version=version_tuple))
414414
fp.close()
415415

416416

0 commit comments

Comments
 (0)