Skip to content

Commit bc55a70

Browse files
committed
Handle single sample output from FeatureFunctionTransformer to support scikit-learn 1.7.2 and later (mne-tools#97)
1 parent 2cd03f4 commit bc55a70

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

mne_features/feature_extraction.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def transform(self, X):
6262
6363
Returns
6464
-------
65-
X_out : ndarray, shape (n_output_func,)
65+
X_out : ndarray, shape (1, n_output_func)
6666
Usually, ``n_output_func`` will be equal to ``n_channels`` for most
6767
univariate feature functions and to
6868
``(n_channels * (n_channels + 1)) // 2`` for most bivariate feature
@@ -71,6 +71,8 @@ def transform(self, X):
7171
"""
7272
X_out = super(FeatureFunctionTransformer, self).transform(X)
7373
self.output_shape_ = X_out.shape[0]
74+
if X_out.ndim == 1:
75+
X_out = X_out[np.newaxis, :]
7476
if not hasattr(self, 'feature_names_'):
7577
func_name = _get_func_name(self.func).replace('compute_', '')
7678
if (func_name in get_univariate_func_names() and

0 commit comments

Comments
 (0)