Skip to content

Commit ba90da8

Browse files
minimal example of GridSearchCV showing compilation issue
1 parent 8e27664 commit ba90da8

File tree

3 files changed

+148
-34
lines changed

3 files changed

+148
-34
lines changed

skglm/estimators.py

Lines changed: 89 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -188,24 +188,105 @@ class GeneralizedLinearEstimator(LinearModel):
188188
Number of subproblems solved to reach the specified tolerance.
189189
"""
190190

191+
_parameter_constraints: dict = {
192+
"datafit": [None, "object"],
193+
"penalty": [None, "object"],
194+
"solver": [None, "object"],
195+
}
196+
191197
def __init__(self, datafit=None, penalty=None, solver=None):
192198
super(GeneralizedLinearEstimator, self).__init__()
193199
self.penalty = penalty
194200
self.datafit = datafit
195201
self.solver = solver
196202

197203
def __repr__(self):
198-
"""Get string representation of the estimator.
204+
"""String representation."""
205+
penalty_name = self.penalty.__class__.__name__ if self.penalty else "None"
206+
datafit_name = self.datafit.__class__.__name__ if self.datafit else "None"
207+
208+
if self.penalty and hasattr(self.penalty, 'alpha'):
209+
penalty_alpha = self.penalty.alpha
210+
else:
211+
penalty_alpha = None
199212

200-
Returns
201-
-------
202-
repr : str
203-
String representation.
204-
"""
205213
return (
206214
'GeneralizedLinearEstimator(datafit=%s, penalty=%s, alpha=%s)'
207-
% (self.datafit.__class__.__name__, self.penalty.__class__.__name__,
208-
self.penalty.alpha))
215+
% (datafit_name, penalty_name, penalty_alpha))
216+
217+
def get_params(self, deep=True):
218+
"""Get parameters, including nested datafit, penalty, solver hyper-parameters."""
219+
# First get the top-level params (datafit, penalty, solver)
220+
params = super().get_params(deep=False)
221+
222+
if not deep:
223+
return params
224+
225+
# For each sub-estimator, ask its own get_params if available
226+
for comp_name in ('datafit', 'penalty', 'solver'):
227+
comp = getattr(self, comp_name)
228+
if comp is None:
229+
continue
230+
231+
# If it implements sklearn API, use get_params
232+
if hasattr(comp, 'get_params'):
233+
sub_params = comp.get_params(deep=True)
234+
else:
235+
# Otherwise fall back to its public attributes
236+
sub_params = {k: v for k, v in vars(
237+
comp).items() if not k.startswith('_')}
238+
239+
for sub_key, sub_val in sub_params.items():
240+
params[f'{comp_name}__{sub_key}'] = sub_val
241+
242+
return params
243+
244+
def set_params(self, **params):
245+
"""Set parameters, including nested ones for datafit, penalty, solver."""
246+
if not params:
247+
return self
248+
249+
# Top-level valid args are exactly those in __init__
250+
valid_top = set(self.get_params(deep=False))
251+
252+
# Collect sub-params for each component
253+
nested = {'datafit': {}, 'penalty': {}, 'solver': {}}
254+
255+
# First pass: split top-level vs nested
256+
for key, val in params.items():
257+
if '__' in key:
258+
comp, sub_key = key.split('__', 1)
259+
if comp not in nested:
260+
raise ValueError(f"Invalid parameter: {key}")
261+
nested[comp][sub_key] = val
262+
else:
263+
if key not in valid_top:
264+
raise ValueError(f"Invalid parameter: {key}")
265+
setattr(self, key, val)
266+
267+
# Second pass: apply nested updates
268+
for comp, comp_params in nested.items():
269+
if not comp_params:
270+
continue
271+
current = getattr(self, comp)
272+
if current is not None and hasattr(current, 'set_params'):
273+
current.set_params(**comp_params)
274+
elif current is not None:
275+
# fallback to simple setattr
276+
for sub_key, sub_val in comp_params.items():
277+
if not hasattr(current, sub_key):
278+
raise ValueError(f"{comp} has no parameter {sub_key}")
279+
setattr(current, sub_key, sub_val)
280+
else:
281+
# instantiate missing component
282+
if comp == 'datafit':
283+
self.datafit = Quadratic(**comp_params)
284+
elif comp == 'penalty':
285+
self.penalty = L1(**comp_params)
286+
elif comp == 'solver':
287+
self.solver = AndersonCD(**comp_params)
288+
289+
return self
209290

210291
def fit(self, X, y):
211292
"""Fit estimator.
@@ -269,31 +350,6 @@ def predict(self, X):
269350
else:
270351
return self._decision_function(X)
271352

272-
def get_params(self, deep=False):
273-
"""Get parameters of the estimators including the datafit's and penalty's.
274-
275-
Parameters
276-
----------
277-
deep : bool
278-
Whether or not return the parameters for contained subobjects estimators.
279-
280-
Returns
281-
-------
282-
params : dict
283-
The parameters of the estimator.
284-
"""
285-
params = super().get_params(deep)
286-
filtered_types = (float, int, str, np.ndarray)
287-
penalty_params = [('penalty__', p, getattr(self.penalty, p)) for p in
288-
dir(self.penalty) if p[0] != "_" and
289-
type(getattr(self.penalty, p)) in filtered_types]
290-
datafit_params = [('datafit__', p, getattr(self.datafit, p)) for p in
291-
dir(self.datafit) if p[0] != "_" and
292-
type(getattr(self.datafit, p)) in filtered_types]
293-
for p_prefix, p_key, p_val in penalty_params + datafit_params:
294-
params[p_prefix + p_key] = p_val
295-
return params
296-
297353

298354
class Lasso(RegressorMixin, LinearModel):
299355
r"""Lasso estimator based on Celer solver and primal extrapolation.

skglm/tests/test_gridsearch.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
from sklearn.model_selection import GridSearchCV
2+
from sklearn.datasets import make_regression
3+
from skglm import GeneralizedLinearEstimator
4+
from skglm.datafits import Quadratic
5+
from skglm.penalties import L1
6+
from skglm.solvers import AndersonCD
7+
8+
9+
def test_gridsearch_compatibility():
10+
# Generate synthetic data
11+
X, y = make_regression(n_samples=100, n_features=10, random_state=42)
12+
13+
# Create base estimator
14+
base_estimator = GeneralizedLinearEstimator(
15+
datafit=Quadratic(),
16+
penalty=L1(1.0),
17+
solver=AndersonCD()
18+
)
19+
20+
# Define parameter grid
21+
param_grid = {
22+
'penalty__alpha': [0.1, 1.0, 10.0],
23+
'solver__max_iter': [10, 20],
24+
'solver__tol': [1e-3, 1e-4]
25+
}
26+
27+
# Create GridSearchCV
28+
grid_search = GridSearchCV(
29+
base_estimator,
30+
param_grid,
31+
cv=3,
32+
scoring='neg_mean_squared_error'
33+
)
34+
35+
# Fit GridSearchCV
36+
grid_search.fit(X, y)
37+
38+
# Verify that GridSearchCV worked
39+
assert hasattr(grid_search, 'best_params_')
40+
assert hasattr(grid_search, 'best_score_')
41+
assert hasattr(grid_search, 'best_estimator_')
42+
43+
# Verify that best_estimator_ has the correct parameters
44+
best_estimator = grid_search.best_estimator_
45+
assert isinstance(best_estimator, GeneralizedLinearEstimator)
46+
assert best_estimator.penalty.alpha in [0.1, 1.0, 10.0]
47+
assert best_estimator.solver.max_iter in [10, 20]
48+
assert best_estimator.solver.tol in [1e-3, 1e-4]

skglm/utils/jit_compilation.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,17 @@ def jit_cached_compile(klass, spec, to_float32=False):
5959
if to_float32:
6060
spec = spec_to_float32(spec)
6161

62-
return jitclass(spec)(klass)
62+
# Create a new class without slots
63+
class CompiledClass:
64+
pass
65+
66+
# Copy over all methods and attributes from the original class
67+
for name, value in klass.__dict__.items():
68+
# Skip __slots__ and __slotnames__ but keep other special methods
69+
if name not in ['__slots__', '__slotnames__']:
70+
setattr(CompiledClass, name, value)
71+
72+
return jitclass(spec)(CompiledClass)
6373

6474

6575
def compiled_clone(instance, to_float32=False):

0 commit comments

Comments
 (0)