Skip to content

Commit 08f64a2

Browse files
zhijingzdrammocklarsoner
authored
BUG: Raise early on non-finite values in PSD (Welch) and ICA.fit (Fix… (mne-tools#13486)
Co-authored-by: Daniel McCloy <[email protected]> Co-authored-by: Eric Larson <[email protected]>
1 parent ff3e0d5 commit 08f64a2

File tree

7 files changed

+117
-10
lines changed

7 files changed

+117
-10
lines changed

doc/changes/dev/13486.bugfix.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
Improved error message when non-finite values (NaN/Inf) are detected in calls to
2+
:meth:`inst.compute_psd(method="welch") <mne.io.Raw.compute_psd>` or
3+
:meth:`ICA.fit() <mne.preprocessing.ICA.fit>`, by :newcontrib:`Emma Zhang`.

doc/changes/names.inc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@
7979
.. _Eduard Ort: https://github.com/eort
8080
.. _Emily Stephen: https://github.com/emilyps14
8181
.. _Emma Bailey: https://www.cbs.mpg.de/employees/bailey
82+
.. _Emma Zhang: https://portfolio-production-ed03.up.railway.app/
8283
.. _Emmanuel Ferdman: https://github.com/emmanuel-ferdman
8384
.. _Emrecan Çelik: https://github.com/emrecncelik
8485
.. _Enrico Varano: https://github.com/enricovara/

doc/sphinxext/related_software.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -64,10 +64,6 @@
6464
"Home-page": "https://github.com/mind-inria/picard",
6565
"Summary": "Preconditioned ICA for Real Data",
6666
},
67-
"mne-features": {
68-
"Home-page": "https://mne.tools/mne-features",
69-
"Summary": "MNE-Features software for extracting features from multivariate time series", # noqa: E501
70-
},
7167
"mffpy": {
7268
"Home-page": "https://github.com/BEL-Public/mffpy",
7369
"Summary": "Reader and Writer for Philips' MFF file format.",

mne/preprocessing/ica.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -891,6 +891,9 @@ def _pre_whiten(self, data):
891891

892892
def _fit(self, data, fit_type):
893893
"""Aux function."""
894+
if not np.isfinite(data).all():
895+
raise ValueError("Input data contains non-finite values (NaN/Inf). ")
896+
894897
random_state = check_random_state(self.random_state)
895898
n_channels, n_samples = data.shape
896899
self._compute_pre_whitener(data)

mne/preprocessing/tests/test_ica.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1746,3 +1746,30 @@ def test_ica_get_sources_concatenated():
17461746
# get sources
17471747
raw_sources = ica.get_sources(raw_concat) # but this only has 3 seconds of data
17481748
assert raw_concat.n_times == raw_sources.n_times # this will fail
1749+
1750+
1751+
@pytest.mark.filterwarnings(
1752+
"ignore:The data has not been high-pass filtered.:RuntimeWarning"
1753+
)
1754+
@pytest.mark.filterwarnings(
1755+
"ignore:invalid value encountered in subtract:RuntimeWarning"
1756+
)
1757+
def test_ica_rejects_nonfinite():
1758+
"""ICA.fit should fail early on NaN/Inf in the input data."""
1759+
info = create_info(["Fz", "Cz", "Pz", "Oz"], sfreq=100.0, ch_types="eeg")
1760+
rng = np.random.RandomState(1)
1761+
data = rng.standard_normal(size=(4, 1000))
1762+
1763+
# Case 1: NaN
1764+
raw = RawArray(data.copy(), info)
1765+
raw._data[0, 25] = np.nan
1766+
ica = ICA(n_components=2, random_state=0, method="fastica", max_iter="auto")
1767+
with pytest.raises(ValueError, match=r"Input data contains non-finite values"):
1768+
ica.fit(raw)
1769+
1770+
# Case 2: Inf
1771+
raw = RawArray(data.copy(), info)
1772+
raw._data[1, 50] = np.inf
1773+
ica = ICA(n_components=2, random_state=0, method="fastica", max_iter="auto")
1774+
with pytest.raises(ValueError, match=r"Input data contains non-finite values"):
1775+
ica.fit(raw)

mne/time_frequency/psd.py

Lines changed: 50 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from scipy.signal import spectrogram
1010

1111
from ..parallel import parallel_func
12-
from ..utils import _check_option, _ensure_int, logger, verbose
12+
from ..utils import _check_option, _ensure_int, logger, verbose, warn
1313
from ..utils.numerics import _mask_to_onsets_offsets
1414

1515

@@ -200,6 +200,42 @@ def psd_array_welch(
200200
del freq_mask
201201
freqs = freqs[freq_sl]
202202

203+
step = max(int(n_per_seg) - int(n_overlap), 1)
204+
if n_times >= n_per_seg:
205+
n_segments = 1 + (n_times - n_per_seg) // step
206+
analyzed_end = step * (n_segments - 1) + n_per_seg
207+
else:
208+
n_segments = 0
209+
analyzed_end = 0
210+
211+
nan_mask_full = np.isnan(x)
212+
nan_present = nan_mask_full.any()
213+
if nan_present:
214+
good_mask_full = ~nan_mask_full
215+
aligned_nan = np.allclose(good_mask_full, good_mask_full[[0]], equal_nan=True)
216+
else:
217+
aligned_nan = False
218+
219+
if analyzed_end > 0:
220+
# Inf always counts as non-finite per-channel
221+
nonfinite_mask = np.isinf(x[..., :analyzed_end])
222+
# NaNs count per-channel only if NOT aligned (i.e., not annotations)
223+
if nan_present and not aligned_nan:
224+
nonfinite_mask |= nan_mask_full[..., :analyzed_end]
225+
bad_ch = nonfinite_mask.any(axis=-1)
226+
else:
227+
bad_ch = np.zeros(x.shape[0], dtype=bool)
228+
229+
if bad_ch.any():
230+
warn(
231+
"Non-finite values (NaN/Inf) detected in some channels; PSD for "
232+
"those channels will be NaN.",
233+
)
234+
# avoid downstream NumPy warnings by zeroing bad channels;
235+
# will overwrite their PSD rows with NaN at the end
236+
x = x.copy()
237+
x[bad_ch] = 0.0
238+
203239
# Parallelize across first N-1 dimensions
204240
logger.debug(
205241
f"Spectogram using {n_fft}-point FFT on {n_per_seg} samples with "
@@ -217,11 +253,9 @@ def psd_array_welch(
217253
window=window,
218254
mode=mode,
219255
)
220-
if np.any(np.isnan(x)):
221-
good_mask = ~np.isnan(x)
222-
# NaNs originate from annot, so must match for all channels. Note that we CANNOT
223-
# use np.testing.assert_allclose() here; it is strict about shapes/broadcasting
224-
assert np.allclose(good_mask, good_mask[[0]], equal_nan=True)
256+
if nan_present and aligned_nan:
257+
# Aligned NaNs across channels → treat as bad annotations.
258+
good_mask = ~nan_mask_full
225259
t_onsets, t_offsets = _mask_to_onsets_offsets(good_mask[0])
226260
x_splits = [x[..., t_ons:t_off] for t_ons, t_off in zip(t_onsets, t_offsets)]
227261
# weights reflect the number of samples used from each span. For spans longer
@@ -257,6 +291,12 @@ def func(*args, **kwargs):
257291
return _func(*args, **kwargs)
258292

259293
else:
294+
# Either no NaNs, or NaNs are not aligned across channels.
295+
if nan_present and not aligned_nan:
296+
logger.info(
297+
"NaN masks are not aligned across channels; treating NaNs as "
298+
"per-channel contamination."
299+
)
260300
x_splits = [arr for arr in np.array_split(x, n_jobs) if arr.size != 0]
261301
agg_func = np.concatenate
262302
func = _func
@@ -268,5 +308,9 @@ def func(*args, **kwargs):
268308
shape = dshape + (len(freqs),)
269309
if average is None:
270310
shape = shape + (-1,)
311+
312+
if bad_ch.any():
313+
psds[bad_ch] = np.nan
314+
271315
psds.shape = shape
272316
return psds, freqs

mne/time_frequency/tests/test_psd.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,3 +226,36 @@ def test_psd_array_welch_n_jobs():
226226
data = np.zeros((1, 2048))
227227
psd_array_welch(data, 1024, n_jobs=1)
228228
psd_array_welch(data, 1024, n_jobs=2)
229+
230+
231+
def test_psd_nan_in_data():
232+
"""psd_array_welch should fail if +Inf lies inside analyzed samples."""
233+
n_samples, n_fft, n_overlap = 2048, 256, 128
234+
rng = np.random.RandomState(0)
235+
x = rng.standard_normal(size=(2, n_samples))
236+
# Put +Inf inside the series; this falls within Welch windows
237+
x[0, 800] = np.inf # Channel 0 has Inf → bad channel
238+
with pytest.warns(RuntimeWarning, match="Non-finite values"):
239+
psds, freqs = psd_array_welch(x, float(n_fft), n_fft=n_fft, n_overlap=n_overlap)
240+
241+
# Channel 0 is contaminated → NaN PSD
242+
assert np.isnan(psds[0]).all()
243+
244+
# Channel 1 is clean → has finite PSD values
245+
assert np.isfinite(psds[1]).any()
246+
247+
248+
def test_psd_misaligned_nan_across_channels():
249+
"""If NaNs are present but masks are NOT aligned across channels."""
250+
n_samples, n_fft, n_overlap = 2048, 256, 128
251+
rng = np.random.RandomState(42)
252+
x = rng.standard_normal(size=(2, n_samples))
253+
# NaN only in ch0; ch1 has no NaN => masks not aligned -> should raise
254+
x[0, 500] = np.nan
255+
with pytest.warns(RuntimeWarning, match="Non-finite values"):
256+
psds, freqs = psd_array_welch(x, float(n_fft), n_fft=n_fft, n_overlap=n_overlap)
257+
# Bad channel gets NaN PSD
258+
assert np.isnan(psds[0]).all()
259+
260+
# Good channel retains finite values
261+
assert np.isfinite(psds[1]).any()

0 commit comments

Comments
 (0)