Skip to content

Commit 22fba10

Browse files
committed
compression arg compatibility
1 parent 89747d2 commit 22fba10

File tree

7 files changed

+304
-260
lines changed

7 files changed

+304
-260
lines changed

zarr/codecs.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -828,6 +828,15 @@ def _ensure_bytes(l):
828828
raise ValueError('expected bytes, found %r' % l)
829829

830830

831+
def _ensure_text(l):
832+
if isinstance(l, text_type):
833+
return l
834+
elif isinstance(l, binary_type):
835+
return text_type(l, 'ascii')
836+
else:
837+
raise ValueError('expected text, found %r' % l)
838+
839+
831840
class Categorize(Codec):
832841
"""Filter encoding categorical string data as integers.
833842
@@ -862,10 +871,13 @@ class Categorize(Codec):
862871
codec_id = 'categorize'
863872

864873
def __init__(self, labels, dtype, astype='u1'):
865-
self.labels = [_ensure_bytes(l) for l in labels]
866874
self.dtype = np.dtype(dtype)
867-
if self.dtype.kind != 'S':
868-
raise ValueError('only string data types are supported')
875+
if self.dtype.kind == 'S':
876+
self.labels = [_ensure_bytes(l) for l in labels]
877+
elif self.dtype.kind == 'U':
878+
self.labels = [_ensure_text(l) for l in labels]
879+
else:
880+
raise ValueError('data type not supported')
869881
self.astype = np.dtype(astype)
870882

871883
def encode(self, buf):
@@ -909,7 +921,7 @@ def decode(self, buf, out=None):
909921
def get_config(self):
910922
config = dict()
911923
config['id'] = self.codec_id
912-
config['labels'] = [text_type(l, 'ascii') for l in self.labels]
924+
config['labels'] = [_ensure_text(l) for l in self.labels]
913925
config['dtype'] = encode_dtype(self.dtype)
914926
config['astype'] = encode_dtype(self.astype)
915927
return config
@@ -922,8 +934,12 @@ def from_config(cls, config):
922934
return cls(labels=labels, dtype=dtype, astype=astype)
923935

924936
def __repr__(self):
925-
r = '%s(dtype=%s, astype=%s, labels=%r)' % \
926-
(type(self).__name__, self.dtype, self.astype, self.labels)
937+
# make sure labels part is not too long
938+
labels = repr(self.labels[:3])
939+
if len(self.labels) > 3:
940+
labels = labels[:-1] + ', ...]'
941+
r = '%s(dtype=%s, astype=%s, labels=%s)' % \
942+
(type(self).__name__, self.dtype, self.astype, labels)
927943
return r
928944

929945

0 commit comments

Comments
 (0)