diff --git a/eegdash/__init__.py b/eegdash/__init__.py index 86bde0ca..08069bed 100644 --- a/eegdash/__init__.py +++ b/eegdash/__init__.py @@ -18,4 +18,4 @@ __all__ = ["EEGDash", "EEGDashDataset", "EEGChallengeDataset", "preprocessing"] -__version__ = "0.4.1" +__version__ = "0.5.0" diff --git a/eegdash/features/__init__.py b/eegdash/features/__init__.py index b0acd3c0..315fde09 100644 --- a/eegdash/features/__init__.py +++ b/eegdash/features/__init__.py @@ -15,17 +15,13 @@ UnivariateFeature, ) from .feature_bank import ( # Complexity; Connectivity; CSP; Dimensionality; Signal; Spectral - CoherenceFeatureExtractor, CommonSpatialPattern, - DBSpectralFeatureExtractor, - EntropyFeatureExtractor, - HilbertFeatureExtractor, - NormalizedSpectralFeatureExtractor, - SpectralFeatureExtractor, complexity_approx_entropy, + complexity_entropy_preprocessor, complexity_lempel_ziv, complexity_sample_entropy, complexity_svd_entropy, + connectivity_coherency_preprocessor, connectivity_imaginary_coherence, connectivity_lagged_coherence, connectivity_magnitude_square_coherence, @@ -35,6 +31,7 @@ dimensionality_katz_fractal_dim, dimensionality_petrosian_fractal_dim, signal_decorrelation_time, + signal_hilbert_preprocessor, signal_hjorth_activity, signal_hjorth_complexity, signal_hjorth_mobility, @@ -49,18 +46,21 @@ signal_variance, signal_zero_crossings, spectral_bands_power, + spectral_db_preprocessor, spectral_edge, spectral_entropy, spectral_hjorth_activity, spectral_hjorth_complexity, spectral_hjorth_mobility, spectral_moment, + spectral_normalized_preprocessor, + spectral_preprocessor, spectral_root_total_power, spectral_slope, ) from .inspect import ( - get_all_feature_extractors, get_all_feature_kinds, + get_all_feature_preprocessors, get_all_features, get_feature_kind, get_feature_predecessors, @@ -82,7 +82,7 @@ "MultivariateFeature", "TrainableFeature", "UnivariateFeature", - "get_all_feature_extractors", + "get_all_feature_preprocessors", "get_all_feature_kinds", "get_all_features", "get_feature_kind", @@ -92,13 +92,13 @@ "fit_feature_extractors", # Feature part # Complexity - "EntropyFeatureExtractor", + "complexity_entropy_preprocessor", "complexity_approx_entropy", "complexity_sample_entropy", "complexity_svd_entropy", "complexity_lempel_ziv", # Connectivity - "CoherenceFeatureExtractor", + "connectivity_coherency_preprocessor", "connectivity_magnitude_square_coherence", "connectivity_imaginary_coherence", "connectivity_lagged_coherence", @@ -111,7 +111,7 @@ "dimensionality_hurst_exp", "dimensionality_detrended_fluctuation_analysis", # Signal - "HilbertFeatureExtractor", + "signal_hilbert_preprocessor", "signal_mean", "signal_variance", "signal_skewness", @@ -127,9 +127,9 @@ "signal_hjorth_complexity", "signal_decorrelation_time", # Spectral - "SpectralFeatureExtractor", - "NormalizedSpectralFeatureExtractor", - "DBSpectralFeatureExtractor", + "spectral_preprocessor", + "spectral_normalized_preprocessor", + "spectral_db_preprocessor", "spectral_root_total_power", "spectral_moment", "spectral_entropy", diff --git a/eegdash/features/decorators.py b/eegdash/features/decorators.py index 5a63fb89..2d7d8a67 100644 --- a/eegdash/features/decorators.py +++ b/eegdash/features/decorators.py @@ -1,10 +1,9 @@ from collections.abc import Callable -from typing import List, Type +from typing import List from .extractors import ( BivariateFeature, DirectedBivariateFeature, - FeatureExtractor, MultivariateFeature, UnivariateFeature, _get_underlying_func, @@ -22,26 +21,26 @@ class FeaturePredecessor: """A decorator to specify parent extractors for a feature function. - This decorator attaches a list of parent extractor types to a feature - extraction function. This information can be used to build a dependency - graph of features. + This decorator attaches a list of immediate parent preprocessing steps to a feature + extraction function. This information can be used to build a dependency graph of + features. Parameters ---------- *parent_extractor_type : list of Type - A list of feature extractor classes (subclasses of - :class:`~eegdash.features.extractors.FeatureExtractor`) that this - feature depends on. + A list of preprocessing functions (subclasses of + :class:`~collections.abc.Callable` or None) that this feature immediately depends + on. """ - def __init__(self, *parent_extractor_type: List[Type]): - parent_cls = parent_extractor_type - if not parent_cls: - parent_cls = [FeatureExtractor] - for p_cls in parent_cls: - assert issubclass(p_cls, FeatureExtractor) - self.parent_extractor_type = parent_cls + def __init__(self, *parent_extractor_type: List[Callable | None]): + parent_func = parent_extractor_type + if not parent_func: + parent_func = [None] + for p_func in parent_func: + assert p_func is None or callable(p_func) + self.parent_extractor_type = parent_func def __call__(self, func: Callable) -> Callable: """Apply the decorator to a function. diff --git a/eegdash/features/extractors.py b/eegdash/features/extractors.py index 451f1636..b651bccd 100644 --- a/eegdash/features/extractors.py +++ b/eegdash/features/extractors.py @@ -102,28 +102,28 @@ class FeatureExtractor(TrainableFeature): feature_extractors : dict[str, callable] A dictionary where keys are feature names and values are the feature extraction functions or other `FeatureExtractor` instances. - **preprocess_kwargs - Keyword arguments to be passed to the `preprocess` method. + preprocessor + A shared preprocessing function for all child feature extraction functions. """ def __init__( - self, feature_extractors: Dict[str, Callable], **preprocess_kwargs: Dict + self, + feature_extractors: Dict[str, Callable], + preprocessor: Callable | None = None, ): + self.preprocessor = preprocessor self.feature_extractors_dict = self._validate_execution_tree(feature_extractors) self._is_trainable = self._check_is_trainable(feature_extractors) super().__init__() # bypassing FeaturePredecessor to avoid circular import if not hasattr(self, "parent_extractor_type"): - self.parent_extractor_type = [FeatureExtractor] - - self.preprocess_kwargs = preprocess_kwargs - if self.preprocess_kwargs is None: - self.preprocess_kwargs = dict() - self.features_kwargs = { - "preprocess_kwargs": preprocess_kwargs, - } + self.parent_extractor_type = [None] + + self.features_kwargs = dict() + if preprocessor is not None and isinstance(preprocessor, partial): + self.features_kwargs["preprocess_kwargs"] = preprocessor.args for fn, fe in feature_extractors.items(): if isinstance(fe, FeatureExtractor): self.features_kwargs[fn] = fe.features_kwargs @@ -132,12 +132,21 @@ def __init__( def _validate_execution_tree(self, feature_extractors: dict) -> dict: """Validate the feature dependency graph.""" + preprocessor = ( + None + if self.preprocessor is None + else _get_underlying_func(self.preprocessor) + ) for fname, f in feature_extractors.items(): + if isinstance(f, FeatureExtractor): + f = f.preprocessor f = _get_underlying_func(f) - pe_type = getattr(f, "parent_extractor_type", [FeatureExtractor]) - if type(self) not in pe_type: + pe_type = getattr(f, "parent_extractor_type", [None]) + if preprocessor not in pe_type: + parent = getattr(preprocessor, "__name__", preprocessor) + child = getattr(f, "__name__", f) raise TypeError( - f"Feature '{fname}' cannot be a child of {type(self).__name__}" + f"Feature '{fname}: {child}' cannot be a child of {parent}" ) return feature_extractors @@ -151,15 +160,13 @@ def _check_is_trainable(self, feature_extractors: dict) -> bool: return True return False - def preprocess(self, *x, **kwargs): + def preprocess(self, *x): """Apply pre-processing to the input data. Parameters ---------- *x : tuple Input data. - **kwargs - Additional keyword arguments. Returns ------- @@ -167,7 +174,10 @@ def preprocess(self, *x, **kwargs): The pre-processed data. """ - return (*x,) + if self.preprocessor is None: + return (*x,) + else: + return self.preprocessor(*x) def __call__(self, *x, _batch_size=None, _ch_names=None) -> dict: """Apply all feature extractors to the input data. @@ -193,7 +203,7 @@ def __call__(self, *x, _batch_size=None, _ch_names=None) -> dict: if self._is_trainable: super().__call__() results_dict = dict() - z = self.preprocess(*x, **self.preprocess_kwargs) + z = self.preprocess(*x) if not isinstance(z, tuple): z = (z,) for fname, f in self.feature_extractors_dict.items(): @@ -227,26 +237,29 @@ def clear(self): if not self._is_trainable: return for f in self.feature_extractors_dict.values(): - if isinstance(_get_underlying_func(f), TrainableFeature): - _get_underlying_func(f).clear() + f = _get_underlying_func(f) + if isinstance(f, TrainableFeature): + f.clear() def partial_fit(self, *x, y=None): """Partially fit all trainable sub-features.""" if not self._is_trainable: return - z = self.preprocess(*x, **self.preprocess_kwargs) + z = self.preprocess(*x) if not isinstance(z, tuple): z = (z,) for f in self.feature_extractors_dict.values(): - if isinstance(_get_underlying_func(f), TrainableFeature): - _get_underlying_func(f).partial_fit(*z, y=y) + f = _get_underlying_func(f) + if isinstance(f, TrainableFeature): + f.partial_fit(*z, y=y) def fit(self): """Fit all trainable sub-features.""" if not self._is_trainable: return for f in self.feature_extractors_dict.values(): - if isinstance(_get_underlying_func(f), TrainableFeature): + f = _get_underlying_func(f) + if isinstance(f, TrainableFeature): f.fit() super().fit() diff --git a/eegdash/features/feature_bank/__init__.py b/eegdash/features/feature_bank/__init__.py index 23f90236..9048ae38 100644 --- a/eegdash/features/feature_bank/__init__.py +++ b/eegdash/features/feature_bank/__init__.py @@ -6,14 +6,14 @@ """ from .complexity import ( - EntropyFeatureExtractor, complexity_approx_entropy, + complexity_entropy_preprocessor, complexity_lempel_ziv, complexity_sample_entropy, complexity_svd_entropy, ) from .connectivity import ( - CoherenceFeatureExtractor, + connectivity_coherency_preprocessor, connectivity_imaginary_coherence, connectivity_lagged_coherence, connectivity_magnitude_square_coherence, @@ -27,8 +27,8 @@ dimensionality_petrosian_fractal_dim, ) from .signal import ( - HilbertFeatureExtractor, signal_decorrelation_time, + signal_hilbert_preprocessor, signal_hjorth_activity, signal_hjorth_complexity, signal_hjorth_mobility, @@ -44,29 +44,29 @@ signal_zero_crossings, ) from .spectral import ( - DBSpectralFeatureExtractor, - NormalizedSpectralFeatureExtractor, - SpectralFeatureExtractor, spectral_bands_power, + spectral_db_preprocessor, spectral_edge, spectral_entropy, spectral_hjorth_activity, spectral_hjorth_complexity, spectral_hjorth_mobility, spectral_moment, + spectral_normalized_preprocessor, + spectral_preprocessor, spectral_root_total_power, spectral_slope, ) __all__ = [ # Complexity - "EntropyFeatureExtractor", + "complexity_entropy_preprocessor", "complexity_approx_entropy", "complexity_sample_entropy", "complexity_svd_entropy", "complexity_lempel_ziv", # Connectivity - "CoherenceFeatureExtractor", + "connectivity_coherency_preprocessor", "connectivity_magnitude_square_coherence", "connectivity_imaginary_coherence", "connectivity_lagged_coherence", @@ -79,7 +79,7 @@ "dimensionality_hurst_exp", "dimensionality_detrended_fluctuation_analysis", # Signal - "HilbertFeatureExtractor", + "signal_hilbert_preprocessor", "signal_mean", "signal_variance", "signal_skewness", @@ -95,9 +95,9 @@ "signal_hjorth_complexity", "signal_decorrelation_time", # Spectral - "SpectralFeatureExtractor", - "NormalizedSpectralFeatureExtractor", - "DBSpectralFeatureExtractor", + "spectral_preprocessor", + "spectral_normalized_preprocessor", + "spectral_db_preprocessor", "spectral_root_total_power", "spectral_moment", "spectral_entropy", diff --git a/eegdash/features/feature_bank/complexity.py b/eegdash/features/feature_bank/complexity.py index 3aeef2fb..3198bd76 100644 --- a/eegdash/features/feature_bank/complexity.py +++ b/eegdash/features/feature_bank/complexity.py @@ -3,11 +3,10 @@ from sklearn.neighbors import KDTree from ..decorators import FeaturePredecessor, univariate_feature -from ..extractors import FeatureExtractor from .signal import SIGNAL_PREDECESSORS __all__ = [ - "EntropyFeatureExtractor", + "complexity_entropy_preprocessor", "complexity_approx_entropy", "complexity_sample_entropy", "complexity_svd_entropy", @@ -30,32 +29,31 @@ def _channel_app_samp_entropy_counts(x, m, r, l): @FeaturePredecessor(*SIGNAL_PREDECESSORS) -class EntropyFeatureExtractor(FeatureExtractor): - def preprocess(self, x, m=2, r=0.2, l=1): - rr = r * x.std(axis=-1) - counts_m = np.empty((*x.shape[:-1], (x.shape[-1] - m + 1) // l)) - counts_mp1 = np.empty((*x.shape[:-1], (x.shape[-1] - m) // l)) - for i in np.ndindex(x.shape[:-1]): - counts_m[i + (slice(None),)] = _channel_app_samp_entropy_counts( - x[i], m, rr[i], l - ) - counts_mp1[i + (slice(None),)] = _channel_app_samp_entropy_counts( - x[i], m + 1, rr[i], l - ) - return counts_m, counts_mp1 - - -@FeaturePredecessor(EntropyFeatureExtractor) +def complexity_entropy_preprocessor(x, /, m=2, r=0.2, l=1): + rr = r * x.std(axis=-1) + counts_m = np.empty((*x.shape[:-1], (x.shape[-1] - m + 1) // l)) + counts_mp1 = np.empty((*x.shape[:-1], (x.shape[-1] - m) // l)) + for i in np.ndindex(x.shape[:-1]): + counts_m[i + (slice(None),)] = _channel_app_samp_entropy_counts( + x[i], m, rr[i], l + ) + counts_mp1[i + (slice(None),)] = _channel_app_samp_entropy_counts( + x[i], m + 1, rr[i], l + ) + return counts_m, counts_mp1 + + +@FeaturePredecessor(complexity_entropy_preprocessor) @univariate_feature -def complexity_approx_entropy(counts_m, counts_mp1): +def complexity_approx_entropy(counts_m, counts_mp1, /): phi_m = np.log(counts_m / counts_m.shape[-1]).mean(axis=-1) phi_mp1 = np.log(counts_mp1 / counts_mp1.shape[-1]).mean(axis=-1) return phi_m - phi_mp1 -@FeaturePredecessor(EntropyFeatureExtractor) +@FeaturePredecessor(complexity_entropy_preprocessor) @univariate_feature -def complexity_sample_entropy(counts_m, counts_mp1): +def complexity_sample_entropy(counts_m, counts_mp1, /): A = np.sum(counts_mp1 - 1, axis=-1) B = np.sum(counts_m - 1, axis=-1) return -np.log(A / B) @@ -63,7 +61,7 @@ def complexity_sample_entropy(counts_m, counts_mp1): @FeaturePredecessor(*SIGNAL_PREDECESSORS) @univariate_feature -def complexity_svd_entropy(x, m=10, tau=1): +def complexity_svd_entropy(x, /, m=10, tau=1): x_emb = np.empty((*x.shape[:-1], (x.shape[-1] - m + 1) // tau, m)) for i in np.ndindex(x.shape[:-1]): x_emb[i + (slice(None), slice(None))] = _create_embedding(x[i], m, tau) @@ -75,7 +73,7 @@ def complexity_svd_entropy(x, m=10, tau=1): @FeaturePredecessor(*SIGNAL_PREDECESSORS) @univariate_feature @nb.njit(cache=True, fastmath=True) -def complexity_lempel_ziv(x, threshold=None, normalize=True): +def complexity_lempel_ziv(x, /, threshold=None, normalize=True): lzc = np.empty(x.shape[:-1]) for i in np.ndindex(x.shape[:-1]): t = np.median(x[i]) if threshold is None else threshold diff --git a/eegdash/features/feature_bank/connectivity.py b/eegdash/features/feature_bank/connectivity.py index 414d9a45..0bcd15ed 100644 --- a/eegdash/features/feature_bank/connectivity.py +++ b/eegdash/features/feature_bank/connectivity.py @@ -4,56 +4,55 @@ from scipy.signal import csd from ..decorators import FeaturePredecessor, bivariate_feature -from ..extractors import BivariateFeature, FeatureExtractor +from ..extractors import BivariateFeature from . import utils +from .signal import SIGNAL_PREDECESSORS __all__ = [ - "CoherenceFeatureExtractor", + "connectivity_coherency_preprocessor", "connectivity_magnitude_square_coherence", "connectivity_imaginary_coherence", "connectivity_lagged_coherence", ] -class CoherenceFeatureExtractor(FeatureExtractor): - def preprocess(self, x, **kwargs): - f_min = kwargs.pop("f_min") if "f_min" in kwargs else None - f_max = kwargs.pop("f_max") if "f_max" in kwargs else None - assert "fs" in kwargs and "nperseg" in kwargs - kwargs["axis"] = -1 - n = x.shape[1] - idx_x, idx_y = BivariateFeature.get_pair_iterators(n) - ix, iy = list(chain(range(n), idx_x)), list(chain(range(n), idx_y)) - f, s = csd(x[:, ix], x[:, iy], **kwargs) - f_min, f_max = utils.get_valid_freq_band( - kwargs["fs"], x.shape[-1], f_min, f_max - ) - f, s = utils.slice_freq_band(f, s, f_min=f_min, f_max=f_max) - p, sxy = np.split(s, [n], axis=1) - sxx, syy = p[:, idx_x].real, p[:, idx_y].real - c = sxy / np.sqrt(sxx * syy) - return f, c - - -@FeaturePredecessor(CoherenceFeatureExtractor) +@FeaturePredecessor(*SIGNAL_PREDECESSORS) +def connectivity_coherency_preprocessor(x, /, **kwargs): + f_min = kwargs.pop("f_min") if "f_min" in kwargs else None + f_max = kwargs.pop("f_max") if "f_max" in kwargs else None + assert "fs" in kwargs and "nperseg" in kwargs + kwargs["axis"] = -1 + n = x.shape[1] + idx_x, idx_y = BivariateFeature.get_pair_iterators(n) + ix, iy = list(chain(range(n), idx_x)), list(chain(range(n), idx_y)) + f, s = csd(x[:, ix], x[:, iy], **kwargs) + f_min, f_max = utils.get_valid_freq_band(kwargs["fs"], x.shape[-1], f_min, f_max) + f, s = utils.slice_freq_band(f, s, f_min=f_min, f_max=f_max) + p, sxy = np.split(s, [n], axis=1) + sxx, syy = p[:, idx_x].real, p[:, idx_y].real + c = sxy / np.sqrt(sxx * syy) + return f, c + + +@FeaturePredecessor(connectivity_coherency_preprocessor) @bivariate_feature -def connectivity_magnitude_square_coherence(f, c, bands=utils.DEFAULT_FREQ_BANDS): +def connectivity_magnitude_square_coherence(f, c, /, bands=utils.DEFAULT_FREQ_BANDS): # https://neuroimage.usc.edu/brainstorm/Tutorials/Connectivity coher = c.real**2 + c.imag**2 return utils.reduce_freq_bands(f, coher, bands, np.mean) -@FeaturePredecessor(CoherenceFeatureExtractor) +@FeaturePredecessor(connectivity_coherency_preprocessor) @bivariate_feature -def connectivity_imaginary_coherence(f, c, bands=utils.DEFAULT_FREQ_BANDS): +def connectivity_imaginary_coherence(f, c, /, bands=utils.DEFAULT_FREQ_BANDS): # https://neuroimage.usc.edu/brainstorm/Tutorials/Connectivity coher = c.imag return utils.reduce_freq_bands(f, coher, bands, np.mean) -@FeaturePredecessor(CoherenceFeatureExtractor) +@FeaturePredecessor(connectivity_coherency_preprocessor) @bivariate_feature -def connectivity_lagged_coherence(f, c, bands=utils.DEFAULT_FREQ_BANDS): +def connectivity_lagged_coherence(f, c, /, bands=utils.DEFAULT_FREQ_BANDS): # https://neuroimage.usc.edu/brainstorm/Tutorials/Connectivity coher = c.imag / np.sqrt(1 - c.real) return utils.reduce_freq_bands(f, coher, bands, np.mean) diff --git a/eegdash/features/feature_bank/csp.py b/eegdash/features/feature_bank/csp.py index 7ff8fffc..733a732e 100644 --- a/eegdash/features/feature_bank/csp.py +++ b/eegdash/features/feature_bank/csp.py @@ -3,8 +3,9 @@ import scipy import scipy.linalg -from ..decorators import multivariate_feature +from ..decorators import FeaturePredecessor, multivariate_feature from ..extractors import TrainableFeature +from .signal import SIGNAL_PREDECESSORS __all__ = [ "CommonSpatialPattern", @@ -21,6 +22,7 @@ def _update_mean_cov(count, mean, cov, x_count, x_mean, x_cov): cov[:] -= np.outer(mean, mean) +@FeaturePredecessor(*SIGNAL_PREDECESSORS) @multivariate_feature class CommonSpatialPattern(TrainableFeature): def __init__(self): diff --git a/eegdash/features/feature_bank/dimensionality.py b/eegdash/features/feature_bank/dimensionality.py index 336744e4..f38ddf70 100644 --- a/eegdash/features/feature_bank/dimensionality.py +++ b/eegdash/features/feature_bank/dimensionality.py @@ -17,7 +17,7 @@ @FeaturePredecessor(*SIGNAL_PREDECESSORS) @univariate_feature @nb.njit(cache=True, fastmath=True) -def dimensionality_higuchi_fractal_dim(x, k_max=10, eps=1e-7): +def dimensionality_higuchi_fractal_dim(x, /, k_max=10, eps=1e-7): N = x.shape[-1] hfd = np.empty(x.shape[:-1]) log_k = np.vstack((-np.log(np.arange(1, k_max + 1)), np.ones(k_max))).T @@ -35,7 +35,7 @@ def dimensionality_higuchi_fractal_dim(x, k_max=10, eps=1e-7): @FeaturePredecessor(*SIGNAL_PREDECESSORS) @univariate_feature -def dimensionality_petrosian_fractal_dim(x): +def dimensionality_petrosian_fractal_dim(x, /): nd = signal_zero_crossings(np.diff(x, axis=-1)) log_n = np.log(x.shape[-1]) return log_n / (np.log(nd) + log_n) @@ -43,7 +43,7 @@ def dimensionality_petrosian_fractal_dim(x): @FeaturePredecessor(*SIGNAL_PREDECESSORS) @univariate_feature -def dimensionality_katz_fractal_dim(x): +def dimensionality_katz_fractal_dim(x, /): dists = np.abs(np.diff(x, axis=-1)) L = dists.sum(axis=-1) a = dists.mean(axis=-1) @@ -79,7 +79,7 @@ def _hurst_exp(x, ns, a, gamma_ratios, log_n): @FeaturePredecessor(*SIGNAL_PREDECESSORS) @univariate_feature -def dimensionality_hurst_exp(x): +def dimensionality_hurst_exp(x, /): ns = np.unique(np.power(2, np.arange(2, np.log2(x.shape[-1]) - 1)).astype(int)) idx = ns > 340 gamma_ratios = np.empty(ns.shape[0]) @@ -94,7 +94,7 @@ def dimensionality_hurst_exp(x): @FeaturePredecessor(*SIGNAL_PREDECESSORS) @univariate_feature @nb.njit(cache=True, fastmath=True) -def dimensionality_detrended_fluctuation_analysis(x): +def dimensionality_detrended_fluctuation_analysis(x, /): ns = np.unique(np.floor(np.power(2, np.arange(2, np.log2(x.shape[-1]) - 1)))) a = np.vstack((np.arange(ns[-1]), np.ones(int(ns[-1])))).T log_n = np.vstack((np.log(ns), np.ones(ns.shape[0]))).T diff --git a/eegdash/features/feature_bank/signal.py b/eegdash/features/feature_bank/signal.py index df8fdb4e..eb9ea556 100644 --- a/eegdash/features/feature_bank/signal.py +++ b/eegdash/features/feature_bank/signal.py @@ -4,10 +4,9 @@ from scipy import signal, stats from ..decorators import FeaturePredecessor, univariate_feature -from ..extractors import FeatureExtractor __all__ = [ - "HilbertFeatureExtractor", + "signal_hilbert_preprocessor", "SIGNAL_PREDECESSORS", "signal_decorrelation_time", "signal_hjorth_activity", @@ -26,72 +25,71 @@ ] -@FeaturePredecessor(FeatureExtractor) -class HilbertFeatureExtractor(FeatureExtractor): - def preprocess(self, x): - return np.abs(signal.hilbert(x - x.mean(axis=-1, keepdims=True), axis=-1)) +@FeaturePredecessor() +def signal_hilbert_preprocessor(x, /): + return np.abs(signal.hilbert(x - x.mean(axis=-1, keepdims=True), axis=-1)) -SIGNAL_PREDECESSORS = [FeatureExtractor, HilbertFeatureExtractor] +SIGNAL_PREDECESSORS = [None, signal_hilbert_preprocessor] @FeaturePredecessor(*SIGNAL_PREDECESSORS) @univariate_feature -def signal_mean(x): +def signal_mean(x, /): return x.mean(axis=-1) @FeaturePredecessor(*SIGNAL_PREDECESSORS) @univariate_feature -def signal_variance(x, **kwargs): +def signal_variance(x, /, **kwargs): return x.var(axis=-1, **kwargs) @FeaturePredecessor(*SIGNAL_PREDECESSORS) @univariate_feature -def signal_std(x, **kwargs): +def signal_std(x, /, **kwargs): return x.std(axis=-1, **kwargs) @FeaturePredecessor(*SIGNAL_PREDECESSORS) @univariate_feature -def signal_skewness(x, **kwargs): +def signal_skewness(x, /, **kwargs): return stats.skew(x, axis=x.ndim - 1, **kwargs) @FeaturePredecessor(*SIGNAL_PREDECESSORS) @univariate_feature -def signal_kurtosis(x, **kwargs): +def signal_kurtosis(x, /, **kwargs): return stats.kurtosis(x, axis=x.ndim - 1, **kwargs) @FeaturePredecessor(*SIGNAL_PREDECESSORS) @univariate_feature -def signal_root_mean_square(x): +def signal_root_mean_square(x, /): return np.sqrt(np.power(x, 2).mean(axis=-1)) @FeaturePredecessor(*SIGNAL_PREDECESSORS) @univariate_feature -def signal_peak_to_peak(x, **kwargs): +def signal_peak_to_peak(x, /, **kwargs): return np.ptp(x, axis=-1, **kwargs) @FeaturePredecessor(*SIGNAL_PREDECESSORS) @univariate_feature -def signal_quantile(x, q: numbers.Number = 0.5, **kwargs): +def signal_quantile(x, /, q: numbers.Number = 0.5, **kwargs): return np.quantile(x, q=q, axis=-1, **kwargs) @FeaturePredecessor(*SIGNAL_PREDECESSORS) @univariate_feature -def signal_line_length(x): +def signal_line_length(x, /): return np.abs(np.diff(x, axis=-1)).mean(axis=-1) @FeaturePredecessor(*SIGNAL_PREDECESSORS) @univariate_feature -def signal_zero_crossings(x, threshold=1e-15): +def signal_zero_crossings(x, /, threshold=1e-15): zero_ind = np.logical_and(x > -threshold, x < threshold) zero_cross = np.diff(zero_ind, axis=-1).astype(int).sum(axis=-1) y = x.copy() @@ -102,13 +100,13 @@ def signal_zero_crossings(x, threshold=1e-15): @FeaturePredecessor(*SIGNAL_PREDECESSORS) @univariate_feature -def signal_hjorth_mobility(x): +def signal_hjorth_mobility(x, /): return np.diff(x, axis=-1).std(axis=-1) / x.std(axis=-1) @FeaturePredecessor(*SIGNAL_PREDECESSORS) @univariate_feature -def signal_hjorth_complexity(x): +def signal_hjorth_complexity(x, /): return (np.diff(x, 2, axis=-1).std(axis=-1) * x.std(axis=-1)) / np.diff( x, axis=-1 ).var(axis=-1) @@ -116,7 +114,7 @@ def signal_hjorth_complexity(x): @FeaturePredecessor(*SIGNAL_PREDECESSORS) @univariate_feature -def signal_decorrelation_time(x, fs=1): +def signal_decorrelation_time(x, /, fs=1): f = np.fft.fft(x - x.mean(axis=-1, keepdims=True), axis=-1) ac = np.fft.ifft(f.real**2 + f.imag**2, axis=-1)[..., : x.shape[-1] // 2] dct = np.empty(x.shape[:-1]) diff --git a/eegdash/features/feature_bank/spectral.py b/eegdash/features/feature_bank/spectral.py index a497a43d..68371197 100644 --- a/eegdash/features/feature_bank/spectral.py +++ b/eegdash/features/feature_bank/spectral.py @@ -3,13 +3,13 @@ from scipy.signal import welch from ..decorators import FeaturePredecessor, univariate_feature -from ..extractors import FeatureExtractor from . import utils +from .signal import SIGNAL_PREDECESSORS __all__ = [ - "SpectralFeatureExtractor", - "NormalizedSpectralFeatureExtractor", - "DBSpectralFeatureExtractor", + "spectral_preprocessor", + "spectral_normalized_preprocessor", + "spectral_db_preprocessor", "spectral_root_total_power", "spectral_moment", "spectral_entropy", @@ -22,84 +22,80 @@ ] -class SpectralFeatureExtractor(FeatureExtractor): - def preprocess(self, x, **kwargs): - f_min = kwargs.pop("f_min") if "f_min" in kwargs else None - f_max = kwargs.pop("f_max") if "f_max" in kwargs else None - assert "fs" in kwargs - kwargs["axis"] = -1 - f, p = welch(x, **kwargs) - f_min, f_max = utils.get_valid_freq_band( - kwargs["fs"], x.shape[-1], f_min, f_max - ) - f, p = utils.slice_freq_band(f, p, f_min=f_min, f_max=f_max) - return f, p +@FeaturePredecessor(*SIGNAL_PREDECESSORS) +def spectral_preprocessor(x, /, **kwargs): + f_min = kwargs.pop("f_min") if "f_min" in kwargs else None + f_max = kwargs.pop("f_max") if "f_max" in kwargs else None + assert "fs" in kwargs + kwargs["axis"] = -1 + f, p = welch(x, **kwargs) + f_min, f_max = utils.get_valid_freq_band(kwargs["fs"], x.shape[-1], f_min, f_max) + f, p = utils.slice_freq_band(f, p, f_min=f_min, f_max=f_max) + return f, p -@FeaturePredecessor(SpectralFeatureExtractor) -class NormalizedSpectralFeatureExtractor(FeatureExtractor): - def preprocess(self, *x): - return (*x[:-1], x[-1] / x[-1].sum(axis=-1, keepdims=True)) +@FeaturePredecessor(spectral_preprocessor) +def spectral_normalized_preprocessor(f, p, /): + return f, p / p.sum(axis=-1, keepdims=True) -@FeaturePredecessor(SpectralFeatureExtractor) -class DBSpectralFeatureExtractor(FeatureExtractor): - def preprocess(self, *x, eps=1e-15): - return (*x[:-1], 10 * np.log10(x[-1] + eps)) +@FeaturePredecessor(spectral_preprocessor) +def spectral_db_preprocessor(f, p, /, eps=1e-15): + return f, 10 * np.log10(p + eps) -@FeaturePredecessor(SpectralFeatureExtractor) +@FeaturePredecessor(spectral_preprocessor) @univariate_feature -def spectral_root_total_power(f, p): +def spectral_root_total_power(f, p, /): return np.sqrt(p.sum(axis=-1)) -@FeaturePredecessor(NormalizedSpectralFeatureExtractor) +@FeaturePredecessor(spectral_normalized_preprocessor) @univariate_feature -def spectral_moment(f, p): +def spectral_moment(f, p, /): return np.sum(f * p, axis=-1) -@FeaturePredecessor(SpectralFeatureExtractor) +@FeaturePredecessor(spectral_preprocessor) @univariate_feature -def spectral_hjorth_activity(f, p): +def spectral_hjorth_activity(f, p, /): return np.sum(p, axis=-1) -@FeaturePredecessor(NormalizedSpectralFeatureExtractor) +@FeaturePredecessor(spectral_normalized_preprocessor) @univariate_feature -def spectral_hjorth_mobility(f, p): +def spectral_hjorth_mobility(f, p, /): return np.sqrt(np.sum(np.power(f, 2) * p, axis=-1)) -@FeaturePredecessor(NormalizedSpectralFeatureExtractor) +@FeaturePredecessor(spectral_normalized_preprocessor) @univariate_feature -def spectral_hjorth_complexity(f, p): +def spectral_hjorth_complexity(f, p, /): return np.sqrt(np.sum(np.power(f, 4) * p, axis=-1)) -@FeaturePredecessor(NormalizedSpectralFeatureExtractor) +@FeaturePredecessor(spectral_normalized_preprocessor) @univariate_feature -def spectral_entropy(f, p): +def spectral_entropy(f, p, /): idx = p > 0 plogp = np.zeros_like(p) plogp[idx] = p[idx] * np.log(p[idx]) return -np.sum(plogp, axis=-1) -@FeaturePredecessor(NormalizedSpectralFeatureExtractor) +@FeaturePredecessor(spectral_normalized_preprocessor) @univariate_feature @nb.njit(cache=True, fastmath=True) -def spectral_edge(f, p, edge=0.9): +def spectral_edge(f, p, /, edge=0.9): se = np.empty(p.shape[:-1]) for i in np.ndindex(p.shape[:-1]): se[i] = f[np.searchsorted(np.cumsum(p[i]), edge)] return se -@FeaturePredecessor(DBSpectralFeatureExtractor) +@FeaturePredecessor(spectral_db_preprocessor) @univariate_feature -def spectral_slope(f, p): +def spectral_slope(f, p, /): log_f = np.vstack((np.log(f), np.ones(f.shape[0]))).T r = np.linalg.lstsq(log_f, p.reshape(-1, p.shape[-1]).T)[0] r = r.reshape(2, *p.shape[:-1]) @@ -107,10 +103,10 @@ def spectral_slope(f, p): @FeaturePredecessor( - SpectralFeatureExtractor, - NormalizedSpectralFeatureExtractor, - DBSpectralFeatureExtractor, + spectral_preprocessor, + spectral_normalized_preprocessor, + spectral_db_preprocessor, ) @univariate_feature -def spectral_bands_power(f, p, bands=utils.DEFAULT_FREQ_BANDS): +def spectral_bands_power(f, p, /, bands=utils.DEFAULT_FREQ_BANDS): return utils.reduce_freq_bands(f, p, bands, np.sum) diff --git a/eegdash/features/inspect.py b/eegdash/features/inspect.py index 1496713e..165eb555 100644 --- a/eegdash/features/inspect.py +++ b/eegdash/features/inspect.py @@ -7,7 +7,7 @@ from .extractors import _get_underlying_func __all__ = [ - "get_all_feature_extractors", + "get_all_feature_preprocessors", "get_all_feature_kinds", "get_all_features", "get_feature_kind", @@ -15,7 +15,7 @@ ] -def get_feature_predecessors(feature_or_extractor: Callable) -> list: +def get_feature_predecessors(feature_or_extractor: Callable | None) -> list: """Get the dependency hierarchy for a feature or feature extractor. This function recursively traverses the `parent_extractor_type` attribute @@ -37,12 +37,13 @@ class to inspect. multiple dependencies, it will contain tuples of sub-dependencies. """ + current = feature_or_extractor + if current is None: + return [None] + if isinstance(current, extractors.FeatureExtractor): + current = current.preprocessor current = _get_underlying_func(feature_or_extractor) - if current is extractors.FeatureExtractor: - return [current] - predecessor = getattr( - current, "parent_extractor_type", [extractors.FeatureExtractor] - ) + predecessor = getattr(current, "parent_extractor_type", [None]) if len(predecessor) == 1: return [current, *get_feature_predecessors(predecessor[0])] else: @@ -92,27 +93,27 @@ def isfeature(x): return inspect.getmembers(feature_bank, isfeature) -def get_all_feature_extractors() -> list[tuple[str, type[extractors.FeatureExtractor]]]: - """Get a list of all available :class:`~eegdash.features.extractors.FeatureExtractor` classes. +def get_all_feature_preprocessors() -> list[tuple[str, Callable]]: + """Get a list of all available preprocessor functions. - Scans the `eegdash.features.feature_bank` module for all classes that - subclass :class:`~eegdash.features.extractors.FeatureExtractor`. + Scans the `eegdash.features.feature_bank` module for all preprocessor functions. Returns ------- - list[tuple[str, type[eegdash.features.extractors.FeatureExtractor]]] - A list of (name, class) tuples for all discovered feature extractors, - including the base :class:`~eegdash.features.extractors.FeatureExtractor` itself. + list[tuple[str, Callable]] + A list of (name, function) tuples for all discovered feature preprocessors. """ def isfeatureextractor(x): - return inspect.isclass(x) and issubclass(x, extractors.FeatureExtractor) - - return [ - ("FeatureExtractor", extractors.FeatureExtractor), - *inspect.getmembers(feature_bank, isfeatureextractor), - ] + y = _get_underlying_func(x) + return ( + callable(y) + and not hasattr(y, "feature_kind") + and hasattr(y, "parent_extractor_type") + ) + + return inspect.getmembers(feature_bank, isfeatureextractor) def get_all_feature_kinds() -> list[tuple[str, type[extractors.MultivariateFeature]]]: diff --git a/pyproject.toml b/pyproject.toml index 63b445ce..b8f173cc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,6 +45,7 @@ dependencies = [ "eeglabio", "tabulate", "rich", + "pyarrow", ] [project.urls] diff --git a/tests/features/__init__.py b/tests/features/__init__.py new file mode 100644 index 00000000..eac7268f --- /dev/null +++ b/tests/features/__init__.py @@ -0,0 +1,4 @@ +# Authors: EEG DaSh team +# +# License: +# General Public License (GPL) v3.0ยง diff --git a/tests/features/conftest.py b/tests/features/conftest.py new file mode 100644 index 00000000..97545d1e --- /dev/null +++ b/tests/features/conftest.py @@ -0,0 +1,103 @@ +"""Fixtures for features module test.""" + +import shutil +from pathlib import Path + +import pytest + +from braindecode.datautil import load_concat_dataset +from braindecode.preprocessing import ( + Preprocessor, + create_windows_from_events, + preprocess, +) +from eegdash import EEGDashDataset +from eegdash.hbn.preprocessing import hbn_ec_ec_reannotation +from eegdash.logging import logger + + +@pytest.fixture(scope="session") +def eeg_dash_dataset(cache_dir: Path): + """Fixture to create an instance of EEGDashDataset.""" + return EEGDashDataset( + query={ + "dataset": "ds005514", + "task": "RestingState", + "subject": "NDARDB033FW5", + }, + cache_dir=cache_dir, + ) + + +@pytest.fixture(scope="session") +def preprocess_instance(eeg_dash_dataset, cache_dir: Path): + """Fixture to create an instance of EEGDashDataset with preprocessing.""" + selected_channels = [ + "E22", + "E9", + "E33", + "E24", + "E11", + "E124", + "E122", + "E29", + "E6", + "E111", + "E45", + "E36", + "E104", + "E108", + "E42", + "E55", + "E93", + "E58", + "E52", + "E62", + "E92", + "E96", + "E70", + "Cz", + ] + pre_processed_dir = cache_dir / "preprocessed" + pre_processed_dir.mkdir(parents=True, exist_ok=True) + try: + eeg_dash_dataset = load_concat_dataset( + pre_processed_dir, + preload=True, + ) + return eeg_dash_dataset + + except Exception as e: + logger.warning(f"Failed to load dataset creating a new instance: {e}. ") + if pre_processed_dir.exists(): + # folder with issue, erasing and creating again + shutil.rmtree(pre_processed_dir) + pre_processed_dir.mkdir(parents=True, exist_ok=True) + + preprocessors = [ + hbn_ec_ec_reannotation(), + Preprocessor( + "pick_channels", + ch_names=selected_channels, + ), + Preprocessor("resample", sfreq=128), + Preprocessor("filter", l_freq=1, h_freq=55), + ] + + eeg_dash_dataset = preprocess( + eeg_dash_dataset, preprocessors, n_jobs=-1, save_dir=pre_processed_dir + ) + + return eeg_dash_dataset + + +@pytest.fixture(scope="session") +def windows_ds(preprocess_instance): + """Fixture to create windows from the preprocessed EEG dataset.""" + windows = create_windows_from_events( + preprocess_instance, + trial_start_offset_samples=0, + trial_stop_offset_samples=256, + preload=True, + ) + return windows diff --git a/tests/features/test_features_extraction.py b/tests/features/test_features_extraction.py new file mode 100644 index 00000000..0ea8be07 --- /dev/null +++ b/tests/features/test_features_extraction.py @@ -0,0 +1,100 @@ +"""Test for features module Python 3.10+ compatibility.""" + +from functools import partial + +import pytest + +from eegdash import features +from eegdash.features import FeatureExtractor, FeaturesConcatDataset, extract_features + + +@pytest.fixture(scope="module") +def feature_dict(windows_ds): + """Fixture to create a feature extraction tree.""" + sfreq = windows_ds.datasets[0].raw.info["sfreq"] + filter_freqs = dict(windows_ds.datasets[0].raw_preproc_kwargs)["filter"] + + feats = { + "sig": features.FeatureExtractor( + { + "mean": features.signal_mean, + "var": features.signal_variance, + "std": features.signal_std, + "skew": features.signal_skewness, + "kurt": features.signal_kurtosis, + "rms": features.signal_root_mean_square, + "ptp": features.signal_peak_to_peak, + "quan.1": partial(features.signal_quantile, q=0.1), + "quan.9": partial(features.signal_quantile, q=0.9), + "line_len": features.signal_line_length, + "zero_x": features.signal_zero_crossings, + }, + ), + "spec": features.FeatureExtractor( + preprocessor=partial( + features.spectral_preprocessor, + fs=sfreq, + f_min=filter_freqs["l_freq"], + f_max=filter_freqs["h_freq"], + nperseg=2 * sfreq, + noverlap=int(1.5 * sfreq), + ), + feature_extractors={ + "rtot_power": features.spectral_root_total_power, + "band_power": partial( + features.spectral_bands_power, + bands={ + "theta": (4.5, 8), + "alpha": (8, 12), + "beta": (12, 30), + }, + ), + 0: features.FeatureExtractor( + preprocessor=features.spectral_normalized_preprocessor, + feature_extractors={ + "moment": features.spectral_moment, + "entropy": features.spectral_entropy, + "edge": partial(features.spectral_edge, edge=0.9), + }, + ), + 1: features.FeatureExtractor( + preprocessor=features.spectral_db_preprocessor, + feature_extractors={ + "slope": features.spectral_slope, + }, + ), + }, + ), + } + return feats + + +@pytest.fixture(scope="module") +def feature_extractor(feature_dict): + """Fixture to create a feature extractor.""" + feats = FeatureExtractor(feature_dict) + return feats + + +def test_feature_extraction_benchmark( + benchmark, windows_ds, feature_extractor, batch_size=512, n_jobs=1 +): + """Benchmark feature extraction function.""" + feats = benchmark( + extract_features, + windows_ds, + feature_extractor, + batch_size=batch_size, + n_jobs=n_jobs, + ) + assert isinstance(feats, FeaturesConcatDataset) + assert len(windows_ds.datasets) == len(feats.datasets) + + +@pytest.fixture(scope="module") +def features_ds(windows_ds, feature_extractor, batch_size=512, n_jobs=1): + """Fixture to create a features dataset.""" + feats = extract_features( + windows_ds, feature_extractor, batch_size=batch_size, n_jobs=n_jobs + ) + return feats diff --git a/tests/test_features.py b/tests/features/test_features_init.py similarity index 77% rename from tests/test_features.py rename to tests/features/test_features_init.py index 8a721ba1..d62a66dd 100644 --- a/tests/test_features.py +++ b/tests/features/test_features_init.py @@ -46,22 +46,3 @@ def test_import_features_submodules(): # Some imports might fail due to missing dependencies, that's ok # We only care about SyntaxError pass - - -def test_features_basic_functionality(): - """Test basic features module functionality.""" - from eegdash.features import ( - get_all_feature_extractors, - get_all_feature_kinds, - get_all_features, - ) - - # These should return lists without errors - features = get_all_features() - assert isinstance(features, list) - - extractors = get_all_feature_extractors() - assert isinstance(extractors, list) - - kinds = get_all_feature_kinds() - assert isinstance(kinds, list) diff --git a/tests/features/test_features_inspect.py b/tests/features/test_features_inspect.py new file mode 100644 index 00000000..18059d9c --- /dev/null +++ b/tests/features/test_features_inspect.py @@ -0,0 +1,20 @@ +"""Test for features module Python 3.10+ compatibility.""" + +from eegdash.features import ( + get_all_feature_kinds, + get_all_feature_preprocessors, + get_all_features, +) + + +def test_features_basic_functionality(): + """Test basic features module functionality.""" + # These should return lists without errors + features = get_all_features() + assert isinstance(features, list) + + extractors = get_all_feature_preprocessors() + assert isinstance(extractors, list) + + kinds = get_all_feature_kinds() + assert isinstance(kinds, list)