Skip to content

Commit a0bad06

Browse files
GenusterzEdS15B3GCwq
authored andcommitted
BUG: Refactor LinearModel (mne-tools#13361)
1 parent e2707a0 commit a0bad06

File tree

3 files changed

+102
-62
lines changed

3 files changed

+102
-62
lines changed

doc/changes/dev/13361.bugfix.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
``model`` parameter of :class:`mne.decoding.LinearModel`
2+
will not be modified, use ``model_`` attribute to access the fitted model.
3+
To be compatible with all MNE-Python versions you can use
4+
``getattr(clf, "model_", getattr(clf, "model"))``
5+
The provided ``model`` is expected to be a supervised predictor,
6+
i.e. classifier or regressor (or :class:`sklearn.multiclass.OneVsRestClassifier`),
7+
otherwise an error will be raised.
8+
by `Gennadiy Belonosov`_.

mne/decoding/base.py

Lines changed: 68 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,25 @@
1717
TransformerMixin,
1818
clone,
1919
is_classifier,
20+
is_regressor,
2021
)
2122
from sklearn.linear_model import LogisticRegression
2223
from sklearn.metrics import check_scoring
2324
from sklearn.model_selection import KFold, StratifiedKFold, check_cv
24-
from sklearn.utils import check_array, check_X_y, indexable
25+
from sklearn.utils import indexable
2526
from sklearn.utils.validation import check_is_fitted
2627

2728
from ..parallel import parallel_func
28-
from ..utils import _check_option, _pl, _validate_type, logger, pinv, verbose, warn
29+
from ..utils import (
30+
_check_option,
31+
_pl,
32+
_validate_type,
33+
logger,
34+
pinv,
35+
verbose,
36+
warn,
37+
)
38+
from ._fixes import validate_data
2939
from ._ged import (
3040
_handle_restr_mat,
3141
_is_cov_pos_semidef,
@@ -340,7 +350,8 @@ class LinearModel(MetaEstimatorMixin, BaseEstimator):
340350
model : object | None
341351
A linear model from scikit-learn with a fit method
342352
that updates a ``coef_`` attribute.
343-
If None the model will be LogisticRegression.
353+
If None the model will be
354+
:class:`sklearn.linear_model.LogisticRegression`.
344355
345356
Attributes
346357
----------
@@ -364,46 +375,66 @@ class LinearModel(MetaEstimatorMixin, BaseEstimator):
364375
.. footbibliography::
365376
"""
366377

367-
# TODO: Properly refactor this using
368-
# https://github.com/scikit-learn/scikit-learn/issues/30237#issuecomment-2465572885
369378
_model_attr_wrap = (
370379
"transform",
380+
"fit_transform",
371381
"predict",
372382
"predict_proba",
373-
"_estimator_type",
374-
"__tags__",
383+
"predict_log_proba",
384+
"_estimator_type", # remove after sklearn 1.6
375385
"decision_function",
376386
"score",
377387
"classes_",
378388
)
379389

380390
def __init__(self, model=None):
381-
# TODO: We need to set this to get our tag checking to work properly
382-
if model is None:
383-
model = LogisticRegression(solver="liblinear")
384391
self.model = model
385392

386393
def __sklearn_tags__(self):
387394
"""Get sklearn tags."""
388-
from sklearn.utils import get_tags # added in 1.6
389-
390-
# fit method below does not allow sparse data via check_data, we could
391-
# eventually make it smarter if we had to
392-
tags = get_tags(self.model)
393-
tags.input_tags.sparse = False
395+
tags = super().__sklearn_tags__()
396+
model = self.model if self.model is not None else LogisticRegression()
397+
model_tags = model.__sklearn_tags__()
398+
tags.estimator_type = model_tags.estimator_type
399+
if tags.estimator_type is not None:
400+
model_type_tags = getattr(model_tags, f"{tags.estimator_type}_tags")
401+
setattr(tags, f"{tags.estimator_type}_tags", model_type_tags)
394402
return tags
395403

396404
def __getattr__(self, attr):
397405
"""Wrap to model for some attributes."""
398406
if attr in LinearModel._model_attr_wrap:
399-
return getattr(self.model, attr)
400-
elif attr == "fit_transform" and hasattr(self.model, "fit_transform"):
401-
return super().__getattr__(self, "_fit_transform")
402-
return super().__getattr__(self, attr)
407+
model = self.model_ if "model_" in self.__dict__ else self.model
408+
if attr == "fit_transform" and hasattr(model, "fit_transform"):
409+
return self._fit_transform
410+
else:
411+
return getattr(model, attr)
412+
else:
413+
raise AttributeError(
414+
f"'{type(self).__name__}' object has no attribute '{attr}'"
415+
)
403416

404417
def _fit_transform(self, X, y):
405418
return self.fit(X, y).transform(X)
406419

420+
def _validate_params(self, X):
421+
if self.model is not None:
422+
model = self.model
423+
if isinstance(model, MetaEstimatorMixin):
424+
model = model.estimator
425+
is_predictor = is_regressor(model) or is_classifier(model)
426+
if not is_predictor:
427+
raise ValueError(
428+
"Linear model should be a supervised predictor "
429+
"(classifier or regressor)"
430+
)
431+
432+
# For sklearn < 1.6
433+
try:
434+
self._check_n_features(X, reset=True)
435+
except AttributeError:
436+
pass
437+
407438
def fit(self, X, y, **fit_params):
408439
"""Estimate the coefficients of the linear model.
409440
@@ -424,25 +455,18 @@ def fit(self, X, y, **fit_params):
424455
self : instance of LinearModel
425456
Returns the modified instance.
426457
"""
427-
if y is not None:
428-
X = check_array(X)
429-
else:
430-
X, y = check_X_y(X, y)
431-
self.n_features_in_ = X.shape[1]
432-
if y is not None:
433-
y = check_array(y, dtype=None, ensure_2d=False, input_name="y")
434-
if y.ndim > 2:
435-
raise ValueError(
436-
f"LinearModel only accepts up to 2-dimensional y, got {y.shape} "
437-
"instead."
438-
)
458+
self._validate_params(X)
459+
X, y = validate_data(self, X, y, multi_output=True)
439460

440461
# fit the Model
441-
self.model.fit(X, y, **fit_params)
442-
self.model_ = self.model # for better sklearn compat
462+
self.model_ = (
463+
clone(self.model)
464+
if self.model is not None
465+
else LogisticRegression(solver="liblinear")
466+
)
467+
self.model_.fit(X, y, **fit_params)
443468

444469
# Computes patterns using Haufe's trick: A = Cov_X . W . Precision_Y
445-
446470
inv_Y = 1.0
447471
X = X - X.mean(0, keepdims=True)
448472
if y.ndim == 2 and y.shape[1] != 1:
@@ -454,12 +478,17 @@ def fit(self, X, y, **fit_params):
454478

455479
@property
456480
def filters_(self):
457-
if hasattr(self.model, "coef_"):
481+
if hasattr(self.model_, "coef_"):
458482
# Standard Linear Model
459-
filters = self.model.coef_
460-
elif hasattr(self.model.best_estimator_, "coef_"):
483+
filters = self.model_.coef_
484+
elif hasattr(self.model_, "estimators_"):
485+
# Linear model with OneVsRestClassifier
486+
filters = np.vstack([est.coef_ for est in self.model_.estimators_])
487+
elif hasattr(self.model_, "best_estimator_") and hasattr(
488+
self.model_.best_estimator_, "coef_"
489+
):
461490
# Linear Model with GridSearchCV
462-
filters = self.model.best_estimator_.coef_
491+
filters = self.model_.best_estimator_.coef_
463492
else:
464493
raise ValueError("model does not have a `coef_` attribute.")
465494
if filters.ndim == 2 and filters.shape[0] == 1:

mne/decoding/tests/test_base.py

Lines changed: 26 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,15 @@
2828
is_classifier,
2929
is_regressor,
3030
)
31+
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
3132
from sklearn.linear_model import LinearRegression, LogisticRegression, Ridge
3233
from sklearn.model_selection import (
3334
GridSearchCV,
3435
KFold,
3536
StratifiedKFold,
3637
cross_val_score,
3738
)
39+
from sklearn.multiclass import OneVsRestClassifier
3840
from sklearn.pipeline import make_pipeline
3941
from sklearn.preprocessing import StandardScaler
4042
from sklearn.utils.estimator_checks import parametrize_with_checks
@@ -93,12 +95,11 @@ def _make_data(n_samples=1000, n_features=5, n_targets=3):
9395
return X, Y, A
9496

9597

96-
@pytest.mark.filterwarnings("ignore:invalid value encountered in cast.*:RuntimeWarning")
9798
def test_get_coef():
9899
"""Test getting linear coefficients (filters/patterns) from estimators."""
99-
lm_classification = LinearModel()
100+
lm_classification = LinearModel(LogisticRegression(solver="liblinear"))
100101
assert hasattr(lm_classification, "__sklearn_tags__")
101-
if check_version("sklearn", "1.4"):
102+
if check_version("sklearn", "1.6"):
102103
print(lm_classification.__sklearn_tags__())
103104
assert is_classifier(lm_classification.model)
104105
assert is_classifier(lm_classification)
@@ -200,19 +201,19 @@ def inverse_transform(self, X):
200201
# Retrieve final linear model
201202
filters = get_coef(clf, "filters_", False)
202203
if hasattr(clf, "steps"):
203-
if hasattr(clf.steps[-1][-1].model, "best_estimator_"):
204+
if hasattr(clf.steps[-1][-1].model_, "best_estimator_"):
204205
# Linear Model with GridSearchCV
205-
coefs = clf.steps[-1][-1].model.best_estimator_.coef_
206+
coefs = clf.steps[-1][-1].model_.best_estimator_.coef_
206207
else:
207208
# Standard Linear Model
208-
coefs = clf.steps[-1][-1].model.coef_
209+
coefs = clf.steps[-1][-1].model_.coef_
209210
else:
210-
if hasattr(clf.model, "best_estimator_"):
211+
if hasattr(clf.model_, "best_estimator_"):
211212
# Linear Model with GridSearchCV
212-
coefs = clf.model.best_estimator_.coef_
213+
coefs = clf.model_.best_estimator_.coef_
213214
else:
214215
# Standard Linear Model
215-
coefs = clf.model.coef_
216+
coefs = clf.model_.coef_
216217
if coefs.ndim == 2 and coefs.shape[0] == 1:
217218
coefs = coefs[0]
218219
assert_array_equal(filters, coefs)
@@ -280,9 +281,7 @@ def test_get_coef_multiclass(n_features, n_targets):
280281
lm = LinearModel(LinearRegression())
281282
assert not hasattr(lm, "model_")
282283
lm.fit(X, Y)
283-
# TODO: modifying non-underscored `model` is a sklearn no-no, maybe should be a
284-
# metaestimator?
285-
assert lm.model is lm.model_
284+
assert lm.model is not lm.model_
286285
assert_array_equal(lm.filters_.shape, lm.patterns_.shape)
287286
if n_targets == 1:
288287
want_shape = (n_features,)
@@ -328,9 +327,6 @@ def test_get_coef_multiclass(n_features, n_targets):
328327
(3, 1, 2),
329328
],
330329
)
331-
# TODO: Need to fix this properly in LinearModel
332-
@pytest.mark.filterwarnings("ignore:'multi_class' was depr.*:FutureWarning")
333-
@pytest.mark.filterwarnings("ignore:lbfgs failed to converge.*:")
334330
def test_get_coef_multiclass_full(n_classes, n_channels, n_times):
335331
"""Test a full example with pattern extraction."""
336332
data = np.zeros((10 * n_classes, n_channels, n_times))
@@ -345,7 +341,7 @@ def test_get_coef_multiclass_full(n_classes, n_channels, n_times):
345341
clf = make_pipeline(
346342
Scaler(epochs.info),
347343
Vectorizer(),
348-
LinearModel(LogisticRegression(random_state=0, multi_class="ovr")),
344+
LinearModel(OneVsRestClassifier(LogisticRegression(random_state=0))),
349345
)
350346
scorer = "roc_auc_ovr_weighted"
351347
time_gen = GeneralizingEstimator(clf, scorer, verbose=True)
@@ -382,6 +378,20 @@ def test_linearmodel():
382378
wrong_X = rng.rand(n, n_features, 99)
383379
clf.fit(wrong_X, y)
384380

381+
# check fit_transform call
382+
clf = LinearModel(LinearDiscriminantAnalysis())
383+
_ = clf.fit_transform(X, y)
384+
385+
# check that model has to have coef_, RBF-SVM doesn't
386+
clf = LinearModel(svm.SVC(kernel="rbf"))
387+
with pytest.raises(ValueError, match="does not have a `coef_`"):
388+
clf.fit(X, y)
389+
390+
# check that model has to be a predictor
391+
clf = LinearModel(StandardScaler())
392+
with pytest.raises(ValueError, match="classifier or regressor"):
393+
clf.fit(X, y)
394+
385395
# check categorical target fit in standard linear model with GridSearchCV
386396
parameters = {"kernel": ["linear"], "C": [1, 10]}
387397
clf = LinearModel(
@@ -481,11 +491,4 @@ def test_cross_val_multiscore():
481491
@parametrize_with_checks([LinearModel(LogisticRegression())])
482492
def test_sklearn_compliance(estimator, check):
483493
"""Test LinearModel compliance with sklearn."""
484-
ignores = (
485-
"check_estimators_overwrite_params", # self.model changes!
486-
"check_dont_overwrite_parameters",
487-
"check_parameters_default_constructible",
488-
)
489-
if any(ignore in str(check) for ignore in ignores):
490-
return
491494
check(estimator)

0 commit comments

Comments
 (0)