Skip to content

Commit fc6bc21

Browse files
tvayermathurinm
andauthored
FEAT Add QuadraticHessian datafit (#279)
Co-authored-by: mathurinm <[email protected]>
1 parent 1225970 commit fc6bc21

File tree

5 files changed

+61
-5
lines changed

5 files changed

+61
-5
lines changed

doc/api.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ Datafits
6868
Poisson
6969
Quadratic
7070
QuadraticGroup
71+
QuadraticHessian
7172
QuadraticSVC
7273
WeightedQuadratic
7374

@@ -102,4 +103,4 @@ Experimental
102103
PDCD_WS
103104
Pinball
104105
SqrtQuadratic
105-
SqrtLasso
106+
SqrtLasso

skglm/datafits/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from .base import BaseDatafit, BaseMultitaskDatafit
22
from .single_task import (Quadratic, QuadraticSVC, Logistic, Huber, Poisson, Gamma,
3-
Cox, WeightedQuadratic,)
3+
Cox, WeightedQuadratic, QuadraticHessian,)
44
from .multi_task import QuadraticMultiTask
55
from .group import QuadraticGroup, LogisticGroup
66

@@ -9,5 +9,6 @@
99
BaseDatafit, BaseMultitaskDatafit,
1010
Quadratic, QuadraticSVC, Logistic, Huber, Poisson, Gamma, Cox,
1111
QuadraticMultiTask,
12-
QuadraticGroup, LogisticGroup, WeightedQuadratic
12+
QuadraticGroup, LogisticGroup, WeightedQuadratic,
13+
QuadraticHessian
1314
]

skglm/datafits/single_task.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,41 @@ def intercept_update_step(self, y, Xw):
239239
return np.sum(self.sample_weights * (Xw - y)) / self.sample_weights.sum()
240240

241241

242+
class QuadraticHessian(BaseDatafit):
243+
r"""Quadratic datafit where we pass the Hessian A directly.
244+
245+
The datafit reads:
246+
247+
.. math:: 1 / 2 x^(\top) A x + \langle b, x \rangle
248+
249+
For a symmetric A. Up to a constant, it is the same as a Quadratic, with
250+
:math:`A = 1 / (n_"samples") X^(\top)X` and :math:`b = - 1 / n_"samples" X^(\top)y`.
251+
When the Hessian is available, this datafit is more efficient than using Quadratic.
252+
"""
253+
254+
def __init__(self):
255+
pass
256+
257+
def get_spec(self):
258+
pass
259+
260+
def params_to_dict(self):
261+
return dict()
262+
263+
def get_lipschitz(self, A, b):
264+
n_features = A.shape[0]
265+
lipschitz = np.zeros(n_features, dtype=A.dtype)
266+
for j in range(n_features):
267+
lipschitz[j] = A[j, j]
268+
return lipschitz
269+
270+
def gradient_scalar(self, A, b, w, Ax, j):
271+
return Ax[j] + b[j]
272+
273+
def value(self, b, x, Ax):
274+
return 0.5 * (x*Ax).sum() + (b*x).sum()
275+
276+
242277
@njit
243278
def sigmoid(x):
244279
"""Vectorwise sigmoid."""

skglm/estimators.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from skglm.utils.jit_compilation import compiled_clone
2222
from skglm.solvers import AndersonCD, MultiTaskBCD, GroupBCD
2323
from skglm.datafits import (Cox, Quadratic, Logistic, QuadraticSVC,
24-
QuadraticMultiTask, QuadraticGroup)
24+
QuadraticMultiTask, QuadraticGroup,)
2525
from skglm.penalties import (L1, WeightedL1, L1_plus_L2, L2, WeightedGroupL2,
2626
MCPenalty, WeightedMCPenalty, IndicatorBox, L2_1)
2727
from skglm.utils.data import grp_converter

skglm/tests/test_datafits.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from numpy.testing import assert_allclose, assert_array_less
77

88
from skglm.datafits import (Huber, Logistic, Poisson, Gamma, Cox, WeightedQuadratic,
9-
Quadratic,)
9+
Quadratic, QuadraticHessian)
1010
from skglm.penalties import L1, WeightedL1
1111
from skglm.solvers import AndersonCD, ProxNewton
1212
from skglm import GeneralizedLinearEstimator
@@ -219,5 +219,24 @@ def test_sample_weights(fit_intercept):
219219
# np.testing.assert_equal(n_iter, n_iter_overs)
220220

221221

222+
def test_HessianQuadratic():
223+
n_samples = 20
224+
n_features = 10
225+
X, y, _ = make_correlated_data(
226+
n_samples=n_samples, n_features=n_features, random_state=0)
227+
A = X.T @ X / n_samples
228+
b = -X.T @ y / n_samples
229+
alpha = np.max(np.abs(b)) / 10
230+
231+
pen = L1(alpha)
232+
solv = AndersonCD(warm_start=False, verbose=2, fit_intercept=False)
233+
lasso = GeneralizedLinearEstimator(Quadratic(), pen, solv).fit(X, y)
234+
qpl1 = GeneralizedLinearEstimator(QuadraticHessian(), pen, solv).fit(A, b)
235+
236+
np.testing.assert_allclose(lasso.coef_, qpl1.coef_)
237+
# check that it's not just because we got alpha too high and thus 0 coef
238+
np.testing.assert_array_less(0.1, np.max(np.abs(qpl1.coef_)))
239+
240+
222241
if __name__ == '__main__':
223242
pass

0 commit comments

Comments
 (0)