Skip to content

Commit ba5d9d9

Browse files
TST: Add sklearn-skglm match tests for Poisson and Gamma's predict() (#323)
1 parent 27288ed commit ba5d9d9

File tree

5 files changed

+55
-6
lines changed

5 files changed

+55
-6
lines changed

skglm/datafits/base.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,22 @@
22
class BaseDatafit:
33
"""Base class for datafits."""
44

5+
@staticmethod
6+
def inverse_link(Xw):
7+
"""Inverse link function (identity by default).
8+
9+
Parameters
10+
----------
11+
Xw : array-like
12+
Linear predictor values.
13+
14+
Returns
15+
-------
16+
array-like
17+
Transformed values in response scale.
18+
"""
19+
return Xw
20+
521
def get_spec(self):
622
"""Specify the numba types of the class attributes.
723

skglm/datafits/group.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,10 @@ class PoissonGroup(Poisson):
184184
def __init__(self, grp_ptr, grp_indices):
185185
self.grp_ptr, self.grp_indices = grp_ptr, grp_indices
186186

187+
@staticmethod
188+
def inverse_link(Xw):
189+
return np.exp(Xw)
190+
187191
def get_spec(self):
188192
return (
189193
('grp_ptr', int32[:]),

skglm/datafits/single_task.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -590,6 +590,10 @@ class Poisson(BaseDatafit):
590590
def __init__(self):
591591
pass
592592

593+
@staticmethod
594+
def inverse_link(Xw):
595+
return np.exp(Xw)
596+
593597
def get_spec(self):
594598
pass
595599

@@ -664,6 +668,10 @@ class Gamma(BaseDatafit):
664668
def __init__(self):
665669
pass
666670

671+
@staticmethod
672+
def inverse_link(Xw):
673+
return np.exp(Xw)
674+
667675
def get_spec(self):
668676
pass
669677

skglm/estimators.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
from skglm.solvers import AndersonCD, MultiTaskBCD, GroupBCD, ProxNewton, LBFGS
2121
from skglm.datafits import (
22-
Cox, Quadratic, Logistic, Poisson, PoissonGroup, QuadraticSVC,
22+
Cox, Quadratic, Logistic, QuadraticSVC,
2323
QuadraticMultiTask, QuadraticGroup,)
2424
from skglm.penalties import (L1, WeightedL1, L1_plus_L2, L2, WeightedGroupL2,
2525
MCPenalty, WeightedMCPenalty, IndicatorBox, L2_1)
@@ -266,10 +266,8 @@ def predict(self, X):
266266
else:
267267
indices = scores.argmax(axis=1)
268268
return self.classes_[indices]
269-
elif isinstance(self.datafit, (Poisson, PoissonGroup)):
270-
return np.exp(self._decision_function(X))
271269
else:
272-
return self._decision_function(X)
270+
return self.datafit.inverse_link(self._decision_function(X))
273271

274272
def get_params(self, deep=False):
275273
"""Get parameters of the estimators including the datafit's and penalty's.

skglm/tests/test_estimators.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from sklearn.linear_model import ElasticNet as ElasticNet_sklearn
1515
from sklearn.linear_model import LogisticRegression as LogReg_sklearn
1616
from sklearn.linear_model import MultiTaskLasso as MultiTaskLasso_sklearn
17+
from sklearn.linear_model import PoissonRegressor, GammaRegressor
1718
from sklearn.model_selection import GridSearchCV
1819
from sklearn.svm import LinearSVC as LinearSVC_sklearn
1920
from sklearn.utils.estimator_checks import check_estimator
@@ -23,8 +24,12 @@
2324
from skglm.estimators import (
2425
GeneralizedLinearEstimator, Lasso, MultiTaskLasso, WeightedLasso, ElasticNet,
2526
MCPRegression, SparseLogisticRegression, LinearSVC, GroupLasso, CoxEstimator)
26-
from skglm.datafits import Logistic, Quadratic, QuadraticSVC, QuadraticMultiTask, Cox
27-
from skglm.penalties import L1, IndicatorBox, L1_plus_L2, MCPenalty, WeightedL1, SLOPE
27+
from skglm.datafits import (
28+
Logistic, Quadratic, QuadraticSVC, QuadraticMultiTask, Cox, Poisson, Gamma
29+
)
30+
from skglm.penalties import (
31+
L1, IndicatorBox, L1_plus_L2, MCPenalty, WeightedL1, SLOPE
32+
)
2833
from skglm.solvers import AndersonCD, FISTA, ProxNewton
2934

3035
n_samples = 50
@@ -629,5 +634,23 @@ def test_SLOPE_printing():
629634
assert isinstance(res, str)
630635

631636

637+
@pytest.mark.parametrize(
638+
"sklearn_reg, skglm_datafit",
639+
[(PoissonRegressor, Poisson), (GammaRegressor, Gamma)]
640+
)
641+
def test_inverse_link_prediction(sklearn_reg, skglm_datafit):
642+
np.random.seed(42)
643+
X = np.random.randn(20, 5)
644+
y = np.random.randint(1, 6, size=20) # Use 1-6 for both (Gamma needs y>0)
645+
sklearn_pred = sklearn_reg(alpha=0.0, max_iter=10_000,
646+
tol=1e-8).fit(X, y).predict(X)
647+
skglm_pred = GeneralizedLinearEstimator(
648+
datafit=skglm_datafit(),
649+
penalty=L1_plus_L2(0.0, l1_ratio=0.0),
650+
solver=ProxNewton(fit_intercept=True, max_iter=10_000, tol=1e-8)
651+
).fit(X, y).predict(X)
652+
np.testing.assert_allclose(sklearn_pred, skglm_pred, rtol=1e-6, atol=1e-8)
653+
654+
632655
if __name__ == "__main__":
633656
pass

0 commit comments

Comments
 (0)