3030import tempfile
3131import sys
3232import shutil
33+ import time
3334from dataclasses import dataclass
3435from pathlib import Path
3536from typing import TYPE_CHECKING
@@ -99,7 +100,7 @@ class Decompiler:
99100 :param trust_lnotab: Decides whether or not to use line number information
100101 """
101102
102- def __init__ (self , pyc : PYCFile , segmenter : transformers .Pipeline , translator : CacheTranslator , version : PythonVersion , top_k = 10 , trust_lnotab = False ):
103+ def __init__ (self , pyc : PYCFile , segmenter : transformers .Pipeline , translator : CacheTranslator , version : PythonVersion , top_k = 10 , trust_lnotab = False , timeout = None ):
103104 self .pyc = pyc
104105 self .pyc .copy ()
105106 self .name = pyc .pyc_path .name if pyc .pyc_path is not None else repr (pyc )
@@ -110,16 +111,27 @@ def __init__(self, pyc: PYCFile, segmenter: transformers.Pipeline, translator: C
110111 self .highest_k_used = 0
111112 self .tmpn = 0
112113 self .trust_lnotab = trust_lnotab
114+ self .timeout = timeout
115+ self .start_time = time .time ()
116+
117+ def check_timeout (self ):
118+ if self .timeout is not None and time .time () - self .start_time > self .timeout :
119+ raise TimeoutError (f"Decompilation of { self .name } timed out after { self .timeout } seconds" )
113120
114121 def __call__ (self ):
115122 with tempfile .TemporaryDirectory () as tmp :
116123 self .tmp = Path (tmp )
117124
118125 self .mask_bytecode ()
126+ self .check_timeout ()
119127 self .run_segmentation ()
128+ self .check_timeout ()
120129 self .run_translation ()
130+ self .check_timeout ()
121131 self .unmask_lines ()
132+ self .check_timeout ()
122133 self .run_cflow_reconstruction ()
134+ self .check_timeout ()
123135 self .reconstruct_source ()
124136
125137 self .equivalence_results = self .check_reconstruction (self .indented_source )
@@ -247,7 +259,7 @@ def run_segmentation(self):
247259
248260 window_coordinates , flat_window_requests , inst_index = zip (* window_segmentation_requests )
249261
250- window_segmentation_results = [filter_subwords (segmentation_result ) for segmentation_result in self .segmenter (TrackedDataset (SEGMENTATION_STEP , list (flat_window_requests )), batch_size = 8 )]
262+ window_segmentation_results = [filter_subwords (segmentation_result ) for segmentation_result in self .segmenter (TrackedDataset (SEGMENTATION_STEP , list (flat_window_requests ), check_timeout = self . check_timeout ), batch_size = 8 )]
251263
252264 self .segmentation_results = merge (list (window_coordinates ), window_segmentation_results , list (inst_index ), MAX_WINDOW_LENGTH , STEP_SIZE ) # merge everything
253265
@@ -292,7 +304,7 @@ def run_translation(self):
292304 for instructions , boundary_predictions in zip (self .ordered_instructions , self .segmentation_results ):
293305 translation_requests .append (self .make_translation_request (instructions , boundary_predictions ))
294306 flattened_translation_requests = list (itertools .chain .from_iterable (translation_requests ))
295- self .translation_results = self .translator (flattened_translation_requests )
307+ self .translation_results = self .translator (flattened_translation_requests , check_timeout = self . check_timeout )
296308 unflatten (self .translation_results , translation_requests )
297309 self .update_source_lines ()
298310 except Exception as e :
@@ -302,7 +314,7 @@ def run_translation(self):
302314 def run_cflow_reconstruction (self ):
303315 logger .info (f"Reconstructing control flow for { self .name } ..." )
304316 try :
305- cfts = {bc .codeobj : bc_to_cft (bc , self .source_lines ) for bc in TrackedList (CFLOW_STEP , self .ordered_bytecodes )}
317+ cfts = {bc .codeobj : bc_to_cft (bc , self .source_lines ) for bc in TrackedList (CFLOW_STEP , self .ordered_bytecodes , check_timeout = self . check_timeout )}
306318 self .source_context = SourceContext (self .pyc , self .source_lines , cfts )
307319 version = magicint2version .get (self .pyc .magic , "?" )
308320 time = datetime .datetime .fromtimestamp (self .pyc .timestamp , datetime .UTC ).strftime ("%Y-%m-%d %H:%M:%S UTC" )
@@ -352,7 +364,7 @@ def check_reconstruction(self, source: str) -> list[TestResult]:
352364 logger .info (f"Checking decompilation for { self .name } ..." )
353365 src = self .tmpfile ()
354366 pyc = self .tmpfile ()
355- src .write_text (source , encoding = ' utf-8' )
367+ src .write_text (source , encoding = " utf-8" )
356368 try :
357369 compile_version (src , pyc , self .version )
358370 except CompileError as e :
@@ -428,7 +440,7 @@ def update_starts_line(self):
428440 inst .starts_line = None
429441
430442
431- def decompile (pyc : PYCFile | Path , save_to : Path | None = None , config_file : Path | None = None , version : str | None = None , top_k : int = 10 , trust_lnotab : bool = False ) -> DecompilerResult :
443+ def decompile (pyc : PYCFile | Path , save_to : Path | None = None , config_file : Path | None = None , version : str | None = None , top_k : int = 10 , trust_lnotab : bool = False , timeout : int | None = None ) -> DecompilerResult :
432444 """
433445 Decompile a PYC file.
434446
@@ -438,6 +450,7 @@ def decompile(pyc: PYCFile | Path, save_to: Path | None = None, config_file: Pat
438450 :param version: Loads the models corresponding to this python version. if None, automatically detects version based on input PYC file.
439451 :param top_k: Max number of pyc segmentations to consider.
440452 :param trust_lnotab: Trust the lnotab in the input PYC for segmentation (False recommended).
453+ :param timeout: Maximum time in seconds to allow decompilation to run.
441454 :return: DecompilerResult class including important information about decompilation
442455 """
443456 logger .info (f"Loading { pyc } ..." )
@@ -472,12 +485,12 @@ def decompile(pyc: PYCFile | Path, save_to: Path | None = None, config_file: Pat
472485 logger .info (f"Decompiling pyc { pyc .pyc_path .resolve () if pyc .pyc_path else repr (pyc )} to { save_to .resolve ()} " )
473486 else :
474487 logger .info (f"Decompiling pyc { pyc .pyc_path .resolve () if pyc .pyc_path else repr (pyc )} " )
475- decompiler = Decompiler (pyc , segmenter , translator , pversion , top_k , trust_lnotab )
488+ decompiler = Decompiler (pyc , segmenter , translator , pversion , top_k , trust_lnotab , timeout )
476489 result = decompiler ()
477490
478491 logger .info ("Decompilation complete" )
479492 logger .info (f"{ result .calculate_success_rate ():.2%} code object success rate" )
480493 if save_to :
481- save_to .write_text (result .decompiled_source , encoding = ' utf-8' )
494+ save_to .write_text (result .decompiled_source , encoding = " utf-8" )
482495 logger .info (f"Result saved to { save_to } " )
483496 return result
0 commit comments