Skip to content

Commit 02a20aa

Browse files
Merge pull request #339 from GLevV/catboost-patch
Bugfix
2 parents d737e17 + fa9b350 commit 02a20aa

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

category_encoders/cat_boost.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,7 @@ def _transform(self, X_in, y, mapping=None):
264264
raise ValueError('Columns to be encoded can not contain new values')
265265

266266
if y is None: # Replace level with its mean target; if level occurs only once, use global mean
267-
level_means = ((colmap['sum'] + self._mean) / (colmap['count'] + self.a)).where(level_notunique, self._mean)
267+
level_means = ((colmap['sum'] + self._mean * self.a) / (colmap['count'] + self.a)).where(level_notunique, self._mean)
268268
X[col] = X[col].map(level_means)
269269
else:
270270
# Simulation of CatBoost implementation, which calculates leave-one-out on the fly.
@@ -277,7 +277,7 @@ def _transform(self, X_in, y, mapping=None):
277277
# As a workaround, we cast the grouping column as string.
278278
# See: issue #209
279279
temp = y.groupby(X[col].astype(str)).agg(['cumsum', 'cumcount'])
280-
X[col] = (temp['cumsum'] - y + self._mean) / (temp['cumcount'] + self.a)
280+
X[col] = (temp['cumsum'] - y + self._mean * self.a) / (temp['cumcount'] + self.a)
281281

282282
if self.handle_unknown == 'value':
283283
if X[col].dtype.name == 'category':

0 commit comments

Comments
 (0)