Skip to content
Open
2 changes: 2 additions & 0 deletions docs/release.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ Improvements
By :user:`John Kirkham <jakirkham>`, :issue:`723`
* All codecs are now pickleable.
By :user:`Tom Nicholas <TomNicholas>`, :issue:`744`
* The Zstandard codec can now decode bytes containing multiple frames
By :user:`Mark Kittisopikul <mkitti>`, :issue:`757`

Fixes
~~~~~
Expand Down
1 change: 0 additions & 1 deletion numcodecs/tests/test_pyzstd.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ def test_pyzstd_simple(input):
assert pyzstd.decompress(z.encode(input)) == input


@pytest.mark.xfail
@pytest.mark.parametrize("input", test_data)
def test_pyzstd_simple_multiple_frames_decode(input):
"""
Expand Down
31 changes: 31 additions & 0 deletions numcodecs/tests/test_zstd.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,3 +156,34 @@ def zstd_cli_available() -> bool:
return not subprocess.run(
["zstd", "-V"], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL
).returncode


def test_multi_frame():
codec = Zstd()

hello_world = codec.encode(b"Hello world!")
assert codec.decode(hello_world) == b"Hello world!"
assert codec.decode(hello_world * 2) == b"Hello world!Hello world!"

hola = codec.encode(b"Hola ")
mundo = codec.encode(b"Mundo!")
assert codec.decode(hola) == b"Hola "
assert codec.decode(mundo) == b"Mundo!"
assert codec.decode(hola + mundo) == b"Hola Mundo!"

bytes_val = b'(\xb5/\xfd\x00Xa\x00\x00Hello World!'
dec = codec.decode(bytes_val)
dec_expected = b'Hello World!'
assert dec == dec_expected
cli = zstd_cli_available()
if cli:
assert bytes_val == generate_zstd_streaming_bytes(dec_expected)
assert dec_expected == generate_zstd_streaming_bytes(bytes_val, decompress=True)

# Concatenate frames of known sizes and unknown sizes
# unknown size frame at the end
assert codec.decode(hola + mundo + bytes_val) == b"Hola Mundo!Hello World!"
# unknown size frame at the beginning
assert codec.decode(bytes_val + hola + mundo) == b"Hello World!Hola Mundo!"
# unknown size frame in the middle
assert codec.decode(hola + bytes_val + mundo) == b"Hola Hello World!Mundo!"
52 changes: 49 additions & 3 deletions numcodecs/zstd.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,12 @@ cdef extern from "zstd.h":
size_t ZSTD_freeDStream(ZSTD_DStream* zds) nogil
size_t ZSTD_initDStream(ZSTD_DStream* zds) nogil

cdef long ZSTD_CONTENTSIZE_UNKNOWN
cdef long ZSTD_CONTENTSIZE_ERROR
cdef unsigned long long ZSTD_CONTENTSIZE_UNKNOWN
cdef unsigned long long ZSTD_CONTENTSIZE_ERROR

unsigned long long ZSTD_getFrameContentSize(const void* src,
size_t srcSize) nogil
size_t ZSTD_findFrameCompressedSize(const void* src, size_t srcSize) nogil

int ZSTD_minCLevel() nogil
int ZSTD_maxCLevel() nogil
Expand Down Expand Up @@ -218,7 +220,11 @@ def decompress(source, dest=None):

try:
# determine uncompressed size using unsigned long long for full range
content_size = ZSTD_getFrameContentSize(source_ptr, source_size)
try:
content_size = findTotalContentSize(source_ptr, source_size)
except RuntimeError:
raise RuntimeError('Zstd decompression error: invalid input data')

if content_size == ZSTD_CONTENTSIZE_UNKNOWN and dest is None:
return stream_decompress(source_pb)
elif content_size == ZSTD_CONTENTSIZE_UNKNOWN:
Expand Down Expand Up @@ -362,6 +368,46 @@ cdef stream_decompress(const Py_buffer* source_pb):

return dest

cdef unsigned long long findTotalContentSize(const char* source_ptr, size_t source_size):
"""Find the total uncompressed content size of all frames in the source buffer

Parameters
----------
source_ptr : Pointer to the beginning of the buffer
source_size : Size of the buffer containing the frame sizes to sum

Returns
-------
total_content_size: Sum of the content size of all frames within the source buffer
If any of the frame sizes is unknown, return ZSTD_CONTENTSIZE_UNKNOWN.
If any of the frames causes ZSTD_getFrameContentSize to error, return ZSTD_CONTENTSIZE_ERROR.
"""
cdef:
unsigned long long frame_content_size = 0
unsigned long long total_content_size = 0
size_t frame_compressed_size = 0
size_t offset = 0

while offset < source_size:
frame_compressed_size = ZSTD_findFrameCompressedSize(source_ptr + offset, source_size - offset);

if ZSTD_isError(frame_compressed_size):
error = ZSTD_getErrorName(frame_compressed_size)
raise RuntimeError('Could not set determine zstd frame size: %s' % error)

frame_content_size = ZSTD_getFrameContentSize(source_ptr + offset, frame_compressed_size);

if frame_content_size == ZSTD_CONTENTSIZE_ERROR:
return ZSTD_CONTENTSIZE_ERROR

if frame_content_size == ZSTD_CONTENTSIZE_UNKNOWN:
return ZSTD_CONTENTSIZE_UNKNOWN

total_content_size += frame_content_size
offset += frame_compressed_size

return total_content_size

class Zstd(Codec):
"""Codec providing compression using Zstandard.

Expand Down
Loading