Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/changes/dev/13521.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix bug preventing :func:`mne.time_frequency.read_spectrum` from reading saved :class:`mne.time_frequency.Spectrum` objects created from :meth:`mne.time_frequency.EpochsSpectrum.average`, by `Thomas Binns`_.
7 changes: 5 additions & 2 deletions mne/io/edf/edf.py
Original file line number Diff line number Diff line change
Expand Up @@ -2117,8 +2117,11 @@ def read_raw_bdf(
>>> events[:, 2] &= (2**16 - 1) # doctest:+SKIP

The above operation can be carried out directly in :func:`mne.find_events`
using the ``mask`` and ``mask_type`` parameters (see
:func:`mne.find_events` for more details).
using the ``mask`` parameter as follows:

>>> events = mne.find_events(..., mask=2**16 - 1) # doctest:+SKIP

See :func:`mne.find_events` for more details.

It is also possible to retrieve system codes, but no particular effort has
been made to decode these in MNE. In case it is necessary, for instance to
Expand Down
2 changes: 1 addition & 1 deletion mne/time_frequency/spectrum.py
Original file line number Diff line number Diff line change
Expand Up @@ -1752,7 +1752,7 @@ def read_spectrum(fname):
n_jobs=None,
verbose=None,
)
Klass = EpochsSpectrum if hdf5_dict["inst_type_str"] == "Epochs" else Spectrum
Klass = EpochsSpectrum if "epoch" in hdf5_dict["dims"] else Spectrum
return Klass(hdf5_dict, **defaults)


Expand Down
7 changes: 7 additions & 0 deletions mne/time_frequency/tests/test_spectrum.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,12 @@ def test_spectrum_io(inst, tmp_path, request, evoked):
orig.save(fname)
loaded = read_spectrum(fname)
assert orig == loaded
# Test Spectrum from EpochsSpectrum.average() can be read (gh-13521)
if isinstance(inst, BaseEpochs):
origavg = orig.average()
origavg.save(fname, overwrite=True)
loadedavg = read_spectrum(fname)
assert origavg == loadedavg


def test_spectrum_copy(raw_spectrum):
Expand Down Expand Up @@ -320,6 +326,7 @@ def test_epochs_spectrum_average(epochs_spectrum, method):
avg_spect = epochs_spectrum.average(method=method)
assert avg_spect.shape == epochs_spectrum.shape[1:]
assert avg_spect._dims == ("channel", "freq") # no 'epoch'
assert repr(avg_spect).startswith("<Averaged Power Spectrum (from Epochs")


@pytest.mark.parametrize("inst", ("raw_spectrum", "epochs_spectrum", "evoked"))
Expand Down
9 changes: 9 additions & 0 deletions mne/time_frequency/tests/test_tfr.py
Original file line number Diff line number Diff line change
Expand Up @@ -674,6 +674,11 @@ def test_tfr_io(inst, average_tfr, request, tmp_path):
with tfr.info._unlock():
tfr.info["meas_date"] = want
assert tfr_loaded == tfr
# test AverageTFR from EpochsTFR.average() can be read (gh-13521)
tfravg = tfr.average()
tfravg.save(fname, overwrite=True)
tfravg_loaded = read_tfrs(fname)
assert tfravg == tfravg_loaded
# test with taper dimension and weights
n_tapers = 3 # anything >= 1 should do
weights = np.ones((n_tapers, tfr.shape[2])) # tapers x freqs
Expand Down Expand Up @@ -1217,6 +1222,10 @@ def test_averaging_epochsTFR():
):
tapered.average()

# Test repr from original instance info is preserved
avgpower = power.average()
assert repr(avgpower).startswith("<Average Power from Epochs")


def test_averaging_freqsandtimes_epochsTFR():
"""Test that EpochsTFR averaging freqs methods work."""
Expand Down
1 change: 0 additions & 1 deletion mne/time_frequency/tfr.py
Original file line number Diff line number Diff line change
Expand Up @@ -3313,7 +3313,6 @@ def average(self, method="mean", *, dim="epochs", copy=False):
state["freqs"] = freqs
state["times"] = times
if dim == "epochs":
state["inst_type_str"] = "Evoked"
state["nave"] = n_epochs
state["comment"] = f"{method} of {n_epochs} EpochsTFR{_pl(n_epochs)}"
out = AverageTFR(inst=state)
Expand Down
Loading