Skip to content

Commit 10fa99c

Browse files
authored
ENH add global_lipschitz to Cox datafit (#180)
1 parent e1a27e1 commit 10fa99c

File tree

3 files changed

+110
-5
lines changed

3 files changed

+110
-5
lines changed

skglm/datafits/single_task.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -599,6 +599,7 @@ def get_spec(self):
599599
('use_efron', bool_),
600600
('T_indptr', int64[:]), ('T_indices', int64[:]),
601601
('H_indptr', int64[:]), ('H_indices', int64[:]),
602+
('global_lipschitz', float64),
602603
)
603604

604605
def params_to_dict(self):
@@ -693,6 +694,20 @@ def initialize_sparse(self, X_data, X_indptr, X_indices, y):
693694
# small hack to avoid repetitive code: pass in X_data as only its dtype is used
694695
self.initialize(X_data, y)
695696

697+
def init_global_lipschitz(self, X, y):
698+
s = y[:, 1]
699+
700+
n_samples = X.shape[0]
701+
self.global_lipschitz = s.sum() * norm(X, ord=2) ** 2 / n_samples
702+
703+
def init_global_lipschitz_sparse(self, X_data, X_indptr, X_indices, y):
704+
s = y[:, 1]
705+
706+
n_samples = s.shape[0]
707+
norm_X = spectral_norm(X_data, X_indptr, X_indices, n_samples)
708+
709+
self.global_lipschitz = s.sum() * norm_X ** 2 / n_samples
710+
696711
def _B_dot_vec(self, vec):
697712
# compute `B @ vec` in O(n) instead of O(n^2)
698713
out = np.zeros_like(vec)

skglm/solvers/fista.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,11 +63,19 @@ def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None):
6363
t_old = t_new
6464
t_new = (1 + np.sqrt(1 + 4 * t_old ** 2)) / 2
6565
w_old = w.copy()
66+
6667
if X_is_sparse:
67-
grad = construct_grad_sparse(
68-
X.data, X.indptr, X.indices, y, z, X @ z, datafit, all_features)
68+
if hasattr(datafit, "gradient_sparse"):
69+
grad = datafit.gradient_sparse(
70+
X.data, X.indptr, X.indices, y, X @ z)
71+
else:
72+
grad = construct_grad_sparse(
73+
X.data, X.indptr, X.indices, y, z, X @ z, datafit, all_features)
6974
else:
70-
grad = construct_grad(X, y, z, X @ z, datafit, all_features)
75+
if hasattr(datafit, "gradient"):
76+
grad = datafit.gradient(X, y, X @ z)
77+
else:
78+
grad = construct_grad(X, y, z, X @ z, datafit, all_features)
7179

7280
step = 1 / lipschitz
7381
z -= step * grad

skglm/tests/test_estimators.py

Lines changed: 84 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,16 @@
1414
from sklearn.svm import LinearSVC as LinearSVC_sklearn
1515
from sklearn.utils.estimator_checks import check_estimator
1616

17+
import scipy.optimize
1718
from scipy.sparse import csc_matrix, issparse
1819

1920
from skglm.utils.data import make_correlated_data, make_dummy_survival_data
2021
from skglm.estimators import (
2122
GeneralizedLinearEstimator, Lasso, MultiTaskLasso, WeightedLasso, ElasticNet,
2223
MCPRegression, SparseLogisticRegression, LinearSVC)
2324
from skglm.datafits import Logistic, Quadratic, QuadraticSVC, QuadraticMultiTask, Cox
24-
from skglm.penalties import L1, IndicatorBox, L1_plus_L2, MCPenalty, WeightedL1
25-
from skglm.solvers import AndersonCD
25+
from skglm.penalties import L1, IndicatorBox, L1_plus_L2, MCPenalty, WeightedL1, SLOPE
26+
from skglm.solvers import AndersonCD, FISTA
2627

2728
import pandas as pd
2829
from skglm.solvers import ProxNewton
@@ -326,6 +327,87 @@ def test_Cox_sk_compatibility():
326327
check_estimator(CoxEstimator())
327328

328329

330+
@pytest.mark.parametrize("use_efron, issparse", product([True, False], repeat=2))
331+
def test_equivalence_cox_SLOPE_cox_L1(use_efron, issparse):
332+
# this only tests the case of SLOPE equivalent to L1 (equal alphas)
333+
reg = 1e-2
334+
n_samples, n_features = 100, 10
335+
X_density = 1. if not issparse else 0.2
336+
337+
X, y = make_dummy_survival_data(
338+
n_samples, n_features, with_ties=use_efron, X_density=X_density,
339+
random_state=0)
340+
341+
# init datafit
342+
datafit = compiled_clone(Cox(use_efron))
343+
344+
if not issparse:
345+
datafit.initialize(X, y)
346+
else:
347+
datafit.initialize_sparse(X.data, X.indptr, X.indices, y)
348+
349+
# compute alpha_max
350+
grad_0 = datafit.raw_grad(y, np.zeros(n_samples))
351+
alpha_max = np.linalg.norm(X.T @ grad_0, ord=np.inf)
352+
353+
# init penalty
354+
alpha = reg * alpha_max
355+
alphas = alpha * np.ones(n_features)
356+
penalty = compiled_clone(SLOPE(alphas))
357+
358+
solver = FISTA(opt_strategy="fixpoint", max_iter=10_000, tol=1e-9)
359+
360+
w, *_ = solver.solve(X, y, datafit, penalty)
361+
362+
method = 'efron' if use_efron else 'breslow'
363+
estimator = CoxEstimator(alpha, l1_ratio=1., method=method, tol=1e-9).fit(X, y)
364+
365+
np.testing.assert_allclose(w, estimator.coef_, atol=1e-6)
366+
367+
368+
@pytest.mark.parametrize("use_efron", [True, False])
369+
def test_cox_SLOPE(use_efron):
370+
reg = 1e-2
371+
n_samples, n_features = 100, 10
372+
373+
X, y = make_dummy_survival_data(
374+
n_samples, n_features, with_ties=use_efron, random_state=0)
375+
376+
# init datafit
377+
datafit = compiled_clone(Cox(use_efron))
378+
datafit.initialize(X, y)
379+
380+
# compute alpha_max
381+
grad_0 = datafit.raw_grad(y, np.zeros(n_samples))
382+
alpha_ref = np.linalg.norm(X.T @ grad_0, ord=np.inf)
383+
384+
# init penalty
385+
alpha = reg * alpha_ref
386+
alphas = alpha / np.arange(n_features + 1)[1:]
387+
penalty = compiled_clone(SLOPE(alphas))
388+
389+
solver = FISTA(opt_strategy="fixpoint", max_iter=10_000, tol=1e-9)
390+
391+
w, *_ = solver.solve(X, y, datafit, penalty)
392+
393+
result = scipy.optimize.minimize(
394+
fun=lambda w: datafit.value(y, w, X @ w) + penalty.value(w),
395+
x0=np.zeros(n_features),
396+
method="SLSQP",
397+
options=dict(
398+
ftol=1e-9,
399+
maxiter=10_000,
400+
),
401+
)
402+
w_sp = result.x
403+
404+
# check both methods yield the same objective
405+
np.testing.assert_allclose(
406+
datafit.value(y, w, X @ w) + penalty.value(w),
407+
datafit.value(y, w_sp, X @ w_sp) + penalty.value(w_sp)
408+
)
409+
410+
329411
# Test if GeneralizedLinearEstimator returns the correct coefficients
330412
@pytest.mark.parametrize("Datafit, Penalty, Estimator, pen_args", [
331413
(Quadratic, L1, Lasso, [alpha]),

0 commit comments

Comments
 (0)