Skip to content

Commit 512464f

Browse files
DOC add survival analysis example (#162)
Co-authored-by: Badr-MOUFAD <[email protected]>
1 parent 399dfc6 commit 512464f

File tree

2 files changed

+206
-0
lines changed

2 files changed

+206
-0
lines changed

doc/doc-requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,4 @@ sphinx_copybutton
99
sphinx-gallery
1010
pytest
1111
furo
12+
lifelines

examples/plot_survival_analysis.py

Lines changed: 205 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,205 @@
1+
# Authors: Badr Moufad
2+
# Mathurin Massias
3+
"""
4+
========================================================
5+
Comparison of lifelines with skglm for survival analysis
6+
========================================================
7+
This example shows that ``skglm`` fits a Cox model exactly as ``lifelines`` but with
8+
x100 less time.
9+
"""
10+
11+
# %%
12+
# Data
13+
# ----
14+
#
15+
# Let's first generate synthetic data on which to run the Cox estimator,
16+
# using ``skglm`` data utils.
17+
#
18+
from skglm.utils.data import make_dummy_survival_data
19+
20+
n_samples, n_features = 500, 100
21+
tm, s, X = make_dummy_survival_data(
22+
n_samples, n_features,
23+
normalize=True,
24+
random_state=0
25+
)
26+
27+
# %%
28+
# The synthetic data has the following properties:
29+
#
30+
# * ``tm`` is the vector of occurrence times which follows a Weibull(1) distribution
31+
# * ``s`` indicates the observations censorship and follows a Bernoulli(0.5) distribution
32+
# * ``X`` is the matrix of predictors, generated using standard normal distribution with Toeplitz covariance.
33+
#
34+
# Let's inspect the data quickly:
35+
import matplotlib.pyplot as plt
36+
37+
fig, axes = plt.subplots(
38+
1, 3,
39+
figsize=(6, 2),
40+
tight_layout=True,
41+
)
42+
43+
dists = (tm, s, X[:, 5])
44+
axes_title = ("times", "censorship", "fifth predictor")
45+
46+
for idx, (dist, name) in enumerate(zip(dists, axes_title)):
47+
axes[idx].hist(dist, bins="auto")
48+
axes[idx].set_title(name)
49+
50+
_ = axes[0].set_ylabel("count")
51+
52+
# %%
53+
# Fitting the Cox Estimator
54+
# -----------------
55+
#
56+
# After generating the synthetic data, we can now fit a L1-regularized Cox estimator.
57+
# Todo so, we need to combine a Cox datafit and a :math:`\ell_1` penalty
58+
# and solve the resulting problem using skglm Proximal Newton solver ``ProxNewton``.
59+
# We set the intensity of the :math:`\ell_1` regularization to ``alpha=1e-2``.
60+
from skglm.datafits import Cox
61+
from skglm.penalties import L1
62+
from skglm.solvers import ProxNewton
63+
64+
from skglm.utils.jit_compilation import compiled_clone
65+
66+
# regularization intensity
67+
alpha = 1e-2
68+
69+
# skglm internals: init datafit and penalty
70+
datafit = compiled_clone(Cox())
71+
penalty = compiled_clone(L1(alpha))
72+
73+
datafit.initialize(X, (tm, s))
74+
75+
# init solver
76+
solver = ProxNewton(fit_intercept=False, max_iter=50,)
77+
78+
# solve the problem
79+
w_sk = solver.solve(X, (tm, s), datafit, penalty)[0]
80+
81+
# %%
82+
# For this data a regularization value a relatively sparse solution is found:
83+
print(f"Number of nonzero coefficients in solution: {(w_sk != 0).sum()} out of {len(w_sk)}.")
84+
85+
86+
# %%
87+
# Let's solve the problem with ``lifelines`` through its ``CoxPHFitter``
88+
# estimator and compare the objectives found by the two packages.
89+
import numpy as np
90+
import pandas as pd
91+
from lifelines import CoxPHFitter
92+
93+
# format data
94+
stacked_tm_s_X = np.hstack((tm[:, None], s[:, None], X))
95+
df = pd.DataFrame(stacked_tm_s_X)
96+
97+
# fit lifelines estimator
98+
lifelines_estimator = CoxPHFitter(penalizer=alpha, l1_ratio=1.).fit(
99+
df,
100+
duration_col=0,
101+
event_col=1
102+
)
103+
w_ll = lifelines_estimator.params_.values
104+
105+
# %%
106+
# Check that both solvers find solutions having the same objective value:
107+
obj_sk = datafit.value((tm, s), w_sk, X @ w_sk) + penalty.value(w_sk)
108+
obj_ll = datafit.value((tm, s), w_ll, X @ w_ll) + penalty.value(w_ll)
109+
110+
print(f"Objective skglm: {obj_sk:.6f}")
111+
print(f"Objective lifelines: {obj_ll:.6f}")
112+
print(f"Difference: {(obj_sk - obj_ll):.2e}")
113+
# %%
114+
# We can do the same to check how close the two solutions are.
115+
print(f"Euclidean distance between solutions: {np.linalg.norm(w_sk - w_ll):.3e}")
116+
117+
# %%
118+
# Timing comparison
119+
# -----------------
120+
#
121+
# Now that we checked that both ``skglm`` and ``lifelines`` yield the same results,
122+
# let's compare their execution time. To get the evolution of the suboptimality
123+
# (objective - optimal objective) we run both estimators with increasing number of
124+
# iterations.
125+
import time
126+
import warnings
127+
128+
warnings.filterwarnings('ignore')
129+
130+
# where to save records
131+
records = {
132+
"skglm": {"times": [], "objs": []},
133+
"lifelines": {"times": [], "objs": []},
134+
}
135+
136+
# time skglm
137+
max_runs = 20
138+
for n_iter in range(1, max_runs + 1):
139+
solver.max_iter = n_iter
140+
141+
start = time.perf_counter()
142+
w = solver.solve(X, (tm, s), datafit, penalty)[0]
143+
end = time.perf_counter()
144+
145+
records["skglm"]["objs"].append(
146+
datafit.value((tm, s), w, X @ w) + penalty.value(w)
147+
)
148+
records["skglm"]["times"].append(end - start)
149+
150+
# time lifelines
151+
max_runs = 50
152+
for n_iter in list(range(10)) + list(range(10, max_runs + 1, 5)):
153+
start = time.perf_counter()
154+
lifelines_estimator.fit(
155+
df,
156+
duration_col=0,
157+
event_col=1,
158+
fit_options={"max_steps": n_iter},
159+
)
160+
end = time.perf_counter()
161+
162+
w = lifelines_estimator.params_.values
163+
164+
records["lifelines"]["objs"].append(
165+
datafit.value((tm, s), w, X @ w) + penalty.value(w)
166+
)
167+
records["lifelines"]["times"].append(end - start)
168+
169+
170+
# cast records as numpy array
171+
for idx, label in enumerate(("skglm", "lifelines")):
172+
for metric in ("objs", "times"):
173+
records[label][metric] = np.asarray(records[label][metric])
174+
175+
# %%
176+
# Results
177+
# -------
178+
179+
fig, ax = plt.subplots(1, 1, tight_layout=True, figsize=(6, 3))
180+
solvers = ("skglm", "lifelines")
181+
182+
optimal_obj = min(records[solver]["objs"].min() for solver in solvers)
183+
184+
# plot evolution of suboptimality
185+
for solver in solvers:
186+
ax.semilogy(
187+
records[solver]["times"],
188+
records[solver]["objs"] - optimal_obj,
189+
label=solver,
190+
marker='o',
191+
)
192+
ax.legend()
193+
ax.set_title("Time to fit a Cox model")
194+
195+
ax.set_ylabel("objective suboptimality")
196+
_ = ax.set_xlabel("time in seconds")
197+
198+
199+
200+
# %%
201+
# According to printed ratio, using ``skglm`` we get the same result as ``lifelines``
202+
# with more than x100 less time!
203+
speed_up = records["lifelines"]["times"][-1] / records["skglm"]["times"][-1]
204+
print(f"speed up ratio: {speed_up:.0f}")
205+

0 commit comments

Comments
 (0)