Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 89 additions & 33 deletions skglm/estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,24 +188,105 @@ class GeneralizedLinearEstimator(LinearModel):
Number of subproblems solved to reach the specified tolerance.
"""

_parameter_constraints: dict = {
"datafit": [None, "object"],
"penalty": [None, "object"],
"solver": [None, "object"],
}

def __init__(self, datafit=None, penalty=None, solver=None):
super(GeneralizedLinearEstimator, self).__init__()
self.penalty = penalty
self.datafit = datafit
self.solver = solver

def __repr__(self):
"""Get string representation of the estimator.
"""String representation."""
penalty_name = self.penalty.__class__.__name__ if self.penalty else "None"
datafit_name = self.datafit.__class__.__name__ if self.datafit else "None"

if self.penalty and hasattr(self.penalty, 'alpha'):
penalty_alpha = self.penalty.alpha
else:
penalty_alpha = None

Returns
-------
repr : str
String representation.
"""
return (
'GeneralizedLinearEstimator(datafit=%s, penalty=%s, alpha=%s)'
% (self.datafit.__class__.__name__, self.penalty.__class__.__name__,
self.penalty.alpha))
% (datafit_name, penalty_name, penalty_alpha))

def get_params(self, deep=True):
"""Get parameters, including nested datafit, penalty, solver hyper-parameters."""
# First get the top-level params (datafit, penalty, solver)
params = super().get_params(deep=False)

if not deep:
return params

# For each sub-estimator, ask its own get_params if available
for comp_name in ('datafit', 'penalty', 'solver'):
comp = getattr(self, comp_name)
if comp is None:
continue

# If it implements sklearn API, use get_params
if hasattr(comp, 'get_params'):
sub_params = comp.get_params(deep=True)
else:
# Otherwise fall back to its public attributes
sub_params = {k: v for k, v in vars(
comp).items() if not k.startswith('_')}

for sub_key, sub_val in sub_params.items():
params[f'{comp_name}__{sub_key}'] = sub_val

return params

def set_params(self, **params):
"""Set parameters, including nested ones for datafit, penalty, solver."""
if not params:
return self

# Top-level valid args are exactly those in __init__
valid_top = set(self.get_params(deep=False))

# Collect sub-params for each component
nested = {'datafit': {}, 'penalty': {}, 'solver': {}}

# First pass: split top-level vs nested
for key, val in params.items():
if '__' in key:
comp, sub_key = key.split('__', 1)
if comp not in nested:
raise ValueError(f"Invalid parameter: {key}")
nested[comp][sub_key] = val
else:
if key not in valid_top:
raise ValueError(f"Invalid parameter: {key}")
setattr(self, key, val)

# Second pass: apply nested updates
for comp, comp_params in nested.items():
if not comp_params:
continue
current = getattr(self, comp)
if current is not None and hasattr(current, 'set_params'):
current.set_params(**comp_params)
elif current is not None:
# fallback to simple setattr
for sub_key, sub_val in comp_params.items():
if not hasattr(current, sub_key):
raise ValueError(f"{comp} has no parameter {sub_key}")
setattr(current, sub_key, sub_val)
else:
# instantiate missing component
if comp == 'datafit':
self.datafit = Quadratic(**comp_params)
elif comp == 'penalty':
self.penalty = L1(**comp_params)
elif comp == 'solver':
self.solver = AndersonCD(**comp_params)

return self

def fit(self, X, y):
"""Fit estimator.
Expand Down Expand Up @@ -269,31 +350,6 @@ def predict(self, X):
else:
return self._decision_function(X)

def get_params(self, deep=False):
"""Get parameters of the estimators including the datafit's and penalty's.

Parameters
----------
deep : bool
Whether or not return the parameters for contained subobjects estimators.

Returns
-------
params : dict
The parameters of the estimator.
"""
params = super().get_params(deep)
filtered_types = (float, int, str, np.ndarray)
penalty_params = [('penalty__', p, getattr(self.penalty, p)) for p in
dir(self.penalty) if p[0] != "_" and
type(getattr(self.penalty, p)) in filtered_types]
datafit_params = [('datafit__', p, getattr(self.datafit, p)) for p in
dir(self.datafit) if p[0] != "_" and
type(getattr(self.datafit, p)) in filtered_types]
for p_prefix, p_key, p_val in penalty_params + datafit_params:
params[p_prefix + p_key] = p_val
return params


class Lasso(RegressorMixin, LinearModel):
r"""Lasso estimator based on Celer solver and primal extrapolation.
Expand Down
48 changes: 48 additions & 0 deletions skglm/tests/test_gridsearch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from sklearn.model_selection import GridSearchCV
from sklearn.datasets import make_regression
from skglm import GeneralizedLinearEstimator
from skglm.datafits import Quadratic
from skglm.penalties import L1
from skglm.solvers import AndersonCD


def test_gridsearch_compatibility():
# Generate synthetic data
X, y = make_regression(n_samples=100, n_features=10, random_state=42)

# Create base estimator
base_estimator = GeneralizedLinearEstimator(
datafit=Quadratic(),
penalty=L1(1.0),
solver=AndersonCD()
)

# Define parameter grid
param_grid = {
'penalty__alpha': [0.1, 1.0, 10.0],
'solver__max_iter': [10, 20],
'solver__tol': [1e-3, 1e-4]
}

# Create GridSearchCV
grid_search = GridSearchCV(
base_estimator,
param_grid,
cv=3,
scoring='neg_mean_squared_error'
)

# Fit GridSearchCV
grid_search.fit(X, y)

# Verify that GridSearchCV worked
assert hasattr(grid_search, 'best_params_')
assert hasattr(grid_search, 'best_score_')
assert hasattr(grid_search, 'best_estimator_')

# Verify that best_estimator_ has the correct parameters
best_estimator = grid_search.best_estimator_
assert isinstance(best_estimator, GeneralizedLinearEstimator)
assert best_estimator.penalty.alpha in [0.1, 1.0, 10.0]
assert best_estimator.solver.max_iter in [10, 20]
assert best_estimator.solver.tol in [1e-3, 1e-4]
12 changes: 11 additions & 1 deletion skglm/utils/jit_compilation.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,17 @@ def jit_cached_compile(klass, spec, to_float32=False):
if to_float32:
spec = spec_to_float32(spec)

return jitclass(spec)(klass)
# Create a new class without slots
class CompiledClass:
pass

# Copy over all methods and attributes from the original class
for name, value in klass.__dict__.items():
# Skip __slots__ and __slotnames__ but keep other special methods
if name not in ['__slots__', '__slotnames__']:
setattr(CompiledClass, name, value)

return jitclass(spec)(CompiledClass)


def compiled_clone(instance, to_float32=False):
Expand Down
Loading