Skip to content

Commit 399dfc6

Browse files
ENH - implement Cox datafit with Breslow estimate (#157)
Co-authored-by: mathurinm <[email protected]>
1 parent 7323e27 commit 399dfc6

File tree

8 files changed

+312
-6
lines changed

8 files changed

+312
-6
lines changed

.github/workflows/main.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,5 +26,8 @@ jobs:
2626
pip install .
2727
pip install statsmodels cvxopt
2828
pip install git+https://github.com/jolars/pyslope.git
29+
# for testing Cox estimator
30+
pip install lifelines
31+
pip install pandas
2932
- name: Test with pytest
3033
run: pytest -v skglm/

LICENSE

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
BSD 3-Clause License
22

3-
Copyright (c) 2022, scikit-learn-contrib
3+
Copyright (c) 2023, scikit-learn-contrib
44
All rights reserved.
55

66
Redistribution and use in source and binary forms, with or without

doc/api.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ Datafits
5454
.. autosummary::
5555
:toctree: generated/
5656

57+
Cox
5758
Gamma
5859
Huber
5960
Logistic

skglm/datafits/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
from .base import BaseDatafit, BaseMultitaskDatafit
2-
from .single_task import Quadratic, QuadraticSVC, Logistic, Huber, Poisson, Gamma
2+
from .single_task import Quadratic, QuadraticSVC, Logistic, Huber, Poisson, Gamma, Cox
33
from .multi_task import QuadraticMultiTask
44
from .group import QuadraticGroup, LogisticGroup
55

66

77
__all__ = [
88
BaseDatafit, BaseMultitaskDatafit,
9-
Quadratic, QuadraticSVC, Logistic, Huber, Poisson, Gamma,
9+
Quadratic, QuadraticSVC, Logistic, Huber, Poisson, Gamma, Cox,
1010
QuadraticMultiTask,
1111
QuadraticGroup, LogisticGroup
1212
]

skglm/datafits/single_task.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -544,3 +544,105 @@ def gradient_scalar_sparse(self, X_data, X_indptr, X_indices, y, Xw, j):
544544

545545
def intercept_update_self(self, y, Xw):
546546
pass
547+
548+
549+
class Cox(BaseDatafit):
550+
r"""Cox datafit for survival analysis with Breslow estimate.
551+
552+
The datafit reads [1]
553+
554+
.. math::
555+
556+
1 / n_"samples" \sum_(i=1)^(n_"samples") -s_i \langle x_i, w \rangle
557+
+ \log (\sum_(j | y_j \geq y_i) e^{\langle x_i, w \rangle})
558+
559+
where :math:`s_i` indicates the sample censorship and :math:`tm`
560+
is the vector recording the time of event occurrences.
561+
562+
Defining the matrix :math:`B` with
563+
:math:`B_{i,j} = 1` if :math:`tm_j \geq tm_i` and :math:`0` otherwise,
564+
the datafit can be rewritten in the following compact form
565+
566+
.. math::
567+
568+
1 / n_"samples" \langle s, Xw \rangle
569+
+ 1 / n_"samples" \langle s, \log B e^{Xw} \rangle
570+
571+
572+
Attributes
573+
----------
574+
B : array-like, shape (n_samples, n_samples)
575+
Matrix where every ``(i, j)`` entry (row, column) equals ``1``
576+
if ``tm[j] >= tm[i]`` and `0` otherwise. This matrix is initialized
577+
using the ``.initialize`` method.
578+
579+
References
580+
----------
581+
.. [1] DY Lin. On the Breslow estimator.
582+
Lifetime data analysis, 13:471–480, 2007.
583+
"""
584+
585+
def __init__(self):
586+
pass
587+
588+
def get_spec(self):
589+
return (
590+
('B', float64[:, ::1]),
591+
)
592+
593+
def params_to_dict(self):
594+
return dict()
595+
596+
def value(self, y, w, Xw):
597+
"""Compute the value of the datafit."""
598+
tm, s = y
599+
n_samples = Xw.shape[0]
600+
601+
out = -(s @ Xw) + s @ np.log(self.B @ np.exp(Xw))
602+
return out / n_samples
603+
604+
def raw_grad(self, y, Xw):
605+
r"""Compute gradient of datafit w.r.t. ``Xw``.
606+
607+
The raw gradient reads
608+
609+
(-s + exp_Xw * (B.T @ (s / B @ exp_Xw)) / n_samples
610+
"""
611+
tm, s = y
612+
n_samples = Xw.shape[0]
613+
614+
exp_Xw = np.exp(Xw)
615+
B_exp_Xw = self.B @ exp_Xw
616+
617+
out = -s + exp_Xw * (self.B.T @ (s / B_exp_Xw))
618+
return out / n_samples
619+
620+
def raw_hessian(self, y, Xw):
621+
"""Compute a diagonal upper bound of the datafit's Hessian w.r.t. ``Xw``.
622+
623+
The diagonal upper bound reads
624+
625+
exp_Xw * (B.T @ s / B_exp_Xw) / n_samples
626+
"""
627+
tm, s = y
628+
n_samples = Xw.shape[0]
629+
630+
exp_Xw = np.exp(Xw)
631+
B_exp_Xw = self.B @ exp_Xw
632+
633+
out = exp_Xw * (self.B.T @ (s / B_exp_Xw))
634+
return out / n_samples
635+
636+
def initialize(self, X, y):
637+
"""Initialize the datafit attributes."""
638+
tm, s = y
639+
640+
tm_as_col = tm.reshape((-1, 1))
641+
self.B = (tm >= tm_as_col).astype(X.dtype)
642+
643+
def initialize_sparse(self, X_data, X_indptr, X_indices, y):
644+
"""Initialize the datafit attributes in sparse dataset case."""
645+
tm, s = y
646+
647+
tm_as_col = tm.reshape((-1, 1))
648+
self.B = (tm >= tm_as_col).astype(X_data.dtype)

skglm/tests/test_datafits.py

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,16 @@
11
import numpy as np
2+
import scipy.optimize
23
import pytest
34

45
from sklearn.linear_model import HuberRegressor
56
from numpy.testing import assert_allclose, assert_array_less
67

7-
from skglm.datafits import Huber, Logistic, Poisson, Gamma
8+
from skglm.datafits import Huber, Logistic, Poisson, Gamma, Cox
89
from skglm.penalties import L1, WeightedL1
910
from skglm.solvers import AndersonCD, ProxNewton
1011
from skglm import GeneralizedLinearEstimator
1112
from skglm.utils.data import make_correlated_data
13+
from skglm.utils.jit_compilation import compiled_clone
1214

1315

1416
@pytest.mark.parametrize('fit_intercept', [False, True])
@@ -114,5 +116,58 @@ def test_gamma():
114116
np.testing.assert_allclose(clf.coef_, gamma_results.params, rtol=1e-6)
115117

116118

119+
def test_cox():
120+
rng = np.random.RandomState(1265)
121+
n_samples, n_features = 10, 30
122+
123+
# generate data
124+
X = rng.randn(n_samples, n_features)
125+
tm = rng.choice(n_samples*n_features, size=n_samples, replace=True).astype(float)
126+
s = rng.choice(2, size=n_samples).astype(float)
127+
y = (tm, s)
128+
129+
# generate dummy w, Xw
130+
w = rng.randn(n_features)
131+
Xw = X @ w
132+
133+
# check datafit
134+
cox_df = compiled_clone(Cox())
135+
136+
cox_df.initialize(X, (tm, s))
137+
cox_df.value(y, w, Xw)
138+
139+
# perform test 10 times to consider truncation errors
140+
# due to usage of finite differences to evaluate grad and Hessian
141+
for _ in range(10):
142+
143+
# generate dummy w, Xw
144+
w = rng.randn(n_features)
145+
Xw = X @ w
146+
147+
# check gradient
148+
np.testing.assert_allclose(
149+
scipy.optimize.check_grad(
150+
lambda x: cox_df.value(y, w, x),
151+
lambda x: cox_df.raw_grad(y, x),
152+
x0=Xw,
153+
seed=rng
154+
),
155+
0., atol=1e-6
156+
)
157+
158+
# check hessian upper bound
159+
# Hessian minus its upper bound must be negative semi definite
160+
hess_upper_bound = np.diag(cox_df.raw_hessian(y, Xw))
161+
hess = scipy.optimize.approx_fprime(
162+
xk=Xw,
163+
f=lambda x: cox_df.raw_grad(y, x),
164+
)
165+
166+
positive_eig = np.linalg.eigh(hess - hess_upper_bound)[0]
167+
positive_eig = positive_eig[positive_eig >= 0.]
168+
169+
np.testing.assert_allclose(positive_eig, 0., atol=1e-6)
170+
171+
117172
if __name__ == '__main__':
118173
pass

skglm/tests/test_estimators.py

Lines changed: 90 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,18 @@
1616

1717
from scipy.sparse import csc_matrix, issparse
1818

19-
from skglm.utils.data import make_correlated_data
19+
from skglm.utils.data import make_correlated_data, make_dummy_survival_data
2020
from skglm.estimators import (
2121
GeneralizedLinearEstimator, Lasso, MultiTaskLasso, WeightedLasso, ElasticNet,
2222
MCPRegression, SparseLogisticRegression, LinearSVC)
23-
from skglm.datafits import Logistic, Quadratic, QuadraticSVC, QuadraticMultiTask
23+
from skglm.datafits import Logistic, Quadratic, QuadraticSVC, QuadraticMultiTask, Cox
2424
from skglm.penalties import L1, IndicatorBox, L1_plus_L2, MCPenalty, WeightedL1
2525
from skglm.solvers import AndersonCD
2626

27+
import pandas as pd
28+
from skglm.solvers import ProxNewton
29+
from skglm.utils.jit_compilation import compiled_clone
30+
2731

2832
n_samples = 50
2933
n_tasks = 9
@@ -164,6 +168,90 @@ def test_mtl_path():
164168
np.testing.assert_allclose(coef_ours, coef_sk, rtol=1e-5)
165169

166170

171+
def test_CoxEstimator():
172+
try:
173+
from lifelines import CoxPHFitter
174+
except ModuleNotFoundError:
175+
pytest.xfail(
176+
"Testing Cox Estimator requires `lifelines` packages\n"
177+
"Run `pip install lifelines`"
178+
)
179+
180+
reg = 1e-2
181+
# norms of solutions differ when n_features > n_samples
182+
n_samples, n_features = 100, 30
183+
random_state = 1265
184+
185+
tm, s, X = make_dummy_survival_data(n_samples, n_features,
186+
normalize=True, random_state=random_state)
187+
188+
# compute alpha_max
189+
B = (tm >= tm[:, None]).astype(X.dtype)
190+
grad_0 = -s + B.T @ (s / np.sum(B, axis=1))
191+
alpha_max = norm(X.T @ grad_0, ord=np.inf) / n_samples
192+
193+
alpha = reg * alpha_max
194+
195+
# fit Cox using ProxNewton solver
196+
datafit = compiled_clone(Cox())
197+
penalty = compiled_clone(L1(alpha))
198+
199+
datafit.initialize(X, (tm, s))
200+
201+
w, *_ = ProxNewton(
202+
fit_intercept=False, tol=1e-6, max_iter=50
203+
).solve(
204+
X, (tm, s), datafit, penalty
205+
)
206+
207+
# fit lifeline estimator
208+
stacked_tm_s_X = np.hstack((tm[:, None], s[:, None], X))
209+
df = pd.DataFrame(stacked_tm_s_X)
210+
211+
estimator = CoxPHFitter(penalizer=alpha, l1_ratio=1.)
212+
estimator.fit(
213+
df, duration_col=0, event_col=1,
214+
fit_options={"max_steps": 10_000, "precision": 1e-12}
215+
)
216+
w_ll = estimator.params_.values
217+
218+
p_obj_skglm = datafit.value((tm, s), w, X @ w) + penalty.value(w)
219+
p_obj_ll = datafit.value((tm, s), w_ll, X @ w_ll) + penalty.value(w_ll)
220+
221+
# though norm of solution might differ
222+
np.testing.assert_allclose(p_obj_skglm, p_obj_ll, atol=1e-6)
223+
224+
225+
def test_CoxEstimator_sparse():
226+
reg = 1e-2
227+
n_samples, n_features = 100, 30
228+
X_density, random_state = 0.5, 1265
229+
230+
tm, s, X = make_dummy_survival_data(n_samples, n_features, X_density=X_density,
231+
random_state=random_state)
232+
233+
# compute alpha_max
234+
B = (tm >= tm[:, None]).astype(X.dtype)
235+
grad_0 = -s + B.T @ (s / np.sum(B, axis=1))
236+
alpha_max = norm(X.T @ grad_0, ord=np.inf) / n_samples
237+
238+
alpha = reg * alpha_max
239+
240+
# fit Cox using ProxNewton solver
241+
datafit = compiled_clone(Cox())
242+
penalty = compiled_clone(L1(alpha))
243+
244+
datafit.initialize_sparse(X.data, X.indptr, X.indices, (tm, s))
245+
246+
*_, stop_crit = ProxNewton(
247+
fit_intercept=False, tol=1e-6, max_iter=50
248+
).solve(
249+
X, (tm, s), datafit, penalty
250+
)
251+
252+
np.testing.assert_allclose(stop_crit, 0., atol=1e-6)
253+
254+
167255
# Test if GeneralizedLinearEstimator returns the correct coefficients
168256
@pytest.mark.parametrize("Datafit, Penalty, Estimator, pen_args", [
169257
(Quadratic, L1, Lasso, [alpha]),

0 commit comments

Comments
 (0)