diff --git a/doc/changes/dev/13526.bugfix.rst b/doc/changes/dev/13526.bugfix.rst new file mode 100644 index 00000000000..09827e7a581 --- /dev/null +++ b/doc/changes/dev/13526.bugfix.rst @@ -0,0 +1 @@ +Fix bug preventing reading of :class:`mne.time_frequency.Spectrum` and :class:`mne.time_frequency.BaseTFR` objects created in MNE<1.8 using the deprecated subject info birthday tuple format, by `Thomas Binns`_. \ No newline at end of file diff --git a/doc/changes/dev/13528.bugfix.rst b/doc/changes/dev/13528.bugfix.rst new file mode 100644 index 00000000000..ef24b8c6e24 --- /dev/null +++ b/doc/changes/dev/13528.bugfix.rst @@ -0,0 +1 @@ +Fix bug where invalid date formats passed to :meth:`mne.Info.set_meas_date` were not caught, by `Thomas Binns`_. \ No newline at end of file diff --git a/mne/_fiff/meas_info.py b/mne/_fiff/meas_info.py index 708b7135a9c..90454b9a699 100644 --- a/mne/_fiff/meas_info.py +++ b/mne/_fiff/meas_info.py @@ -833,6 +833,10 @@ def set_meas_date(self, meas_date): """ from ..annotations import _handle_meas_date + _validate_type( + meas_date, (datetime.datetime, "numeric", tuple, None), "meas_date" + ) + info = self if isinstance(self, Info) else self.info meas_date = _handle_meas_date(meas_date) diff --git a/mne/_fiff/tests/test_meas_info.py b/mne/_fiff/tests/test_meas_info.py index 4e409d262e0..d0effacde91 100644 --- a/mne/_fiff/tests/test_meas_info.py +++ b/mne/_fiff/tests/test_meas_info.py @@ -1199,6 +1199,13 @@ def test_invalid_subject_birthday(): assert "birthday" not in raw.info["subject_info"] +def test_invalid_set_meas_date(): + """Test set_meas_date catches invalid str input.""" + info = create_info(1, 1000, "eeg") + with pytest.raises(TypeError, match=r"meas_date must be an instance of"): + info.set_meas_date("2025-01-01 00:00:00.000000") + + @pytest.mark.slowtest @pytest.mark.parametrize( "fname", diff --git a/mne/export/tests/test_export.py b/mne/export/tests/test_export.py index ac9551252c5..4651123d499 100644 --- a/mne/export/tests/test_export.py +++ b/mne/export/tests/test_export.py @@ -198,7 +198,7 @@ def _create_raw_for_edf_tests(stim_channel_index=None): def test_double_export_edf(tmp_path): """Test exporting an EDF file multiple times.""" raw = _create_raw_for_edf_tests(stim_channel_index=2) - raw.info.set_meas_date("2023-09-04 14:53:09.000") + raw.info.set_meas_date(datetime(2023, 9, 4, 14, 53, 9, tzinfo=timezone.utc)) raw.set_annotations(Annotations(onset=[1], duration=[0], description=["test"])) # include subject info and measurement date diff --git a/mne/time_frequency/spectrum.py b/mne/time_frequency/spectrum.py index 0d0ce0c30c8..6591527a28b 100644 --- a/mne/time_frequency/spectrum.py +++ b/mne/time_frequency/spectrum.py @@ -46,7 +46,11 @@ check_fname, ) from ..utils.misc import _pl -from ..utils.spectrum import _get_instance_type_string, _split_psd_kwargs +from ..utils.spectrum import ( + _convert_old_birthday_format, + _get_instance_type_string, + _split_psd_kwargs, +) from ..viz.topo import _plot_timeseries, _plot_timeseries_unified, _plot_topo from ..viz.topomap import _make_head_outlines, _prepare_topomap_plot, plot_psds_topomap from ..viz.utils import ( @@ -391,7 +395,7 @@ def __setstate__(self, state): self._freqs = state["freqs"] self._dims = state["dims"] self._sfreq = state["sfreq"] - self.info = Info(**state["info"]) + self.info = Info(**_convert_old_birthday_format(state["info"])) self._data_type = state["data_type"] self._nave = state.get("nave") # objs saved before #11282 won't have `nave` self._weights = state.get("weights") # objs saved before #12747 won't have diff --git a/mne/time_frequency/tests/test_spectrum.py b/mne/time_frequency/tests/test_spectrum.py index b1ad677352d..c3197173492 100644 --- a/mne/time_frequency/tests/test_spectrum.py +++ b/mne/time_frequency/tests/test_spectrum.py @@ -2,6 +2,7 @@ # License: BSD-3-Clause # Copyright the MNE-Python contributors. +import datetime import re from functools import partial @@ -26,7 +27,7 @@ SpectrumArray, combine_spectrum, ) -from mne.utils import _record_warnings +from mne.utils import _import_h5io_funcs, _record_warnings def test_compute_psd_errors(raw): @@ -178,6 +179,7 @@ def _get_inst(inst, request, *, evoked=None, average_tfr=None): def test_spectrum_io(inst, tmp_path, request, evoked): """Test save/load of spectrum objects.""" pytest.importorskip("h5io") + h5py = pytest.importorskip("h5py") fname = tmp_path / f"{inst}-spectrum.h5" inst = _get_inst(inst, request, evoked=evoked) if isinstance(inst, BaseEpochs): @@ -190,12 +192,25 @@ def test_spectrum_io(inst, tmp_path, request, evoked): orig.save(fname) loaded = read_spectrum(fname) assert orig == loaded + # Only check following for one type + if not isinstance(inst, BaseEpochs): + return + # Test loading with old-style birthday format + fname_subject_info = tmp_path / "subject-info.h5" + _, write_hdf5 = _import_h5io_funcs() + write_hdf5(fname_subject_info, dict(birthday=(2000, 1, 1)), title="subject_info") + with h5py.File(fname, "r+") as f: + del f["mnepython/key_info/key_subject_info"] + f["mnepython/key_info/key_subject_info"] = h5py.ExternalLink( + fname_subject_info, "subject_info" + ) + loaded = read_spectrum(fname) + assert isinstance(loaded.info["subject_info"]["birthday"], datetime.date) # Test Spectrum from EpochsSpectrum.average() can be read (gh-13521) - if isinstance(inst, BaseEpochs): - origavg = orig.average() - origavg.save(fname, overwrite=True) - loadedavg = read_spectrum(fname) - assert origavg == loadedavg + origavg = orig.average() + origavg.save(fname, overwrite=True) + loadedavg = read_spectrum(fname) + assert origavg == loadedavg def test_spectrum_copy(raw_spectrum): diff --git a/mne/time_frequency/tests/test_tfr.py b/mne/time_frequency/tests/test_tfr.py index ed6ddd6da82..fdf89a836c0 100644 --- a/mne/time_frequency/tests/test_tfr.py +++ b/mne/time_frequency/tests/test_tfr.py @@ -51,7 +51,7 @@ tfr_multitaper, write_tfrs, ) -from mne.utils import catch_logging, grand_average +from mne.utils import _import_h5io_funcs, catch_logging, grand_average from mne.utils._testing import _get_suptitle from mne.viz.utils import ( _channel_type_prettyprint, @@ -620,6 +620,7 @@ def test_tfr_io(inst, average_tfr, request, tmp_path): """Test TFR I/O.""" pytest.importorskip("h5io") pd = pytest.importorskip("pandas") + h5py = pytest.importorskip("h5py") tfr = _get_inst(inst, request, average_tfr=average_tfr) fname = tmp_path / "temp_tfr.hdf5" @@ -679,6 +680,22 @@ def test_tfr_io(inst, average_tfr, request, tmp_path): tfravg.save(fname, overwrite=True) tfravg_loaded = read_tfrs(fname) assert tfravg == tfravg_loaded + # test loading with old-style birthday format + fname_multi = tmp_path / "temp_multi_tfr.hdf5" + write_tfrs(fname_multi, tfr) # also check for multiple files from write_tfrs + fname_subject_info = tmp_path / "subject-info.hdf5" + _, write_hdf5 = _import_h5io_funcs() + write_hdf5(fname_subject_info, dict(birthday=(2000, 1, 1)), title="subject_info") + for this_fname in (fname, fname_multi): + with h5py.File(this_fname, "r+") as f: + if f.get("mnepython/key_info/key_subject_info"): + path = "mnepython/key_info/key_subject_info" + else: # multi-files on linux have different path to attrs + path = "mnepython/idx_0/idx_1/key_info/key_subject_info" + del f[path] + f[path] = h5py.ExternalLink(fname_subject_info, "subject_info") + tfr_loaded = read_tfrs(this_fname) + assert isinstance(tfr_loaded.info["subject_info"]["birthday"], datetime.date) # test with taper dimension and weights n_tapers = 3 # anything >= 1 should do weights = np.ones((n_tapers, tfr.shape[2])) # tapers x freqs diff --git a/mne/time_frequency/tfr.py b/mne/time_frequency/tfr.py index f64680845c4..f232ff30158 100644 --- a/mne/time_frequency/tfr.py +++ b/mne/time_frequency/tfr.py @@ -59,7 +59,7 @@ verbose, warn, ) -from ..utils.spectrum import _get_instance_type_string +from ..utils.spectrum import _convert_old_birthday_format, _get_instance_type_string from ..viz.topo import _imshow_tfr, _imshow_tfr_unified, _plot_topo from ..viz.topomap import ( _add_colorbar, @@ -1433,7 +1433,7 @@ def __setstate__(self, state): self._dims = defaults["dims"] self._raw_times = np.asarray(defaults["times"], dtype=np.float64) self._baseline = defaults["baseline"] - self.info = Info(**defaults["info"]) + self.info = Info(**_convert_old_birthday_format(defaults["info"])) self._data_type = defaults["data_type"] self._decim = defaults["decim"] self.preload = True @@ -4141,7 +4141,7 @@ def _read_multiple_tfrs(tfr_data, condition=None, *, verbose=None): if key != condition: continue tfr = dict(tfr) - tfr["info"] = Info(tfr["info"]) + tfr["info"] = Info(_convert_old_birthday_format(tfr["info"])) tfr["info"]._check_consistency() if "metadata" in tfr: tfr["metadata"] = _prepare_read_metadata(tfr["metadata"]) diff --git a/mne/utils/spectrum.py b/mne/utils/spectrum.py index 69052f21797..1efd06381c9 100644 --- a/mne/utils/spectrum.py +++ b/mne/utils/spectrum.py @@ -4,6 +4,7 @@ # License: BSD-3-Clause # Copyright the MNE-Python contributors. +from datetime import datetime from inspect import currentframe, getargvalues, signature from ..utils import warn @@ -102,3 +103,13 @@ def _split_psd_kwargs(*, plot_fun=None, kwargs=None): for k in plot_kwargs: del kwargs[k] return kwargs, plot_kwargs + + +def _convert_old_birthday_format(info): + """Convert deprecated birthday tuple to datetime.""" + subject_info = info.get("subject_info") + if subject_info is not None: + birthday = subject_info.get("birthday") + if isinstance(birthday, tuple): + info["subject_info"]["birthday"] = datetime(*birthday) + return info