Skip to content

Commit 18e6456

Browse files
authored
MNT - Switch reponse y convention in Cox estomation (#175)
1 parent d80c9aa commit 18e6456

File tree

6 files changed

+60
-54
lines changed

6 files changed

+60
-54
lines changed

examples/plot_survival_analysis.py

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -18,18 +18,20 @@
1818
from skglm.utils.data import make_dummy_survival_data
1919

2020
n_samples, n_features = 500, 100
21-
tm, s, X = make_dummy_survival_data(
21+
X, y = make_dummy_survival_data(
2222
n_samples, n_features,
2323
normalize=True,
2424
random_state=0
2525
)
2626

27+
tm, s = y[:, 0], y[:, 1]
28+
2729
# %%
2830
# The synthetic data has the following properties:
2931
#
32+
# * ``X`` is the matrix of predictors, generated using standard normal distribution with Toeplitz covariance.
3033
# * ``tm`` is the vector of occurrence times which follows a Weibull(1) distribution
3134
# * ``s`` indicates the observations censorship and follows a Bernoulli(0.5) distribution
32-
# * ``X`` is the matrix of predictors, generated using standard normal distribution with Toeplitz covariance.
3335
#
3436
# Let's inspect the data quickly:
3537
import matplotlib.pyplot as plt
@@ -70,13 +72,13 @@
7072
datafit = compiled_clone(Cox())
7173
penalty = compiled_clone(L1(alpha))
7274

73-
datafit.initialize(X, (tm, s))
75+
datafit.initialize(X, y)
7476

7577
# init solver
7678
solver = ProxNewton(fit_intercept=False, max_iter=50)
7779

7880
# solve the problem
79-
w_sk = solver.solve(X, (tm, s), datafit, penalty)[0]
81+
w_sk = solver.solve(X, y, datafit, penalty)[0]
8082

8183
# %%
8284
# For this data a regularization value a relatively sparse solution is found:
@@ -93,8 +95,8 @@
9395
from lifelines import CoxPHFitter
9496

9597
# format data
96-
stacked_tm_s_X = np.hstack((tm[:, None], s[:, None], X))
97-
df = pd.DataFrame(stacked_tm_s_X)
98+
stacked_y_X = np.hstack((y, X))
99+
df = pd.DataFrame(stacked_y_X)
98100

99101
# fit lifelines estimator
100102
lifelines_estimator = CoxPHFitter(penalizer=alpha, l1_ratio=1.).fit(
@@ -106,8 +108,8 @@
106108

107109
# %%
108110
# Check that both solvers find solutions having the same objective value:
109-
obj_sk = datafit.value((tm, s), w_sk, X @ w_sk) + penalty.value(w_sk)
110-
obj_ll = datafit.value((tm, s), w_ll, X @ w_ll) + penalty.value(w_ll)
111+
obj_sk = datafit.value(y, w_sk, X @ w_sk) + penalty.value(w_sk)
112+
obj_ll = datafit.value(y, w_ll, X @ w_ll) + penalty.value(w_ll)
111113

112114
print(f"Objective skglm: {obj_sk:.6f}")
113115
print(f"Objective lifelines: {obj_ll:.6f}")
@@ -141,11 +143,11 @@
141143
solver.max_iter = n_iter
142144

143145
start = time.perf_counter()
144-
w = solver.solve(X, (tm, s), datafit, penalty)[0]
146+
w = solver.solve(X, y, datafit, penalty)[0]
145147
end = time.perf_counter()
146148

147149
records["skglm"]["objs"].append(
148-
datafit.value((tm, s), w, X @ w) + penalty.value(w)
150+
datafit.value(y, w, X @ w) + penalty.value(w)
149151
)
150152
records["skglm"]["times"].append(end - start)
151153

@@ -164,7 +166,7 @@
164166
w = lifelines_estimator.params_.values
165167

166168
records["lifelines"]["objs"].append(
167-
datafit.value((tm, s), w, X @ w) + penalty.value(w)
169+
datafit.value(y, w, X @ w) + penalty.value(w)
168170
)
169171
records["lifelines"]["times"].append(end - start)
170172

@@ -212,12 +214,13 @@
212214
#
213215
# Let's start by generating data with tied observations. This can be achieved
214216
# by passing in a ``with_ties=True`` to ``make_dummy_survival_data`` function.
215-
tm, s, X = make_dummy_survival_data(
217+
X, y = make_dummy_survival_data(
216218
n_samples, n_features,
217219
normalize=True,
218220
with_ties=True,
219221
random_state=0
220222
)
223+
tm, s = y[:, 0], y[:, 1]
221224

222225
# check the data has tied observations
223226
print(f"Number of unique times {len(np.unique(tm))} out of {n_samples}")
@@ -228,11 +231,11 @@
228231

229232
# ensure using Efron estimate
230233
datafit = compiled_clone(Cox(use_efron=True))
231-
datafit.initialize(X, (tm, s))
234+
datafit.initialize(X, y)
232235

233236
# solve the problem
234237
solver = ProxNewton(fit_intercept=False, max_iter=50)
235-
w_sk = solver.solve(X, (tm, s), datafit, penalty)[0]
238+
w_sk = solver.solve(X, y, datafit, penalty)[0]
236239

237240
# %%
238241
# Again a relatively sparse solution is found:
@@ -257,8 +260,8 @@
257260
w_ll = lifelines_estimator.params_.values
258261

259262
# Check that both solvers find solutions with the same objective value
260-
obj_sk = datafit.value((tm, s), w_sk, X @ w_sk) + penalty.value(w_sk)
261-
obj_ll = datafit.value((tm, s), w_ll, X @ w_ll) + penalty.value(w_ll)
263+
obj_sk = datafit.value(y, w_sk, X @ w_sk) + penalty.value(w_sk)
264+
obj_ll = datafit.value(y, w_ll, X @ w_ll) + penalty.value(w_ll)
262265

263266
print(f"Objective skglm: {obj_sk:.6f}")
264267
print(f"Objective lifelines: {obj_ll:.6f}")
@@ -272,7 +275,7 @@
272275

273276
# time skglm
274277
start = time.perf_counter()
275-
solver.solve(X, (tm, s), datafit, penalty)[0]
278+
solver.solve(X, y, datafit, penalty)[0]
276279
end = time.perf_counter()
277280

278281
total_time_skglm = end - start

skglm/datafits/single_task.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -607,7 +607,7 @@ def params_to_dict(self):
607607

608608
def value(self, y, w, Xw):
609609
"""Compute the value of the datafit."""
610-
tm, s = y
610+
tm, s = y[:, 0], y[:, 1] # noqa
611611
n_samples = Xw.shape[0]
612612

613613
# compute inside log term
@@ -625,7 +625,7 @@ def raw_grad(self, y, Xw):
625625
Refer to :ref:`Mathematics behind Cox datafit <maths_cox_datafit>`
626626
equation 4 for details.
627627
"""
628-
tm, s = y
628+
tm, s = y[:, 0], y[:, 1] # noqa
629629
n_samples = Xw.shape[0]
630630

631631
exp_Xw = np.exp(Xw)
@@ -646,7 +646,7 @@ def raw_hessian(self, y, Xw):
646646
Refer to :ref:`Mathematics behind Cox datafit <maths_cox_datafit>`
647647
equation 6 for details.
648648
"""
649-
tm, s = y
649+
tm, s = y[:, 0], y[:, 1] # noqa
650650
n_samples = Xw.shape[0]
651651

652652
exp_Xw = np.exp(Xw)
@@ -678,7 +678,7 @@ def gradient_sparse(self, X_data, X_indptr, X_indices, y, Xw):
678678

679679
def initialize(self, X, y):
680680
"""Initialize the datafit attributes."""
681-
tm, s = y
681+
tm, s = y[:, 0], y[:, 1] # noqa
682682

683683
self.T_indices = np.argsort(tm)
684684
self.T_indptr = self._get_indptr(tm, self.T_indices)

skglm/tests/test_datafits.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from skglm import GeneralizedLinearEstimator
1212
from skglm.utils.data import make_correlated_data
1313
from skglm.utils.jit_compilation import compiled_clone
14+
from skglm.utils.data import make_dummy_survival_data
1415

1516

1617
@pytest.mark.parametrize('fit_intercept', [False, True])
@@ -122,10 +123,8 @@ def test_cox(use_efron):
122123
n_samples, n_features = 10, 30
123124

124125
# generate data
125-
X = rng.randn(n_samples, n_features)
126-
tm = rng.choice(n_samples*n_features, size=n_samples, replace=True).astype(float)
127-
s = rng.choice(2, size=n_samples).astype(float)
128-
y = (tm, s)
126+
X, y = make_dummy_survival_data(n_samples, n_features, normalize=True,
127+
with_ties=use_efron, random_state=0)
129128

130129
# generate dummy w, Xw
131130
w = rng.randn(n_features)
@@ -134,7 +133,7 @@ def test_cox(use_efron):
134133
# check datafit
135134
cox_df = compiled_clone(Cox(use_efron))
136135

137-
cox_df.initialize(X, (tm, s))
136+
cox_df.initialize(X, y)
138137
cox_df.value(y, w, Xw)
139138

140139
# perform test 10 times to consider truncation errors

skglm/tests/test_estimators.py

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -184,9 +184,10 @@ def test_CoxEstimator(use_efron, use_float_32):
184184
n_samples, n_features = 100, 30
185185
random_state = 1265
186186

187-
tm, s, X = make_dummy_survival_data(n_samples, n_features, normalize=True,
188-
with_ties=use_efron, use_float_32=use_float_32,
189-
random_state=random_state)
187+
X, y = make_dummy_survival_data(n_samples, n_features, normalize=True,
188+
with_ties=use_efron, use_float_32=use_float_32,
189+
random_state=random_state)
190+
tm, s = y[:, 0], y[:, 1]
190191

191192
# compute alpha_max
192193
B = (tm >= tm[:, None]).astype(X.dtype)
@@ -199,17 +200,17 @@ def test_CoxEstimator(use_efron, use_float_32):
199200
datafit = compiled_clone(Cox(use_efron))
200201
penalty = compiled_clone(L1(alpha))
201202

202-
datafit.initialize(X, (tm, s))
203+
datafit.initialize(X, y)
203204

204205
w, *_ = ProxNewton(
205206
fit_intercept=False, tol=1e-6, max_iter=50
206207
).solve(
207-
X, (tm, s), datafit, penalty
208+
X, y, datafit, penalty
208209
)
209210

210211
# fit lifeline estimator
211-
stacked_tm_s_X = np.hstack((tm[:, None], s[:, None], X))
212-
df = pd.DataFrame(stacked_tm_s_X)
212+
stacked_y_X = np.hstack((y, X))
213+
df = pd.DataFrame(stacked_y_X)
213214

214215
estimator = CoxPHFitter(penalizer=alpha, l1_ratio=1.)
215216
estimator.fit(
@@ -218,8 +219,8 @@ def test_CoxEstimator(use_efron, use_float_32):
218219
)
219220
w_ll = estimator.params_.values.astype(X.dtype)
220221

221-
p_obj_skglm = datafit.value((tm, s), w, X @ w) + penalty.value(w)
222-
p_obj_ll = datafit.value((tm, s), w_ll, X @ w_ll) + penalty.value(w_ll)
222+
p_obj_skglm = datafit.value(y, w, X @ w) + penalty.value(w)
223+
p_obj_ll = datafit.value(y, w_ll, X @ w_ll) + penalty.value(w_ll)
223224

224225
# though norm of solution might differ
225226
np.testing.assert_allclose(p_obj_skglm, p_obj_ll, atol=1e-6)
@@ -232,9 +233,10 @@ def test_CoxEstimator_sparse(use_efron, use_float_32):
232233
n_samples, n_features = 100, 30
233234
X_density, random_state = 0.5, 1265
234235

235-
tm, s, X = make_dummy_survival_data(n_samples, n_features, X_density=X_density,
236-
use_float_32=use_float_32, with_ties=use_efron,
237-
random_state=random_state)
236+
X, y = make_dummy_survival_data(n_samples, n_features, X_density=X_density,
237+
use_float_32=use_float_32, with_ties=use_efron,
238+
random_state=random_state)
239+
tm, s = y[:, 0], y[:, 1]
238240

239241
# compute alpha_max
240242
B = (tm >= tm[:, None]).astype(X.dtype)
@@ -247,12 +249,12 @@ def test_CoxEstimator_sparse(use_efron, use_float_32):
247249
datafit = compiled_clone(Cox(use_efron))
248250
penalty = compiled_clone(L1(alpha))
249251

250-
datafit.initialize_sparse(X.data, X.indptr, X.indices, (tm, s))
252+
datafit.initialize_sparse(X.data, X.indptr, X.indices, y)
251253

252254
*_, stop_crit = ProxNewton(
253255
fit_intercept=False, tol=1e-6, max_iter=50
254256
).solve(
255-
X, (tm, s), datafit, penalty
257+
X, y, datafit, penalty
256258
)
257259

258260
np.testing.assert_allclose(stop_crit, 0., atol=1e-6)

skglm/tests/test_lbfgs_solver.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -52,27 +52,27 @@ def test_L2_Cox(use_efron):
5252
alpha = 10.
5353
n_samples, n_features = 100, 50
5454

55-
tm, s, X = make_dummy_survival_data(
55+
X, y = make_dummy_survival_data(
5656
n_samples, n_features, normalize=True,
5757
with_ties=use_efron, random_state=0)
5858

5959
datafit = compiled_clone(Cox(use_efron))
6060
penalty = compiled_clone(L2(alpha))
6161

62-
datafit.initialize(X, (tm, s))
63-
w, *_ = LBFGS().solve(X, (tm, s), datafit, penalty)
62+
datafit.initialize(X, y)
63+
w, *_ = LBFGS().solve(X, y, datafit, penalty)
6464

6565
# fit lifeline estimator
66-
stacked_tm_s_X = np.hstack((tm[:, None], s[:, None], X))
67-
df = pd.DataFrame(stacked_tm_s_X)
66+
stacked_y_X = np.hstack((y, X))
67+
df = pd.DataFrame(stacked_y_X)
6868

6969
estimator = CoxPHFitter(penalizer=alpha, l1_ratio=0.).fit(
7070
df, duration_col=0, event_col=1
7171
)
7272
w_ll = estimator.params_.values
7373

74-
p_obj_skglm = datafit.value((tm, s), w, X @ w) + penalty.value(w)
75-
p_obj_ll = datafit.value((tm, s), w_ll, X @ w_ll) + penalty.value(w_ll)
74+
p_obj_skglm = datafit.value(y, w, X @ w) + penalty.value(w)
75+
p_obj_ll = datafit.value(y, w_ll, X @ w_ll) + penalty.value(w_ll)
7676

7777
# despite increasing tol in lifelines, solutions are quite far apart
7878
# suspecting lifelines https://github.com/CamDavidsonPilon/lifelines/pull/1534

skglm/utils/data.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -160,14 +160,13 @@ def make_dummy_survival_data(n_samples, n_features, normalize=False, X_density=1
160160
161161
Returns
162162
-------
163-
tm : array-like, shape (n_samples,)
164-
The vector of recording the time of event occurrences
165-
166-
s : array-like, shape (n_samples,)
167-
The vector of indicating samples censorship
168-
169163
X : array-like, shape (n_samples, n_features)
170164
The matrix of predictors. If ``density < 1``, a CSC sparse matrix is returned.
165+
166+
y : array-like, shape (n_samples, 2)
167+
Two-column array where the first column ``tm`` is the vector
168+
recording the time of event occurrences, and the second column ``s``
169+
is the vector of sample censoring.
171170
"""
172171
rng = np.random.RandomState(random_state)
173172
dtype = np.float64 if use_float_32 is False else np.float32
@@ -189,7 +188,10 @@ def make_dummy_survival_data(n_samples, n_features, normalize=False, X_density=1
189188
if normalize and X_density == 1.:
190189
X = StandardScaler().fit_transform(X)
191190

192-
return tm, s, X
191+
# stack (tm, s)
192+
y = np.column_stack((tm, s)).astype(dtype, order='F')
193+
194+
return X, y
193195

194196

195197
def grp_converter(groups, n_features):

0 commit comments

Comments
 (0)