Skip to content

ENH -jit-compile datafits and penalties inside solver #270

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions examples/plot_sparse_recovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from skglm.utils.data import make_correlated_data
from skglm.solvers import AndersonCD
from skglm.datafits import Quadratic
from skglm.utils.jit_compilation import compiled_clone
from skglm.penalties import L1, MCPenalty, L0_5, L2_3, SCAD

cmap = plt.get_cmap('tab10')
Expand Down Expand Up @@ -74,7 +73,7 @@
for idx, estimator in enumerate(penalties.keys()):
print(f'Running {estimator}...')
estimator_path = solver.path(
X, y, compiled_clone(datafit), compiled_clone(penalties[estimator]),
X, y, datafit, penalties[estimator],
alphas=alphas)

f1_temp = np.zeros(n_alphas)
Expand Down
25 changes: 12 additions & 13 deletions examples/plot_survival_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,15 @@
# Let's first generate synthetic data on which to run the Cox estimator,
# using ``skglm`` data utils.
#
import warnings
import time
from lifelines import CoxPHFitter
import pandas as pd
import numpy as np
from skglm.solvers import ProxNewton
from skglm.penalties import L1
from skglm.datafits import Cox
import matplotlib.pyplot as plt
from skglm.utils.data import make_dummy_survival_data

n_samples, n_features = 500, 100
Expand All @@ -34,7 +43,6 @@
# * ``s`` indicates the observations censorship and follows a Bernoulli(0.5) distribution
#
# Let's inspect the data quickly:
import matplotlib.pyplot as plt

fig, axes = plt.subplots(
1, 3,
Expand All @@ -59,18 +67,14 @@
# Todo so, we need to combine a Cox datafit and a :math:`\ell_1` penalty
# and solve the resulting problem using skglm Proximal Newton solver ``ProxNewton``.
# We set the intensity of the :math:`\ell_1` regularization to ``alpha=1e-2``.
from skglm.datafits import Cox
from skglm.penalties import L1
from skglm.solvers import ProxNewton

from skglm.utils.jit_compilation import compiled_clone

# regularization intensity
alpha = 1e-2

# skglm internals: init datafit and penalty
datafit = compiled_clone(Cox())
penalty = compiled_clone(L1(alpha))
datafit = Cox()
penalty = L1(alpha)

datafit.initialize(X, y)

Expand All @@ -90,9 +94,6 @@
# %%
# Let's solve the problem with ``lifelines`` through its ``CoxPHFitter``
# estimator and compare the objectives found by the two packages.
import numpy as np
import pandas as pd
from lifelines import CoxPHFitter

# format data
stacked_y_X = np.hstack((y, X))
Expand Down Expand Up @@ -126,8 +127,6 @@
# let's compare their execution time. To get the evolution of the suboptimality
# (objective - optimal objective) we run both estimators with increasing number of
# iterations.
import time
import warnings

warnings.filterwarnings('ignore')

Expand Down Expand Up @@ -230,7 +229,7 @@
# We only need to pass in ``use_efron=True`` to the ``Cox`` datafit.

# ensure using Efron estimate
datafit = compiled_clone(Cox(use_efron=True))
datafit = Cox(use_efron=True)
datafit.initialize(X, y)

# solve the problem
Expand Down
47 changes: 19 additions & 28 deletions skglm/estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from sklearn.utils._param_validation import Interval, StrOptions
from sklearn.multiclass import OneVsRestClassifier, check_classification_targets

from skglm.utils.jit_compilation import compiled_clone
from skglm.solvers import AndersonCD, MultiTaskBCD, GroupBCD
from skglm.datafits import (Cox, Quadratic, Logistic, QuadraticSVC,
QuadraticMultiTask, QuadraticGroup,)
Expand Down Expand Up @@ -102,12 +101,10 @@ def _glm_fit(X, y, model, datafit, penalty, solver):

n_samples, n_features = X_.shape

penalty_jit = compiled_clone(penalty)
datafit_jit = compiled_clone(datafit, to_float32=X.dtype == np.float32)
if issparse(X):
datafit_jit.initialize_sparse(X_.data, X_.indptr, X_.indices, y)
datafit.initialize_sparse(X_.data, X_.indptr, X_.indices, y)
else:
datafit_jit.initialize(X_, y)
datafit.initialize(X_, y)

# if model.warm_start and hasattr(model, 'coef_') and model.coef_ is not None:
if solver.warm_start and hasattr(model, 'coef_') and model.coef_ is not None:
Expand Down Expand Up @@ -136,7 +133,7 @@ def _glm_fit(X, y, model, datafit, penalty, solver):
"The size of the WeightedL1 penalty weights should be n_features, "
"expected %i, got %i." % (X_.shape[1], len(penalty.weights)))

coefs, p_obj, kkt = solver.solve(X_, y, datafit_jit, penalty_jit, w, Xw)
coefs, p_obj, kkt = solver.solve(X_, y, datafit, penalty, w, Xw)
model.coef_, model.stop_crit_ = coefs[:n_features], kkt
if y.ndim == 1:
model.intercept_ = coefs[-1] if fit_intercept else 0.
Expand Down Expand Up @@ -440,8 +437,8 @@ def path(self, X, y, alphas, coef_init=None, return_n_iter=True, **params):
The number of iterations along the path. If return_n_iter is set to
``True``.
"""
penalty = compiled_clone(L1(self.alpha, self.positive))
datafit = compiled_clone(Quadratic(), to_float32=X.dtype == np.float32)
penalty = L1(self.alpha, self.positive)
datafit = Quadratic()
solver = AndersonCD(
self.max_iter, self.max_epochs, self.p0, tol=self.tol,
ws_strategy=self.ws_strategy, fit_intercept=self.fit_intercept,
Expand Down Expand Up @@ -581,8 +578,8 @@ def path(self, X, y, alphas, coef_init=None, return_n_iter=True, **params):
raise ValueError("The number of weights must match the number of \
features. Got %s, expected %s." % (
len(weights), X.shape[1]))
penalty = compiled_clone(WeightedL1(self.alpha, weights, self.positive))
datafit = compiled_clone(Quadratic(), to_float32=X.dtype == np.float32)
penalty = WeightedL1(self.alpha, weights, self.positive)
datafit = Quadratic()
solver = AndersonCD(
self.max_iter, self.max_epochs, self.p0, tol=self.tol,
ws_strategy=self.ws_strategy, fit_intercept=self.fit_intercept,
Expand Down Expand Up @@ -744,8 +741,8 @@ def path(self, X, y, alphas, coef_init=None, return_n_iter=True, **params):
The number of iterations along the path. If return_n_iter is set to
``True``.
"""
penalty = compiled_clone(L1_plus_L2(self.alpha, self.l1_ratio, self.positive))
datafit = compiled_clone(Quadratic(), to_float32=X.dtype == np.float32)
penalty = L1_plus_L2(self.alpha, self.l1_ratio, self.positive)
datafit = Quadratic()
solver = AndersonCD(
self.max_iter, self.max_epochs, self.p0, tol=self.tol,
ws_strategy=self.ws_strategy, fit_intercept=self.fit_intercept,
Expand Down Expand Up @@ -917,19 +914,17 @@ def path(self, X, y, alphas, coef_init=None, return_n_iter=True, **params):
``True``.
"""
if self.weights is None:
penalty = compiled_clone(
MCPenalty(self.alpha, self.gamma, self.positive)
)
penalty = MCPenalty(self.alpha, self.gamma, self.positive)
else:
if X.shape[1] != len(self.weights):
raise ValueError(
"The number of weights must match the number of features. "
f"Got {len(self.weights)}, expected {X.shape[1]}."
)
penalty = compiled_clone(
WeightedMCPenalty(self.alpha, self.gamma, self.weights, self.positive)
)
datafit = compiled_clone(Quadratic(), to_float32=X.dtype == np.float32)
penalty = WeightedMCPenalty(
self.alpha, self.gamma, self.weights, self.positive)

datafit = Quadratic()
solver = AndersonCD(
self.max_iter, self.max_epochs, self.p0, tol=self.tol,
ws_strategy=self.ws_strategy, fit_intercept=self.fit_intercept,
Expand Down Expand Up @@ -1369,10 +1364,6 @@ def fit(self, X, y):
else:
penalty = L2(self.alpha)

# skglm internal: JIT compile classes
datafit = compiled_clone(datafit)
penalty = compiled_clone(penalty)

# init solver
if self.l1_ratio == 0.:
solver = LBFGS(max_iter=self.max_iter, tol=self.tol, verbose=self.verbose)
Expand Down Expand Up @@ -1518,14 +1509,14 @@ def fit(self, X, Y):
if not self.warm_start or not hasattr(self, "coef_"):
self.coef_ = None

datafit_jit = compiled_clone(QuadraticMultiTask(), X.dtype == np.float32)
penalty_jit = compiled_clone(L2_1(self.alpha), X.dtype == np.float32)
datafit = QuadraticMultiTask()
penalty = L2_1(self.alpha)

solver = MultiTaskBCD(
self.max_iter, self.max_epochs, self.p0, tol=self.tol,
ws_strategy=self.ws_strategy, fit_intercept=self.fit_intercept,
warm_start=self.warm_start, verbose=self.verbose)
W, obj_out, kkt = solver.solve(X, Y, datafit_jit, penalty_jit)
W, obj_out, kkt = solver.solve(X, Y, datafit, penalty)

self.coef_ = W[:X.shape[1], :].T
self.intercept_ = self.fit_intercept * W[-1, :]
Expand Down Expand Up @@ -1573,8 +1564,8 @@ def path(self, X, Y, alphas, coef_init=None, return_n_iter=False, **params):
The number of iterations along the path. If return_n_iter is set to
``True``.
"""
datafit = compiled_clone(QuadraticMultiTask(), to_float32=X.dtype == np.float32)
penalty = compiled_clone(L2_1(self.alpha))
datafit = QuadraticMultiTask()
penalty = L2_1(self.alpha)
solver = MultiTaskBCD(
self.max_iter, self.max_epochs, self.p0, tol=self.tol,
ws_strategy=self.ws_strategy, fit_intercept=self.fit_intercept,
Expand Down
4 changes: 2 additions & 2 deletions skglm/experimental/reweighted.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,9 @@ def fit(self, X, y):
f"penalty {self.penalty.__class__.__name__}")

n_features = X.shape[1]
_penalty = compiled_clone(WeightedL1(self.penalty.alpha, np.ones(n_features)))
self.datafit = compiled_clone(self.datafit)
# we need to compile this as it is not passed to solver.solve:
self.penalty = compiled_clone(self.penalty)
_penalty = WeightedL1(self.penalty.alpha, np.ones(n_features))

self.loss_history_ = []

Expand Down
5 changes: 2 additions & 3 deletions skglm/experimental/sqrt_lasso.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

from skglm.penalties import L1
from skglm.utils.prox_funcs import ST_vec, proj_L2ball, BST
from skglm.utils.jit_compilation import compiled_clone
from skglm.datafits.base import BaseDatafit
from skglm.solvers.prox_newton import ProxNewton

Expand Down Expand Up @@ -179,8 +178,8 @@ def path(self, X, y, alphas=None, eps=1e-3, n_alphas=10):
alphas = np.sort(alphas)[::-1]

n_features = X.shape[1]
sqrt_quadratic = compiled_clone(SqrtQuadratic())
l1_penalty = compiled_clone(L1(1.)) # alpha is set along the path
sqrt_quadratic = SqrtQuadratic()
l1_penalty = L1(1.) # alpha is set along the path

coefs = np.zeros((n_alphas, n_features))

Expand Down
5 changes: 2 additions & 3 deletions skglm/experimental/tests/test_quantile_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from skglm import GeneralizedLinearEstimator
from skglm.experimental.pdcd_ws import PDCD_WS
from skglm.experimental.quantile_regression import Pinball
from skglm.utils.jit_compilation import compiled_clone

from skglm.utils.data import make_correlated_data
from sklearn.linear_model import QuantileRegressor
Expand All @@ -23,8 +22,8 @@ def test_PDCD_WS(quantile_level):
alpha_max = norm(X.T @ (np.sign(y)/2 + (quantile_level - 0.5)), ord=np.inf)
alpha = alpha_max / 5

datafit = compiled_clone(Pinball(quantile_level))
penalty = compiled_clone(L1(alpha))
datafit = Pinball(quantile_level)
penalty = L1(alpha)

w = PDCD_WS(
dual_init=np.sign(y)/2 + (quantile_level - 0.5)
Expand Down
5 changes: 2 additions & 3 deletions skglm/experimental/tests/test_sqrt_lasso.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from skglm.experimental.sqrt_lasso import (SqrtLasso, SqrtQuadratic,
_chambolle_pock_sqrt)
from skglm.experimental.pdcd_ws import PDCD_WS
from skglm.utils.jit_compilation import compiled_clone


def test_alpha_max():
Expand Down Expand Up @@ -70,8 +69,8 @@ def test_PDCD_WS(with_dual_init):

dual_init = y / norm(y) if with_dual_init else None

datafit = compiled_clone(SqrtQuadratic())
penalty = compiled_clone(L1(alpha))
datafit = SqrtQuadratic()
penalty = L1(alpha)

w = PDCD_WS(dual_init=dual_init).solve(X, y, datafit, penalty)[0]
clf = SqrtLasso(alpha=alpha, tol=1e-12).fit(X, y)
Expand Down
28 changes: 26 additions & 2 deletions skglm/solvers/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
import warnings
from abc import abstractmethod, ABC

import numpy as np

from skglm.utils.validation import check_attrs
from skglm.utils.jit_compilation import compiled_clone


class BaseSolver(ABC):
Expand Down Expand Up @@ -89,8 +94,9 @@ def custom_checks(self, X, y, datafit, penalty):
"""
pass

def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None,
*, run_checks=True):
def solve(
self, X, y, datafit, penalty, w_init=None, Xw_init=None, *, run_checks=True
):
"""Solve the optimization problem after validating its compatibility.

A proxy of ``_solve`` method that implicitly ensures the compatibility
Expand All @@ -101,6 +107,24 @@ def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None,
>>> ...
>>> coefs, obj_out, stop_crit = solver.solve(X, y, datafit, penalty)
"""
# TODO do it properly instead of searching for a string
if "jitclass" in str(type(datafit)):
warnings.warn(
"Do not pass a compiled datafit, compilation is done inside solver now"
)
if "jitclass" in str(type(penalty)):
warnings.warn(
"Do not pass a compiled penalty, compilation is done inside solver now"
)
else:
if datafit is not None:
datafit = compiled_clone(datafit, to_float32=X.dtype == np.float32)
if penalty is not None:
penalty = compiled_clone(penalty)
# TODO add support for bool spec in compiled_clone
# currently, doing so break the code
# penalty = compiled_clone(penalty, to_float32=X.dtype == np.float32)

if run_checks:
self._validate(X, y, datafit, penalty)

Expand Down
3 changes: 1 addition & 2 deletions skglm/solvers/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,7 @@ def dist_fix_point_cd(w, grad_ws, lipschitz_ws, datafit, penalty, ws):


@njit
def dist_fix_point_bcd(
w, grad_ws, lipschitz_ws, datafit, penalty, ws):
def dist_fix_point_bcd(w, grad_ws, lipschitz_ws, datafit, penalty, ws):
"""Compute the violation of the fixed point iterate scheme for BCD.

Parameters
Expand Down
2 changes: 2 additions & 0 deletions skglm/solvers/fista.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,12 @@ def _solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None):
Xw = Xw_init.copy() if Xw_init is not None else np.zeros(n_samples)

if X_is_sparse:
datafit.initialize_sparse(X.data, X.indptr, X.indices, y)
lipschitz = datafit.get_global_lipschitz_sparse(
X.data, X.indptr, X.indices, y
)
else:
datafit.initialize(X, y)
lipschitz = datafit.get_global_lipschitz(X, y)

for n_iter in range(self.max_iter):
Expand Down
7 changes: 7 additions & 0 deletions skglm/solvers/group_prox_newton.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,13 @@ def _solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None):
stop_crit = 0.
p_objs_out = []

# TODO: to be isolated in a seperated method
is_sparse = issparse(X)
if is_sparse:
datafit.initialize_sparse(X.data, X.indptr, X.indices, y)
else:
datafit.initialize(X, y)

for iter in range(self.max_iter):
grad = _construct_grad(X, y, w, Xw, datafit, all_groups)

Expand Down
Loading