Skip to content

Commit 17c84d2

Browse files
nilichenKatrina Nirasbt
authored
Implementation of both use_clones and fit_base_estimators (#670)
* implementation of both use_clones and fit_base_estimators for EnsembleVoteClassifier, StackingClassifier and StackingCVClassifier as well as tests for these to parameters * some formatting changes, e.g., sort imports * updated CHANGELOG and user_guide * travis mvn fix * add conftest * mod conftest * mod conftest * mod conftest * mod conftest * upd travis * fix test case * minor updates Co-authored-by: Katrina Ni <[email protected]> Co-authored-by: rasbt <[email protected]>
1 parent 6b457a1 commit 17c84d2

File tree

11 files changed

+1025
-811
lines changed

11 files changed

+1025
-811
lines changed

ci/.travis_test.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,4 +44,4 @@ if [[ "$NOTEBOOKS" == "true" ]]; then
4444
find sources -name "*.ipynb" -not -path "sources/user_guide/image/*" -exec jupyter nbconvert --to notebook --execute {} \;
4545

4646
fi
47-
fi
47+
fi

docs/sources/CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ The CHANGELOG for the current development version is available at
2323

2424
##### Changes
2525

26-
- -
26+
- Implemented both `use_clones` and `fit_base_estimators` (previously `refit` in `EnsembleVoteClassifier`) for `EnsembleVoteClassifier` and `StackingClassifier`. ([#670](https://github.com/rasbt/mlxtend/pull/670) via [Katrina Ni](https://github.com/nilichen))
2727

2828
##### Bug Fixes
2929

docs/sources/user_guide/classifier/EnsembleVoteClassifier.ipynb

Lines changed: 652 additions & 346 deletions
Large diffs are not rendered by default.

docs/sources/user_guide/classifier/StackingClassifier.ipynb

Lines changed: 103 additions & 319 deletions
Large diffs are not rendered by default.

mlxtend/classifier/ensemble_vote.py

Lines changed: 45 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,15 @@
88
#
99
# License: BSD 3 clause
1010

11-
from sklearn.base import BaseEstimator
12-
from sklearn.base import ClassifierMixin
13-
from sklearn.base import TransformerMixin
14-
from sklearn.preprocessing import LabelEncoder
15-
from sklearn.base import clone
11+
import numpy as np
12+
import warnings
13+
from sklearn.base import (BaseEstimator, ClassifierMixin, TransformerMixin,
14+
clone)
1615
from sklearn.exceptions import NotFittedError
17-
from ..externals.name_estimators import _name_estimators
16+
from sklearn.preprocessing import LabelEncoder
17+
1818
from ..externals import six
19-
import numpy as np
19+
from ..externals.name_estimators import _name_estimators
2020

2121

2222
class EnsembleVoteClassifier(BaseEstimator, ClassifierMixin, TransformerMixin):
@@ -28,9 +28,10 @@ class EnsembleVoteClassifier(BaseEstimator, ClassifierMixin, TransformerMixin):
2828
clfs : array-like, shape = [n_classifiers]
2929
A list of classifiers.
3030
Invoking the `fit` method on the `VotingClassifier` will fit clones
31-
of those original classifiers that will
31+
of those original classifiers
3232
be stored in the class attribute
33-
`self.clfs_` if `refit=True` (default).
33+
if `use_clones=True` (default) and
34+
`fit_base_estimators=True` (default).
3435
voting : str, {'hard', 'soft'} (default='hard')
3536
If 'hard', uses predicted class labels for majority rule voting.
3637
Else if 'soft', predicts the class label based on the argmax of
@@ -47,22 +48,33 @@ class EnsembleVoteClassifier(BaseEstimator, ClassifierMixin, TransformerMixin):
4748
- `verbose=2`: Prints info about the parameters of the clf being fitted
4849
- `verbose>2`: Changes `verbose` param of the underlying clf to
4950
self.verbose - 2
50-
refit : bool (default: True)
51+
use_clones : bool (default: True)
52+
Clones the classifiers for stacking classification if True (default)
53+
or else uses the original ones, which will be refitted on the dataset
54+
upon calling the `fit` method. Hence, if use_clones=True, the original
55+
input classifiers will remain unmodified upon using the
56+
StackingClassifier's `fit` method.
57+
Setting `use_clones=False` is
58+
recommended if you are working with estimators that are supporting
59+
the scikit-learn fit/predict API interface but are not compatible
60+
to scikit-learn's `clone` function.
61+
fit_base_estimators : bool (default: True)
5162
Refits classifiers in `clfs` if True; uses references to the `clfs`,
5263
otherwise (assumes that the classifiers were already fit).
53-
Note: refit=False is incompatible to mist scikit-learn wrappers!
64+
Note: fit_base_estimators=False will enforce use_clones to be False,
65+
and is incompatible to most scikit-learn wrappers!
5466
For instance, if any form of cross-validation is performed
5567
this would require the re-fitting classifiers to training folds, which
56-
would raise a NotFitterError if refit=False.
68+
would raise a NotFitterError if fit_base_estimators=False.
5769
(New in mlxtend v0.6.)
5870
5971
Attributes
6072
----------
6173
classes_ : array-like, shape = [n_predictions]
6274
clf : array-like, shape = [n_predictions]
63-
The unmodified input classifiers
75+
The input classifiers; may be overwritten if `use_clones=False`
6476
clf_ : array-like, shape = [n_predictions]
65-
Fitted clones of the input classifiers
77+
Fitted input classifiers; clones if `use_clones=True`
6678
6779
Examples
6880
--------
@@ -96,15 +108,19 @@ class EnsembleVoteClassifier(BaseEstimator, ClassifierMixin, TransformerMixin):
96108
For more usage examples, please see
97109
http://rasbt.github.io/mlxtend/user_guide/classifier/EnsembleVoteClassifier/
98110
"""
111+
99112
def __init__(self, clfs, voting='hard',
100-
weights=None, verbose=0, refit=True):
113+
weights=None, verbose=0,
114+
use_clones=True,
115+
fit_base_estimators=True):
101116

102117
self.clfs = clfs
103118
self.named_clfs = {key: value for key, value in _name_estimators(clfs)}
104119
self.voting = voting
105120
self.weights = weights
106121
self.verbose = verbose
107-
self.refit = refit
122+
self.use_clones = use_clones
123+
self.fit_base_estimators = fit_base_estimators
108124

109125
def fit(self, X, y, sample_weight=None):
110126
"""Learn weight coefficients from training data for each classifier.
@@ -146,12 +162,17 @@ def fit(self, X, y, sample_weight=None):
146162
self.le_.fit(y)
147163
self.classes_ = self.le_.classes_
148164

149-
if not self.refit:
150-
self.clfs_ = [clf for clf in self.clfs]
165+
if not self.fit_base_estimators:
166+
warnings.warn("fit_base_estimators=False "
167+
"enforces use_clones to be `False`")
168+
self.use_clones = False
151169

170+
if self.use_clones:
171+
self.clfs_ = clone(self.clfs)
152172
else:
153-
self.clfs_ = [clone(clf) for clf in self.clfs]
173+
self.clfs_ = self.clfs
154174

175+
if self.fit_base_estimators:
155176
if self.verbose > 0:
156177
print("Fitting %d classifiers..." % (len(self.clfs)))
157178

@@ -204,8 +225,8 @@ def predict(self, X):
204225
predictions = self._predict(X)
205226

206227
maj = np.apply_along_axis(lambda x:
207-
np.argmax(np.bincount(x,
208-
weights=self.weights)),
228+
np.argmax(np.bincount(
229+
x, weights=self.weights)),
209230
axis=1,
210231
arr=predictions)
211232

@@ -266,15 +287,15 @@ def get_params(self, deep=True):
266287
for key, value in six.iteritems(step.get_params(deep=True)):
267288
out['%s__%s' % (name, key)] = value
268289

269-
for key, value in six.iteritems(super(EnsembleVoteClassifier,
270-
self).get_params(deep=False)):
290+
for key, value in six.iteritems(
291+
super(EnsembleVoteClassifier, self).get_params(deep=False)):
271292
out['%s' % key] = value
272293
return out
273294

274295
def _predict(self, X):
275296
"""Collect results from clf.predict calls."""
276297

277-
if self.refit:
298+
if self.fit_base_estimators:
278299
return np.asarray([clf.predict(X) for clf in self.clfs_]).T
279300
else:
280301
return np.asarray([self.le_.transform(clf.predict(X))

mlxtend/classifier/stacking_classification.py

Lines changed: 43 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,15 @@
88
#
99
# License: BSD 3 clause
1010

11+
import numpy as np
12+
import warnings
13+
from scipy import sparse
14+
from sklearn.base import TransformerMixin, clone
15+
1116
from ..externals.estimator_checks import check_is_fitted
1217
from ..externals.name_estimators import _name_estimators
1318
from ..utils.base_compostion import _BaseXComposition
1419
from ._base_classification import _BaseStackingClassifier
15-
from scipy import sparse
16-
from sklearn.base import TransformerMixin
17-
from sklearn.base import clone
18-
import numpy as np
1920

2021

2122
class StackingClassifier(_BaseXComposition, _BaseStackingClassifier,
@@ -30,7 +31,8 @@ class StackingClassifier(_BaseXComposition, _BaseStackingClassifier,
3031
Invoking the `fit` method on the `StackingClassifer` will fit clones
3132
of these original classifiers that will
3233
be stored in the class attribute
33-
`self.clfs_`.
34+
`self.clfs_` if `use_clones=True` (default) and
35+
`fit_base_estimators=True` (default).
3436
meta_classifier : object
3537
The meta-classifier to be fitted on the ensemble of
3638
classifiers
@@ -77,6 +79,16 @@ class StackingClassifier(_BaseXComposition, _BaseStackingClassifier,
7779
recommended if you are working with estimators that are supporting
7880
the scikit-learn fit/predict API interface but are not compatible
7981
to scikit-learn's `clone` function.
82+
fit_base_estimators: bool (default: True)
83+
Refits classifiers in `classifiers` if True; uses references to the
84+
`classifiers`, otherwise (assumes that the classifiers were
85+
already fit).
86+
Note: fit_base_estimators=False will enforce use_clones to be False,
87+
and is incompatible to most scikit-learn wrappers!
88+
For instance, if any form of cross-validation is performed
89+
this would require the re-fitting classifiers to training folds, which
90+
would raise a NotFitterError if fit_base_estimators=False.
91+
(New in mlxtend v0.6.)
8092
8193
Attributes
8294
----------
@@ -100,7 +112,7 @@ def __init__(self, classifiers, meta_classifier,
100112
average_probas=False, verbose=0,
101113
use_features_in_secondary=False,
102114
store_train_meta_features=False,
103-
use_clones=True):
115+
use_clones=True, fit_base_estimators=True):
104116

105117
self.classifiers = classifiers
106118
self.meta_classifier = meta_classifier
@@ -117,6 +129,7 @@ def __init__(self, classifiers, meta_classifier,
117129
self.use_features_in_secondary = use_features_in_secondary
118130
self.store_train_meta_features = store_train_meta_features
119131
self.use_clones = use_clones
132+
self.fit_base_estimators = fit_base_estimators
120133

121134
@property
122135
def named_classifiers(self):
@@ -143,33 +156,39 @@ def fit(self, X, y, sample_weight=None):
143156
self : object
144157
145158
"""
159+
if not self.fit_base_estimators:
160+
warnings.warn("fit_base_estimators=False "
161+
"enforces use_clones to be `False`")
162+
self.use_clones = False
163+
146164
if self.use_clones:
147165
self.clfs_ = clone(self.classifiers)
148166
self.meta_clf_ = clone(self.meta_classifier)
149167
else:
150168
self.clfs_ = self.classifiers
151169
self.meta_clf_ = self.meta_classifier
152170

153-
if self.verbose > 0:
154-
print("Fitting %d classifiers..." % (len(self.classifiers)))
171+
if self.fit_base_estimators:
172+
if self.verbose > 0:
173+
print("Fitting %d classifiers..." % (len(self.classifiers)))
155174

156-
for clf in self.clfs_:
175+
for clf in self.clfs_:
157176

158-
if self.verbose > 0:
159-
i = self.clfs_.index(clf) + 1
160-
print("Fitting classifier%d: %s (%d/%d)" %
161-
(i, _name_estimators((clf,))[0][0], i, len(self.clfs_)))
162-
163-
if self.verbose > 2:
164-
if hasattr(clf, 'verbose'):
165-
clf.set_params(verbose=self.verbose - 2)
166-
167-
if self.verbose > 1:
168-
print(_name_estimators((clf,))[0][1])
169-
if sample_weight is None:
170-
clf.fit(X, y)
171-
else:
172-
clf.fit(X, y, sample_weight=sample_weight)
177+
if self.verbose > 0:
178+
i = self.clfs_.index(clf) + 1
179+
print("Fitting classifier%d: %s (%d/%d)" %
180+
(i, _name_estimators((clf,))[0][0], i, len(self.clfs_)))
181+
182+
if self.verbose > 2:
183+
if hasattr(clf, 'verbose'):
184+
clf.set_params(verbose=self.verbose - 2)
185+
186+
if self.verbose > 1:
187+
print(_name_estimators((clf,))[0][1])
188+
if sample_weight is None:
189+
clf.fit(X, y)
190+
else:
191+
clf.fit(X, y, sample_weight=sample_weight)
173192

174193
meta_features = self.predict_meta_features(X)
175194

mlxtend/classifier/stacking_cv_classification.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,17 @@
99
#
1010
# License: BSD 3 clause
1111

12-
from ..externals.name_estimators import _name_estimators
13-
from ..externals.estimator_checks import check_is_fitted
14-
from ..utils.base_compostion import _BaseXComposition
15-
from ._base_classification import _BaseStackingClassifier
1612
import numpy as np
1713
from scipy import sparse
18-
from sklearn.base import TransformerMixin
19-
from sklearn.base import clone
14+
from sklearn.base import TransformerMixin, clone
2015
from sklearn.model_selection import cross_val_predict
2116
from sklearn.model_selection._split import check_cv
17+
18+
from ..externals.estimator_checks import check_is_fitted
19+
from ..externals.name_estimators import _name_estimators
20+
from ..utils.base_compostion import _BaseXComposition
21+
from ._base_classification import _BaseStackingClassifier
22+
2223
# from sklearn.utils import check_X_y
2324

2425

@@ -35,7 +36,7 @@ class StackingCVClassifier(_BaseXComposition, _BaseStackingClassifier,
3536
A list of classifiers.
3637
Invoking the `fit` method on the `StackingCVClassifer` will fit clones
3738
of these original classifiers that will
38-
be stored in the class attribute `self.clfs_`.
39+
be stored in the class attribute `self.clfs_` if `use_clones=True`.
3940
meta_classifier : object
4041
The meta-classifier to be fitted on the ensemble of
4142
classifiers
@@ -139,6 +140,7 @@ class StackingCVClassifier(_BaseXComposition, _BaseStackingClassifier,
139140
http://rasbt.github.io/mlxtend/user_guide/classifier/StackingCVClassifier/
140141
141142
"""
143+
142144
def __init__(self, classifiers, meta_classifier,
143145
use_probas=False, drop_proba_col=None,
144146
cv=2, shuffle=True,
@@ -245,10 +247,10 @@ def fit(self, X, y, groups=None, sample_weight=None):
245247
print(_name_estimators((model,))[0][1])
246248

247249
prediction = cross_val_predict(
248-
model, X, y, groups=groups, cv=final_cv,
249-
n_jobs=self.n_jobs, fit_params=fit_params,
250-
verbose=self.verbose, pre_dispatch=self.pre_dispatch,
251-
method='predict_proba' if self.use_probas else 'predict')
250+
model, X, y, groups=groups, cv=final_cv,
251+
n_jobs=self.n_jobs, fit_params=fit_params,
252+
verbose=self.verbose, pre_dispatch=self.pre_dispatch,
253+
method='predict_proba' if self.use_probas else 'predict')
252254

253255
if not self.use_probas:
254256
prediction = prediction[:, np.newaxis]

0 commit comments

Comments
 (0)