Skip to content

Commit 73a21c4

Browse files
mkittid-v-b
andauthored
Allow Zstandard to decompress multiple concatenated frames (#757)
* Add support for multiple zstd frames in decompression * Add release notes * Format with ruff * Address MSVC type errors * Explicitly declare return type of findTotalContentSize * Mark multiframe pyzstd tests as now passing * Format with ruff * Test concatenated frames of known and unknown sizes * Add docstring for findTotalContentSize --------- Co-authored-by: Davis Bennett <[email protected]>
1 parent ea08835 commit 73a21c4

File tree

4 files changed

+82
-4
lines changed

4 files changed

+82
-4
lines changed

docs/release.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ Improvements
2727
By :user:`John Kirkham <jakirkham>`, :issue:`723`
2828
* All codecs are now pickleable.
2929
By :user:`Tom Nicholas <TomNicholas>`, :issue:`744`
30+
* The Zstandard codec can now decode bytes containing multiple frames
31+
By :user:`Mark Kittisopikul <mkitti>`, :issue:`757`
3032

3133
Fixes
3234
~~~~~

numcodecs/tests/test_pyzstd.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ def test_pyzstd_simple(input):
2525
assert pyzstd.decompress(z.encode(input)) == input
2626

2727

28-
@pytest.mark.xfail
2928
@pytest.mark.parametrize("input", test_data)
3029
def test_pyzstd_simple_multiple_frames_decode(input):
3130
"""

numcodecs/tests/test_zstd.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,3 +156,34 @@ def zstd_cli_available() -> bool:
156156
return not subprocess.run(
157157
["zstd", "-V"], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL
158158
).returncode
159+
160+
161+
def test_multi_frame():
162+
codec = Zstd()
163+
164+
hello_world = codec.encode(b"Hello world!")
165+
assert codec.decode(hello_world) == b"Hello world!"
166+
assert codec.decode(hello_world * 2) == b"Hello world!Hello world!"
167+
168+
hola = codec.encode(b"Hola ")
169+
mundo = codec.encode(b"Mundo!")
170+
assert codec.decode(hola) == b"Hola "
171+
assert codec.decode(mundo) == b"Mundo!"
172+
assert codec.decode(hola + mundo) == b"Hola Mundo!"
173+
174+
bytes_val = b'(\xb5/\xfd\x00Xa\x00\x00Hello World!'
175+
dec = codec.decode(bytes_val)
176+
dec_expected = b'Hello World!'
177+
assert dec == dec_expected
178+
cli = zstd_cli_available()
179+
if cli:
180+
assert bytes_val == generate_zstd_streaming_bytes(dec_expected)
181+
assert dec_expected == generate_zstd_streaming_bytes(bytes_val, decompress=True)
182+
183+
# Concatenate frames of known sizes and unknown sizes
184+
# unknown size frame at the end
185+
assert codec.decode(hola + mundo + bytes_val) == b"Hola Mundo!Hello World!"
186+
# unknown size frame at the beginning
187+
assert codec.decode(bytes_val + hola + mundo) == b"Hello World!Hola Mundo!"
188+
# unknown size frame in the middle
189+
assert codec.decode(hola + bytes_val + mundo) == b"Hola Hello World!Mundo!"

numcodecs/zstd.pyx

Lines changed: 49 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,10 +71,12 @@ cdef extern from "zstd.h":
7171
size_t ZSTD_freeDStream(ZSTD_DStream* zds) nogil
7272
size_t ZSTD_initDStream(ZSTD_DStream* zds) nogil
7373

74-
cdef long ZSTD_CONTENTSIZE_UNKNOWN
75-
cdef long ZSTD_CONTENTSIZE_ERROR
74+
cdef unsigned long long ZSTD_CONTENTSIZE_UNKNOWN
75+
cdef unsigned long long ZSTD_CONTENTSIZE_ERROR
76+
7677
unsigned long long ZSTD_getFrameContentSize(const void* src,
7778
size_t srcSize) nogil
79+
size_t ZSTD_findFrameCompressedSize(const void* src, size_t srcSize) nogil
7880

7981
int ZSTD_minCLevel() nogil
8082
int ZSTD_maxCLevel() nogil
@@ -218,7 +220,11 @@ def decompress(source, dest=None):
218220

219221
try:
220222
# determine uncompressed size using unsigned long long for full range
221-
content_size = ZSTD_getFrameContentSize(source_ptr, source_size)
223+
try:
224+
content_size = findTotalContentSize(source_ptr, source_size)
225+
except RuntimeError:
226+
raise RuntimeError('Zstd decompression error: invalid input data')
227+
222228
if content_size == ZSTD_CONTENTSIZE_UNKNOWN and dest is None:
223229
return stream_decompress(source_pb)
224230
elif content_size == ZSTD_CONTENTSIZE_UNKNOWN:
@@ -362,6 +368,46 @@ cdef stream_decompress(const Py_buffer* source_pb):
362368

363369
return dest
364370

371+
cdef unsigned long long findTotalContentSize(const char* source_ptr, size_t source_size):
372+
"""Find the total uncompressed content size of all frames in the source buffer
373+
374+
Parameters
375+
----------
376+
source_ptr : Pointer to the beginning of the buffer
377+
source_size : Size of the buffer containing the frame sizes to sum
378+
379+
Returns
380+
-------
381+
total_content_size: Sum of the content size of all frames within the source buffer
382+
If any of the frame sizes is unknown, return ZSTD_CONTENTSIZE_UNKNOWN.
383+
If any of the frames causes ZSTD_getFrameContentSize to error, return ZSTD_CONTENTSIZE_ERROR.
384+
"""
385+
cdef:
386+
unsigned long long frame_content_size = 0
387+
unsigned long long total_content_size = 0
388+
size_t frame_compressed_size = 0
389+
size_t offset = 0
390+
391+
while offset < source_size:
392+
frame_compressed_size = ZSTD_findFrameCompressedSize(source_ptr + offset, source_size - offset);
393+
394+
if ZSTD_isError(frame_compressed_size):
395+
error = ZSTD_getErrorName(frame_compressed_size)
396+
raise RuntimeError('Could not set determine zstd frame size: %s' % error)
397+
398+
frame_content_size = ZSTD_getFrameContentSize(source_ptr + offset, frame_compressed_size);
399+
400+
if frame_content_size == ZSTD_CONTENTSIZE_ERROR:
401+
return ZSTD_CONTENTSIZE_ERROR
402+
403+
if frame_content_size == ZSTD_CONTENTSIZE_UNKNOWN:
404+
return ZSTD_CONTENTSIZE_UNKNOWN
405+
406+
total_content_size += frame_content_size
407+
offset += frame_compressed_size
408+
409+
return total_content_size
410+
365411
class Zstd(Codec):
366412
"""Codec providing compression using Zstandard.
367413

0 commit comments

Comments
 (0)