Skip to content

Commit d04e536

Browse files
committed
Add streaming decompression for ZSTD_CONTENTSIZE_UNKNOWN case
1 parent 3cf8ab1 commit d04e536

File tree

1 file changed

+128
-5
lines changed

1 file changed

+128
-5
lines changed

numcodecs/zstd.pyx

Lines changed: 128 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,12 @@
88
from cpython.buffer cimport PyBUF_ANY_CONTIGUOUS, PyBUF_WRITEABLE
99
from cpython.bytes cimport PyBytes_FromStringAndSize, PyBytes_AS_STRING
1010

11-
1211
from .compat_ext cimport Buffer
1312
from .compat_ext import Buffer
1413
from .compat import ensure_contiguous_ndarray
1514
from .abc import Codec
1615

16+
from libc.stdlib cimport malloc, realloc, free
1717

1818
cdef extern from "zstd.h":
1919

@@ -22,6 +22,23 @@ cdef extern from "zstd.h":
2222
struct ZSTD_CCtx_s:
2323
pass
2424
ctypedef ZSTD_CCtx_s ZSTD_CCtx
25+
26+
struct ZSTD_DStream_s:
27+
pass
28+
ctypedef ZSTD_DStream_s ZSTD_DStream
29+
30+
struct ZSTD_inBuffer_s:
31+
const void* src
32+
size_t size
33+
size_t pos
34+
ctypedef ZSTD_inBuffer_s ZSTD_inBuffer
35+
36+
struct ZSTD_outBuffer_s:
37+
void* dst
38+
size_t size
39+
size_t pos
40+
ctypedef ZSTD_outBuffer_s ZSTD_outBuffer
41+
2542
cdef enum ZSTD_cParameter:
2643
ZSTD_c_compressionLevel=100
2744
ZSTD_c_checksumFlag=201
@@ -37,12 +54,20 @@ cdef extern from "zstd.h":
3754
size_t dstCapacity,
3855
const void* src,
3956
size_t srcSize) nogil
40-
4157
size_t ZSTD_decompress(void* dst,
4258
size_t dstCapacity,
4359
const void* src,
4460
size_t compressedSize) nogil
4561

62+
size_t ZSTD_decompressStream(ZSTD_DStream* zds,
63+
ZSTD_outBuffer* output,
64+
ZSTD_inBuffer* input) nogil
65+
66+
size_t ZSTD_DStreamOutSize() nogil
67+
ZSTD_DStream* ZSTD_createDStream() nogil
68+
size_t ZSTD_freeDStream(ZSTD_DStream* zds) nogil
69+
size_t ZSTD_initDStream(ZSTD_DStream* zds) nogil
70+
4671
cdef long ZSTD_CONTENTSIZE_UNKNOWN
4772
cdef long ZSTD_CONTENTSIZE_ERROR
4873
unsigned long long ZSTD_getFrameContentSize(const void* src,
@@ -56,7 +81,7 @@ cdef extern from "zstd.h":
5681

5782
unsigned ZSTD_isError(size_t code) nogil
5883

59-
const char* ZSTD_getErrorName(size_t code)
84+
const char* ZSTD_getErrorName(size_t code) nogil
6085

6186

6287
VERSION_NUMBER = ZSTD_versionNumber()
@@ -156,7 +181,8 @@ def decompress(source, dest=None):
156181
source : bytes-like
157182
Compressed data. Can be any object supporting the buffer protocol.
158183
dest : array-like, optional
159-
Object to decompress into.
184+
Object to decompress into. If the content size is unknown, the
185+
length of dest must match the decompressed size.
160186
161187
Returns
162188
-------
@@ -180,9 +206,12 @@ def decompress(source, dest=None):
180206

181207
# determine uncompressed size
182208
dest_size = ZSTD_getFrameContentSize(source_ptr, source_size)
183-
if dest_size == 0 or dest_size == ZSTD_CONTENTSIZE_UNKNOWN or dest_size == ZSTD_CONTENTSIZE_ERROR:
209+
if dest_size == 0 or dest_size == ZSTD_CONTENTSIZE_ERROR:
184210
raise RuntimeError('Zstd decompression error: invalid input data')
185211

212+
if dest_size == ZSTD_CONTENTSIZE_UNKNOWN and dest is None:
213+
return stream_decompress(source_buffer)
214+
186215
# setup destination buffer
187216
if dest is None:
188217
# allocate memory
@@ -192,6 +221,8 @@ def decompress(source, dest=None):
192221
arr = ensure_contiguous_ndarray(dest)
193222
dest_buffer = Buffer(arr, PyBUF_ANY_CONTIGUOUS | PyBUF_WRITEABLE)
194223
dest_ptr = dest_buffer.ptr
224+
if dest_size == ZSTD_CONTENTSIZE_UNKNOWN:
225+
dest_size = dest_buffer.nbytes
195226
if dest_buffer.nbytes < dest_size:
196227
raise ValueError('destination buffer too small; expected at least %s, '
197228
'got %s' % (dest_size, dest_buffer.nbytes))
@@ -217,6 +248,98 @@ def decompress(source, dest=None):
217248

218249
return dest
219250

251+
cdef stream_decompress(Buffer source_buffer):
252+
"""Decompress data of unknown size
253+
254+
Parameters
255+
----------
256+
source : Buffer
257+
Compressed data buffer
258+
259+
Returns
260+
-------
261+
dest : bytes
262+
Object containing decompressed data.
263+
"""
264+
265+
cdef:
266+
char *source_ptr
267+
void *dest_ptr
268+
void *new_dst
269+
Buffer dest_buffer = None
270+
size_t source_size, dest_size, decompressed_size
271+
size_t DEST_GROWTH_SIZE, status
272+
ZSTD_inBuffer input
273+
ZSTD_outBuffer output
274+
ZSTD_DStream *zds
275+
276+
# Recommended size for output buffer, guaranteed to flush at least
277+
# one completely block in all circumstances
278+
DEST_GROWTH_SIZE = ZSTD_DStreamOutSize();
279+
280+
source_ptr = source_buffer.ptr
281+
source_size = source_buffer.nbytes
282+
283+
# unknown content size, guess it is twice the size as the source
284+
dest_size = source_size * 2
285+
286+
if dest_size < DEST_GROWTH_SIZE:
287+
# minimum dest_size is DEST_GROWTH_SIZE
288+
dest_size = DEST_GROWTH_SIZE
289+
290+
dest_ptr = malloc(dest_size)
291+
zds = ZSTD_createDStream()
292+
293+
try:
294+
295+
with nogil:
296+
297+
status = ZSTD_initDStream(zds)
298+
if ZSTD_isError(status):
299+
error = ZSTD_getErrorName(status)
300+
ZSTD_freeDStream(zds);
301+
raise RuntimeError('Zstd stream decompression error on ZSTD_initDStream: %s' % error)
302+
303+
input = ZSTD_inBuffer(source_ptr, source_size, 0)
304+
output = ZSTD_outBuffer(dest_ptr, dest_size, 0)
305+
306+
# Initialize to 1 to force a loop iteration
307+
status = 1
308+
while(status > 0 or input.pos < input.size):
309+
# Possible returned values of ZSTD_decompressStream:
310+
# 0: frame is completely decoded and fully flushed
311+
# error (<0)
312+
# >0: suggested next input size
313+
status = ZSTD_decompressStream(zds, &output, &input)
314+
315+
if ZSTD_isError(status):
316+
error = ZSTD_getErrorName(status)
317+
raise RuntimeError('Zstd stream decompression error on ZSTD_decompressStream: %s' % error)
318+
319+
# There is more to decompress, grow the buffer
320+
if status > 0 and output.pos == output.size:
321+
new_size = output.size + DEST_GROWTH_SIZE
322+
323+
if new_size < output.size or new_size < DEST_GROWTH_SIZE:
324+
raise RuntimeError('Zstd stream decompression error: output buffer overflow')
325+
326+
new_dst = realloc(output.dst, new_size)
327+
328+
if new_dst == NULL:
329+
# output.dst freed in finally block
330+
raise RuntimeError('Zstd stream decompression error on realloc: could not expand output buffer')
331+
332+
output.dst = new_dst
333+
output.size = new_size
334+
335+
# Copy the output to a bytes object
336+
dest = PyBytes_FromStringAndSize(<char *>output.dst, output.pos)
337+
338+
finally:
339+
ZSTD_freeDStream(zds)
340+
free(output.dst)
341+
342+
return dest
220343

221344
class Zstd(Codec):
222345
"""Codec providing compression using Zstandard.

0 commit comments

Comments
 (0)