diff --git a/skglm/skglm_jax/README.md b/skglm/skglm_jax/README.md new file mode 100644 index 000000000..52146bafc --- /dev/null +++ b/skglm/skglm_jax/README.md @@ -0,0 +1,22 @@ +## Installation + + +1. create then activate ``conda`` environnement +```shell +# create +conda create -n skglm-jax python=3.10 + +# activate env +conda activate skglm-jax +``` + +2. install ``skglm`` in editable mode +```shell +pip install skglm -e . +``` + +3. install dependencies +```shell +# jax +conda install jaxlib=*=*cuda* jax cuda-nvcc -c conda-forge -c nvidia +``` diff --git a/skglm/skglm_jax/__init__.py b/skglm/skglm_jax/__init__.py new file mode 100644 index 000000000..9f3937231 --- /dev/null +++ b/skglm/skglm_jax/__init__.py @@ -0,0 +1,12 @@ +# if not set, raises an error related to CUDA linking API. +# as recommended, setting the 'XLA_FLAGS' to bypass it. +# side-effect: (perhaps) slow compilation time. +# import os +# os.environ['XLA_FLAGS'] = '--xla_gpu_force_compilation_parallelism=1' # noqa + +# set flag to resolve bug with `jax.linalg.norm` +# ref: https://github.com/google/jax/issues/8916#issuecomment-1101113497 +# os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = "False" # noqa + +import jax +jax.config.update("jax_enable_x64", True) diff --git a/skglm/skglm_jax/anderson_cd.py b/skglm/skglm_jax/anderson_cd.py new file mode 100644 index 000000000..a3b45f433 --- /dev/null +++ b/skglm/skglm_jax/anderson_cd.py @@ -0,0 +1,141 @@ +from functools import partial + +import jax +import numpy as np +import jax.numpy as jnp + +from skglm.skglm_jax.datafits import QuadraticJax +from skglm.skglm_jax.penalties import L1Jax +from skglm.skglm_jax.utils import JaxAA + + +class AndersonCD: + + EPS_TOL = 0.3 + + def __init__(self, max_iter=100, max_epochs=100, tol=1e-6, p0=10, + use_acc=False, verbose=0): + self.max_iter = max_iter + self.max_epochs = max_epochs + self.tol = tol + self.p0 = p0 + self.use_acc = use_acc + self.verbose = verbose + + def solve(self, X, y, datafit: QuadraticJax, penalty: L1Jax): + X, y = self._transfer_to_device(X, y) + + n_samples, n_features = X.shape + lipschitz = datafit.get_features_lipschitz_cst(X, y) + + w = jnp.zeros(n_features) + Xw = jnp.zeros(n_samples) + all_features = jnp.full(n_features, fill_value=True, dtype=bool) + + for it in range(self.max_iter): + + # check convergence + grad = datafit.gradient_ws(X, y, w, Xw, all_features) + scores = penalty.subdiff_dist_ws(w, grad, all_features) + stop_crit = jnp.max(scores) + + if self.verbose: + p_obj = datafit.value(X, y, w) + penalty.value(w) + + print( + f"Iteration {it}: p_obj_in={p_obj:.8f} " + f"stop_crit_in={stop_crit:.4e}" + ) + + if stop_crit <= self.tol: + break + + # build ws + gsupp_size = penalty.generalized_support(w).sum() + ws_size = min( + max(2 * gsupp_size, self.p0), + n_features + ) + + ws = jnp.full(n_features, fill_value=False, dtype=bool) + ws_features = jnp.argsort(scores)[-ws_size:] + ws = ws.at[ws_features].set(True) + + tol_in = AndersonCD.EPS_TOL * stop_crit + + w, Xw = self._solve_sub_problem(X, y, w, Xw, ws, lipschitz, tol_in, + datafit, penalty) + + w_cpu = np.asarray(w) + return w_cpu + + def _solve_sub_problem(self, X, y, w, Xw, ws, lipschitz, tol_in, + datafit, penalty): + + if self.use_acc: + accelerator = JaxAA(K=5) + + for epoch in range(self.max_epochs): + + w, Xw = self._cd_epoch(X, y, w, Xw, ws, lipschitz, + datafit, penalty) + + if self.use_acc: + w, Xw = accelerator.extrapolate(w, Xw) + + # check convergence + grad_ws = datafit.gradient_ws(X, y, w, Xw, ws) + scores_ws = penalty.subdiff_dist_ws(w, grad_ws, ws) + stop_crit_in = jnp.max(scores_ws) + + if max(self.verbose - 1, 0): + p_obj_in = datafit.value(X, y, w) + penalty.value(w) + + print( + f"Epoch {epoch}: p_obj_in={p_obj_in:.8f} " + f"stop_crit_in={stop_crit_in:.4e}" + ) + + if stop_crit_in <= tol_in: + break + + return w, Xw + + @partial(jax.jit, static_argnums=(0, -2, -1)) + def _cd_epoch(self, X, y, w, Xw, ws, lipschitz, datafit, penalty): + for j, in_ws in enumerate(ws): + + w, Xw = jax.lax.cond( + in_ws, + lambda X, y, w, Xw, j, lipschitz: self._cd_epoch_j(X, y, w, Xw, j, lipschitz, datafit, penalty), # noqa + lambda X, y, w, Xw, j, lipschitz: (w, Xw), + *(X, y, w, Xw, j, lipschitz) + ) + + return w, Xw + + @partial(jax.jit, static_argnums=(0, -2, -1)) + def _cd_epoch_j(self, X, y, w, Xw, j, lipschitz, datafit, penalty): + + # Null columns of X would break this functions + # as their corresponding lipschitz is 0 + # TODO: implement condition using lax + # if lipschitz[j] == 0.: + # continue + + step = 1 / lipschitz[j] + + grad_j = datafit.gradient_1d(X, y, w, Xw, j) + next_w_j = penalty.prox_1d(w[j] - step * grad_j, step) + + delta_w_j = next_w_j - w[j] + + w = w.at[j].set(next_w_j) + Xw = Xw + delta_w_j * X[:, j] + + return w, Xw + + def _transfer_to_device(self, X, y): + # TODO: other checks + # - skip if they are already jax array + return jnp.asarray(X), jnp.asarray(y) diff --git a/skglm/skglm_jax/datafits.py b/skglm/skglm_jax/datafits.py new file mode 100644 index 000000000..948133df6 --- /dev/null +++ b/skglm/skglm_jax/datafits.py @@ -0,0 +1,48 @@ +import jax +import jax.numpy as jnp +from jax.numpy.linalg import norm as jnorm + +from skglm.skglm_jax.utils import jax_jit_method + + +class QuadraticJax: + """1 / (2 n_samples) ||y - Xw||^2""" + + def value(self, X, y, w): + n_samples = X.shape[0] + return ((X @ w - y) ** 2).sum() / (2. * n_samples) + + def gradient_1d(self, X, y, w, Xw, j): + n_samples = X.shape[0] + return X[:, j] @ (Xw - y) / n_samples + + @jax_jit_method + def gradient_ws(self, X, y, w, Xw, ws): + n_features = X.shape[1] + Xw_minus_y = Xw - y + + grad_ws = jnp.empty(n_features) + for j, in_ws in enumerate(ws): + + grad_j = jax.lax.cond( + in_ws, + lambda X, Xw_minus_y, j: X[:, j] @ Xw_minus_y / len(Xw_minus_y), + lambda X, Xw_minus_y, j: 0., + *(X, Xw_minus_y, j) + ) + + grad_ws = grad_ws.at[j].set(grad_j) + + return grad_ws + + def get_features_lipschitz_cst(self, X, y): + n_samples = X.shape[0] + return jnorm(X, ord=2, axis=0) ** 2 / n_samples + + def get_global_lipschitz_cst(self, X, y): + n_samples = X.shape[0] + return jnorm(X, ord=2) ** 2 / n_samples + + def gradient(self, X, y, w): + n_samples = X.shape[0] + return X.T @ (X @ w - y) / n_samples diff --git a/skglm/skglm_jax/fista.py b/skglm/skglm_jax/fista.py new file mode 100644 index 000000000..f39b26a2e --- /dev/null +++ b/skglm/skglm_jax/fista.py @@ -0,0 +1,79 @@ +import numpy as np + +import jax +import jax.numpy as jnp + +from skglm.skglm_jax.datafits import QuadraticJax +from skglm.skglm_jax.penalties import L1Jax + + +class Fista: + + def __init__(self, max_iter=200, use_auto_diff=True, verbose=0): + self.max_iter = max_iter + self.use_auto_diff = use_auto_diff + self.verbose = verbose + + def solve(self, X, y, datafit: QuadraticJax, penalty: L1Jax): + n_samples, n_features = X.shape + X_gpu, y_gpu = jnp.asarray(X), jnp.asarray(y) + + # compute step + lipschitz = datafit.get_global_lipschitz_cst(X_gpu, y_gpu) + if lipschitz == 0.: + return np.zeros(n_features) + + step = 1 / lipschitz + all_features = jnp.full(n_features, fill_value=True, dtype=bool) + + # get grad func of datafit + if self.use_auto_diff: + auto_grad = jax.jit(jax.grad(datafit.value, argnums=-1)) + + # init vars in device + w = jnp.zeros(n_features) + old_w = jnp.zeros(n_features) + mid_w = jnp.zeros(n_features) + grad = jnp.zeros(n_features) + + t_old, t_new = 1, 1 + + for it in range(self.max_iter): + + # compute grad + if self.use_auto_diff: + grad = auto_grad(X_gpu, y_gpu, mid_w) + else: + grad = datafit.gradient(X_gpu, y_gpu, mid_w) + + # forward / backward + val = mid_w - step * grad + w = penalty.prox(val, step) + + if self.verbose: + p_obj = datafit.value(X_gpu, y_gpu, w) + penalty.value(w) + + if self.use_auto_diff: + grad = auto_grad(X_gpu, y_gpu, w) + else: + grad = datafit.gradient(X_gpu, y_gpu, w) + + scores = penalty.subdiff_dist_ws(w, grad, all_features) + stop_crit = jnp.max(scores) + + print( + f"Iteration {it:4}: p_obj={p_obj:.8f}, opt crit={stop_crit:.4e}" + ) + + # extrapolate + mid_w = w + ((t_old - 1) / t_new) * (w - old_w) + + # update FISTA vars + t_old = t_new + t_new = 0.5 * (1 + jnp.sqrt(1. + 4. * t_old ** 2)) + old_w = jnp.copy(w) + + # transfer back to host + w_cpu = np.asarray(w, dtype=np.float64) + + return w_cpu diff --git a/skglm/skglm_jax/penalties.py b/skglm/skglm_jax/penalties.py new file mode 100644 index 000000000..b9a27c0fe --- /dev/null +++ b/skglm/skglm_jax/penalties.py @@ -0,0 +1,54 @@ +import jax +import jax.numpy as jnp + +from skglm.skglm_jax.utils import jax_jit_method + + +class L1Jax: + """alpha ||w||_1""" + + def __init__(self, alpha): + self.alpha = alpha + + def value(self, w): + return (self.alpha * jnp.abs(w)).sum() + + def prox_1d(self, value, stepsize): + shifted_value = jnp.abs(value) - stepsize * self.alpha + return jnp.sign(value) * jnp.maximum(shifted_value, 0.) + + def prox(self, value, stepsize): + return self.prox_1d(value, stepsize) + + @jax_jit_method + def subdiff_dist_ws(self, w, grad_ws, ws): + n_features = w.shape[0] + dist = jnp.empty(n_features) + + for j, in_ws in enumerate(ws): + w_j = w[j] + grad_j = grad_ws[j] + + dist_j = jax.lax.cond( + in_ws, + self._compute_subdiff_dist_j, + lambda w_j, grad_j: 0., + *(w_j, grad_j) + ) + + dist = dist.at[j].set(dist_j) + + return dist + + def generalized_support(self, w): + return w != 0. + + @jax_jit_method + def _compute_subdiff_dist_j(self, w_j, grad_j): + dist_j = jax.lax.cond( + w_j == 0., + lambda w_j, grad_j, alpha: jnp.maximum(jnp.abs(grad_j) - alpha, 0.), + lambda w_j, grad_j, alpha: jnp.abs(grad_j + jnp.sign(w_j) * alpha), + *(w_j, grad_j, self.alpha) + ) + return dist_j diff --git a/skglm/skglm_jax/tests/test_anderson_cd.py b/skglm/skglm_jax/tests/test_anderson_cd.py new file mode 100644 index 000000000..f66834860 --- /dev/null +++ b/skglm/skglm_jax/tests/test_anderson_cd.py @@ -0,0 +1,45 @@ +import pytest + +import numpy as np +from numpy.linalg import norm +from skglm.utils.data import make_correlated_data + +from skglm.skglm_jax.anderson_cd import AndersonCD +from skglm.skglm_jax.fista import Fista +from skglm.skglm_jax.datafits import QuadraticJax +from skglm.skglm_jax.penalties import L1Jax + +from skglm.estimators import Lasso + + +@pytest.mark.parametrize( + "solver", [AndersonCD(), + Fista(use_auto_diff=True), + Fista(use_auto_diff=False)]) +def test_solver(solver): + random_state = 135 + n_samples, n_features = 10_000, 100 + + X, y, _ = make_correlated_data(n_samples, n_features, random_state=random_state) + + lmbd_max = norm(X.T @ y, ord=np.inf) / n_samples + lmbd = 1e-2 * lmbd_max + + datafit = QuadraticJax() + penalty = L1Jax(lmbd) + w = solver.solve(X, y, datafit, penalty) + + estimator = Lasso(alpha=lmbd, fit_intercept=False).fit(X, y) + + np.testing.assert_allclose(w, estimator.coef_, atol=1e-4) + + +if __name__ == "__main__": + import time + + start = time.perf_counter() + test_solver(AndersonCD(verbose=2)) + end = time.perf_counter() + + print("Elapsed time:", end - start) + pass diff --git a/skglm/skglm_jax/utils.py b/skglm/skglm_jax/utils.py new file mode 100644 index 000000000..8c560214c --- /dev/null +++ b/skglm/skglm_jax/utils.py @@ -0,0 +1,40 @@ +import jax +import jax.numpy as jnp +from functools import partial + + +jax_jit_method = partial(jax.jit, static_argnums=(0,)) + + +class JaxAA: + + def __init__(self, K): + self.K, self.current_iter = K, 0 + self.arr_w_, self.arr_Xw_ = None, None + + def extrapolate(self, w, Xw): + if self.arr_w_ is None or self.arr_Xw_ is None: + self.arr_w_ = jnp.zeros((w.shape[0], self.K+1)) + self.arr_Xw_ = jnp.zeros((Xw.shape[0], self.K+1)) + + if self.current_iter <= self.K: + self.arr_w_ = self.arr_w_.at[:, self.current_iter].set(w) + self.arr_Xw_ = self.arr_Xw_.at[:, self.current_iter].set(Xw) + self.current_iter += 1 + return w, Xw + + # compute residuals + U = jnp.diff(self.arr_w_, axis=1) + + # compute extrapolation coefs + try: + inv_UTU_ones = jnp.linalg.solve(U.T @ U, jnp.ones(self.K)) + except Exception: + return w, Xw + finally: + self.current_iter = 0 + + # extrapolate + C = inv_UTU_ones / jnp.sum(inv_UTU_ones) + + return self.arr_w_[:, 1:] @ C, self.arr_Xw_[:, 1:] @ C