|
5 | 5 | from skglm.penalties import L1
|
6 | 6 | from skglm.datafits import Quadratic
|
7 | 7 | from skglm.penalties.block_separable import WeightedGroupL2
|
8 |
| -from skglm.datafits.group import QuadraticGroup |
| 8 | +from skglm.datafits.group import QuadraticGroup, LogisticGroup |
9 | 9 | from skglm.solvers import GroupBCD
|
10 | 10 | from skglm.utils import (
|
11 | 11 | _alpha_max_group_lasso, grp_converter, make_correlated_data, compiled_clone,
|
12 | 12 | AndersonAcceleration)
|
13 | 13 | from celer import GroupLasso, Lasso
|
| 14 | +from sklearn.linear_model import LogisticRegression |
14 | 15 |
|
15 | 16 |
|
16 | 17 | def _generate_random_grp(n_groups, n_features, shuffle=True):
|
@@ -160,6 +161,61 @@ def test_intercept_grouplasso():
|
160 | 161 | np.testing.assert_allclose(model.intercept_, w[-1], atol=1e-5)
|
161 | 162 |
|
162 | 163 |
|
| 164 | +@pytest.mark.parametrize("rho", [1e-1, 1e-2]) |
| 165 | +def test_equivalence_logreg(rho): |
| 166 | + n_samples, n_features = 30, 50 |
| 167 | + rng = np.random.RandomState(1123) |
| 168 | + X, y, _ = make_correlated_data(n_samples, n_features, random_state=rng) |
| 169 | + y = np.sign(y) |
| 170 | + |
| 171 | + grp_indices, grp_ptr = grp_converter(1, n_features) |
| 172 | + weights = np.ones(n_features) |
| 173 | + alpha_max = norm(X.T @ y, ord=np.inf) / (2 * n_samples) |
| 174 | + alpha = rho * alpha_max / 10. |
| 175 | + |
| 176 | + group_logistic = LogisticGroup(grp_ptr=grp_ptr, grp_indices=grp_indices) |
| 177 | + group_penalty = WeightedGroupL2( |
| 178 | + alpha=alpha, grp_ptr=grp_ptr, |
| 179 | + grp_indices=grp_indices, weights=weights) |
| 180 | + |
| 181 | + group_logistic = compiled_clone(group_logistic, to_float32=X.dtype == np.float32) |
| 182 | + group_penalty = compiled_clone(group_penalty) |
| 183 | + w = GroupBCD(tol=1e-12).solve(X, y, group_logistic, group_penalty)[0] |
| 184 | + |
| 185 | + sk_logreg = LogisticRegression(penalty='l1', C=1/(n_samples * alpha), |
| 186 | + fit_intercept=False, tol=1e-12, solver='liblinear') |
| 187 | + sk_logreg.fit(X, y) |
| 188 | + |
| 189 | + np.testing.assert_allclose(sk_logreg.coef_.flatten(), w, atol=1e-6, rtol=1e-5) |
| 190 | + |
| 191 | + |
| 192 | +@pytest.mark.parametrize("n_groups, rho", [[15, 1e-1], [25, 1e-2]]) |
| 193 | +def test_group_logreg(n_groups, rho): |
| 194 | + n_samples, n_features, shuffle = 30, 60, True |
| 195 | + random_state = 123 |
| 196 | + rng = np.random.RandomState(random_state) |
| 197 | + |
| 198 | + X, y, _ = make_correlated_data(n_samples, n_features, random_state=rng) |
| 199 | + y = np.sign(y) |
| 200 | + |
| 201 | + rng.seed(random_state) |
| 202 | + weights = np.abs(rng.randn(n_groups)) |
| 203 | + grp_indices, grp_ptr, _ = _generate_random_grp(n_groups, n_features, shuffle) |
| 204 | + |
| 205 | + alpha_max = _alpha_max_group_lasso(X, y, grp_indices, grp_ptr, weights) |
| 206 | + alpha = rho * alpha_max |
| 207 | + |
| 208 | + # skglm |
| 209 | + group_logistic = LogisticGroup(grp_ptr=grp_ptr, grp_indices=grp_indices) |
| 210 | + group_penalty = WeightedGroupL2(alpha, weights, grp_ptr, grp_indices) |
| 211 | + |
| 212 | + group_logistic = compiled_clone(group_logistic, to_float32=X.dtype == np.float32) |
| 213 | + group_penalty = compiled_clone(group_penalty) |
| 214 | + stop_crit = GroupBCD(tol=1e-12).solve(X, y, group_logistic, group_penalty)[2] |
| 215 | + |
| 216 | + np.testing.assert_array_less(stop_crit, 1e-12) |
| 217 | + |
| 218 | + |
163 | 219 | def test_anderson_acceleration():
|
164 | 220 | # VAR: w = rho * w + 1 with |rho| < 1
|
165 | 221 | # converges to w_star = 1 / (1 - rho)
|
|
0 commit comments