Skip to content

Commit 17add9f

Browse files
Fix BaseNEncoder number of output columns (#296)
BaseNEncoder encoder used an incorrect formula for calculating the number of required bits in the output. If there are `nvals` distinct values and we reserve one encoding to represent "missing or unknown", then the correct number of bits is `ceil(log(nvals + 1, base))`. However, the code was previously using the formula `ceil(log(nvals, base)) + 1`. Fixes #264 - Change the formula to `ceil(log(nvals + 1, base))`. - Switch the formula to use integer math so we don't have to worry about floating point rounding errors. - Add a test. - Fix a non-deterministic test.
1 parent f56b2c7 commit 17add9f

File tree

3 files changed

+61
-4
lines changed

3 files changed

+61
-4
lines changed

category_encoders/basen.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,27 @@
1212
__author__ = 'willmcginnis'
1313

1414

15+
def _ceillogint(n, base):
16+
"""
17+
Returns ceil(log(n, base)) for integers n and base.
18+
19+
Uses integer math, so the result is not subject to floating point rounding errors.
20+
21+
base must be >= 2 and n must be >= 1.
22+
"""
23+
if base < 2:
24+
raise ValueError('base must be >= 2')
25+
if n < 1:
26+
raise ValueError('n must be >= 1')
27+
28+
n -= 1
29+
ret = 0
30+
while n > 0:
31+
ret += 1
32+
n //= base
33+
return ret
34+
35+
1536
class BaseNEncoder(BaseEstimator, TransformerMixin):
1637
"""Base-N encoder encodes the categories into arrays of their base-N representation. A base of 1 is equivalent to
1738
one-hot encoding (not really base-1, but useful), a base of 2 is equivalent to binary encoding. N=number of actual
@@ -296,7 +317,7 @@ def calc_required_digits(self, values):
296317
if self.base == 1:
297318
digits = len(values) + 1
298319
else:
299-
digits = int(np.ceil(math.log(len(values), self.base))) + 1
320+
digits = _ceillogint(len(values) + 1, self.base)
300321

301322
return digits
302323

tests/test_basen.py

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,8 @@ def test_HandleUnknown_HaveOnlyKnown_ExpectSecondColumn(self):
8282
result = encoder.fit_transform(train)
8383

8484
self.assertEqual(2, result.shape[0])
85-
self.assertListEqual([0, 0, 1], result.iloc[0, :].tolist())
86-
self.assertListEqual([0, 1, 0], result.iloc[1, :].tolist())
85+
self.assertListEqual([0, 1], result.iloc[0, :].tolist())
86+
self.assertListEqual([1, 0], result.iloc[1, :].tolist())
8787

8888
def test_inverse_transform_HaveNanInTrainAndHandleMissingValue_ExpectReturnedWithNan(self):
8989
train = pd.DataFrame({'city': ['chicago', np.nan]})
@@ -139,3 +139,39 @@ def test_inverse_transform_HaveHandleMissingValueAndHandleUnknownReturnNan_Expec
139139
original = enc.inverse_transform(result)
140140

141141
pd.testing.assert_frame_equal(expected, original)
142+
143+
def test_num_cols(self):
144+
"""
145+
Test that BaseNEncoder produces the correct number of output columns.
146+
147+
Since the value 0 is reserved for encoding unseen values, there need to be enough digits to
148+
represent up to nvals + 1 distinct encodings, where nvals is the number of distinct input
149+
values. This is ceil(log(nvals + 1, base)) digits.
150+
151+
This test specifically checks the case where BaseNEncoder is initialized with
152+
handle_unknown='value' and handle_missing='value' (i.e. the defaults).
153+
"""
154+
def num_cols(nvals, base):
155+
"""Returns the number of columns output for a given number of distinct input values"""
156+
vals = [str(i) for i in range(nvals)]
157+
df = pd.DataFrame({'vals': vals})
158+
encoder = encoders.BaseNEncoder(base=base)
159+
encoder.fit(df)
160+
return len(list(encoder.transform(df)))
161+
162+
self.assertEqual(num_cols(1, 2), 1)
163+
self.assertEqual(num_cols(2, 2), 2)
164+
self.assertEqual(num_cols(3, 2), 2)
165+
self.assertEqual(num_cols(4, 2), 3)
166+
self.assertEqual(num_cols(7, 2), 3)
167+
self.assertEqual(num_cols(8, 2), 4)
168+
self.assertEqual(num_cols(62, 2), 6)
169+
self.assertEqual(num_cols(63, 2), 6)
170+
self.assertEqual(num_cols(64, 2), 7)
171+
self.assertEqual(num_cols(65, 2), 7)
172+
173+
# nvals = 0 returns the original dataframe unchanged, so it still has 1 column even though
174+
# logically there should be zero.
175+
self.assertEqual(num_cols(0, 2), 1)
176+
177+
self.assertEqual(num_cols(55, 7), 3)

tests/test_glmm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
# data definitions
77
X = th.create_dataset(n_rows=100)
8-
np_y = np.random.randn(100) > 0.5
8+
np_y = np.random.default_rng(42).standard_normal(100) > 0.5
99

1010
class TestGLMMEncoder(TestCase):
1111
def test_continuous(self):

0 commit comments

Comments
 (0)