Skip to content

Commit 605bd79

Browse files
committed
compile inside solver
1 parent 7399ba0 commit 605bd79

File tree

4 files changed

+25
-16
lines changed

4 files changed

+25
-16
lines changed

examples/plot_sparse_recovery.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
from skglm.utils.data import make_correlated_data
1919
from skglm.solvers import AndersonCD
2020
from skglm.datafits import Quadratic
21-
from skglm.utils.jit_compilation import compiled_clone
2221
from skglm.penalties import L1, MCPenalty, L0_5, L2_3, SCAD
2322

2423
cmap = plt.get_cmap('tab10')
@@ -74,7 +73,7 @@
7473
for idx, estimator in enumerate(penalties.keys()):
7574
print(f'Running {estimator}...')
7675
estimator_path = solver.path(
77-
X, y, compiled_clone(datafit), compiled_clone(penalties[estimator]),
76+
X, y, datafit, penalties[estimator],
7877
alphas=alphas)
7978

8079
f1_temp = np.zeros(n_alphas)

examples/plot_survival_analysis.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,15 @@
1515
# Let's first generate synthetic data on which to run the Cox estimator,
1616
# using ``skglm`` data utils.
1717
#
18+
import warnings
19+
import time
20+
from lifelines import CoxPHFitter
21+
import pandas as pd
22+
import numpy as np
23+
from skglm.solvers import ProxNewton
24+
from skglm.penalties import L1
25+
from skglm.datafits import Cox
26+
import matplotlib.pyplot as plt
1827
from skglm.utils.data import make_dummy_survival_data
1928

2029
n_samples, n_features = 500, 100
@@ -34,7 +43,6 @@
3443
# * ``s`` indicates the observations censorship and follows a Bernoulli(0.5) distribution
3544
#
3645
# Let's inspect the data quickly:
37-
import matplotlib.pyplot as plt
3846

3947
fig, axes = plt.subplots(
4048
1, 3,
@@ -59,18 +67,14 @@
5967
# Todo so, we need to combine a Cox datafit and a :math:`\ell_1` penalty
6068
# and solve the resulting problem using skglm Proximal Newton solver ``ProxNewton``.
6169
# We set the intensity of the :math:`\ell_1` regularization to ``alpha=1e-2``.
62-
from skglm.datafits import Cox
63-
from skglm.penalties import L1
64-
from skglm.solvers import ProxNewton
6570

66-
from skglm.utils.jit_compilation import compiled_clone
6771

6872
# regularization intensity
6973
alpha = 1e-2
7074

7175
# skglm internals: init datafit and penalty
72-
datafit = compiled_clone(Cox())
73-
penalty = compiled_clone(L1(alpha))
76+
datafit = Cox()
77+
penalty = L1(alpha)
7478

7579
datafit.initialize(X, y)
7680

@@ -90,9 +94,6 @@
9094
# %%
9195
# Let's solve the problem with ``lifelines`` through its ``CoxPHFitter``
9296
# estimator and compare the objectives found by the two packages.
93-
import numpy as np
94-
import pandas as pd
95-
from lifelines import CoxPHFitter
9697

9798
# format data
9899
stacked_y_X = np.hstack((y, X))
@@ -126,8 +127,6 @@
126127
# let's compare their execution time. To get the evolution of the suboptimality
127128
# (objective - optimal objective) we run both estimators with increasing number of
128129
# iterations.
129-
import time
130-
import warnings
131130

132131
warnings.filterwarnings('ignore')
133132

@@ -230,7 +229,7 @@
230229
# We only need to pass in ``use_efron=True`` to the ``Cox`` datafit.
231230

232231
# ensure using Efron estimate
233-
datafit = compiled_clone(Cox(use_efron=True))
232+
datafit = Cox(use_efron=True)
234233
datafit.initialize(X, y)
235234

236235
# solve the problem

skglm/estimators.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
from sklearn.utils._param_validation import Interval, StrOptions
1919
from sklearn.multiclass import OneVsRestClassifier, check_classification_targets
2020

21-
from skglm.utils.jit_compilation import compiled_clone
2221
from skglm.solvers import AndersonCD, MultiTaskBCD, GroupBCD
2322
from skglm.datafits import (Cox, Quadratic, Logistic, QuadraticSVC,
2423
QuadraticMultiTask, QuadraticGroup)

skglm/solvers/base.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
1+
import warnings
12
from abc import abstractmethod, ABC
3+
4+
import numpy as np
5+
26
from skglm.utils.validation import check_attrs
7+
from skglm.utils.jit_compilation import compiled_clone
38

49

510
class BaseSolver(ABC):
@@ -101,6 +106,13 @@ def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None,
101106
>>> ...
102107
>>> coefs, obj_out, stop_crit = solver.solve(X, y, datafit, penalty)
103108
"""
109+
if "jitclass" in str(type(datafit)):
110+
warnings.warn(
111+
"Do not pass a compiled datafit, compilation is done inside solver now")
112+
else:
113+
datafit = compiled_clone(datafit, to_float32=X.dtype == np.float32)
114+
penalty = compiled_clone(penalty, to_float32=X.dtype == np.float32)
115+
104116
if run_checks:
105117
self._validate(X, y, datafit, penalty)
106118

0 commit comments

Comments
 (0)