@@ -129,7 +129,7 @@ def compress(source, int level=DEFAULT_CLEVEL, bint checksum=False):
129129 level = MAX_CLEVEL
130130
131131 # obtain source memoryview
132- source_mv = ensure_continguous_memoryview (source)
132+ source_mv = ensure_contiguous_memoryview (source)
133133 source_pb = PyMemoryView_GET_BUFFER(source_mv)
134134
135135 # setup source buffer
@@ -202,30 +202,31 @@ def decompress(source, dest=None):
202202 Py_buffer* dest_pb
203203 char * dest_ptr
204204 size_t source_size, dest_size, decompressed_size
205- size_t nbytes, cbytes, blocksize
206205 size_t dest_nbytes
206+ unsigned long long content_size
207207
208208 # obtain source memoryview
209- source_mv = ensure_continguous_memoryview (source)
209+ source_mv = ensure_contiguous_memoryview (source)
210210 source_pb = PyMemoryView_GET_BUFFER(source_mv)
211211
212212 # get source pointer
213213 source_ptr = < const char * > source_pb.buf
214214 source_size = source_pb.len
215215
216216 try :
217-
218- # determine uncompressed size
219- dest_size = ZSTD_getFrameContentSize(source_ptr, source_size)
220- if dest_size == 0 or dest_size == ZSTD_CONTENTSIZE_ERROR:
217+ # determine uncompressed size using unsigned long long for full range
218+ content_size = ZSTD_getFrameContentSize(source_ptr, source_size)
219+ if content_size == ZSTD_CONTENTSIZE_UNKNOWN and dest is None :
220+ return stream_decompress(source_pb)
221+ elif content_size == ZSTD_CONTENTSIZE_ERROR or content_size == 0 :
221222 raise RuntimeError (' Zstd decompression error: invalid input data' )
223+ elif content_size > (< unsigned long long > SIZE_MAX):
224+ raise RuntimeError (' Zstd decompression error: content size too large for platform' )
222225
223- if dest_size == ZSTD_CONTENTSIZE_UNKNOWN and dest is None :
224- return stream_decompress(source_pb)
226+ dest_size = < size_t> content_size
225227
226228 # setup destination buffer
227229 if dest is None :
228- # allocate memory
229230 dest_1d = dest = PyBytes_FromStringAndSize(NULL , dest_size)
230231 else :
231232 dest_1d = ensure_contiguous_ndarray(dest)
@@ -236,9 +237,6 @@ def decompress(source, dest=None):
236237 dest_ptr = < char * > dest_pb.buf
237238 dest_nbytes = dest_pb.len
238239
239- if dest_size == ZSTD_CONTENTSIZE_UNKNOWN:
240- dest_size = dest_nbytes
241-
242240 # validate output buffer
243241 if dest_nbytes < dest_size:
244242 raise ValueError (' destination buffer too small; expected at least %s , '
@@ -388,7 +386,7 @@ class Zstd(Codec):
388386 return decompress(buf, out)
389387
390388 def __repr__ (self ):
391- r = ' %s (level=%r )' % \
389+ r = ' %s (level=%r )' %
392390 (type (self ).__name__,
393391 self .level)
394392 return r
@@ -406,4 +404,4 @@ class Zstd(Codec):
406404 @classmethod
407405 def max_level (cls ):
408406 """ Returns the maximum compression level of the underlying zstd library."""
409- return ZSTD_maxCLevel()
407+ return ZSTD_maxCLevel()
0 commit comments