Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 52 additions & 4 deletions fickling/fickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,6 +528,19 @@ def __init__(self):
self.calls: list[ast.Call] = []
self.non_setstate_calls: list[ast.Call] = []
self.likely_safe_imports: set[str] = set()
self._visited: set[int] = set() # Track visited nodes by id to detect cycles

def visit(self, node: ast.AST) -> Any:
"""Override visit to detect and skip cycles in the AST.

Pickle files can create cyclic AST structures via MEMOIZE + GET opcodes.
Without cycle detection, visiting such structures causes infinite recursion.
"""
node_id = id(node)
if node_id in self._visited:
return None # Skip already-visited nodes
self._visited.add(node_id)
return super().visit(node)

def _process_import(self, node: ast.Import | ast.ImportFrom):
self.imports.append(node)
Expand Down Expand Up @@ -572,6 +585,8 @@ def __init__(self, opcodes: Iterable[Opcode], has_invalid_opcode: bool = False):
# Whether the pickled sequence was interrupted because of
# an invalid opcode
self._has_invalid_opcode: bool = has_invalid_opcode
# Whether the pickle contains cyclic references
self._has_cycles: bool = False
self._has_interpretation_error: bool = False

def __len__(self) -> int:
Expand Down Expand Up @@ -891,6 +906,7 @@ def has_invalid_opcode(self) -> bool:

@property
def has_interpretation_error(self) -> bool:
_ = self.ast # Ensure interpretation ran
return self._has_interpretation_error

@staticmethod
Expand Down Expand Up @@ -1036,7 +1052,9 @@ def non_standard_imports(self) -> Iterator[ast.Import | ast.ImportFrom]:
def ast(self) -> ast.Module:
if self._ast is None:
try:
self._ast = Interpreter.interpret(self)
interpreter = Interpreter(self)
self._ast = interpreter.to_ast()
self._has_cycles = interpreter._has_cycle
except InterpretationError as e:
self._has_interpretation_error = True
sys.stderr.write(
Expand All @@ -1046,6 +1064,12 @@ def ast(self) -> ast.Module:
self._ast = ast.Module(body=[], type_ignores=[])
return self._ast

@property
def has_cycles(self) -> bool:
"""Check if the pickle contains cyclic references."""
_ = self.ast # Ensure interpretation ran
return self._has_cycles

@property
def nb_opcodes(self) -> int:
return len(self._opcodes)
Expand Down Expand Up @@ -1130,6 +1154,7 @@ def __init__(
self._module: ast.Module | None = None
self._var_counter: int = first_variable_id
self._opcodes: Iterator[Opcode] = iter(pickled)
self._has_cycle: bool = False

@property
def next_variable_id(self) -> int:
Expand Down Expand Up @@ -1477,11 +1502,15 @@ def run(self, interpreter: Interpreter):
raise InterpretationError("Exhausted the stack while searching for a MarkObject!")
if not interpreter.stack:
raise ValueError("Stack was empty; expected a pyset")
pyset = interpreter.stack.pop()
pyset = interpreter.stack[-1]
if not isinstance(pyset, ast.Set):
raise ValueError(
f"{pyset!r} was expected to be a set-like object with an `add` function"
)
# Check for cyclic references - sets cannot contain themselves (unhashable)
for elem in to_add:
if elem is pyset:
raise InterpretationError("Set cannot contain itself (unhashable type)")
pyset.elts.extend(reversed(to_add))


Expand Down Expand Up @@ -1760,11 +1789,17 @@ def create(memo_id: int) -> Get:
class SetItems(StackSliceOpcode):
name = "SETITEMS"

def run(self, interpreter: Interpreter, stack_slice: List[ast.expr]):
def run(self, interpreter: Interpreter, stack_slice: list[ast.expr]):
pydict = interpreter.stack.pop()
update_dict_keys = []
update_dict_values = []
for key, value in zip(stack_slice[::2], stack_slice[1::2], strict=False):
# Check for cyclic references
if key is pydict:
raise InterpretationError("Dict cannot use itself as key (unhashable type)")
if value is pydict:
value = ast.Set(elts=[ast.Constant(value=...)])
interpreter._has_cycle = True
update_dict_keys.append(key)
update_dict_values.append(value)
if isinstance(pydict, ast.Dict) and not pydict.keys:
Expand Down Expand Up @@ -1792,6 +1827,12 @@ def run(self, interpreter: Interpreter):
value = interpreter.stack.pop()
key = interpreter.stack.pop()
pydict = interpreter.stack.pop()
# Check for cyclic references
if key is pydict:
raise InterpretationError("Dict cannot use itself as key (unhashable type)")
if value is pydict:
value = ast.Set(elts=[ast.Constant(value=...)])
interpreter._has_cycle = True
if isinstance(pydict, ast.Dict) and not pydict.keys:
# the dict is empty, so add a new one
interpreter.stack.append(ast.Dict(keys=[key], values=[value]))
Expand Down Expand Up @@ -1871,6 +1912,9 @@ def run(self, interpreter: Interpreter):
value = interpreter.stack.pop()
list_obj = interpreter.stack[-1]
if isinstance(list_obj, ast.List):
if value is list_obj:
value = ast.List(elts=[ast.Constant(value=...)], ctx=ast.Load())
interpreter._has_cycle = True
list_obj.elts.append(value)
else:
raise ValueError(f"Expected a list on the stack, but instead found {list_obj!r}")
Expand All @@ -1879,9 +1923,13 @@ def run(self, interpreter: Interpreter):
class Appends(StackSliceOpcode):
name = "APPENDS"

def run(self, interpreter: Interpreter, stack_slice: List[ast.expr]):
def run(self, interpreter: Interpreter, stack_slice: list[ast.expr]):
list_obj = interpreter.stack[-1]
if isinstance(list_obj, ast.List):
for i, elem in enumerate(stack_slice):
if elem is list_obj:
stack_slice[i] = ast.List(elts=[ast.Constant(value=...)], ctx=ast.Load())
interpreter._has_cycle = True
list_obj.elts.extend(stack_slice)
else:
raise ValueError(f"Expected a list on the stack, but instead found {list_obj!r}")
Expand Down
89 changes: 89 additions & 0 deletions test/test_crashes.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,24 @@

from fickling.analysis import Severity, check_safety
from fickling.fickle import (
AddItems,
Append,
BinGet,
BinInt1,
BinPut,
BinUnicode,
EmptyDict,
EmptyList,
EmptySet,
Get,
Global,
Mark,
Memoize,
Pickled,
Pop,
Proto,
Reduce,
SetItem,
ShortBinUnicode,
StackGlobal,
Stop,
Expand Down Expand Up @@ -164,3 +171,85 @@ def test_missing_mark_before_tuple(self):
# Safety check should flag it
results = check_safety(loaded)
self.assertGreater(results.severity, Severity.LIKELY_SAFE)

def test_cyclic_pickle(self):
"""Reproduces https://github.com/trailofbits/fickling/issues/196"""
# List with itself as value: L = []; L.append(L)
pickled = Pickled(
[
Proto(2),
EmptyList(),
Memoize(),
Get(0),
Append(),
Stop(),
]
)

# Should detect cycles
self.assertTrue(pickled.has_cycles)

# Should complete without RecursionError
result = check_safety(pickled)
self.assertIsNotNone(result)

# Cyclic reference should be replaced with placeholders
code = unparse(pickled.ast)
self.assertEqual("result = [[...]]", code)

# Dict with itself as value: d = {}; d["self"] = d
dict_value_cycle = Pickled(
[
Proto(2),
EmptyDict(),
Memoize(),
ShortBinUnicode("self"),
Get(0),
SetItem(),
Stop(),
]
)
self.assertTrue(dict_value_cycle.has_cycles)
self.assertEqual("result = {'self': {...}}", unparse(dict_value_cycle.ast))

def test_impossible_cyclic_pickle(self):
"""Test that impossible cyclic structures raise InterpretationError."""
# Dict with itself as key: d = {}; d[d] = "value"
# Python raises: TypeError: unhashable type: 'dict'
dict_key_cycle = Pickled(
[
Proto(2),
EmptyDict(),
Memoize(),
Get(0),
ShortBinUnicode("value"),
SetItem(),
Stop(),
]
)
# Should flag as having interpretation error
self.assertTrue(dict_key_cycle.has_interpretation_error)

# Safety check should flag it
results = check_safety(dict_key_cycle)
self.assertGreater(results.severity, Severity.LIKELY_SAFE)

# Set containing itself: s = set(); s.add(s)
# Python raises: TypeError: unhashable type: 'set'
set_cycle = Pickled(
[
Proto(4),
EmptySet(),
Memoize(),
Mark(),
Get(0),
AddItems(),
Stop(),
]
)
# Should flag as having interpretation error
self.assertTrue(set_cycle.has_interpretation_error)

# Safety check should flag it
results = check_safety(set_cycle)
self.assertGreater(results.severity, Severity.LIKELY_SAFE)
Loading