Skip to content

Commit 238ce8a

Browse files
authored
ENH - Add LogisitcGroup datafit (#94)
1 parent 5457eda commit 238ce8a

File tree

4 files changed

+121
-3
lines changed

4 files changed

+121
-3
lines changed

doc/api.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ Datafits
5555

5656
Huber
5757
Logistic
58+
LogisticGroup
5859
Quadratic
5960
QuadraticGroup
6061
QuadraticSVC

skglm/datafits/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
from .base import BaseDatafit, BaseMultitaskDatafit
22
from .single_task import Quadratic, QuadraticSVC, Logistic, Huber, Poisson
33
from .multi_task import QuadraticMultiTask
4-
from .group import QuadraticGroup
4+
from .group import QuadraticGroup, LogisticGroup
55

66

77
__all__ = [
88
BaseDatafit, BaseMultitaskDatafit,
99
Quadratic, QuadraticSVC, Logistic, Huber, Poisson,
1010
QuadraticMultiTask,
11-
QuadraticGroup
11+
QuadraticGroup, LogisticGroup
1212
]

skglm/datafits/group.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from numba import int32, float64
44

55
from skglm.datafits.base import BaseDatafit
6+
from skglm.datafits.single_task import Logistic
67

78

89
class QuadraticGroup(BaseDatafit):
@@ -71,3 +72,63 @@ def gradient_scalar(self, X, y, w, Xw, j):
7172

7273
def intercept_update_step(self, y, Xw):
7374
return np.mean(Xw - y)
75+
76+
77+
class LogisticGroup(Logistic):
78+
r"""Logistic datafit used with group penalties.
79+
80+
The datafit reads::
81+
82+
(1 / n_samples) * \sum_i log(1 + exp(-y_i * Xw_i))
83+
84+
Attributes
85+
----------
86+
grp_indices : array, shape (n_features,)
87+
The group indices stacked contiguously
88+
([grp1_indices, grp2_indices, ...]).
89+
90+
grp_ptr : array, shape (n_groups + 1,)
91+
The group pointers such that two consecutive elements delimit
92+
the indices of a group in ``grp_indices``.
93+
94+
lipschitz : array, shape (n_groups,)
95+
The lipschitz constants for each group.
96+
"""
97+
98+
def __init__(self, grp_ptr, grp_indices):
99+
self.grp_ptr, self.grp_indices = grp_ptr, grp_indices
100+
101+
def get_spec(self):
102+
spec = (
103+
('grp_ptr', int32[:]),
104+
('grp_indices', int32[:]),
105+
('lipschitz', float64[:])
106+
)
107+
return spec
108+
109+
def params_to_dict(self):
110+
return dict(grp_ptr=self.grp_ptr,
111+
grp_indices=self.grp_indices)
112+
113+
def initialize(self, X, y):
114+
grp_ptr, grp_indices = self.grp_ptr, self.grp_indices
115+
n_groups = len(grp_ptr) - 1
116+
117+
lipschitz = np.zeros(n_groups)
118+
for g in range(n_groups):
119+
grp_g_indices = grp_indices[grp_ptr[g]: grp_ptr[g+1]]
120+
X_g = X[:, grp_g_indices]
121+
lipschitz[g] = norm(X_g, ord=2) ** 2 / (4 * len(y))
122+
123+
self.lipschitz = lipschitz
124+
125+
def gradient_g(self, X, y, w, Xw, g):
126+
grp_ptr, grp_indices = self.grp_ptr, self.grp_indices
127+
grp_g_indices = grp_indices[grp_ptr[g]: grp_ptr[g+1]]
128+
raw_grad_val = self.raw_grad(y, Xw)
129+
130+
grad_g = np.zeros(len(grp_g_indices))
131+
for idx, j in enumerate(grp_g_indices):
132+
grad_g[idx] = X[:, j] @ raw_grad_val
133+
134+
return grad_g

skglm/tests/test_group.py

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,13 @@
55
from skglm.penalties import L1
66
from skglm.datafits import Quadratic
77
from skglm.penalties.block_separable import WeightedGroupL2
8-
from skglm.datafits.group import QuadraticGroup
8+
from skglm.datafits.group import QuadraticGroup, LogisticGroup
99
from skglm.solvers import GroupBCD
1010
from skglm.utils import (
1111
_alpha_max_group_lasso, grp_converter, make_correlated_data, compiled_clone,
1212
AndersonAcceleration)
1313
from celer import GroupLasso, Lasso
14+
from sklearn.linear_model import LogisticRegression
1415

1516

1617
def _generate_random_grp(n_groups, n_features, shuffle=True):
@@ -160,6 +161,61 @@ def test_intercept_grouplasso():
160161
np.testing.assert_allclose(model.intercept_, w[-1], atol=1e-5)
161162

162163

164+
@pytest.mark.parametrize("rho", [1e-1, 1e-2])
165+
def test_equivalence_logreg(rho):
166+
n_samples, n_features = 30, 50
167+
rng = np.random.RandomState(1123)
168+
X, y, _ = make_correlated_data(n_samples, n_features, random_state=rng)
169+
y = np.sign(y)
170+
171+
grp_indices, grp_ptr = grp_converter(1, n_features)
172+
weights = np.ones(n_features)
173+
alpha_max = norm(X.T @ y, ord=np.inf) / (2 * n_samples)
174+
alpha = rho * alpha_max / 10.
175+
176+
group_logistic = LogisticGroup(grp_ptr=grp_ptr, grp_indices=grp_indices)
177+
group_penalty = WeightedGroupL2(
178+
alpha=alpha, grp_ptr=grp_ptr,
179+
grp_indices=grp_indices, weights=weights)
180+
181+
group_logistic = compiled_clone(group_logistic, to_float32=X.dtype == np.float32)
182+
group_penalty = compiled_clone(group_penalty)
183+
w = GroupBCD(tol=1e-12).solve(X, y, group_logistic, group_penalty)[0]
184+
185+
sk_logreg = LogisticRegression(penalty='l1', C=1/(n_samples * alpha),
186+
fit_intercept=False, tol=1e-12, solver='liblinear')
187+
sk_logreg.fit(X, y)
188+
189+
np.testing.assert_allclose(sk_logreg.coef_.flatten(), w, atol=1e-6, rtol=1e-5)
190+
191+
192+
@pytest.mark.parametrize("n_groups, rho", [[15, 1e-1], [25, 1e-2]])
193+
def test_group_logreg(n_groups, rho):
194+
n_samples, n_features, shuffle = 30, 60, True
195+
random_state = 123
196+
rng = np.random.RandomState(random_state)
197+
198+
X, y, _ = make_correlated_data(n_samples, n_features, random_state=rng)
199+
y = np.sign(y)
200+
201+
rng.seed(random_state)
202+
weights = np.abs(rng.randn(n_groups))
203+
grp_indices, grp_ptr, _ = _generate_random_grp(n_groups, n_features, shuffle)
204+
205+
alpha_max = _alpha_max_group_lasso(X, y, grp_indices, grp_ptr, weights)
206+
alpha = rho * alpha_max
207+
208+
# skglm
209+
group_logistic = LogisticGroup(grp_ptr=grp_ptr, grp_indices=grp_indices)
210+
group_penalty = WeightedGroupL2(alpha, weights, grp_ptr, grp_indices)
211+
212+
group_logistic = compiled_clone(group_logistic, to_float32=X.dtype == np.float32)
213+
group_penalty = compiled_clone(group_penalty)
214+
stop_crit = GroupBCD(tol=1e-12).solve(X, y, group_logistic, group_penalty)[2]
215+
216+
np.testing.assert_array_less(stop_crit, 1e-12)
217+
218+
163219
def test_anderson_acceleration():
164220
# VAR: w = rho * w + 1 with |rho| < 1
165221
# converges to w_star = 1 / (1 - rho)

0 commit comments

Comments
 (0)