Skip to content

Add streaming decompression for ZSTD_CONTENTSIZE_UNKNOWN case #707

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Jul 10, 2025
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/compression/zstd.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ Zstd
.. autoattribute:: codec_id
.. automethod:: encode
.. automethod:: decode
.. note::
If the compressed data does not contain the decompressed size, streaming
decompression will be used.
.. automethod:: get_config
.. automethod:: from_config

Expand Down
2 changes: 2 additions & 0 deletions docs/release.rst
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ Maintenance

Improvements
~~~~~~~~~~~~
* Add streaming decompression for ZSTD (:issue:`699`)
By :user:`Mark Kittisopikul <mkitti>`.
* Raise a custom `UnknownCodecError` when trying to retrieve an unavailable codec.
By :user:`Cas Wognum <cwognum>`.

Expand Down
25 changes: 25 additions & 0 deletions numcodecs/tests/test_zstd.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,3 +90,28 @@ def test_native_functions():
assert Zstd.default_level() == 3
assert Zstd.min_level() == -131072
assert Zstd.max_level() == 22


def test_streaming_decompression():
# Test input frames with unknown frame content size
codec = Zstd()

# Encode bytes directly that were the result of streaming compression
bytes_val = b'(\xb5/\xfd\x00Xa\x00\x00Hello World!'
dec = codec.decode(bytes_val)
assert dec == b'Hello World!'
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where did these bytes come from? Ideally we would have a test that generated a streaming output from another zstd tool, and used that as an input. Is this particularly onerous to test?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

alternatively, include explanatory reference information from the zstd spec as a comment. basically imagine someone else coming to do maintenance on this test -- how will they know where to look to decipher bytes_val?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we compare it the zstd CLI?

In [1]: import subprocess

In [2]: subprocess.run(["zstd","--no-check"], input=b"Hello world!", capture_out
   ...: put=True).stdout
Out[2]: b'(\xb5/\xfd\x00Xa\x00\x00Hello world!'
$ echo -n "Hello world!" | zstd --no-check | hd
00000000  28 b5 2f fd 00 58 61 00  00 48 65 6c 6c 6f 20 77  |(./..Xa..Hello w|
00000010  6f 72 6c 64 21                                    |orld!|
00000015

or should we break down the frame format:
https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

whatever's easiest for you. the goal is to upgrade from a blob of (apparently) random bytes in our test to something that has a clear tie to the zstd spec.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The hard part for me is figuring out how to manage the test dependency and what the basic requirements are there. If we need to cobble this together relying on PyPI alone, then it probably makes sense to pull in python-zstandard or pyzstd as a test dependency

If we could do a conda package, then the conda-forge zstd package would probably the way to go. In this case, we would not need to rely on another 3d party wrapper but could just depend on the 1st party command line interface.

https://anaconda.org/conda-forge/zstd

Alternatively, we could also use the system package manager to install the zstd CLI.

The byte strings are there to decouple the dependency. They are the same as the bytes in the numcodecs.js test implementation.

https://github.com/manzt/numcodecs.js/blob/main/test%2Fzstd.test.js

I suppose what we could do is just add a bunch of tests that are conditional on the available tools or packages which would be optional dependencies.

I still think we should keep the byte strings there for the case when the no other tools are available. We could of course better document how to generate those bytes.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still think we should keep the byte strings there for the case when the no other tools are available. We could of course better document how to generate those bytes.

this is also fine, which is why I said earlier that a comment explaining where the bytes came from would be sufficient.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you declare the python zstd package as a test dependency?

Which one?

https://pypi.org/project/pyzstd/
https://pypi.org/project/zstd/
https://pypi.org/project/zstandard/

I'm leaning towards pyzstd since it is relatively complete and well maintained.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i don't care which one 🤣

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added pyzstd tests. I even included a test which decompressing multiple frames in a buffer, which currently fails. I marked it as a xfail.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The multiple frames issue can be resolved by #757


# Two consecutive frames given as input
bytes2 = bytes(bytearray(bytes_val * 2))
dec2 = codec.decode(bytes2)
assert dec2 == b'Hello World!Hello World!'

# Single long frame that decompresses to a large output
bytes3 = b'(\xb5/\xfd\x00X$\x02\x00\xa4\x03ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\]^_`abcdefghijklmnopqrstuvwxyz\x01\x00:\xfc\xdfs\x05\x05L\x00\x00\x08s\x01\x00\xfc\xff9\x10\x02L\x00\x00\x08k\x01\x00\xfc\xff9\x10\x02L\x00\x00\x08c\x01\x00\xfc\xff9\x10\x02L\x00\x00\x08[\x01\x00\xfc\xff9\x10\x02L\x00\x00\x08S\x01\x00\xfc\xff9\x10\x02L\x00\x00\x08K\x01\x00\xfc\xff9\x10\x02L\x00\x00\x08C\x01\x00\xfc\xff9\x10\x02L\x00\x00\x08u\x01\x00\xfc\xff9\x10\x02L\x00\x00\x08m\x01\x00\xfc\xff9\x10\x02L\x00\x00\x08e\x01\x00\xfc\xff9\x10\x02L\x00\x00\x08]\x01\x00\xfc\xff9\x10\x02L\x00\x00\x08U\x01\x00\xfc\xff9\x10\x02L\x00\x00\x08M\x01\x00\xfc\xff9\x10\x02M\x00\x00\x08E\x01\x00\xfc\x7f\x1d\x08\x01'
dec3 = codec.decode(bytes3)
assert dec3 == b'ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\]^_`abcdefghijklmnopqrstuvwxyz' * 1024 * 32

# Garbage input results in an error
bytes4 = bytes(bytearray([0, 0, 0, 0, 0, 0, 0, 0]))
with pytest.raises(RuntimeError, match='Zstd decompression error: invalid input data'):
codec.decode(bytes4)
135 changes: 131 additions & 4 deletions numcodecs/zstd.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ from .compat_ext cimport PyBytes_RESIZE, ensure_continguous_memoryview
from .compat import ensure_contiguous_ndarray
from .abc import Codec

from libc.stdlib cimport malloc, realloc, free

cdef extern from "zstd.h":

Expand All @@ -21,6 +22,23 @@ cdef extern from "zstd.h":
struct ZSTD_CCtx_s:
pass
ctypedef ZSTD_CCtx_s ZSTD_CCtx

struct ZSTD_DStream_s:
pass
ctypedef ZSTD_DStream_s ZSTD_DStream

struct ZSTD_inBuffer_s:
const void* src
size_t size
size_t pos
ctypedef ZSTD_inBuffer_s ZSTD_inBuffer

struct ZSTD_outBuffer_s:
void* dst
size_t size
size_t pos
ctypedef ZSTD_outBuffer_s ZSTD_outBuffer

cdef enum ZSTD_cParameter:
ZSTD_c_compressionLevel=100
ZSTD_c_checksumFlag=201
Expand All @@ -36,12 +54,20 @@ cdef extern from "zstd.h":
size_t dstCapacity,
const void* src,
size_t srcSize) nogil

size_t ZSTD_decompress(void* dst,
size_t dstCapacity,
const void* src,
size_t compressedSize) nogil

size_t ZSTD_decompressStream(ZSTD_DStream* zds,
ZSTD_outBuffer* output,
ZSTD_inBuffer* input) nogil

size_t ZSTD_DStreamOutSize() nogil
ZSTD_DStream* ZSTD_createDStream() nogil
size_t ZSTD_freeDStream(ZSTD_DStream* zds) nogil
size_t ZSTD_initDStream(ZSTD_DStream* zds) nogil

cdef long ZSTD_CONTENTSIZE_UNKNOWN
cdef long ZSTD_CONTENTSIZE_ERROR
unsigned long long ZSTD_getFrameContentSize(const void* src,
Expand All @@ -55,7 +81,7 @@ cdef extern from "zstd.h":

unsigned ZSTD_isError(size_t code) nogil

const char* ZSTD_getErrorName(size_t code)
const char* ZSTD_getErrorName(size_t code) nogil


VERSION_NUMBER = ZSTD_versionNumber()
Expand Down Expand Up @@ -157,7 +183,10 @@ def decompress(source, dest=None):
source : bytes-like
Compressed data. Can be any object supporting the buffer protocol.
dest : array-like, optional
Object to decompress into.
Object to decompress into. If the content size is unknown, the
length of dest must match the decompressed size. If the content size
is unknown and dest is not provided, streaming decompression will be
used.

Returns
-------
Expand All @@ -174,6 +203,7 @@ def decompress(source, dest=None):
char* dest_ptr
size_t source_size, dest_size, decompressed_size
size_t nbytes, cbytes, blocksize
size_t dest_nbytes

# obtain source memoryview
source_mv = ensure_continguous_memoryview(source)
Expand All @@ -187,9 +217,12 @@ def decompress(source, dest=None):

# determine uncompressed size
dest_size = ZSTD_getFrameContentSize(source_ptr, source_size)
if dest_size == 0 or dest_size == ZSTD_CONTENTSIZE_UNKNOWN or dest_size == ZSTD_CONTENTSIZE_ERROR:
if dest_size == 0 or dest_size == ZSTD_CONTENTSIZE_ERROR:
raise RuntimeError('Zstd decompression error: invalid input data')

if dest_size == ZSTD_CONTENTSIZE_UNKNOWN and dest is None:
return stream_decompress(source_pb)

# setup destination buffer
if dest is None:
# allocate memory
Expand All @@ -203,6 +236,9 @@ def decompress(source, dest=None):
dest_ptr = <char*>dest_pb.buf
dest_nbytes = dest_pb.len

if dest_size == ZSTD_CONTENTSIZE_UNKNOWN:
dest_size = dest_nbytes

# validate output buffer
if dest_nbytes < dest_size:
raise ValueError('destination buffer too small; expected at least %s, '
Expand All @@ -225,6 +261,97 @@ def decompress(source, dest=None):

return dest

cdef stream_decompress(const Py_buffer* source_pb):
"""Decompress data of unknown size

Parameters
----------
source : Py_buffer
Compressed data buffer

Returns
-------
dest : bytes
Object containing decompressed data.
"""

cdef:
const char *source_ptr
void *dest_ptr
void *new_dst
size_t source_size, dest_size, decompressed_size
size_t DEST_GROWTH_SIZE, status
ZSTD_inBuffer input
ZSTD_outBuffer output
ZSTD_DStream *zds

# Recommended size for output buffer, guaranteed to flush at least
# one completely block in all circumstances
DEST_GROWTH_SIZE = ZSTD_DStreamOutSize();

source_ptr = <const char*>source_pb.buf
source_size = source_pb.len

# unknown content size, guess it is twice the size as the source
dest_size = source_size * 2

if dest_size < DEST_GROWTH_SIZE:
# minimum dest_size is DEST_GROWTH_SIZE
dest_size = DEST_GROWTH_SIZE

dest_ptr = <char *>malloc(dest_size)
zds = ZSTD_createDStream()

try:

with nogil:

status = ZSTD_initDStream(zds)
if ZSTD_isError(status):
error = ZSTD_getErrorName(status)
ZSTD_freeDStream(zds);
raise RuntimeError('Zstd stream decompression error on ZSTD_initDStream: %s' % error)

input = ZSTD_inBuffer(source_ptr, source_size, 0)
output = ZSTD_outBuffer(dest_ptr, dest_size, 0)

# Initialize to 1 to force a loop iteration
status = 1
while(status > 0 or input.pos < input.size):
# Possible returned values of ZSTD_decompressStream:
# 0: frame is completely decoded and fully flushed
# error (<0)
# >0: suggested next input size
status = ZSTD_decompressStream(zds, &output, &input)

if ZSTD_isError(status):
error = ZSTD_getErrorName(status)
raise RuntimeError('Zstd stream decompression error on ZSTD_decompressStream: %s' % error)

# There is more to decompress, grow the buffer
if status > 0 and output.pos == output.size:
new_size = output.size + DEST_GROWTH_SIZE

if new_size < output.size or new_size < DEST_GROWTH_SIZE:
raise RuntimeError('Zstd stream decompression error: output buffer overflow')

new_dst = realloc(output.dst, new_size)

if new_dst == NULL:
# output.dst freed in finally block
raise RuntimeError('Zstd stream decompression error on realloc: could not expand output buffer')

output.dst = new_dst
output.size = new_size

# Copy the output to a bytes object
dest = PyBytes_FromStringAndSize(<char *>output.dst, output.pos)

finally:
ZSTD_freeDStream(zds)
free(output.dst)

return dest

class Zstd(Codec):
"""Codec providing compression using Zstandard.
Expand Down
Loading