Skip to content

Commit e1b590e

Browse files
author
rocky
committed
DRY unmarshal integer reading routines
1 parent 95dc22f commit e1b590e

File tree

3 files changed

+60
-65
lines changed

3 files changed

+60
-65
lines changed

xdis/unmarsh_graal.py

Lines changed: 18 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,7 @@ def graal_readBigInteger(self):
235235
Reads a marshaled big integer from the input stream.
236236
"""
237237
negative = False
238-
sz = self.graal_readInt() # Get the size in shorts
238+
sz = self.read_uint32() # Get the size in shorts
239239
if sz < 0:
240240
negative = True
241241
sz = -sz
@@ -277,7 +277,7 @@ def graal_readBooleanArray(self) -> tuple[bool, ...]:
277277
Python equivalent of Python Graal's readBooleanArray() from
278278
MarshalModuleBuiltins.java
279279
"""
280-
length: int = int(unpack("<i", self.fp.read(4))[0])
280+
length: int = self.read_uint32()
281281
return tuple([bool(self.graal_readByte()) for _ in range(length)])
282282

283283
def graal_readByte(self) -> int:
@@ -292,7 +292,7 @@ def graal_readBytes(self) -> bytes:
292292
Python equivalent of Python Graal's readBytes() from
293293
MarshalModuleBuiltins.java
294294
"""
295-
length: int = unpack("<i", self.fp.read(4))[0]
295+
length: int = self.read_uint32()
296296
return bytes([self.graal_readByte() for _ in range(length)])
297297

298298
def graal_readDouble(self) -> float:
@@ -307,23 +307,16 @@ def graal_readDoubleArray(self) -> tuple[float, ...]:
307307
Python equivalent of Python Graal's readDoubleArray() from
308308
MarshalModuleBuiltins.java
309309
"""
310-
length: int = int(unpack("<i", self.fp.read(4))[0])
310+
length: int = self.read_uint32()
311311
return tuple([self.graal_readDouble() for _ in range(length)])
312312

313-
def graal_readInt(self) -> int:
314-
"""
315-
Python equivalent of Python Graal's readInt() from
316-
MarshalModuleBuiltins.java
317-
"""
318-
return int(unpack("<i", self.fp.read(4))[0])
319-
320313
def graal_readIntArray(self) -> tuple[int, ...]:
321314
"""
322315
Python equivalent of Python Graal's readIntArray() from
323316
MarshalModuleBuiltins.java
324317
"""
325318
length: int = int(unpack("<i", self.fp.read(4))[0])
326-
return tuple([self.graal_readInt() for _ in range(length)])
319+
return tuple([self.read_int32() for _ in range(length)])
327320

328321
def graal_readLong(self) -> int:
329322
"""
@@ -370,18 +363,18 @@ def graal_readStringArray(self) -> tuple[str, ...]:
370363
Python equvalent of Python Graal's readObjectArray() from
371364
MarshalModuleBuiltins.java
372365
"""
373-
length: int = self.graal_readInt()
366+
length: int = self.read_uint32()
374367
return tuple([self.graal_readString() for _ in range(length)])
375368

376369
def graal_readSparseTable(self) -> Dict[int, tuple]:
377370
"""
378371
Python equvalent of Python Graal's readObjectArray() from
379372
MarshalModuleBuiltins.java
380373
"""
381-
self.graal_readInt() # the length return value isn't used.
374+
self.read_uint32() # the length return value isn't used.
382375
table = {} # new int[length][];
383376
while True:
384-
i = self.graal_readInt()
377+
i = self.read_int32()
385378
if i == -1:
386379
return table
387380
table[i] = self.graal_readIntArray()
@@ -494,15 +487,15 @@ def t_graal_CodeUnit(self, save_ref, bytes_for_s: bool = False):
494487

495488
co_name = self.graal_readString()
496489
co_qualname = self.graal_readString()
497-
co_argcount = self.graal_readInt()
498-
co_kwonlyargcount = self.graal_readInt()
499-
co_posonlyargcount = self.graal_readInt()
490+
co_argcount = self.read_uint32()
491+
co_kwonlyargcount = self.read_uint32()
492+
co_posonlyargcount = self.read_uint32()
500493

501-
co_stacksize = self.graal_readInt()
494+
co_stacksize = self.read_uint32()
502495
co_code_offset_in_file = self.fp.tell()
503496
co_code = self.graal_readBytes()
504497
other_fields["srcOffsetTable"] = self.graal_readBytes()
505-
co_flags = self.graal_readInt()
498+
co_flags = self.read_uint32()
506499

507500
# writeStringArray(code.names);
508501
# writeStringArray(code.varnames);
@@ -542,11 +535,11 @@ def t_graal_CodeUnit(self, save_ref, bytes_for_s: bool = False):
542535

543536
other_fields["primitiveConstants"] = self.graal_readLongArray()
544537
other_fields["exception_handler_ranges"] = self.graal_readIntArray()
545-
other_fields["condition_profileCount"] = self.graal_readInt()
546-
other_fields["startLine"] = self.graal_readInt()
547-
other_fields["startColumn"] = self.graal_readInt()
548-
other_fields["endLine"] = self.graal_readInt()
549-
other_fields["endColumn"] = self.graal_readInt()
538+
other_fields["condition_profileCount"] = self.read_uint32()
539+
other_fields["startLine"] = self.read_uint32()
540+
other_fields["startColumn"] = self.read_uint32()
541+
other_fields["endLine"] = self.read_uint32()
542+
other_fields["endColumn"] = self.read_uint32()
550543
other_fields["outputCanQuicken"] = self.graal_readBytes()
551544
other_fields["variableShouldUnbox"] = self.graal_readBytes()
552545
other_fields["generalizeInputsMap"] = self.graal_readSparseTable()

xdis/unmarsh_rust.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
object.
2525
"""
2626

27-
from struct import unpack
2827
from typing import Any, Dict, List, Tuple, Union
2928

3029
from xdis.codetype.code313rust import Code313Rust, SourceLocation
@@ -277,18 +276,6 @@ def t_bigint(self, save_ref: bool=False, bytes_for_s: bool=False):
277276
value = int.from_bytes(byte_data, byteorder='little')
278277
return value if is_positive else -value
279278

280-
def read_int16(self):
281-
return int(unpack("<h", self.fp.read(2))[0])
282-
283-
def read_int32(self):
284-
return int(unpack("<i", self.fp.read(4))[0])
285-
286-
def read_slice(self, n: int) -> bytes:
287-
return self.fp.read(n)
288-
289-
def read_uint32(self):
290-
return int(unpack("<I", self.fp.read(4))[0])
291-
292279
def read_string(self, n: int, bytes_for_s: bool=False) -> Union[bytes, str]:
293280
s = self.read_slice(n)
294281
if not bytes_for_s:

xdis/unmarshal.py

Lines changed: 42 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,21 @@ def __init__(self, fp, magic_int, bytes_for_s, code_objects={}) -> None:
195195

196196
self.UNMARSHAL_DISPATCH_TABLE = UNMARSHAL_DISPATCH_TABLE
197197

198+
def read_int16(self) -> int:
199+
return int(unpack("<h", self.fp.read(2))[0])
200+
201+
def read_int32(self) -> int:
202+
return int(unpack("<i", self.fp.read(4))[0])
203+
204+
def read_int64(self) -> int:
205+
return int(unpack("<q", self.fp.read(8))[0])
206+
207+
def read_slice(self, n: int) -> bytes:
208+
return self.fp.read(n)
209+
210+
def read_uint32(self) -> int:
211+
return int(unpack("<I", self.fp.read(4))[0])
212+
198213
def load(self):
199214
"""
200215
``marshal.load()`` written in Python. When the Python bytecode magic loaded is the
@@ -302,16 +317,16 @@ def t_True(self, save_ref, bytes_for_s: bool = False) -> bool:
302317
return True
303318

304319
def t_int32(self, save_ref, bytes_for_s: bool = False):
305-
return self.r_ref(int(unpack("<i", self.fp.read(4))[0]), save_ref)
320+
return self.r_ref(self.read_int32(), save_ref)
306321

307322
def t_long(self, save_ref, bytes_for_s: bool = False):
308-
n = unpack("<i", self.fp.read(4))[0]
323+
n = self.read_uint32()
309324
if n == 0:
310325
return long(0)
311326
size = abs(n)
312327
d = long(0)
313328
for j in range(0, size):
314-
md = int(unpack("<h", self.fp.read(2))[0])
329+
md = self.read_int16()
315330
# This operation and turn "d" from a long back
316331
# into an int.
317332
d += md << j * 15
@@ -323,7 +338,7 @@ def t_long(self, save_ref, bytes_for_s: bool = False):
323338

324339
# Python 3.4 removed this.
325340
def t_int64(self, save_ref, bytes_for_s: bool = False):
326-
obj = unpack("<q", self.fp.read(8))[0]
341+
obj = self.read_int64()
327342
if save_ref:
328343
self.intern_objects.append(obj)
329344
return obj
@@ -342,7 +357,7 @@ def unpack_pre_24() -> float:
342357
return float(self.fp.read(unpack("B", self.fp.read(1))[0]))
343358

344359
def unpack_newer() -> float:
345-
return float(self.fp.read(unpack("<i", self.fp.read(4))[0]))
360+
return float(self.fp.read(self.read_int32()))
346361

347362
get_float = unpack_pre_24 if self.magic_int <= 62061 else unpack_newer
348363

@@ -363,7 +378,7 @@ def t_string(self, save_ref, bytes_for_s: bool):
363378
In Python3, this is a ``bytes`` type. In Python2, it is a string type;
364379
``bytes_for_s`` is True when a Python 3 interpreter is reading Python 2 bytecode.
365380
"""
366-
strsize = unpack("<i", self.fp.read(4))[0]
381+
strsize = self.read_uint32()
367382
s = self.fp.read(strsize)
368383
if not bytes_for_s:
369384
s = compat_str(s)
@@ -377,7 +392,7 @@ def t_ASCII_interned(self, save_ref, bytes_for_s: bool = False):
377392
the string.
378393
"""
379394
# FIXME: check
380-
strsize = unpack("<i", self.fp.read(4))[0]
395+
strsize = self.read_uint32()
381396
interned = compat_str(self.fp.read(strsize))
382397
self.intern_strings.append(interned)
383398
return self.r_ref(interned, save_ref)
@@ -388,7 +403,7 @@ def t_ASCII(self, save_ref, bytes_for_s: bool = False):
388403
There are true strings in Python3 as opposed to
389404
bytes.
390405
"""
391-
strsize = unpack("<i", self.fp.read(4))[0]
406+
strsize = self.read_uint32()
392407
s = self.fp.read(strsize)
393408
s = compat_str(s)
394409
return self.r_ref(s, save_ref)
@@ -407,13 +422,13 @@ def t_short_ASCII_interned(self, save_ref, bytes_for_s: bool = False):
407422
return self.r_ref(interned, save_ref)
408423

409424
def t_interned(self, save_ref, bytes_for_s: bool = False):
410-
strsize = unpack("<i", self.fp.read(4))[0]
425+
strsize = self.read_uint32()
411426
interned = compat_str(self.fp.read(strsize))
412427
self.intern_strings.append(interned)
413428
return self.r_ref(interned, save_ref)
414429

415430
def t_unicode(self, save_ref, bytes_for_s: bool = False):
416-
strsize = unpack("<i", self.fp.read(4))[0]
431+
strsize = self.read_uint32()
417432
unicodestring = self.fp.read(strsize)
418433
if self.version_triple < (3, 0):
419434
string = UnicodeForPython3(unicodestring)
@@ -434,7 +449,7 @@ def t_small_tuple(self, save_ref, bytes_for_s: bool = False):
434449
return self.r_ref_insert(ret, i)
435450

436451
def t_tuple(self, save_ref, bytes_for_s: bool = False):
437-
tuplesize = unpack("<i", self.fp.read(4))[0]
452+
tuplesize = self.read_uint32()
438453
ret = self.r_ref(tuple(), save_ref)
439454
while tuplesize > 0:
440455
ret += (self.r_object(bytes_for_s=bytes_for_s),)
@@ -443,15 +458,15 @@ def t_tuple(self, save_ref, bytes_for_s: bool = False):
443458

444459
def t_list(self, save_ref, bytes_for_s: bool = False):
445460
# FIXME: check me
446-
n = unpack("<i", self.fp.read(4))[0]
461+
n = self.read_uint32()
447462
ret = self.r_ref(list(), save_ref)
448463
while n > 0:
449464
ret += (self.r_object(bytes_for_s=bytes_for_s),)
450465
n -= 1
451466
return ret
452467

453468
def t_frozenset(self, save_ref, bytes_for_s: bool = False):
454-
setsize = unpack("<i", self.fp.read(4))[0]
469+
setsize = self.read_uint32()
455470
collection, i = self.r_ref_reserve([], save_ref)
456471
while setsize > 0:
457472
collection.append(self.r_object(bytes_for_s=bytes_for_s))
@@ -462,7 +477,7 @@ def t_frozenset(self, save_ref, bytes_for_s: bool = False):
462477
return self.r_ref_insert(final_frozenset, i)
463478

464479
def t_set(self, save_ref, bytes_for_s: bool = False):
465-
setsize = unpack("<i", self.fp.read(4))[0]
480+
setsize = self.read_uint32()
466481
ret, i = self.r_ref_reserve(tuple(), save_ref)
467482
while setsize > 0:
468483
ret += (self.r_object(bytes_for_s=bytes_for_s),)
@@ -484,7 +499,7 @@ def t_dict(self, save_ref, bytes_for_s: bool = False):
484499
return ret
485500

486501
def t_python2_string_reference(self, save_ref, bytes_for_s: bool = False):
487-
refnum = unpack("<i", self.fp.read(4))[0]
502+
refnum = self.read_uint32()
488503
return self.intern_strings[refnum]
489504

490505
def t_slice(self, save_ref, bytes_for_s: bool = False):
@@ -522,9 +537,9 @@ def t_code(self, save_ref, bytes_for_s: bool = False):
522537
self.version_triple = magic_int2tuple(self.magic_int)
523538

524539
if self.version_triple >= (2, 3):
525-
co_argcount = unpack("<i", self.fp.read(4))[0]
540+
co_argcount = self.read_uint32()
526541
elif self.version_triple >= (1, 3):
527-
co_argcount = unpack("<h", self.fp.read(2))[0]
542+
co_argcount = self.read_int16()
528543
else:
529544
co_argcount = 0
530545

@@ -538,7 +553,7 @@ def t_code(self, save_ref, bytes_for_s: bool = False):
538553
co_posonlyargcount = None
539554

540555
if self.version_triple >= (3, 0):
541-
kwonlyargcount = unpack("<i", self.fp.read(4))[0]
556+
kwonlyargcount = self.read_uint32()
542557
else:
543558
kwonlyargcount = 0
544559

@@ -547,21 +562,21 @@ def t_code(self, save_ref, bytes_for_s: bool = False):
547562
self.version_triple[:2] == (3, 11) and self.is_pypy
548563
):
549564
if self.version_triple >= (2, 3):
550-
co_nlocals = unpack("<i", self.fp.read(4))[0]
565+
co_nlocals = self.read_uint32()
551566
elif self.version_triple >= (1, 3):
552-
co_nlocals = unpack("<h", self.fp.read(2))[0]
567+
co_nlocals = self.read_int16()
553568

554569
if self.version_triple >= (2, 3):
555-
co_stacksize = unpack("<i", self.fp.read(4))[0]
570+
co_stacksize = self.read_uint32()
556571
elif self.version_triple >= (1, 5):
557-
co_stacksize = unpack("<h", self.fp.read(2))[0]
572+
co_stacksize = self.read_int16()
558573
else:
559574
co_stacksize = 0
560575

561576
if self.version_triple >= (2, 3):
562-
co_flags = unpack("<i", self.fp.read(4))[0]
577+
co_flags = self.read_uint32()
563578
elif self.version_triple >= (1, 3):
564-
co_flags = unpack("<h", self.fp.read(2))[0]
579+
co_flags = self.read_int16()
565580
else:
566581
co_flags = 0
567582

@@ -628,9 +643,9 @@ def t_code(self, save_ref, bytes_for_s: bool = False):
628643

629644
if self.version_triple >= (1, 5):
630645
if self.version_triple >= (2, 3):
631-
co_firstlineno = unpack("<i", self.fp.read(4))[0]
646+
co_firstlineno = self.read_int32()
632647
else:
633-
co_firstlineno = unpack("<h", self.fp.read(2))[0]
648+
co_firstlineno = self.read_int16()
634649

635650
if self.version_triple >= (3, 11) and not self.is_pypy:
636651
co_linetable = self.r_object(bytes_for_s=bytes_for_s)
@@ -775,7 +790,7 @@ def t_code_old(self, _, bytes_for_s: bool = False):
775790

776791
# Since Python 3.4
777792
def t_object_reference(self, save_ref=None, bytes_for_s: bool = False):
778-
refnum = unpack("<i", self.fp.read(4))[0]
793+
refnum = self.read_uint32()
779794
return self.intern_objects[refnum]
780795

781796
def t_unknown(self, save_ref=None, bytes_for_s: bool = False):

0 commit comments

Comments
 (0)