diff --git a/skglm/estimators.py b/skglm/estimators.py index 8e327d89e..4085c2542 100644 --- a/skglm/estimators.py +++ b/skglm/estimators.py @@ -188,6 +188,12 @@ class GeneralizedLinearEstimator(LinearModel): Number of subproblems solved to reach the specified tolerance. """ + _parameter_constraints: dict = { + "datafit": [None, "object"], + "penalty": [None, "object"], + "solver": [None, "object"], + } + def __init__(self, datafit=None, penalty=None, solver=None): super(GeneralizedLinearEstimator, self).__init__() self.penalty = penalty @@ -195,17 +201,92 @@ def __init__(self, datafit=None, penalty=None, solver=None): self.solver = solver def __repr__(self): - """Get string representation of the estimator. + """String representation.""" + penalty_name = self.penalty.__class__.__name__ if self.penalty else "None" + datafit_name = self.datafit.__class__.__name__ if self.datafit else "None" + + if self.penalty and hasattr(self.penalty, 'alpha'): + penalty_alpha = self.penalty.alpha + else: + penalty_alpha = None - Returns - ------- - repr : str - String representation. - """ return ( 'GeneralizedLinearEstimator(datafit=%s, penalty=%s, alpha=%s)' - % (self.datafit.__class__.__name__, self.penalty.__class__.__name__, - self.penalty.alpha)) + % (datafit_name, penalty_name, penalty_alpha)) + + def get_params(self, deep=True): + """Get parameters, including nested datafit, penalty, solver hyper-parameters.""" + # First get the top-level params (datafit, penalty, solver) + params = super().get_params(deep=False) + + if not deep: + return params + + # For each sub-estimator, ask its own get_params if available + for comp_name in ('datafit', 'penalty', 'solver'): + comp = getattr(self, comp_name) + if comp is None: + continue + + # If it implements sklearn API, use get_params + if hasattr(comp, 'get_params'): + sub_params = comp.get_params(deep=True) + else: + # Otherwise fall back to its public attributes + sub_params = {k: v for k, v in vars( + comp).items() if not k.startswith('_')} + + for sub_key, sub_val in sub_params.items(): + params[f'{comp_name}__{sub_key}'] = sub_val + + return params + + def set_params(self, **params): + """Set parameters, including nested ones for datafit, penalty, solver.""" + if not params: + return self + + # Top-level valid args are exactly those in __init__ + valid_top = set(self.get_params(deep=False)) + + # Collect sub-params for each component + nested = {'datafit': {}, 'penalty': {}, 'solver': {}} + + # First pass: split top-level vs nested + for key, val in params.items(): + if '__' in key: + comp, sub_key = key.split('__', 1) + if comp not in nested: + raise ValueError(f"Invalid parameter: {key}") + nested[comp][sub_key] = val + else: + if key not in valid_top: + raise ValueError(f"Invalid parameter: {key}") + setattr(self, key, val) + + # Second pass: apply nested updates + for comp, comp_params in nested.items(): + if not comp_params: + continue + current = getattr(self, comp) + if current is not None and hasattr(current, 'set_params'): + current.set_params(**comp_params) + elif current is not None: + # fallback to simple setattr + for sub_key, sub_val in comp_params.items(): + if not hasattr(current, sub_key): + raise ValueError(f"{comp} has no parameter {sub_key}") + setattr(current, sub_key, sub_val) + else: + # instantiate missing component + if comp == 'datafit': + self.datafit = Quadratic(**comp_params) + elif comp == 'penalty': + self.penalty = L1(**comp_params) + elif comp == 'solver': + self.solver = AndersonCD(**comp_params) + + return self def fit(self, X, y): """Fit estimator. @@ -269,31 +350,6 @@ def predict(self, X): else: return self._decision_function(X) - def get_params(self, deep=False): - """Get parameters of the estimators including the datafit's and penalty's. - - Parameters - ---------- - deep : bool - Whether or not return the parameters for contained subobjects estimators. - - Returns - ------- - params : dict - The parameters of the estimator. - """ - params = super().get_params(deep) - filtered_types = (float, int, str, np.ndarray) - penalty_params = [('penalty__', p, getattr(self.penalty, p)) for p in - dir(self.penalty) if p[0] != "_" and - type(getattr(self.penalty, p)) in filtered_types] - datafit_params = [('datafit__', p, getattr(self.datafit, p)) for p in - dir(self.datafit) if p[0] != "_" and - type(getattr(self.datafit, p)) in filtered_types] - for p_prefix, p_key, p_val in penalty_params + datafit_params: - params[p_prefix + p_key] = p_val - return params - class Lasso(RegressorMixin, LinearModel): r"""Lasso estimator based on Celer solver and primal extrapolation. diff --git a/skglm/tests/test_gridsearch.py b/skglm/tests/test_gridsearch.py new file mode 100644 index 000000000..7c99dfb0f --- /dev/null +++ b/skglm/tests/test_gridsearch.py @@ -0,0 +1,48 @@ +from sklearn.model_selection import GridSearchCV +from sklearn.datasets import make_regression +from skglm import GeneralizedLinearEstimator +from skglm.datafits import Quadratic +from skglm.penalties import L1 +from skglm.solvers import AndersonCD + + +def test_gridsearch_compatibility(): + # Generate synthetic data + X, y = make_regression(n_samples=100, n_features=10, random_state=42) + + # Create base estimator + base_estimator = GeneralizedLinearEstimator( + datafit=Quadratic(), + penalty=L1(1.0), + solver=AndersonCD() + ) + + # Define parameter grid + param_grid = { + 'penalty__alpha': [0.1, 1.0, 10.0], + 'solver__max_iter': [10, 20], + 'solver__tol': [1e-3, 1e-4] + } + + # Create GridSearchCV + grid_search = GridSearchCV( + base_estimator, + param_grid, + cv=3, + scoring='neg_mean_squared_error' + ) + + # Fit GridSearchCV + grid_search.fit(X, y) + + # Verify that GridSearchCV worked + assert hasattr(grid_search, 'best_params_') + assert hasattr(grid_search, 'best_score_') + assert hasattr(grid_search, 'best_estimator_') + + # Verify that best_estimator_ has the correct parameters + best_estimator = grid_search.best_estimator_ + assert isinstance(best_estimator, GeneralizedLinearEstimator) + assert best_estimator.penalty.alpha in [0.1, 1.0, 10.0] + assert best_estimator.solver.max_iter in [10, 20] + assert best_estimator.solver.tol in [1e-3, 1e-4] diff --git a/skglm/utils/jit_compilation.py b/skglm/utils/jit_compilation.py index cf63e357e..5f55516a5 100644 --- a/skglm/utils/jit_compilation.py +++ b/skglm/utils/jit_compilation.py @@ -59,7 +59,17 @@ def jit_cached_compile(klass, spec, to_float32=False): if to_float32: spec = spec_to_float32(spec) - return jitclass(spec)(klass) + # Create a new class without slots + class CompiledClass: + pass + + # Copy over all methods and attributes from the original class + for name, value in klass.__dict__.items(): + # Skip __slots__ and __slotnames__ but keep other special methods + if name not in ['__slots__', '__slotnames__']: + setattr(CompiledClass, name, value) + + return jitclass(spec)(CompiledClass) def compiled_clone(instance, to_float32=False):