Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/changes/dev/13408.newfeature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add support for multi-wavelength NIRS processing to :func:`mne.preprocessing.nirs.beer_lambert_law`, :func:`mne.preprocessing.nirs.scalp_coupling_index`, and SNIRF reader :func:`mne.io.read_raw_snirf`, by :newcontrib:`Tamas Fehervari`.
1 change: 1 addition & 0 deletions doc/changes/names.inc
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,7 @@
.. _Sébastien Marti: https://www.researchgate.net/profile/Sebastien-Marti
.. _T. Wang: https://github.com/twang5
.. _Tal Linzen: https://tallinzen.net/
.. _Tamas Fehervari: https://github.com/zEdS15B3GCwq
.. _Teon Brooks: https://github.com/teonbrooks
.. _Tharupahan Jayawardana: https://github.com/tharu-jwd
.. _Thomas Binns: https://github.com/tsbinns
Expand Down
4 changes: 2 additions & 2 deletions mne/datasets/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@
# update the checksum in the MNE_DATASETS dict below, and change version
# here: ↓↓↓↓↓↓↓↓
RELEASES = dict(
testing="0.170",
testing="0.171",
misc="0.27",
phantom_kit="0.2",
ucl_opm_auditory="0.2",
Expand Down Expand Up @@ -115,7 +115,7 @@
# Testing and misc are at the top as they're updated most often
MNE_DATASETS["testing"] = dict(
archive_name=f"{TESTING_VERSIONED}.tar.gz",
hash="md5:ebd873ea89507cf5a75043f56119d22b",
hash="md5:138caf29bd8a9b0a6b6ea43d92c16201",
url=(
"https://codeload.github.com/mne-tools/mne-testing-data/"
f"tar.gz/{RELEASES['testing']}"
Expand Down
7 changes: 3 additions & 4 deletions mne/io/snirf/_snirf.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,14 +148,13 @@ def __init__(
# Extract wavelengths
fnirs_wavelengths = np.array(dat.get("nirs/probe/wavelengths"))
fnirs_wavelengths = [int(w) for w in fnirs_wavelengths]
if len(fnirs_wavelengths) != 2:
if len(fnirs_wavelengths) < 2:
raise RuntimeError(
f"The data contains "
f"{len(fnirs_wavelengths)}"
f" wavelengths: {fnirs_wavelengths}. "
f"MNE only supports reading continuous"
" wave amplitude SNIRF files "
"with two wavelengths."
f"MNE requires at least two wavelengths for "
"continuous wave amplitude SNIRF files."
)

# Extract channels
Expand Down
28 changes: 27 additions & 1 deletion mne/io/snirf/tests/test_snirf.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,19 @@

import numpy as np
import pytest
from numpy.testing import assert_allclose, assert_almost_equal, assert_equal
from numpy.testing import (
assert_allclose,
assert_almost_equal,
assert_array_equal,
assert_equal,
)

from mne._fiff.constants import FIFF
from mne.datasets.testing import data_path, requires_testing_data
from mne.io import read_raw_nirx, read_raw_snirf
from mne.io.tests.test_raw import _test_raw_reader
from mne.preprocessing.nirs import (
_channel_frequencies,
_reorder_nirx,
beer_lambert_law,
optical_density,
Expand Down Expand Up @@ -68,6 +74,11 @@
# GowerLabs
lumo110 = testing_path / "SNIRF" / "GowerLabs" / "lumomat-1-1-0.snirf"

# Shimadzu Labnirs 3-wavelength converted to snirf using custom tool
labnirs_multi_wavelength = (
testing_path / "SNIRF" / "Labnirs" / "labnirs_3wl_raw_recording.snirf"
)


def _get_loc(raw, ch_name):
return raw.copy().pick(ch_name).info["chs"][0]["loc"]
Expand All @@ -88,6 +99,7 @@ def _get_loc(raw, ch_name):
nirx_nirsport2_103_2,
kernel_hb,
lumo110,
labnirs_multi_wavelength,
]
),
)
Expand Down Expand Up @@ -574,3 +586,17 @@ def test_sample_rate_jitter(tmp_path):
f.create_dataset("nirs/data1/time", data=unacceptable_time_jitter)
with pytest.warns(RuntimeWarning, match="non-uniformly-sampled data"):
read_raw_snirf(new_file, verbose=True)


@requires_testing_data
def test_snirf_multiple_wavelengths():
"""Test importing synthetic SNIRF files with >=3 wavelengths."""
raw = read_raw_snirf(labnirs_multi_wavelength, preload=True)
assert raw._data.shape == (45, 250)
assert raw.info["sfreq"] == pytest.approx(19.6, abs=0.01)
assert raw.info["ch_names"][:3] == ["S2_D2 780", "S2_D2 805", "S2_D2 830"]
assert len(raw.ch_names) == 45
freqs = np.unique(_channel_frequencies(raw.info))
assert_array_equal(freqs, [780, 805, 830])
distances = source_detector_distances(raw.info)
assert len(distances) == len(raw.ch_names)
92 changes: 69 additions & 23 deletions mne/preprocessing/nirs/_beer_lambert_law.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from ..._fiff.constants import FIFF
from ...io import BaseRaw
from ...utils import _validate_type, pinv, warn
from ..nirs import _validate_nirs_info, source_detector_distances
from ..nirs import _channel_frequencies, _validate_nirs_info, source_detector_distances


def beer_lambert_law(raw, ppf=6.0):
Expand All @@ -36,23 +36,46 @@ def beer_lambert_law(raw, ppf=6.0):
_validate_type(raw, BaseRaw, "raw")
_validate_type(ppf, ("numeric", "array-like"), "ppf")
ppf = np.array(ppf, float)
if ppf.ndim == 0: # upcast single float to shape (2,)
ppf = np.array([ppf, ppf])
if ppf.shape != (2,):
picks = _validate_nirs_info(raw.info, fnirs="od", which="Beer-lambert")

# Use nominal channel frequencies
#
# Notes on implementation:
# 1. Frequencies are calculated the same way as in nirs._validate_nirs_info().
# 2. Wavelength values in the info structure may contain actual frequencies,
# which may be used for more accurate calculation in the future.
# 3. nirs._channel_frequencies uses both cw_amplitude and OD data to determine
# frequencies, whereas we only need those from OD here. Is there any chance
# that they're different?
# 4. If actual frequencies were used, using np.unique() like below will lead to
# errors. Instead, absorption coefficients will need to be calculated for
# each individual frequency.
freqs = _channel_frequencies(raw.info)

# Get unique wavelengths and determine number of wavelengths
unique_freqs = np.unique(freqs)
n_wavelengths = len(unique_freqs)

# PPF validation for multiple wavelengths
if ppf.ndim == 0: # single float
# same PPF for all wavelengths, shape (n_wavelengths, 1)
ppf = np.full((n_wavelengths, 1), ppf)
elif ppf.ndim == 1 and len(ppf) == n_wavelengths:
# separate ppf for each wavelength
ppf = ppf[:, np.newaxis] # shape (n_wavelengths, 1)
else:
raise ValueError(
f"ppf must be float or array-like of shape (2,), got shape {ppf.shape}"
f"ppf must be a single float or an array-like of length {n_wavelengths} "
f"(number of wavelengths), got shape {ppf.shape}"
)
ppf = ppf[:, np.newaxis] # shape (2, 1)
picks = _validate_nirs_info(raw.info, fnirs="od", which="Beer-lambert")
# This is the one place we *really* need the actual/accurate frequencies
freqs = np.array([raw.info["chs"][pick]["loc"][9] for pick in picks], float)
abs_coef = _load_absorption(freqs)

abs_coef = _load_absorption(unique_freqs) # shape (n_wavelengths, 2)
distances = source_detector_distances(raw.info, picks="all")
bad = ~np.isfinite(distances[picks])
bad |= distances[picks] <= 0
if bad.any():
warn(
"Source-detector distances are zero on NaN, some resulting "
"Source-detector distances are zero or NaN, some resulting "
"concentrations will be zero. Consider setting a montage "
"with raw.set_montage."
)
Expand All @@ -64,20 +87,41 @@ def beer_lambert_law(raw, ppf=6.0):
"likely due to optode locations being stored in a "
" unit other than meters."
)

rename = dict()
for ii, jj in zip(picks[::2], picks[1::2]):
EL = abs_coef * distances[ii] * ppf
channels_to_drop_all = [] # Accumulate all channels to drop

# Iterate over channel groups ([Si_Di all wavelengths, Sj_Dj all wavelengths, ...])
for ii in range(0, len(picks), n_wavelengths):
group_picks = picks[ii : ii + n_wavelengths]
# Calculate Δc based on the system: ΔOD = E * L * PPF * Δc
# where E is (n_wavelengths, 2), Δc is (2, n_timepoints)
# using pseudo-inverse
EL = abs_coef * distances[group_picks[0]] * ppf
iEL = pinv(EL)
conc_data = iEL @ raw._data[group_picks] * 1e-3

raw._data[[ii, jj]] = iEL @ raw._data[[ii, jj]] * 1e-3
# Replace the first two channels with HbO and HbR
raw._data[group_picks[:2]] = conc_data[:2] # HbO, HbR

# Update channel information
coil_dict = dict(hbo=FIFF.FIFFV_COIL_FNIRS_HBO, hbr=FIFF.FIFFV_COIL_FNIRS_HBR)
for ki, kind in zip((ii, jj), ("hbo", "hbr")):
for ki, kind in zip(group_picks[:2], ("hbo", "hbr")):
ch = raw.info["chs"][ki]
ch.update(coil_type=coil_dict[kind], unit=FIFF.FIFF_UNIT_MOL)
new_name = f"{ch['ch_name'].split(' ')[0]} {kind}"
rename[ch["ch_name"]] = new_name

# Accumulate extra wavelength channels to drop (keep only HbO and HbR)
if n_wavelengths > 2:
channels_to_drop = group_picks[2:]
channel_names_to_drop = [raw.ch_names[idx] for idx in channels_to_drop]
channels_to_drop_all.extend(channel_names_to_drop)

# Drop all accumulated extra wavelength channels after processing all groups
if channels_to_drop_all:
raw.drop_channels(channels_to_drop_all)

raw.rename_channels(rename)

# Validate the format of data after transformation is valid
Expand All @@ -95,7 +139,9 @@ def _load_absorption(freqs):
# save('extinction_coef.mat', 'extinct_coef')
#
# Returns data as [[HbO2(freq1), Hb(freq1)],
# [HbO2(freq2), Hb(freq2)]]
# [HbO2(freq2), Hb(freq2)],
# ...,
# [HbO2(freqN), Hb(freqN)]]
extinction_fname = op.join(
op.dirname(__file__), "..", "..", "data", "extinction_coef.mat"
)
Expand All @@ -104,12 +150,12 @@ def _load_absorption(freqs):
interp_hbo = interp1d(a[:, 0], a[:, 1], kind="linear")
interp_hb = interp1d(a[:, 0], a[:, 2], kind="linear")

ext_coef = np.array(
[
[interp_hbo(freqs[0]), interp_hb(freqs[0])],
[interp_hbo(freqs[1]), interp_hb(freqs[1])],
]
)
abs_coef = ext_coef * 0.2303
# Build coefficient matrix for all wavelengths
# Shape: (n_wavelengths, 2) where columns are [HbO2, Hb]
ext_coef = np.zeros((len(freqs), 2))
for i, freq in enumerate(freqs):
ext_coef[i, 0] = interp_hbo(freq) # HbO2
ext_coef[i, 1] = interp_hb(freq) # Hb

abs_coef = ext_coef * 0.2303
return abs_coef
36 changes: 28 additions & 8 deletions mne/preprocessing/nirs/_scalp_coupling_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from ...io import BaseRaw
from ...utils import _validate_type, verbose
from ..nirs import _validate_nirs_info
from ..nirs import _channel_frequencies, _validate_nirs_info


@verbose
Expand Down Expand Up @@ -56,14 +56,34 @@ def scalp_coupling_index(
verbose=verbose,
).get_data()

# Determine number of wavelengths per source-detector pair
# We use nominal wavelengths as the info structure may contain arbitrary data.
freqs = _channel_frequencies(raw.info)
n_wavelengths = len(np.unique(freqs))

sci = np.zeros(picks.shape)
for ii in range(0, len(picks), 2):
with np.errstate(invalid="ignore"):
c = np.corrcoef(filtered_data[ii], filtered_data[ii + 1])[0][1]
if not np.isfinite(c): # someone had std=0
c = 0
sci[ii] = c
sci[ii + 1] = c

# Calculate all pairwise correlations within each group and use the minimum as SCI
pair_indices = np.triu_indices(n_wavelengths, k=1)

for gg in range(0, len(picks), n_wavelengths):
group_data = filtered_data[gg : gg + n_wavelengths]

# Calculate pairwise correlations within the group
correlations = np.zeros(pair_indices[0].shape[0])

for n, (ii, jj) in enumerate(zip(*pair_indices)):
with np.errstate(invalid="ignore"):
c = np.corrcoef(group_data[ii], group_data[jj])[0][1]
if np.isfinite(c):
correlations[n] = c

# Use minimum correlation as SCI
group_sci = correlations.min()

# Assign the same SCI value to all channels in the group
sci[gg : gg + n_wavelengths] = group_sci

sci[zero_mask] = 0
sci = sci[np.argsort(picks)] # restore original order
return sci
Loading
Loading