@@ -634,43 +634,30 @@ def test_SLOPE_printing():
634634 assert isinstance (res , str )
635635
636636
637- def test_poisson_predictions_match_sklearn ():
638- """Test that skglm Poisson estimator predictions match sklearn PoissonRegressor."""
639- np .random .seed (42 )
640- X = np .random .randn (20 , 5 )
641- y = np .random .poisson (np .exp (X .sum (axis = 1 ) * 0.1 ))
642-
643- # Fit sklearn PoissonRegressor (no regularization due to different alpha scaling)
644- sklearn_pred = PoissonRegressor (
645- alpha = 0.0 , max_iter = 10_000 , tol = 1e-8 ).fit (X , y ).predict (X )
646-
647- # Fit skglm equivalent (no regularization)
648- skglm_pred = GeneralizedLinearEstimator (
649- datafit = Poisson (),
650- penalty = L1_plus_L2 (0.0 , l1_ratio = 0.0 ),
651- solver = ProxNewton (fit_intercept = True , max_iter = 10_000 , tol = 1e-8 )
652- ).fit (X , y ).predict (X )
653-
654- np .testing .assert_allclose (sklearn_pred , skglm_pred , rtol = 1e-6 , atol = 1e-8 )
655-
656-
657- def test_gamma_predictions_match_sklearn ():
658- """Test that skglm Gamma estimator predictions match sklearn GammaRegressor."""
637+ @pytest .mark .parametrize (
638+ "sklearn_reg, skglm_datafit, y_gen" ,
639+ [
640+ (
641+ PoissonRegressor , Poisson ,
642+ lambda X : np .random .poisson (np .exp (X .sum (axis = 1 ) * 0.1 ))
643+ ),
644+ (
645+ GammaRegressor , Gamma ,
646+ lambda X : np .random .gamma (2.0 , np .exp (X .sum (axis = 1 ) * 0.1 ))
647+ ),
648+ ]
649+ )
650+ def test_inverse_link_prediction (sklearn_reg , skglm_datafit , y_gen ):
659651 np .random .seed (42 )
660652 X = np .random .randn (20 , 5 )
661- y = np .random .gamma (2.0 , np .exp (X .sum (axis = 1 ) * 0.1 ))
662-
663- # Fit sklearn GammaRegressor (no regularization due to different alpha scaling)
664- sklearn_pred = GammaRegressor (
665- alpha = 0.0 , max_iter = 10_000 , tol = 1e-8 ).fit (X , y ).predict (X )
666-
667- # Fit skglm equivalent (no regularization)
653+ y = y_gen (X )
654+ sklearn_pred = sklearn_reg (alpha = 0.0 , max_iter = 10_000 ,
655+ tol = 1e-8 ).fit (X , y ).predict (X )
668656 skglm_pred = GeneralizedLinearEstimator (
669- datafit = Gamma (),
657+ datafit = skglm_datafit (),
670658 penalty = L1_plus_L2 (0.0 , l1_ratio = 0.0 ),
671659 solver = ProxNewton (fit_intercept = True , max_iter = 10_000 , tol = 1e-8 )
672660 ).fit (X , y ).predict (X )
673-
674661 np .testing .assert_allclose (sklearn_pred , skglm_pred , rtol = 1e-6 , atol = 1e-8 )
675662
676663
0 commit comments