Skip to content

Commit 7a6fad3

Browse files
authored
Add streaming decompression for ZSTD_CONTENTSIZE_UNKNOWN case (#707)
* Add streaming decompression for ZSTD_CONTENTSIZE_UNKNOWN case * Add tests and documentation for streaming Zstd * Add better tests from numcodecs.js * Adapt zstd.pyx streaming for Py_buffer * Formatting with ruff * Fix zstd comparison of different signedness * Undo change to unrelated test * Add zstd cli tests * Test Zstd against pyzstd * Apply ruff * Fix zstd tests, coverage * Make imports not optional * Add EndlessZstdDecompressor tests * Add docstrings to test_pyzstd.py
1 parent 4fdb625 commit 7a6fad3

File tree

6 files changed

+286
-10
lines changed

6 files changed

+286
-10
lines changed

docs/compression/zstd.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@ Zstd
77
.. autoattribute:: codec_id
88
.. automethod:: encode
99
.. automethod:: decode
10+
.. note::
11+
If the compressed data does not contain the decompressed size, streaming
12+
decompression will be used.
1013
.. automethod:: get_config
1114
.. automethod:: from_config
1215

docs/release.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,8 @@ Maintenance
9595

9696
Improvements
9797
~~~~~~~~~~~~
98+
* Add streaming decompression for ZSTD (:issue:`699`)
99+
By :user:`Mark Kittisopikul <mkitti>`.
98100
* Raise a custom `UnknownCodecError` when trying to retrieve an unavailable codec.
99101
By :user:`Cas Wognum <cwognum>`.
100102

numcodecs/tests/test_pyzstd.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
# Check Zstd against pyzstd package
2+
3+
import numpy as np
4+
import pytest
5+
import pyzstd
6+
7+
from numcodecs.zstd import Zstd
8+
9+
test_data = [
10+
b"Hello World!",
11+
np.arange(113).tobytes(),
12+
np.arange(10, 15).tobytes(),
13+
np.random.randint(3, 50, size=(53,), dtype=np.uint16).tobytes(),
14+
]
15+
16+
17+
@pytest.mark.parametrize("input", test_data)
18+
def test_pyzstd_simple(input):
19+
"""
20+
Test if Zstd.[decode, encode] can perform the inverse operation to
21+
pyzstd.[compress, decompress] in the simple case.
22+
"""
23+
z = Zstd()
24+
assert z.decode(pyzstd.compress(input)) == input
25+
assert pyzstd.decompress(z.encode(input)) == input
26+
27+
28+
@pytest.mark.xfail
29+
@pytest.mark.parametrize("input", test_data)
30+
def test_pyzstd_simple_multiple_frames_decode(input):
31+
"""
32+
Test decompression of two concatenated frames of known sizes
33+
34+
numcodecs.zstd.Zstd currently fails because it only assesses the size of the
35+
first frame. Rather, it should keep iterating through all the frames until
36+
the end of the input buffer.
37+
"""
38+
z = Zstd()
39+
assert pyzstd.decompress(pyzstd.compress(input) * 2) == input * 2
40+
assert z.decode(pyzstd.compress(input) * 2) == input * 2
41+
42+
43+
@pytest.mark.parametrize("input", test_data)
44+
def test_pyzstd_simple_multiple_frames_encode(input):
45+
"""
46+
Test if pyzstd can decompress two concatenated frames from Zstd.encode
47+
"""
48+
z = Zstd()
49+
assert pyzstd.decompress(z.encode(input) * 2) == input * 2
50+
51+
52+
@pytest.mark.parametrize("input", test_data)
53+
def test_pyzstd_streaming(input):
54+
"""
55+
Test if Zstd can decode a single frame and concatenated frames in streaming
56+
mode where the decompressed size is not recorded in the frame header.
57+
"""
58+
pyzstd_c = pyzstd.ZstdCompressor()
59+
pyzstd_d = pyzstd.ZstdDecompressor()
60+
pyzstd_e = pyzstd.EndlessZstdDecompressor()
61+
z = Zstd()
62+
63+
d_bytes = input
64+
pyzstd_c.compress(d_bytes)
65+
c_bytes = pyzstd_c.flush()
66+
assert z.decode(c_bytes) == d_bytes
67+
assert pyzstd_d.decompress(z.encode(d_bytes)) == d_bytes
68+
69+
# Test multiple streaming frames
70+
assert z.decode(c_bytes * 2) == pyzstd_e.decompress(c_bytes * 2)
71+
assert z.decode(c_bytes * 3) == pyzstd_e.decompress(c_bytes * 3)
72+
assert z.decode(c_bytes * 4) == pyzstd_e.decompress(c_bytes * 4)
73+
assert z.decode(c_bytes * 5) == pyzstd_e.decompress(c_bytes * 5)
74+
assert z.decode(c_bytes * 7) == pyzstd_e.decompress(c_bytes * 7)
75+
assert z.decode(c_bytes * 11) == pyzstd_e.decompress(c_bytes * 11)
76+
assert z.decode(c_bytes * 13) == pyzstd_e.decompress(c_bytes * 13)
77+
assert z.decode(c_bytes * 99) == pyzstd_e.decompress(c_bytes * 99)

numcodecs/tests/test_zstd.py

Lines changed: 72 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,9 @@
11
import itertools
2+
import subprocess
23

34
import numpy as np
45
import pytest
56

6-
try:
7-
from numcodecs.zstd import Zstd
8-
except ImportError: # pragma: no cover
9-
pytest.skip("numcodecs.zstd not available", allow_module_level=True)
10-
11-
127
from numcodecs.tests.common import (
138
check_backwards_compatibility,
149
check_config,
@@ -17,6 +12,7 @@
1712
check_err_encode_object_buffer,
1813
check_repr,
1914
)
15+
from numcodecs.zstd import Zstd
2016

2117
codecs = [
2218
Zstd(),
@@ -90,3 +86,73 @@ def test_native_functions():
9086
assert Zstd.default_level() == 3
9187
assert Zstd.min_level() == -131072
9288
assert Zstd.max_level() == 22
89+
90+
91+
def test_streaming_decompression():
92+
# Test input frames with unknown frame content size
93+
codec = Zstd()
94+
95+
# If the zstd command line interface is available, check the bytes
96+
cli = zstd_cli_available()
97+
if cli:
98+
view_zstd_streaming_bytes()
99+
100+
# Encode bytes directly that were the result of streaming compression
101+
bytes_val = b'(\xb5/\xfd\x00Xa\x00\x00Hello World!'
102+
dec = codec.decode(bytes_val)
103+
dec_expected = b'Hello World!'
104+
assert dec == dec_expected
105+
if cli:
106+
assert bytes_val == generate_zstd_streaming_bytes(dec_expected)
107+
assert dec_expected == generate_zstd_streaming_bytes(bytes_val, decompress=True)
108+
109+
# Two consecutive frames given as input
110+
bytes2 = bytes(bytearray(bytes_val * 2))
111+
dec2 = codec.decode(bytes2)
112+
dec2_expected = b'Hello World!Hello World!'
113+
assert dec2 == dec2_expected
114+
if cli:
115+
assert dec2_expected == generate_zstd_streaming_bytes(bytes2, decompress=True)
116+
117+
# Single long frame that decompresses to a large output
118+
bytes3 = b'(\xb5/\xfd\x00X$\x02\x00\xa4\x03ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\]^_`abcdefghijklmnopqrstuvwxyz\x01\x00:\xfc\xdfs\x05\x05L\x00\x00\x08s\x01\x00\xfc\xff9\x10\x02L\x00\x00\x08k\x01\x00\xfc\xff9\x10\x02L\x00\x00\x08c\x01\x00\xfc\xff9\x10\x02L\x00\x00\x08[\x01\x00\xfc\xff9\x10\x02L\x00\x00\x08S\x01\x00\xfc\xff9\x10\x02L\x00\x00\x08K\x01\x00\xfc\xff9\x10\x02L\x00\x00\x08C\x01\x00\xfc\xff9\x10\x02L\x00\x00\x08u\x01\x00\xfc\xff9\x10\x02L\x00\x00\x08m\x01\x00\xfc\xff9\x10\x02L\x00\x00\x08e\x01\x00\xfc\xff9\x10\x02L\x00\x00\x08]\x01\x00\xfc\xff9\x10\x02L\x00\x00\x08U\x01\x00\xfc\xff9\x10\x02L\x00\x00\x08M\x01\x00\xfc\xff9\x10\x02M\x00\x00\x08E\x01\x00\xfc\x7f\x1d\x08\x01'
119+
dec3 = codec.decode(bytes3)
120+
dec3_expected = b'ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\]^_`abcdefghijklmnopqrstuvwxyz' * 1024 * 32
121+
assert dec3 == dec3_expected
122+
if cli:
123+
assert bytes3 == generate_zstd_streaming_bytes(dec3_expected)
124+
assert dec3_expected == generate_zstd_streaming_bytes(bytes3, decompress=True)
125+
126+
# Garbage input results in an error
127+
bytes4 = bytes(bytearray([0, 0, 0, 0, 0, 0, 0, 0]))
128+
with pytest.raises(RuntimeError, match='Zstd decompression error: invalid input data'):
129+
codec.decode(bytes4)
130+
131+
132+
def generate_zstd_streaming_bytes(input: bytes, *, decompress: bool = False) -> bytes:
133+
"""
134+
Use the zstd command line interface to compress or decompress bytes in streaming mode.
135+
"""
136+
if decompress:
137+
args = ["-d"]
138+
else:
139+
args = []
140+
141+
p = subprocess.run(["zstd", "--no-check", *args], input=input, capture_output=True)
142+
return p.stdout
143+
144+
145+
def view_zstd_streaming_bytes():
146+
bytes_val = generate_zstd_streaming_bytes(b"Hello world!")
147+
print(f" bytes_val = {bytes_val}")
148+
149+
bytes3 = generate_zstd_streaming_bytes(
150+
b"ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\]^_`abcdefghijklmnopqrstuvwxyz" * 1024 * 32
151+
)
152+
print(f" bytes3 = {bytes3}")
153+
154+
155+
def zstd_cli_available() -> bool:
156+
return not subprocess.run(
157+
["zstd", "-V"], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL
158+
).returncode

numcodecs/zstd.pyx

Lines changed: 131 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ from .compat_ext cimport PyBytes_RESIZE, ensure_continguous_memoryview
1313
from .compat import ensure_contiguous_ndarray
1414
from .abc import Codec
1515

16+
from libc.stdlib cimport malloc, realloc, free
1617

1718
cdef extern from "zstd.h":
1819

@@ -21,6 +22,23 @@ cdef extern from "zstd.h":
2122
struct ZSTD_CCtx_s:
2223
pass
2324
ctypedef ZSTD_CCtx_s ZSTD_CCtx
25+
26+
struct ZSTD_DStream_s:
27+
pass
28+
ctypedef ZSTD_DStream_s ZSTD_DStream
29+
30+
struct ZSTD_inBuffer_s:
31+
const void* src
32+
size_t size
33+
size_t pos
34+
ctypedef ZSTD_inBuffer_s ZSTD_inBuffer
35+
36+
struct ZSTD_outBuffer_s:
37+
void* dst
38+
size_t size
39+
size_t pos
40+
ctypedef ZSTD_outBuffer_s ZSTD_outBuffer
41+
2442
cdef enum ZSTD_cParameter:
2543
ZSTD_c_compressionLevel=100
2644
ZSTD_c_checksumFlag=201
@@ -36,12 +54,20 @@ cdef extern from "zstd.h":
3654
size_t dstCapacity,
3755
const void* src,
3856
size_t srcSize) nogil
39-
4057
size_t ZSTD_decompress(void* dst,
4158
size_t dstCapacity,
4259
const void* src,
4360
size_t compressedSize) nogil
4461

62+
size_t ZSTD_decompressStream(ZSTD_DStream* zds,
63+
ZSTD_outBuffer* output,
64+
ZSTD_inBuffer* input) nogil
65+
66+
size_t ZSTD_DStreamOutSize() nogil
67+
ZSTD_DStream* ZSTD_createDStream() nogil
68+
size_t ZSTD_freeDStream(ZSTD_DStream* zds) nogil
69+
size_t ZSTD_initDStream(ZSTD_DStream* zds) nogil
70+
4571
cdef long ZSTD_CONTENTSIZE_UNKNOWN
4672
cdef long ZSTD_CONTENTSIZE_ERROR
4773
unsigned long long ZSTD_getFrameContentSize(const void* src,
@@ -55,7 +81,7 @@ cdef extern from "zstd.h":
5581

5682
unsigned ZSTD_isError(size_t code) nogil
5783

58-
const char* ZSTD_getErrorName(size_t code)
84+
const char* ZSTD_getErrorName(size_t code) nogil
5985

6086

6187
VERSION_NUMBER = ZSTD_versionNumber()
@@ -157,7 +183,10 @@ def decompress(source, dest=None):
157183
source : bytes-like
158184
Compressed data. Can be any object supporting the buffer protocol.
159185
dest : array-like, optional
160-
Object to decompress into.
186+
Object to decompress into. If the content size is unknown, the
187+
length of dest must match the decompressed size. If the content size
188+
is unknown and dest is not provided, streaming decompression will be
189+
used.
161190
162191
Returns
163192
-------
@@ -174,6 +203,7 @@ def decompress(source, dest=None):
174203
char* dest_ptr
175204
size_t source_size, dest_size, decompressed_size
176205
size_t nbytes, cbytes, blocksize
206+
size_t dest_nbytes
177207

178208
# obtain source memoryview
179209
source_mv = ensure_continguous_memoryview(source)
@@ -187,9 +217,12 @@ def decompress(source, dest=None):
187217

188218
# determine uncompressed size
189219
dest_size = ZSTD_getFrameContentSize(source_ptr, source_size)
190-
if dest_size == 0 or dest_size == ZSTD_CONTENTSIZE_UNKNOWN or dest_size == ZSTD_CONTENTSIZE_ERROR:
220+
if dest_size == 0 or dest_size == ZSTD_CONTENTSIZE_ERROR:
191221
raise RuntimeError('Zstd decompression error: invalid input data')
192222

223+
if dest_size == ZSTD_CONTENTSIZE_UNKNOWN and dest is None:
224+
return stream_decompress(source_pb)
225+
193226
# setup destination buffer
194227
if dest is None:
195228
# allocate memory
@@ -203,6 +236,9 @@ def decompress(source, dest=None):
203236
dest_ptr = <char*>dest_pb.buf
204237
dest_nbytes = dest_pb.len
205238

239+
if dest_size == ZSTD_CONTENTSIZE_UNKNOWN:
240+
dest_size = dest_nbytes
241+
206242
# validate output buffer
207243
if dest_nbytes < dest_size:
208244
raise ValueError('destination buffer too small; expected at least %s, '
@@ -225,6 +261,97 @@ def decompress(source, dest=None):
225261

226262
return dest
227263

264+
cdef stream_decompress(const Py_buffer* source_pb):
265+
"""Decompress data of unknown size
266+
267+
Parameters
268+
----------
269+
source : Py_buffer
270+
Compressed data buffer
271+
272+
Returns
273+
-------
274+
dest : bytes
275+
Object containing decompressed data.
276+
"""
277+
278+
cdef:
279+
const char *source_ptr
280+
void *dest_ptr
281+
void *new_dst
282+
size_t source_size, dest_size, decompressed_size
283+
size_t DEST_GROWTH_SIZE, status
284+
ZSTD_inBuffer input
285+
ZSTD_outBuffer output
286+
ZSTD_DStream *zds
287+
288+
# Recommended size for output buffer, guaranteed to flush at least
289+
# one completely block in all circumstances
290+
DEST_GROWTH_SIZE = ZSTD_DStreamOutSize();
291+
292+
source_ptr = <const char*>source_pb.buf
293+
source_size = source_pb.len
294+
295+
# unknown content size, guess it is twice the size as the source
296+
dest_size = source_size * 2
297+
298+
if dest_size < DEST_GROWTH_SIZE:
299+
# minimum dest_size is DEST_GROWTH_SIZE
300+
dest_size = DEST_GROWTH_SIZE
301+
302+
dest_ptr = <char *>malloc(dest_size)
303+
zds = ZSTD_createDStream()
304+
305+
try:
306+
307+
with nogil:
308+
309+
status = ZSTD_initDStream(zds)
310+
if ZSTD_isError(status):
311+
error = ZSTD_getErrorName(status)
312+
ZSTD_freeDStream(zds);
313+
raise RuntimeError('Zstd stream decompression error on ZSTD_initDStream: %s' % error)
314+
315+
input = ZSTD_inBuffer(source_ptr, source_size, 0)
316+
output = ZSTD_outBuffer(dest_ptr, dest_size, 0)
317+
318+
# Initialize to 1 to force a loop iteration
319+
status = 1
320+
while(status > 0 or input.pos < input.size):
321+
# Possible returned values of ZSTD_decompressStream:
322+
# 0: frame is completely decoded and fully flushed
323+
# error (<0)
324+
# >0: suggested next input size
325+
status = ZSTD_decompressStream(zds, &output, &input)
326+
327+
if ZSTD_isError(status):
328+
error = ZSTD_getErrorName(status)
329+
raise RuntimeError('Zstd stream decompression error on ZSTD_decompressStream: %s' % error)
330+
331+
# There is more to decompress, grow the buffer
332+
if status > 0 and output.pos == output.size:
333+
new_size = output.size + DEST_GROWTH_SIZE
334+
335+
if new_size < output.size or new_size < DEST_GROWTH_SIZE:
336+
raise RuntimeError('Zstd stream decompression error: output buffer overflow')
337+
338+
new_dst = realloc(output.dst, new_size)
339+
340+
if new_dst == NULL:
341+
# output.dst freed in finally block
342+
raise RuntimeError('Zstd stream decompression error on realloc: could not expand output buffer')
343+
344+
output.dst = new_dst
345+
output.size = new_size
346+
347+
# Copy the output to a bytes object
348+
dest = PyBytes_FromStringAndSize(<char *>output.dst, output.pos)
349+
350+
finally:
351+
ZSTD_freeDStream(zds)
352+
free(output.dst)
353+
354+
return dest
228355

229356
class Zstd(Codec):
230357
"""Codec providing compression using Zstandard.

0 commit comments

Comments
 (0)