Skip to content

Commit 08d96a7

Browse files
committed
relax object encoding accepted inputs
1 parent 0397986 commit 08d96a7

File tree

5 files changed

+17
-40
lines changed

5 files changed

+17
-40
lines changed

numcodecs/msgpacks.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,11 @@
66

77

88
from numcodecs.abc import Codec
9-
from numcodecs.compat import ndarray_from_buffer, buffer_copy
109
import msgpack
1110

1211

1312
class MsgPack(Codec):
14-
"""Codec to encode data as msgpacked bytes. Useful for encoding python
15-
strings
13+
"""Codec to encode data as msgpacked bytes. Useful for encoding an array of Python strings
1614
1715
Raises
1816
------
@@ -27,14 +25,16 @@ class MsgPack(Codec):
2725
>>> f.decode(f.encode(x))
2826
array(['foo', 'bar', 'baz'], dtype=object)
2927
28+
See Also
29+
--------
30+
:class:`numcodecs.pickles.Pickle`
31+
3032
""" # flake8: noqa
3133

3234
codec_id = 'msgpack'
3335

3436
def encode(self, buf):
35-
if hasattr(buf, 'dtype') and buf.dtype != 'object':
36-
raise ValueError("cannot encode non-object ndarrays, %s "
37-
"dtype was passed" % buf.dtype)
37+
buf = np.asarray(buf, dtype='object')
3838
return msgpack.packb(buf.tolist(), encoding='utf-8')
3939

4040
def decode(self, buf, out=None):

numcodecs/pickles.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,17 +14,16 @@
1414

1515

1616
class Pickle(Codec):
17-
"""Codec to encode data as as pickled bytes. Useful for encoding python
18-
strings.
17+
"""Codec to encode data as as pickled bytes. Useful for encoding an array of Python strings.
1918
2019
Parameters
2120
----------
2221
protocol : int, defaults to pickle.HIGHEST_PROTOCOL
23-
the protocol used to pickle data
22+
The protocol used to pickle data.
2423
2524
Raises
2625
------
27-
encoding a non-object dtyped ndarray will raise ValueError
26+
Encoding a non-object dtyped ndarray will raise ValueError.
2827
2928
Examples
3029
--------
@@ -35,6 +34,10 @@ class Pickle(Codec):
3534
>>> f.decode(f.encode(x))
3635
array(['foo', 'bar', 'baz'], dtype=object)
3736
37+
See Also
38+
--------
39+
:class:`numcodecs.msgpacks.MsgPack`
40+
3841
""" # flake8: noqa
3942

4043
codec_id = 'pickle'
@@ -43,9 +46,7 @@ def __init__(self, protocol=pickle.HIGHEST_PROTOCOL):
4346
self.protocol = protocol
4447

4548
def encode(self, buf):
46-
if hasattr(buf, 'dtype') and buf.dtype != 'object':
47-
raise ValueError("cannot encode non-object ndarrays, %s "
48-
"dtype was passed" % buf.dtype)
49+
buf = np.asarray(buf, dtype='object')
4950
return pickle.dumps(buf, protocol=self.protocol)
5051

5152
def decode(self, buf, out=None):

numcodecs/tests/common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ def compare(res, arr=arr):
117117
dec = codec.decode(enc)
118118
compare(dec)
119119

120-
out = np.empty_like(arr)
120+
out = np.empty_like(arr, dtype='object')
121121
codec.decode(enc, out=out)
122122
compare(out)
123123

numcodecs/tests/test_msgpacks.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44

55
import numpy as np
66
import nose
7-
from numpy.testing import assert_raises
87

98
try:
109
from numcodecs.msgpacks import MsgPack
@@ -21,22 +20,11 @@
2120
np.array(['foo', 'bar', 'baz'] * 300, dtype=object),
2221
np.array([['foo', 'bar', np.nan]] * 300, dtype=object),
2322
np.array(['foo', 1.0, 2] * 300, dtype=object),
24-
]
25-
26-
27-
# non-object ndarrays
28-
arrays_incompat = [
2923
np.arange(1000, dtype='i4'),
3024
np.array(['foo', 'bar', 'baz'] * 300),
3125
]
3226

3327

34-
def test_encode_errors():
35-
for arr in arrays_incompat:
36-
codec = MsgPack()
37-
assert_raises(ValueError, codec.encode, arr)
38-
39-
4028
def test_encode_decode():
4129
for arr in arrays:
4230
codec = MsgPack()

numcodecs/tests/test_pickle.py renamed to numcodecs/tests/test_pickles.py

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,10 @@
33

44

55
import numpy as np
6-
from numpy.testing import assert_raises
76

87

98
from numcodecs.pickles import Pickle
10-
from numcodecs.tests.common import (check_config, check_repr,
11-
check_encode_decode_objects)
9+
from numcodecs.tests.common import check_config, check_repr, check_encode_decode_objects
1210

1311

1412
# object array with strings
@@ -18,24 +16,14 @@
1816
np.array(['foo', 'bar', 'baz'] * 300, dtype=object),
1917
np.array([['foo', 'bar', np.nan]] * 300, dtype=object),
2018
np.array(['foo', 1.0, 2] * 300, dtype=object),
21-
]
22-
23-
# non-object ndarrays
24-
arrays_incompat = [
2519
np.arange(1000, dtype='i4'),
2620
np.array(['foo', 'bar', 'baz'] * 300),
2721
]
2822

2923

30-
def test_encode_errors():
31-
for arr in arrays_incompat:
32-
codec = Pickle()
33-
assert_raises(ValueError, codec.encode, arr)
34-
35-
3624
def test_encode_decode():
25+
codec = Pickle()
3726
for arr in arrays:
38-
codec = Pickle()
3927
check_encode_decode_objects(arr, codec)
4028

4129

0 commit comments

Comments
 (0)