|
| 1 | +# Authors: Badr Moufad |
| 2 | +# Mathurin Massias |
| 3 | +""" |
| 4 | +======================================================== |
| 5 | +Comparison of lifelines with skglm for survival analysis |
| 6 | +======================================================== |
| 7 | +This example shows that ``skglm`` fits a Cox model exactly as ``lifelines`` but with |
| 8 | +x100 less time. |
| 9 | +""" |
| 10 | + |
| 11 | +# %% |
| 12 | +# Data |
| 13 | +# ---- |
| 14 | +# |
| 15 | +# Let's first generate synthetic data on which to run the Cox estimator, |
| 16 | +# using ``skglm`` data utils. |
| 17 | +# |
| 18 | +from skglm.utils.data import make_dummy_survival_data |
| 19 | + |
| 20 | +n_samples, n_features = 500, 100 |
| 21 | +tm, s, X = make_dummy_survival_data( |
| 22 | + n_samples, n_features, |
| 23 | + normalize=True, |
| 24 | + random_state=0 |
| 25 | +) |
| 26 | + |
| 27 | +# %% |
| 28 | +# The synthetic data has the following properties: |
| 29 | +# |
| 30 | +# * ``tm`` is the vector of occurrence times which follows a Weibull(1) distribution |
| 31 | +# * ``s`` indicates the observations censorship and follows a Bernoulli(0.5) distribution |
| 32 | +# * ``X`` is the matrix of predictors, generated using standard normal distribution with Toeplitz covariance. |
| 33 | +# |
| 34 | +# Let's inspect the data quickly: |
| 35 | +import matplotlib.pyplot as plt |
| 36 | + |
| 37 | +fig, axes = plt.subplots( |
| 38 | + 1, 3, |
| 39 | + figsize=(6, 2), |
| 40 | + tight_layout=True, |
| 41 | +) |
| 42 | + |
| 43 | +dists = (tm, s, X[:, 5]) |
| 44 | +axes_title = ("times", "censorship", "fifth predictor") |
| 45 | + |
| 46 | +for idx, (dist, name) in enumerate(zip(dists, axes_title)): |
| 47 | + axes[idx].hist(dist, bins="auto") |
| 48 | + axes[idx].set_title(name) |
| 49 | + |
| 50 | +_ = axes[0].set_ylabel("count") |
| 51 | + |
| 52 | +# %% |
| 53 | +# Fitting the Cox Estimator |
| 54 | +# ----------------- |
| 55 | +# |
| 56 | +# After generating the synthetic data, we can now fit a L1-regularized Cox estimator. |
| 57 | +# Todo so, we need to combine a Cox datafit and a :math:`\ell_1` penalty |
| 58 | +# and solve the resulting problem using skglm Proximal Newton solver ``ProxNewton``. |
| 59 | +# We set the intensity of the :math:`\ell_1` regularization to ``alpha=1e-2``. |
| 60 | +from skglm.datafits import Cox |
| 61 | +from skglm.penalties import L1 |
| 62 | +from skglm.solvers import ProxNewton |
| 63 | + |
| 64 | +from skglm.utils.jit_compilation import compiled_clone |
| 65 | + |
| 66 | +# regularization intensity |
| 67 | +alpha = 1e-2 |
| 68 | + |
| 69 | +# skglm internals: init datafit and penalty |
| 70 | +datafit = compiled_clone(Cox()) |
| 71 | +penalty = compiled_clone(L1(alpha)) |
| 72 | + |
| 73 | +datafit.initialize(X, (tm, s)) |
| 74 | + |
| 75 | +# init solver |
| 76 | +solver = ProxNewton(fit_intercept=False, max_iter=50,) |
| 77 | + |
| 78 | +# solve the problem |
| 79 | +w_sk = solver.solve(X, (tm, s), datafit, penalty)[0] |
| 80 | + |
| 81 | +# %% |
| 82 | +# For this data a regularization value a relatively sparse solution is found: |
| 83 | +print(f"Number of nonzero coefficients in solution: {(w_sk != 0).sum()} out of {len(w_sk)}.") |
| 84 | + |
| 85 | + |
| 86 | +# %% |
| 87 | +# Let's solve the problem with ``lifelines`` through its ``CoxPHFitter`` |
| 88 | +# estimator and compare the objectives found by the two packages. |
| 89 | +import numpy as np |
| 90 | +import pandas as pd |
| 91 | +from lifelines import CoxPHFitter |
| 92 | + |
| 93 | +# format data |
| 94 | +stacked_tm_s_X = np.hstack((tm[:, None], s[:, None], X)) |
| 95 | +df = pd.DataFrame(stacked_tm_s_X) |
| 96 | + |
| 97 | +# fit lifelines estimator |
| 98 | +lifelines_estimator = CoxPHFitter(penalizer=alpha, l1_ratio=1.).fit( |
| 99 | + df, |
| 100 | + duration_col=0, |
| 101 | + event_col=1 |
| 102 | +) |
| 103 | +w_ll = lifelines_estimator.params_.values |
| 104 | + |
| 105 | +# %% |
| 106 | +# Check that both solvers find solutions having the same objective value: |
| 107 | +obj_sk = datafit.value((tm, s), w_sk, X @ w_sk) + penalty.value(w_sk) |
| 108 | +obj_ll = datafit.value((tm, s), w_ll, X @ w_ll) + penalty.value(w_ll) |
| 109 | + |
| 110 | +print(f"Objective skglm: {obj_sk:.6f}") |
| 111 | +print(f"Objective lifelines: {obj_ll:.6f}") |
| 112 | +print(f"Difference: {(obj_sk - obj_ll):.2e}") |
| 113 | +# %% |
| 114 | +# We can do the same to check how close the two solutions are. |
| 115 | +print(f"Euclidean distance between solutions: {np.linalg.norm(w_sk - w_ll):.3e}") |
| 116 | + |
| 117 | +# %% |
| 118 | +# Timing comparison |
| 119 | +# ----------------- |
| 120 | +# |
| 121 | +# Now that we checked that both ``skglm`` and ``lifelines`` yield the same results, |
| 122 | +# let's compare their execution time. To get the evolution of the suboptimality |
| 123 | +# (objective - optimal objective) we run both estimators with increasing number of |
| 124 | +# iterations. |
| 125 | +import time |
| 126 | +import warnings |
| 127 | + |
| 128 | +warnings.filterwarnings('ignore') |
| 129 | + |
| 130 | +# where to save records |
| 131 | +records = { |
| 132 | + "skglm": {"times": [], "objs": []}, |
| 133 | + "lifelines": {"times": [], "objs": []}, |
| 134 | +} |
| 135 | + |
| 136 | +# time skglm |
| 137 | +max_runs = 20 |
| 138 | +for n_iter in range(1, max_runs + 1): |
| 139 | + solver.max_iter = n_iter |
| 140 | + |
| 141 | + start = time.perf_counter() |
| 142 | + w = solver.solve(X, (tm, s), datafit, penalty)[0] |
| 143 | + end = time.perf_counter() |
| 144 | + |
| 145 | + records["skglm"]["objs"].append( |
| 146 | + datafit.value((tm, s), w, X @ w) + penalty.value(w) |
| 147 | + ) |
| 148 | + records["skglm"]["times"].append(end - start) |
| 149 | + |
| 150 | +# time lifelines |
| 151 | +max_runs = 50 |
| 152 | +for n_iter in list(range(10)) + list(range(10, max_runs + 1, 5)): |
| 153 | + start = time.perf_counter() |
| 154 | + lifelines_estimator.fit( |
| 155 | + df, |
| 156 | + duration_col=0, |
| 157 | + event_col=1, |
| 158 | + fit_options={"max_steps": n_iter}, |
| 159 | + ) |
| 160 | + end = time.perf_counter() |
| 161 | + |
| 162 | + w = lifelines_estimator.params_.values |
| 163 | + |
| 164 | + records["lifelines"]["objs"].append( |
| 165 | + datafit.value((tm, s), w, X @ w) + penalty.value(w) |
| 166 | + ) |
| 167 | + records["lifelines"]["times"].append(end - start) |
| 168 | + |
| 169 | + |
| 170 | +# cast records as numpy array |
| 171 | +for idx, label in enumerate(("skglm", "lifelines")): |
| 172 | + for metric in ("objs", "times"): |
| 173 | + records[label][metric] = np.asarray(records[label][metric]) |
| 174 | + |
| 175 | +# %% |
| 176 | +# Results |
| 177 | +# ------- |
| 178 | + |
| 179 | +fig, ax = plt.subplots(1, 1, tight_layout=True, figsize=(6, 3)) |
| 180 | +solvers = ("skglm", "lifelines") |
| 181 | + |
| 182 | +optimal_obj = min(records[solver]["objs"].min() for solver in solvers) |
| 183 | + |
| 184 | +# plot evolution of suboptimality |
| 185 | +for solver in solvers: |
| 186 | + ax.semilogy( |
| 187 | + records[solver]["times"], |
| 188 | + records[solver]["objs"] - optimal_obj, |
| 189 | + label=solver, |
| 190 | + marker='o', |
| 191 | + ) |
| 192 | +ax.legend() |
| 193 | +ax.set_title("Time to fit a Cox model") |
| 194 | + |
| 195 | +ax.set_ylabel("objective suboptimality") |
| 196 | +_ = ax.set_xlabel("time in seconds") |
| 197 | + |
| 198 | + |
| 199 | + |
| 200 | +# %% |
| 201 | +# According to printed ratio, using ``skglm`` we get the same result as ``lifelines`` |
| 202 | +# with more than x100 less time! |
| 203 | +speed_up = records["lifelines"]["times"][-1] / records["skglm"]["times"][-1] |
| 204 | +print(f"speed up ratio: {speed_up:.0f}") |
| 205 | + |
0 commit comments