Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/changes/dev/13393.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Make decoding classes sklearn-compliant, by `Gennadiy Belonosov`_.
41 changes: 41 additions & 0 deletions mne/decoding/_fixes.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,3 +130,44 @@ def validate_data(
out = X, y

return out


def _check_n_features_3d(estimator, X, reset):
"""Set the `n_features_in_` attribute, or check against it on an estimator.

Sklearn takes n_features from X.shape[1], but we need X.shape[-1]

Parameters
----------
estimator : estimator instance
The estimator to validate the input for.

X : {ndarray, sparse matrix} of shape ([n_epochs], n_samples, n_features)
The input samples.

reset : bool
If True, the `n_features_in_` attribute is set to `X.shape[1]`.
If False and the attribute exists, then check that it is equal to
`X.shape[1]`. If False and the attribute does *not* exist, then
the check is skipped.
.. note::
It is recommended to call reset=True in `fit` and in the first
call to `partial_fit`. All other methods that validate `X`
should set `reset=False`.
"""
n_features = X.shape[-1]
if reset:
estimator.n_features_in_ = n_features
return

if not hasattr(estimator, "n_features_in_"):
# Skip this check if the expected number of expected input features
# was not recorded by calling fit first. This is typically the case
# for stateless transformers.
return

if n_features != estimator.n_features_in_:
raise ValueError(
f"X has {n_features} features, but {estimator.__class__.__name__} "
f"is expecting {estimator.n_features_in_} features as input."
)
1 change: 1 addition & 0 deletions mne/decoding/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,7 @@ def __sklearn_tags__(self):
tags.target_tags.one_d_labels = True
tags.input_tags.two_d_array = True
tags.input_tags.three_d_array = True
tags.requires_fit = True
return tags


Expand Down
18 changes: 17 additions & 1 deletion mne/decoding/csp.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,13 @@ def __init__(
R_func=sum,
)

def __sklearn_tags__(self):
"""Tag the transformer."""
tags = super().__sklearn_tags__()
tags.target_tags.required = True
tags.target_tags.multi_output = True
return tags

def _validate_params(self, *, y):
_validate_type(self.n_components, int, "n_components")
if hasattr(self, "cov_est"):
Expand Down Expand Up @@ -187,7 +194,10 @@ def _validate_params(self, *, y):
self.classes_ = np.unique(y)
n_classes = len(self.classes_)
if n_classes < 2:
raise ValueError(f"n_classes must be >= 2, but got {n_classes} class")
raise ValueError(
"y should be a 1d array with more than two classes, "
f"but got {n_classes} class from {y}"
)
elif n_classes > 2 and self.component_order == "alternate":
raise ValueError(
"component_order='alternate' requires two classes, but data contains "
Expand Down Expand Up @@ -756,6 +766,12 @@ def __init__(
delattr(self, "cov_est")
delattr(self, "norm_trace")

def __sklearn_tags__(self):
"""Tag the transformer."""
tags = super().__sklearn_tags__()
tags.target_tags.multi_output = False
return tags

def fit(self, X, y):
"""Estimate the SPoC decomposition on epochs.

Expand Down
70 changes: 59 additions & 11 deletions mne/decoding/receptive_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from sklearn.metrics import r2_score

from ..utils import _validate_type, fill_doc, pinv
from ._fixes import _check_n_features_3d, validate_data
from .base import _check_estimator, get_coef
from .time_delaying_ridge import TimeDelayingRidge

Expand Down Expand Up @@ -125,7 +126,7 @@ def __init__(
self.tmax = tmax
self.sfreq = sfreq
self.feature_names = feature_names
self.estimator = 0.0 if estimator is None else estimator
self.estimator = estimator
self.fit_intercept = fit_intercept
self.scoring = scoring
self.patterns = patterns
Expand All @@ -152,6 +153,19 @@ def __repr__(self): # noqa: D105
s += f"scored ({self.scoring})"
return f"<ReceptiveField | {s}>"

def __sklearn_tags__(self):
"""..."""
from sklearn.utils import RegressorTags

tags = super().__sklearn_tags__()
tags.estimator_type = "regressor"
tags.regressor_tags = RegressorTags()
tags.input_tags.three_d_array = True
tags.target_tags.one_d_labels = True
tags.target_tags.multi_output = True
tags.target_tags.required = True
return tags

def _delay_and_reshape(self, X, y=None):
"""Delay and reshape the variables."""
if not isinstance(self.estimator_, TimeDelayingRidge):
Expand All @@ -169,6 +183,32 @@ def _delay_and_reshape(self, X, y=None):
y = y.reshape(-1, y.shape[-1], order="F")
return X, y

def _check_data(self, X, y=None, reset=False):
if reset:
X, y = validate_data(
self,
X=X,
y=y,
reset=reset,
validate_separately=( # to take care of 3D y
dict(allow_nd=True, ensure_2d=False),
dict(allow_nd=True, ensure_2d=False),
),
)
else:
X = validate_data(self, X=X, allow_nd=True, ensure_2d=False, reset=reset)
_check_n_features_3d(self, X, reset)
return X, y

def _validate_params(self, X):
if self.scoring not in _SCORERS.keys():
raise ValueError(
f"scoring must be one of {sorted(_SCORERS.keys())}, got {self.scoring}"
)
self.sfreq_ = float(self.sfreq)
if self.tmin > self.tmax:
raise ValueError(f"tmin ({self.tmin}) must be at most tmax ({self.tmax})")

def fit(self, X, y):
"""Fit a receptive field model.

Expand All @@ -184,22 +224,18 @@ def fit(self, X, y):
self : instance
The instance so you can chain operations.
"""
if self.scoring not in _SCORERS.keys():
raise ValueError(
f"scoring must be one of {sorted(_SCORERS.keys())}, got {self.scoring} "
)
self.sfreq_ = float(self.sfreq)
X, y = self._check_data(X, y, reset=True)
self._validate_params(X)
X, y, _, self._y_dim = self._check_dimensions(X, y)

if self.tmin > self.tmax:
raise ValueError(f"tmin ({self.tmin}) must be at most tmax ({self.tmax})")
# Initialize delays
self.delays_ = _times_to_delays(self.tmin, self.tmax, self.sfreq_)

# Define the slice that we should use in the middle
self.valid_samples_ = _delays_to_slice(self.delays_)

if isinstance(self.estimator, numbers.Real):
if self.estimator is None or isinstance(self.estimator, numbers.Real):
alpha = self.estimator if self.estimator is not None else 0.0
if self.fit_intercept is None:
self.fit_intercept_ = True
else:
Expand All @@ -208,7 +244,7 @@ def fit(self, X, y):
self.tmin,
self.tmax,
self.sfreq_,
alpha=self.estimator,
alpha=alpha,
fit_intercept=self.fit_intercept_,
n_jobs=self.n_jobs,
edge_correction=self.edge_correction,
Expand Down Expand Up @@ -259,6 +295,12 @@ def fit(self, X, y):

# Inverse-transform model weights
if self.patterns:
n_total_samples = n_times * n_epochs
if n_total_samples < 2:
raise ValueError(
"Cannot compute patterns with only one sample; "
f"got n_samples = {n_total_samples}."
)
if isinstance(self.estimator_, TimeDelayingRidge):
cov_ = self.estimator_.cov_ / float(n_times * n_epochs - 1)
y = y.reshape(-1, y.shape[-1], order="F")
Expand Down Expand Up @@ -300,7 +342,10 @@ def predict(self, X):
"""
if not hasattr(self, "delays_"):
raise NotFittedError("Estimator has not been fit yet.")

X, _ = self._check_data(X)
X, _, X_dim = self._check_dimensions(X, None, predict=True)[:3]

del _
# convert to sklearn and back
pred_shape = X.shape[:-1]
Expand Down Expand Up @@ -384,7 +429,10 @@ def _check_dimensions(self, X, y, predict=False):
)
else:
raise ValueError(
f"X must be shape (n_times[, n_epochs], n_features), got {X.shape}"
"X must be shape (n_times[, n_epochs], n_features), "
f"got {X.shape}. Reshape your data to 2D or 3D "
"(e.g., array.reshape(-1, 1) for a single feature, "
"or array.reshape(1, -1) for a single sample)."
)
if y is not None:
if X.shape[0] != y.shape[0]:
Expand Down
14 changes: 10 additions & 4 deletions mne/decoding/search_light.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,8 +190,11 @@ def _transform(self, X, method):
y_pred = np.concatenate(y_pred, axis=1)
if orig_method == "transform":
y_pred = y_pred.astype(X.dtype)
if orig_method == "predict_proba" and not is_nd:
y_pred = y_pred[:, 0, :]
elif (
orig_method in ("predict", "predict_proba", "decision_function")
and not is_nd
):
y_pred = y_pred.squeeze()
return y_pred

def transform(self, X):
Expand Down Expand Up @@ -525,8 +528,11 @@ def _transform(self, X, method):
y_pred = np.concatenate(y_pred, axis=2)
if orig_method == "transform":
y_pred = y_pred.astype(X.dtype)
if orig_method == "predict_proba" and not is_nd:
y_pred = y_pred[:, 0, 0, :]
if (
orig_method in ("predict", "predict_proba", "decision_function")
and not is_nd
):
y_pred = y_pred.squeeze()
return y_pred

def transform(self, X):
Expand Down
9 changes: 2 additions & 7 deletions mne/decoding/tests/test_ged.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,13 +126,8 @@ def _mock_mod_ged_callable(evals, evecs, covs, **kwargs):
cov_callable=[partial(_mock_cov_callable, cov_method_params=dict(reg="empirical"))],
mod_ged_callable=[_mock_mod_ged_callable],
dec_type=["single", "multi"],
# XXX: Not covering "ssd" here because test_ssd.py works with 2D data.
# Need to fix its tests first.
restr_type=[
"restricting",
"whitening",
],
R_func=[partial(np.sum, axis=0)],
restr_type=["restricting", "whitening"],
R_func=[None, partial(np.sum, axis=0)],
)

ged_estimators = [_GEDTransformer(**p) for p in ParameterGrid(param_grid)]
Expand Down
35 changes: 10 additions & 25 deletions mne/decoding/tests/test_receptive_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -589,22 +589,12 @@ def test_linalg_warning():
@parametrize_with_checks([TimeDelayingRidge(0, 10, 1.0, 0.1, "laplacian", n_jobs=1)])
def test_tdr_sklearn_compliance(estimator, check):
"""Test sklearn estimator compliance."""
# We don't actually comply with a bunch of the regressor specs :(
pytest.importorskip("sklearn", minversion="1.4") # TODO VERSION remove on 1.4+
ignores = (
"check_supervised_y_no_nan",
"check_regressor",
"check_parameters_default_constructible",
"check_estimators_unfitted",
"_invariance",
"check_complex_data",
"check_estimators_empty_data_messages",
"check_estimators_nan_inf",
"check_supervised_y_2d",
"check_n_features_in",
"check_fit2d_1sample",
"check_fit1d",
"check_fit2d_predict1d",
"check_requires_y_none",
# TDR convolves and thus its output cannot be invariant when
# shuffled or subsampled.
"check_methods_sample_order_invariance",
"check_methods_subset_invariance",
)
if any(ignore in str(check) for ignore in ignores):
return
Expand All @@ -615,17 +605,12 @@ def test_tdr_sklearn_compliance(estimator, check):
@parametrize_with_checks([ReceptiveField(-1, 2, 1.0, estimator=Ridge(), patterns=True)])
def test_rf_sklearn_compliance(estimator, check):
"""Test sklearn RF compliance."""
pytest.importorskip("sklearn", minversion="1.4") # TODO VERSION remove on 1.4+
ignores = (
"check_parameters_default_constructible",
"_invariance",
"check_fit2d_1sample",
# Should probably fix these?
"check_complex_data",
"check_dtype_object",
"check_estimators_empty_data_messages",
"check_n_features_in",
"check_fit2d_predict1d",
"check_estimators_unfitted",
# RF does time-lagging, so its output cannot be invariant when
# shuffled or subsampled.
"check_methods_sample_order_invariance",
"check_methods_subset_invariance",
)
if any(ignore in str(check) for ignore in ignores):
return
Expand Down
10 changes: 1 addition & 9 deletions mne/decoding/tests/test_search_light.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,13 +354,5 @@ def predict_proba(self, X):
]
)
def test_sklearn_compliance(estimator, check):
"""Test LinearModel compliance with sklearn."""
ignores = (
# TODO: we don't handle singleton right (probably)
"check_classifiers_one_label_sample_weights",
"check_classifiers_classes",
"check_classifiers_train",
)
if any(ignore in str(check) for ignore in ignores):
return
"""Test searchlights compliance with sklearn."""
check(estimator)
10 changes: 8 additions & 2 deletions mne/decoding/tests/test_ssd.py
Original file line number Diff line number Diff line change
Expand Up @@ -623,8 +623,14 @@ def test_sklearn_compliance(estimator, check):
"""Test LinearModel compliance with sklearn."""
pytest.importorskip("sklearn", minversion="1.4") # TODO VERSION remove on 1.4+
ignores = (
"check_methods_sample_order_invariance",
# Shape stuff
# Checks below fail because what sklearn passes as (n_samples, n_features)
# is considered (n_channels, n_times) by SSD and creates problems
# when n_channels change between fit and transform.
# Could potentially be fixed by if X.ndim == 2: X = np.expand_dims(X, axis=2)
# in fit and transform instead of axis=0.
# But this will require to drop support for 2D inputs and expect
# user to provide 3D array even if it's a continuous signal.
"check_methods_sample_order_invariance", # SSD is not time-invariant
"check_fit_idempotent",
"check_methods_subset_invariance",
"check_transformer_general",
Expand Down
7 changes: 1 addition & 6 deletions mne/decoding/tests/test_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,12 +344,7 @@ def test_sklearn_compliance(estimator, check):
if estimator.__class__.__name__ == "FilterEstimator":
ignores += [
"check_estimators_overwrite_params", # we modify self.info
"check_methods_sample_order_invariance",
]
if estimator.__class__.__name__.startswith(("PSD", "Temporal")):
ignores += [
"check_transformers_unfitted", # allow unfitted transform
"check_methods_sample_order_invariance",
"check_methods_sample_order_invariance", # Filtering is not time invariant
]
if any(ignore in str(check) for ignore in ignores):
return
Expand Down
17 changes: 17 additions & 0 deletions mne/decoding/tests/test_xdawn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.

import pytest

pytest.importorskip("sklearn")
from sklearn.utils.estimator_checks import parametrize_with_checks

from mne.decoding import XdawnTransformer


@pytest.mark.filterwarnings("ignore:.*Only one sample available.*")
@parametrize_with_checks([XdawnTransformer(reg="oas")]) # oas handles few sample cases
def test_sklearn_compliance(estimator, check):
"""Test compliance with sklearn."""
check(estimator)
Loading
Loading