@@ -23,11 +23,9 @@ class BitInfo(BitRound):
2323 Parameters
2424 ----------
2525
26- inflevel: float
27- The number of bits of the mantissa to keep. The range allowed
28- depends on the dtype input data. If keepbits is
29- equal to the maximum allowed for the data type, this is equivalent
30- to no transform.
26+ info_level: float
27+ The level of information to preserve in the data. The value should be
28+ between 0. and 1.0. Higher values preserve more information.
3129
3230 axes: int or list of int, optional
3331 Axes along which to calculate the bit information. If None, all axes
@@ -36,13 +34,22 @@ class BitInfo(BitRound):
3634
3735 codec_id = 'bitinfo'
3836
39- def __init__ (self , inflevel : float , axes = None ):
40- if (inflevel < 0 ) or (inflevel > 1.0 ):
41- raise ValueError ("Please provide `inflevel ` from interval [0.,1.]" )
37+ def __init__ (self , info_level : float , axes = None ):
38+ if (info_level < 0 ) or (info_level > 1.0 ):
39+ raise ValueError ("Please provide `info_level ` from interval [0.,1.]" )
4240
43- self .inflevel = inflevel
41+ elif axes is not None and not isinstance (axes , list ):
42+ if int (axes ) != axes :
43+ raise ValueError ("axis must be an integer or a list of integers." )
44+ axes = [axes ]
45+
46+ elif isinstance (axes , list ) and not all (int (ax ) == ax for ax in axes ):
47+ raise ValueError ("axis must be an integer or a list of integers." )
48+
49+ self .info_level = info_level
4450 self .axes = axes
4551
52+
4653 def encode (self , buf ):
4754 """Create int array by rounding floating-point data
4855
@@ -68,11 +75,11 @@ def encode(self, buf):
6875
6976 for ax in self .axes :
7077 info_per_bit = bitinformation (a , axis = ax )
71- keepbits .append (get_keepbits (info_per_bit , self .inflevel ))
78+ keepbits .append (get_keepbits (info_per_bit , self .info_level ))
7279
7380 keepbits = max (keepbits )
7481
75- return BitRound ._bitround ( a , keepbits , dtype )
82+ return BitRound .bitround ( buf , keepbits , dtype )
7683
7784
7885def exponent_bias (dtype ):
@@ -117,12 +124,12 @@ def signed_exponent(A):
117124
118125 Parameters
119126 ----------
120- A : :py:class:`numpy. array`
127+ a : array
121128 Array to transform
122129
123130 Returns
124131 -------
125- B : :py:class:`numpy. array`
132+ array
126133
127134 Example
128135 -------
@@ -162,8 +169,7 @@ def signed_exponent(A):
162169 eabs = np .uint64 (eabs )
163170 esign = np .uint64 (esign )
164171 esigned = esign | (eabs << sbits )
165- B = (sf | esigned ).view (np .int64 )
166- return B
172+ return (sf | esigned ).view (np .int64 )
167173
168174
169175def bitpaircount_u1 (a , b ):
@@ -260,7 +266,8 @@ def get_keepbits(info_per_bit, inflevel=0.99):
260266
261267def _cdf_from_info_per_bit (info_per_bit ):
262268 """Convert info_per_bit to cumulative distribution function"""
263- tol = info_per_bit [- 4 :].max () * 1.5
264- info_per_bit [info_per_bit < tol ] = 0
269+ # TODO this threshold isn't working yet
270+ #tol = info_per_bit[-4:].max() * 1.5
271+ #info_per_bit[info_per_bit < tol] = 0
265272 cdf = info_per_bit .cumsum ()
266273 return cdf / cdf [- 1 ]
0 commit comments