Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion eegdash/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,4 @@

__all__ = ["EEGDash", "EEGDashDataset", "EEGChallengeDataset", "preprocessing"]

__version__ = "0.4.1"
__version__ = "0.5.0"
28 changes: 14 additions & 14 deletions eegdash/features/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,13 @@
UnivariateFeature,
)
from .feature_bank import ( # Complexity; Connectivity; CSP; Dimensionality; Signal; Spectral
CoherenceFeatureExtractor,
CommonSpatialPattern,
DBSpectralFeatureExtractor,
EntropyFeatureExtractor,
HilbertFeatureExtractor,
NormalizedSpectralFeatureExtractor,
SpectralFeatureExtractor,
complexity_approx_entropy,
complexity_entropy_preprocessor,
complexity_lempel_ziv,
complexity_sample_entropy,
complexity_svd_entropy,
connectivity_coherency_preprocessor,
connectivity_imaginary_coherence,
connectivity_lagged_coherence,
connectivity_magnitude_square_coherence,
Expand All @@ -35,6 +31,7 @@
dimensionality_katz_fractal_dim,
dimensionality_petrosian_fractal_dim,
signal_decorrelation_time,
signal_hilbert_preprocessor,
signal_hjorth_activity,
signal_hjorth_complexity,
signal_hjorth_mobility,
Expand All @@ -49,18 +46,21 @@
signal_variance,
signal_zero_crossings,
spectral_bands_power,
spectral_db_preprocessor,
spectral_edge,
spectral_entropy,
spectral_hjorth_activity,
spectral_hjorth_complexity,
spectral_hjorth_mobility,
spectral_moment,
spectral_normalized_preprocessor,
spectral_preprocessor,
spectral_root_total_power,
spectral_slope,
)
from .inspect import (
get_all_feature_extractors,
get_all_feature_kinds,
get_all_feature_preprocessors,
get_all_features,
get_feature_kind,
get_feature_predecessors,
Expand All @@ -82,7 +82,7 @@
"MultivariateFeature",
"TrainableFeature",
"UnivariateFeature",
"get_all_feature_extractors",
"get_all_feature_preprocessors",
"get_all_feature_kinds",
"get_all_features",
"get_feature_kind",
Expand All @@ -92,13 +92,13 @@
"fit_feature_extractors",
# Feature part
# Complexity
"EntropyFeatureExtractor",
"complexity_entropy_preprocessor",
"complexity_approx_entropy",
"complexity_sample_entropy",
"complexity_svd_entropy",
"complexity_lempel_ziv",
# Connectivity
"CoherenceFeatureExtractor",
"connectivity_coherency_preprocessor",
"connectivity_magnitude_square_coherence",
"connectivity_imaginary_coherence",
"connectivity_lagged_coherence",
Expand All @@ -111,7 +111,7 @@
"dimensionality_hurst_exp",
"dimensionality_detrended_fluctuation_analysis",
# Signal
"HilbertFeatureExtractor",
"signal_hilbert_preprocessor",
"signal_mean",
"signal_variance",
"signal_skewness",
Expand All @@ -127,9 +127,9 @@
"signal_hjorth_complexity",
"signal_decorrelation_time",
# Spectral
"SpectralFeatureExtractor",
"NormalizedSpectralFeatureExtractor",
"DBSpectralFeatureExtractor",
"spectral_preprocessor",
"spectral_normalized_preprocessor",
"spectral_db_preprocessor",
"spectral_root_total_power",
"spectral_moment",
"spectral_entropy",
Expand Down
29 changes: 14 additions & 15 deletions eegdash/features/decorators.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from collections.abc import Callable
from typing import List, Type
from typing import List

from .extractors import (
BivariateFeature,
DirectedBivariateFeature,
FeatureExtractor,
MultivariateFeature,
UnivariateFeature,
_get_underlying_func,
Expand All @@ -22,26 +21,26 @@
class FeaturePredecessor:
"""A decorator to specify parent extractors for a feature function.

This decorator attaches a list of parent extractor types to a feature
extraction function. This information can be used to build a dependency
graph of features.
This decorator attaches a list of immediate parent preprocessing steps to a feature
extraction function. This information can be used to build a dependency graph of
features.

Parameters
----------
*parent_extractor_type : list of Type
A list of feature extractor classes (subclasses of
:class:`~eegdash.features.extractors.FeatureExtractor`) that this
feature depends on.
A list of preprocessing functions (subclasses of
:class:`~collections.abc.Callable` or None) that this feature immediately depends
on.

"""

def __init__(self, *parent_extractor_type: List[Type]):
parent_cls = parent_extractor_type
if not parent_cls:
parent_cls = [FeatureExtractor]
for p_cls in parent_cls:
assert issubclass(p_cls, FeatureExtractor)
self.parent_extractor_type = parent_cls
def __init__(self, *parent_extractor_type: List[Callable | None]):
parent_func = parent_extractor_type
if not parent_func:
parent_func = [None]
for p_func in parent_func:
assert p_func is None or callable(p_func)
self.parent_extractor_type = parent_func

def __call__(self, func: Callable) -> Callable:
"""Apply the decorator to a function.
Expand Down
63 changes: 38 additions & 25 deletions eegdash/features/extractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,28 +102,28 @@ class FeatureExtractor(TrainableFeature):
feature_extractors : dict[str, callable]
A dictionary where keys are feature names and values are the feature
extraction functions or other `FeatureExtractor` instances.
**preprocess_kwargs
Keyword arguments to be passed to the `preprocess` method.
preprocessor
A shared preprocessing function for all child feature extraction functions.

"""

def __init__(
self, feature_extractors: Dict[str, Callable], **preprocess_kwargs: Dict
self,
feature_extractors: Dict[str, Callable],
preprocessor: Callable | None = None,
):
self.preprocessor = preprocessor
self.feature_extractors_dict = self._validate_execution_tree(feature_extractors)
self._is_trainable = self._check_is_trainable(feature_extractors)
super().__init__()

# bypassing FeaturePredecessor to avoid circular import
if not hasattr(self, "parent_extractor_type"):
self.parent_extractor_type = [FeatureExtractor]

self.preprocess_kwargs = preprocess_kwargs
if self.preprocess_kwargs is None:
self.preprocess_kwargs = dict()
self.features_kwargs = {
"preprocess_kwargs": preprocess_kwargs,
}
self.parent_extractor_type = [None]

self.features_kwargs = dict()
if preprocessor is not None and isinstance(preprocessor, partial):
self.features_kwargs["preprocess_kwargs"] = preprocessor.args
for fn, fe in feature_extractors.items():
if isinstance(fe, FeatureExtractor):
self.features_kwargs[fn] = fe.features_kwargs
Expand All @@ -132,12 +132,21 @@ def __init__(

def _validate_execution_tree(self, feature_extractors: dict) -> dict:
"""Validate the feature dependency graph."""
preprocessor = (
None
if self.preprocessor is None
else _get_underlying_func(self.preprocessor)
)
for fname, f in feature_extractors.items():
if isinstance(f, FeatureExtractor):
f = f.preprocessor
f = _get_underlying_func(f)
pe_type = getattr(f, "parent_extractor_type", [FeatureExtractor])
if type(self) not in pe_type:
pe_type = getattr(f, "parent_extractor_type", [None])
if preprocessor not in pe_type:
parent = getattr(preprocessor, "__name__", preprocessor)
child = getattr(f, "__name__", f)
raise TypeError(
f"Feature '{fname}' cannot be a child of {type(self).__name__}"
f"Feature '{fname}: {child}' cannot be a child of {parent}"
)
return feature_extractors

Expand All @@ -151,23 +160,24 @@ def _check_is_trainable(self, feature_extractors: dict) -> bool:
return True
return False

def preprocess(self, *x, **kwargs):
def preprocess(self, *x):
"""Apply pre-processing to the input data.

Parameters
----------
*x : tuple
Input data.
**kwargs
Additional keyword arguments.

Returns
-------
tuple
The pre-processed data.

"""
return (*x,)
if self.preprocessor is None:
return (*x,)
else:
return self.preprocessor(*x)

def __call__(self, *x, _batch_size=None, _ch_names=None) -> dict:
"""Apply all feature extractors to the input data.
Expand All @@ -193,7 +203,7 @@ def __call__(self, *x, _batch_size=None, _ch_names=None) -> dict:
if self._is_trainable:
super().__call__()
results_dict = dict()
z = self.preprocess(*x, **self.preprocess_kwargs)
z = self.preprocess(*x)
if not isinstance(z, tuple):
z = (z,)
for fname, f in self.feature_extractors_dict.items():
Expand Down Expand Up @@ -227,26 +237,29 @@ def clear(self):
if not self._is_trainable:
return
for f in self.feature_extractors_dict.values():
if isinstance(_get_underlying_func(f), TrainableFeature):
_get_underlying_func(f).clear()
f = _get_underlying_func(f)
if isinstance(f, TrainableFeature):
f.clear()

def partial_fit(self, *x, y=None):
"""Partially fit all trainable sub-features."""
if not self._is_trainable:
return
z = self.preprocess(*x, **self.preprocess_kwargs)
z = self.preprocess(*x)
if not isinstance(z, tuple):
z = (z,)
for f in self.feature_extractors_dict.values():
if isinstance(_get_underlying_func(f), TrainableFeature):
_get_underlying_func(f).partial_fit(*z, y=y)
f = _get_underlying_func(f)
if isinstance(f, TrainableFeature):
f.partial_fit(*z, y=y)

def fit(self):
"""Fit all trainable sub-features."""
if not self._is_trainable:
return
for f in self.feature_extractors_dict.values():
if isinstance(_get_underlying_func(f), TrainableFeature):
f = _get_underlying_func(f)
if isinstance(f, TrainableFeature):
f.fit()
super().fit()

Expand Down
24 changes: 12 additions & 12 deletions eegdash/features/feature_bank/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@
"""

from .complexity import (
EntropyFeatureExtractor,
complexity_approx_entropy,
complexity_entropy_preprocessor,
complexity_lempel_ziv,
complexity_sample_entropy,
complexity_svd_entropy,
)
from .connectivity import (
CoherenceFeatureExtractor,
connectivity_coherency_preprocessor,
connectivity_imaginary_coherence,
connectivity_lagged_coherence,
connectivity_magnitude_square_coherence,
Expand All @@ -27,8 +27,8 @@
dimensionality_petrosian_fractal_dim,
)
from .signal import (
HilbertFeatureExtractor,
signal_decorrelation_time,
signal_hilbert_preprocessor,
signal_hjorth_activity,
signal_hjorth_complexity,
signal_hjorth_mobility,
Expand All @@ -44,29 +44,29 @@
signal_zero_crossings,
)
from .spectral import (
DBSpectralFeatureExtractor,
NormalizedSpectralFeatureExtractor,
SpectralFeatureExtractor,
spectral_bands_power,
spectral_db_preprocessor,
spectral_edge,
spectral_entropy,
spectral_hjorth_activity,
spectral_hjorth_complexity,
spectral_hjorth_mobility,
spectral_moment,
spectral_normalized_preprocessor,
spectral_preprocessor,
spectral_root_total_power,
spectral_slope,
)

__all__ = [
# Complexity
"EntropyFeatureExtractor",
"complexity_entropy_preprocessor",
"complexity_approx_entropy",
"complexity_sample_entropy",
"complexity_svd_entropy",
"complexity_lempel_ziv",
# Connectivity
"CoherenceFeatureExtractor",
"connectivity_coherency_preprocessor",
"connectivity_magnitude_square_coherence",
"connectivity_imaginary_coherence",
"connectivity_lagged_coherence",
Expand All @@ -79,7 +79,7 @@
"dimensionality_hurst_exp",
"dimensionality_detrended_fluctuation_analysis",
# Signal
"HilbertFeatureExtractor",
"signal_hilbert_preprocessor",
"signal_mean",
"signal_variance",
"signal_skewness",
Expand All @@ -95,9 +95,9 @@
"signal_hjorth_complexity",
"signal_decorrelation_time",
# Spectral
"SpectralFeatureExtractor",
"NormalizedSpectralFeatureExtractor",
"DBSpectralFeatureExtractor",
"spectral_preprocessor",
"spectral_normalized_preprocessor",
"spectral_db_preprocessor",
"spectral_root_total_power",
"spectral_moment",
"spectral_entropy",
Expand Down
Loading