99from scipy .signal import spectrogram
1010
1111from ..parallel import parallel_func
12- from ..utils import _check_option , _ensure_int , logger , verbose
12+ from ..utils import _check_option , _ensure_int , logger , verbose , warn
1313from ..utils .numerics import _mask_to_onsets_offsets
1414
1515
@@ -200,6 +200,42 @@ def psd_array_welch(
200200 del freq_mask
201201 freqs = freqs [freq_sl ]
202202
203+ step = max (int (n_per_seg ) - int (n_overlap ), 1 )
204+ if n_times >= n_per_seg :
205+ n_segments = 1 + (n_times - n_per_seg ) // step
206+ analyzed_end = step * (n_segments - 1 ) + n_per_seg
207+ else :
208+ n_segments = 0
209+ analyzed_end = 0
210+
211+ nan_mask_full = np .isnan (x )
212+ nan_present = nan_mask_full .any ()
213+ if nan_present :
214+ good_mask_full = ~ nan_mask_full
215+ aligned_nan = np .allclose (good_mask_full , good_mask_full [[0 ]], equal_nan = True )
216+ else :
217+ aligned_nan = False
218+
219+ if analyzed_end > 0 :
220+ # Inf always counts as non-finite per-channel
221+ nonfinite_mask = np .isinf (x [..., :analyzed_end ])
222+ # NaNs count per-channel only if NOT aligned (i.e., not annotations)
223+ if nan_present and not aligned_nan :
224+ nonfinite_mask |= nan_mask_full [..., :analyzed_end ]
225+ bad_ch = nonfinite_mask .any (axis = - 1 )
226+ else :
227+ bad_ch = np .zeros (x .shape [0 ], dtype = bool )
228+
229+ if bad_ch .any ():
230+ warn (
231+ "Non-finite values (NaN/Inf) detected in some channels; PSD for "
232+ "those channels will be NaN." ,
233+ )
234+ # avoid downstream NumPy warnings by zeroing bad channels;
235+ # will overwrite their PSD rows with NaN at the end
236+ x = x .copy ()
237+ x [bad_ch ] = 0.0
238+
203239 # Parallelize across first N-1 dimensions
204240 logger .debug (
205241 f"Spectogram using { n_fft } -point FFT on { n_per_seg } samples with "
@@ -217,11 +253,9 @@ def psd_array_welch(
217253 window = window ,
218254 mode = mode ,
219255 )
220- if np .any (np .isnan (x )):
221- good_mask = ~ np .isnan (x )
222- # NaNs originate from annot, so must match for all channels. Note that we CANNOT
223- # use np.testing.assert_allclose() here; it is strict about shapes/broadcasting
224- assert np .allclose (good_mask , good_mask [[0 ]], equal_nan = True )
256+ if nan_present and aligned_nan :
257+ # Aligned NaNs across channels → treat as bad annotations.
258+ good_mask = ~ nan_mask_full
225259 t_onsets , t_offsets = _mask_to_onsets_offsets (good_mask [0 ])
226260 x_splits = [x [..., t_ons :t_off ] for t_ons , t_off in zip (t_onsets , t_offsets )]
227261 # weights reflect the number of samples used from each span. For spans longer
@@ -257,6 +291,12 @@ def func(*args, **kwargs):
257291 return _func (* args , ** kwargs )
258292
259293 else :
294+ # Either no NaNs, or NaNs are not aligned across channels.
295+ if nan_present and not aligned_nan :
296+ logger .info (
297+ "NaN masks are not aligned across channels; treating NaNs as "
298+ "per-channel contamination."
299+ )
260300 x_splits = [arr for arr in np .array_split (x , n_jobs ) if arr .size != 0 ]
261301 agg_func = np .concatenate
262302 func = _func
@@ -268,5 +308,9 @@ def func(*args, **kwargs):
268308 shape = dshape + (len (freqs ),)
269309 if average is None :
270310 shape = shape + (- 1 ,)
311+
312+ if bad_ch .any ():
313+ psds [bad_ch ] = np .nan
314+
271315 psds .shape = shape
272316 return psds , freqs
0 commit comments