@@ -188,24 +188,105 @@ class GeneralizedLinearEstimator(LinearModel):
188188 Number of subproblems solved to reach the specified tolerance.
189189 """
190190
191+ _parameter_constraints : dict = {
192+ "datafit" : [None , "object" ],
193+ "penalty" : [None , "object" ],
194+ "solver" : [None , "object" ],
195+ }
196+
191197 def __init__ (self , datafit = None , penalty = None , solver = None ):
192198 super (GeneralizedLinearEstimator , self ).__init__ ()
193199 self .penalty = penalty
194200 self .datafit = datafit
195201 self .solver = solver
196202
197203 def __repr__ (self ):
198- """Get string representation of the estimator.
204+ """String representation."""
205+ penalty_name = self .penalty .__class__ .__name__ if self .penalty else "None"
206+ datafit_name = self .datafit .__class__ .__name__ if self .datafit else "None"
207+
208+ if self .penalty and hasattr (self .penalty , 'alpha' ):
209+ penalty_alpha = self .penalty .alpha
210+ else :
211+ penalty_alpha = None
199212
200- Returns
201- -------
202- repr : str
203- String representation.
204- """
205213 return (
206214 'GeneralizedLinearEstimator(datafit=%s, penalty=%s, alpha=%s)'
207- % (self .datafit .__class__ .__name__ , self .penalty .__class__ .__name__ ,
208- self .penalty .alpha ))
215+ % (datafit_name , penalty_name , penalty_alpha ))
216+
217+ def get_params (self , deep = True ):
218+ """Get parameters, including nested datafit, penalty, solver hyper-parameters."""
219+ # First get the top-level params (datafit, penalty, solver)
220+ params = super ().get_params (deep = False )
221+
222+ if not deep :
223+ return params
224+
225+ # For each sub-estimator, ask its own get_params if available
226+ for comp_name in ('datafit' , 'penalty' , 'solver' ):
227+ comp = getattr (self , comp_name )
228+ if comp is None :
229+ continue
230+
231+ # If it implements sklearn API, use get_params
232+ if hasattr (comp , 'get_params' ):
233+ sub_params = comp .get_params (deep = True )
234+ else :
235+ # Otherwise fall back to its public attributes
236+ sub_params = {k : v for k , v in vars (
237+ comp ).items () if not k .startswith ('_' )}
238+
239+ for sub_key , sub_val in sub_params .items ():
240+ params [f'{ comp_name } __{ sub_key } ' ] = sub_val
241+
242+ return params
243+
244+ def set_params (self , ** params ):
245+ """Set parameters, including nested ones for datafit, penalty, solver."""
246+ if not params :
247+ return self
248+
249+ # Top-level valid args are exactly those in __init__
250+ valid_top = set (self .get_params (deep = False ))
251+
252+ # Collect sub-params for each component
253+ nested = {'datafit' : {}, 'penalty' : {}, 'solver' : {}}
254+
255+ # First pass: split top-level vs nested
256+ for key , val in params .items ():
257+ if '__' in key :
258+ comp , sub_key = key .split ('__' , 1 )
259+ if comp not in nested :
260+ raise ValueError (f"Invalid parameter: { key } " )
261+ nested [comp ][sub_key ] = val
262+ else :
263+ if key not in valid_top :
264+ raise ValueError (f"Invalid parameter: { key } " )
265+ setattr (self , key , val )
266+
267+ # Second pass: apply nested updates
268+ for comp , comp_params in nested .items ():
269+ if not comp_params :
270+ continue
271+ current = getattr (self , comp )
272+ if current is not None and hasattr (current , 'set_params' ):
273+ current .set_params (** comp_params )
274+ elif current is not None :
275+ # fallback to simple setattr
276+ for sub_key , sub_val in comp_params .items ():
277+ if not hasattr (current , sub_key ):
278+ raise ValueError (f"{ comp } has no parameter { sub_key } " )
279+ setattr (current , sub_key , sub_val )
280+ else :
281+ # instantiate missing component
282+ if comp == 'datafit' :
283+ self .datafit = Quadratic (** comp_params )
284+ elif comp == 'penalty' :
285+ self .penalty = L1 (** comp_params )
286+ elif comp == 'solver' :
287+ self .solver = AndersonCD (** comp_params )
288+
289+ return self
209290
210291 def fit (self , X , y ):
211292 """Fit estimator.
@@ -269,31 +350,6 @@ def predict(self, X):
269350 else :
270351 return self ._decision_function (X )
271352
272- def get_params (self , deep = False ):
273- """Get parameters of the estimators including the datafit's and penalty's.
274-
275- Parameters
276- ----------
277- deep : bool
278- Whether or not return the parameters for contained subobjects estimators.
279-
280- Returns
281- -------
282- params : dict
283- The parameters of the estimator.
284- """
285- params = super ().get_params (deep )
286- filtered_types = (float , int , str , np .ndarray )
287- penalty_params = [('penalty__' , p , getattr (self .penalty , p )) for p in
288- dir (self .penalty ) if p [0 ] != "_" and
289- type (getattr (self .penalty , p )) in filtered_types ]
290- datafit_params = [('datafit__' , p , getattr (self .datafit , p )) for p in
291- dir (self .datafit ) if p [0 ] != "_" and
292- type (getattr (self .datafit , p )) in filtered_types ]
293- for p_prefix , p_key , p_val in penalty_params + datafit_params :
294- params [p_prefix + p_key ] = p_val
295- return params
296-
297353
298354class Lasso (RegressorMixin , LinearModel ):
299355 r"""Lasso estimator based on Celer solver and primal extrapolation.
0 commit comments