-
Notifications
You must be signed in to change notification settings - Fork 38
TST: Add sklearn <-> skglm match tests for Poisson and Gamma predictions #323
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 4 commits
9df9d80
2479119
58587cc
05f1edb
a1302a5
4041b1a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -184,6 +184,10 @@ class PoissonGroup(Poisson): | |
| def __init__(self, grp_ptr, grp_indices): | ||
| self.grp_ptr, self.grp_indices = grp_ptr, grp_indices | ||
|
|
||
| @staticmethod | ||
| def inverse_link(x): | ||
|
||
| return np.exp(x) | ||
|
|
||
| def get_spec(self): | ||
| return ( | ||
| ('grp_ptr', int32[:]), | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -590,6 +590,10 @@ class Poisson(BaseDatafit): | |
| def __init__(self): | ||
| pass | ||
|
|
||
| @staticmethod | ||
| def inverse_link(x): | ||
|
||
| return np.exp(x) | ||
|
|
||
| def get_spec(self): | ||
| pass | ||
|
|
||
|
|
@@ -664,6 +668,10 @@ class Gamma(BaseDatafit): | |
| def __init__(self): | ||
| pass | ||
|
|
||
| @staticmethod | ||
| def inverse_link(x): | ||
|
||
| return np.exp(x) | ||
|
|
||
| def get_spec(self): | ||
| pass | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -14,6 +14,7 @@ | |
| from sklearn.linear_model import ElasticNet as ElasticNet_sklearn | ||
| from sklearn.linear_model import LogisticRegression as LogReg_sklearn | ||
| from sklearn.linear_model import MultiTaskLasso as MultiTaskLasso_sklearn | ||
| from sklearn.linear_model import PoissonRegressor, GammaRegressor | ||
| from sklearn.model_selection import GridSearchCV | ||
| from sklearn.svm import LinearSVC as LinearSVC_sklearn | ||
| from sklearn.utils.estimator_checks import check_estimator | ||
|
|
@@ -23,8 +24,12 @@ | |
| from skglm.estimators import ( | ||
| GeneralizedLinearEstimator, Lasso, MultiTaskLasso, WeightedLasso, ElasticNet, | ||
| MCPRegression, SparseLogisticRegression, LinearSVC, GroupLasso, CoxEstimator) | ||
| from skglm.datafits import Logistic, Quadratic, QuadraticSVC, QuadraticMultiTask, Cox | ||
| from skglm.penalties import L1, IndicatorBox, L1_plus_L2, MCPenalty, WeightedL1, SLOPE | ||
| from skglm.datafits import ( | ||
| Logistic, Quadratic, QuadraticSVC, QuadraticMultiTask, Cox, Poisson, Gamma | ||
| ) | ||
| from skglm.penalties import ( | ||
| L1, IndicatorBox, L1_plus_L2, MCPenalty, WeightedL1, SLOPE | ||
| ) | ||
| from skglm.solvers import AndersonCD, FISTA, ProxNewton | ||
|
|
||
| n_samples = 50 | ||
|
|
@@ -629,5 +634,32 @@ def test_SLOPE_printing(): | |
| assert isinstance(res, str) | ||
|
|
||
|
|
||
| @pytest.mark.parametrize( | ||
| "sklearn_reg, skglm_datafit, y_gen", | ||
| [ | ||
| ( | ||
| PoissonRegressor, Poisson, | ||
| lambda X: np.random.poisson(np.exp(X.sum(axis=1) * 0.1)) | ||
| ), | ||
| ( | ||
| GammaRegressor, Gamma, | ||
| lambda X: np.random.gamma(2.0, np.exp(X.sum(axis=1) * 0.1)) | ||
| ), | ||
| ] | ||
| ) | ||
| def test_inverse_link_prediction(sklearn_reg, skglm_datafit, y_gen): | ||
| np.random.seed(42) | ||
| X = np.random.randn(20, 5) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. one last thing : IMO it makes sense to run the test on completely random values of |
||
| y = y_gen(X) | ||
| sklearn_pred = sklearn_reg(alpha=0.0, max_iter=10_000, | ||
| tol=1e-8).fit(X, y).predict(X) | ||
| skglm_pred = GeneralizedLinearEstimator( | ||
| datafit=skglm_datafit(), | ||
| penalty=L1_plus_L2(0.0, l1_ratio=0.0), | ||
| solver=ProxNewton(fit_intercept=True, max_iter=10_000, tol=1e-8) | ||
| ).fit(X, y).predict(X) | ||
| np.testing.assert_allclose(sklearn_pred, skglm_pred, rtol=1e-6, atol=1e-8) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| pass | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just call the argument Xw for clarity