Skip to content

Commit e1a27e1

Browse files
authored
ENH - Implement a Cox scikit-learn-like Estimator (#171)
1 parent 52c8319 commit e1a27e1

File tree

6 files changed

+248
-13
lines changed

6 files changed

+248
-13
lines changed

doc/api.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ Estimators
1515
:toctree: generated/
1616

1717
GeneralizedLinearEstimator
18+
CoxEstimator
1819
ElasticNet
1920
Lasso
2021
LinearSVC

skglm/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,5 @@
22

33
from skglm.estimators import ( # noqa F401
44
Lasso, WeightedLasso, ElasticNet, MCPRegression, MultiTaskLasso, LinearSVC,
5-
SparseLogisticRegression, GeneralizedLinearEstimator
5+
SparseLogisticRegression, GeneralizedLinearEstimator, CoxEstimator
66
)

skglm/estimators.py

Lines changed: 171 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,22 +4,25 @@
44
import numpy as np
55
from scipy.sparse import issparse
66
from scipy.special import expit
7-
from skglm.solvers.prox_newton import ProxNewton
7+
from numbers import Integral, Real
8+
from skglm.solvers import ProxNewton, LBFGS
89

9-
from sklearn.utils.validation import check_is_fitted
10-
from sklearn.utils import check_array, check_consistent_length
10+
from sklearn.utils.validation import (check_is_fitted, check_array,
11+
check_consistent_length)
1112
from sklearn.linear_model._base import (
1213
LinearModel, RegressorMixin,
1314
LinearClassifierMixin, SparseCoefMixin, BaseEstimator
1415
)
1516
from sklearn.utils.extmath import softmax
1617
from sklearn.preprocessing import LabelEncoder
18+
from sklearn.utils._param_validation import Interval, StrOptions
1719
from sklearn.multiclass import OneVsRestClassifier, check_classification_targets
1820

1921
from skglm.utils.jit_compilation import compiled_clone
2022
from skglm.solvers import AndersonCD, MultiTaskBCD
21-
from skglm.datafits import Quadratic, Logistic, QuadraticSVC, QuadraticMultiTask
22-
from skglm.penalties import L1, WeightedL1, L1_plus_L2, MCPenalty, IndicatorBox, L2_1
23+
from skglm.datafits import Cox, Quadratic, Logistic, QuadraticSVC, QuadraticMultiTask
24+
from skglm.penalties import (L1, WeightedL1, L1_plus_L2, L2,
25+
MCPenalty, IndicatorBox, L2_1)
2326

2427

2528
def _glm_fit(X, y, model, datafit, penalty, solver):
@@ -1159,6 +1162,169 @@ def fit(self, X, y):
11591162
# TODO add predict_proba for LinearSVC
11601163

11611164

1165+
class CoxEstimator(LinearModel):
1166+
r"""Elastic Cox estimator with Efron and Breslow estimate.
1167+
1168+
Refer to :ref:`Mathematics behind Cox datafit <maths_cox_datafit>`
1169+
for details about the datafit expression. The data convention for the estimator is
1170+
1171+
- ``X`` the design matrix with ``n_features`` predictors
1172+
- ``y`` a two-column array where the first ``tm`` is of event time occurrences
1173+
and the second ``s`` is of censoring.
1174+
1175+
For L2-regularized Cox (``l1_ratio=0.``) :ref:`LBFGS <skglm.solvers.LBFGS>`
1176+
is the used solver, otherwise it is :ref:`ProxNewton <skglm.solvers.ProxNewton>`.
1177+
1178+
Parameters
1179+
----------
1180+
alpha : float, optional
1181+
Penalty strength. It must be strictly positive.
1182+
1183+
l1_ratio : float, default=0.5
1184+
The ElasticNet mixing parameter, with ``0 <= l1_ratio <= 1``. For
1185+
``l1_ratio = 0`` the penalty is an L2 penalty. ``For l1_ratio = 1`` it
1186+
is an L1 penalty. For ``0 < l1_ratio < 1``, the penalty is a
1187+
combination of L1 and L2.
1188+
1189+
method : {'efron', 'breslow'}, default='efron'
1190+
The estimate used for the Cox datafit. Use ``efron`` to
1191+
handle tied observations.
1192+
1193+
tol : float, optional
1194+
Stopping criterion for the optimization.
1195+
1196+
max_iter : int, optional
1197+
The maximum number of iterations to solve the problem.
1198+
1199+
verbose : bool or int
1200+
Amount of verbosity.
1201+
1202+
Attributes
1203+
----------
1204+
coef_ : array, shape (n_features,)
1205+
Parameter vector of Cox regression.
1206+
1207+
stop_crit_ : float
1208+
The value of the stopping criterion at convergence.
1209+
"""
1210+
1211+
_parameter_constraints: dict = {
1212+
"alpha": [Interval(Real, 0, None, closed="neither")],
1213+
"l1_ratio": [Interval(Real, 0, 1, closed="both")],
1214+
"method": [StrOptions({"efron", "breslow"})],
1215+
"tol": [Interval(Real, 0, None, closed="left")],
1216+
"max_iter": [Interval(Integral, 1, None, closed="left")],
1217+
"verbose": ["boolean", Interval(Integral, 0, 2, closed="both")],
1218+
}
1219+
1220+
def __init__(self, alpha=1., l1_ratio=0.7, method="efron", tol=1e-4,
1221+
max_iter=50, verbose=False):
1222+
self.alpha = alpha
1223+
self.l1_ratio = l1_ratio
1224+
self.method = method
1225+
self.tol = tol
1226+
self.max_iter = max_iter
1227+
self.verbose = verbose
1228+
1229+
def fit(self, X, y):
1230+
"""Fit Cox estimator.
1231+
1232+
Parameters
1233+
----------
1234+
X : array-like, shape (n_samples, n_features)
1235+
Design matrix.
1236+
1237+
y : array-like, shape (n_samples, 2)
1238+
Two-column array where the first is of event time occurrences
1239+
and the second is of censoring. If it is of dimension 1, it is
1240+
assumed to be the times vector and there no censoring.
1241+
1242+
Returns
1243+
-------
1244+
self :
1245+
The fitted estimator.
1246+
"""
1247+
self._validate_params()
1248+
1249+
# validate input data
1250+
X = check_array(
1251+
X,
1252+
accept_sparse="csc",
1253+
order="F",
1254+
dtype=[np.float64, np.float32],
1255+
input_name="X",
1256+
)
1257+
if y is None:
1258+
# Needed to pass check estimator. Message error is
1259+
# copy/paste from https://github.com/scikit-learn/scikit-learn/blob/ \
1260+
# 23ff51c07ebc03c866984e93c921a8993e96d1f9/sklearn/utils/ \
1261+
# estimator_checks.py#L3886
1262+
raise ValueError("requires y to be passed, but the target y is None")
1263+
y = check_array(
1264+
y,
1265+
accept_sparse=False,
1266+
order="F",
1267+
dtype=X.dtype,
1268+
ensure_2d=False,
1269+
input_name="y",
1270+
)
1271+
if y.ndim == 1:
1272+
warnings.warn(
1273+
f"{repr(self)} requires the vector of response `y` to have "
1274+
f"two columns. Got one column.\nAssuming that `y` "
1275+
"is the vector of times and there is no censoring."
1276+
)
1277+
y = np.column_stack((y, np.ones_like(y))).astype(X.dtype, order="F")
1278+
elif y.shape[1] > 2:
1279+
raise ValueError(
1280+
f"{repr(self)} requires the vector of response `y` to have "
1281+
f"two columns. Got {y.shape[1]} columns."
1282+
)
1283+
1284+
check_consistent_length(X, y)
1285+
1286+
# init datafit and penalty
1287+
datafit = Cox(self.method)
1288+
1289+
if self.l1_ratio == 1.:
1290+
penalty = L1(self.alpha)
1291+
elif 0. < self.l1_ratio < 1.:
1292+
penalty = L1_plus_L2(self.alpha, self.l1_ratio)
1293+
else:
1294+
penalty = L2(self.alpha)
1295+
1296+
# skglm internal: JIT compile classes
1297+
datafit = compiled_clone(datafit)
1298+
penalty = compiled_clone(penalty)
1299+
1300+
# init solver
1301+
if self.l1_ratio == 0.:
1302+
solver = LBFGS(max_iter=self.max_iter, tol=self.tol, verbose=self.verbose)
1303+
else:
1304+
solver = ProxNewton(
1305+
max_iter=self.max_iter, tol=self.tol, verbose=self.verbose,
1306+
fit_intercept=False,
1307+
)
1308+
1309+
# solve problem
1310+
if not issparse(X):
1311+
datafit.initialize(X, y)
1312+
else:
1313+
datafit.initialize_sparse(X.data, X.indptr, X.indices, y)
1314+
1315+
w, _, stop_crit = solver.solve(X, y, datafit, penalty)
1316+
1317+
# save to attribute
1318+
self.coef_ = w
1319+
self.stop_crit_ = stop_crit
1320+
1321+
self.intercept_ = 0.
1322+
self.n_features_in_ = X.shape[1]
1323+
self.feature_names_in_ = np.arange(X.shape[1])
1324+
1325+
return self
1326+
1327+
11621328
class MultiTaskLasso(LinearModel, RegressorMixin):
11631329
r"""MultiTaskLasso estimator.
11641330

skglm/solvers/lbfgs.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def callback_post_iter(w_k):
6262

6363
if self.verbose:
6464
grad = jac(w_k)
65-
stop_crit = norm(grad)
65+
stop_crit = norm(grad, ord=np.inf)
6666

6767
it = len(p_objs_out)
6868
print(
@@ -82,7 +82,8 @@ def callback_post_iter(w_k):
8282
method="L-BFGS-B",
8383
options=dict(
8484
maxiter=self.max_iter,
85-
gtol=self.tol
85+
gtol=self.tol,
86+
ftol=0. # set ftol=0. to control convergence using only gtol
8687
),
8788
callback=callback_post_iter,
8889
)
@@ -96,6 +97,8 @@ def callback_post_iter(w_k):
9697
)
9798

9899
w = result.x
99-
stop_crit = norm(result.jac)
100+
# scipy LBFGS uses || projected gradient ||_oo to check convergence, cf. `gtol`
101+
# in https://docs.scipy.org/doc/scipy/reference/optimize.minimize-lbfgsb.html
102+
stop_crit = norm(result.jac, ord=np.inf)
100103

101104
return w, np.asarray(p_objs_out), stop_crit

skglm/tests/test_estimators.py

Lines changed: 68 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import pandas as pd
2828
from skglm.solvers import ProxNewton
2929
from skglm.utils.jit_compilation import compiled_clone
30+
from skglm.estimators import CoxEstimator
3031

3132

3233
n_samples = 50
@@ -209,8 +210,7 @@ def test_CoxEstimator(use_efron, use_float_32):
209210
)
210211

211212
# fit lifeline estimator
212-
stacked_y_X = np.hstack((y, X))
213-
df = pd.DataFrame(stacked_y_X)
213+
df = pd.DataFrame(np.hstack((y, X)))
214214

215215
estimator = CoxPHFitter(penalizer=alpha, l1_ratio=1.)
216216
estimator.fit(
@@ -260,6 +260,72 @@ def test_CoxEstimator_sparse(use_efron, use_float_32):
260260
np.testing.assert_allclose(stop_crit, 0., atol=1e-6)
261261

262262

263+
@pytest.mark.parametrize("use_efron, l1_ratio", product([True, False], [1., 0.7, 0.]))
264+
def test_Cox_sk_like_estimator(use_efron, l1_ratio):
265+
try:
266+
from lifelines import CoxPHFitter
267+
except ModuleNotFoundError:
268+
pytest.xfail(
269+
"Testing Cox Estimator requires `lifelines` packages\n"
270+
"Run `pip install lifelines`"
271+
)
272+
273+
alpha = 1e-2
274+
# norms of solutions differ when n_features > n_samples
275+
n_samples, n_features = 100, 30
276+
method = "efron" if use_efron else "breslow"
277+
278+
X, y = make_dummy_survival_data(n_samples, n_features, normalize=True,
279+
with_ties=use_efron, random_state=0)
280+
281+
estimator_sk = CoxEstimator(
282+
alpha, l1_ratio=l1_ratio, method=method, tol=1e-6
283+
).fit(X, y)
284+
w_sk = estimator_sk.coef_
285+
286+
# fit lifeline estimator
287+
df = pd.DataFrame(np.hstack((y, X)))
288+
289+
estimator_ll = CoxPHFitter(penalizer=alpha, l1_ratio=l1_ratio)
290+
estimator_ll.fit(
291+
df, duration_col=0, event_col=1,
292+
fit_options={"max_steps": 10_000, "precision": 1e-12}
293+
)
294+
w_ll = estimator_ll.params_.values
295+
296+
# define datafit and penalty to check objs
297+
datafit = Cox(use_efron)
298+
penalty = L1_plus_L2(alpha, l1_ratio)
299+
datafit.initialize(X, y)
300+
301+
p_obj_skglm = datafit.value(y, w_sk, X @ w_sk) + penalty.value(w_sk)
302+
p_obj_ll = datafit.value(y, w_ll, X @ w_ll) + penalty.value(w_ll)
303+
304+
# though norm of solution might differ
305+
np.testing.assert_allclose(p_obj_skglm, p_obj_ll, atol=1e-6)
306+
307+
308+
@pytest.mark.parametrize("use_efron, l1_ratio", product([True, False], [1., 0.7, 0.]))
309+
def test_Cox_sk_like_estimator_sparse(use_efron, l1_ratio):
310+
alpha = 1e-2
311+
n_samples, n_features = 100, 30
312+
method = "efron" if use_efron else "breslow"
313+
314+
X, y = make_dummy_survival_data(n_samples, n_features, X_density=0.1,
315+
with_ties=use_efron, random_state=0)
316+
317+
estimator_sk = CoxEstimator(
318+
alpha, l1_ratio=l1_ratio, method=method, tol=1e-9
319+
).fit(X, y)
320+
stop_crit = estimator_sk.stop_crit_
321+
322+
np.testing.assert_array_less(stop_crit, 1e-9)
323+
324+
325+
def test_Cox_sk_compatibility():
326+
check_estimator(CoxEstimator())
327+
328+
263329
# Test if GeneralizedLinearEstimator returns the correct coefficients
264330
@pytest.mark.parametrize("Datafit, Penalty, Estimator, pen_args", [
265331
(Quadratic, L1, Lasso, [alpha]),
@@ -379,5 +445,4 @@ def test_warm_start(estimator_name):
379445

380446

381447
if __name__ == "__main__":
382-
test_CoxEstimator(True, True)
383448
pass

skglm/tests/test_lbfgs_solver.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def test_lbfgs_L2_logreg(X_sparse):
3636
tol=1e-12,
3737
).fit(X, y)
3838

39-
np.testing.assert_allclose(w, estimator.coef_.flatten())
39+
np.testing.assert_allclose(w, estimator.coef_.flatten(), atol=1e-5)
4040

4141

4242
@pytest.mark.parametrize("use_efron", [True, False])

0 commit comments

Comments
 (0)