diff --git a/azure-pipelines.yml b/azure-pipelines.yml index efa6f13dcc4..474cdf326aa 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -114,7 +114,7 @@ stages: - bash: | set -e python -m pip install --progress-bar off --upgrade pip - python -m pip install --progress-bar off "mne-qt-browser[opengl] @ git+https://github.com/mne-tools/mne-qt-browser.git" pyvista scikit-learn python-picard qtpy nibabel sphinx-gallery "PySide6!=6.8.0,!=6.8.0.1,!=6.8.1.1,!=6.9.1" pandas neo pymatreader antio defusedxml + python -m pip install --progress-bar off "mne-qt-browser[opengl] @ git+https://github.com/mne-tools/mne-qt-browser.git" pyvista scikit-learn python-picard qtpy nibabel sphinx-gallery "PySide6!=6.8.0,!=6.8.0.1,!=6.8.1.1,!=6.9.1" pandas neo pymatreader antio defusedxml curryreader python -m pip uninstall -yq mne python -m pip install --progress-bar off --upgrade -e . --group=test displayName: 'Install dependencies with pip' diff --git a/doc/_includes/dig_formats.rst b/doc/_includes/dig_formats.rst index 5928b081aea..47d99a889bc 100644 --- a/doc/_includes/dig_formats.rst +++ b/doc/_includes/dig_formats.rst @@ -22,21 +22,23 @@ function for more info on reading specific file types. .. cssclass:: table-bordered .. rst-class:: midvalign -================= ================ ============================================== -Vendor Extension(s) MNE-Python function -================= ================ ============================================== -Neuromag .fif :func:`mne.channels.read_dig_fif` +===================== ================ ============================================== +Vendor Extension(s) MNE-Python function +===================== ================ ============================================== +Neuromag .fif :func:`mne.channels.read_dig_fif` -Polhemus ISOTRAK .hsp, .elp, .eeg :func:`mne.channels.read_dig_polhemus_isotrak` +Polhemus ISOTRAK .hsp, .elp, .eeg :func:`mne.channels.read_dig_polhemus_isotrak` -EGI .xml :func:`mne.channels.read_dig_egi` +EGI .xml :func:`mne.channels.read_dig_egi` -MNE-C .hpts :func:`mne.channels.read_dig_hpts` +MNE-C .hpts :func:`mne.channels.read_dig_hpts` -Brain Products .bvct :func:`mne.channels.read_dig_captrak` +Brain Products .bvct :func:`mne.channels.read_dig_captrak` -Compumedics .dat :func:`mne.channels.read_dig_dat` -================= ================ ============================================== +Compumedics .dat, .cdt :func:`mne.channels.read_dig_curry` + +Compumedics (legacy) .dat :func:`mne.channels.read_dig_dat` +===================== ================ ============================================== To load Polhemus FastSCAN files you can use :func:`montage `. diff --git a/doc/api/preprocessing.rst b/doc/api/preprocessing.rst index 9fe3f995cc4..07443e518aa 100644 --- a/doc/api/preprocessing.rst +++ b/doc/api/preprocessing.rst @@ -42,6 +42,7 @@ Projections: read_dig_polhemus_isotrak read_dig_captrak read_dig_dat + read_dig_curry read_dig_egi read_dig_fif read_dig_hpts diff --git a/doc/changes/dev/13156.newfeature.rst b/doc/changes/dev/13156.newfeature.rst new file mode 100644 index 00000000000..067a3cd4481 --- /dev/null +++ b/doc/changes/dev/13156.newfeature.rst @@ -0,0 +1 @@ +Added support for file like objects in :func:`read_raw_bdf `, :func:`read_raw_edf ` and :func:`read_raw_gdf `, by :newcontrib:`Santi Martínez`. \ No newline at end of file diff --git a/doc/changes/dev/13176.dependency.rst b/doc/changes/dev/13176.dependency.rst new file mode 100644 index 00000000000..713ce3ba502 --- /dev/null +++ b/doc/changes/dev/13176.dependency.rst @@ -0,0 +1 @@ +New reader for Neuroscan Curry files, using the curry-python-reader module, by `Dominik Welke`_. \ No newline at end of file diff --git a/doc/changes/dev/13176.newfeature.rst b/doc/changes/dev/13176.newfeature.rst new file mode 100644 index 00000000000..dd483b214fc --- /dev/null +++ b/doc/changes/dev/13176.newfeature.rst @@ -0,0 +1 @@ +Read impedances and montage from Neuroscan Curry files, by `Dominik Welke`_. \ No newline at end of file diff --git a/doc/changes/names.inc b/doc/changes/names.inc index 2d362aeac66..709a7ebd623 100644 --- a/doc/changes/names.inc +++ b/doc/changes/names.inc @@ -283,6 +283,7 @@ .. _Samuel Louviot: https://github.com/Sam54000 .. _Samuel Powell: https://github.com/samuelpowell .. _Santeri Ruuskanen: https://github.com/ruuskas +.. _Santi Martínez: https://github.com/szz-dvl .. _Sara Sommariva: https://github.com/sarasommariva .. _Sawradip Saha: https://sawradip.github.io/ .. _Scott Huberty: https://orcid.org/0000-0003-2637-031X diff --git a/environment.yml b/environment.yml index e7620d41fa4..b72d39925e4 100644 --- a/environment.yml +++ b/environment.yml @@ -5,6 +5,7 @@ channels: dependencies: - python >=3.10 - antio >=0.5.0 + - curryreader >=0.1.2 - darkdetect - decorator - defusedxml diff --git a/mne/_edf/open.py b/mne/_edf/open.py new file mode 100644 index 00000000000..2fd97833b29 --- /dev/null +++ b/mne/_edf/open.py @@ -0,0 +1,23 @@ +# Authors: The MNE-Python contributors. +# License: BSD-3-Clause +# Copyright the MNE-Python contributors. + +# Maybe we can move this one to utils or something like that. +from pathlib import Path + +from mne._fiff.open import _NoCloseRead + +from ..utils import _file_like, _validate_type, logger + + +def _gdf_edf_get_fid(fname, **kwargs): + """Open a EDF/BDF/GDF file with no additional parsing.""" + if _file_like(fname): + logger.debug("Using file-like I/O") + fid = _NoCloseRead(fname) + fid.seek(0) + else: + _validate_type(fname, [Path, str], "fname", extra="or file-like") + logger.debug("Using normal I/O") + fid = open(fname, "rb", **kwargs) # Open in binary mode + return fid diff --git a/mne/annotations.py b/mne/annotations.py index c724bb8d354..74a5878dca2 100644 --- a/mne/annotations.py +++ b/mne/annotations.py @@ -1415,6 +1415,8 @@ def read_annotations( ".csv": _read_annotations_csv, ".cnt": _read_annotations_cnt, ".ds": _read_annotations_ctf, + ".dat": _read_annotations_curry, + ".cdt": _read_annotations_curry, ".cef": _read_annotations_curry, ".set": _read_annotations_eeglab, ".edf": _read_annotations_edf, @@ -1427,6 +1429,8 @@ def read_annotations( kwargs = { ".vmrk": {"sfreq": sfreq, "ignore_marker_types": ignore_marker_types}, ".amrk": {"sfreq": sfreq, "ignore_marker_types": ignore_marker_types}, + ".dat": {"sfreq": sfreq}, + ".cdt": {"sfreq": sfreq}, ".cef": {"sfreq": sfreq}, ".set": {"uint16_codec": uint16_codec}, ".edf": {"encoding": encoding}, diff --git a/mne/channels/__init__.pyi b/mne/channels/__init__.pyi index 05f273a713d..470bc0cfd0b 100644 --- a/mne/channels/__init__.pyi +++ b/mne/channels/__init__.pyi @@ -22,6 +22,7 @@ __all__ = [ "read_ch_adjacency", "read_custom_montage", "read_dig_captrak", + "read_dig_curry", "read_dig_dat", "read_dig_egi", "read_dig_fif", @@ -67,6 +68,7 @@ from .montage import ( make_standard_montage, read_custom_montage, read_dig_captrak, + read_dig_curry, read_dig_dat, read_dig_egi, read_dig_fif, diff --git a/mne/channels/_dig_montage_utils.py b/mne/channels/_dig_montage_utils.py index a59e209b2e4..31bee83648d 100644 --- a/mne/channels/_dig_montage_utils.py +++ b/mne/channels/_dig_montage_utils.py @@ -2,6 +2,8 @@ # License: BSD-3-Clause # Copyright the MNE-Python contributors. +import re + import numpy as np from ..utils import Bunch, _check_fname, _soft_import, warn @@ -94,3 +96,46 @@ def _parse_brainvision_dig_montage(fname, scale): ch_pos=dig_ch_pos, coord_frame="unknown", ) + + +def _read_dig_montage_curry(ch_names, ch_types, ch_pos, landmarks, landmarkslabels): + # scale ch_pos to m?! + ch_pos /= 1000.0 + landmarks /= 1000.0 + # channel locations + # what about misc without pos? can they mess things up if unordered? + assert len(ch_pos) >= (ch_types.count("mag") + ch_types.count("eeg")) + assert len(ch_pos) == (ch_types.count("mag") + ch_types.count("eeg")) + ch_pos_eeg = { + ch_names[i]: ch_pos[i, :3] for i, t in enumerate(ch_types) if t == "eeg" + } + # landmarks and headshape + landmark_dict = dict(zip(landmarkslabels, landmarks)) + for k in ["Nas", "RPA", "LPA"]: + if k not in landmark_dict.keys(): + landmark_dict[k] = None + if len(landmarkslabels) > 0: + hpi_pos = landmarks[ + [i for i, n in enumerate(landmarkslabels) if re.match("HPI[1-99]", n)], : + ] + else: + hpi_pos = None + if len(landmarkslabels) > 0: + hsp_pos = landmarks[ + [i for i, n in enumerate(landmarkslabels) if re.match("H[1-99]", n)], : + ] + else: + hsp_pos = None + # compile dig montage positions for eeg + if len(ch_pos_eeg) > 0: + return dict( + ch_pos=ch_pos_eeg, + nasion=landmark_dict["Nas"], + lpa=landmark_dict["LPA"], + rpa=landmark_dict["RPA"], + hsp=hsp_pos, + hpi=hpi_pos, + coord_frame="unknown", + ) + else: # not recorded? + raise ValueError("No eeg sensor locations found in header file.") diff --git a/mne/channels/montage.py b/mne/channels/montage.py index cc6ad705cf5..35fdbce917c 100644 --- a/mne/channels/montage.py +++ b/mne/channels/montage.py @@ -49,12 +49,17 @@ check_fname, copy_function_doc_to_method_doc, fill_doc, + legacy, verbose, warn, ) from ..utils.docs import docdict from ..viz import plot_montage -from ._dig_montage_utils import _parse_brainvision_dig_montage, _read_dig_montage_egi +from ._dig_montage_utils import ( + _parse_brainvision_dig_montage, + _read_dig_montage_curry, + _read_dig_montage_egi, +) @dataclass @@ -322,7 +327,6 @@ class DigMontage: See Also -------- read_dig_captrak - read_dig_dat read_dig_egi read_dig_fif read_dig_hpts @@ -757,6 +761,7 @@ def transform_to_head(montage): return montage +@legacy(alt="read_dig_curry()") def read_dig_dat(fname): r"""Read electrode positions from a ``*.dat`` file. @@ -779,7 +784,7 @@ def read_dig_dat(fname): See Also -------- read_dig_captrak - read_dig_dat + read_dig_curry read_dig_egi read_dig_fif read_dig_hpts @@ -845,9 +850,9 @@ def read_dig_fif(fname, *, verbose=None): See Also -------- DigMontage - read_dig_dat read_dig_egi read_dig_captrak + read_dig_curry read_dig_polhemus_isotrak read_dig_hpts read_dig_localite @@ -898,7 +903,7 @@ def read_dig_hpts(fname, unit="mm"): -------- DigMontage read_dig_captrak - read_dig_dat + read_dig_curry read_dig_egi read_dig_fif read_dig_localite @@ -991,7 +996,7 @@ def read_dig_egi(fname): -------- DigMontage read_dig_captrak - read_dig_dat + read_dig_curry read_dig_fif read_dig_hpts read_dig_localite @@ -1023,7 +1028,7 @@ def read_dig_captrak(fname): See Also -------- DigMontage - read_dig_dat + read_dig_curry read_dig_egi read_dig_fif read_dig_hpts @@ -1037,6 +1042,51 @@ def read_dig_captrak(fname): return make_dig_montage(**data) +def read_dig_curry(fname): + """Read electrode locations from Neuroscan Curry files. + + Parameters + ---------- + fname : path-like + A valid Curry file. + + Returns + ------- + montage : instance of DigMontage | None + The montage. + + See Also + -------- + DigMontage + read_dig_captrak + read_dig_egi + read_dig_fif + read_dig_hpts + read_dig_localite + read_dig_polhemus_isotrak + make_dig_montage + + Notes + ----- + .. versionadded:: 1.11 + """ + from ..io.curry.curry import ( + _check_curry_filename, + _extract_curry_info, + ) + + # TODO - REVIEW NEEDED + fname = _check_curry_filename(fname) + (_, _, ch_names, ch_types, ch_pos, landmarks, landmarkslabels, _, _, _, _, _, _) = ( + _extract_curry_info(fname) + ) + data = _read_dig_montage_curry( + ch_names, ch_types, ch_pos, landmarks, landmarkslabels + ) + mont = make_dig_montage(**data) if data else None + return mont + + def read_dig_localite(fname, nasion=None, lpa=None, rpa=None): """Read Localite .csv file. @@ -1060,7 +1110,7 @@ def read_dig_localite(fname, nasion=None, lpa=None, rpa=None): -------- DigMontage read_dig_captrak - read_dig_dat + read_dig_curry read_dig_egi read_dig_fif read_dig_hpts @@ -1461,7 +1511,7 @@ def read_dig_polhemus_isotrak(fname, ch_names=None, unit="m"): make_dig_montage read_polhemus_fastscan read_dig_captrak - read_dig_dat + read_dig_curry read_dig_egi read_dig_fif read_dig_localite @@ -1821,8 +1871,8 @@ def make_standard_montage(kind, head_size="auto"): Notes ----- Individualized (digitized) electrode positions should be read in using - :func:`read_dig_captrak`, :func:`read_dig_dat`, :func:`read_dig_egi`, - :func:`read_dig_fif`, :func:`read_dig_polhemus_isotrak`, + :func:`read_dig_captrak`, :func:`read_dig_curry`, + :func:`read_dig_egi`, :func:`read_dig_fif`, :func:`read_dig_polhemus_isotrak`, :func:`read_dig_hpts`, or manually made with :func:`make_dig_montage`. .. versionadded:: 0.19.0 diff --git a/mne/datasets/config.py b/mne/datasets/config.py index 62937779692..10d3dea7fa4 100644 --- a/mne/datasets/config.py +++ b/mne/datasets/config.py @@ -87,7 +87,7 @@ # update the checksum in the MNE_DATASETS dict below, and change version # here: ↓↓↓↓↓↓↓↓ RELEASES = dict( - testing="0.167", + testing="0.168", misc="0.27", phantom_kit="0.2", ucl_opm_auditory="0.2", @@ -115,7 +115,7 @@ # Testing and misc are at the top as they're updated most often MNE_DATASETS["testing"] = dict( archive_name=f"{TESTING_VERSIONED}.tar.gz", - hash="md5:d82318a83b436ca2c7ca8420487c05c2", + hash="md5:7782a64f170b9435b0fd126862b0cf63", url=( "https://codeload.github.com/mne-tools/mne-testing-data/" f"tar.gz/{RELEASES['testing']}" diff --git a/mne/fixes.py b/mne/fixes.py index 070d4125d18..2148330fb34 100644 --- a/mne/fixes.py +++ b/mne/fixes.py @@ -16,12 +16,14 @@ # because this module is imported many places (but not always used)! import inspect +import io import operator as operator_module import os import warnings from math import log import numpy as np +import numpy.typing from packaging.version import parse ############################################################################### @@ -733,3 +735,33 @@ def sph_harm_y(n, m, theta, phi, *, diff_n=0): return special.sph_harm_y(n, m, theta, phi, diff_n=diff_n) else: return special.sph_harm(m, n, phi, theta) + + +############################################################################### +# workaround: Numpy won't allow to read from file-like objects with numpy.fromfile, +# we try to use numpy.fromfile, if a blob is used we use numpy.frombuffer to read +# from the file-like object. +def read_from_file_or_buffer( + file: str | bytes | os.PathLike | io.IOBase, + dtype: numpy.typing.DTypeLike = float, + count: int = -1, +): + """numpy.fromfile() wrapper, handling io.BytesIO file-like streams. + + Numpy requires open files to be actual files on disk, i.e., must support + file.fileno(), so it fails with file-like streams such as io.BytesIO(). + + If numpy.fromfile() fails due to no file.fileno() support, this wrapper + reads the required bytes from file and redirects the call to + numpy.frombuffer(). + + See https://github.com/numpy/numpy/issues/2230#issuecomment-949795210 + """ + try: + return np.fromfile(file, dtype=dtype, count=count) + except io.UnsupportedOperation as e: + if not (e.args and e.args[0] == "fileno" and isinstance(file, io.IOBase)): + raise # Nothing I can do about it + dtype = np.dtype(dtype) + buffer = file.read(dtype.itemsize * count) + return np.frombuffer(buffer, dtype=dtype, count=count) diff --git a/mne/io/cnt/cnt.py b/mne/io/cnt/cnt.py index da91ee59f9e..196a87564d1 100644 --- a/mne/io/cnt/cnt.py +++ b/mne/io/cnt/cnt.py @@ -206,7 +206,7 @@ def read_raw_cnt( - Standard montages with :func:`mne.channels.make_standard_montage` - Montages for `Compumedics systems `__ with - :func:`mne.channels.read_dig_dat` + :func:`mne.channels.read_dig_curry` - Other reader functions are listed under *See Also* at :class:`mne.channels.DigMontage` diff --git a/mne/io/curry/__init__.py b/mne/io/curry/__init__.py index fce6b7d9a32..5b2e89b6798 100644 --- a/mne/io/curry/__init__.py +++ b/mne/io/curry/__init__.py @@ -5,3 +5,4 @@ # Copyright the MNE-Python contributors. from .curry import read_raw_curry +from .curry import read_impedances_curry diff --git a/mne/io/curry/curry.py b/mne/io/curry/curry.py index 3e8347fba0d..c1ffcad96fb 100644 --- a/mne/io/curry/curry.py +++ b/mne/io/curry/curry.py @@ -3,9 +3,7 @@ # License: BSD-3-Clause # Copyright the MNE-Python contributors. -import os.path as op import re -from collections import namedtuple from datetime import datetime, timezone from pathlib import Path @@ -16,7 +14,8 @@ from ..._fiff.meas_info import create_info from ..._fiff.tag import _coil_trans_to_loc from ..._fiff.utils import _mult_cal_one, _read_segments_file -from ...annotations import Annotations +from ...annotations import annotations_from_events +from ...epochs import Epochs from ...surface import _normal_orth from ...transforms import ( Transform, @@ -27,395 +26,637 @@ invert_transform, rot_to_quat, ) -from ...utils import _check_fname, check_fname, logger, verbose +from ...utils import ( + _on_missing, + _soft_import, + catch_logging, + logger, + verbose, + warn, +) from ..base import BaseRaw from ..ctf.trans import _quaternion_align -FILE_EXTENSIONS = { - "Curry 7": { - "info": ".dap", - "data": ".dat", - "labels": ".rs3", - "events_cef": ".cef", - "events_ceo": ".ceo", - "hpi": ".hpi", - }, - "Curry 8": { - "info": ".cdt.dpa", - "data": ".cdt", - "labels": ".cdt.dpa", - "events_cef": ".cdt.cef", - "events_ceo": ".cdt.ceo", - "hpi": ".cdt.hpi", - }, -} -CHANTYPES = {"meg": "_MAG1", "eeg": "", "misc": "_OTHERS"} -FIFFV_CHANTYPES = { - "meg": FIFF.FIFFV_MEG_CH, - "eeg": FIFF.FIFFV_EEG_CH, - "misc": FIFF.FIFFV_MISC_CH, -} -FIFFV_COILTYPES = { - "meg": FIFF.FIFFV_COIL_CTF_GRAD, - "eeg": FIFF.FIFFV_COIL_EEG, - "misc": FIFF.FIFFV_COIL_NONE, -} -SI_UNITS = dict(V=FIFF.FIFF_UNIT_V, T=FIFF.FIFF_UNIT_T) -SI_UNIT_SCALE = dict(c=1e-2, m=1e-3, u=1e-6, µ=1e-6, n=1e-9, p=1e-12, f=1e-15) - -CurryParameters = namedtuple( - "CurryParameters", - "n_samples, sfreq, is_ascii, unit_dict, n_chans, dt_start, chanidx_in_file", -) +CURRY_SUFFIX_DATA = [".cdt", ".dat"] +CURRY_SUFFIX_HDR = [".cdt.dpa", ".cdt.dpo", ".dap"] +CURRY_SUFFIX_LABELS = [".cdt.dpa", ".cdt.dpo", ".rs3"] -def _get_curry_version(file_extension): +def _get_curry_version(fname): """Check out the curry file version.""" - return "Curry 8" if "cdt" in file_extension else "Curry 7" + fname_hdr = _check_curry_header_filename(_check_curry_filename(fname)) + content_hdr = fname_hdr.read_text() + return ( + "Curry 7" + if ".dap" in str(fname_hdr) + else "Curry 8" + if re.compile(r"FileVersion\s*=\s*[0-9]+") + .search(content_hdr) + .group(0) + .split()[-1][0] + == "8" + else "Curry 9" + if re.compile(r"FileVersion\s*=\s*[0-9]+") + .search(content_hdr) + .group(0) + .split()[-1][0] + == "9" + else None + ) + + +def _check_curry_filename(fname): + fname_in = Path(fname).expanduser() + fname_out = None + # try suffixes + if fname_in.suffix in CURRY_SUFFIX_DATA: + fname_out = fname_in + elif ( + fname_in.with_suffix("").exists() + and fname_in.with_suffix("").suffix in CURRY_SUFFIX_DATA + ): + fname_out = fname_in.with_suffix("") + else: + for data_suff in CURRY_SUFFIX_DATA: + if fname_in.with_suffix(data_suff).exists(): + fname_out = fname_in.with_suffix(data_suff) + break + # final check + if not fname_out or not fname_out.exists(): + raise FileNotFoundError( + f"no curry data file found (.dat or .cdt), checked {fname_out or fname_in}" + ) + return fname_out + + +def _check_curry_header_filename(fname): + fname_in = Path(fname) + fname_hdr = None + # try suffixes + for hdr_suff in CURRY_SUFFIX_HDR: + if fname_in.with_suffix(hdr_suff).exists(): + fname_hdr = fname_in.with_suffix(hdr_suff) + break + # final check + if not fname_hdr or not fname_in.exists(): + raise FileNotFoundError( + f"no corresponding header file found {CURRY_SUFFIX_HDR}" + ) + return fname_hdr -def _get_curry_file_structure(fname, required=()): - """Store paths to a dict and check for required files.""" - _msg = ( - "The following required files cannot be found: {0}.\nPlease make " - "sure all required files are located in the same directory as {1}." +def _check_curry_labels_filename(fname): + fname_in = Path(fname) + fname_labels = None + # try suffixes + for hdr_suff in CURRY_SUFFIX_LABELS: + if fname_in.with_suffix(hdr_suff).exists(): + fname_labels = fname_in.with_suffix(hdr_suff) + break + # final check + if not fname_labels or not fname_in.exists(): + raise FileNotFoundError( + f"no corresponding labels file found {CURRY_SUFFIX_HDR}" + ) + return fname_labels + + +def _check_curry_sfreq_consistency(fname_hdr): + content_hdr = fname_hdr.read_text() + stime = float( + re.compile(r"SampleTimeUsec\s*=\s*.+").search(content_hdr).group(0).split()[-1] + ) + sfreq = float( + re.compile(r"SampleFreqHz\s*=\s*.+").search(content_hdr).group(0).split()[-1] ) - fname = Path(_check_fname(fname, "read", True, "fname")) - - # we don't use os.path.splitext to also handle extensions like .cdt.dpa - # this won't handle a dot in the filename, but it should handle it in - # the parent directories - fname_base = fname.name.split(".", maxsplit=1)[0] - ext = fname.name[len(fname_base) :] - fname_base = str(fname) - fname_base = fname_base[: len(fname_base) - len(ext)] - del fname - version = _get_curry_version(ext) - my_curry = dict() - for key in ("info", "data", "labels", "events_cef", "events_ceo", "hpi"): - fname = fname_base + FILE_EXTENSIONS[version][key] - if op.isfile(fname): - _key = "events" if key.startswith("events") else key - my_curry[_key] = fname - - missing = [field for field in required if field not in my_curry] - if missing: - raise FileNotFoundError(_msg.format(np.unique(missing), fname)) - - return my_curry - - -def _read_curry_lines(fname, regex_list): - """Read through the lines of a curry parameter files and save data. + if stime == 0: + raise ValueError("Header file indicates a sampling interval of 0µs.") + if not np.isclose(1e6 / stime, sfreq): + warn( + f"Sample distance ({stime}µs) and sample frequency ({sfreq}Hz) in header " + "file do not match! sfreq will be derived from sample distance." + ) - Parameters - ---------- - fname : path-like - Path to a curry file. - regex_list : list of str - A list of strings or regular expressions to search within the file. - Each element `regex` in `regex_list` must be formulated so that - `regex + " START_LIST"` initiates the start and `regex + " END_LIST"` - initiates the end of the elements that should be saved. - Returns - ------- - data_dict : dict - A dictionary containing the extracted data. For each element `regex` - in `regex_list` a dictionary key `data_dict[regex]` is created, which - contains a list of the according data. +def _get_curry_meas_info(fname): + # Note that the time zone information is not stored in the Curry info + # file, and it seems the start time info is in the local timezone + # of the acquisition system (which is unknown); therefore, just set + # the timezone to be UTC. If the user knows otherwise, they can + # change it later. (Some Curry files might include StartOffsetUTCMin, + # but its presence is unpredictable, so we won't rely on it.) + fname_hdr = _check_curry_header_filename(fname) + content_hdr = fname_hdr.read_text() + + # read meas_date + meas_date = [ + int(re.compile(rf"{v}\s*=\s*-?\d+").search(content_hdr).group(0).split()[-1]) + for v in [ + "StartYear", + "StartMonth", + "StartDay", + "StartHour", + "StartMin", + "StartSec", + "StartMillisec", + ] + ] + try: + meas_date = datetime( + *meas_date[:-1], + meas_date[-1] * 1000, # -> microseconds + timezone.utc, + ) + except Exception: + meas_date = None + + # read datatype + byteorder = ( + re.compile(r"DataByteOrder\s*=\s*[A-Z]+") + .search(content_hdr) + .group() + .split()[-1] + ) + is_ascii = byteorder == "ASCII" + + # amplifier info + # TODO - PRIVACY + # seems like there can be identifiable information (serial numbers, dates). + # MNE anonymization functions only overwrite "serial" and "site", though + # TODO - FUTURE ENHANCEMENT + # # there can be filter details in AmplifierInfo, too + amp_info = ( + re.compile(r"AmplifierInfo\s*=.*\n") + .search(content_hdr) + .group() + .strip("\n") + .split("= ")[-1] + .strip() + ) - """ - save_lines = {} - data_dict = {} + device_info = ( + dict(serial=amp_info) + if amp_info != "" + else None # model="", serial="", site="" + ) + + return meas_date, is_ascii, device_info - for regex in regex_list: - save_lines[regex] = False - data_dict[regex] = [] - with open(fname) as fid: - for line in fid: - for regex in regex_list: - if re.match(regex + " END_LIST", line): - save_lines[regex] = False +def _get_curry_recording_type(fname): + _soft_import("curryreader", "read recording modality") - if save_lines[regex] and line != "\n": - result = line.replace("\n", "") - if "\t" in result: - result = result.split("\t") - data_dict[regex].append(result) + import curryreader - if re.match(regex + " START_LIST", line): - save_lines[regex] = True + epochinfo = curryreader.read(str(fname), plotdata=0, verbosity=1)["epochinfo"] + if epochinfo.size == 0: + return "raw" + else: + n_average = epochinfo[:, 0] + if (n_average == 1).all(): + return "epochs" + else: + return "evoked" - return data_dict +def _get_curry_epoch_info(fname): + _soft_import("curryreader", "read epoch info") + _soft_import("pandas", "dataframe integration") -def _read_curry_parameters(fname): - """Extract Curry params from a Curry info file.""" - _msg_match = ( - "The sampling frequency and the time steps extracted from " - "the parameter file do not match." + import curryreader + import pandas as pd + + # use curry-python-reader + currydata = curryreader.read(str(fname), plotdata=0, verbosity=1) + + # get epoch info + sfreq = currydata["info"]["samplingfreq"] + n_samples = currydata["info"]["samples"] + n_epochs = len(currydata["epochlabels"]) + epochinfo = currydata["epochinfo"] + epochtypes = epochinfo[:, 2].astype(int).tolist() + epochlabels = currydata["epochlabels"] + epochmetainfo = pd.DataFrame( + epochinfo[:, -4:], columns=["accept", "correct", "response", "response time"] + ) + # create mne events + events = np.array( + [[i * n_samples for i in range(n_epochs)], [0] * n_epochs, epochtypes] + ).T + event_id = dict(zip(epochlabels, epochtypes)) + return dict( + events=events, + event_id=event_id, + tmin=0.0, + tmax=(n_samples - 1) / sfreq, + baseline=None, + detrend=None, + verbose=False, + metadata=epochmetainfo, + reject_by_annotation=False, + reject=None, ) - _msg_invalid = "sfreq must be greater than 0. Got sfreq = {0}" - - var_names = [ - "NumSamples", - "SampleFreqHz", - "DataFormat", - "SampleTimeUsec", - "NumChannels", - "StartYear", - "StartMonth", - "StartDay", - "StartHour", - "StartMin", - "StartSec", - "StartMillisec", - "NUM_SAMPLES", - "SAMPLE_FREQ_HZ", - "DATA_FORMAT", - "SAMPLE_TIME_USEC", - "NUM_CHANNELS", - "START_YEAR", - "START_MONTH", - "START_DAY", - "START_HOUR", - "START_MIN", - "START_SEC", - "START_MILLISEC", + + +def _get_curry_meg_normals(fname): + fname_lbl = _check_curry_labels_filename(fname) + normals_str = fname_lbl.read_text().split("\n") + # i_start, i_stop = [ + # i + # for i, ll in enumerate(normals_str) + # if ("NORMALS" in ll and "START_LIST" in ll) + # or ("NORMALS" in ll and "END_LIST" in ll) + # ] + # normals_str = [nn.split("\t") for nn in normals_str[i_start + 1 : i_stop]] + i_list = [ + i + for i, ll in enumerate(normals_str) + if ("NORMALS" in ll and "START_LIST" in ll) + or ("NORMALS" in ll and "END_LIST" in ll) ] + assert len(i_list) % 2 == 0 + i_start_list = i_list[::2] + i_stop_list = i_list[1::2] + normals_str = [ + nn.split("\t") + for i_start, i_stop in zip(i_start_list, i_stop_list) + for nn in normals_str[i_start + 1 : i_stop] + ] + return np.array([[float(nnn.strip()) for nnn in nn] for nn in normals_str]) - param_dict = dict() - unit_dict = dict() - - with open(fname) as fid: - for line in iter(fid): - if any(var_name in line for var_name in var_names): - key, val = line.replace(" ", "").replace("\n", "").split("=") - param_dict[key.lower().replace("_", "")] = val - for key, type_ in CHANTYPES.items(): - if f"DEVICE_PARAMETERS{type_} START" in line: - data_unit = next(fid) - unit_dict[key] = ( - data_unit.replace(" ", "").replace("\n", "").split("=")[-1] - ) - # look for CHAN_IN_FILE sections, which may or may not exist; issue #8391 - types = ["meg", "eeg", "misc"] - chanidx_in_file = _read_curry_lines( - fname, ["CHAN_IN_FILE" + CHANTYPES[key] for key in types] - ) +def _extract_curry_info(fname): + _soft_import("curryreader", "read file header") - n_samples = int(param_dict["numsamples"]) - sfreq = float(param_dict["samplefreqhz"]) - time_step = float(param_dict["sampletimeusec"]) * 1e-6 - is_ascii = param_dict["dataformat"] == "ASCII" - n_channels = int(param_dict["numchannels"]) - try: - dt_start = datetime( - int(param_dict["startyear"]), - int(param_dict["startmonth"]), - int(param_dict["startday"]), - int(param_dict["starthour"]), - int(param_dict["startmin"]), - int(param_dict["startsec"]), - int(param_dict["startmillisec"]) * 1000, - timezone.utc, + import curryreader + + # check if sfreq values make sense + fname_hdr = _check_curry_header_filename(fname) + _check_curry_sfreq_consistency(fname_hdr) + + # use curry-python-reader + currydata = curryreader.read(str(fname), plotdata=0, verbosity=1) + + # basic info + sfreq = currydata["info"]["samplingfreq"] + n_samples = currydata["info"]["samples"] + if n_samples != currydata["data"].shape[0]: # normal in epoched data + n_samples = currydata["data"].shape[0] + if _get_curry_recording_type(fname) == "raw": + warn( + "sample count from header doesn't match actual data! " + "file corrupted? will use data shape" + ) + + # channel information + n_ch = currydata["info"]["channels"] + ch_names = currydata["labels"] + ch_pos = currydata["sensorpos"] + landmarks = currydata["landmarks"] + if not isinstance(landmarks, np.ndarray): + landmarks = np.array(landmarks) + landmarkslabels = currydata["landmarkslabels"] + hpimatrix = currydata["hpimatrix"] + if isinstance(currydata["hpimatrix"], np.ndarray) and hpimatrix.ndim == 1: + hpimatrix = hpimatrix[np.newaxis, :] + + # data + orig_format = "int" + # curryreader.py always reads float32, but this is probably just numpy. + # legacy MNE code states int. + + # events + events = currydata["events"] + annotations = currydata["annotations"] + assert len(annotations) == len(events) + if len(events) > 0: + event_desc = dict() + for k, v in zip(events[:, 1], annotations): + if int(k) not in event_desc.keys(): + event_desc[int(k)] = v.strip() if (v.strip() != "") else str(int(k)) + else: + event_desc = None + + # impedance measurements + # moved to standalone def; see read_impedances_curry + # impedances = currydata["impedances"] + + # get other essential info not provided by curryreader + # channel types and units + ch_types, units = [], [] + ch_groups = fname_hdr.read_text().split("DEVICE_PARAMETERS")[1::2] + for ch_group in ch_groups: + ch_group = re.compile(r"\s+").sub(" ", ch_group).strip() + groupid = ch_group.split()[0] + unit = ch_group.split("DataUnit = ")[1].split()[0] + n_ch_group = int(ch_group.split("NumChanThisGroup = ")[1].split()[0]) + ch_type = ( + "mag" if ("MAG" in groupid) else "misc" if ("OTHER" in groupid) else "eeg" ) - # Note that the time zone information is not stored in the Curry info - # file, and it seems the start time info is in the local timezone - # of the acquisition system (which is unknown); therefore, just set - # the timezone to be UTC. If the user knows otherwise, they can - # change it later. (Some Curry files might include StartOffsetUTCMin, - # but its presence is unpredictable, so we won't rely on it.) - except (ValueError, KeyError): - dt_start = None # if missing keywords or illegal values, don't set - - if time_step == 0: - true_sfreq = sfreq - elif sfreq == 0: - true_sfreq = 1 / time_step - elif not np.isclose(sfreq, 1 / time_step): - raise ValueError(_msg_match) - else: # they're equal and != 0 - true_sfreq = sfreq - if true_sfreq <= 0: - raise ValueError(_msg_invalid.format(true_sfreq)) - - return CurryParameters( + # combine info + ch_types += [ch_type] * n_ch_group + units += [unit] * n_ch_group + + # This for Git issue #8391. In some cases, the 'labels' (.rs3 file will + # list channels that are not actually saved in the datafile (such as the + # 'Ref' channel). These channels are denoted in the 'info' (.dap) file + # in the CHAN_IN_FILE section with a '0' as their index. + # + # current curryreader cannot cope with this - loads the list of channels solely + # based on their order, so can be false. fix it here! + if not len(ch_types) == len(units) == len(ch_names) == n_ch: + # read relevant info + fname_lbl = _check_curry_labels_filename(fname) + lbl = fname_lbl.read_text().split("START_LIST") + ch_names_full = [] + for i in range(1, len(lbl)): + if "LABELS" in lbl[i - 1].split()[-1]: + for ll in lbl[i].split("\n")[1:]: + if "LABELS" not in ll: + ch_names_full.append(ll.strip()) + else: + break + hdr = fname_hdr.read_text().split("START_LIST") + chaninfile_full = [] + for i in range(1, len(hdr)): + if "CHAN_IN_FILE" in hdr[i - 1].split()[-1]: + for ll in hdr[i].split("\n")[1:]: + if "CHAN_IN_FILE" not in ll: + chaninfile_full.append(int(ll.strip())) + else: + break + # drop channels with chan_in_file==0, account for order + i_drop = [i for i, ich in enumerate(chaninfile_full) if ich == 0] + ch_names = [ + ch_names_full[i] for i in np.argsort(chaninfile_full) if i not in i_drop + ] + ch_pos = np.array( + [ + ch_pos[i] + for i in np.argsort(chaninfile_full) + if (i not in i_drop) and (i < len(ch_pos)) + ] + ) + ch_types = [ch_types[i] for i in np.argsort(chaninfile_full) if i not in i_drop] + units = [units[i] for i in np.argsort(chaninfile_full) if i not in i_drop] + + assert len(ch_types) == len(units) == len(ch_names) == n_ch + assert len(ch_pos) == ch_types.count("eeg") + ch_types.count("mag") + + # finetune channel types (e.g. stim, eog etc might be identified by name) + # TODO - FUTURE ENHANCEMENT + + # scale data to SI units + orig_units = dict(zip(ch_names, units)) + cals = [ + 1.0 / 1e15 if (u == "fT") else 1.0 / 1e6 if (u == "uV") else 1.0 for u in units + ] + + return ( + sfreq, n_samples, - true_sfreq, - is_ascii, - unit_dict, - n_channels, - dt_start, - chanidx_in_file, + ch_names, + ch_types, + ch_pos, + landmarks, + landmarkslabels, + hpimatrix, + events, + event_desc, + orig_format, + orig_units, + cals, ) -def _read_curry_info(curry_paths): - """Extract info from curry parameter files.""" - curry_params = _read_curry_parameters(curry_paths["info"]) - R = np.eye(4) - R[[0, 1], [0, 1]] = -1 # rotate 180 deg - # shift down and back - # (chosen by eyeballing to make the CTF helmet look roughly correct) - R[:3, 3] = [0.0, -0.015, -0.12] - curry_dev_dev_t = Transform("ctf_meg", "meg", R) - - # read labels from label files - label_fname = curry_paths["labels"] - types = ["meg", "eeg", "misc"] - labels = _read_curry_lines( - label_fname, ["LABELS" + CHANTYPES[key] for key in types] +def _read_annotations_curry(fname, sfreq="auto"): + r"""Read events from Curry event files. + + Parameters + ---------- + fname : path-like + The filename. + sfreq : float | 'auto' + The sampling frequency in the file. If set to 'auto' then the + ``sfreq`` is taken from the fileheader. + + Returns + ------- + annot : instance of Annotations | None + The annotations. + """ + fname = _check_curry_filename(fname) + + (sfreq_fromfile, _, _, _, _, _, _, _, events, event_desc, _, _, _) = ( + _extract_curry_info(fname) ) - sensors = _read_curry_lines( - label_fname, ["SENSORS" + CHANTYPES[key] for key in types] + if sfreq == "auto": + sfreq = sfreq_fromfile + elif np.isreal(sfreq): + if float(sfreq) != float(sfreq_fromfile): + warn( + f"provided sfreq ({sfreq} Hz) does not match freq from fileheader " + "({sfreq_fromfile} Hz)!" + ) + else: + raise ValueError("'sfreq' must be numeric or 'auto'") + + if isinstance(events, np.ndarray): # if there are events + events = events.astype("int") + events = np.insert(events, 1, np.diff(events[:, 2:]).flatten(), axis=1)[:, :3] + return annotations_from_events(events, sfreq, event_desc=event_desc) + else: + warn("no event annotations found") + return None + + +def _set_chanloc_curry( + inst, ch_types, ch_pos, landmarks, landmarkslabels, hpimatrix, on_bad_hpi_match +): + ch_names = inst.info["ch_names"] + + # scale ch_pos to m?! + ch_pos /= 1000.0 + landmarks /= 1000.0 + # channel locations + # what about misc without pos? can they mess things up if unordered? + assert len(ch_pos) >= (ch_types.count("mag") + ch_types.count("eeg")) + assert len(ch_pos) == (ch_types.count("mag") + ch_types.count("eeg")) + ch_pos_meg = { + ch_names[i]: ch_pos[i, :3] for i, t in enumerate(ch_types) if t == "mag" + } + ch_pos_eeg = { + ch_names[i]: ch_pos[i, :3] for i, t in enumerate(ch_types) if t == "eeg" + } + + # landmarks and headshape + # FIX: one of the test files (c,rfDC*.cdt) names landmarks differently: + NAS_NAMES = ["nasion", "nas"] + LPA_NAMES = ["left ear", "lpa"] + RPA_NAMES = ["right ear", "rpa"] + landmarkslabels = [ + "Nas" + if (ll.lower() in NAS_NAMES) + else "LPA" + if (ll.lower() in LPA_NAMES) + else "RPA" + if (ll.lower() in RPA_NAMES) + else ll + for ll in landmarkslabels + ] + landmark_dict = dict(zip(landmarkslabels, landmarks)) + for k in ["Nas", "RPA", "LPA"]: + if k not in landmark_dict.keys(): + landmark_dict[k] = None + if len(landmarkslabels) > 0: + hpi_pos = landmarks[ + [i for i, n in enumerate(landmarkslabels) if re.match("HPI.?[1-99]", n)], + :, + ] + else: + hpi_pos = None + if len(landmarkslabels) > 0: + hsp_pos = landmarks[ + [i for i, n in enumerate(landmarkslabels) if re.match("H.?[1-99]", n)], : + ] + else: + hsp_pos = None + + has_cards = ( + False + if ( + isinstance(landmark_dict["Nas"], type(None)) + and isinstance(landmark_dict["LPA"], type(None)) + and isinstance(landmark_dict["RPA"], type(None)) + ) + else True ) - normals = _read_curry_lines( - label_fname, ["NORMALS" + CHANTYPES[key] for key in types] + has_hpi = True if isinstance(hpi_pos, np.ndarray) else False + + add_missing_fiducials = not has_cards # raises otherwise + dig = _make_dig_points( + nasion=landmark_dict["Nas"], + lpa=landmark_dict["LPA"], + rpa=landmark_dict["RPA"], + hpi=hpi_pos, + extra_points=hsp_pos, + dig_ch_pos=ch_pos_eeg, + coord_frame="head", + add_missing_fiducials=add_missing_fiducials, ) - assert len(labels) == len(sensors) == len(normals) - - all_chans = list() - dig_ch_pos = dict() - for key in ["meg", "eeg", "misc"]: - chanidx_is_explicit = ( - len(curry_params.chanidx_in_file["CHAN_IN_FILE" + CHANTYPES[key]]) > 0 - ) # channel index - # position in the datafile may or may not be explicitly declared, - # based on the CHAN_IN_FILE section in info file - for ind, chan in enumerate(labels["LABELS" + CHANTYPES[key]]): - chanidx = len(all_chans) + 1 # by default, just assume the - # channel index in the datafile is in order of the channel - # names as we found them in the labels file - if chanidx_is_explicit: # but, if explicitly declared, use - # that index number - chanidx = int( - curry_params.chanidx_in_file["CHAN_IN_FILE" + CHANTYPES[key]][ind] - ) - if chanidx <= 0: # if chanidx was explicitly declared to be ' 0', - # it means the channel is not actually saved in the data file - # (e.g. the "Ref" channel), so don't add it to our list. - # Git issue #8391 - continue - ch = { - "ch_name": chan, - "unit": curry_params.unit_dict[key], - "kind": FIFFV_CHANTYPES[key], - "coil_type": FIFFV_COILTYPES[key], - "ch_idx": chanidx, - } - if key == "eeg": - loc = np.array(sensors["SENSORS" + CHANTYPES[key]][ind], float) - # XXX just the sensor, where is ref (next 3)? - assert loc.shape == (3,) - loc /= 1000.0 # to meters - loc = np.concatenate([loc, np.zeros(9)]) - ch["loc"] = loc - # XXX need to check/ensure this + with inst.info._unlock(): + inst.info["dig"] = dig + + # loc transformation for meg sensors (taken from previous version) + if len(ch_pos_meg) > 0: + R = np.eye(4) + R[[0, 1], [0, 1]] = -1 # rotate 180 deg + # shift down and back + # (chosen by eyeballing to make the helmet look roughly correct) + R[:3, 3] = [0.0, -0.015, -0.12] + curry_dev_dev_t = Transform("ctf_meg", "meg", R) + + ch_normals_meg = _get_curry_meg_normals(inst.filenames[0]) + assert len(ch_normals_meg) == len(ch_pos_meg) + else: + curry_dev_dev_t, ch_normals_meg = None, None + # fill up chanlocs + assert len(ch_names) == len(ch_types) >= len(ch_pos) + for i, (ch_name, ch_type, ch_loc) in enumerate(zip(ch_names, ch_types, ch_pos)): + assert inst.info["ch_names"][i] == ch_name + ch = inst.info["chs"][i] + if ch_type == "eeg": + with inst.info._unlock(): + ch["loc"][:3] = ch_loc[:3] ch["coord_frame"] = FIFF.FIFFV_COORD_HEAD - dig_ch_pos[chan] = loc[:3] - elif key == "meg": - pos = np.array(sensors["SENSORS" + CHANTYPES[key]][ind], float) - pos /= 1000.0 # to meters - pos = pos[:3] # just the inner coil - pos = apply_trans(curry_dev_dev_t, pos) - nn = np.array(normals["NORMALS" + CHANTYPES[key]][ind], float) - assert np.isclose(np.linalg.norm(nn), 1.0, atol=1e-4) - nn /= np.linalg.norm(nn) - nn = apply_trans(curry_dev_dev_t, nn, move=False) - trans = np.eye(4) - trans[:3, 3] = pos - trans[:3, :3] = _normal_orth(nn).T + elif ch_type == "mag": + # transform mode + pos = ch_loc[:3] # just the inner coil for MEG + pos = apply_trans(curry_dev_dev_t, pos) + nn = ch_normals_meg[i] + assert np.isclose(np.linalg.norm(nn), 1.0, atol=1e-4) + nn /= np.linalg.norm(nn) + nn = apply_trans(curry_dev_dev_t, nn, move=False) + trans = np.eye(4) + trans[:3, 3] = pos + trans[:3, :3] = _normal_orth(nn).T + with inst.info._unlock(): ch["loc"] = _coil_trans_to_loc(trans) + # TODO: We should figure out if all files are Compumedics, + # and even then figure out if it's adult or child + ch["coil_type"] = FIFF.FIFFV_COIL_COMPUMEDICS_ADULT_GRAD ch["coord_frame"] = FIFF.FIFFV_COORD_DEVICE - all_chans.append(ch) - dig = _make_dig_points( - dig_ch_pos=dig_ch_pos, coord_frame="head", add_missing_fiducials=True + elif ch_type == "misc": + pass + else: + raise NotImplementedError + + # TODO - REVIEW NEEDED + # do we need further transformations for MEG channel positions? + # the testfiles i got look good to me.. + _make_trans_dig( + inst.info, + curry_dev_dev_t, + landmark_dict, + has_cards, + has_hpi, + hpimatrix, + on_bad_hpi_match, ) - del dig_ch_pos - ch_count = len(all_chans) - assert ch_count == curry_params.n_chans # ensure that we have assembled - # the same number of channels as declared in the info (.DAP) file in the - # DATA_PARAMETERS section. Git issue #8391 - # sort the channels to assure they are in the order that matches how - # recorded in the datafile. In general they most likely are already in - # the correct order, but if the channel index in the data file was - # explicitly declared we might as well use it. - all_chans = sorted(all_chans, key=lambda ch: ch["ch_idx"]) +def _make_trans_dig( + info, + curry_dev_dev_t, + landmark_dict, + has_cards, + has_hpi, + chpidata, + on_bad_hpi_match, +): + cards = { + FIFF.FIFFV_POINT_LPA: landmark_dict["LPA"], + FIFF.FIFFV_POINT_NASION: landmark_dict["Nas"], + FIFF.FIFFV_POINT_RPA: landmark_dict["RPA"], + } - ch_names = [chan["ch_name"] for chan in all_chans] - info = create_info(ch_names, curry_params.sfreq) - with info._unlock(): - info["meas_date"] = curry_params.dt_start # for Git issue #8398 - info["dig"] = dig - _make_trans_dig(curry_paths, info, curry_dev_dev_t) - - for ind, ch_dict in enumerate(info["chs"]): - all_chans[ind].pop("ch_idx") - ch_dict.update(all_chans[ind]) - assert ch_dict["loc"].shape == (12,) - ch_dict["unit"] = SI_UNITS[all_chans[ind]["unit"][1]] - ch_dict["cal"] = SI_UNIT_SCALE[all_chans[ind]["unit"][0]] - - return info, curry_params.n_samples, curry_params.is_ascii - - -_card_dict = { - "Left ear": FIFF.FIFFV_POINT_LPA, - "Nasion": FIFF.FIFFV_POINT_NASION, - "Right ear": FIFF.FIFFV_POINT_RPA, -} - - -def _make_trans_dig(curry_paths, info, curry_dev_dev_t): # Coordinate frame transformations and definitions no_msg = "Leaving device<->head transform as None" info["dev_head_t"] = None - label_fname = curry_paths["labels"] - key = "LANDMARKS" + CHANTYPES["meg"] - lm = _read_curry_lines(label_fname, [key])[key] - lm = np.array(lm, float) - lm.shape = (-1, 3) + lm = [v for v in landmark_dict.values() if isinstance(v, np.ndarray)] if len(lm) == 0: # no dig logger.info(no_msg + " (no landmarks found)") return - lm /= 1000.0 - key = "LM_REMARKS" + CHANTYPES["meg"] - remarks = _read_curry_lines(label_fname, [key])[key] - assert len(remarks) == len(lm) - with info._unlock(): - info["dig"] = list() - cards = dict() - for remark, r in zip(remarks, lm): - kind = ident = None - if remark in _card_dict: - kind = FIFF.FIFFV_POINT_CARDINAL - ident = _card_dict[remark] - cards[ident] = r - elif remark.startswith("HPI"): - kind = FIFF.FIFFV_POINT_HPI - ident = int(remark[3:]) - 1 - if kind is not None: - info["dig"].append( - dict(kind=kind, ident=ident, r=r, coord_frame=FIFF.FIFFV_COORD_UNKNOWN) - ) - with info._unlock(): - info["dig"].sort(key=lambda x: (x["kind"], x["ident"])) - has_cards = len(cards) == 3 - has_hpi = "hpi" in curry_paths + if has_cards and has_hpi: # have all three logger.info("Composing device<->head transformation from dig points") hpi_u = np.array( [d["r"] for d in info["dig"] if d["kind"] == FIFF.FIFFV_POINT_HPI], float ) - hpi_c = np.ascontiguousarray(_first_hpi(curry_paths["hpi"])[: len(hpi_u), 1:4]) - unknown_curry_t = _quaternion_align("unknown", "ctf_meg", hpi_u, hpi_c, 1e-2) + hpi_c = np.ascontiguousarray(chpidata[0][: len(hpi_u), 1:4]) + bad_hpi_match = False + try: + with catch_logging() as log: + unknown_curry_t = _quaternion_align( + "unknown", + "ctf_meg", + hpi_u.astype("float64"), + hpi_c.astype("float64"), + 1e-2, + ) + except RuntimeError: + bad_hpi_match = True + with catch_logging() as log: + unknown_curry_t = _quaternion_align( + "unknown", + "ctf_meg", + hpi_u.astype("float64"), + hpi_c.astype("float64"), + 1e-1, + ) + logger.info(log.getvalue()) + angle = np.rad2deg( _angle_between_quats( np.zeros(3), rot_to_quat(unknown_curry_t["trans"][:3, :3]) @@ -423,6 +664,14 @@ def _make_trans_dig(curry_paths, info, curry_dev_dev_t): ) dist = 1000 * np.linalg.norm(unknown_curry_t["trans"][:3, 3]) logger.info(f" Fit a {angle:0.1f}° rotation, {dist:0.1f} mm translation") + + if bad_hpi_match: + _on_missing( + on_bad_hpi_match, + "Poor HPI matching (see log above)!", + name="on_bad_hpi_match", + ) + unknown_dev_t = combine_transforms( unknown_curry_t, curry_dev_dev_t, "unknown", "meg" ) @@ -459,95 +708,21 @@ def _make_trans_dig(curry_paths, info, curry_dev_dev_t): logger.info(no_msg) -def _first_hpi(fname): - # Get the first HPI result - with open(fname) as fid: - for line in fid: - line = line.strip() - if any(x in line for x in ("FileVersion", "NumCoils")) or not line: - continue - hpi = np.array(line.split(), float) - break - else: - raise RuntimeError(f"Could not find valid HPI in {fname}") - # t is the first entry - assert hpi.ndim == 1 - hpi = hpi[1:] - hpi.shape = (-1, 5) - hpi /= 1000.0 - return hpi - - -def _read_events_curry(fname): - """Read events from Curry event files. - - Parameters - ---------- - fname : path-like - Path to a curry event file with extensions .cef, .ceo, - .cdt.cef, or .cdt.ceo - - Returns - ------- - events : ndarray, shape (n_events, 3) - The array of events. - """ - check_fname( - fname, - "curry event", - (".cef", ".ceo", ".cdt.cef", ".cdt.ceo"), - endings_err=(".cef", ".ceo", ".cdt.cef", ".cdt.ceo"), - ) - - events_dict = _read_curry_lines(fname, ["NUMBER_LIST"]) - # The first 3 column seem to contain the event information - curry_events = np.array(events_dict["NUMBER_LIST"], dtype=int)[:, 0:3] - - return curry_events - - -def _read_annotations_curry(fname, sfreq="auto"): - r"""Read events from Curry event files. - - Parameters - ---------- - fname : str - The filename. - sfreq : float | 'auto' - The sampling frequency in the file. If set to 'auto' then the - ``sfreq`` is taken from the respective info file of the same name with - according file extension (\*.dap for Curry 7; \*.cdt.dpa for Curry8). - So data.cef looks in data.dap and data.cdt.cef looks in data.cdt.dpa. - - Returns - ------- - annot : instance of Annotations | None - The annotations. - """ - required = ["events", "info"] if sfreq == "auto" else ["events"] - curry_paths = _get_curry_file_structure(fname, required) - events = _read_events_curry(curry_paths["events"]) - - if sfreq == "auto": - sfreq = _read_curry_parameters(curry_paths["info"]).sfreq - - onset = events[:, 0] / sfreq - duration = np.zeros(events.shape[0]) - description = events[:, 2] - - return Annotations(onset, duration, description) - - @verbose -def read_raw_curry(fname, preload=False, verbose=None) -> "RawCurry": +def read_raw_curry( + fname, preload=False, on_bad_hpi_match="warn", verbose=None +) -> "RawCurry": """Read raw data from Curry files. + .. versionchanged:: 1.11 + This function now requires ``curryreader`` to be installed. + Parameters ---------- fname : path-like - Path to a curry file with extensions ``.dat``, ``.dap``, ``.rs3``, - ``.cdt``, ``.cdt.dpa``, ``.cdt.cef`` or ``.cef``. + Path to a valid curry file. %(preload)s + %(on_bad_hpi_match)s %(verbose)s Returns @@ -560,7 +735,20 @@ def read_raw_curry(fname, preload=False, verbose=None) -> "RawCurry": -------- mne.io.Raw : Documentation of attributes and methods of RawCurry. """ - return RawCurry(fname, preload, verbose) + fname = _check_curry_filename(fname) + fname_hdr = _check_curry_header_filename(fname) + + _check_curry_sfreq_consistency(fname_hdr) + + rectype = _get_curry_recording_type(fname) + + inst = RawCurry(fname, preload, on_bad_hpi_match, verbose) + if rectype in ["epochs", "evoked"]: + curry_epoch_info = _get_curry_epoch_info(fname) + inst = Epochs(inst, **curry_epoch_info) + if rectype == "evoked": + raise NotImplementedError # not sure this is even supported format + return inst class RawCurry(BaseRaw): @@ -569,9 +757,9 @@ class RawCurry(BaseRaw): Parameters ---------- fname : path-like - Path to a curry file with extensions ``.dat``, ``.dap``, ``.rs3``, - ``.cdt``, ``.cdt.dpa``, ``.cdt.cef`` or ``.cef``. + Path to a valid curry file. %(preload)s + %(on_bad_hpi_match)s %(verbose)s See Also @@ -581,39 +769,94 @@ class RawCurry(BaseRaw): """ @verbose - def __init__(self, fname, preload=False, verbose=None): - curry_paths = _get_curry_file_structure( - fname, required=["info", "data", "labels"] - ) - - data_fname = op.abspath(curry_paths["data"]) - - info, n_samples, is_ascii = _read_curry_info(curry_paths) - + def __init__(self, fname, preload=False, on_bad_hpi_match="warn", verbose=None): + fname = _check_curry_filename(fname) + + ( + sfreq, + n_samples, + ch_names, + ch_types, + ch_pos, + landmarks, + landmarkslabels, + hpimatrix, + events, + event_desc, + orig_format, + orig_units, + cals, + ) = _extract_curry_info(fname) + + meas_date, is_ascii, device_info = _get_curry_meas_info(fname) + + # construct info + info = create_info(ch_names=ch_names, sfreq=sfreq, ch_types=ch_types) + info["device_info"] = device_info + + # create raw object last_samps = [n_samples - 1] raw_extras = dict(is_ascii=is_ascii) - super().__init__( info, - preload, - filenames=[data_fname], + preload=False, + filenames=[fname], last_samps=last_samps, - orig_format="int", + orig_format=orig_format, raw_extras=[raw_extras], + orig_units=orig_units, verbose=verbose, ) - if "events" in curry_paths: - logger.info( - "Event file found. Extracting Annotations from " - f"{curry_paths['events']}..." - ) - annots = _read_annotations_curry( - curry_paths["events"], sfreq=self.info["sfreq"] - ) - self.set_annotations(annots) - else: - logger.info("Event file not found. No Annotations set.") + # set meas_date + self.set_meas_date(meas_date) + + # scale data to SI units + self._cals = np.array(cals) + if isinstance(preload, bool | np.bool_) and preload: + self.load_data() + + # set events / annotations + # format from curryreader: sample, etype, startsample, endsample + if isinstance(events, np.ndarray): # if there are events + events = events.astype("int") + events = np.insert(events, 1, np.diff(events[:, 2:]).flatten(), axis=1)[ + :, :3 + ] + annot = annotations_from_events(events, sfreq, event_desc=event_desc) + self.set_annotations(annot) + + # add HPI data (if present) + # TODO - FUTURE ENHANCEMENT + # from curryreader docstring: + # "HPI-coil measurements matrix (Orion-MEG only) where every row is: + # [measurementsample, dipolefitflag, x, y, z, deviation]" + # + # that's incorrect, though. it ratehr seems to be: + # [sample, dipole_1, x_1,y_1, z_1, dev_1, ..., dipole_n, x_n, ...] + # for all n coils. + # + # Do not implement cHPI reader for now. + # Can be used for dev-head transform, though! + if not isinstance(hpimatrix, list): + # warn("cHPI data found, but reader not implemented.") + hpisamples = hpimatrix[:, 0] + n_coil = int((hpimatrix.shape[1] - 1) / 5) + hpimatrix = hpimatrix[:, 1:].reshape(hpimatrix.shape[0], n_coil, 5) / 1000 + logger.info(f"found {len(hpisamples)} cHPI samples for {n_coil} coils") + + # add sensor locations + # TODO - REVIEW NEEDED + assert len(self.info["ch_names"]) == len(ch_types) >= len(ch_pos) + _set_chanloc_curry( + inst=self, + ch_types=ch_types, + ch_pos=ch_pos, + landmarks=landmarks, + landmarkslabels=landmarkslabels, + hpimatrix=hpimatrix, + on_bad_hpi_match=on_bad_hpi_match, + ) def _read_segment_file(self, data, idx, fi, start, stop, cals, mult): """Read a chunk of raw data.""" @@ -629,3 +872,49 @@ def _read_segment_file(self, data, idx, fi, start, stop, cals, mult): _read_segments_file( self, data, idx, fi, start, stop, cals, mult, dtype=">> events[:, 2] >>= 8 # doctest:+SKIP - TAL channels called 'EDF Annotations' or 'BDF Annotations' are parsed and + TAL channels called 'EDF Annotations' are parsed and extracted annotations are stored in raw.annotations. Use :func:`mne.events_from_annotations` to obtain events from these annotations. @@ -147,8 +170,10 @@ def __init__( *, verbose=None, ): - logger.info(f"Extracting EDF parameters from {input_fname}...") - input_fname = os.path.abspath(input_fname) + if not _file_like(input_fname): + logger.info(f"Extracting EDF parameters from {input_fname}...") + input_fname = os.path.abspath(input_fname) + info, edf_info, orig_units = _get_info( input_fname, stim_channel, @@ -156,11 +181,224 @@ def __init__( misc, exclude, infer_types, + FileType.EDF, + include, + exclude_after_unique, + ) + logger.info("Creating raw.info structure...") + edf_info["blob"] = input_fname if _file_like(input_fname) else None + + _validate_type(units, (str, None, dict), "units") + if units is None: + units = dict() + elif isinstance(units, str): + units = {ch_name: units for ch_name in info["ch_names"]} + + for k, (this_ch, this_unit) in enumerate(orig_units.items()): + if this_ch not in units: + continue + if this_unit not in ("", units[this_ch]): + raise ValueError( + f"Unit for channel {this_ch} is present in the file as " + f"{repr(this_unit)}, cannot overwrite it with the units " + f"argument {repr(units[this_ch])}." + ) + if this_unit == "": + orig_units[this_ch] = units[this_ch] + ch_type = edf_info["ch_types"][k] + scaling = _get_scaling(ch_type.lower(), orig_units[this_ch]) + edf_info["units"][k] /= scaling + + # Raw attributes + last_samps = [edf_info["nsamples"] - 1] + super().__init__( + info, preload, + filenames=[_path_from_fname(input_fname)], + raw_extras=[edf_info], + last_samps=last_samps, + orig_format="int", + orig_units=orig_units, + verbose=verbose, + ) + + # Read annotations from file and set it + if len(edf_info["tal_idx"]) > 0: + # Read TAL data exploiting the header info (no regexp) + idx = np.empty(0, int) + tal_data = self._read_segment_file( + np.empty((0, self.n_times)), + idx, + 0, + 0, + int(self.n_times), + np.ones((len(idx), 1)), + None, + ) + annotations = _read_annotations_edf( + tal_data[0], + ch_names=info["ch_names"], + encoding=encoding, + ) + self.set_annotations(annotations, on_missing="warn") + + def _read_segment_file(self, data, idx, fi, start, stop, cals, mult): + """Read a chunk of raw data.""" + return _read_segment_file( + data, + idx, + fi, + start, + stop, + self._raw_extras[fi], + self.filenames[fi] + if self._raw_extras[fi]["blob"] is None + else self._raw_extras[fi]["blob"], + cals, + mult, + ) + + +def _path_from_fname(fname) -> Path | None: + if isinstance(fname, str | Path): + return Path(fname) + + # Try to get a filename from the file-like object + try: + return Path(fname.name) + except Exception: + return None + + +@fill_doc +class RawBDF(BaseRaw): + """Raw object from BDF file. + + Parameters + ---------- + input_fname : path-like | file-like + Path to the BDF file. If a file-like object is provided, + preloading must be used. + + .. versionchanged:: 1.10 + Added support for file-like objects + eog : list or tuple + Names of channels or list of indices that should be designated EOG + channels. Values should correspond to the electrodes in the file. + Default is None. + misc : list or tuple + Names of channels or list of indices that should be designated MISC + channels. Values should correspond to the electrodes in the file. + Default is None. + stim_channel : ``'auto'`` | str | list of str | int | list of int + Defaults to ``'auto'``, which means that channels named ``'status'`` or + ``'trigger'`` (case insensitive) are set to STIM. If str (or list of + str), all channels matching the name(s) are set to STIM. If int (or + list of ints), the channels corresponding to the indices are set to + STIM. + exclude : list of str + Channel names to exclude. This can help when reading data with + different sampling rates to avoid unnecessary resampling. + infer_types : bool + If True, try to infer channel types from channel labels. If a channel + label starts with a known type (such as 'EEG') followed by a space and + a name (such as 'Fp1'), the channel type will be set accordingly, and + the channel will be renamed to the original label without the prefix. + For unknown prefixes, the type will be 'EEG' and the name will not be + modified. If False, do not infer types and assume all channels are of + type 'EEG'. + + .. versionadded:: 0.24.1 + include : list of str | str + Channel names to be included. A str is interpreted as a regular + expression. 'exclude' must be empty if include is assigned. + + .. versionadded:: 1.1 + %(preload)s + %(units_edf_bdf_io)s + %(encoding_edf)s + %(exclude_after_unique)s + %(verbose)s + + See Also + -------- + mne.io.Raw : Documentation of attributes and methods. + mne.io.read_raw_bdf : Recommended way to read BDF files. + + Notes + ----- + %(edf_resamp_note)s + + Biosemi devices trigger codes are encoded in 16-bit format, whereas system + codes (CMS in/out-of range, battery low, etc.) are coded in bits 16-23 of + the status channel (see http://www.biosemi.com/faq/trigger_signals.htm). + To retrieve correct event values (bits 1-16), one could do: + + >>> events = mne.find_events(...) # doctest:+SKIP + >>> events[:, 2] &= (2**16 - 1) # doctest:+SKIP + + The above operation can be carried out directly in :func:`mne.find_events` + using the ``mask`` and ``mask_type`` parameters (see + :func:`mne.find_events` for more details). + + It is also possible to retrieve system codes, but no particular effort has + been made to decode these in MNE. In case it is necessary, for instance to + check the CMS bit, the following operation can be carried out: + + >>> cms_bit = 20 # doctest:+SKIP + >>> cms_high = (events[:, 2] & (1 << cms_bit)) != 0 # doctest:+SKIP + + It is worth noting that in some special cases, it may be necessary to shift + event values in order to retrieve correct event triggers. This depends on + the triggering device used to perform the synchronization. For instance, in + some files events need to be shifted by 8 bits: + + >>> events[:, 2] >>= 8 # doctest:+SKIP + + TAL channels called 'BDF Annotations' are parsed and + extracted annotations are stored in raw.annotations. Use + :func:`mne.events_from_annotations` to obtain events from these + annotations. + + If channels named 'status' or 'trigger' are present, they are considered as + STIM channels by default. Use func:`mne.find_events` to parse events + encoded in such analog stim channels. + """ + + @verbose + def __init__( + self, + input_fname, + eog=None, + misc=None, + stim_channel="auto", + exclude=(), + infer_types=False, + preload=False, + include=None, + units=None, + encoding="utf8", + exclude_after_unique=False, + *, + verbose=None, + ): + if not _file_like(input_fname): + logger.info(f"Extracting BDF parameters from {input_fname}...") + input_fname = os.path.abspath(input_fname) + + info, edf_info, orig_units = _get_info( + input_fname, + stim_channel, + eog, + misc, + exclude, + infer_types, + FileType.BDF, include, exclude_after_unique, ) logger.info("Creating raw.info structure...") + edf_info["blob"] = input_fname if _file_like(input_fname) else None _validate_type(units, (str, None, dict), "units") if units is None: @@ -188,7 +426,7 @@ def __init__( super().__init__( info, preload, - filenames=[input_fname], + filenames=[_path_from_fname(input_fname)], raw_extras=[edf_info], last_samps=last_samps, orig_format="int", @@ -225,7 +463,9 @@ def _read_segment_file(self, data, idx, fi, start, stop, cals, mult): start, stop, self._raw_extras[fi], - self.filenames[fi], + self.filenames[fi] + if self._raw_extras[fi]["blob"] is None + else self._raw_extras[fi]["blob"], cals, mult, ) @@ -237,8 +477,12 @@ class RawGDF(BaseRaw): Parameters ---------- - input_fname : path-like - Path to the GDF file. + input_fname : path-like | file-like + Path to the GDF file. If a file-like object is provided, + preloading must be used. + + .. versionchanged:: 1.10 + Added support for file-like objects eog : list or tuple Names of channels or list of indices that should be designated EOG channels. Values should correspond to the electrodes in the file. @@ -289,19 +533,29 @@ def __init__( include=None, verbose=None, ): - logger.info(f"Extracting EDF parameters from {input_fname}...") - input_fname = os.path.abspath(input_fname) + if not _file_like(input_fname): + logger.info(f"Extracting GDF parameters from {input_fname}...") + input_fname = os.path.abspath(input_fname) + info, edf_info, orig_units = _get_info( - input_fname, stim_channel, eog, misc, exclude, True, preload, include + input_fname, + stim_channel, + eog, + misc, + exclude, + True, + FileType.GDF, + include, ) logger.info("Creating raw.info structure...") + edf_info["blob"] = input_fname if _file_like(input_fname) else None # Raw attributes last_samps = [edf_info["nsamples"] - 1] super().__init__( info, preload, - filenames=[input_fname], + filenames=[_path_from_fname(input_fname)], raw_extras=[edf_info], last_samps=last_samps, orig_format="int", @@ -327,7 +581,9 @@ def _read_segment_file(self, data, idx, fi, start, stop, cals, mult): start, stop, self._raw_extras[fi], - self.filenames[fi], + self.filenames[fi] + if self._raw_extras[fi]["blob"] is None + else self._raw_extras[fi]["blob"], cals, mult, ) @@ -337,7 +593,7 @@ def _read_ch(fid, subtype, samp, dtype_byte, dtype=None): """Read a number of samples for a single channel.""" # BDF if subtype == "bdf": - ch_data = np.fromfile(fid, dtype=dtype, count=samp * dtype_byte) + ch_data = read_from_file_or_buffer(fid, dtype=dtype, count=samp * dtype_byte) ch_data = ch_data.reshape(-1, 3).astype(INT32) ch_data = (ch_data[:, 0]) + (ch_data[:, 1] << 8) + (ch_data[:, 2] << 16) # 24th bit determines the sign @@ -345,7 +601,7 @@ def _read_ch(fid, subtype, samp, dtype_byte, dtype=None): # GDF data and EDF data else: - ch_data = np.fromfile(fid, dtype=dtype, count=samp) + ch_data = read_from_file_or_buffer(fid, dtype=dtype, count=samp) return ch_data @@ -379,7 +635,8 @@ def _read_segment_file(data, idx, fi, start, stop, raw_extras, filenames, cals, # Otherwise we can end up with e.g. 18,181 chunks for a 20 MB file! # Let's do ~10 MB chunks: n_per = max(10 * 1024 * 1024 // (ch_offsets[-1] * dtype_byte), 1) - with open(filenames, "rb", buffering=0) as fid: + + with _gdf_edf_get_fid(filenames, buffering=0) as fid: # Extract data start_offset = data_offset + block_start_idx * ch_offsets[-1] * dtype_byte @@ -481,13 +738,20 @@ def _read_segment_file(data, idx, fi, start, stop, raw_extras, filenames, cals, @fill_doc -def _read_header(fname, exclude, infer_types, include=None, exclude_after_unique=False): +def _read_header( + fname, + exclude, + infer_types, + file_type, + include=None, + exclude_after_unique=False, +): """Unify EDF, BDF and GDF _read_header call. Parameters ---------- fname : str - Path to the EDF+, BDF, or GDF file. + Path to the EDF+, BDF, or GDF file or file-like object. exclude : list of str | str Channel names to exclude. This can help when reading data with different sampling rates to avoid unnecessary resampling. A str is @@ -509,18 +773,19 @@ def _read_header(fname, exclude, infer_types, include=None, exclude_after_unique ------- (edf_info, orig_units) : tuple """ - ext = os.path.splitext(fname)[1][1:].lower() - logger.info(f"{ext.upper()} file detected") - if ext in ("bdf", "edf"): + if file_type in (FileType.BDF, FileType.EDF): return _read_edf_header( - fname, exclude, infer_types, include, exclude_after_unique + fname, + exclude, + infer_types, + file_type, + include, + exclude_after_unique, ) - elif ext == "gdf": + elif file_type == FileType.GDF: return _read_gdf_header(fname, exclude, include), None else: - raise NotImplementedError( - f"Only GDF, EDF, and BDF files are supported, got {ext}." - ) + raise NotImplementedError("Only GDF, EDF, and BDF files are supported.") def _get_info( @@ -530,7 +795,7 @@ def _get_info( misc, exclude, infer_types, - preload, + file_type, include=None, exclude_after_unique=False, ): @@ -539,7 +804,7 @@ def _get_info( misc = misc if misc is not None else [] edf_info, orig_units = _read_header( - fname, exclude, infer_types, include, exclude_after_unique + fname, exclude, infer_types, file_type, include, exclude_after_unique ) # XXX: `tal_ch_names` to pass to `_check_stim_channel` should be computed @@ -801,12 +1066,17 @@ def _edf_str_num(x): def _read_edf_header( - fname, exclude, infer_types, include=None, exclude_after_unique=False + fname, + exclude, + infer_types, + file_type, + include=None, + exclude_after_unique=False, ): """Read header information from EDF+ or BDF file.""" edf_info = {"events": []} - with open(fname, "rb") as fid: + with _gdf_edf_get_fid(fname) as fid: fid.read(8) # version (unused here) # patient ID @@ -877,14 +1147,20 @@ def _read_edf_header( fid.read(8) # skip the file's measurement time warn("Invalid measurement date encountered in the header.") - header_nbytes = int(_edf_str(fid.read(8))) + try: + header_nbytes = int(_edf_str(fid.read(8))) + except ValueError: + raise ValueError( + f"Bad {'EDF' if file_type == FileType.EDF else 'BDF'} file provided." + ) + # The following 44 bytes sometimes identify the file type, but this is - # not guaranteed. Therefore, we skip this field and use the file - # extension to determine the subtype (EDF or BDF, which differ in the + # not guaranteed. Therefore, we skip this field and use the file_type + # to determine the subtype (EDF or BDF, which differ in the # number of bytes they use for the data records; EDF uses 2 bytes # whereas BDF uses 3 bytes). fid.read(44) - subtype = os.path.splitext(fname)[1][1:].lower() + subtype = file_type n_records = int(_edf_str(fid.read(8))) record_length = float(_edf_str(fid.read(8))) @@ -996,7 +1272,7 @@ def _read_edf_header( physical_max=physical_max, physical_min=physical_min, record_length=record_length, - subtype=subtype, + subtype="bdf" if subtype == FileType.BDF else "edf", tal_idx=tal_idx, ) @@ -1006,7 +1282,9 @@ def _read_edf_header( fid.seek(0, 2) n_bytes = fid.tell() n_data_bytes = n_bytes - header_nbytes - total_samps = n_data_bytes // 3 if subtype == "bdf" else n_data_bytes // 2 + total_samps = ( + n_data_bytes // 3 if subtype == FileType.BDF else n_data_bytes // 2 + ) read_records = total_samps // np.sum(n_samps) if n_records != read_records: warn( @@ -1017,7 +1295,7 @@ def _read_edf_header( edf_info["n_records"] = read_records del n_records - if subtype == "bdf": + if subtype == FileType.BDF: edf_info["dtype_byte"] = 3 # 24-bit (3 byte) integers edf_info["dtype_np"] = UINT8 else: @@ -1074,10 +1352,15 @@ def _read_gdf_header(fname, exclude, include=None): """Read GDF 1.x and GDF 2.x header info.""" edf_info = dict() events = None - with open(fname, "rb") as fid: - version = fid.read(8).decode() - edf_info["type"] = edf_info["subtype"] = version[:3] - edf_info["number"] = float(version[4:]) + + with _gdf_edf_get_fid(fname) as fid: + try: + version = fid.read(8).decode() + edf_info["type"] = edf_info["subtype"] = version[:3] + edf_info["number"] = float(version[4:]) + except ValueError: + raise ValueError("Bad GDF file provided.") + meas_date = None # GDF 1.x @@ -1113,22 +1396,22 @@ def _read_gdf_header(fname, exclude, include=None): except Exception: pass - header_nbytes = np.fromfile(fid, INT64, 1)[0] - meas_id["equipment"] = np.fromfile(fid, UINT8, 8)[0] - meas_id["hospital"] = np.fromfile(fid, UINT8, 8)[0] - meas_id["technician"] = np.fromfile(fid, UINT8, 8)[0] + header_nbytes = read_from_file_or_buffer(fid, INT64, 1)[0] + meas_id["equipment"] = read_from_file_or_buffer(fid, UINT8, 8)[0] + meas_id["hospital"] = read_from_file_or_buffer(fid, UINT8, 8)[0] + meas_id["technician"] = read_from_file_or_buffer(fid, UINT8, 8)[0] fid.seek(20, 1) # 20bytes reserved - n_records = np.fromfile(fid, INT64, 1)[0] + n_records = read_from_file_or_buffer(fid, INT64, 1)[0] # record length in seconds - record_length = np.fromfile(fid, UINT32, 2) + record_length = read_from_file_or_buffer(fid, UINT32, 2) if record_length[0] == 0: record_length[0] = 1.0 warn( "Header information is incorrect for record length. " "Default record length set to 1." ) - nchan = int(np.fromfile(fid, UINT32, 1)[0]) + nchan = int(read_from_file_or_buffer(fid, UINT32, 1)[0]) channels = list(range(nchan)) ch_names = [_edf_str(fid.read(16)).strip() for ch in channels] exclude = _find_exclude_idx(ch_names, exclude, include) @@ -1146,18 +1429,18 @@ def _read_gdf_header(fname, exclude, include=None): edf_info["units"] = np.array(edf_info["units"], float) ch_names = [ch_names[idx] for idx in sel] - physical_min = np.fromfile(fid, FLOAT64, len(channels)) - physical_max = np.fromfile(fid, FLOAT64, len(channels)) - digital_min = np.fromfile(fid, INT64, len(channels)) - digital_max = np.fromfile(fid, INT64, len(channels)) + physical_min = read_from_file_or_buffer(fid, FLOAT64, len(channels)) + physical_max = read_from_file_or_buffer(fid, FLOAT64, len(channels)) + digital_min = read_from_file_or_buffer(fid, INT64, len(channels)) + digital_max = read_from_file_or_buffer(fid, INT64, len(channels)) prefiltering = [_edf_str(fid.read(80)) for ch in channels] highpass, lowpass = _parse_prefilter_string(prefiltering) # n samples per record - n_samps = np.fromfile(fid, INT32, len(channels)) + n_samps = read_from_file_or_buffer(fid, INT32, len(channels)) # channel data type - dtype = np.fromfile(fid, INT32, len(channels)) + dtype = read_from_file_or_buffer(fid, INT32, len(channels)) # total number of bytes for data bytes_tot = np.sum( @@ -1197,19 +1480,21 @@ def _read_gdf_header(fname, exclude, include=None): etp = header_nbytes + n_records * edf_info["bytes_tot"] # skip data to go to event table fid.seek(etp) - etmode = np.fromfile(fid, UINT8, 1)[0] + etmode = read_from_file_or_buffer(fid, UINT8, 1)[0] if etmode in (1, 3): - sr = np.fromfile(fid, UINT8, 3).astype(np.uint32) + sr = read_from_file_or_buffer(fid, UINT8, 3).astype(np.uint32) event_sr = sr[0] for i in range(1, len(sr)): event_sr = event_sr + sr[i] * 2 ** (i * 8) - n_events = np.fromfile(fid, UINT32, 1)[0] - pos = np.fromfile(fid, UINT32, n_events) - 1 # 1-based inds - typ = np.fromfile(fid, UINT16, n_events) + n_events = read_from_file_or_buffer(fid, UINT32, 1)[0] + pos = ( + read_from_file_or_buffer(fid, UINT32, n_events) - 1 + ) # 1-based inds + typ = read_from_file_or_buffer(fid, UINT16, n_events) if etmode == 3: - chn = np.fromfile(fid, UINT16, n_events) - dur = np.fromfile(fid, UINT32, n_events) + chn = read_from_file_or_buffer(fid, UINT16, n_events) + dur = read_from_file_or_buffer(fid, UINT32, n_events) else: chn = np.zeros(n_events, dtype=np.int32) dur = np.ones(n_events, dtype=UINT32) @@ -1234,20 +1519,20 @@ def _read_gdf_header(fname, exclude, include=None): fid.seek(10, 1) # 10bytes reserved # Smoking / Alcohol abuse / drug abuse / medication - sadm = np.fromfile(fid, UINT8, 1)[0] + sadm = read_from_file_or_buffer(fid, UINT8, 1)[0] patient["smoking"] = scale[sadm % 4] patient["alcohol_abuse"] = scale[(sadm >> 2) % 4] patient["drug_abuse"] = scale[(sadm >> 4) % 4] patient["medication"] = scale[(sadm >> 6) % 4] - patient["weight"] = np.fromfile(fid, UINT8, 1)[0] + patient["weight"] = read_from_file_or_buffer(fid, UINT8, 1)[0] if patient["weight"] == 0 or patient["weight"] == 255: patient["weight"] = None - patient["height"] = np.fromfile(fid, UINT8, 1)[0] + patient["height"] = read_from_file_or_buffer(fid, UINT8, 1)[0] if patient["height"] == 0 or patient["height"] == 255: patient["height"] = None # Gender / Handedness / Visual Impairment - ghi = np.fromfile(fid, UINT8, 1)[0] + ghi = read_from_file_or_buffer(fid, UINT8, 1)[0] patient["sex"] = gender[ghi % 4] patient["handedness"] = handedness[(ghi >> 2) % 4] patient["visual"] = scale[(ghi >> 4) % 4] @@ -1255,7 +1540,7 @@ def _read_gdf_header(fname, exclude, include=None): # Recording identification meas_id = {} meas_id["recording_id"] = _edf_str(fid.read(64)).strip() - vhsv = np.fromfile(fid, UINT8, 4) + vhsv = read_from_file_or_buffer(fid, UINT8, 4) loc = {} if vhsv[3] == 0: loc["vertpre"] = 10 * int(vhsv[0] >> 4) + int(vhsv[0] % 16) @@ -1266,12 +1551,16 @@ def _read_gdf_header(fname, exclude, include=None): loc["horzpre"] = 29 loc["size"] = 29 loc["version"] = 0 - loc["latitude"] = float(np.fromfile(fid, UINT32, 1)[0]) / 3600000 - loc["longitude"] = float(np.fromfile(fid, UINT32, 1)[0]) / 3600000 - loc["altitude"] = float(np.fromfile(fid, INT32, 1)[0]) / 100 + loc["latitude"] = ( + float(read_from_file_or_buffer(fid, UINT32, 1)[0]) / 3600000 + ) + loc["longitude"] = ( + float(read_from_file_or_buffer(fid, UINT32, 1)[0]) / 3600000 + ) + loc["altitude"] = float(read_from_file_or_buffer(fid, INT32, 1)[0]) / 100 meas_id["loc"] = loc - meas_date = np.fromfile(fid, UINT64, 1)[0] + meas_date = read_from_file_or_buffer(fid, UINT64, 1)[0] if meas_date != 0: meas_date = datetime(1, 1, 1, tzinfo=timezone.utc) + timedelta( meas_date * pow(2, -32) - 367 @@ -1279,7 +1568,7 @@ def _read_gdf_header(fname, exclude, include=None): else: meas_date = None - birthday = np.fromfile(fid, UINT64, 1).tolist()[0] + birthday = read_from_file_or_buffer(fid, UINT64, 1).tolist()[0] if birthday == 0: birthday = datetime(1, 1, 1, tzinfo=timezone.utc) else: @@ -1298,22 +1587,22 @@ def _read_gdf_header(fname, exclude, include=None): else: patient["age"] = None - header_nbytes = np.fromfile(fid, UINT16, 1)[0] * 256 + header_nbytes = read_from_file_or_buffer(fid, UINT16, 1)[0] * 256 fid.seek(6, 1) # 6 bytes reserved - meas_id["equipment"] = np.fromfile(fid, UINT8, 8) - meas_id["ip"] = np.fromfile(fid, UINT8, 6) - patient["headsize"] = np.fromfile(fid, UINT16, 3) + meas_id["equipment"] = read_from_file_or_buffer(fid, UINT8, 8) + meas_id["ip"] = read_from_file_or_buffer(fid, UINT8, 6) + patient["headsize"] = read_from_file_or_buffer(fid, UINT16, 3) patient["headsize"] = np.asarray(patient["headsize"], np.float32) patient["headsize"] = np.ma.masked_array( patient["headsize"], np.equal(patient["headsize"], 0), None ).filled() - ref = np.fromfile(fid, FLOAT32, 3) - gnd = np.fromfile(fid, FLOAT32, 3) - n_records = np.fromfile(fid, INT64, 1)[0] + ref = read_from_file_or_buffer(fid, FLOAT32, 3) + gnd = read_from_file_or_buffer(fid, FLOAT32, 3) + n_records = read_from_file_or_buffer(fid, INT64, 1)[0] # record length in seconds - record_length = np.fromfile(fid, UINT32, 2) + record_length = read_from_file_or_buffer(fid, UINT32, 2) if record_length[0] == 0: record_length[0] = 1.0 warn( @@ -1321,7 +1610,7 @@ def _read_gdf_header(fname, exclude, include=None): "Default record length set to 1." ) - nchan = int(np.fromfile(fid, UINT16, 1)[0]) + nchan = int(read_from_file_or_buffer(fid, UINT16, 1)[0]) fid.seek(2, 1) # 2bytes reserved # Channels (variable header) @@ -1339,7 +1628,7 @@ def _read_gdf_header(fname, exclude, include=None): - Decimal factors codes: https://sourceforge.net/p/biosig/svn/HEAD/tree/trunk/biosig/doc/DecimalFactors.txt """ # noqa - units = np.fromfile(fid, UINT16, len(channels)).tolist() + units = read_from_file_or_buffer(fid, UINT16, len(channels)).tolist() unitcodes = np.array(units[:]) edf_info["units"] = list() for i, unit in enumerate(units): @@ -1363,32 +1652,36 @@ def _read_gdf_header(fname, exclude, include=None): edf_info["units"] = np.array(edf_info["units"], float) ch_names = [ch_names[idx] for idx in sel] - physical_min = np.fromfile(fid, FLOAT64, len(channels)) - physical_max = np.fromfile(fid, FLOAT64, len(channels)) - digital_min = np.fromfile(fid, FLOAT64, len(channels)) - digital_max = np.fromfile(fid, FLOAT64, len(channels)) + physical_min = read_from_file_or_buffer(fid, FLOAT64, len(channels)) + physical_max = read_from_file_or_buffer(fid, FLOAT64, len(channels)) + digital_min = read_from_file_or_buffer(fid, FLOAT64, len(channels)) + digital_max = read_from_file_or_buffer(fid, FLOAT64, len(channels)) fid.seek(68 * len(channels), 1) # obsolete - lowpass = np.fromfile(fid, FLOAT32, len(channels)) - highpass = np.fromfile(fid, FLOAT32, len(channels)) - notch = np.fromfile(fid, FLOAT32, len(channels)) + lowpass = read_from_file_or_buffer(fid, FLOAT32, len(channels)) + highpass = read_from_file_or_buffer(fid, FLOAT32, len(channels)) + notch = read_from_file_or_buffer(fid, FLOAT32, len(channels)) # number of samples per record - n_samps = np.fromfile(fid, INT32, len(channels)) + n_samps = read_from_file_or_buffer(fid, INT32, len(channels)) # data type - dtype = np.fromfile(fid, INT32, len(channels)) + dtype = read_from_file_or_buffer(fid, INT32, len(channels)) channel = {} - channel["xyz"] = [np.fromfile(fid, FLOAT32, 3)[0] for ch in channels] + channel["xyz"] = [ + read_from_file_or_buffer(fid, FLOAT32, 3)[0] for ch in channels + ] if edf_info["number"] < 2.19: - impedance = np.fromfile(fid, UINT8, len(channels)).astype(float) + impedance = read_from_file_or_buffer(fid, UINT8, len(channels)).astype( + float + ) impedance[impedance == 255] = np.nan channel["impedance"] = pow(2, impedance / 8) fid.seek(19 * len(channels), 1) # reserved else: - tmp = np.fromfile(fid, FLOAT32, 5 * len(channels)) + tmp = read_from_file_or_buffer(fid, FLOAT32, 5 * len(channels)) tmp = tmp[::5] fZ = tmp[:] impedance = tmp[:] @@ -1446,22 +1739,24 @@ def _read_gdf_header(fname, exclude, include=None): etmode = np.fromstring(etmode, UINT8).tolist()[0] if edf_info["number"] < 1.94: - sr = np.fromfile(fid, UINT8, 3) + sr = read_from_file_or_buffer(fid, UINT8, 3) event_sr = sr[0] for i in range(1, len(sr)): event_sr = event_sr + sr[i] * 2 ** (i * 8) - n_events = np.fromfile(fid, UINT32, 1)[0] + n_events = read_from_file_or_buffer(fid, UINT32, 1)[0] else: - ne = np.fromfile(fid, UINT8, 3) + ne = read_from_file_or_buffer(fid, UINT8, 3) n_events = sum(int(ne[i]) << (i * 8) for i in range(len(ne))) - event_sr = np.fromfile(fid, FLOAT32, 1)[0] + event_sr = read_from_file_or_buffer(fid, FLOAT32, 1)[0] - pos = np.fromfile(fid, UINT32, n_events) - 1 # 1-based inds - typ = np.fromfile(fid, UINT16, n_events) + pos = ( + read_from_file_or_buffer(fid, UINT32, n_events) - 1 + ) # 1-based inds + typ = read_from_file_or_buffer(fid, UINT16, n_events) if etmode == 3: - chn = np.fromfile(fid, UINT16, n_events) - dur = np.fromfile(fid, UINT32, n_events) + chn = read_from_file_or_buffer(fid, UINT16, n_events) + dur = read_from_file_or_buffer(fid, UINT32, n_events) else: chn = np.zeros(n_events, dtype=np.uint32) dur = np.ones(n_events, dtype=np.uint32) @@ -1576,6 +1871,20 @@ def _find_tal_idx(ch_names): return tal_channel_idx +def _check_args(input_fname, preload, target_ext): + if not _file_like(input_fname): + input_fname = _check_fname(fname=input_fname, overwrite="read", must_exist=True) + ext = input_fname.suffix[1:].lower() + + if ext != target_ext: + raise NotImplementedError( + f"Only {target_ext.upper()} files are supported, got {ext}." + ) + else: + if not preload: + raise ValueError("preload must be used with file-like objects") + + @fill_doc def read_raw_edf( input_fname, @@ -1597,7 +1906,11 @@ def read_raw_edf( Parameters ---------- input_fname : path-like - Path to the EDF or EDF+ file. + Path to the EDF or EDF+ file or EDF/EDF+ file itself. If a file-like + object is provided, preload must be used. + + .. versionchanged:: 1.10 + Added support for file-like objects eog : list or tuple Names of channels or list of indices that should be designated EOG channels. Values should correspond to the electrodes in the file. @@ -1693,10 +2006,8 @@ def read_raw_edf( The EDF specification allows storage of subseconds in measurement date. However, this reader currently sets subseconds to 0 by default. """ - input_fname = os.path.abspath(input_fname) - ext = os.path.splitext(input_fname)[1][1:].lower() - if ext != "edf": - raise NotImplementedError(f"Only EDF files are supported, got {ext}.") + _check_args(input_fname, preload, "edf") + return RawEDF( input_fname=input_fname, eog=eog, @@ -1728,13 +2039,17 @@ def read_raw_bdf( exclude_after_unique=False, *, verbose=None, -) -> RawEDF: +) -> RawBDF: """Reader function for BDF files. Parameters ---------- - input_fname : path-like - Path to the BDF file. + input_fname : path-like | file-like + Path to the BDF file of BDF file itself. If a file-like object is + provided, preload must be used. + + .. versionchanged:: 1.10 + Added support for file-like objects eog : list or tuple Names of channels or list of indices that should be designated EOG channels. Values should correspond to the electrodes in the file. @@ -1827,11 +2142,9 @@ def read_raw_bdf( STIM channels by default. Use func:`mne.find_events` to parse events encoded in such analog stim channels. """ - input_fname = os.path.abspath(input_fname) - ext = os.path.splitext(input_fname)[1][1:].lower() - if ext != "bdf": - raise NotImplementedError(f"Only BDF files are supported, got {ext}.") - return RawEDF( + _check_args(input_fname, preload, "bdf") + + return RawBDF( input_fname=input_fname, eog=eog, misc=misc, @@ -1862,8 +2175,12 @@ def read_raw_gdf( Parameters ---------- - input_fname : path-like - Path to the GDF file. + input_fname : path-like | file-like + Path to the GDF file or GDF file itself. If a file-like object is + provided, preload must be used. + + .. versionchanged:: 1.10 + Added support for file-like objects eog : list or tuple Names of channels or list of indices that should be designated EOG channels. Values should correspond to the electrodes in the file. @@ -1905,10 +2222,8 @@ def read_raw_gdf( STIM channels by default. Use func:`mne.find_events` to parse events encoded in such analog stim channels. """ - input_fname = os.path.abspath(input_fname) - ext = os.path.splitext(input_fname)[1][1:].lower() - if ext != "gdf": - raise NotImplementedError(f"Only GDF files are supported, got {ext}.") + _check_args(input_fname, preload, "gdf") + return RawGDF( input_fname=input_fname, eog=eog, diff --git a/mne/io/edf/tests/test_edf.py b/mne/io/edf/tests/test_edf.py index ce671ca7e81..1760081bac4 100644 --- a/mne/io/edf/tests/test_edf.py +++ b/mne/io/edf/tests/test_edf.py @@ -5,6 +5,7 @@ import datetime from contextlib import nullcontext from functools import partial +from io import BytesIO from pathlib import Path import numpy as np @@ -174,7 +175,7 @@ def test_bdf_data(): test_scaling=test_scaling, ) assert len(raw_py.ch_names) == 71 - assert "RawEDF" in repr(raw_py) + assert "RawBDF" in repr(raw_py) picks = pick_types(raw_py.info, meg=False, eeg=True, exclude="bads") data_py, _ = raw_py[picks] @@ -958,11 +959,17 @@ def test_degenerate(): read_raw_edf, read_raw_bdf, read_raw_gdf, - partial(_read_header, exclude=(), infer_types=False), ): with pytest.raises(NotImplementedError, match="Only.*txt.*"): func(edf_txt_stim_channel_path) + with pytest.raises( + NotImplementedError, match="Only GDF, EDF, and BDF files are supported." + ): + partial(_read_header, exclude=(), infer_types=False, file_type=4)( + edf_txt_stim_channel_path + ) + def test_exclude(): """Test exclude parameter.""" @@ -1208,3 +1215,49 @@ def test_anonymization(): assert bday == datetime.date(1967, 10, 9) raw.anonymize() assert raw.info["subject_info"]["birthday"] != bday + + +@pytest.mark.filterwarnings( + "ignore:Invalid measurement date encountered in the header." +) +@testing.requires_testing_data +def test_bdf_read_from_bad_file_like(): + """Test that RawEDF is NOT able to read from file-like objects for non BDF files.""" + with pytest.raises(Exception, match="Bad BDF file provided."): + with open(edf_txt_stim_channel_path, "rb") as blob: + read_raw_bdf(BytesIO(blob.read()), preload=True) + + +@testing.requires_testing_data +def test_bdf_read_from_file_like(): + """Test that RawEDF is able to read from file-like objects for BDF files.""" + with open(bdf_path, "rb") as blob: + raw = read_raw_bdf(BytesIO(blob.read()), preload=True) + assert len(raw.ch_names) == 73 + + +@pytest.mark.filterwarnings( + "ignore:Invalid measurement date encountered in the header." +) +@testing.requires_testing_data +def test_edf_read_from_bad_file_like(): + """Test that RawEDF is NOT able to read from file-like objects for non EDF files.""" + with pytest.raises(Exception, match="Bad EDF file provided."): + with open(edf_txt_stim_channel_path, "rb") as blob: + read_raw_edf(BytesIO(blob.read()), preload=True) + + +@testing.requires_testing_data +def test_edf_read_from_file_like(): + """Test that RawEDF is able to read from file-like objects for EDF files.""" + with open(edf_path, "rb") as blob: + raw = read_raw_edf(BytesIO(blob.read()), preload=True) + channels = [ + *[f"{prefix}{num}" for prefix in "ABCDEFGH" for num in range(1, 17)], + *[f"I{num}" for num in range(1, 9)], + "Ergo-Left", + "Ergo-Right", + "Status", + ] + + assert raw.ch_names == channels diff --git a/mne/io/edf/tests/test_gdf.py b/mne/io/edf/tests/test_gdf.py index 1dc5dc00a47..7b1d03f1960 100644 --- a/mne/io/edf/tests/test_gdf.py +++ b/mne/io/edf/tests/test_gdf.py @@ -4,10 +4,12 @@ import shutil from datetime import date, datetime, timedelta, timezone +from io import BytesIO import numpy as np +import pytest import scipy.io as sio -from numpy.testing import assert_array_almost_equal, assert_array_equal, assert_equal +from numpy.testing import assert_allclose, assert_array_equal, assert_equal from mne import events_from_annotations, find_events, pick_types from mne.datasets import testing @@ -68,7 +70,7 @@ def test_gdf_data(): data_biosig = raw_biosig[picks] # Assert data are almost equal - assert_array_almost_equal(data, data_biosig, 8) + assert_allclose(data, data_biosig, rtol=1e-8) # Test for events assert len(raw.annotations.duration == 963) @@ -127,7 +129,7 @@ def test_gdf2_data(): data_biosig = data_biosig[picks] # Assert data are almost equal - assert_array_almost_equal(data, data_biosig, 8) + assert_allclose(data, data_biosig, rtol=1e-8) # Find events events = find_events(raw, verbose=1) @@ -181,3 +183,19 @@ def test_gdf_include(): gdf1_path.with_name(gdf1_path.name + ".gdf"), include=("FP1", "O1") ) assert sorted(raw.ch_names) == ["FP1", "O1"] + + +@testing.requires_testing_data +def test_gdf_read_from_file_like(): + """Test that RawGDF is able to read from file-like objects for GDF files.""" + channels = "FP1 FP2 F5 AFz F6 T7 Cz T8 P7 P3 Pz P4 P8 O1 Oz O2".split() + fname = gdf1_path.with_name(gdf1_path.name + ".gdf") + with open(fname, "rb") as blob: + raw = read_raw_gdf(blob, preload=True) + assert raw.ch_names == channels + data = raw.get_data() + data_2 = read_raw_gdf(fname, preload=True).get_data() + assert_allclose(data, data_2) + + with pytest.raises(Exception, match="Bad GDF file provided."): + read_raw_gdf(BytesIO(), preload=True) diff --git a/mne/utils/config.py b/mne/utils/config.py index 620d356c666..9d5cf9f3c40 100644 --- a/mne/utils/config.py +++ b/mne/utils/config.py @@ -837,6 +837,7 @@ def sys_info( "neo", "eeglabio", "edfio", + "curryreader", "mffpy", "pybv", "antio", diff --git a/mne/utils/docs.py b/mne/utils/docs.py index 337d50b462a..898fb9c34e1 100644 --- a/mne/utils/docs.py +++ b/mne/utils/docs.py @@ -3035,6 +3035,13 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): .. versionadded:: 0.12 """ +docdict["on_bad_hpi_match"] = """ +on_bad_hpi_match : str + Can be ``'raise'`` to raise an error, ``'warn'`` (default) to emit a + warning, or ``'ignore'`` to ignore when there is poor matching of HPI coordinates + (>10mm difference) for device - head transform. +""" + docdict["on_baseline_ica"] = """ on_baseline : str How to handle baseline-corrected epochs or evoked data. diff --git a/pyproject.toml b/pyproject.toml index 38ec45e76f0..84536bf174b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -130,6 +130,7 @@ full = ["mne[full-no-qt]", "PyQt6 != 6.6.0", "PyQt6-Qt6 != 6.6.0, != 6.7.0"] # and mne[full-pyside6], which will install PySide6 instead of PyQt6. full-no-qt = [ "antio >= 0.5.0", + "curryreader >= 0.1.2", "darkdetect", "defusedxml", "dipy", diff --git a/tools/install_pre_requirements.sh b/tools/install_pre_requirements.sh index 2e13a6b2dfe..c2b60366a18 100755 --- a/tools/install_pre_requirements.sh +++ b/tools/install_pre_requirements.sh @@ -61,7 +61,7 @@ python -m pip install $STD_ARGS \ git+https://github.com/python-quantities/python-quantities \ trame trame-vtk trame-vuetify jupyter ipyevents ipympl openmeeg \ imageio-ffmpeg xlrd mffpy traitlets pybv eeglabio defusedxml \ - antio + antio curryreader echo "::endgroup::" echo "::group::Make sure we're on a NumPy 2.0 variant" diff --git a/tools/vulture_allowlist.py b/tools/vulture_allowlist.py index 881d265d2d2..0a3e58c059a 100644 --- a/tools/vulture_allowlist.py +++ b/tools/vulture_allowlist.py @@ -82,6 +82,7 @@ _cleanup_agg _notebook_vtk_works _.drop_inds_ +_.required # mne/io/ant/tests/test_ant.py andy_101