-
Couldn't load subscription status.
- Fork 38
POC - skglm GPU support
#149
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
skglm/gpu/solvers/jax_solver.py
Outdated
| from skglm.gpu.utils.host_utils import compute_obj, eval_opt_crit # noqa | ||
|
|
||
|
|
||
| class JaxSolver: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In order to test the design choices, can you call this FISTAJax and make it modular (ie pass an objective function directly) ? and same for other solvers
|
close in favor of #155 |
This explores adding GPU support to the package.
It implements FISTA solver for Lasso problem using
JAXCuPyNumba.cudaRefer to README to install and play around with the solvers