Skip to content

Commit e761de7

Browse files
attempt at fixing warm_start, adressing comments, timing measure, different datasets in unit test
1 parent 70a07fd commit e761de7

File tree

2 files changed

+37
-12
lines changed

2 files changed

+37
-12
lines changed

skglm/cv.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,9 @@ def _solve_fold(k, train, test, alpha, l1, w_init):
6262
est = GeneralizedLinearEstimator(
6363
datafit=self.datafit, penalty=pen, solver=self.solver
6464
)
65-
est.solver.warm_start = True
65+
if w_init is not None:
66+
est.coef_ = w_init[0]
67+
est.intercept_ = w_init[1]
6668
est.fit(X[train], y[train])
6769
y_pred = est.predict(X[test])
6870
return est.coef_, est.intercept_, self._score(y[test], y_pred)

skglm/tests/test_cv.py

Lines changed: 34 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,40 +1,63 @@
11
import numpy as np
2-
import pytest
2+
import time
33
from sklearn.datasets import make_regression
44
from sklearn.linear_model import ElasticNet
55
from sklearn.model_selection import GridSearchCV, KFold
66
from skglm.datafits import Quadratic
77
from skglm.penalties import L1_plus_L2
88
from skglm.solvers import AndersonCD
99
from skglm.cv import GeneralizedLinearEstimatorCV
10+
import pytest
1011

1112

12-
@pytest.mark.parametrize("seed", [0, 42])
13-
def test_elasticnet_cv_matches_sklearn(seed):
13+
@pytest.mark.parametrize("n_samples,n_features,noise",
14+
[(100, 10, 0.1), (100, 500, 0.2), (100, 500, 0.3)])
15+
def test_elasticnet_cv_matches_sklearn(n_samples, n_features, noise):
1416
"""Test GeneralizedLinearEstimatorCV matches sklearn GridSearchCV for ElasticNet."""
15-
X, y = make_regression(n_samples=100, n_features=20, noise=0.1, random_state=seed)
17+
seed = 42
18+
X, y = make_regression(n_samples=n_samples,
19+
n_features=n_features, noise=noise, random_state=seed)
1620

17-
alphas = np.array([0.001, 0.01, 0.1, 1.0])
21+
n = X.shape[0]
22+
alpha_max = np.max(np.abs(X.T @ y)) / n
23+
alphas = alpha_max * np.array([1, 0.1, 0.01, 0.001])
1824
l1_ratios = np.array([0.2, 0.5, 0.8])
1925
cv = KFold(n_splits=5, shuffle=True, random_state=seed)
2026

27+
start_time = time.time()
2128
sklearn_model = GridSearchCV(
2229
ElasticNet(max_iter=10000, tol=1e-8),
2330
{'alpha': alphas, 'l1_ratio': l1_ratios},
2431
cv=cv, scoring='neg_mean_squared_error', n_jobs=1
2532
).fit(X, y)
33+
sklearn_time = time.time() - start_time
2634

35+
start_time = time.time()
2736
skglm_model = GeneralizedLinearEstimatorCV(
2837
Quadratic(), L1_plus_L2(0.1, 0.5), AndersonCD(max_iter=10000, tol=1e-8),
2938
alphas=alphas, l1_ratio=l1_ratios, cv=5, random_state=seed, n_jobs=1
3039
).fit(X, y)
40+
skglm_time = time.time() - start_time
3141

32-
assert sklearn_model.best_params_['alpha'] == skglm_model.alpha_
33-
assert sklearn_model.best_params_['l1_ratio'] == skglm_model.l1_ratio_
34-
np.testing.assert_allclose(sklearn_model.best_estimator_.coef_,
35-
skglm_model.coef_.ravel(), rtol=1e-4, atol=1e-6)
36-
np.testing.assert_allclose(sklearn_model.best_estimator_.intercept_,
37-
skglm_model.intercept_, rtol=1e-4, atol=1e-6)
42+
print(f"\nTest case: {n_samples} samples, {n_features} features, noise={noise}")
43+
print(f"Timing comparison (seed={seed}):")
44+
print(f"sklearn: {sklearn_time:.2f}s")
45+
print(f"skglm: {skglm_time:.2f}s")
46+
print(f"speedup: {sklearn_time/skglm_time:.1f}x")
47+
48+
try:
49+
assert sklearn_model.best_params_['alpha'] == skglm_model.alpha_
50+
assert sklearn_model.best_params_['l1_ratio'] == skglm_model.l1_ratio_
51+
np.testing.assert_allclose(sklearn_model.best_estimator_.coef_,
52+
skglm_model.coef_.ravel(), rtol=1e-4, atol=1e-6)
53+
np.testing.assert_allclose(sklearn_model.best_estimator_.intercept_,
54+
skglm_model.intercept_, rtol=1e-4, atol=1e-6)
55+
except AssertionError:
56+
print("\nBest parameters:")
57+
print(f"sklearn: alpha={sklearn_model.best_params_['alpha']}, "
58+
f"l1_ratio={sklearn_model.best_params_['l1_ratio']}")
59+
print(f"skglm: alpha={skglm_model.alpha_}, l1_ratio={skglm_model.l1_ratio_}")
60+
raise
3861

3962

4063
if __name__ == "__main__":

0 commit comments

Comments
 (0)