Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/autofix.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ jobs:
name: Autoupdate changelog entry and headers
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/checkout@v5
with:
persist-credentials: false
- uses: actions/setup-python@v5
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/codeql-analysis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ jobs:

steps:
- name: Checkout repository
uses: actions/checkout@v4
uses: actions/checkout@v5
with:
persist-credentials: false

Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/credit.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ jobs:
GH_TOKEN: ${{ github.token }}
GITHUB_TOKEN: ${{ github.token }}
steps:
- uses: actions/checkout@v4
- uses: actions/checkout@v5
with:
persist-credentials: true
- uses: actions/setup-python@v5
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ jobs:
package:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/checkout@v5
with:
persist-credentials: false
- uses: actions/setup-python@v5
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ jobs:
runs-on: ubuntu-latest
timeout-minutes: 3
steps:
- uses: actions/checkout@v4
- uses: actions/checkout@v5
with:
persist-credentials: false
- uses: actions/setup-python@v5
Expand Down Expand Up @@ -91,7 +91,7 @@ jobs:
python: '3.10'
kind: old
steps:
- uses: actions/checkout@v4
- uses: actions/checkout@v5
with:
fetch-depth: 0
persist-credentials: false
Expand Down
8 changes: 8 additions & 0 deletions doc/changes/dev/13361.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
``model`` parameter of :class:`mne.decoding.LinearModel`
will not be modified, use ``model_`` attribute to access the fitted model.
To be compatible with all MNE-Python versions you can use
``getattr(clf, "model_", getattr(clf, "model"))``
The provided ``model`` is expected to be a supervised predictor,
i.e. classifier or regressor (or :class:`sklearn.multiclass.OneVsRestClassifier`),
otherwise an error will be raised.
by `Gennadiy Belonosov`_.
107 changes: 68 additions & 39 deletions mne/decoding/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,25 @@
TransformerMixin,
clone,
is_classifier,
is_regressor,
)
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import check_scoring
from sklearn.model_selection import KFold, StratifiedKFold, check_cv
from sklearn.utils import check_array, check_X_y, indexable
from sklearn.utils import indexable
from sklearn.utils.validation import check_is_fitted

from ..parallel import parallel_func
from ..utils import _check_option, _pl, _validate_type, logger, pinv, verbose, warn
from ..utils import (
_check_option,
_pl,
_validate_type,
logger,
pinv,
verbose,
warn,
)
from ._fixes import validate_data
from ._ged import (
_handle_restr_mat,
_is_cov_pos_semidef,
Expand Down Expand Up @@ -340,7 +350,8 @@ class LinearModel(MetaEstimatorMixin, BaseEstimator):
model : object | None
A linear model from scikit-learn with a fit method
that updates a ``coef_`` attribute.
If None the model will be LogisticRegression.
If None the model will be
:class:`sklearn.linear_model.LogisticRegression`.

Attributes
----------
Expand All @@ -364,46 +375,66 @@ class LinearModel(MetaEstimatorMixin, BaseEstimator):
.. footbibliography::
"""

# TODO: Properly refactor this using
# https://github.com/scikit-learn/scikit-learn/issues/30237#issuecomment-2465572885
_model_attr_wrap = (
"transform",
"fit_transform",
"predict",
"predict_proba",
"_estimator_type",
"__tags__",
"predict_log_proba",
"_estimator_type", # remove after sklearn 1.6
"decision_function",
"score",
"classes_",
)

def __init__(self, model=None):
# TODO: We need to set this to get our tag checking to work properly
if model is None:
model = LogisticRegression(solver="liblinear")
self.model = model

def __sklearn_tags__(self):
"""Get sklearn tags."""
from sklearn.utils import get_tags # added in 1.6

# fit method below does not allow sparse data via check_data, we could
# eventually make it smarter if we had to
tags = get_tags(self.model)
tags.input_tags.sparse = False
tags = super().__sklearn_tags__()
model = self.model if self.model is not None else LogisticRegression()
model_tags = model.__sklearn_tags__()
tags.estimator_type = model_tags.estimator_type
if tags.estimator_type is not None:
model_type_tags = getattr(model_tags, f"{tags.estimator_type}_tags")
setattr(tags, f"{tags.estimator_type}_tags", model_type_tags)
return tags

def __getattr__(self, attr):
"""Wrap to model for some attributes."""
if attr in LinearModel._model_attr_wrap:
return getattr(self.model, attr)
elif attr == "fit_transform" and hasattr(self.model, "fit_transform"):
return super().__getattr__(self, "_fit_transform")
return super().__getattr__(self, attr)
model = self.model_ if "model_" in self.__dict__ else self.model
if attr == "fit_transform" and hasattr(model, "fit_transform"):
return self._fit_transform
else:
return getattr(model, attr)
else:
raise AttributeError(
f"'{type(self).__name__}' object has no attribute '{attr}'"
)

def _fit_transform(self, X, y):
return self.fit(X, y).transform(X)

def _validate_params(self, X):
if self.model is not None:
model = self.model
if isinstance(model, MetaEstimatorMixin):
model = model.estimator
is_predictor = is_regressor(model) or is_classifier(model)
if not is_predictor:
raise ValueError(
"Linear model should be a supervised predictor "
"(classifier or regressor)"
)

# For sklearn < 1.6
try:
self._check_n_features(X, reset=True)
except AttributeError:
pass

def fit(self, X, y, **fit_params):
"""Estimate the coefficients of the linear model.

Expand All @@ -424,25 +455,18 @@ def fit(self, X, y, **fit_params):
self : instance of LinearModel
Returns the modified instance.
"""
if y is not None:
X = check_array(X)
else:
X, y = check_X_y(X, y)
self.n_features_in_ = X.shape[1]
if y is not None:
y = check_array(y, dtype=None, ensure_2d=False, input_name="y")
if y.ndim > 2:
raise ValueError(
f"LinearModel only accepts up to 2-dimensional y, got {y.shape} "
"instead."
)
self._validate_params(X)
X, y = validate_data(self, X, y, multi_output=True)

# fit the Model
self.model.fit(X, y, **fit_params)
self.model_ = self.model # for better sklearn compat
self.model_ = (
clone(self.model)
if self.model is not None
else LogisticRegression(solver="liblinear")
)
self.model_.fit(X, y, **fit_params)

# Computes patterns using Haufe's trick: A = Cov_X . W . Precision_Y

inv_Y = 1.0
X = X - X.mean(0, keepdims=True)
if y.ndim == 2 and y.shape[1] != 1:
Expand All @@ -454,12 +478,17 @@ def fit(self, X, y, **fit_params):

@property
def filters_(self):
if hasattr(self.model, "coef_"):
if hasattr(self.model_, "coef_"):
# Standard Linear Model
filters = self.model.coef_
elif hasattr(self.model.best_estimator_, "coef_"):
filters = self.model_.coef_
elif hasattr(self.model_, "estimators_"):
# Linear model with OneVsRestClassifier
filters = np.vstack([est.coef_ for est in self.model_.estimators_])
elif hasattr(self.model_, "best_estimator_") and hasattr(
self.model_.best_estimator_, "coef_"
):
# Linear Model with GridSearchCV
filters = self.model.best_estimator_.coef_
filters = self.model_.best_estimator_.coef_
else:
raise ValueError("model does not have a `coef_` attribute.")
if filters.ndim == 2 and filters.shape[0] == 1:
Expand Down
49 changes: 26 additions & 23 deletions mne/decoding/tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,15 @@
is_classifier,
is_regressor,
)
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.linear_model import LinearRegression, LogisticRegression, Ridge
from sklearn.model_selection import (
GridSearchCV,
KFold,
StratifiedKFold,
cross_val_score,
)
from sklearn.multiclass import OneVsRestClassifier
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.utils.estimator_checks import parametrize_with_checks
Expand Down Expand Up @@ -93,12 +95,11 @@ def _make_data(n_samples=1000, n_features=5, n_targets=3):
return X, Y, A


@pytest.mark.filterwarnings("ignore:invalid value encountered in cast.*:RuntimeWarning")
def test_get_coef():
"""Test getting linear coefficients (filters/patterns) from estimators."""
lm_classification = LinearModel()
lm_classification = LinearModel(LogisticRegression(solver="liblinear"))
assert hasattr(lm_classification, "__sklearn_tags__")
if check_version("sklearn", "1.4"):
if check_version("sklearn", "1.6"):
print(lm_classification.__sklearn_tags__())
assert is_classifier(lm_classification.model)
assert is_classifier(lm_classification)
Expand Down Expand Up @@ -200,19 +201,19 @@ def inverse_transform(self, X):
# Retrieve final linear model
filters = get_coef(clf, "filters_", False)
if hasattr(clf, "steps"):
if hasattr(clf.steps[-1][-1].model, "best_estimator_"):
if hasattr(clf.steps[-1][-1].model_, "best_estimator_"):
# Linear Model with GridSearchCV
coefs = clf.steps[-1][-1].model.best_estimator_.coef_
coefs = clf.steps[-1][-1].model_.best_estimator_.coef_
else:
# Standard Linear Model
coefs = clf.steps[-1][-1].model.coef_
coefs = clf.steps[-1][-1].model_.coef_
else:
if hasattr(clf.model, "best_estimator_"):
if hasattr(clf.model_, "best_estimator_"):
# Linear Model with GridSearchCV
coefs = clf.model.best_estimator_.coef_
coefs = clf.model_.best_estimator_.coef_
else:
# Standard Linear Model
coefs = clf.model.coef_
coefs = clf.model_.coef_
if coefs.ndim == 2 and coefs.shape[0] == 1:
coefs = coefs[0]
assert_array_equal(filters, coefs)
Expand Down Expand Up @@ -280,9 +281,7 @@ def test_get_coef_multiclass(n_features, n_targets):
lm = LinearModel(LinearRegression())
assert not hasattr(lm, "model_")
lm.fit(X, Y)
# TODO: modifying non-underscored `model` is a sklearn no-no, maybe should be a
# metaestimator?
assert lm.model is lm.model_
assert lm.model is not lm.model_
assert_array_equal(lm.filters_.shape, lm.patterns_.shape)
if n_targets == 1:
want_shape = (n_features,)
Expand Down Expand Up @@ -328,9 +327,6 @@ def test_get_coef_multiclass(n_features, n_targets):
(3, 1, 2),
],
)
# TODO: Need to fix this properly in LinearModel
@pytest.mark.filterwarnings("ignore:'multi_class' was depr.*:FutureWarning")
@pytest.mark.filterwarnings("ignore:lbfgs failed to converge.*:")
def test_get_coef_multiclass_full(n_classes, n_channels, n_times):
"""Test a full example with pattern extraction."""
data = np.zeros((10 * n_classes, n_channels, n_times))
Expand All @@ -345,7 +341,7 @@ def test_get_coef_multiclass_full(n_classes, n_channels, n_times):
clf = make_pipeline(
Scaler(epochs.info),
Vectorizer(),
LinearModel(LogisticRegression(random_state=0, multi_class="ovr")),
LinearModel(OneVsRestClassifier(LogisticRegression(random_state=0))),
)
scorer = "roc_auc_ovr_weighted"
time_gen = GeneralizingEstimator(clf, scorer, verbose=True)
Expand Down Expand Up @@ -382,6 +378,20 @@ def test_linearmodel():
wrong_X = rng.rand(n, n_features, 99)
clf.fit(wrong_X, y)

# check fit_transform call
clf = LinearModel(LinearDiscriminantAnalysis())
_ = clf.fit_transform(X, y)

# check that model has to have coef_, RBF-SVM doesn't
clf = LinearModel(svm.SVC(kernel="rbf"))
with pytest.raises(ValueError, match="does not have a `coef_`"):
clf.fit(X, y)

# check that model has to be a predictor
clf = LinearModel(StandardScaler())
with pytest.raises(ValueError, match="classifier or regressor"):
clf.fit(X, y)

# check categorical target fit in standard linear model with GridSearchCV
parameters = {"kernel": ["linear"], "C": [1, 10]}
clf = LinearModel(
Expand Down Expand Up @@ -481,11 +491,4 @@ def test_cross_val_multiscore():
@parametrize_with_checks([LinearModel(LogisticRegression())])
def test_sklearn_compliance(estimator, check):
"""Test LinearModel compliance with sklearn."""
ignores = (
"check_estimators_overwrite_params", # self.model changes!
"check_dont_overwrite_parameters",
"check_parameters_default_constructible",
)
if any(ignore in str(check) for ignore in ignores):
return
check(estimator)
Loading