Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ jobs:
contents: read
strategy:
matrix:
python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
python-version: ["3.10", "3.11", "3.12", "3.13", "3.14"]

steps:
- uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
Expand Down
14 changes: 6 additions & 8 deletions fickling/fickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@
overload,
)

from stdlib_list import in_stdlib

from fickling.exception import WrongMethodError

T = TypeVar("T")
Expand All @@ -34,14 +32,14 @@
GenericSequence = Sequence[T]
make_constant = ast.Constant

BUILTIN_MODULE_NAMES: frozenset[str] = frozenset(sys.builtin_module_names)
BUILTIN_STDLIB_MODULE_NAMES: frozenset[str] = sys.stdlib_module_names

OPCODES_BY_NAME: dict[str, type[Opcode]] = {}
OPCODE_INFO_BY_NAME: dict[str, OpcodeInfo] = {opcode.name: opcode for opcode in opcodes}


def is_std_module(module_name: str) -> bool:
return in_stdlib(module_name) or module_name in BUILTIN_MODULE_NAMES
return module_name in BUILTIN_STDLIB_MODULE_NAMES


class MarkObject:
Expand Down Expand Up @@ -396,7 +394,7 @@ def insert(self, index: int, opcode: Opcode):
self._properties = None

def _is_constant_type(self, obj: Any) -> bool:
return isinstance(obj, (int, float, str, bytes))
return isinstance(obj, int | float | str | bytes)

def _encode_python_obj(self, obj: Any) -> List[Opcode]:
"""Create an opcode sequence that builds an arbitrary python object on the top of the
Expand Down Expand Up @@ -453,7 +451,7 @@ def insert_python(
# its stack so it remains how we left it!
# TODO: Add code to emulate the code afterward and confirm that the stack is sane!
i = 0
while isinstance(self[i], (Proto, Frame)):
while isinstance(self[i], Proto | Frame):
i += 1
self.insert(i, Global.create(module, attr))
i += 1
Expand Down Expand Up @@ -698,7 +696,7 @@ def has_invalid_opcode(self) -> bool:

@staticmethod
def make_stream(data: Buffer | BinaryIO) -> BinaryIO:
if isinstance(data, (bytes, bytearray, Buffer)):
if isinstance(data, bytes | bytearray | Buffer):
data = BytesIO(data)
elif (not hasattr(data, "seekable") or not data.seekable()) and hasattr(data, "read"):
data = BytesIO(data.read())
Expand Down Expand Up @@ -1469,7 +1467,7 @@ def run(self, interpreter: Interpreter, stack_slice: List[ast.expr]):
pydict = interpreter.stack.pop()
update_dict_keys = []
update_dict_values = []
for key, value in zip(stack_slice[::2], stack_slice[1::2]):
for key, value in zip(stack_slice[::2], stack_slice[1::2], strict=False):
update_dict_keys.append(key)
update_dict_values.append(value)
if isinstance(pydict, ast.Dict) and not pydict.keys:
Expand Down
3 changes: 1 addition & 2 deletions fickling/import_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import warnings
from collections.abc import Sequence
from types import ModuleType
from typing import Union

import fickling.loader as loader

Expand Down Expand Up @@ -47,7 +46,7 @@ def __init__(self, verbose=False):
def find_spec(
self,
fullname: str,
path: Sequence[Union[bytes, str]] | None,
path: Sequence[bytes | str] | None,
target: ModuleType | None = None,
):
if fullname == "pickle":
Expand Down
7 changes: 3 additions & 4 deletions fickling/tracing.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import ast
from ast import unparse
from typing import Union

from .fickle import Interpreter, MarkObject, Opcode, Stack

Expand All @@ -9,14 +8,14 @@ class Trace:
def __init__(self, interpreter: Interpreter):
self.interpreter: Interpreter = interpreter

def on_pop(self, popped_value: Union[ast.expr, MarkObject]):
def on_pop(self, popped_value: ast.expr | MarkObject):
if isinstance(popped_value, MarkObject):
value = "MARK"
else:
value = unparse(popped_value).strip()
print(f"\tPopped {value}")

def on_push(self, pushed_value: Union[ast.expr, MarkObject]):
def on_push(self, pushed_value: ast.expr | MarkObject):
if isinstance(pushed_value, MarkObject):
value = "MARK"
else:
Expand Down Expand Up @@ -51,7 +50,7 @@ def run(self) -> ast.AST:
for added in self.interpreter.module_body[len_module_before:]:
self.on_statement(added)
common_prefix_length = 0
for before, after in zip(stack_before, self.interpreter.stack):
for before, after in zip(stack_before, self.interpreter.stack, strict=False):
if before != after:
break
common_prefix_length += 1
Expand Down
16 changes: 10 additions & 6 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,24 @@ classifiers = [
"Intended Audience :: Science/Research",
"License :: OSI Approved :: GNU Lesser General Public License v3 or later (LGPLv3+)",
"Programming Language :: Python :: 3 :: Only",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.13",
"Programming Language :: Python :: 3.14",
"Topic :: Security",
"Topic :: Software Development :: Testing",
"Topic :: Utilities",
]
dependencies = ["stdlib_list ~= 0.11.1"]
requires-python = ">=3.9"
requires-python = ">=3.10"

[project.optional-dependencies]
torch = ["torch >= 2.1.0", "torchvision >= 0.16.1", "numpy >= 1.24.0"]
torch = [
"torch >= 2.1.0",
"torchvision >= 0.24.1",
"numpy >= 1.24.0; python_version < '3.14'",
"numpy >= 2.3.5; python_version >= '3.14'",
]
lint = [
"ruff >= 0.8.0",
"mypy >= 1.10.0",
Expand All @@ -42,7 +46,7 @@ test = [
"coverage[toml] >= 7.0.0",
"numpy",
"torch >= 2.1.0",
"torchvision >= 0.16.1",
"torchvision >= 0.24.1",
]
dev = [
"fickling[lint,test,torch]",
Expand Down Expand Up @@ -100,7 +104,7 @@ packages = ["fickling"]

[tool.ruff]
line-length = 100
target-version = "py39"
target-version = "py310"

[tool.ruff.lint]
select = [
Expand Down
41 changes: 27 additions & 14 deletions test/test_polyglot.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import random
import string
import sys
import tarfile
import unittest
import zipfile
Expand All @@ -11,6 +12,8 @@

import fickling.polyglot as polyglot

_lacks_torch_jit_support = sys.version_info >= (3, 14)


def create_pytorch_legacy_tar(tar_file_name):
# This is an intentional polymock
Expand Down Expand Up @@ -61,14 +64,17 @@ def setUp(self):
self.filename_legacy_pickle = "model_legacy_pickle.pth"
torch.save(model, self.filename_legacy_pickle, _use_new_zipfile_serialization=False)

# TorchScript v1.4
m = torch.jit.script(model)
self.filename_torchscript = "model_torchscript.pt"
torch.jit.save(m, self.filename_torchscript)
if not _lacks_torch_jit_support:
# TorchScript v1.4
m = torch.jit.script(model)
self.filename_torchscript = "model_torchscript.pt"
torch.jit.save(m, self.filename_torchscript)

# TorchScript v1.4
self.filename_torchscript_dup = "model_torchscript_dup.pt"
torch.jit.save(m, self.filename_torchscript_dup)

# TorchScript v1.4
self.filename_torchscript_dup = "model_torchscript_dup.pt"
torch.jit.save(m, self.filename_torchscript_dup)
self.standard_torchscript_polyglot_name = "test_polyglot.pt"

# PyTorch v0.1.1
self.filename_legacy_tar = "model_legacy_tar.pth"
Expand Down Expand Up @@ -96,23 +102,27 @@ def setUp(self):
archive.write(self.numpy_pickle, self.numpy_pickle)
archive.close()

self.standard_torchscript_polyglot_name = "test_polyglot.pt"

def tearDown(self):
for filename in [
files = [
self.filename_v1_3,
self.filename_legacy_pickle,
self.filename_torchscript,
self.filename_legacy_tar,
self.zip_filename,
self.filename_torchscript_dup,
self.filename_v1_3_dup,
self.standard_torchscript_polyglot_name,
self.numpy_not_pickle,
self.numpy_pickle,
self.tar_numpy_pickle,
self.zip_numpy_pickle,
]:
]
if not _lacks_torch_jit_support:
files.extend(
[
self.filename_torchscript,
self.filename_torchscript_dup,
self.standard_torchscript_polyglot_name,
]
)
for filename in files:
if os.path.exists(filename):
os.remove(filename)

Expand All @@ -125,6 +135,7 @@ def test_v1_3(self):
# formats = polyglot.identify_pytorch_file_format(self.filename_legacy_pickle)
# self.assertEqual(formats, ["PyTorch v0.1.10"])

@unittest.skipIf(_lacks_torch_jit_support, "PyTorch 2.9.1 JIT broken with Python 3.14+")
def test_torchscript(self):
formats = polyglot.identify_pytorch_file_format(self.filename_torchscript)
self.assertEqual(formats, ["TorchScript v1.4", "TorchScript v1.3", "PyTorch v1.3"])
Expand Down Expand Up @@ -279,6 +290,7 @@ def test_legacy_pickle_properties(self):
}
self.assertEqual(properties, proper_result)

@unittest.skipIf(_lacks_torch_jit_support, "PyTorch 2.9.1 JIT broken with Python 3.14+")
def test_torchscript_properties(self):
properties = polyglot.find_file_properties(self.filename_torchscript)
proper_result = {
Expand Down Expand Up @@ -316,6 +328,7 @@ def test_zip_properties(self):
}
self.assertEqual(properties, proper_result)

@unittest.skipIf(_lacks_torch_jit_support, "PyTorch 2.9.1 JIT broken with Python 3.14+")
def test_create_standard_torchscript_polyglot(self):
polyglot.create_polyglot(
self.filename_v1_3_dup,
Expand Down
17 changes: 13 additions & 4 deletions test/test_pytorch.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import sys
import unittest

import torch
Expand All @@ -7,19 +8,25 @@
from fickling.fickle import Pickled
from fickling.pytorch import PyTorchModelWrapper

_lacks_torch_jit_support = sys.version_info >= (3, 14)


class TestPyTorchModule(unittest.TestCase):
def setUp(self):
model = models.mobilenet_v2()
self.filename_v1_3 = "test_model.pth"
torch.save(model, self.filename_v1_3)
self.zip_filename = "test_random_data.zip"
m = torch.jit.script(model)
self.torchscript_filename = "test_model_torchscript.pth"
torch.jit.save(m, self.torchscript_filename)
if not _lacks_torch_jit_support:
m = torch.jit.script(model)
self.torchscript_filename = "test_model_torchscript.pth"
torch.jit.save(m, self.torchscript_filename)

def tearDown(self):
for filename in [self.filename_v1_3, self.zip_filename, self.torchscript_filename]:
files = [self.filename_v1_3, self.zip_filename]
if not _lacks_torch_jit_support:
files.append(self.torchscript_filename)
for filename in files:
if os.path.exists(filename):
os.remove(filename)

Expand All @@ -29,6 +36,7 @@ def test_wrapper(self):
except Exception as e: # noqa
self.fail(f"PyTorchModelWrapper was not able to load a PyTorch v1.3 file: {e}")

@unittest.skipIf(_lacks_torch_jit_support, "PyTorch 2.9.1 JIT broken with Python 3.14+")
def test_torchscript_wrapper(self):
try:
PyTorchModelWrapper(self.torchscript_filename)
Expand All @@ -40,6 +48,7 @@ def test_pickled(self):
pickled_portion = result.pickled
self.assertIsInstance(pickled_portion, Pickled)

@unittest.skipIf(_lacks_torch_jit_support, "PyTorch 2.9.1 JIT broken with Python 3.14+")
def test_torchscript_pickled(self):
result = PyTorchModelWrapper(self.torchscript_filename)
pickled_portion = result.pickled
Expand Down
Loading
Loading