Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions skglm/datafits/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,10 @@ class PoissonGroup(Poisson):
def __init__(self, grp_ptr, grp_indices):
self.grp_ptr, self.grp_indices = grp_ptr, grp_indices

@staticmethod
def inverse_link(x):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same

return np.exp(x)

def get_spec(self):
return (
('grp_ptr', int32[:]),
Expand Down
8 changes: 8 additions & 0 deletions skglm/datafits/single_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,6 +590,10 @@ class Poisson(BaseDatafit):
def __init__(self):
pass

@staticmethod
def inverse_link(x):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same

return np.exp(x)

def get_spec(self):
pass

Expand Down Expand Up @@ -664,6 +668,10 @@ class Gamma(BaseDatafit):
def __init__(self):
pass

@staticmethod
def inverse_link(x):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same

return np.exp(x)

def get_spec(self):
pass

Expand Down
6 changes: 3 additions & 3 deletions skglm/estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from skglm.solvers import AndersonCD, MultiTaskBCD, GroupBCD, ProxNewton, LBFGS
from skglm.datafits import (
Cox, Quadratic, Logistic, Poisson, PoissonGroup, QuadraticSVC,
Cox, Quadratic, Logistic, QuadraticSVC,
QuadraticMultiTask, QuadraticGroup,)
from skglm.penalties import (L1, WeightedL1, L1_plus_L2, L2, WeightedGroupL2,
MCPenalty, WeightedMCPenalty, IndicatorBox, L2_1)
Expand Down Expand Up @@ -266,8 +266,8 @@ def predict(self, X):
else:
indices = scores.argmax(axis=1)
return self.classes_[indices]
elif isinstance(self.datafit, (Poisson, PoissonGroup)):
return np.exp(self._decision_function(X))
elif hasattr(self.datafit, "inverse_link"):
return self.datafit.inverse_link(self._decision_function(X))
else:
return self._decision_function(X)

Expand Down
36 changes: 34 additions & 2 deletions skglm/tests/test_estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from sklearn.linear_model import ElasticNet as ElasticNet_sklearn
from sklearn.linear_model import LogisticRegression as LogReg_sklearn
from sklearn.linear_model import MultiTaskLasso as MultiTaskLasso_sklearn
from sklearn.linear_model import PoissonRegressor, GammaRegressor
from sklearn.model_selection import GridSearchCV
from sklearn.svm import LinearSVC as LinearSVC_sklearn
from sklearn.utils.estimator_checks import check_estimator
Expand All @@ -23,8 +24,12 @@
from skglm.estimators import (
GeneralizedLinearEstimator, Lasso, MultiTaskLasso, WeightedLasso, ElasticNet,
MCPRegression, SparseLogisticRegression, LinearSVC, GroupLasso, CoxEstimator)
from skglm.datafits import Logistic, Quadratic, QuadraticSVC, QuadraticMultiTask, Cox
from skglm.penalties import L1, IndicatorBox, L1_plus_L2, MCPenalty, WeightedL1, SLOPE
from skglm.datafits import (
Logistic, Quadratic, QuadraticSVC, QuadraticMultiTask, Cox, Poisson, Gamma
)
from skglm.penalties import (
L1, IndicatorBox, L1_plus_L2, MCPenalty, WeightedL1, SLOPE
)
from skglm.solvers import AndersonCD, FISTA, ProxNewton

n_samples = 50
Expand Down Expand Up @@ -629,5 +634,32 @@ def test_SLOPE_printing():
assert isinstance(res, str)


@pytest.mark.parametrize(
"sklearn_reg, skglm_datafit, y_gen",
[
(
PoissonRegressor, Poisson,
lambda X: np.random.poisson(np.exp(X.sum(axis=1) * 0.1))
),
(
GammaRegressor, Gamma,
lambda X: np.random.gamma(2.0, np.exp(X.sum(axis=1) * 0.1))
),
]
)
def test_inverse_link_prediction(sklearn_reg, skglm_datafit, y_gen):
np.random.seed(42)
X = np.random.randn(20, 5)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

one last thing : IMO it makes sense to run the test on completely random values of y. They don't have to fit the model well, thay could be random integers between 0 and 5. We're notchecking statistical validity, we're checking that the optimizer works well and we return the same thing as sklearn. This would make the test simpler.

y = y_gen(X)
sklearn_pred = sklearn_reg(alpha=0.0, max_iter=10_000,
tol=1e-8).fit(X, y).predict(X)
skglm_pred = GeneralizedLinearEstimator(
datafit=skglm_datafit(),
penalty=L1_plus_L2(0.0, l1_ratio=0.0),
solver=ProxNewton(fit_intercept=True, max_iter=10_000, tol=1e-8)
).fit(X, y).predict(X)
np.testing.assert_allclose(sklearn_pred, skglm_pred, rtol=1e-6, atol=1e-8)


if __name__ == "__main__":
pass