1717 TransformerMixin ,
1818 clone ,
1919 is_classifier ,
20+ is_regressor ,
2021)
2122from sklearn .linear_model import LogisticRegression
2223from sklearn .metrics import check_scoring
2324from sklearn .model_selection import KFold , StratifiedKFold , check_cv
24- from sklearn .utils import check_array , check_X_y , indexable
25+ from sklearn .utils import indexable
2526from sklearn .utils .validation import check_is_fitted
2627
2728from ..parallel import parallel_func
28- from ..utils import _check_option , _pl , _validate_type , logger , pinv , verbose , warn
29+ from ..utils import (
30+ _check_option ,
31+ _pl ,
32+ _validate_type ,
33+ logger ,
34+ pinv ,
35+ verbose ,
36+ warn ,
37+ )
38+ from ._fixes import validate_data
2939from ._ged import (
3040 _handle_restr_mat ,
3141 _is_cov_pos_semidef ,
@@ -340,7 +350,8 @@ class LinearModel(MetaEstimatorMixin, BaseEstimator):
340350 model : object | None
341351 A linear model from scikit-learn with a fit method
342352 that updates a ``coef_`` attribute.
343- If None the model will be LogisticRegression.
353+ If None the model will be
354+ :class:`sklearn.linear_model.LogisticRegression`.
344355
345356 Attributes
346357 ----------
@@ -364,46 +375,66 @@ class LinearModel(MetaEstimatorMixin, BaseEstimator):
364375 .. footbibliography::
365376 """
366377
367- # TODO: Properly refactor this using
368- # https://github.com/scikit-learn/scikit-learn/issues/30237#issuecomment-2465572885
369378 _model_attr_wrap = (
370379 "transform" ,
380+ "fit_transform" ,
371381 "predict" ,
372382 "predict_proba" ,
373- "_estimator_type " ,
374- "__tags__" ,
383+ "predict_log_proba " ,
384+ "_estimator_type" , # remove after sklearn 1.6
375385 "decision_function" ,
376386 "score" ,
377387 "classes_" ,
378388 )
379389
380390 def __init__ (self , model = None ):
381- # TODO: We need to set this to get our tag checking to work properly
382- if model is None :
383- model = LogisticRegression (solver = "liblinear" )
384391 self .model = model
385392
386393 def __sklearn_tags__ (self ):
387394 """Get sklearn tags."""
388- from sklearn .utils import get_tags # added in 1.6
389-
390- # fit method below does not allow sparse data via check_data, we could
391- # eventually make it smarter if we had to
392- tags = get_tags (self .model )
393- tags .input_tags .sparse = False
395+ tags = super ().__sklearn_tags__ ()
396+ model = self .model if self .model is not None else LogisticRegression ()
397+ model_tags = model .__sklearn_tags__ ()
398+ tags .estimator_type = model_tags .estimator_type
399+ if tags .estimator_type is not None :
400+ model_type_tags = getattr (model_tags , f"{ tags .estimator_type } _tags" )
401+ setattr (tags , f"{ tags .estimator_type } _tags" , model_type_tags )
394402 return tags
395403
396404 def __getattr__ (self , attr ):
397405 """Wrap to model for some attributes."""
398406 if attr in LinearModel ._model_attr_wrap :
399- return getattr (self .model , attr )
400- elif attr == "fit_transform" and hasattr (self .model , "fit_transform" ):
401- return super ().__getattr__ (self , "_fit_transform" )
402- return super ().__getattr__ (self , attr )
407+ model = self .model_ if "model_" in self .__dict__ else self .model
408+ if attr == "fit_transform" and hasattr (model , "fit_transform" ):
409+ return self ._fit_transform
410+ else :
411+ return getattr (model , attr )
412+ else :
413+ raise AttributeError (
414+ f"'{ type (self ).__name__ } ' object has no attribute '{ attr } '"
415+ )
403416
404417 def _fit_transform (self , X , y ):
405418 return self .fit (X , y ).transform (X )
406419
420+ def _validate_params (self , X ):
421+ if self .model is not None :
422+ model = self .model
423+ if isinstance (model , MetaEstimatorMixin ):
424+ model = model .estimator
425+ is_predictor = is_regressor (model ) or is_classifier (model )
426+ if not is_predictor :
427+ raise ValueError (
428+ "Linear model should be a supervised predictor "
429+ "(classifier or regressor)"
430+ )
431+
432+ # For sklearn < 1.6
433+ try :
434+ self ._check_n_features (X , reset = True )
435+ except AttributeError :
436+ pass
437+
407438 def fit (self , X , y , ** fit_params ):
408439 """Estimate the coefficients of the linear model.
409440
@@ -424,25 +455,18 @@ def fit(self, X, y, **fit_params):
424455 self : instance of LinearModel
425456 Returns the modified instance.
426457 """
427- if y is not None :
428- X = check_array (X )
429- else :
430- X , y = check_X_y (X , y )
431- self .n_features_in_ = X .shape [1 ]
432- if y is not None :
433- y = check_array (y , dtype = None , ensure_2d = False , input_name = "y" )
434- if y .ndim > 2 :
435- raise ValueError (
436- f"LinearModel only accepts up to 2-dimensional y, got { y .shape } "
437- "instead."
438- )
458+ self ._validate_params (X )
459+ X , y = validate_data (self , X , y , multi_output = True )
439460
440461 # fit the Model
441- self .model .fit (X , y , ** fit_params )
442- self .model_ = self .model # for better sklearn compat
462+ self .model_ = (
463+ clone (self .model )
464+ if self .model is not None
465+ else LogisticRegression (solver = "liblinear" )
466+ )
467+ self .model_ .fit (X , y , ** fit_params )
443468
444469 # Computes patterns using Haufe's trick: A = Cov_X . W . Precision_Y
445-
446470 inv_Y = 1.0
447471 X = X - X .mean (0 , keepdims = True )
448472 if y .ndim == 2 and y .shape [1 ] != 1 :
@@ -454,12 +478,17 @@ def fit(self, X, y, **fit_params):
454478
455479 @property
456480 def filters_ (self ):
457- if hasattr (self .model , "coef_" ):
481+ if hasattr (self .model_ , "coef_" ):
458482 # Standard Linear Model
459- filters = self .model .coef_
460- elif hasattr (self .model .best_estimator_ , "coef_" ):
483+ filters = self .model_ .coef_
484+ elif hasattr (self .model_ , "estimators_" ):
485+ # Linear model with OneVsRestClassifier
486+ filters = np .vstack ([est .coef_ for est in self .model_ .estimators_ ])
487+ elif hasattr (self .model_ , "best_estimator_" ) and hasattr (
488+ self .model_ .best_estimator_ , "coef_"
489+ ):
461490 # Linear Model with GridSearchCV
462- filters = self .model .best_estimator_ .coef_
491+ filters = self .model_ .best_estimator_ .coef_
463492 else :
464493 raise ValueError ("model does not have a `coef_` attribute." )
465494 if filters .ndim == 2 and filters .shape [0 ] == 1 :
0 commit comments