11import numpy as np
22
3+ from typing import Callable
34
45from .abc import Codec
56from .compat import ensure_ndarray_like , ndarray_copy
@@ -29,19 +30,30 @@ class BitRound(Codec):
2930 Parameters
3031 ----------
3132
32- keepbits: int
33+ keepbits: int or function
3334 The number of bits of the mantissa to keep. The range allowed
3435 depends on the dtype input data. If keepbits is
3536 equal to the maximum allowed for the data type, this is equivalent
36- to no transform.
37+ to no transform. Alternatively, pass a function to determine the
38+ number of bits to keep from the input data. The function should
39+ take a single argument, the input data, and return an integer
40+ specifying the number of bits to keep.
3741 """
3842
3943 codec_id = 'bitround'
4044
41- def __init__ (self , keepbits : int ):
42- if keepbits < 0 :
45+ def __init__ (self , keepbits : [ int , Callable ] ):
46+ if isinstance ( keepbits , int ) and keepbits < 0 :
4347 raise ValueError ("keepbits must be zero or positive" )
44- self .keepbits = keepbits
48+
49+ elif isinstance (keepbits , int ):
50+ self .keepbits = [lambda x : keepbits ]
51+
52+ elif isinstance (keepbits , Callable ):
53+ self .keepbits = keepbits
54+
55+ else :
56+ raise TypeError ("keepbits must be an integer or function" )
4557
4658 def encode (self , buf ):
4759 """Create int array by rounding floating-point data
@@ -56,12 +68,13 @@ def encode(self, buf):
5668 # cast float to int type of same width (preserve endianness)
5769 a_int_dtype = np .dtype (a .dtype .str .replace ("f" , "i" ))
5870 all_set = np .array (- 1 , dtype = a_int_dtype )
59- if self .keepbits == bits :
71+ buf_keepbits = self .keepbits (buf )
72+ if buf_keepbits == bits :
6073 return a
61- if self . keepbits > bits :
74+ if buf_keepbits > bits :
6275 raise ValueError ("Keepbits too large for given dtype" )
6376 b = a .view (a_int_dtype )
64- maskbits = bits - self . keepbits
77+ maskbits = bits - buf_keepbits
6578 mask = (all_set >> maskbits ) << maskbits
6679 half_quantum1 = (1 << (maskbits - 1 )) - 1
6780 b += ((b >> maskbits ) & 1 ) + half_quantum1
0 commit comments