Skip to content

Commit f2f1e3c

Browse files
authored
Early stopping support (#94)
* Early stopping support * Typo fix
1 parent cc56e2f commit f2f1e3c

File tree

1 file changed

+46
-1
lines changed

1 file changed

+46
-1
lines changed

boruta/boruta_py.py

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,17 @@ class BorutaPy(BaseEstimator, TransformerMixin):
114114
- 1: displays iteration number
115115
- 2: which features have been selected already
116116
117+
early_stopping : bool, default = False
118+
Whether to use early stopping to terminate the selection process
119+
before reaching `max_iter` iterations if the algorithm cannot
120+
confirm a tentative feature for `n_iter_no_change` iterations.
121+
Will speed up the process at a cost of a possibility of a
122+
worse result.
123+
124+
n_iter_no_change : int, default = 20
125+
Ignored if `early_stopping` is False. The maximum amount of
126+
iterations without confirming a tentative feature.
127+
117128
Attributes
118129
----------
119130
@@ -180,7 +191,8 @@ class BorutaPy(BaseEstimator, TransformerMixin):
180191
"""
181192

182193
def __init__(self, estimator, n_estimators=1000, perc=100, alpha=0.05,
183-
two_step=True, max_iter=100, random_state=None, verbose=0):
194+
two_step=True, max_iter=100, random_state=None, verbose=0,
195+
early_stopping=False, n_iter_no_change=20):
184196
self.estimator = estimator
185197
self.n_estimators = n_estimators
186198
self.perc = perc
@@ -189,6 +201,8 @@ def __init__(self, estimator, n_estimators=1000, perc=100, alpha=0.05,
189201
self.max_iter = max_iter
190202
self.random_state = random_state
191203
self.verbose = verbose
204+
self.early_stopping = early_stopping
205+
self.n_iter_no_change = n_iter_no_change
192206
self.__version__ = '0.3'
193207
self._is_lightgbm = 'lightgbm' in str(type(self.estimator))
194208

@@ -279,9 +293,25 @@ def _fit(self, X, y):
279293
y = self._validate_pandas_input(y)
280294

281295
self.random_state = check_random_state(self.random_state)
296+
297+
early_stopping = False
298+
if self.early_stopping:
299+
if self.n_iter_no_change >= self.max_iter:
300+
if self.verbose > 0:
301+
print(
302+
f"n_iter_no_change is bigger or equal to max_iter"
303+
f"({self.n_iter_no_change} >= {self.max_iter}), "
304+
f"early stopping will not be used."
305+
)
306+
else:
307+
early_stopping = True
308+
282309
# setup variables for Boruta
283310
n_sample, n_feat = X.shape
284311
_iter = 1
312+
# early stopping vars
313+
_same_iters = 1
314+
_last_dec_reg = None
285315
# holds the decision about each feature:
286316
# 0 - default state = tentative in original code
287317
# 1 - accepted in original code
@@ -335,6 +365,21 @@ def _fit(self, X, y):
335365
self._print_results(dec_reg, _iter, 0)
336366
if _iter < self.max_iter:
337367
_iter += 1
368+
369+
# early stopping
370+
if early_stopping:
371+
if _last_dec_reg is not None and (_last_dec_reg == dec_reg).all():
372+
_same_iters += 1
373+
if self.verbose > 0:
374+
print(
375+
f"Early stopping: {_same_iters} out "
376+
f"of {self.n_iter_no_change}"
377+
)
378+
else:
379+
_same_iters = 1
380+
_last_dec_reg = dec_reg.copy()
381+
if _same_iters > self.n_iter_no_change:
382+
break
338383

339384
# we automatically apply R package's rough fix for tentative ones
340385
confirmed = np.where(dec_reg == 1)[0]

0 commit comments

Comments
 (0)