-
Notifications
You must be signed in to change notification settings - Fork 38
POC - skglm
GPU support
#149
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
Changes from 15 commits
Commits
Show all changes
52 commits
Select commit
Hold shift + click to select a range
42f6736
FISTA CPU
Badr-MOUFAD 116dda1
cupy solver
Badr-MOUFAD ff79040
unittest eval optimality condition
Badr-MOUFAD 1385900
cleanups cpu solver
Badr-MOUFAD e63cdb7
jax solver
Badr-MOUFAD d49291c
add unittest jax solver
Badr-MOUFAD ec8c663
pass flake8
Badr-MOUFAD b703d25
numba solver layout
Badr-MOUFAD c82e352
numba cuda utils
Badr-MOUFAD 8f57931
fix numba solver
Badr-MOUFAD 68c0b71
unittest numba solver
Badr-MOUFAD 0677519
numba solver example
Badr-MOUFAD cd686b3
move into solvers folder
Badr-MOUFAD dd7298a
add README to install
Badr-MOUFAD 3b95086
fix conda env name
Badr-MOUFAD 204c5e7
fix bug numba solver
Badr-MOUFAD b7ce4fe
update unittest & example
Badr-MOUFAD 71819a0
fix bug init numba
Badr-MOUFAD b6372f9
base class for Fista solvers
Badr-MOUFAD 97ceb4d
sparse matrix support CPU & CuPy
Badr-MOUFAD 2ad2f2c
unittest sparse data
Badr-MOUFAD b8c753a
base quadratic and L1
Badr-MOUFAD 872f9b3
refactor CPU solver
Badr-MOUFAD 0aa15b3
test utils and fixes
Badr-MOUFAD 14efe7b
unittest FISTA CPU
Badr-MOUFAD 05a4f36
sparse data unittest
Badr-MOUFAD d23ac12
modular CuPy solver
Badr-MOUFAD 601eb86
fix cupy verbose
Badr-MOUFAD 761ab54
modular jax
Badr-MOUFAD 6274e5f
unittest jax
Badr-MOUFAD 3af102e
sparse matrices modular jax
Badr-MOUFAD 715d3fb
modular Numba solver
Badr-MOUFAD 34cc4a8
unittest numba && dev utils
Badr-MOUFAD b6f971c
comments && prob formula
Badr-MOUFAD a7d1375
Numba with shared memory
Badr-MOUFAD c786a12
Numba shared memory version
Badr-MOUFAD 2c4cd63
kernels as static methods && Numba fix tests
Badr-MOUFAD e3ac70a
sparse Numba solver
Badr-MOUFAD ce4367d
fix bug numba gradient
Badr-MOUFAD c8fd8a1
fix bug numba sparse residual
Badr-MOUFAD a6df22a
n_samples instead of shape
Badr-MOUFAD cf5dc9e
Numba_solver: striding for scalable kernels
Badr-MOUFAD 20f9274
Numba_L1: striding for scalable kernels
Badr-MOUFAD d8d7157
Numba sparse datafit: striding
Badr-MOUFAD 4e4e6c1
Numba dense datafit: striding
Badr-MOUFAD 1d07d9e
info comments Numba solver
Badr-MOUFAD ca9f694
update installation && normalize df and pen cupy
Badr-MOUFAD 324cac5
pytorch solver [buggy]
Badr-MOUFAD c5c1dfe
fix grad bug pytorch solver && unittest
Badr-MOUFAD 32f1014
pytorch solver sparse data
Badr-MOUFAD 0caa9f9
set order between jax pytorch && xfail sparse and auto_diff false
Badr-MOUFAD 545a27f
test on obj value
Badr-MOUFAD File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
## Installation | ||
|
||
1. checkout branch | ||
```shell | ||
# add remote if it does't exist (check with: git remote -v) | ||
git remote add Badr-MOUFAD https://github.com/Badr-MOUFAD/skglm.git | ||
|
||
git fetch Badr-MOUFAD skglm-gpu | ||
|
||
git checkout skglm-gpu | ||
``` | ||
|
||
2. create then activate``conda`` environnement | ||
```shell | ||
# create | ||
conda create -n skglm-gpu python=3.7 | ||
|
||
# activate env | ||
conda activate skglm-gpu | ||
``` | ||
|
||
3. install ``skglm`` in editable mode | ||
```shell | ||
pip install skglm -e . | ||
``` | ||
|
||
4. install dependencies | ||
```shell | ||
# cupy | ||
conda conda install -c conda-forge cupy cudatoolkit=11.5 | ||
|
||
# jax | ||
conda install jaxlib=*=*cuda* jax cuda-nvcc -c conda-forge -c nvidia | ||
``` |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
"""Solve Lasso problem using FISTA GPU-implementation. | ||
|
||
Problem reads:: | ||
|
||
min_w (1/2) * ||y - Xw||^2 + lmbd * ||w||_1 | ||
""" |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
import time | ||
|
||
import numpy as np | ||
from numpy.linalg import norm | ||
|
||
from skglm.gpu.solvers import NumbaSolver, CPUSolver | ||
|
||
from skglm.gpu.utils.host_utils import compute_obj, eval_opt_crit | ||
|
||
|
||
random_state = 1265 | ||
n_samples, n_features = 10_000, 500 | ||
reg = 1e-2 | ||
|
||
# generate dummy data | ||
rng = np.random.RandomState(random_state) | ||
X = rng.randn(n_samples, n_features) | ||
y = rng.randn(n_samples) | ||
|
||
|
||
# set lambda | ||
lmbd_max = norm(X.T @ y, ord=np.inf) | ||
lmbd = reg * lmbd_max | ||
|
||
solver = NumbaSolver(verbose=0) | ||
solver.max_iter = 10 | ||
solver.solve(X, y, lmbd) | ||
|
||
# solve problem | ||
start = time.perf_counter() | ||
solver.max_iter = 1000 | ||
w_gpu = solver.solve(X, y, lmbd) | ||
end = time.perf_counter() | ||
|
||
print("gpu time: ", end - start) | ||
|
||
|
||
solver_cpu = CPUSolver() | ||
start = time.perf_counter() | ||
w_cpu = solver_cpu.solve(X, y, lmbd) | ||
end = time.perf_counter() | ||
print("sklearn time: ", end - start) | ||
|
||
|
||
print( | ||
"Objective\n" | ||
f"gpu : {compute_obj(X, y, lmbd, w_gpu):.8f}\n" | ||
f"cpu : {compute_obj(X, y, lmbd, w_cpu):.8f}" | ||
) | ||
|
||
|
||
print( | ||
"Optimality condition\n" | ||
f"gpu : {eval_opt_crit(X, y, lmbd, w_gpu):.8f}\n" | ||
f"cpu : {eval_opt_crit(X, y, lmbd, w_cpu):.8f}" | ||
) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
from skglm.gpu.solvers.cpu_solver import CPUSolver # noqa | ||
from skglm.gpu.solvers.cupy_solver import CupySolver # noqa | ||
from skglm.gpu.solvers.jax_solver import JaxSolver # noqa | ||
from skglm.gpu.solvers.numba_solver import NumbaSolver # noqa |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
import numpy as np | ||
|
||
from skglm.utils.prox_funcs import ST_vec | ||
from skglm.gpu.utils.host_utils import compute_obj, eval_opt_crit | ||
|
||
|
||
class CPUSolver: | ||
|
||
def __init__(self, max_iter=1000, verbose=0): | ||
self.max_iter = max_iter | ||
self.verbose = verbose | ||
|
||
def solve(self, X, y, lmbd): | ||
n_samples, n_features = X.shape | ||
|
||
# compute step | ||
lipschitz = np.linalg.norm(X, ord=2) ** 2 | ||
if lipschitz == 0.: | ||
return np.zeros(n_features) | ||
|
||
step = 1 / lipschitz | ||
|
||
# init vars | ||
w = np.zeros(n_features) | ||
old_w = np.zeros(n_features) | ||
mid_w = np.zeros(n_features) | ||
grad = np.zeros(n_features) | ||
|
||
t_old, t_new = 1, 1 | ||
|
||
for it in range(self.max_iter): | ||
|
||
# compute grad | ||
grad = X.T @ (X @ mid_w - y) | ||
|
||
# forward / backward | ||
mid_w = mid_w - step * grad | ||
w = ST_vec(mid_w, step * lmbd) | ||
|
||
if self.verbose: | ||
p_obj = compute_obj(X, y, lmbd, w) | ||
opt_crit = eval_opt_crit(X, y, lmbd, w) | ||
|
||
print( | ||
f"Iteration {it:4}: p_obj={p_obj:.8f}, opt crit={opt_crit:.4e}" | ||
) | ||
|
||
# extrapolate | ||
mid_w = w + ((t_old - 1) / t_new) * (w - old_w) | ||
|
||
# update FISTA vars | ||
t_old = t_new | ||
t_new = (1 + np.sqrt(1 + 4 * t_old ** 2)) / 2 | ||
old_w = np.copy(w) | ||
|
||
return w |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
import cupy as cp | ||
import numpy as np | ||
|
||
from skglm.gpu.utils.host_utils import compute_obj, eval_opt_crit | ||
|
||
|
||
class CupySolver: | ||
|
||
def __init__(self, max_iter=1000, verbose=0): | ||
self.max_iter = max_iter | ||
self.verbose = verbose | ||
|
||
def solve(self, X, y, lmbd): | ||
n_samples, n_features = X.shape | ||
|
||
# compute step | ||
lipschitz = np.linalg.norm(X, ord=2) ** 2 | ||
if lipschitz == 0.: | ||
return np.zeros(n_features) | ||
|
||
step = 1 / lipschitz | ||
|
||
# transfer to device | ||
X_gpu = cp.array(X) | ||
y_gpu = cp.array(y) | ||
|
||
# init vars in device | ||
w = cp.zeros(n_features) | ||
old_w = cp.zeros(n_features) | ||
mid_w = cp.zeros(n_features) | ||
grad = cp.zeros(n_features) | ||
|
||
t_old, t_new = 1, 1 | ||
|
||
for it in range(self.max_iter): | ||
|
||
# compute grad | ||
cp.dot(X_gpu.T, X_gpu @ mid_w - y_gpu, out=grad) | ||
|
||
# forward / backward: w = ST(mid_w - step * grad, step * lmbd) | ||
mid_w = mid_w - step * grad | ||
w = cp.sign(mid_w) * cp.maximum(cp.abs(mid_w) - step * lmbd, 0.) | ||
|
||
if self.verbose: | ||
w_cpu = cp.asnumpy(w) | ||
|
||
p_obj = compute_obj(X, y, lmbd, w_cpu) | ||
opt_crit = eval_opt_crit(X, y, lmbd, w_cpu) | ||
|
||
print( | ||
f"Iteration {it:4}: p_obj={p_obj:.8f}, opt crit={opt_crit:.4e}" | ||
) | ||
|
||
# extrapolate | ||
mid_w = w + ((t_old - 1) / t_new) * (w - old_w) | ||
|
||
# update FISTA vars | ||
t_old = t_new | ||
t_new = (1 + cp.sqrt(1 + 4 * t_old ** 2)) / 2 | ||
old_w = cp.copy(w) | ||
|
||
# transfer back to host | ||
w_cpu = cp.asnumpy(w) | ||
|
||
return w_cpu |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
# if not set, raises an error related to CUDA linking API. | ||
# as recommended, setting the 'XLA_FLAGS' to bypass it. | ||
# side-effect: (perhaps) slow compilation time. | ||
import os | ||
os.environ['XLA_FLAGS'] = '--xla_gpu_force_compilation_parallelism=1' # noqa | ||
|
||
import numpy as np # noqa | ||
|
||
import jax # noqa | ||
import jax.numpy as jnp # noqa | ||
# set float64 as default float type. | ||
# if not, amplifies rounding errors. | ||
jax.config.update("jax_enable_x64", True) # noqa | ||
|
||
from skglm.gpu.utils.host_utils import compute_obj, eval_opt_crit # noqa | ||
|
||
|
||
class JaxSolver: | ||
|
||
def __init__(self, max_iter=1000, use_auto_diff=True, verbose=0): | ||
self.max_iter = max_iter | ||
self.use_auto_diff = use_auto_diff | ||
self.verbose = verbose | ||
|
||
def solve(self, X, y, lmbd): | ||
n_samples, n_features = X.shape | ||
|
||
# compute step | ||
lipschitz = np.linalg.norm(X, ord=2) ** 2 | ||
if lipschitz == 0.: | ||
return np.zeros(n_features) | ||
|
||
step = 1 / lipschitz | ||
|
||
# transfer to device | ||
X_gpu = jnp.asarray(X) | ||
y_gpu = jnp.asarray(y) | ||
|
||
# get grad func of datafit | ||
if self.use_auto_diff: | ||
grad_quad_loss = jax.grad(_quad_loss) | ||
|
||
# init vars in device | ||
w = jnp.zeros(n_features) | ||
old_w = jnp.zeros(n_features) | ||
mid_w = jnp.zeros(n_features) | ||
grad = jnp.zeros(n_features) | ||
|
||
t_old, t_new = 1, 1 | ||
|
||
for it in range(self.max_iter): | ||
|
||
# compute grad | ||
if self.use_auto_diff: | ||
grad = grad_quad_loss(mid_w, X_gpu, y_gpu) | ||
else: | ||
grad = jnp.dot(X_gpu.T, jnp.dot(X_gpu, mid_w) - y_gpu) | ||
|
||
# forward / backward | ||
mid_w = mid_w - step * grad | ||
w = jnp.sign(mid_w) * jnp.maximum(jnp.abs(mid_w) - step * lmbd, 0.) | ||
|
||
if self.verbose: | ||
w_cpu = np.asarray(w, dtype=np.float64) | ||
|
||
p_obj = compute_obj(X, y, lmbd, w_cpu) | ||
opt_crit = eval_opt_crit(X, y, lmbd, w_cpu) | ||
|
||
print( | ||
f"Iteration {it:4}: p_obj={p_obj:.8f}, opt crit={opt_crit:.4e}" | ||
) | ||
|
||
# extrapolate | ||
mid_w = w + ((t_old - 1) / t_new) * (w - old_w) | ||
|
||
# update FISTA vars | ||
t_old = t_new | ||
t_new = 0.5 * (1 + jnp.sqrt(1. + 4. * t_old ** 2)) | ||
old_w = jnp.copy(w) | ||
|
||
# transfer back to host | ||
w_cpu = np.asarray(w, dtype=np.float64) | ||
|
||
return w_cpu | ||
|
||
|
||
def _quad_loss(w, X_gpu, y_gpu): | ||
pred_y = jnp.dot(X_gpu, w) | ||
return 0.5 * jnp.sum((y_gpu - pred_y) ** 2) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In order to test the design choices, can you call this FISTAJax and make it modular (ie pass an objective function directly) ? and same for other solvers