Skip to content

Commit 415c1e9

Browse files
committed
fix cox model
1 parent 0facfd7 commit 415c1e9

File tree

3 files changed

+31
-17
lines changed

3 files changed

+31
-17
lines changed

skglm/solvers/base.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -94,8 +94,9 @@ def custom_checks(self, X, y, datafit, penalty):
9494
"""
9595
pass
9696

97-
def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None,
98-
*, run_checks=True):
97+
def solve(
98+
self, X, y, datafit, penalty, w_init=None, Xw_init=None, *, run_checks=True
99+
):
99100
"""Solve the optimization problem after validating its compatibility.
100101
101102
A proxy of ``_solve`` method that implicitly ensures the compatibility
@@ -108,7 +109,8 @@ def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None,
108109
"""
109110
if "jitclass" in str(type(datafit)):
110111
warnings.warn(
111-
"Do not pass a compiled datafit, compilation is done inside solver now")
112+
"Do not pass a compiled datafit, compilation is done inside solver now"
113+
)
112114
else:
113115
if datafit is not None:
114116
datafit = compiled_clone(datafit, to_float32=X.dtype == np.float32)

skglm/solvers/lbfgs.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,13 @@ def __init__(self, max_iter=50, tol=1e-4, verbose=False):
3838

3939
def _solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None):
4040

41+
# TODO: to be isolated in a seperated method
42+
is_sparse = issparse(X)
43+
if is_sparse:
44+
datafit.initialize_sparse(X.data, X.indptr, X.indices, y)
45+
else:
46+
datafit.initialize(X, y)
47+
4148
def objective(w):
4249
Xw = X @ w
4350
datafit_value = datafit.value(y, w, Xw)
@@ -70,8 +77,7 @@ def callback_post_iter(w_k):
7077

7178
it = len(p_objs_out)
7279
print(
73-
f"Iteration {it}: {p_obj:.10f}, "
74-
f"stopping crit: {stop_crit:.2e}"
80+
f"Iteration {it}: {p_obj:.10f}, " f"stopping crit: {stop_crit:.2e}"
7581
)
7682

7783
n_features = X.shape[1]
@@ -87,7 +93,7 @@ def callback_post_iter(w_k):
8793
options=dict(
8894
maxiter=self.max_iter,
8995
gtol=self.tol,
90-
ftol=0. # set ftol=0. to control convergence using only gtol
96+
ftol=0.0, # set ftol=0. to control convergence using only gtol
9197
),
9298
callback=callback_post_iter,
9399
)
@@ -97,7 +103,7 @@ def callback_post_iter(w_k):
97103
f"`LBFGS` did not converge for tol={self.tol:.3e} "
98104
f"and max_iter={self.max_iter}.\n"
99105
"Consider increasing `max_iter` and/or `tol`.",
100-
category=ConvergenceWarning
106+
category=ConvergenceWarning,
101107
)
102108

103109
w = result.x
@@ -110,7 +116,8 @@ def callback_post_iter(w_k):
110116
def custom_checks(self, X, y, datafit, penalty):
111117
# check datafit support sparse data
112118
check_attrs(
113-
datafit, solver=self,
119+
datafit,
120+
solver=self,
114121
required_attr=self._datafit_required_attr,
115-
support_sparse=issparse(X)
122+
support_sparse=issparse(X),
116123
)

skglm/tests/test_lbfgs_solver.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,15 @@
1313

1414
@pytest.mark.parametrize("X_sparse", [True, False])
1515
def test_lbfgs_L2_logreg(X_sparse):
16-
reg = 1.
17-
X_density = 1. if not X_sparse else 0.5
16+
reg = 1.0
17+
X_density = 1.0 if not X_sparse else 0.5
1818
n_samples, n_features = 100, 50
1919

2020
X, y, _ = make_correlated_data(
21-
n_samples, n_features, random_state=0, X_density=X_density,
21+
n_samples,
22+
n_features,
23+
random_state=0,
24+
X_density=X_density,
2225
)
2326
y = np.sign(y)
2427

@@ -29,7 +32,7 @@ def test_lbfgs_L2_logreg(X_sparse):
2932

3033
# fit scikit learn
3134
estimator = LogisticRegression(
32-
penalty='l2',
35+
penalty="l2",
3336
C=1 / (n_samples * reg),
3437
fit_intercept=False,
3538
tol=1e-12,
@@ -48,24 +51,26 @@ def test_L2_Cox(use_efron):
4851
"Run `pip install lifelines`"
4952
)
5053

51-
alpha = 10.
54+
alpha = 10.0
5255
n_samples, n_features = 100, 50
5356

5457
X, y = make_dummy_survival_data(
55-
n_samples, n_features, normalize=True,
56-
with_ties=use_efron, random_state=0)
58+
n_samples, n_features, normalize=True, with_ties=use_efron, random_state=0
59+
)
5760

5861
datafit = Cox(use_efron)
5962
penalty = L2(alpha)
6063

64+
# XXX: intialize is needed here although it is done in LBFGS
65+
# is used to evaluate the objective
6166
datafit.initialize(X, y)
6267
w, *_ = LBFGS().solve(X, y, datafit, penalty)
6368

6469
# fit lifeline estimator
6570
stacked_y_X = np.hstack((y, X))
6671
df = pd.DataFrame(stacked_y_X)
6772

68-
estimator = CoxPHFitter(penalizer=alpha, l1_ratio=0.).fit(
73+
estimator = CoxPHFitter(penalizer=alpha, l1_ratio=0.0).fit(
6974
df, duration_col=0, event_col=1
7075
)
7176
w_ll = estimator.params_.values

0 commit comments

Comments
 (0)