@@ -114,6 +114,17 @@ class BorutaPy(BaseEstimator, TransformerMixin):
114
114
- 1: displays iteration number
115
115
- 2: which features have been selected already
116
116
117
+ early_stopping : bool, default = False
118
+ Whether to use early stopping to terminate the selection process
119
+ before reaching `max_iter` iterations if the algorithm cannot
120
+ confirm a tentative feature for `n_iter_no_change` iterations.
121
+ Will speed up the process at a cost of a possibility of a
122
+ worse result.
123
+
124
+ n_iter_no_change : int, default = 20
125
+ Ignored if `early_stopping` is False. The maximum amount of
126
+ iterations without confirming a tentative feature.
127
+
117
128
Attributes
118
129
----------
119
130
@@ -180,7 +191,8 @@ class BorutaPy(BaseEstimator, TransformerMixin):
180
191
"""
181
192
182
193
def __init__ (self , estimator , n_estimators = 1000 , perc = 100 , alpha = 0.05 ,
183
- two_step = True , max_iter = 100 , random_state = None , verbose = 0 ):
194
+ two_step = True , max_iter = 100 , random_state = None , verbose = 0 ,
195
+ early_stopping = False , n_iter_no_change = 20 ):
184
196
self .estimator = estimator
185
197
self .n_estimators = n_estimators
186
198
self .perc = perc
@@ -189,6 +201,8 @@ def __init__(self, estimator, n_estimators=1000, perc=100, alpha=0.05,
189
201
self .max_iter = max_iter
190
202
self .random_state = random_state
191
203
self .verbose = verbose
204
+ self .early_stopping = early_stopping
205
+ self .n_iter_no_change = n_iter_no_change
192
206
self .__version__ = '0.3'
193
207
self ._is_lightgbm = 'lightgbm' in str (type (self .estimator ))
194
208
@@ -279,9 +293,25 @@ def _fit(self, X, y):
279
293
y = self ._validate_pandas_input (y )
280
294
281
295
self .random_state = check_random_state (self .random_state )
296
+
297
+ early_stopping = False
298
+ if self .early_stopping :
299
+ if self .n_iter_no_change >= self .max_iter :
300
+ if self .verbose > 0 :
301
+ print (
302
+ f"n_iter_no_change is bigger or equal to max_iter"
303
+ f"({ self .n_iter_no_change } >= { self .max_iter } ), "
304
+ f"early stopping will not be used."
305
+ )
306
+ else :
307
+ early_stopping = True
308
+
282
309
# setup variables for Boruta
283
310
n_sample , n_feat = X .shape
284
311
_iter = 1
312
+ # early stopping vars
313
+ _same_iters = 1
314
+ _last_dec_reg = None
285
315
# holds the decision about each feature:
286
316
# 0 - default state = tentative in original code
287
317
# 1 - accepted in original code
@@ -335,6 +365,21 @@ def _fit(self, X, y):
335
365
self ._print_results (dec_reg , _iter , 0 )
336
366
if _iter < self .max_iter :
337
367
_iter += 1
368
+
369
+ # early stopping
370
+ if early_stopping :
371
+ if _last_dec_reg is not None and (_last_dec_reg == dec_reg ).all ():
372
+ _same_iters += 1
373
+ if self .verbose > 0 :
374
+ print (
375
+ f"Early stopping: { _same_iters } out "
376
+ f"of { self .n_iter_no_change } "
377
+ )
378
+ else :
379
+ _same_iters = 1
380
+ _last_dec_reg = dec_reg .copy ()
381
+ if _same_iters > self .n_iter_no_change :
382
+ break
338
383
339
384
# we automatically apply R package's rough fix for tentative ones
340
385
confirmed = np .where (dec_reg == 1 )[0 ]
0 commit comments