Skip to content

Commit 182cae8

Browse files
PABanniermathurinm
andauthored
ENH Add Poisson datafit (#78)
Co-authored-by: mathurinm <[email protected]>
1 parent 6ea09b2 commit 182cae8

File tree

3 files changed

+96
-5
lines changed

3 files changed

+96
-5
lines changed

skglm/datafits/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
from .base import BaseDatafit, BaseMultitaskDatafit
2-
from .single_task import Quadratic, QuadraticSVC, Logistic, Huber
2+
from .single_task import Quadratic, QuadraticSVC, Logistic, Huber, Poisson
33
from .multi_task import QuadraticMultiTask
44
from .group import QuadraticGroup
55

66

77
__all__ = [
88
BaseDatafit, BaseMultitaskDatafit,
9-
Quadratic, QuadraticSVC, Logistic, Huber,
9+
Quadratic, QuadraticSVC, Logistic, Huber, Poisson,
1010
QuadraticMultiTask,
1111
QuadraticGroup
1212
]

skglm/datafits/single_task.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -360,3 +360,66 @@ def intercept_update_step(self, y, Xw):
360360
else:
361361
update -= np.sign(residual) * self.delta
362362
return update / n_samples
363+
364+
365+
class Poisson(BaseDatafit):
366+
r"""Poisson datafit.
367+
368+
The datafit reads::
369+
370+
(1 / n_samples) * \sum_i (exp(Xw_i) - y_i * Xw_i)
371+
372+
Note:
373+
----
374+
The class is jit compiled at fit time using Numba compiler.
375+
This allows for faster computations.
376+
"""
377+
378+
def __init__(self):
379+
pass
380+
381+
def get_spec(self):
382+
pass
383+
384+
def params_to_dict(self):
385+
return dict()
386+
387+
def initialize(self, X, y):
388+
pass
389+
390+
def initialize_sparse(self, X_data, X_indptr, X_indices, y):
391+
pass
392+
393+
def raw_grad(self, y, Xw):
394+
"""Compute gradient of datafit w.r.t ``Xw``."""
395+
return (np.exp(Xw) - y) / len(y)
396+
397+
def raw_hessian(self, y, Xw):
398+
"""Compute Hessian of datafit w.r.t ``Xw``."""
399+
return np.exp(Xw) / len(y)
400+
401+
def value(self, y, w, Xw):
402+
return np.sum(np.exp(Xw) - y * Xw) / len(y)
403+
404+
def gradient_scalar(self, X, y, w, Xw, j):
405+
return (X[:, j] @ (np.exp(Xw) - y)) / len(y)
406+
407+
def full_grad_sparse(self, X_data, X_indptr, X_indices, y, Xw):
408+
n_features = X_indptr.shape[0] - 1
409+
grad = np.zeros(n_features, dtype=X_data.dtype)
410+
for j in range(n_features):
411+
grad[j] = 0.
412+
for i in range(X_indptr[j], X_indptr[j + 1]):
413+
grad[j] += X_data[i] * (
414+
np.exp(Xw[X_indices[i]] - y[X_indices[i]])) / len(y)
415+
return grad
416+
417+
def gradient_scalar_sparse(self, X_data, X_indptr, X_indices, y, Xw, j):
418+
grad = 0.
419+
for i in range(X_indptr[j], X_indptr[j + 1]):
420+
idx_i = X_indices[i]
421+
grad += X_data[i] * (np.exp(Xw[idx_i]) - y[idx_i])
422+
return grad / len(y)
423+
424+
def intercept_update_self(self, y, Xw):
425+
pass

skglm/tests/test_datafits.py

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@
44
from sklearn.linear_model import HuberRegressor
55
from numpy.testing import assert_allclose, assert_array_less
66

7-
from skglm.datafits import Huber, Logistic
8-
from skglm.penalties import WeightedL1
9-
from skglm.solvers import AndersonCD
7+
from skglm.datafits import Huber, Logistic, Poisson
8+
from skglm.penalties import L1, WeightedL1
9+
from skglm.solvers import AndersonCD, ProxNewton
1010
from skglm import GeneralizedLinearEstimator
1111
from skglm.utils import make_correlated_data
1212

@@ -56,5 +56,33 @@ def test_log_datafit():
5656
np.testing.assert_almost_equal(-grad * (y + n_samples * grad), hess)
5757

5858

59+
def test_poisson():
60+
try:
61+
from statsmodels.discrete.discrete_model import Poisson as PoissonRegressor # noqa
62+
except ImportError:
63+
pytest.xfail("`statsmodels` not found. `Poisson` datafit can't be tested.")
64+
65+
n_samples, n_features = 10, 22
66+
tol = 1e-14
67+
X, y, _ = make_correlated_data(n_samples, n_features, random_state=0)
68+
y = np.abs(y)
69+
70+
alpha_max = np.linalg.norm(X.T @ (np.ones(n_samples) - y), ord=np.inf) / n_samples
71+
alpha = alpha_max * 0.1
72+
73+
df = Poisson()
74+
pen = L1(alpha)
75+
76+
solver = ProxNewton(tol=tol, fit_intercept=False)
77+
model = GeneralizedLinearEstimator(df, pen, solver).fit(X, y)
78+
79+
poisson_regressor = PoissonRegressor(y, X, offset=None)
80+
res = poisson_regressor.fit_regularized(
81+
method="l1", size_trim_tol=tol, alpha=alpha * n_samples, trim_mode="size")
82+
w_statsmodels = res.params
83+
84+
assert_allclose(model.coef_, w_statsmodels, rtol=1e-4)
85+
86+
5987
if __name__ == '__main__':
6088
pass

0 commit comments

Comments
 (0)