Skip to content

Commit 41150f2

Browse files
committed
Update zstd.pyx with new content
1 parent 506c89b commit 41150f2

File tree

1 file changed

+13
-15
lines changed

1 file changed

+13
-15
lines changed

numcodecs/zstd.pyx

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ def compress(source, int level=DEFAULT_CLEVEL, bint checksum=False):
129129
level = MAX_CLEVEL
130130

131131
# obtain source memoryview
132-
source_mv = ensure_continguous_memoryview(source)
132+
source_mv = ensure_contiguous_memoryview(source)
133133
source_pb = PyMemoryView_GET_BUFFER(source_mv)
134134

135135
# setup source buffer
@@ -202,30 +202,31 @@ def decompress(source, dest=None):
202202
Py_buffer* dest_pb
203203
char* dest_ptr
204204
size_t source_size, dest_size, decompressed_size
205-
size_t nbytes, cbytes, blocksize
206205
size_t dest_nbytes
206+
unsigned long long content_size
207207

208208
# obtain source memoryview
209-
source_mv = ensure_continguous_memoryview(source)
209+
source_mv = ensure_contiguous_memoryview(source)
210210
source_pb = PyMemoryView_GET_BUFFER(source_mv)
211211

212212
# get source pointer
213213
source_ptr = <const char*>source_pb.buf
214214
source_size = source_pb.len
215215

216216
try:
217-
218-
# determine uncompressed size
219-
dest_size = ZSTD_getFrameContentSize(source_ptr, source_size)
220-
if dest_size == 0 or dest_size == ZSTD_CONTENTSIZE_ERROR:
217+
# determine uncompressed size using unsigned long long for full range
218+
content_size = ZSTD_getFrameContentSize(source_ptr, source_size)
219+
if content_size == ZSTD_CONTENTSIZE_UNKNOWN and dest is None:
220+
return stream_decompress(source_pb)
221+
elif content_size == ZSTD_CONTENTSIZE_ERROR or content_size == 0:
221222
raise RuntimeError('Zstd decompression error: invalid input data')
223+
elif content_size > (<unsigned long long>SIZE_MAX):
224+
raise RuntimeError('Zstd decompression error: content size too large for platform')
222225

223-
if dest_size == ZSTD_CONTENTSIZE_UNKNOWN and dest is None:
224-
return stream_decompress(source_pb)
226+
dest_size = <size_t>content_size
225227

226228
# setup destination buffer
227229
if dest is None:
228-
# allocate memory
229230
dest_1d = dest = PyBytes_FromStringAndSize(NULL, dest_size)
230231
else:
231232
dest_1d = ensure_contiguous_ndarray(dest)
@@ -236,9 +237,6 @@ def decompress(source, dest=None):
236237
dest_ptr = <char*>dest_pb.buf
237238
dest_nbytes = dest_pb.len
238239

239-
if dest_size == ZSTD_CONTENTSIZE_UNKNOWN:
240-
dest_size = dest_nbytes
241-
242240
# validate output buffer
243241
if dest_nbytes < dest_size:
244242
raise ValueError('destination buffer too small; expected at least %s, '
@@ -388,7 +386,7 @@ class Zstd(Codec):
388386
return decompress(buf, out)
389387

390388
def __repr__(self):
391-
r = '%s(level=%r)' % \
389+
r = '%s(level=%r)' %
392390
(type(self).__name__,
393391
self.level)
394392
return r
@@ -406,4 +404,4 @@ class Zstd(Codec):
406404
@classmethod
407405
def max_level(cls):
408406
"""Returns the maximum compression level of the underlying zstd library."""
409-
return ZSTD_maxCLevel()
407+
return ZSTD_maxCLevel()

0 commit comments

Comments
 (0)