Skip to content

Commit 3851715

Browse files
FEAT - Implement PoissonGroup Datafit (#318)
1 parent 9f7a3ce commit 3851715

File tree

5 files changed

+136
-7
lines changed

5 files changed

+136
-7
lines changed

doc/api.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ Datafits
7070
Logistic
7171
LogisticGroup
7272
Poisson
73+
PoissonGroup
7374
Quadratic
7475
QuadraticGroup
7576
QuadraticHessian

doc/changes/0.5.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,4 @@ Version 0.5 (in progress)
55
- Add support for fitting an intercept in :ref:`SqrtLasso <skglm.experimental.sqrt_lasso.SqrtLasso>` (PR: :gh:`298`)
66
- Add experimental :ref:`QuantileHuber <skglm.experimental.quantile_huber.QuantileHuber>` and :ref:`SmoothQuantileRegressor <skglm.experimental.quantile_huber.SmoothQuantileRegressor>` for quantile regression, and an example script (PR: :gh:`312`).
77
- Add :ref:`GeneralizedLinearEstimatorCV <skglm.cv.GeneralizedLinearEstimatorCV>` for cross-validation with automatic parameter selection for L1 and elastic-net penalties (PR: :gh:`299`)
8+
- Add :class:`skglm.datafits.group.PoissonGroup` datafit for group-structured Poisson regression. (PR: :gh:`317`)

skglm/datafits/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,13 @@
22
from .single_task import (Quadratic, QuadraticSVC, Logistic, Huber, Poisson, Gamma,
33
Cox, WeightedQuadratic, QuadraticHessian,)
44
from .multi_task import QuadraticMultiTask
5-
from .group import QuadraticGroup, LogisticGroup
5+
from .group import QuadraticGroup, LogisticGroup, PoissonGroup
66

77

88
__all__ = [
99
BaseDatafit, BaseMultitaskDatafit,
1010
Quadratic, QuadraticSVC, Logistic, Huber, Poisson, Gamma, Cox,
1111
QuadraticMultiTask,
12-
QuadraticGroup, LogisticGroup, WeightedQuadratic,
12+
QuadraticGroup, LogisticGroup, PoissonGroup, WeightedQuadratic,
1313
QuadraticHessian
1414
]

skglm/datafits/group.py

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from numba import int32, float64
44

55
from skglm.datafits.base import BaseDatafit
6-
from skglm.datafits.single_task import Logistic
6+
from skglm.datafits.single_task import Logistic, Poisson
77
from skglm.utils.sparse_ops import spectral_norm, sparse_columns_slice
88

99

@@ -161,3 +161,52 @@ def gradient_g(self, X, y, w, Xw, g):
161161
grad_g[idx] = X[:, j] @ raw_grad_val
162162

163163
return grad_g
164+
165+
166+
class PoissonGroup(Poisson):
167+
r"""Poisson datafit used with group penalties.
168+
169+
The datafit reads:
170+
171+
.. math:: 1 / n_"samples" \sum_{i=1}^{n_"samples"} (\exp((Xw)_i) - y_i (Xw)_i)
172+
173+
Attributes
174+
----------
175+
grp_indices : array, shape (n_features,)
176+
The group indices stacked contiguously
177+
``[grp1_indices, grp2_indices, ...]``.
178+
179+
grp_ptr : array, shape (n_groups + 1,)
180+
The group pointers such that two consecutive elements delimit
181+
the indices of a group in ``grp_indices``.
182+
"""
183+
184+
def __init__(self, grp_ptr, grp_indices):
185+
self.grp_ptr, self.grp_indices = grp_ptr, grp_indices
186+
187+
def get_spec(self):
188+
return (
189+
('grp_ptr', int32[:]),
190+
('grp_indices', int32[:]),
191+
)
192+
193+
def params_to_dict(self):
194+
return dict(grp_ptr=self.grp_ptr, grp_indices=self.grp_indices)
195+
196+
def gradient_g(self, X, y, w, Xw, g):
197+
grp_ptr, grp_indices = self.grp_ptr, self.grp_indices
198+
grp_g_indices = grp_indices[grp_ptr[g]: grp_ptr[g+1]]
199+
raw_grad_val = self.raw_grad(y, Xw)
200+
grad_g = np.zeros(len(grp_g_indices))
201+
for idx, j in enumerate(grp_g_indices):
202+
grad_g[idx] = X[:, j] @ raw_grad_val
203+
return grad_g
204+
205+
def gradient_g_sparse(self, X_data, X_indptr, X_indices, y, w, Xw, g):
206+
grp_ptr, grp_indices = self.grp_ptr, self.grp_indices
207+
grp_g_indices = grp_indices[grp_ptr[g]: grp_ptr[g+1]]
208+
grad_g = np.zeros(len(grp_g_indices))
209+
for idx, j in enumerate(grp_g_indices):
210+
grad_g[idx] = self.gradient_scalar_sparse(
211+
X_data, X_indptr, X_indices, y, Xw, j)
212+
return grad_g

skglm/tests/test_group.py

Lines changed: 82 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,21 +4,22 @@
44
import numpy as np
55
from numpy.linalg import norm
66

7-
from skglm.penalties import L1
8-
from skglm.datafits import Quadratic
7+
from skglm.penalties import L1, L2
8+
from skglm.datafits import Quadratic, Poisson
99
from skglm import GeneralizedLinearEstimator
1010
from skglm.penalties.block_separable import (
1111
WeightedL1GroupL2, WeightedGroupL2
1212
)
13-
from skglm.datafits.group import QuadraticGroup, LogisticGroup
14-
from skglm.solvers import GroupBCD, GroupProxNewton
13+
from skglm.datafits.group import QuadraticGroup, LogisticGroup, PoissonGroup
14+
from skglm.solvers import GroupBCD, GroupProxNewton, LBFGS
1515

1616
from skglm.utils.anderson import AndersonAcceleration
1717
from skglm.utils.data import (make_correlated_data, grp_converter,
1818
_alpha_max_group_lasso)
1919

2020
from celer import GroupLasso, Lasso
2121
from sklearn.linear_model import LogisticRegression
22+
from scipy import sparse
2223

2324

2425
def _generate_random_grp(n_groups, n_features, shuffle=True):
@@ -312,5 +313,82 @@ def test_anderson_acceleration():
312313
np.testing.assert_array_equal(n_iter, 99)
313314

314315

316+
def test_poisson_group_gradient():
317+
"""Test gradient computation for PoissonGroup and compare sparse vs dense."""
318+
n_samples, n_features = 15, 6
319+
n_groups = 2
320+
321+
np.random.seed(0)
322+
X = np.random.randn(n_samples, n_features)
323+
X[X < 0] = 0
324+
X_sparse = sparse.csc_matrix(X)
325+
y = np.random.poisson(1.0, n_samples)
326+
w = np.random.randn(n_features) * 0.1
327+
Xw = X @ w
328+
329+
grp_indices, grp_ptr = grp_converter(n_groups, n_features)
330+
poisson_group = PoissonGroup(grp_ptr=grp_ptr, grp_indices=grp_indices)
331+
332+
for group_id in range(n_groups):
333+
# Test dense gradient against expected
334+
raw_grad = poisson_group.raw_grad(y, Xw)
335+
group_idx = grp_indices[grp_ptr[group_id]:grp_ptr[group_id+1]]
336+
expected = X[:, group_idx].T @ raw_grad
337+
grad = poisson_group.gradient_g(X, y, w, Xw, group_id)
338+
np.testing.assert_allclose(grad, expected, rtol=1e-10)
339+
340+
# Test sparse matches dense
341+
grad_dense = poisson_group.gradient_g(X, y, w, Xw, group_id)
342+
grad_sparse = poisson_group.gradient_g_sparse(
343+
X_sparse.data, X_sparse.indptr, X_sparse.indices, y, w, Xw, group_id
344+
)
345+
np.testing.assert_allclose(grad_sparse, grad_dense, rtol=1e-8)
346+
347+
348+
def test_poisson_group_solver():
349+
"""Test solver convergence, solution quality."""
350+
n_samples, n_features = 30, 9
351+
n_groups = 3
352+
alpha = 0.1
353+
354+
np.random.seed(0)
355+
X = np.random.randn(n_samples, n_features)
356+
y = np.random.poisson(np.exp(alpha * X.sum(axis=1)))
357+
358+
grp_indices, grp_ptr = grp_converter(n_groups, n_features)
359+
datafit = PoissonGroup(grp_ptr=grp_ptr, grp_indices=grp_indices)
360+
weights = np.array([1.0, 0.5, 2.0])
361+
penalty = WeightedGroupL2(alpha=alpha, grp_ptr=grp_ptr,
362+
grp_indices=grp_indices, weights=weights)
363+
364+
w, _, stop_crit = GroupProxNewton(fit_intercept=False, tol=1e-8).solve(
365+
X, y, datafit, penalty)
366+
367+
assert stop_crit < 1e-8 and np.all(np.isfinite(w))
368+
369+
370+
def test_poisson_vs_poisson_group_equivalence():
371+
"""Test that Poisson and PoissonGroup give same results when group size is 1."""
372+
n_samples = 20
373+
n_features = 8
374+
alpha = 0.05
375+
376+
np.random.seed(42)
377+
X = np.random.randn(n_samples, n_features)
378+
y = np.random.poisson(np.exp(0.1 * X.sum(axis=1)))
379+
380+
# Poisson with L2 penalty
381+
w_poisson, _, _ = LBFGS(tol=1e-10
382+
).solve(X, y, Poisson(), L2(alpha=alpha))
383+
384+
# PoissonGroup with group size = 1, other settings same as Poisson
385+
grp_indices, grp_ptr = grp_converter(n_features, n_features)
386+
w_group, _, _ = LBFGS(tol=1e-10).solve(
387+
X, y, PoissonGroup(grp_ptr=grp_ptr, grp_indices=grp_indices),
388+
L2(alpha=alpha))
389+
390+
np.testing.assert_equal(w_poisson, w_group)
391+
392+
315393
if __name__ == "__main__":
316394
pass

0 commit comments

Comments
 (0)