Skip to content

Commit 09c49d7

Browse files
feat: Add zstd support for reading and writing (#95)
This uses the `compression.zstd` module in Python 3.14 and `backports.zstd` (unconditionally installed from pypi) for lower versions
1 parent 39f72e1 commit 09c49d7

File tree

3 files changed

+31
-1
lines changed

3 files changed

+31
-1
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ classifiers = [
2424
"Topic :: Scientific/Engineering",
2525
"Intended Audience :: Developers",
2626
]
27-
dependencies = ["numpy>=1.21", "packaging"]
27+
dependencies = ["numpy>=1.21", "packaging", "backports.zstd"]
2828
dynamic = ["version"]
2929

3030
[project.urls]

src/pyhepmc/io.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,14 @@ def __init__(
247247
import lzma
248248

249249
open = lzma.open # type:ignore
250+
elif fn.endswith(".zst") or fn.endswith(".zstd"):
251+
from sys import version_info
252+
253+
if version_info >= (3, 14):
254+
from compression import zstd
255+
else:
256+
from backports import zstd
257+
open = zstd.open
250258
else:
251259
from builtins import open # type:ignore
252260

tests/test_io.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,11 @@
1616
else:
1717
list_type = typing.List
1818

19+
if version_info >= (3, 14):
20+
from compression import zstd
21+
else:
22+
from backports import zstd
23+
1924
# this only does something if pyhepmc is compiled in debug mode
2025
hep.Setup.print_warnings = True
2126

@@ -106,6 +111,23 @@ def test_pystream_5():
106111
pio.write(b"foo")
107112

108113

114+
def test_pystream_6(evt):
115+
fn = "test_pystream_6.dat.zst"
116+
with zstd.open(fn, "w") as f:
117+
with pyiostream(f, 1000) as s:
118+
with io.WriterAscii(s) as w:
119+
w.write(evt)
120+
121+
with zstd.open(fn) as f:
122+
with pyiostream(f, 1000) as s:
123+
with io.ReaderAscii(s) as r:
124+
evt2 = r.read()
125+
126+
assert evt == evt2
127+
128+
os.unlink(fn)
129+
130+
109131
def test_read_event_write_event(evt):
110132
oss = stringstream()
111133
with io.WriterAscii(oss) as f:

0 commit comments

Comments
 (0)