|  | 
| 14 | 14 | from sklearn.linear_model import ElasticNet as ElasticNet_sklearn | 
| 15 | 15 | from sklearn.linear_model import LogisticRegression as LogReg_sklearn | 
| 16 | 16 | from sklearn.linear_model import MultiTaskLasso as MultiTaskLasso_sklearn | 
|  | 17 | +from sklearn.linear_model import PoissonRegressor, GammaRegressor | 
| 17 | 18 | from sklearn.model_selection import GridSearchCV | 
| 18 | 19 | from sklearn.svm import LinearSVC as LinearSVC_sklearn | 
| 19 | 20 | from sklearn.utils.estimator_checks import check_estimator | 
|  | 
| 23 | 24 | from skglm.estimators import ( | 
| 24 | 25 |     GeneralizedLinearEstimator, Lasso, MultiTaskLasso, WeightedLasso, ElasticNet, | 
| 25 | 26 |     MCPRegression, SparseLogisticRegression, LinearSVC, GroupLasso, CoxEstimator) | 
| 26 |  | -from skglm.datafits import Logistic, Quadratic, QuadraticSVC, QuadraticMultiTask, Cox | 
| 27 |  | -from skglm.penalties import L1, IndicatorBox, L1_plus_L2, MCPenalty, WeightedL1, SLOPE | 
|  | 27 | +from skglm.datafits import ( | 
|  | 28 | +    Logistic, Quadratic, QuadraticSVC, QuadraticMultiTask, Cox, Poisson, Gamma | 
|  | 29 | +) | 
|  | 30 | +from skglm.penalties import ( | 
|  | 31 | +    L1, IndicatorBox, L1_plus_L2, MCPenalty, WeightedL1, SLOPE | 
|  | 32 | +) | 
| 28 | 33 | from skglm.solvers import AndersonCD, FISTA, ProxNewton | 
| 29 | 34 | 
 | 
| 30 | 35 | n_samples = 50 | 
| @@ -629,5 +634,23 @@ def test_SLOPE_printing(): | 
| 629 | 634 |     assert isinstance(res, str) | 
| 630 | 635 | 
 | 
| 631 | 636 | 
 | 
|  | 637 | +@pytest.mark.parametrize( | 
|  | 638 | +    "sklearn_reg, skglm_datafit", | 
|  | 639 | +    [(PoissonRegressor, Poisson), (GammaRegressor, Gamma)] | 
|  | 640 | +) | 
|  | 641 | +def test_inverse_link_prediction(sklearn_reg, skglm_datafit): | 
|  | 642 | +    np.random.seed(42) | 
|  | 643 | +    X = np.random.randn(20, 5) | 
|  | 644 | +    y = np.random.randint(1, 6, size=20)  # Use 1-6 for both (Gamma needs y>0) | 
|  | 645 | +    sklearn_pred = sklearn_reg(alpha=0.0, max_iter=10_000, | 
|  | 646 | +                               tol=1e-8).fit(X, y).predict(X) | 
|  | 647 | +    skglm_pred = GeneralizedLinearEstimator( | 
|  | 648 | +        datafit=skglm_datafit(), | 
|  | 649 | +        penalty=L1_plus_L2(0.0, l1_ratio=0.0), | 
|  | 650 | +        solver=ProxNewton(fit_intercept=True, max_iter=10_000, tol=1e-8) | 
|  | 651 | +    ).fit(X, y).predict(X) | 
|  | 652 | +    np.testing.assert_allclose(sklearn_pred, skglm_pred, rtol=1e-6, atol=1e-8) | 
|  | 653 | + | 
|  | 654 | + | 
| 632 | 655 | if __name__ == "__main__": | 
| 633 | 656 |     pass | 
0 commit comments