Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions skglm/estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from skglm.solvers import AndersonCD, MultiTaskBCD, GroupBCD, ProxNewton, LBFGS
from skglm.datafits import (
Cox, Quadratic, Logistic, Poisson, PoissonGroup, QuadraticSVC,
Cox, Quadratic, Logistic, Poisson, PoissonGroup, Gamma, QuadraticSVC,
QuadraticMultiTask, QuadraticGroup,)
from skglm.penalties import (L1, WeightedL1, L1_plus_L2, L2, WeightedGroupL2,
MCPenalty, WeightedMCPenalty, IndicatorBox, L2_1)
Expand Down Expand Up @@ -266,7 +266,7 @@ def predict(self, X):
else:
indices = scores.argmax(axis=1)
return self.classes_[indices]
elif isinstance(self.datafit, (Poisson, PoissonGroup)):
elif isinstance(self.datafit, (Poisson, PoissonGroup, Gamma)):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if __hasattr__(self.datafit, "inverse_link_function"):
     return self.datafit.inverse_link_function(self._decision_function(X))

return np.exp(self._decision_function(X))
else:
return self._decision_function(X)
Expand Down
49 changes: 47 additions & 2 deletions skglm/tests/test_estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from sklearn.linear_model import ElasticNet as ElasticNet_sklearn
from sklearn.linear_model import LogisticRegression as LogReg_sklearn
from sklearn.linear_model import MultiTaskLasso as MultiTaskLasso_sklearn
from sklearn.linear_model import PoissonRegressor, GammaRegressor
from sklearn.model_selection import GridSearchCV
from sklearn.svm import LinearSVC as LinearSVC_sklearn
from sklearn.utils.estimator_checks import check_estimator
Expand All @@ -23,8 +24,12 @@
from skglm.estimators import (
GeneralizedLinearEstimator, Lasso, MultiTaskLasso, WeightedLasso, ElasticNet,
MCPRegression, SparseLogisticRegression, LinearSVC, GroupLasso, CoxEstimator)
from skglm.datafits import Logistic, Quadratic, QuadraticSVC, QuadraticMultiTask, Cox
from skglm.penalties import L1, IndicatorBox, L1_plus_L2, MCPenalty, WeightedL1, SLOPE
from skglm.datafits import (
Logistic, Quadratic, QuadraticSVC, QuadraticMultiTask, Cox, Poisson, Gamma
)
from skglm.penalties import (
L1, IndicatorBox, L1_plus_L2, MCPenalty, WeightedL1, SLOPE
)
from skglm.solvers import AndersonCD, FISTA, ProxNewton

n_samples = 50
Expand Down Expand Up @@ -629,5 +634,45 @@ def test_SLOPE_printing():
assert isinstance(res, str)


def test_poisson_predictions_match_sklearn():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

merge in a single parametrized test test_inverse_link_prediction

"""Test that skglm Poisson estimator predictions match sklearn PoissonRegressor."""
np.random.seed(42)
X = np.random.randn(20, 5)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

one last thing : IMO it makes sense to run the test on completely random values of y. They don't have to fit the model well, thay could be random integers between 0 and 5. We're notchecking statistical validity, we're checking that the optimizer works well and we return the same thing as sklearn. This would make the test simpler.

y = np.random.poisson(np.exp(X.sum(axis=1) * 0.1))

# Fit sklearn PoissonRegressor (no regularization due to different alpha scaling)
sklearn_pred = PoissonRegressor(
alpha=0.0, max_iter=10_000, tol=1e-8).fit(X, y).predict(X)

# Fit skglm equivalent (no regularization)
skglm_pred = GeneralizedLinearEstimator(
datafit=Poisson(),
penalty=L1_plus_L2(0.0, l1_ratio=0.0),
solver=ProxNewton(fit_intercept=True, max_iter=10_000, tol=1e-8)
).fit(X, y).predict(X)

np.testing.assert_allclose(sklearn_pred, skglm_pred, rtol=1e-6, atol=1e-8)


def test_gamma_predictions_match_sklearn():
"""Test that skglm Gamma estimator predictions match sklearn GammaRegressor."""
np.random.seed(42)
X = np.random.randn(20, 5)
y = np.random.gamma(2.0, np.exp(X.sum(axis=1) * 0.1))

# Fit sklearn GammaRegressor (no regularization due to different alpha scaling)
sklearn_pred = GammaRegressor(
alpha=0.0, max_iter=10_000, tol=1e-8).fit(X, y).predict(X)

# Fit skglm equivalent (no regularization)
skglm_pred = GeneralizedLinearEstimator(
datafit=Gamma(),
penalty=L1_plus_L2(0.0, l1_ratio=0.0),
solver=ProxNewton(fit_intercept=True, max_iter=10_000, tol=1e-8)
).fit(X, y).predict(X)

np.testing.assert_allclose(sklearn_pred, skglm_pred, rtol=1e-6, atol=1e-8)


if __name__ == "__main__":
pass