|
4 | 4 | import numpy as np |
5 | 5 | from numpy.linalg import norm |
6 | 6 |
|
7 | | -from skglm.penalties import L1 |
8 | | -from skglm.datafits import Quadratic |
| 7 | +from skglm.penalties import L1, L2 |
| 8 | +from skglm.datafits import Quadratic, Poisson |
9 | 9 | from skglm import GeneralizedLinearEstimator |
10 | 10 | from skglm.penalties.block_separable import ( |
11 | 11 | WeightedL1GroupL2, WeightedGroupL2 |
12 | 12 | ) |
13 | | -from skglm.datafits.group import QuadraticGroup, LogisticGroup |
14 | | -from skglm.solvers import GroupBCD, GroupProxNewton |
| 13 | +from skglm.datafits.group import QuadraticGroup, LogisticGroup, PoissonGroup |
| 14 | +from skglm.solvers import GroupBCD, GroupProxNewton, LBFGS |
15 | 15 |
|
16 | 16 | from skglm.utils.anderson import AndersonAcceleration |
17 | 17 | from skglm.utils.data import (make_correlated_data, grp_converter, |
18 | 18 | _alpha_max_group_lasso) |
19 | 19 |
|
20 | 20 | from celer import GroupLasso, Lasso |
21 | 21 | from sklearn.linear_model import LogisticRegression |
| 22 | +from scipy import sparse |
22 | 23 |
|
23 | 24 |
|
24 | 25 | def _generate_random_grp(n_groups, n_features, shuffle=True): |
@@ -312,5 +313,82 @@ def test_anderson_acceleration(): |
312 | 313 | np.testing.assert_array_equal(n_iter, 99) |
313 | 314 |
|
314 | 315 |
|
| 316 | +def test_poisson_group_gradient(): |
| 317 | + """Test gradient computation for PoissonGroup and compare sparse vs dense.""" |
| 318 | + n_samples, n_features = 15, 6 |
| 319 | + n_groups = 2 |
| 320 | + |
| 321 | + np.random.seed(0) |
| 322 | + X = np.random.randn(n_samples, n_features) |
| 323 | + X[X < 0] = 0 |
| 324 | + X_sparse = sparse.csc_matrix(X) |
| 325 | + y = np.random.poisson(1.0, n_samples) |
| 326 | + w = np.random.randn(n_features) * 0.1 |
| 327 | + Xw = X @ w |
| 328 | + |
| 329 | + grp_indices, grp_ptr = grp_converter(n_groups, n_features) |
| 330 | + poisson_group = PoissonGroup(grp_ptr=grp_ptr, grp_indices=grp_indices) |
| 331 | + |
| 332 | + for group_id in range(n_groups): |
| 333 | + # Test dense gradient against expected |
| 334 | + raw_grad = poisson_group.raw_grad(y, Xw) |
| 335 | + group_idx = grp_indices[grp_ptr[group_id]:grp_ptr[group_id+1]] |
| 336 | + expected = X[:, group_idx].T @ raw_grad |
| 337 | + grad = poisson_group.gradient_g(X, y, w, Xw, group_id) |
| 338 | + np.testing.assert_allclose(grad, expected, rtol=1e-10) |
| 339 | + |
| 340 | + # Test sparse matches dense |
| 341 | + grad_dense = poisson_group.gradient_g(X, y, w, Xw, group_id) |
| 342 | + grad_sparse = poisson_group.gradient_g_sparse( |
| 343 | + X_sparse.data, X_sparse.indptr, X_sparse.indices, y, w, Xw, group_id |
| 344 | + ) |
| 345 | + np.testing.assert_allclose(grad_sparse, grad_dense, rtol=1e-8) |
| 346 | + |
| 347 | + |
| 348 | +def test_poisson_group_solver(): |
| 349 | + """Test solver convergence, solution quality.""" |
| 350 | + n_samples, n_features = 30, 9 |
| 351 | + n_groups = 3 |
| 352 | + alpha = 0.1 |
| 353 | + |
| 354 | + np.random.seed(0) |
| 355 | + X = np.random.randn(n_samples, n_features) |
| 356 | + y = np.random.poisson(np.exp(alpha * X.sum(axis=1))) |
| 357 | + |
| 358 | + grp_indices, grp_ptr = grp_converter(n_groups, n_features) |
| 359 | + datafit = PoissonGroup(grp_ptr=grp_ptr, grp_indices=grp_indices) |
| 360 | + weights = np.array([1.0, 0.5, 2.0]) |
| 361 | + penalty = WeightedGroupL2(alpha=alpha, grp_ptr=grp_ptr, |
| 362 | + grp_indices=grp_indices, weights=weights) |
| 363 | + |
| 364 | + w, _, stop_crit = GroupProxNewton(fit_intercept=False, tol=1e-8).solve( |
| 365 | + X, y, datafit, penalty) |
| 366 | + |
| 367 | + assert stop_crit < 1e-8 and np.all(np.isfinite(w)) |
| 368 | + |
| 369 | + |
| 370 | +def test_poisson_vs_poisson_group_equivalence(): |
| 371 | + """Test that Poisson and PoissonGroup give same results when group size is 1.""" |
| 372 | + n_samples = 20 |
| 373 | + n_features = 8 |
| 374 | + alpha = 0.05 |
| 375 | + |
| 376 | + np.random.seed(42) |
| 377 | + X = np.random.randn(n_samples, n_features) |
| 378 | + y = np.random.poisson(np.exp(0.1 * X.sum(axis=1))) |
| 379 | + |
| 380 | + # Poisson with L2 penalty |
| 381 | + w_poisson, _, _ = LBFGS(tol=1e-10 |
| 382 | + ).solve(X, y, Poisson(), L2(alpha=alpha)) |
| 383 | + |
| 384 | + # PoissonGroup with group size = 1, other settings same as Poisson |
| 385 | + grp_indices, grp_ptr = grp_converter(n_features, n_features) |
| 386 | + w_group, _, _ = LBFGS(tol=1e-10).solve( |
| 387 | + X, y, PoissonGroup(grp_ptr=grp_ptr, grp_indices=grp_indices), |
| 388 | + L2(alpha=alpha)) |
| 389 | + |
| 390 | + np.testing.assert_equal(w_poisson, w_group) |
| 391 | + |
| 392 | + |
315 | 393 | if __name__ == "__main__": |
316 | 394 | pass |
0 commit comments