8
8
9
9
from zarr .meta import encode_dtype , decode_dtype
10
10
from zarr .compressors import registry as compressor_registry
11
+ from zarr .compat import text_type , binary_type
11
12
12
13
13
14
filter_registry = dict ()
@@ -31,6 +32,12 @@ class DeltaFilter(object):
31
32
astype : dtype, optional
32
33
Data type to use for encoded data.
33
34
35
+ Notes
36
+ -----
37
+ If `astype` is an integer data type, please ensure that it is
38
+ sufficiently large to store encoded values. No checks are made and data
39
+ may become corrupted due to integer overflow if `astype` is too small.
40
+
34
41
Examples
35
42
--------
36
43
>>> import zarr
@@ -86,7 +93,7 @@ def get_filter_config(self):
86
93
def from_filter_config (cls , config ):
87
94
dtype = decode_dtype (config ['dtype' ])
88
95
astype = decode_dtype (config ['astype' ])
89
- return cls (dtype = dtype , asdtype = astype )
96
+ return cls (dtype = dtype , astype = astype )
90
97
91
98
92
99
filter_registry [DeltaFilter .filter_name ] = DeltaFilter
@@ -109,6 +116,12 @@ class FixedScaleOffsetFilter(object):
109
116
astype : dtype, optional
110
117
Data type to use for encoded data.
111
118
119
+ Notes
120
+ -----
121
+ If `astype` is an integer data type, please ensure that it is
122
+ sufficiently large to store encoded values. No checks are made and data
123
+ may become corrupted due to integer overflow if `astype` is too small.
124
+
112
125
Examples
113
126
--------
114
127
>>> import zarr
@@ -248,6 +261,8 @@ def __init__(self, digits, dtype, astype=None):
248
261
self .astype = self .dtype
249
262
else :
250
263
self .astype = np .dtype (astype )
264
+ if self .dtype .kind != 'f' or self .astype .kind != 'f' :
265
+ raise ValueError ('only floating point data types are supported' )
251
266
252
267
def encode (self , buf ):
253
268
# interpret buffer as 1D array
@@ -324,11 +339,17 @@ def encode(self, buf):
324
339
arr = _ndarray_from_buffer (buf , bool )
325
340
# determine size of packed data
326
341
n = arr .size
327
- n_bytes_packed = (n // 8 ) + 1
328
- n_bits_padded = n % 8
342
+ n_bytes_packed = (n // 8 )
343
+ n_bits_leftover = n % 8
344
+ if n_bits_leftover > 0 :
345
+ n_bytes_packed += 1
329
346
# setup output
330
347
enc = np .empty (n_bytes_packed + 1 , dtype = 'u1' )
331
348
# remember how many bits were padded
349
+ if n_bits_leftover :
350
+ n_bits_padded = 8 - n_bits_leftover
351
+ else :
352
+ n_bits_padded = 0
332
353
enc [0 ] = n_bits_padded
333
354
# apply encoding
334
355
enc [1 :] = np .packbits (arr )
@@ -342,7 +363,8 @@ def decode(self, buf):
342
363
# apply decoding
343
364
dec = np .unpackbits (enc [1 :])
344
365
# remove padded bits
345
- dec = dec [:- n_bits_padded ]
366
+ if n_bits_padded :
367
+ dec = dec [:- n_bits_padded ]
346
368
# view as boolean array
347
369
dec = dec .view (bool )
348
370
return dec
@@ -360,6 +382,94 @@ def from_filter_config(cls, config):
360
382
filter_registry [PackBitsFilter .filter_name ] = PackBitsFilter
361
383
362
384
385
+ def _ensure_bytes (l ):
386
+ if isinstance (l , binary_type ):
387
+ return l
388
+ elif isinstance (l , text_type ):
389
+ return l .encode ('ascii' )
390
+ else :
391
+ raise ValueError ('expected bytes, found %r' % l )
392
+
393
+
394
+ class CategoryFilter (object ):
395
+ """Filter encoding categorical string data as integers.
396
+
397
+ Parameters
398
+ ----------
399
+ labels : sequence of strings
400
+ Category labels.
401
+ dtype : dtype
402
+ Data type to use for decoded data.
403
+ astype : dtype, optional
404
+ Data type to use for encoded data.
405
+
406
+ Examples
407
+ --------
408
+ >>> import zarr
409
+ >>> import numpy as np
410
+ >>> x = np.array([b'male', b'female', b'female', b'male', b'unexpected'])
411
+ >>> x
412
+ array([b'male', b'female', b'female', b'male', b'unexpected'],
413
+ dtype='|S10')
414
+ >>> f = zarr.CategoryFilter(labels=[b'female', b'male'], dtype=x.dtype)
415
+ >>> y = f.encode(x)
416
+ >>> y
417
+ array([2, 1, 1, 2, 0], dtype=uint8)
418
+ >>> z = f.decode(y)
419
+ >>> z
420
+ array([b'male', b'female', b'female', b'male', b''],
421
+ dtype='|S10')
422
+
423
+ """
424
+
425
+ filter_name = 'category'
426
+
427
+ def __init__ (self , labels , dtype , astype = 'u1' ):
428
+ self .labels = [_ensure_bytes (l ) for l in labels ]
429
+ self .dtype = np .dtype (dtype )
430
+ if self .dtype .kind != 'S' :
431
+ raise ValueError ('only string data types are supported' )
432
+ self .astype = np .dtype (astype )
433
+
434
+ def encode (self , buf ):
435
+ # view input as ndarray
436
+ arr = _ndarray_from_buffer (buf , self .dtype )
437
+ # setup output array
438
+ enc = np .zeros_like (arr , dtype = self .astype )
439
+ # apply encoding, reserving 0 for values not specified in labels
440
+ for i , l in enumerate (self .labels ):
441
+ enc [arr == l ] = i + 1
442
+ return enc
443
+
444
+ def decode (self , buf ):
445
+ # view encoded data as ndarray
446
+ enc = _ndarray_from_buffer (buf , self .astype )
447
+ # setup output
448
+ dec = np .zeros_like (enc , dtype = self .dtype )
449
+ # apply decoding
450
+ for i , l in enumerate (self .labels ):
451
+ dec [enc == (i + 1 )] = l
452
+ return dec
453
+
454
+ def get_filter_config (self ):
455
+ config = dict ()
456
+ config ['name' ] = self .filter_name
457
+ config ['labels' ] = [text_type (l , 'ascii' ) for l in self .labels ]
458
+ config ['dtype' ] = encode_dtype (self .dtype )
459
+ config ['astype' ] = encode_dtype (self .astype )
460
+ return config
461
+
462
+ @classmethod
463
+ def from_filter_config (cls , config ):
464
+ dtype = decode_dtype (config ['dtype' ])
465
+ astype = decode_dtype (config ['astype' ])
466
+ labels = config ['labels' ]
467
+ return cls (labels = labels , dtype = dtype , astype = astype )
468
+
469
+
470
+ filter_registry [CategoryFilter .filter_name ] = CategoryFilter
471
+
472
+
363
473
# add in compressors as filters
364
474
for cls in compressor_registry .values ():
365
475
if hasattr (cls , 'filter_name' ):
0 commit comments