Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/changes/dev/13350.apichange.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add ``on_few_samples`` parameter to :func:`mne.compute_covariance` and :func:`mne.compute_raw_covariance` for controlling behavior when there are fewer samples than channels, which can lead to inaccurate covariance estimates, by :newcontrib:`Emmanuel Ferdman`.
1 change: 1 addition & 0 deletions doc/changes/names.inc
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@
.. _Eduard Ort: https://github.com/eort
.. _Emily Stephen: https://github.com/emilyps14
.. _Emma Bailey: https://www.cbs.mpg.de/employees/bailey
.. _Emmanuel Ferdman: https://github.com/emmanuel-ferdman
.. _Emrecan Çelik: https://github.com/emrecncelik
.. _Enrico Varano: https://github.com/enricovara/
.. _Enzo Altamiranda: https://www.linkedin.com/in/enzoalt
Expand Down
40 changes: 33 additions & 7 deletions mne/cov.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,16 +553,17 @@ def make_ad_hoc_cov(info, std=None, *, verbose=None):
return Covariance(data, ch_names, info["bads"], info["projs"], nfree=0)


def _check_n_samples(n_samples, n_chan):
def _check_n_samples(n_samples, n_chan, on_few_samples="warn"):
"""Check to see if there are enough samples for reliable cov calc."""
n_samples_min = 10 * (n_chan + 1) // 2
if n_samples <= 0:
raise ValueError("No samples found to compute the covariance matrix")
if n_samples < n_samples_min:
warn(
f"Too few samples (required : {n_samples_min} got : {n_samples}), "
"covariance estimate may be unreliable"
msg = (
f"Too few samples (required {n_samples_min} but got {n_samples} for "
f"{n_chan} channels), covariance estimate may be unreliable"
)
_on_missing(on_few_samples, msg, "on_few_samples")


@verbose
Expand All @@ -574,6 +575,8 @@ def compute_raw_covariance(
reject=None,
flat=None,
picks=None,
*,
on_few_samples="warn",
method="empirical",
method_params=None,
cv=3,
Expand Down Expand Up @@ -623,6 +626,12 @@ def compute_raw_covariance(
are floats that set the minimum acceptable peak-to-peak amplitude.
If flat is None then no rejection is done.
%(picks_good_data_noref)s
on_few_samples : str
Can be 'warn' (default), 'ignore', or 'raise' to control behavior when
there are fewer samples than channels, which can lead to inaccurate
covariance or rank estimates.

.. versionadded:: 1.11
method : str | list | None (default 'empirical')
The method used for covariance estimation.
See :func:`mne.compute_covariance`.
Expand Down Expand Up @@ -736,7 +745,7 @@ def compute_raw_covariance(
mu += raw_segment.sum(axis=1)
data += np.dot(raw_segment, raw_segment.T)
n_samples += raw_segment.shape[1]
_check_n_samples(n_samples, len(picks))
_check_n_samples(n_samples, len(picks), on_few_samples)
data -= mu[:, None] * (mu[None, :] / n_samples)
data /= n_samples - 1.0
logger.info("Number of samples used : %d", n_samples)
Expand Down Expand Up @@ -864,6 +873,8 @@ def compute_covariance(
tmin=None,
tmax=None,
projs=None,
*,
on_few_samples="warn",
method="empirical",
method_params=None,
cv=3,
Expand Down Expand Up @@ -909,6 +920,12 @@ def compute_covariance(
List of projectors to use in covariance calculation, or None
to indicate that the projectors from the epochs should be
inherited. If None, then projectors from all epochs must match.
on_few_samples : str
Can be 'warn' (default), 'ignore', or 'raise' to control behavior when
there are fewer samples than channels, which can lead to inaccurate
covariance or rank estimates.

.. versionadded:: 1.11
method : str | list | None (default 'empirical')
The method used for covariance estimation. If 'empirical' (default),
the sample covariance will be computed. A list can be passed to
Expand Down Expand Up @@ -1144,7 +1161,7 @@ def _unpack_epochs(epochs):

epochs = np.hstack(epochs)
n_samples_tot = epochs.shape[-1]
_check_n_samples(n_samples_tot, len(picks_meeg))
_check_n_samples(n_samples_tot, len(picks_meeg), on_few_samples)

epochs = epochs.T # sklearn | C-order
cov_data = _compute_covariance_auto(
Expand All @@ -1158,6 +1175,7 @@ def _unpack_epochs(epochs):
picks_list=picks_list,
scalings=scalings,
rank=rank,
on_few_samples=on_few_samples,
)

if keep_sample_mean is False:
Expand Down Expand Up @@ -1221,7 +1239,7 @@ def _eigvec_subspace(eig, eigvec, mask):

@verbose
def _compute_rank_raw_array(
data, info, rank, scalings, *, log_ch_type=None, verbose=None
data, info, rank, scalings, *, log_ch_type=None, on_few_samples="warn", verbose=None
):
from .io import RawArray

Expand All @@ -1231,6 +1249,7 @@ def _compute_rank_raw_array(
scalings,
info,
log_ch_type=log_ch_type,
on_few_samples=on_few_samples,
)


Expand All @@ -1249,6 +1268,7 @@ def _compute_covariance_auto(
cov_kind="",
log_ch_type=None,
log_rank=True,
on_few_samples="warn",
):
"""Compute covariance auto mode."""
# rescale to improve numerical stability
Expand All @@ -1258,6 +1278,7 @@ def _compute_covariance_auto(
info,
rank=rank,
scalings=scalings,
on_few_samples=on_few_samples,
verbose=_verbose_safe_false(),
)
with _scaled_array(data.T, picks_list, scalings):
Expand Down Expand Up @@ -1729,6 +1750,7 @@ def prepare_noise_cov(
rank=None,
scalings=None,
on_rank_mismatch="ignore",
*,
verbose=None,
):
"""Prepare noise covariance matrix.
Expand Down Expand Up @@ -2119,6 +2141,9 @@ def _regularized_covariance(
log_ch_type=None,
log_rank=None,
cov_kind="",
# backward-compat default for decoding (maybe someday we want to expose this but
# it's likely too invasive and since it's usually regularized, unnecessary):
on_few_samples="ignore",
verbose=None,
):
"""Compute a regularized covariance from data using sklearn.
Expand Down Expand Up @@ -2166,6 +2191,7 @@ def _regularized_covariance(
cov_kind=cov_kind,
log_ch_type=log_ch_type,
log_rank=log_rank,
on_few_samples=on_few_samples,
)[reg]["data"]
return cov

Expand Down
3 changes: 3 additions & 0 deletions mne/decoding/_covs_ged.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ def _csp_estimate(X, y, reg, cov_method_params, cov_est, info, rank, norm_trace)
rank=rank,
scalings=None,
log_ch_type="data",
on_few_samples="ignore",
)

covs = []
Expand Down Expand Up @@ -158,6 +159,7 @@ def _xdawn_estimate(
rank=rank,
scalings=None,
log_ch_type="data",
on_few_samples="ignore",
)
return covs, C_ref, info, rank, dict()

Expand Down Expand Up @@ -280,5 +282,6 @@ def _spoc_estimate(X, y, reg, cov_method_params, info, rank):
rank=rank,
scalings=None,
log_ch_type="data",
on_few_samples="ignore",
)
return covs, C_ref, info, rank, dict()
2 changes: 1 addition & 1 deletion mne/decoding/tests/test_csp.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,7 +431,7 @@ def test_spoc():
# check y
pytest.raises(ValueError, spoc.fit, X, y * 0)

# Check that doesn't take CSP-spcific input
# Check that doesn't take CSP-specific input
pytest.raises(TypeError, SPoC, cov_est="epoch")

# Check mixing matrix on simulated data
Expand Down
62 changes: 54 additions & 8 deletions mne/rank.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,14 +130,27 @@ def _estimate_rank_from_s(s, tol="auto", tol_kind="absolute"):


def _estimate_rank_raw(
raw, picks=None, tol=1e-4, scalings="norm", with_ref_meg=False, tol_kind="absolute"
raw,
picks=None,
tol=1e-4,
scalings="norm",
with_ref_meg=False,
tol_kind="absolute",
on_few_samples="warn",
):
"""Aid the transition away from raw.estimate_rank."""
if picks is None:
picks = _picks_to_idx(raw.info, picks, with_ref_meg=with_ref_meg)
# conveniency wrapper to expose the expert "tol" option + scalings options
return _estimate_rank_meeg_signals(
raw[picks][0], pick_info(raw.info, picks), scalings, tol, False, tol_kind
raw[picks][0],
pick_info(raw.info, picks),
scalings,
tol,
False,
tol_kind,
log_ch_type=None,
on_few_samples=on_few_samples,
)


Expand All @@ -150,6 +163,7 @@ def _estimate_rank_meeg_signals(
return_singular=False,
tol_kind="absolute",
log_ch_type=None,
on_few_samples="warn",
):
"""Estimate rank for M/EEG data.

Expand All @@ -173,6 +187,10 @@ def _estimate_rank_meeg_signals(
to determine the rank.
tol_kind : str
Tolerance kind. See ``estimate_rank``.
on_few_samples : str
Can be 'warn' (default), 'ignore', or 'raise' to control behavior when
there are fewer samples than channels, which can lead to inaccurate rank
estimates.

Returns
-------
Expand All @@ -183,11 +201,14 @@ def _estimate_rank_meeg_signals(
thresholded to determine the rank are also returned.
"""
picks_list = _picks_by_type(info)
if data.shape[1] < data.shape[0]:
ValueError(
"You've got fewer samples than channels, your "
"rank estimate might be inaccurate."
assert data.ndim == 2, data.shape
n_channels, n_samples = data.shape
if n_samples < n_channels:
msg = (
f"Too few samples ({n_samples=} is less than {n_channels=}), "
"rank estimate may be unreliable"
)
_on_missing(on_few_samples, msg, "on_few_samples")
with _scaled_array(data, picks_list, scalings):
out = estimate_rank(
data,
Expand All @@ -214,6 +235,7 @@ def _estimate_rank_meeg_cov(
return_singular=False,
*,
log_ch_type=None,
on_few_samples="warn",
verbose=None,
):
"""Estimate rank of M/EEG covariance data, given the covariance.
Expand All @@ -236,6 +258,10 @@ def _estimate_rank_meeg_cov(
return_singular : bool
If True, also return the singular values that were used
to determine the rank.
on_few_samples : str
Can be 'warn' (default), 'ignore', or 'raise' to control behavior when
there are fewer samples than channels, which can lead to inaccurate rank
estimates.

Returns
-------
Expand All @@ -249,10 +275,11 @@ def _estimate_rank_meeg_cov(
scalings = _handle_default("scalings_cov_rank", scalings)
_apply_scaling_cov(data, picks_list, scalings)
if data.shape[1] < data.shape[0]:
ValueError(
msg = (
"You've got fewer samples than channels, your "
"rank estimate might be inaccurate."
)
_on_missing(on_few_samples, msg, "on_few_samples")
out = estimate_rank(data, tol=tol, norm=False, return_singular=return_singular)
rank = out[0] if isinstance(out, tuple) else out
if log_ch_type is None:
Expand Down Expand Up @@ -325,7 +352,7 @@ def _compute_rank_int(inst, *args, **kwargs):
# XXX eventually we should unify how channel types are handled
# so that we don't need to do this, or we do it everywhere.
# Using pca=True in compute_whitener might help.
return sum(compute_rank(inst, *args, **kwargs).values())
return sum(compute_rank(inst, *args, on_few_samples="ignore", **kwargs).values())


@verbose
Expand All @@ -335,9 +362,11 @@ def compute_rank(
scalings=None,
info=None,
tol="auto",
*,
proj=True,
tol_kind="absolute",
on_rank_mismatch="ignore",
on_few_samples=None,
verbose=None,
):
"""Compute the rank of data or noise covariance.
Expand All @@ -363,6 +392,13 @@ def compute_rank(
considered when ``rank=None`` or ``rank='info'``.
%(tol_kind_rank)s
%(on_rank_mismatch)s
on_few_samples : str | None
Can be 'warn', 'ignore', or 'raise' to control behavior when
there are fewer samples than channels, which can lead to inaccurate rank
estimates. None (default) means "ignore" if ``inst`` is a
:class:`mne.Covariance` or ``rank in ("info", "full")``, and "warn" otherwise.

.. versionadded:: 1.11
%(verbose)s

Returns
Expand All @@ -384,6 +420,7 @@ def compute_rank(
proj=proj,
tol_kind=tol_kind,
on_rank_mismatch=on_rank_mismatch,
on_few_samples=on_few_samples,
)


Expand All @@ -398,6 +435,7 @@ def _compute_rank(
proj=True,
tol_kind="absolute",
on_rank_mismatch="ignore",
on_few_samples=None,
log_ch_type=None,
verbose=None,
):
Expand Down Expand Up @@ -441,6 +479,12 @@ def _compute_rank(
if rank is None:
rank = dict()

if on_few_samples is None:
if inst_type != "covariance" and rank_type == "estimated":
on_few_samples = "warn"
else:
on_few_samples = "ignore"

simple_info = _simplify_info(info)
picks_list = _picks_by_type(info, meg_combined=True, ref_meg=False, exclude="bads")
for ch_type, picks in picks_list:
Expand Down Expand Up @@ -503,6 +547,7 @@ def _compute_rank(
False,
tol_kind,
log_ch_type=log_ch_type,
on_few_samples=on_few_samples,
)
else:
assert isinstance(inst, Covariance)
Expand All @@ -520,6 +565,7 @@ def _compute_rank(
tol,
return_singular=True,
log_ch_type=log_ch_type,
on_few_samples=on_few_samples,
verbose=est_verbose,
)
if ch_type in rank:
Expand Down
3 changes: 2 additions & 1 deletion mne/simulation/tests/test_evoked.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,8 @@ def test_add_noise():
if inst is raw:
cov_new = compute_raw_covariance(inst, picks=picks)
else:
cov_new = compute_covariance(inst)
with pytest.warns(RuntimeWarning, match=".*Too few samples.*"):
cov_new = compute_covariance(inst)
assert cov["names"] == cov_new["names"]
r = np.corrcoef(cov["data"].ravel(), cov_new["data"].ravel())[0, 1]
assert r > 0.99
Expand Down
Loading
Loading