diff --git a/skglm/datafits/base.py b/skglm/datafits/base.py index 2f70d957..c5e103f0 100644 --- a/skglm/datafits/base.py +++ b/skglm/datafits/base.py @@ -2,6 +2,22 @@ class BaseDatafit: """Base class for datafits.""" + @staticmethod + def inverse_link(Xw): + """Inverse link function (identity by default). + + Parameters + ---------- + Xw : array-like + Linear predictor values. + + Returns + ------- + array-like + Transformed values in response scale. + """ + return Xw + def get_spec(self): """Specify the numba types of the class attributes. diff --git a/skglm/datafits/group.py b/skglm/datafits/group.py index 264fc0f3..1bf02c8b 100644 --- a/skglm/datafits/group.py +++ b/skglm/datafits/group.py @@ -184,6 +184,10 @@ class PoissonGroup(Poisson): def __init__(self, grp_ptr, grp_indices): self.grp_ptr, self.grp_indices = grp_ptr, grp_indices + @staticmethod + def inverse_link(Xw): + return np.exp(Xw) + def get_spec(self): return ( ('grp_ptr', int32[:]), diff --git a/skglm/datafits/single_task.py b/skglm/datafits/single_task.py index a0bd2873..987350e2 100644 --- a/skglm/datafits/single_task.py +++ b/skglm/datafits/single_task.py @@ -590,6 +590,10 @@ class Poisson(BaseDatafit): def __init__(self): pass + @staticmethod + def inverse_link(Xw): + return np.exp(Xw) + def get_spec(self): pass @@ -664,6 +668,10 @@ class Gamma(BaseDatafit): def __init__(self): pass + @staticmethod + def inverse_link(Xw): + return np.exp(Xw) + def get_spec(self): pass diff --git a/skglm/estimators.py b/skglm/estimators.py index 9a847240..4d695949 100644 --- a/skglm/estimators.py +++ b/skglm/estimators.py @@ -19,7 +19,7 @@ from skglm.solvers import AndersonCD, MultiTaskBCD, GroupBCD, ProxNewton, LBFGS from skglm.datafits import ( - Cox, Quadratic, Logistic, Poisson, PoissonGroup, QuadraticSVC, + Cox, Quadratic, Logistic, QuadraticSVC, QuadraticMultiTask, QuadraticGroup,) from skglm.penalties import (L1, WeightedL1, L1_plus_L2, L2, WeightedGroupL2, MCPenalty, WeightedMCPenalty, IndicatorBox, L2_1) @@ -266,10 +266,8 @@ def predict(self, X): else: indices = scores.argmax(axis=1) return self.classes_[indices] - elif isinstance(self.datafit, (Poisson, PoissonGroup)): - return np.exp(self._decision_function(X)) else: - return self._decision_function(X) + return self.datafit.inverse_link(self._decision_function(X)) def get_params(self, deep=False): """Get parameters of the estimators including the datafit's and penalty's. diff --git a/skglm/tests/test_estimators.py b/skglm/tests/test_estimators.py index 7ca2fd3c..1bafa21a 100644 --- a/skglm/tests/test_estimators.py +++ b/skglm/tests/test_estimators.py @@ -14,6 +14,7 @@ from sklearn.linear_model import ElasticNet as ElasticNet_sklearn from sklearn.linear_model import LogisticRegression as LogReg_sklearn from sklearn.linear_model import MultiTaskLasso as MultiTaskLasso_sklearn +from sklearn.linear_model import PoissonRegressor, GammaRegressor from sklearn.model_selection import GridSearchCV from sklearn.svm import LinearSVC as LinearSVC_sklearn from sklearn.utils.estimator_checks import check_estimator @@ -23,8 +24,12 @@ from skglm.estimators import ( GeneralizedLinearEstimator, Lasso, MultiTaskLasso, WeightedLasso, ElasticNet, MCPRegression, SparseLogisticRegression, LinearSVC, GroupLasso, CoxEstimator) -from skglm.datafits import Logistic, Quadratic, QuadraticSVC, QuadraticMultiTask, Cox -from skglm.penalties import L1, IndicatorBox, L1_plus_L2, MCPenalty, WeightedL1, SLOPE +from skglm.datafits import ( + Logistic, Quadratic, QuadraticSVC, QuadraticMultiTask, Cox, Poisson, Gamma +) +from skglm.penalties import ( + L1, IndicatorBox, L1_plus_L2, MCPenalty, WeightedL1, SLOPE +) from skglm.solvers import AndersonCD, FISTA, ProxNewton n_samples = 50 @@ -629,5 +634,23 @@ def test_SLOPE_printing(): assert isinstance(res, str) +@pytest.mark.parametrize( + "sklearn_reg, skglm_datafit", + [(PoissonRegressor, Poisson), (GammaRegressor, Gamma)] +) +def test_inverse_link_prediction(sklearn_reg, skglm_datafit): + np.random.seed(42) + X = np.random.randn(20, 5) + y = np.random.randint(1, 6, size=20) # Use 1-6 for both (Gamma needs y>0) + sklearn_pred = sklearn_reg(alpha=0.0, max_iter=10_000, + tol=1e-8).fit(X, y).predict(X) + skglm_pred = GeneralizedLinearEstimator( + datafit=skglm_datafit(), + penalty=L1_plus_L2(0.0, l1_ratio=0.0), + solver=ProxNewton(fit_intercept=True, max_iter=10_000, tol=1e-8) + ).fit(X, y).predict(X) + np.testing.assert_allclose(sklearn_pred, skglm_pred, rtol=1e-6, atol=1e-8) + + if __name__ == "__main__": pass