Skip to content

Commit 70a07fd

Browse files
committed
try other setups test cv
1 parent 2c7162d commit 70a07fd

File tree

1 file changed

+26
-0
lines changed

1 file changed

+26
-0
lines changed

skglm/tests/test_cv.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,3 +35,29 @@ def test_elasticnet_cv_matches_sklearn(seed):
3535
skglm_model.coef_.ravel(), rtol=1e-4, atol=1e-6)
3636
np.testing.assert_allclose(sklearn_model.best_estimator_.intercept_,
3737
skglm_model.intercept_, rtol=1e-4, atol=1e-6)
38+
39+
40+
if __name__ == "__main__":
41+
X, y = make_regression(n_samples=40, n_features=60, noise=2, random_state=0)
42+
43+
alphas = np.array([1e-4, 0.001, 0.01, 0.1])
44+
l1_ratios = np.array([0.2, 0.5, 0.8])
45+
cv = KFold(n_splits=5, shuffle=True, random_state=0)
46+
47+
sklearn_model = GridSearchCV(
48+
ElasticNet(max_iter=10000, tol=1e-8),
49+
{'alpha': alphas, 'l1_ratio': l1_ratios},
50+
cv=cv, scoring='neg_mean_squared_error', n_jobs=1
51+
).fit(X, y)
52+
53+
skglm_model = GeneralizedLinearEstimatorCV(
54+
Quadratic(), L1_plus_L2(0.1, 0.5), AndersonCD(max_iter=10000, tol=1e-8),
55+
alphas=alphas, l1_ratio=l1_ratios, cv=5, random_state=0, n_jobs=1
56+
).fit(X, y)
57+
58+
assert sklearn_model.best_params_['alpha'] == skglm_model.alpha_
59+
assert sklearn_model.best_params_['l1_ratio'] == skglm_model.l1_ratio_
60+
np.testing.assert_allclose(sklearn_model.best_estimator_.coef_,
61+
skglm_model.coef_.ravel(), rtol=1e-4, atol=1e-6)
62+
np.testing.assert_allclose(sklearn_model.best_estimator_.intercept_,
63+
skglm_model.intercept_, rtol=1e-4, atol=1e-6)

0 commit comments

Comments
 (0)