Skip to content

Commit f9f0c42

Browse files
authored
FEAT Add Gamma regressor (#113)
1 parent 73851b3 commit f9f0c42

File tree

4 files changed

+99
-6
lines changed

4 files changed

+99
-6
lines changed

doc/api.rst

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,11 @@ Datafits
5454
.. autosummary::
5555
:toctree: generated/
5656

57+
Gamma
5758
Huber
5859
Logistic
5960
LogisticGroup
61+
Poisson
6062
Quadratic
6163
QuadraticGroup
6264
QuadraticSVC
@@ -87,4 +89,4 @@ Experimental
8789
.. autosummary::
8890
:toctree: generated/
8991

90-
SqrtLasso
92+
SqrtLasso

skglm/datafits/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
from .base import BaseDatafit, BaseMultitaskDatafit
2-
from .single_task import Quadratic, QuadraticSVC, Logistic, Huber, Poisson
2+
from .single_task import Quadratic, QuadraticSVC, Logistic, Huber, Poisson, Gamma
33
from .multi_task import QuadraticMultiTask
44
from .group import QuadraticGroup, LogisticGroup
55

66

77
__all__ = [
88
BaseDatafit, BaseMultitaskDatafit,
9-
Quadratic, QuadraticSVC, Logistic, Huber, Poisson,
9+
Quadratic, QuadraticSVC, Logistic, Huber, Poisson, Gamma,
1010
QuadraticMultiTask,
1111
QuadraticGroup, LogisticGroup
1212
]

skglm/datafits/single_task.py

Lines changed: 63 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -429,10 +429,16 @@ def params_to_dict(self):
429429
return dict()
430430

431431
def initialize(self, X, y):
432-
pass
432+
if np.any(y <= 0):
433+
raise ValueError(
434+
"Target vector `y` should only take positive values " +
435+
"when fitting a Poisson model.")
433436

434437
def initialize_sparse(self, X_data, X_indptr, X_indices, y):
435-
pass
438+
if np.any(y <= 0):
439+
raise ValueError(
440+
"Target vector `y` should only take positive values " +
441+
"when fitting a Poisson model.")
436442

437443
def raw_grad(self, y, Xw):
438444
"""Compute gradient of datafit w.r.t ``Xw``."""
@@ -467,3 +473,58 @@ def gradient_scalar_sparse(self, X_data, X_indptr, X_indices, y, Xw, j):
467473

468474
def intercept_update_self(self, y, Xw):
469475
pass
476+
477+
478+
class Gamma(BaseDatafit):
479+
r"""Gamma datafit.
480+
481+
The datafit reads::
482+
483+
(1 / n_samples) * \sum_i (Xw_i + y_i * exp(-Xw_i) - 1 - log(y_i))
484+
485+
Note:
486+
----
487+
The class is jit compiled at fit time using Numba compiler.
488+
This allows for faster computations.
489+
"""
490+
491+
def __init__(self):
492+
pass
493+
494+
def get_spec(self):
495+
pass
496+
497+
def params_to_dict(self):
498+
return dict()
499+
500+
def initialize(self, X, y):
501+
if np.any(y <= 0):
502+
raise ValueError(
503+
"Target vector `y` should only take positive values "
504+
"when fitting a Gamma model.")
505+
506+
def initialize_sparse(self, X_data, X_indptr, X_indices, y):
507+
if np.any(y <= 0):
508+
raise ValueError(
509+
"Target vector `y` should only take positive values "
510+
"when fitting a Gamma model.")
511+
512+
def raw_grad(self, y, Xw):
513+
"""Compute gradient of datafit w.r.t. ``Xw``."""
514+
return (1 - y * np.exp(-Xw)) / len(y)
515+
516+
def raw_hessian(self, y, Xw):
517+
"""Compute Hessian of datafit w.r.t. ``Xw``."""
518+
return (y * np.exp(-Xw)) / len(y)
519+
520+
def value(self, y, w, Xw):
521+
return (np.sum(Xw + y * np.exp(-Xw) - np.log(y)) - 1) / len(y)
522+
523+
def gradient_scalar(self, X, y, w, Xw, j):
524+
return X[:, j] @ (1 - y * np.exp(-Xw)) / len(y)
525+
526+
def gradient_scalar_sparse(self, X_data, X_indptr, X_indices, y, Xw, j):
527+
pass
528+
529+
def intercept_update_self(self, y, Xw):
530+
pass

skglm/tests/test_datafits.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from sklearn.linear_model import HuberRegressor
55
from numpy.testing import assert_allclose, assert_array_less
66

7-
from skglm.datafits import Huber, Logistic, Poisson
7+
from skglm.datafits import Huber, Logistic, Poisson, Gamma
88
from skglm.penalties import L1, WeightedL1
99
from skglm.solvers import AndersonCD, ProxNewton
1010
from skglm import GeneralizedLinearEstimator
@@ -84,5 +84,35 @@ def test_poisson():
8484
assert_allclose(model.coef_, w_statsmodels, rtol=1e-4)
8585

8686

87+
def test_gamma():
88+
try:
89+
import statsmodels.api as sm
90+
except ImportError:
91+
pytest.xfail("`statsmodels` not found. `Gamma` datafit can't be tested.")
92+
93+
# When n_samples < n_features, the unregularized Gamma objective does not have a
94+
# unique minimizer.
95+
rho = 1e-2
96+
n_samples, n_features = 100, 10
97+
X, y, _ = make_correlated_data(n_samples, n_features, random_state=0)
98+
y[y <= 0] = 0.1
99+
tol = 1e-14
100+
101+
alpha_max = np.linalg.norm(X.T @ (1 - y), ord=np.inf) / n_samples
102+
alpha = rho * alpha_max
103+
104+
gamma_model = sm.GLM(y, X, family=sm.families.Gamma(sm.families.links.Log()))
105+
gamma_results = gamma_model.fit_regularized(
106+
method="elastic_net", L1_wt=1, cnvrg_tol=tol, alpha=alpha)
107+
108+
clf = GeneralizedLinearEstimator(
109+
datafit=Gamma(),
110+
penalty=L1(alpha),
111+
solver=ProxNewton(fit_intercept=False, tol=tol)
112+
).fit(X, y)
113+
114+
np.testing.assert_allclose(clf.coef_, gamma_results.params, rtol=1e-6)
115+
116+
87117
if __name__ == '__main__':
88118
pass

0 commit comments

Comments
 (0)