Skip to content

Commit 985b389

Browse files
committed
add soft per-file --timeout option to allow slow files to be skipped in batch jobs
1 parent 1790864 commit 985b389

File tree

4 files changed

+37
-13
lines changed

4 files changed

+37
-13
lines changed

pylingual/decompiler.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
import tempfile
3131
import sys
3232
import shutil
33+
import time
3334
from dataclasses import dataclass
3435
from pathlib import Path
3536
from typing import TYPE_CHECKING
@@ -99,7 +100,7 @@ class Decompiler:
99100
:param trust_lnotab: Decides whether or not to use line number information
100101
"""
101102

102-
def __init__(self, pyc: PYCFile, segmenter: transformers.Pipeline, translator: CacheTranslator, version: PythonVersion, top_k=10, trust_lnotab=False):
103+
def __init__(self, pyc: PYCFile, segmenter: transformers.Pipeline, translator: CacheTranslator, version: PythonVersion, top_k=10, trust_lnotab=False, timeout=None):
103104
self.pyc = pyc
104105
self.pyc.copy()
105106
self.name = pyc.pyc_path.name if pyc.pyc_path is not None else repr(pyc)
@@ -110,16 +111,27 @@ def __init__(self, pyc: PYCFile, segmenter: transformers.Pipeline, translator: C
110111
self.highest_k_used = 0
111112
self.tmpn = 0
112113
self.trust_lnotab = trust_lnotab
114+
self.timeout = timeout
115+
self.start_time = time.time()
116+
117+
def check_timeout(self):
118+
if self.timeout is not None and time.time() - self.start_time > self.timeout:
119+
raise TimeoutError(f"Decompilation of {self.name} timed out after {self.timeout} seconds")
113120

114121
def __call__(self):
115122
with tempfile.TemporaryDirectory() as tmp:
116123
self.tmp = Path(tmp)
117124

118125
self.mask_bytecode()
126+
self.check_timeout()
119127
self.run_segmentation()
128+
self.check_timeout()
120129
self.run_translation()
130+
self.check_timeout()
121131
self.unmask_lines()
132+
self.check_timeout()
122133
self.run_cflow_reconstruction()
134+
self.check_timeout()
123135
self.reconstruct_source()
124136

125137
self.equivalence_results = self.check_reconstruction(self.indented_source)
@@ -247,7 +259,7 @@ def run_segmentation(self):
247259

248260
window_coordinates, flat_window_requests, inst_index = zip(*window_segmentation_requests)
249261

250-
window_segmentation_results = [filter_subwords(segmentation_result) for segmentation_result in self.segmenter(TrackedDataset(SEGMENTATION_STEP, list(flat_window_requests)), batch_size=8)]
262+
window_segmentation_results = [filter_subwords(segmentation_result) for segmentation_result in self.segmenter(TrackedDataset(SEGMENTATION_STEP, list(flat_window_requests), check_timeout=self.check_timeout), batch_size=8)]
251263

252264
self.segmentation_results = merge(list(window_coordinates), window_segmentation_results, list(inst_index), MAX_WINDOW_LENGTH, STEP_SIZE) # merge everything
253265

@@ -292,7 +304,7 @@ def run_translation(self):
292304
for instructions, boundary_predictions in zip(self.ordered_instructions, self.segmentation_results):
293305
translation_requests.append(self.make_translation_request(instructions, boundary_predictions))
294306
flattened_translation_requests = list(itertools.chain.from_iterable(translation_requests))
295-
self.translation_results = self.translator(flattened_translation_requests)
307+
self.translation_results = self.translator(flattened_translation_requests, check_timeout=self.check_timeout)
296308
unflatten(self.translation_results, translation_requests)
297309
self.update_source_lines()
298310
except Exception as e:
@@ -302,7 +314,7 @@ def run_translation(self):
302314
def run_cflow_reconstruction(self):
303315
logger.info(f"Reconstructing control flow for {self.name}...")
304316
try:
305-
cfts = {bc.codeobj: bc_to_cft(bc, self.source_lines) for bc in TrackedList(CFLOW_STEP, self.ordered_bytecodes)}
317+
cfts = {bc.codeobj: bc_to_cft(bc, self.source_lines) for bc in TrackedList(CFLOW_STEP, self.ordered_bytecodes, check_timeout=self.check_timeout)}
306318
self.source_context = SourceContext(self.pyc, self.source_lines, cfts)
307319
version = magicint2version.get(self.pyc.magic, "?")
308320
time = datetime.datetime.fromtimestamp(self.pyc.timestamp, datetime.UTC).strftime("%Y-%m-%d %H:%M:%S UTC")
@@ -352,7 +364,7 @@ def check_reconstruction(self, source: str) -> list[TestResult]:
352364
logger.info(f"Checking decompilation for {self.name}...")
353365
src = self.tmpfile()
354366
pyc = self.tmpfile()
355-
src.write_text(source, encoding='utf-8')
367+
src.write_text(source, encoding="utf-8")
356368
try:
357369
compile_version(src, pyc, self.version)
358370
except CompileError as e:
@@ -428,7 +440,7 @@ def update_starts_line(self):
428440
inst.starts_line = None
429441

430442

431-
def decompile(pyc: PYCFile | Path, save_to: Path | None = None, config_file: Path | None = None, version: str | None = None, top_k: int = 10, trust_lnotab: bool = False) -> DecompilerResult:
443+
def decompile(pyc: PYCFile | Path, save_to: Path | None = None, config_file: Path | None = None, version: str | None = None, top_k: int = 10, trust_lnotab: bool = False, timeout: int | None = None) -> DecompilerResult:
432444
"""
433445
Decompile a PYC file.
434446
@@ -438,6 +450,7 @@ def decompile(pyc: PYCFile | Path, save_to: Path | None = None, config_file: Pat
438450
:param version: Loads the models corresponding to this python version. if None, automatically detects version based on input PYC file.
439451
:param top_k: Max number of pyc segmentations to consider.
440452
:param trust_lnotab: Trust the lnotab in the input PYC for segmentation (False recommended).
453+
:param timeout: Maximum time in seconds to allow decompilation to run.
441454
:return: DecompilerResult class including important information about decompilation
442455
"""
443456
logger.info(f"Loading {pyc}...")
@@ -472,12 +485,12 @@ def decompile(pyc: PYCFile | Path, save_to: Path | None = None, config_file: Pat
472485
logger.info(f"Decompiling pyc {pyc.pyc_path.resolve() if pyc.pyc_path else repr(pyc)} to {save_to.resolve()}")
473486
else:
474487
logger.info(f"Decompiling pyc {pyc.pyc_path.resolve() if pyc.pyc_path else repr(pyc)}")
475-
decompiler = Decompiler(pyc, segmenter, translator, pversion, top_k, trust_lnotab)
488+
decompiler = Decompiler(pyc, segmenter, translator, pversion, top_k, trust_lnotab, timeout)
476489
result = decompiler()
477490

478491
logger.info("Decompilation complete")
479492
logger.info(f"{result.calculate_success_rate():.2%} code object success rate")
480493
if save_to:
481-
save_to.write_text(result.decompiled_source, encoding='utf-8')
494+
save_to.write_text(result.decompiled_source, encoding="utf-8")
482495
logger.info(f"Result saved to {save_to}")
483496
return result

pylingual/main.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,8 @@ def add_file(source: Path, dest: Path):
9797
@click.option("--force", is_flag=True, default=False, help="Overwrite existing output files.")
9898
@click.option("--trust-lnotab", is_flag=True, default=False, help="Use the lnotab for segmentation instead of the segmentation model.")
9999
@click.option("--init-pyenv", is_flag=True, default=False, help="Install pyenv before decompiling.")
100-
def main(files: list[Path], out_dir: Path | None, config_file: Path | None, version: PythonVersion | None, top_k: int, flatten: bool, force: bool, trust_lnotab: bool, init_pyenv: bool, quiet: bool):
100+
@click.option("--timeout", default=None, type=int, help="Maximum time in seconds to allow decompilation to run per file.", metavar="SECONDS")
101+
def main(files: list[Path], out_dir: Path | None, config_file: Path | None, version: PythonVersion | None, top_k: int, flatten: bool, force: bool, trust_lnotab: bool, init_pyenv: bool, quiet: bool, timeout: int | None):
101102
rich.reconfigure(markup=False, emoji=False, quiet=quiet, theme=Theme({"logging.keyword": "yellow not bold"}))
102103
console = rich.get_console()
103104
log_handler = RichHandler(console=console, rich_tracebacks=True)
@@ -173,9 +174,13 @@ def init(self):
173174
version=version,
174175
top_k=top_k,
175176
trust_lnotab=trust_lnotab,
177+
timeout=timeout,
176178
)
177179
pyc = result.original_pyc
178180
print_result(f"Equivalence Results for {pyc.pyc_path.name if pyc.pyc_path else repr(pyc)}", result.equivalence_results)
181+
except TimeoutError as e:
182+
logger.error(str(e))
183+
continue
179184
except Exception:
180185
logger.exception(f"Failed to decompile {pyc_path}")
181186
console.rule()

pylingual/models.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,13 +85,14 @@ def _translate_with_backoff(self, translation_requests: TrackedDataset) -> list[
8585
translation_results.append("'''Decompiler error: line too long for translation. Please decompile this statement manually.'''")
8686
return translation_results
8787

88-
def __call__(self, args: list, **_):
88+
def __call__(self, args: list, check_timeout: callable = None, **_):
8989
normalized_args = [normalize_masks(fix_jump_targets(x)) for x in args]
9090

9191
# New are those not in the local cache
9292
new = TrackedDataset(
9393
TRANSLATION_STEP,
9494
list({norm for norm, _ in normalized_args if norm not in self.cache}),
95+
check_timeout=check_timeout,
9596
)
9697

9798
# Now, "new" has been updated to those not in local

pylingual/utils/tracked_list.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,17 +18,20 @@ class TrackedList:
1818
Used to display progress bars when PyLingual is run as a script, does nothing otherwise
1919
"""
2020

21-
def __init__(self, name: str, x: list):
21+
def __init__(self, name: str, x: list, check_timeout: callable = None):
2222
self.name = name
2323
self.x = x
2424
self.i = 0
25+
self.check_timeout = check_timeout
2526
self.init()
2627

2728
# overwritten when run as script
2829
def init(self):
2930
pass
3031

3132
def __getitem__(self, i):
33+
if self.check_timeout:
34+
self.check_timeout()
3235
self.progress(i - self.i)
3336
self.i = i
3437
return self.x[i]
@@ -40,6 +43,8 @@ def __iter__(self):
4043
return self
4144

4245
def __next__(self):
46+
if self.check_timeout:
47+
self.check_timeout()
4348
try:
4449
n = self.x[self.i]
4550
except:
@@ -58,6 +63,6 @@ class TrackedDataset(TrackedList):
5863
Like TrackedList, but inherits from Dataset
5964
"""
6065

61-
def __init__(self, name: str, x: list):
62-
super().__init__(name, x)
66+
def __init__(self, name: str, x: list, check_timeout: callable = None):
67+
super().__init__(name, x, check_timeout)
6368
TrackedDataset.__bases__ = (TrackedList, transformers.pipelines.base.Dataset)

0 commit comments

Comments
 (0)