Skip to content

Commit e7048b6

Browse files
authored
MNT - compatibility of Cox datafit with L2 regularization (#167)
1 parent 0e5c938 commit e7048b6

File tree

2 files changed

+56
-10
lines changed

2 files changed

+56
-10
lines changed

skglm/datafits/single_task.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -646,6 +646,10 @@ def raw_hessian(self, y, Xw):
646646

647647
return out / n_samples
648648

649+
def gradient(self, X, y, Xw):
650+
"""Compute gradient of the datafit."""
651+
return X.T @ self.raw_grad(y, Xw)
652+
649653
def initialize(self, X, y):
650654
"""Initialize the datafit attributes."""
651655
tm, s = y

skglm/tests/test_lbfgs_solver.py

Lines changed: 52 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,20 @@
1+
import pytest
12
import numpy as np
3+
import pandas as pd
24

3-
from skglm.solvers import LBFGS
45
from skglm.penalties import L2
5-
from skglm.datafits import Logistic
6+
from skglm.solvers import LBFGS
7+
from skglm.datafits import Logistic, Cox
68

79
from sklearn.linear_model import LogisticRegression
810

9-
from skglm.utils.data import make_correlated_data
1011
from skglm.utils.jit_compilation import compiled_clone
12+
from skglm.utils.data import make_correlated_data, make_dummy_survival_data
1113

1214

1315
def test_lbfgs_L2_logreg():
1416
reg = 1.
15-
n_samples, n_features = 50, 10
17+
n_samples, n_features = 100, 50
1618

1719
X, y, _ = make_correlated_data(
1820
n_samples, n_features, random_state=0)
@@ -21,19 +23,59 @@ def test_lbfgs_L2_logreg():
2123
# fit L-BFGS
2224
datafit = compiled_clone(Logistic())
2325
penalty = compiled_clone(L2(reg))
24-
w, *_ = LBFGS().solve(X, y, datafit, penalty)
26+
w, *_ = LBFGS(tol=1e-12).solve(X, y, datafit, penalty)
2527

2628
# fit scikit learn
2729
estimator = LogisticRegression(
2830
penalty='l2',
2931
C=1 / (n_samples * reg),
30-
fit_intercept=False
31-
)
32-
estimator.fit(X, y)
32+
fit_intercept=False,
33+
tol=1e-12,
34+
).fit(X, y)
35+
36+
np.testing.assert_allclose(w, estimator.coef_.flatten())
37+
38+
39+
@pytest.mark.parametrize("use_efron", [True, False])
40+
def test_L2_Cox(use_efron):
41+
try:
42+
from lifelines import CoxPHFitter
43+
except ModuleNotFoundError:
44+
pytest.xfail(
45+
"Testing L2 Cox Estimator requires `lifelines` packages\n"
46+
"Run `pip install lifelines`"
47+
)
48+
49+
alpha = 10.
50+
n_samples, n_features = 100, 50
3351

34-
np.testing.assert_allclose(
35-
w, estimator.coef_.flatten(), atol=1e-4
52+
tm, s, X = make_dummy_survival_data(
53+
n_samples, n_features, normalize=True,
54+
with_ties=use_efron, random_state=0)
55+
56+
datafit = compiled_clone(Cox(use_efron))
57+
penalty = compiled_clone(L2(alpha))
58+
59+
datafit.initialize(X, (tm, s))
60+
w, *_ = LBFGS().solve(X, (tm, s), datafit, penalty)
61+
62+
# fit lifeline estimator
63+
stacked_tm_s_X = np.hstack((tm[:, None], s[:, None], X))
64+
df = pd.DataFrame(stacked_tm_s_X)
65+
66+
estimator = CoxPHFitter(penalizer=alpha, l1_ratio=0.).fit(
67+
df, duration_col=0, event_col=1
3668
)
69+
w_ll = estimator.params_.values
70+
71+
p_obj_skglm = datafit.value((tm, s), w, X @ w) + penalty.value(w)
72+
p_obj_ll = datafit.value((tm, s), w_ll, X @ w_ll) + penalty.value(w_ll)
73+
74+
# despite increasing tol in lifelines, solutions are quite far apart
75+
# suspecting lifelines https://github.com/CamDavidsonPilon/lifelines/pull/1534
76+
# as our solution gives the lowest objective value
77+
np.testing.assert_allclose(w, w_ll, rtol=1e-1)
78+
np.testing.assert_allclose(p_obj_skglm, p_obj_ll, rtol=1e-6)
3779

3880

3981
if __name__ == "__main__":

0 commit comments

Comments
 (0)