Skip to content

Commit f7a1983

Browse files
API design with penalty and datafit jitted inside fit (#44)
Co-authored-by: Badr MOUFAD <[email protected]>
1 parent 8a581b0 commit f7a1983

File tree

13 files changed

+408
-263
lines changed

13 files changed

+408
-263
lines changed

examples/plot_sparse_recovery.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from skglm.utils import make_correlated_data
1919
from skglm.solvers import cd_solver_path
2020
from skglm.datafits import Quadratic
21+
from skglm.utils import compiled_clone
2122
from skglm.penalties import L1, MCPenalty, L0_5, L2_3, SCAD
2223

2324
cmap = plt.get_cmap('tab10')
@@ -71,7 +72,8 @@
7172
for idx, estimator in enumerate(penalties.keys()):
7273
print(f'Running {estimator}...')
7374
estimator_path = cd_solver_path(
74-
X, y, datafit, penalties[estimator], alphas=alphas, ws_strategy="fixpoint")
75+
X, y, compiled_clone(datafit), compiled_clone(penalties[estimator]),
76+
alphas=alphas, ws_strategy="fixpoint")
7577

7678
f1_temp = np.zeros(n_alphas)
7779
prediction_error_temp = np.zeros(n_alphas)

skglm/datafits/__init__.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,6 @@
11
from .base import BaseDatafit, BaseMultitaskDatafit # noqa F401
22

3-
from .single_task import ( # noqa F401
4-
Quadratic, Quadratic_32, QuadraticSVC, QuadraticSVC_32, Logistic, Logistic_32,
5-
Huber, Huber_32,
6-
)
3+
from .single_task import Quadratic, QuadraticSVC, Logistic, Huber # noqa F401
74

85
from .multi_task import QuadraticMultiTask # noqa F401
96

skglm/datafits/base.py

Lines changed: 40 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1,62 +1,29 @@
11
from abc import abstractmethod
22

3-
import numba
4-
from numba import float32, float64
5-
from numba.experimental import jitclass
6-
7-
8-
def spec_to_float32(spec):
9-
"""Convert a numba specification to an equivalent float32 one.
10-
11-
Parameters
12-
----------
13-
spec : list
14-
A list of (name, dtype) for every attribute of a jitclass.
15-
16-
Returns
17-
-------
18-
spec32 : list
19-
A list of (name, dtype) for every attribute of a jitclass, where float64
20-
have been replaced by float32.
21-
"""
22-
spec32 = []
23-
for name, dtype in spec:
24-
if dtype == float64:
25-
dtype32 = float32
26-
elif isinstance(dtype, numba.core.types.npytypes.Array):
27-
dtype32 = dtype.copy(dtype=float32)
28-
else:
29-
raise ValueError(f"Unknown spec type {dtype}")
30-
spec32.append((name, dtype32))
31-
return spec32
32-
33-
34-
def jit_factory(Datafit, spec):
35-
"""JIT-compile a datafit class in float32 and float64 contexts.
36-
37-
Parameters
38-
----------
39-
Datafit : datafit class, inheriting from BaseDatafit
40-
A datafit class, to be compiled.
41-
42-
spec : list
43-
A list of type specifications for every attribute of Datafit.
44-
45-
Returns
46-
-------
47-
Datafit_64 : Jitclass
48-
A compiled datafit class with attribute types float64.
49-
50-
Datafit_32 : Jitclass
51-
A compiled datafit class with attribute types float32.
52-
"""
53-
spec32 = spec_to_float32(spec)
54-
return jitclass(spec)(Datafit), jitclass(spec32)(Datafit)
55-
563

574
class BaseDatafit():
585
"""Base class for datafits."""
596

7+
@abstractmethod
8+
def get_spec(self):
9+
"""Specify the numba types of the class attributes.
10+
11+
Returns
12+
-------
13+
spec: Tuple of (attribute_name, dtype)
14+
spec to be passed to Numba jitclass to compile the class.
15+
"""
16+
17+
@abstractmethod
18+
def params_to_dict(self):
19+
"""Get the parameters to initialize an instance of the class.
20+
21+
Returns
22+
-------
23+
dict_of_params : dict
24+
The parameters to instantiate an object of the class.
25+
"""
26+
6027
@abstractmethod
6128
def initialize(self, X, y):
6229
"""Pre-computations before fitting on X and y.
@@ -172,6 +139,26 @@ def gradient_scalar_sparse(self, X_data, X_indptr, X_indices, y, Xw, j):
172139
class BaseMultitaskDatafit():
173140
"""Base class for multitask datafits."""
174141

142+
@abstractmethod
143+
def get_spec(self):
144+
"""Specify the numba types of the class attributes.
145+
146+
Returns
147+
-------
148+
spec: Tuple of (attribute_name, dtype)
149+
spec to be passed to Numba jitclass to compile the class.
150+
"""
151+
152+
@abstractmethod
153+
def params_to_dict(self):
154+
"""Get the parameters to initialize an instance of the class.
155+
156+
Returns
157+
-------
158+
dict_of_params : dict
159+
The parameters to instantiate an object of the class.
160+
"""
161+
175162
@abstractmethod
176163
def initialize(self, X, Y):
177164
"""Store useful values before fitting on X and Y.

skglm/datafits/group.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,10 @@
11
import numpy as np
22
from numpy.linalg import norm
3-
from numba.experimental import jitclass
43
from numba import int32, float64
54

65
from skglm.datafits.base import BaseDatafit
76

87

9-
spec_QuadraticGroup = [
10-
('grp_ptr', int32[:]),
11-
('grp_indices', int32[:]),
12-
('lipschitz', float64[:])
13-
]
14-
15-
16-
@jitclass(spec_QuadraticGroup)
178
class QuadraticGroup(BaseDatafit):
189
r"""Quadratic datafit used with group penalties.
1910
@@ -38,6 +29,18 @@ class QuadraticGroup(BaseDatafit):
3829
def __init__(self, grp_ptr, grp_indices):
3930
self.grp_ptr, self.grp_indices = grp_ptr, grp_indices
4031

32+
def get_spec(self):
33+
spec = (
34+
('grp_ptr', int32[:]),
35+
('grp_indices', int32[:]),
36+
('lipschitz', float64[:])
37+
)
38+
return spec
39+
40+
def params_to_dict(self):
41+
return dict(grp_ptr=self.grp_ptr,
42+
grp_indices=self.grp_indices)
43+
4144
def initialize(self, X, y):
4245
grp_ptr, grp_indices = self.grp_ptr, self.grp_indices
4346
n_groups = len(grp_ptr) - 1

skglm/datafits/multi_task.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,10 @@
11
import numpy as np
22
from numpy.linalg import norm
33
from numba import float64
4-
from numba.experimental import jitclass
54

65
from skglm.datafits.base import BaseMultitaskDatafit
76

87

9-
spec_quadratic = [
10-
('XtY', float64[:, :]),
11-
('lipschitz', float64[:]),
12-
]
13-
14-
15-
@jitclass(spec_quadratic)
168
class QuadraticMultiTask(BaseMultitaskDatafit):
179
"""Quadratic datafit used for multi-task regression.
1810
@@ -33,6 +25,16 @@ class QuadraticMultiTask(BaseMultitaskDatafit):
3325
def __init__(self):
3426
pass
3527

28+
def get_spec(self):
29+
spec = (
30+
('XtY', float64[:, :]),
31+
('lipschitz', float64[:]),
32+
)
33+
return spec
34+
35+
def params_to_dict(self):
36+
return dict()
37+
3638
def initialize(self, X, Y):
3739
"""Compute optimization quantities before fitting on X and Y."""
3840
self.XtY = X.T @ Y

0 commit comments

Comments
 (0)