1616from sklearn .metrics import r2_score
1717
1818from ..utils import _validate_type , fill_doc , pinv
19+ from ._fixes import _check_n_features_3d , validate_data
1920from .base import _check_estimator , get_coef
2021from .time_delaying_ridge import TimeDelayingRidge
2122
@@ -125,7 +126,7 @@ def __init__(
125126 self .tmax = tmax
126127 self .sfreq = sfreq
127128 self .feature_names = feature_names
128- self .estimator = 0.0 if estimator is None else estimator
129+ self .estimator = estimator
129130 self .fit_intercept = fit_intercept
130131 self .scoring = scoring
131132 self .patterns = patterns
@@ -152,6 +153,19 @@ def __repr__(self): # noqa: D105
152153 s += f"scored ({ self .scoring } )"
153154 return f"<ReceptiveField | { s } >"
154155
156+ def __sklearn_tags__ (self ):
157+ """..."""
158+ from sklearn .utils import RegressorTags
159+
160+ tags = super ().__sklearn_tags__ ()
161+ tags .estimator_type = "regressor"
162+ tags .regressor_tags = RegressorTags ()
163+ tags .input_tags .three_d_array = True
164+ tags .target_tags .one_d_labels = True
165+ tags .target_tags .multi_output = True
166+ tags .target_tags .required = True
167+ return tags
168+
155169 def _delay_and_reshape (self , X , y = None ):
156170 """Delay and reshape the variables."""
157171 if not isinstance (self .estimator_ , TimeDelayingRidge ):
@@ -169,6 +183,32 @@ def _delay_and_reshape(self, X, y=None):
169183 y = y .reshape (- 1 , y .shape [- 1 ], order = "F" )
170184 return X , y
171185
186+ def _check_data (self , X , y = None , reset = False ):
187+ if reset :
188+ X , y = validate_data (
189+ self ,
190+ X = X ,
191+ y = y ,
192+ reset = reset ,
193+ validate_separately = ( # to take care of 3D y
194+ dict (allow_nd = True , ensure_2d = False ),
195+ dict (allow_nd = True , ensure_2d = False ),
196+ ),
197+ )
198+ else :
199+ X = validate_data (self , X = X , allow_nd = True , ensure_2d = False , reset = reset )
200+ _check_n_features_3d (self , X , reset )
201+ return X , y
202+
203+ def _validate_params (self , X ):
204+ if self .scoring not in _SCORERS .keys ():
205+ raise ValueError (
206+ f"scoring must be one of { sorted (_SCORERS .keys ())} , got { self .scoring } "
207+ )
208+ self .sfreq_ = float (self .sfreq )
209+ if self .tmin > self .tmax :
210+ raise ValueError (f"tmin ({ self .tmin } ) must be at most tmax ({ self .tmax } )" )
211+
172212 def fit (self , X , y ):
173213 """Fit a receptive field model.
174214
@@ -184,22 +224,18 @@ def fit(self, X, y):
184224 self : instance
185225 The instance so you can chain operations.
186226 """
187- if self .scoring not in _SCORERS .keys ():
188- raise ValueError (
189- f"scoring must be one of { sorted (_SCORERS .keys ())} , got { self .scoring } "
190- )
191- self .sfreq_ = float (self .sfreq )
227+ X , y = self ._check_data (X , y , reset = True )
228+ self ._validate_params (X )
192229 X , y , _ , self ._y_dim = self ._check_dimensions (X , y )
193230
194- if self .tmin > self .tmax :
195- raise ValueError (f"tmin ({ self .tmin } ) must be at most tmax ({ self .tmax } )" )
196231 # Initialize delays
197232 self .delays_ = _times_to_delays (self .tmin , self .tmax , self .sfreq_ )
198233
199234 # Define the slice that we should use in the middle
200235 self .valid_samples_ = _delays_to_slice (self .delays_ )
201236
202- if isinstance (self .estimator , numbers .Real ):
237+ if self .estimator is None or isinstance (self .estimator , numbers .Real ):
238+ alpha = self .estimator if self .estimator is not None else 0.0
203239 if self .fit_intercept is None :
204240 self .fit_intercept_ = True
205241 else :
@@ -208,7 +244,7 @@ def fit(self, X, y):
208244 self .tmin ,
209245 self .tmax ,
210246 self .sfreq_ ,
211- alpha = self . estimator ,
247+ alpha = alpha ,
212248 fit_intercept = self .fit_intercept_ ,
213249 n_jobs = self .n_jobs ,
214250 edge_correction = self .edge_correction ,
@@ -259,6 +295,12 @@ def fit(self, X, y):
259295
260296 # Inverse-transform model weights
261297 if self .patterns :
298+ n_total_samples = n_times * n_epochs
299+ if n_total_samples < 2 :
300+ raise ValueError (
301+ "Cannot compute patterns with only one sample; "
302+ f"got n_samples = { n_total_samples } ."
303+ )
262304 if isinstance (self .estimator_ , TimeDelayingRidge ):
263305 cov_ = self .estimator_ .cov_ / float (n_times * n_epochs - 1 )
264306 y = y .reshape (- 1 , y .shape [- 1 ], order = "F" )
@@ -300,7 +342,10 @@ def predict(self, X):
300342 """
301343 if not hasattr (self , "delays_" ):
302344 raise NotFittedError ("Estimator has not been fit yet." )
345+
346+ X , _ = self ._check_data (X )
303347 X , _ , X_dim = self ._check_dimensions (X , None , predict = True )[:3 ]
348+
304349 del _
305350 # convert to sklearn and back
306351 pred_shape = X .shape [:- 1 ]
@@ -384,7 +429,10 @@ def _check_dimensions(self, X, y, predict=False):
384429 )
385430 else :
386431 raise ValueError (
387- f"X must be shape (n_times[, n_epochs], n_features), got { X .shape } "
432+ "X must be shape (n_times[, n_epochs], n_features), "
433+ f"got { X .shape } . Reshape your data to 2D or 3D "
434+ "(e.g., array.reshape(-1, 1) for a single feature, "
435+ "or array.reshape(1, -1) for a single sample)."
388436 )
389437 if y is not None :
390438 if X .shape [0 ] != y .shape [0 ]:
0 commit comments