|
14 | 14 | from sklearn.svm import LinearSVC as LinearSVC_sklearn
|
15 | 15 | from sklearn.utils.estimator_checks import check_estimator
|
16 | 16 |
|
| 17 | +import scipy.optimize |
17 | 18 | from scipy.sparse import csc_matrix, issparse
|
18 | 19 |
|
19 | 20 | from skglm.utils.data import make_correlated_data, make_dummy_survival_data
|
20 | 21 | from skglm.estimators import (
|
21 | 22 | GeneralizedLinearEstimator, Lasso, MultiTaskLasso, WeightedLasso, ElasticNet,
|
22 | 23 | MCPRegression, SparseLogisticRegression, LinearSVC)
|
23 | 24 | from skglm.datafits import Logistic, Quadratic, QuadraticSVC, QuadraticMultiTask, Cox
|
24 |
| -from skglm.penalties import L1, IndicatorBox, L1_plus_L2, MCPenalty, WeightedL1 |
25 |
| -from skglm.solvers import AndersonCD |
| 25 | +from skglm.penalties import L1, IndicatorBox, L1_plus_L2, MCPenalty, WeightedL1, SLOPE |
| 26 | +from skglm.solvers import AndersonCD, FISTA |
26 | 27 |
|
27 | 28 | import pandas as pd
|
28 | 29 | from skglm.solvers import ProxNewton
|
@@ -326,6 +327,87 @@ def test_Cox_sk_compatibility():
|
326 | 327 | check_estimator(CoxEstimator())
|
327 | 328 |
|
328 | 329 |
|
| 330 | +@pytest.mark.parametrize("use_efron, issparse", product([True, False], repeat=2)) |
| 331 | +def test_equivalence_cox_SLOPE_cox_L1(use_efron, issparse): |
| 332 | + # this only tests the case of SLOPE equivalent to L1 (equal alphas) |
| 333 | + reg = 1e-2 |
| 334 | + n_samples, n_features = 100, 10 |
| 335 | + X_density = 1. if not issparse else 0.2 |
| 336 | + |
| 337 | + X, y = make_dummy_survival_data( |
| 338 | + n_samples, n_features, with_ties=use_efron, X_density=X_density, |
| 339 | + random_state=0) |
| 340 | + |
| 341 | + # init datafit |
| 342 | + datafit = compiled_clone(Cox(use_efron)) |
| 343 | + |
| 344 | + if not issparse: |
| 345 | + datafit.initialize(X, y) |
| 346 | + else: |
| 347 | + datafit.initialize_sparse(X.data, X.indptr, X.indices, y) |
| 348 | + |
| 349 | + # compute alpha_max |
| 350 | + grad_0 = datafit.raw_grad(y, np.zeros(n_samples)) |
| 351 | + alpha_max = np.linalg.norm(X.T @ grad_0, ord=np.inf) |
| 352 | + |
| 353 | + # init penalty |
| 354 | + alpha = reg * alpha_max |
| 355 | + alphas = alpha * np.ones(n_features) |
| 356 | + penalty = compiled_clone(SLOPE(alphas)) |
| 357 | + |
| 358 | + solver = FISTA(opt_strategy="fixpoint", max_iter=10_000, tol=1e-9) |
| 359 | + |
| 360 | + w, *_ = solver.solve(X, y, datafit, penalty) |
| 361 | + |
| 362 | + method = 'efron' if use_efron else 'breslow' |
| 363 | + estimator = CoxEstimator(alpha, l1_ratio=1., method=method, tol=1e-9).fit(X, y) |
| 364 | + |
| 365 | + np.testing.assert_allclose(w, estimator.coef_, atol=1e-6) |
| 366 | + |
| 367 | + |
| 368 | +@pytest.mark.parametrize("use_efron", [True, False]) |
| 369 | +def test_cox_SLOPE(use_efron): |
| 370 | + reg = 1e-2 |
| 371 | + n_samples, n_features = 100, 10 |
| 372 | + |
| 373 | + X, y = make_dummy_survival_data( |
| 374 | + n_samples, n_features, with_ties=use_efron, random_state=0) |
| 375 | + |
| 376 | + # init datafit |
| 377 | + datafit = compiled_clone(Cox(use_efron)) |
| 378 | + datafit.initialize(X, y) |
| 379 | + |
| 380 | + # compute alpha_max |
| 381 | + grad_0 = datafit.raw_grad(y, np.zeros(n_samples)) |
| 382 | + alpha_ref = np.linalg.norm(X.T @ grad_0, ord=np.inf) |
| 383 | + |
| 384 | + # init penalty |
| 385 | + alpha = reg * alpha_ref |
| 386 | + alphas = alpha / np.arange(n_features + 1)[1:] |
| 387 | + penalty = compiled_clone(SLOPE(alphas)) |
| 388 | + |
| 389 | + solver = FISTA(opt_strategy="fixpoint", max_iter=10_000, tol=1e-9) |
| 390 | + |
| 391 | + w, *_ = solver.solve(X, y, datafit, penalty) |
| 392 | + |
| 393 | + result = scipy.optimize.minimize( |
| 394 | + fun=lambda w: datafit.value(y, w, X @ w) + penalty.value(w), |
| 395 | + x0=np.zeros(n_features), |
| 396 | + method="SLSQP", |
| 397 | + options=dict( |
| 398 | + ftol=1e-9, |
| 399 | + maxiter=10_000, |
| 400 | + ), |
| 401 | + ) |
| 402 | + w_sp = result.x |
| 403 | + |
| 404 | + # check both methods yield the same objective |
| 405 | + np.testing.assert_allclose( |
| 406 | + datafit.value(y, w, X @ w) + penalty.value(w), |
| 407 | + datafit.value(y, w_sp, X @ w_sp) + penalty.value(w_sp) |
| 408 | + ) |
| 409 | + |
| 410 | + |
329 | 411 | # Test if GeneralizedLinearEstimator returns the correct coefficients
|
330 | 412 | @pytest.mark.parametrize("Datafit, Penalty, Estimator, pen_args", [
|
331 | 413 | (Quadratic, L1, Lasso, [alpha]),
|
|
0 commit comments