Skip to content

Commit ee4a515

Browse files
committed
add object dtype convenience API
1 parent e584f68 commit ee4a515

File tree

3 files changed

+80
-14
lines changed

3 files changed

+80
-14
lines changed

zarr/storage.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@
2525

2626

2727
from zarr.util import (normalize_shape, normalize_chunks, normalize_order,
28-
normalize_storage_path, buffer_size, normalize_fill_value, nolock)
28+
normalize_storage_path, buffer_size,
29+
normalize_fill_value, nolock, normalize_dtype)
2930
from zarr.meta import encode_array_metadata, encode_group_metadata
3031
from zarr.compat import PY2, binary_type
3132
from numcodecs.registry import codec_registry
@@ -308,10 +309,7 @@ def _init_array_metadata(store, shape, chunks=None, dtype=None, compressor='defa
308309

309310
# normalize metadata
310311
shape = normalize_shape(shape)
311-
dtype = np.dtype(dtype)
312-
if dtype.kind in 'mM':
313-
raise ValueError('datetime64 and timedelta64 dtypes are not currently supported; '
314-
'please store the data using int64 instead')
312+
dtype, object_codec = normalize_dtype(dtype, object_codec)
315313
chunks = normalize_chunks(chunks, shape, dtype.itemsize)
316314
order = normalize_order(order)
317315
fill_value = normalize_fill_value(fill_value, dtype)

zarr/tests/test_core.py

Lines changed: 34 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
DBMStore, LMDBStore, atexit_rmtree, atexit_rmglob)
2222
from zarr.core import Array
2323
from zarr.errors import PermissionError
24-
from zarr.compat import PY2
24+
from zarr.compat import PY2, text_type, binary_type
2525
from zarr.util import buffer_size
2626
from numcodecs import (Delta, FixedScaleOffset, Zlib, Blosc, BZ2, MsgPack, Pickle,
2727
Categorize, JSON, VLenUTF8, VLenBytes, VLenArray)
@@ -941,6 +941,13 @@ def test_object_arrays_vlen_text(self):
941941
z[:] = data
942942
assert_array_equal(data, z[:])
943943

944+
# convenience API
945+
z = self.create_array(shape=data.shape, dtype=text_type)
946+
assert z.dtype == object
947+
assert isinstance(z.filters[0], VLenUTF8)
948+
z[:] = data
949+
assert_array_equal(data, z[:])
950+
944951
z = self.create_array(shape=data.shape, dtype=object, object_codec=MsgPack())
945952
z[:] = data
946953
assert_array_equal(data, z[:])
@@ -967,6 +974,13 @@ def test_object_arrays_vlen_bytes(self):
967974
z[:] = data
968975
assert_array_equal(data, z[:])
969976

977+
# convenience API
978+
z = self.create_array(shape=data.shape, dtype=binary_type)
979+
assert z.dtype == object
980+
assert isinstance(z.filters[0], VLenBytes)
981+
z[:] = data
982+
assert_array_equal(data, z[:])
983+
970984
z = self.create_array(shape=data.shape, dtype=object, object_codec=Pickle())
971985
z[:] = data
972986
assert_array_equal(data, z[:])
@@ -977,18 +991,29 @@ def test_object_arrays_vlen_array(self):
977991
np.array([5]),
978992
np.array([2, 8, 12])] * 1000, dtype=object)
979993

994+
def compare_arrays(expected, actual, item_dtype):
995+
assert isinstance(actual, np.ndarray)
996+
assert actual.dtype == object
997+
assert actual.shape == expected.shape
998+
for e, a in zip(expected.flat, actual.flat):
999+
assert isinstance(a, np.ndarray)
1000+
assert_array_equal(e, a)
1001+
assert a.dtype == item_dtype
1002+
9801003
codecs = VLenArray(int), VLenArray('<u4')
9811004
for codec in codecs:
9821005
z = self.create_array(shape=data.shape, dtype=object, object_codec=codec)
9831006
z[:] = data
984-
a = z[:]
985-
assert isinstance(a, np.ndarray)
986-
assert a.dtype == object
987-
assert a.shape == data.shape
988-
for expected, actual in zip(data.flat, a.flat):
989-
assert isinstance(actual, np.ndarray)
990-
assert_array_equal(expected, actual)
991-
assert actual.dtype == codec.dtype
1007+
compare_arrays(data, z[:], codec.dtype)
1008+
1009+
# convenience API
1010+
for item_type in 'int', '<u4':
1011+
z = self.create_array(shape=data.shape, dtype='array:{}'.format(item_type))
1012+
assert z.dtype == object
1013+
assert isinstance(z.filters[0], VLenArray)
1014+
assert z.filters[0].dtype == np.dtype(item_type)
1015+
z[:] = data
1016+
compare_arrays(data, z[:], np.dtype(item_type))
9921017

9931018
def test_object_arrays_danger(self):
9941019

zarr/util.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,26 @@
44
from textwrap import TextWrapper, dedent
55
import numbers
66
import uuid
7+
import inspect
78

89

910
from asciitree import BoxStyle, LeftAligned
1011
from asciitree.traversal import Traversal
1112
import numpy as np
13+
from numcodecs.registry import codec_registry
1214

1315

1416
from zarr.compat import PY2, reduce, text_type, binary_type
1517

1618

19+
# codecs to use for object dtype convenience API
20+
object_codecs = {
21+
text_type.__name__: 'vlen-utf8',
22+
binary_type.__name__: 'vlen-bytes',
23+
'array': 'vlen-array',
24+
}
25+
26+
1727
def normalize_shape(shape):
1828
"""Convenience function to normalize the `shape` argument."""
1929

@@ -116,6 +126,39 @@ def normalize_chunks(chunks, shape, typesize):
116126
return chunks
117127

118128

129+
def normalize_dtype(dtype, object_codec):
130+
131+
# convenience API for object arrays
132+
if inspect.isclass(dtype):
133+
dtype = dtype.__name__
134+
if isinstance(dtype, str):
135+
tokens = dtype.split(':')
136+
key = tokens[0]
137+
if key in object_codecs:
138+
dtype = np.dtype(object)
139+
if object_codec is None:
140+
codec_id = object_codecs[key]
141+
if len(tokens) > 1:
142+
args = tokens[1].split(',')
143+
else:
144+
args = ()
145+
try:
146+
object_codec = codec_registry[codec_id](*args)
147+
except KeyError:
148+
raise ValueError('codec %r for object type %r is not '
149+
'available; please provide an '
150+
'object_codec manually' % (codec_id, key))
151+
return dtype, object_codec
152+
153+
dtype = np.dtype(dtype)
154+
155+
if dtype.kind in 'mM':
156+
raise ValueError('datetime64 and timedelta64 dtypes are not currently '
157+
'supported; please store the data using int64 instead')
158+
159+
return dtype, object_codec
160+
161+
119162
# noinspection PyTypeChecker
120163
def is_total_slice(item, shape):
121164
"""Determine whether `item` specifies a complete slice of array with the

0 commit comments

Comments
 (0)