Skip to content

Commit 8ba41ad

Browse files
committed
add categorize
1 parent a919ea9 commit 8ba41ad

File tree

6 files changed

+255
-0
lines changed

6 files changed

+255
-0
lines changed

docs/categorize.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
Categorize
2+
==========
3+
.. module:: numcodecs.categorize
4+
5+
.. autoclass:: Categorize

docs/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ Contents
3030
delta
3131
fixedscaleoffset
3232
packbits
33+
categorize
3334
release
3435

3536
Acknowledgments

numcodecs/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,3 +31,6 @@
3131

3232
from numcodecs.packbits import PackBits
3333
register_codec(PackBits)
34+
35+
from numcodecs.categorize import Categorize
36+
register_codec(Categorize)

numcodecs/categorize.py

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
# -*- coding: utf-8 -*-
2+
from __future__ import absolute_import, print_function, division
3+
4+
5+
from numcodecs.abc import Codec
6+
from numcodecs.compat import ndarray_from_buffer, buffer_copy, ensure_text, \
7+
ensure_bytes
8+
9+
10+
import numpy as np
11+
12+
13+
class Categorize(Codec):
14+
"""Filter encoding categorical string data as integers.
15+
16+
Parameters
17+
----------
18+
labels : sequence of strings
19+
Category labels.
20+
dtype : dtype
21+
Data type to use for decoded data.
22+
astype : dtype, optional
23+
Data type to use for encoded data.
24+
25+
Examples
26+
--------
27+
>>> import numcodecs as codecs
28+
>>> import numpy as np
29+
>>> x = np.array([b'male', b'female', b'female', b'male', b'unexpected'])
30+
>>> x
31+
array([b'male', b'female', b'female', b'male', b'unexpected'],
32+
dtype='|S10')
33+
>>> f = codecs.Categorize(labels=[b'female', b'male'], dtype=x.dtype)
34+
>>> y = f.encode(x)
35+
>>> y
36+
array([2, 1, 1, 2, 0], dtype=uint8)
37+
>>> z = f.decode(y)
38+
>>> z
39+
array([b'male', b'female', b'female', b'male', b''],
40+
dtype='|S10')
41+
42+
"""
43+
44+
codec_id = 'categorize'
45+
46+
def __init__(self, labels, dtype, astype='u1'):
47+
self.dtype = np.dtype(dtype)
48+
if self.dtype.kind == 'S':
49+
self.labels = [ensure_bytes(l) for l in labels]
50+
elif self.dtype.kind == 'U':
51+
self.labels = [ensure_text(l) for l in labels]
52+
else:
53+
self.labels = labels
54+
self.astype = np.dtype(astype)
55+
56+
def encode(self, buf):
57+
58+
# view input as ndarray
59+
arr = ndarray_from_buffer(buf, self.dtype)
60+
61+
# setup output array
62+
enc = np.zeros_like(arr, dtype=self.astype)
63+
64+
# apply encoding, reserving 0 for values not specified in labels
65+
for i, l in enumerate(self.labels):
66+
enc[arr == l] = i + 1
67+
68+
return enc
69+
70+
def decode(self, buf, out=None):
71+
72+
# view encoded data as ndarray
73+
enc = ndarray_from_buffer(buf, self.astype)
74+
75+
# setup output
76+
if isinstance(out, np.ndarray):
77+
# optimization, decode directly to output
78+
dec = out.reshape(-1, order='A')
79+
copy_needed = False
80+
else:
81+
dec = np.zeros_like(enc, dtype=self.dtype)
82+
copy_needed = True
83+
84+
# apply decoding
85+
for i, l in enumerate(self.labels):
86+
dec[enc == (i + 1)] = l
87+
88+
# handle output
89+
if copy_needed:
90+
dec = buffer_copy(dec, out)
91+
92+
return dec
93+
94+
def get_config(self):
95+
if self.dtype.kind == 'S':
96+
labels = [ensure_text(l) for l in self.labels]
97+
else:
98+
labels = self.labels
99+
config = dict(
100+
id=self.codec_id,
101+
labels=labels,
102+
dtype=self.dtype.str,
103+
astype=self.astype.str
104+
)
105+
return config
106+
107+
def __repr__(self):
108+
# make sure labels part is not too long
109+
labels = repr(self.labels[:3])
110+
if len(self.labels) > 3:
111+
labels = labels[:-1] + ', ...]'
112+
r = '%s(dtype=%r, astype=%r, labels=%s)' % \
113+
(type(self).__name__, self.dtype.str, self.astype.str, labels)
114+
return r

numcodecs/compat.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,3 +89,17 @@ def ndarray_from_buffer(buf, dtype):
8989
else:
9090
arr = np.frombuffer(buf, dtype=dtype)
9191
return arr
92+
93+
94+
def ensure_bytes(l, encoding='utf-8'):
95+
if isinstance(l, binary_type):
96+
return l
97+
else:
98+
return l.encode(encoding=encoding)
99+
100+
101+
def ensure_text(l, encoding='utf-8'):
102+
if isinstance(l, text_type):
103+
return l
104+
else:
105+
return text_type(l, encoding=encoding)

numcodecs/tests/test_categorize.py

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
# -*- coding: utf-8 -*-
2+
from __future__ import absolute_import, print_function, division
3+
4+
5+
import numpy as np
6+
from numpy.testing import assert_array_equal
7+
from nose.tools import eq_ as eq
8+
9+
10+
from numcodecs.categorize import Categorize
11+
from numcodecs.tests.common import check_encode_decode, check_config
12+
from numcodecs.compat import PY2
13+
14+
15+
labels = [b'foo', b'bar', b'baz', b'quux']
16+
labels_u = [u'ƒöõ', u'ßàř', u'ßāẑ', u'ƪùüx']
17+
labels_num = [1000000, 2000000, 3000000]
18+
arrays = [
19+
np.random.choice(labels, size=1000),
20+
np.random.choice(labels, size=(100, 10)),
21+
np.random.choice(labels, size=(10, 10, 10)),
22+
np.random.choice(labels, size=1000).reshape(100, 10, order='F'),
23+
]
24+
arrays_u = [
25+
np.random.choice(labels_u, size=1000),
26+
np.random.choice(labels_u, size=(100, 10)),
27+
np.random.choice(labels_u, size=(10, 10, 10)),
28+
np.random.choice(labels_u, size=1000).reshape(100, 10, order='F'),
29+
]
30+
arrays_num = [
31+
np.random.choice(labels_num, size=1000),
32+
np.random.choice(labels_num, size=(100, 10)),
33+
np.random.choice(labels_num, size=(10, 10, 10)),
34+
np.random.choice(labels_num, size=1000).reshape(100, 10, order='F'),
35+
]
36+
37+
38+
def test_encode_decode():
39+
40+
# string dtype
41+
for arr in arrays:
42+
codec = Categorize(labels, dtype=arr.dtype)
43+
check_encode_decode(arr, codec)
44+
45+
# unicode dtype
46+
for arr in arrays_u:
47+
codec = Categorize(labels_u, dtype=arr.dtype)
48+
check_encode_decode(arr, codec)
49+
50+
# other dtype
51+
for arr in arrays_num:
52+
codec = Categorize(labels_num, dtype=arr.dtype)
53+
check_encode_decode(arr, codec)
54+
55+
56+
def test_encode():
57+
arr = np.array([b'foo', b'bar', b'foo', b'baz', b'quux'])
58+
# miss off quux
59+
codec = Categorize(labels=labels[:-1], dtype=arr.dtype, astype='u1')
60+
61+
# test encoding
62+
expect = np.array([1, 2, 1, 3, 0], dtype='u1')
63+
enc = codec.encode(arr)
64+
assert_array_equal(expect, enc)
65+
eq(expect.dtype, enc.dtype)
66+
67+
# test decoding with unexpected value
68+
dec = codec.decode(enc)
69+
expect = arr.copy()
70+
expect[expect == b'quux'] = b''
71+
assert_array_equal(expect, dec)
72+
eq(arr.dtype, dec.dtype)
73+
74+
75+
def test_encode_unicode():
76+
arr = np.array([u'ƒöõ', u'ßàř', u'ƒöõ', u'ßāẑ', u'ƪùüx'])
77+
# miss off quux
78+
codec = Categorize(labels=labels_u[:-1], dtype=arr.dtype, astype='u1')
79+
80+
# test encoding
81+
expect = np.array([1, 2, 1, 3, 0], dtype='u1')
82+
enc = codec.encode(arr)
83+
assert_array_equal(expect, enc)
84+
eq(expect.dtype, enc.dtype)
85+
86+
# test decoding with unexpected value
87+
dec = codec.decode(enc)
88+
expect = arr.copy()
89+
expect[expect == u'ƪùüx'] = u''
90+
assert_array_equal(expect, dec)
91+
eq(arr.dtype, dec.dtype)
92+
93+
94+
def test_config():
95+
codec = Categorize(labels=labels, dtype='S4')
96+
check_config(codec)
97+
codec = Categorize(labels=labels_u, dtype='U4')
98+
check_config(codec)
99+
100+
101+
def test_repr():
102+
if not PY2:
103+
104+
dtype = '|S5'
105+
astype = '|u1'
106+
codec = Categorize(labels=labels, dtype=dtype, astype=astype)
107+
expect = "Categorize(dtype='|S5', astype='|u1', " \
108+
"labels=[b'foo', b'bar', b'baz', ...])"
109+
actual = repr(codec)
110+
eq(expect, actual)
111+
112+
dtype = '<U5'
113+
astype = '|u1'
114+
codec = Categorize(labels=labels_u, dtype=dtype, astype=astype)
115+
expect = "Categorize(dtype='<U5', astype='|u1', " \
116+
"labels=['ƒöõ', 'ßàř', 'ßāẑ', ...])"
117+
actual = repr(codec)
118+
eq(expect, actual)

0 commit comments

Comments
 (0)