Skip to content

Commit 58587cc

Browse files
merge unit tests in single parametrized unit test
1 parent 2479119 commit 58587cc

File tree

1 file changed

+18
-31
lines changed

1 file changed

+18
-31
lines changed

skglm/tests/test_estimators.py

Lines changed: 18 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -634,43 +634,30 @@ def test_SLOPE_printing():
634634
assert isinstance(res, str)
635635

636636

637-
def test_poisson_predictions_match_sklearn():
638-
"""Test that skglm Poisson estimator predictions match sklearn PoissonRegressor."""
639-
np.random.seed(42)
640-
X = np.random.randn(20, 5)
641-
y = np.random.poisson(np.exp(X.sum(axis=1) * 0.1))
642-
643-
# Fit sklearn PoissonRegressor (no regularization due to different alpha scaling)
644-
sklearn_pred = PoissonRegressor(
645-
alpha=0.0, max_iter=10_000, tol=1e-8).fit(X, y).predict(X)
646-
647-
# Fit skglm equivalent (no regularization)
648-
skglm_pred = GeneralizedLinearEstimator(
649-
datafit=Poisson(),
650-
penalty=L1_plus_L2(0.0, l1_ratio=0.0),
651-
solver=ProxNewton(fit_intercept=True, max_iter=10_000, tol=1e-8)
652-
).fit(X, y).predict(X)
653-
654-
np.testing.assert_allclose(sklearn_pred, skglm_pred, rtol=1e-6, atol=1e-8)
655-
656-
657-
def test_gamma_predictions_match_sklearn():
658-
"""Test that skglm Gamma estimator predictions match sklearn GammaRegressor."""
637+
@pytest.mark.parametrize(
638+
"sklearn_reg, skglm_datafit, y_gen",
639+
[
640+
(
641+
PoissonRegressor, Poisson,
642+
lambda X: np.random.poisson(np.exp(X.sum(axis=1) * 0.1))
643+
),
644+
(
645+
GammaRegressor, Gamma,
646+
lambda X: np.random.gamma(2.0, np.exp(X.sum(axis=1) * 0.1))
647+
),
648+
]
649+
)
650+
def test_inverse_link_prediction(sklearn_reg, skglm_datafit, y_gen):
659651
np.random.seed(42)
660652
X = np.random.randn(20, 5)
661-
y = np.random.gamma(2.0, np.exp(X.sum(axis=1) * 0.1))
662-
663-
# Fit sklearn GammaRegressor (no regularization due to different alpha scaling)
664-
sklearn_pred = GammaRegressor(
665-
alpha=0.0, max_iter=10_000, tol=1e-8).fit(X, y).predict(X)
666-
667-
# Fit skglm equivalent (no regularization)
653+
y = y_gen(X)
654+
sklearn_pred = sklearn_reg(alpha=0.0, max_iter=10_000,
655+
tol=1e-8).fit(X, y).predict(X)
668656
skglm_pred = GeneralizedLinearEstimator(
669-
datafit=Gamma(),
657+
datafit=skglm_datafit(),
670658
penalty=L1_plus_L2(0.0, l1_ratio=0.0),
671659
solver=ProxNewton(fit_intercept=True, max_iter=10_000, tol=1e-8)
672660
).fit(X, y).predict(X)
673-
674661
np.testing.assert_allclose(sklearn_pred, skglm_pred, rtol=1e-6, atol=1e-8)
675662

676663

0 commit comments

Comments
 (0)