Skip to content

Commit 1349304

Browse files
author
Mark Kittisopikul
committed
Fix LLM issues ..
1 parent b0fe556 commit 1349304

File tree

1 file changed

+12
-3
lines changed

1 file changed

+12
-3
lines changed

numcodecs/zstd.pyx

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@ from .abc import Codec
1515

1616
from libc.stdlib cimport malloc, realloc, free
1717

18+
cdef extern from "stdint.h":
19+
cdef size_t SIZE_MAX
20+
1821
cdef extern from "zstd.h":
1922

2023
unsigned ZSTD_versionNumber() nogil
@@ -129,7 +132,7 @@ def compress(source, int level=DEFAULT_CLEVEL, bint checksum=False):
129132
level = MAX_CLEVEL
130133

131134
# obtain source memoryview
132-
source_mv = ensure_contiguous_memoryview(source)
135+
source_mv = ensure_continguous_memoryview(source)
133136
source_pb = PyMemoryView_GET_BUFFER(source_mv)
134137

135138
# setup source buffer
@@ -206,7 +209,7 @@ def decompress(source, dest=None):
206209
unsigned long long content_size
207210

208211
# obtain source memoryview
209-
source_mv = ensure_contiguous_memoryview(source)
212+
source_mv = ensure_continguous_memoryview(source)
210213
source_pb = PyMemoryView_GET_BUFFER(source_mv)
211214

212215
# get source pointer
@@ -218,6 +221,8 @@ def decompress(source, dest=None):
218221
content_size = ZSTD_getFrameContentSize(source_ptr, source_size)
219222
if content_size == ZSTD_CONTENTSIZE_UNKNOWN and dest is None:
220223
return stream_decompress(source_pb)
224+
elif content_size == ZSTD_CONTENTSIZE_UNKNOWN:
225+
# dest is not None
221226
elif content_size == ZSTD_CONTENTSIZE_ERROR or content_size == 0:
222227
raise RuntimeError('Zstd decompression error: invalid input data')
223228
elif content_size > (<unsigned long long>SIZE_MAX):
@@ -227,6 +232,7 @@ def decompress(source, dest=None):
227232

228233
# setup destination buffer
229234
if dest is None:
235+
# allocate memory
230236
dest_1d = dest = PyBytes_FromStringAndSize(NULL, dest_size)
231237
else:
232238
dest_1d = ensure_contiguous_ndarray(dest)
@@ -237,6 +243,9 @@ def decompress(source, dest=None):
237243
dest_ptr = <char*>dest_pb.buf
238244
dest_nbytes = dest_pb.len
239245

246+
if content_size == ZSTD_CONTENTSIZE_UNKNOWN:
247+
dest_size = dest_nbytes
248+
240249
# validate output buffer
241250
if dest_nbytes < dest_size:
242251
raise ValueError('destination buffer too small; expected at least %s, '
@@ -404,4 +413,4 @@ class Zstd(Codec):
404413
@classmethod
405414
def max_level(cls):
406415
"""Returns the maximum compression level of the underlying zstd library."""
407-
return ZSTD_maxCLevel()
416+
return ZSTD_maxCLevel()

0 commit comments

Comments
 (0)