Skip to content

Commit 395af5e

Browse files
ENH - Implement Cox with Efron estimate (#159)
Co-authored-by: mathurinm <[email protected]>
1 parent 9bce414 commit 395af5e

File tree

5 files changed

+221
-56
lines changed

5 files changed

+221
-56
lines changed

examples/plot_survival_analysis.py

Lines changed: 99 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -73,15 +73,17 @@
7373
datafit.initialize(X, (tm, s))
7474

7575
# init solver
76-
solver = ProxNewton(fit_intercept=False, max_iter=50,)
76+
solver = ProxNewton(fit_intercept=False, max_iter=50)
7777

7878
# solve the problem
7979
w_sk = solver.solve(X, (tm, s), datafit, penalty)[0]
8080

8181
# %%
8282
# For this data a regularization value a relatively sparse solution is found:
83-
print(f"Number of nonzero coefficients in solution: {(w_sk != 0).sum()} out of {len(w_sk)}.")
84-
83+
print(
84+
"Number of nonzero coefficients in solution: "
85+
f"{(w_sk != 0).sum()} out of {len(w_sk)}."
86+
)
8587

8688
# %%
8789
# Let's solve the problem with ``lifelines`` through its ``CoxPHFitter``
@@ -195,11 +197,103 @@
195197
ax.set_ylabel("objective suboptimality")
196198
_ = ax.set_xlabel("time in seconds")
197199

198-
199-
200200
# %%
201201
# According to printed ratio, using ``skglm`` we get the same result as ``lifelines``
202202
# with more than x100 less time!
203203
speed_up = records["lifelines"]["times"][-1] / records["skglm"]["times"][-1]
204204
print(f"speed up ratio: {speed_up:.0f}")
205205

206+
# %%
207+
# Efron estimate
208+
# --------------
209+
#
210+
# The previous results, namely closeness of solutions and timings,
211+
# can be extended to the case of handling tied observation with the Efron estimate.
212+
#
213+
# Let's start by generating data with tied observations. This can be achieved
214+
# by passing in a ``with_ties=True`` to ``make_dummy_survival_data`` function.
215+
tm, s, X = make_dummy_survival_data(
216+
n_samples, n_features,
217+
normalize=True,
218+
with_ties=True,
219+
random_state=0
220+
)
221+
222+
# check the data has tied observations
223+
print(f"Number of unique times {len(np.unique(tm))} out of {n_samples}")
224+
225+
# %%
226+
# It is straightforward to fit an :math:`\ell_1` Cox estimator with the Efron estimate.
227+
# We only need to pass in ``use_efron=True`` to the ``Cox`` datafit.
228+
229+
# ensure using Efron estimate
230+
datafit = compiled_clone(Cox(use_efron=True))
231+
datafit.initialize(X, (tm, s))
232+
233+
# solve the problem
234+
solver = ProxNewton(fit_intercept=False, max_iter=50)
235+
w_sk = solver.solve(X, (tm, s), datafit, penalty)[0]
236+
237+
# %%
238+
# Again a relatively sparse solution is found:
239+
print(
240+
"Number of nonzero coefficients in solution: "
241+
f"{(w_sk != 0).sum()} out of {len(w_sk)}."
242+
)
243+
244+
# %%
245+
# Let's do the same with ``lifelines`` and compare the results
246+
247+
# format data
248+
stacked_tm_s_X = np.hstack((tm[:, None], s[:, None], X))
249+
df = pd.DataFrame(stacked_tm_s_X)
250+
251+
# fit lifelines estimator on the new data
252+
lifelines_estimator = CoxPHFitter(penalizer=alpha, l1_ratio=1.).fit(
253+
df,
254+
duration_col=0,
255+
event_col=1
256+
)
257+
w_ll = lifelines_estimator.params_.values
258+
259+
# Check that both solvers find solutions with the same objective value
260+
obj_sk = datafit.value((tm, s), w_sk, X @ w_sk) + penalty.value(w_sk)
261+
obj_ll = datafit.value((tm, s), w_ll, X @ w_ll) + penalty.value(w_ll)
262+
263+
print(f"Objective skglm: {obj_sk:.6f}")
264+
print(f"Objective lifelines: {obj_ll:.6f}")
265+
print(f"Difference: {(obj_sk - obj_ll):.2e}")
266+
267+
# Check that both solutions are close
268+
print(f"Euclidean distance between solutions: {np.linalg.norm(w_sk - w_ll):.3e}")
269+
270+
# %%
271+
# Finally, let's compare the timings of both solvers
272+
273+
# time skglm
274+
start = time.perf_counter()
275+
solver.solve(X, (tm, s), datafit, penalty)[0]
276+
end = time.perf_counter()
277+
278+
total_time_skglm = end - start
279+
280+
# time lifelines
281+
lifelines_estimator = CoxPHFitter(penalizer=alpha, l1_ratio=1.)
282+
283+
start = time.perf_counter()
284+
lifelines_estimator.fit(
285+
df,
286+
duration_col=0,
287+
event_col=1
288+
)
289+
end = time.perf_counter()
290+
291+
total_time_lifelines = end - start
292+
293+
# deduce speed up ratio
294+
speed_up = total_time_lifelines / total_time_skglm
295+
print(f"speed up ratio: {speed_up:.0f}")
296+
297+
# %%
298+
# As shown by the last print, we still preserve the x100 ratio speed up
299+
# even for the Efron estimate.

skglm/datafits/single_task.py

Lines changed: 99 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import numpy as np
22
from numpy.linalg import norm
33
from numba import njit
4-
from numba import float64
4+
from numba import float64, int64, bool_
55

66
from skglm.datafits.base import BaseDatafit
77
from skglm.utils.sparse_ops import spectral_norm
@@ -547,90 +547,100 @@ def intercept_update_self(self, y, Xw):
547547

548548

549549
class Cox(BaseDatafit):
550-
r"""Cox datafit for survival analysis with Breslow estimate.
550+
r"""Cox datafit for survival analysis.
551551
552-
The datafit reads [1]
553-
554-
.. math::
555-
556-
1 / n_"samples" \sum_(i=1)^(n_"samples") -s_i \langle x_i, w \rangle
557-
+ \log (\sum_(j | y_j \geq y_i) e^{\langle x_i, w \rangle})
558-
559-
where :math:`s_i` indicates the sample censorship and :math:`tm`
560-
is the vector recording the time of event occurrences.
561-
562-
Defining the matrix :math:`B` with
563-
:math:`B_{i,j} = 1` if :math:`tm_j \geq tm_i` and :math:`0` otherwise,
564-
the datafit can be rewritten in the following compact form
565-
566-
.. math::
567-
568-
1 / n_"samples" \langle s, Xw \rangle
569-
+ 1 / n_"samples" \langle s, \log B e^{Xw} \rangle
552+
Refer to :ref:`Mathematics behind Cox datafit <maths_cox_datafit>` for details.
570553
554+
Parameters
555+
----------
556+
use_efron : bool, default=False
557+
If ``True`` uses Efron estimate to handle tied observations.
571558
572559
Attributes
573560
----------
574561
B : array-like, shape (n_samples, n_samples)
575562
Matrix where every ``(i, j)`` entry (row, column) equals ``1``
576-
if ``tm[j] >= tm[i]`` and `0` otherwise. This matrix is initialized
563+
if ``tm[j] >= tm[i]`` and ``0`` otherwise. This matrix is initialized
577564
using the ``.initialize`` method.
578565
579-
References
580-
----------
581-
.. [1] DY Lin. On the Breslow estimator.
582-
Lifetime data analysis, 13:471–480, 2007.
566+
H_indices : array-like, shape (n_samples,)
567+
Indices of observations with the same occurrence times stacked horizontally
568+
as ``[group_1, group_2, ...]``. This array is initialized
569+
when calling ``.initialize`` method when ``use_efron=True``.
570+
571+
H_indptr : array-like, (np.unique(tm) + 1,)
572+
Array where two consecutive elements delimits a group of observations
573+
having the same occurrence times.
583574
"""
584575

585-
def __init__(self):
586-
pass
576+
def __init__(self, use_efron=False):
577+
self.use_efron = use_efron
587578

588579
def get_spec(self):
589580
return (
581+
('use_efron', bool_),
590582
('B', float64[:, ::1]),
583+
('H_indptr', int64[:]),
584+
('H_indices', int64[:]),
591585
)
592586

593587
def params_to_dict(self):
594-
return dict()
588+
return dict(use_efron=self.use_efron)
595589

596590
def value(self, y, w, Xw):
597591
"""Compute the value of the datafit."""
598592
tm, s = y
599593
n_samples = Xw.shape[0]
600594

601-
out = -(s @ Xw) + s @ np.log(self.B @ np.exp(Xw))
595+
# compute inside log term
596+
exp_Xw = np.exp(Xw)
597+
B_exp_Xw = self.B @ exp_Xw
598+
if self.use_efron:
599+
B_exp_Xw -= self._A_dot_vec(exp_Xw)
600+
601+
out = -(s @ Xw) + s @ np.log(B_exp_Xw)
602602
return out / n_samples
603603

604604
def raw_grad(self, y, Xw):
605605
r"""Compute gradient of datafit w.r.t. ``Xw``.
606606
607-
The raw gradient reads
608-
609-
(-s + exp_Xw * (B.T @ (s / B @ exp_Xw)) / n_samples
607+
Refer to :ref:`Mathematics behind Cox datafit <maths_cox_datafit>`
608+
equation 4 for details.
610609
"""
611610
tm, s = y
612611
n_samples = Xw.shape[0]
613612

614613
exp_Xw = np.exp(Xw)
615614
B_exp_Xw = self.B @ exp_Xw
615+
if self.use_efron:
616+
B_exp_Xw -= self._A_dot_vec(exp_Xw)
617+
618+
s_over_B_exp_Xw = s / B_exp_Xw
619+
out = -s + exp_Xw * (self.B.T @ (s_over_B_exp_Xw))
620+
if self.use_efron:
621+
out -= exp_Xw * self._AT_dot_vec(s_over_B_exp_Xw)
616622

617-
out = -s + exp_Xw * (self.B.T @ (s / B_exp_Xw))
618623
return out / n_samples
619624

620625
def raw_hessian(self, y, Xw):
621626
"""Compute a diagonal upper bound of the datafit's Hessian w.r.t. ``Xw``.
622627
623-
The diagonal upper bound reads
624-
625-
exp_Xw * (B.T @ s / B_exp_Xw) / n_samples
628+
Refer to :ref:`Mathematics behind Cox datafit <maths_cox_datafit>`
629+
equation 6 for details.
626630
"""
627631
tm, s = y
628632
n_samples = Xw.shape[0]
629633

630634
exp_Xw = np.exp(Xw)
631635
B_exp_Xw = self.B @ exp_Xw
636+
if self.use_efron:
637+
B_exp_Xw -= self._A_dot_vec(exp_Xw)
638+
639+
s_over_B_exp_Xw = s / B_exp_Xw
640+
out = exp_Xw * (self.B.T @ (s_over_B_exp_Xw))
641+
if self.use_efron:
642+
out -= exp_Xw * self._AT_dot_vec(s_over_B_exp_Xw)
632643

633-
out = exp_Xw * (self.B.T @ (s / B_exp_Xw))
634644
return out / n_samples
635645

636646
def initialize(self, X, y):
@@ -640,9 +650,58 @@ def initialize(self, X, y):
640650
tm_as_col = tm.reshape((-1, 1))
641651
self.B = (tm >= tm_as_col).astype(X.dtype)
642652

653+
if self.use_efron:
654+
H_indices = np.argsort(tm)
655+
# filter out censored data
656+
H_indices = H_indices[s[H_indices] != 0]
657+
n_uncensored_samples = H_indices.shape[0]
658+
659+
# build H_indptr
660+
H_indptr = [0]
661+
count = 1
662+
for i in range(1, n_uncensored_samples):
663+
if tm[H_indices[i-1]] == tm[H_indices[i]]:
664+
count += 1
665+
else:
666+
H_indptr.append(count + H_indptr[-1])
667+
count = 1
668+
H_indptr.append(n_uncensored_samples)
669+
H_indptr = np.asarray(H_indptr, dtype=np.int64)
670+
671+
# save in instance
672+
self.H_indptr = H_indptr
673+
self.H_indices = H_indices
674+
643675
def initialize_sparse(self, X_data, X_indptr, X_indices, y):
644676
"""Initialize the datafit attributes in sparse dataset case."""
645-
tm, s = y
677+
# initialize_sparse and initialize have the same implementation
678+
# small hack to avoid repetitive code: pass in X_data as only its dtype is used
679+
self.initialize(X_data, y)
646680

647-
tm_as_col = tm.reshape((-1, 1))
648-
self.B = (tm >= tm_as_col).astype(X_data.dtype)
681+
def _A_dot_vec(self, vec):
682+
out = np.zeros_like(vec)
683+
n_H = self.H_indptr.shape[0] - 1
684+
685+
for idx in range(n_H):
686+
current_H_idx = self.H_indices[self.H_indptr[idx]: self.H_indptr[idx+1]]
687+
size_current_H = current_H_idx.shape[0]
688+
frac_range = np.arange(size_current_H, dtype=vec.dtype) / size_current_H
689+
690+
sum_vec_H = np.sum(vec[current_H_idx])
691+
out[current_H_idx] = sum_vec_H * frac_range
692+
693+
return out
694+
695+
def _AT_dot_vec(self, vec):
696+
out = np.zeros_like(vec)
697+
n_H = self.H_indptr.shape[0] - 1
698+
699+
for idx in range(n_H):
700+
current_H_idx = self.H_indices[self.H_indptr[idx]: self.H_indptr[idx+1]]
701+
size_current_H = current_H_idx.shape[0]
702+
frac_range = np.arange(size_current_H, dtype=vec.dtype) / size_current_H
703+
704+
weighted_sum_vec_H = vec[current_H_idx] @ frac_range
705+
out[current_H_idx] = weighted_sum_vec_H * np.ones(size_current_H)
706+
707+
return out

skglm/tests/test_datafits.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,8 @@ def test_gamma():
116116
np.testing.assert_allclose(clf.coef_, gamma_results.params, rtol=1e-6)
117117

118118

119-
def test_cox():
119+
@pytest.mark.parametrize("use_efron", [True, False])
120+
def test_cox(use_efron):
120121
rng = np.random.RandomState(1265)
121122
n_samples, n_features = 10, 30
122123

@@ -131,7 +132,7 @@ def test_cox():
131132
Xw = X @ w
132133

133134
# check datafit
134-
cox_df = compiled_clone(Cox())
135+
cox_df = compiled_clone(Cox(use_efron))
135136

136137
cox_df.initialize(X, (tm, s))
137138
cox_df.value(y, w, Xw)

0 commit comments

Comments
 (0)