Skip to content

Commit 2e68c62

Browse files
committed
warning message & checking for jit compilation
1 parent e1cd53f commit 2e68c62

File tree

1 file changed

+16
-11
lines changed

1 file changed

+16
-11
lines changed

skglm/solvers/base.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -107,23 +107,28 @@ def solve(
107107
>>> ...
108108
>>> coefs, obj_out, stop_crit = solver.solve(X, y, datafit, penalty)
109109
"""
110-
# TODO do it properly instead of searching for a string
110+
# TODO check for datafit/penalty being jit-compiled properly
111+
# instead of searching for a string
111112
if "jitclass" in str(type(datafit)):
112113
warnings.warn(
113-
"Do not pass a compiled datafit, compilation is done inside solver now"
114+
"Passing in a compiled datafit is deprecated since skglm v0.5 "
115+
"Compilation is now done inside solver."
116+
"This will raise an error starting skglm v0.6 onwards."
114117
)
118+
elif datafit is not None:
119+
datafit = compiled_clone(datafit, to_float32=X.dtype == np.float32)
120+
115121
if "jitclass" in str(type(penalty)):
116122
warnings.warn(
117-
"Do not pass a compiled penalty, compilation is done inside solver now"
123+
"Passing in a compiled penalty is deprecated since skglm v0.5 "
124+
"Compilation is now done inside solver. "
125+
"This will raise an error starting skglm v0.6 onwards."
118126
)
119-
else:
120-
if datafit is not None:
121-
datafit = compiled_clone(datafit, to_float32=X.dtype == np.float32)
122-
if penalty is not None:
123-
penalty = compiled_clone(penalty)
124-
# TODO add support for bool spec in compiled_clone
125-
# currently, doing so break the code
126-
# penalty = compiled_clone(penalty, to_float32=X.dtype == np.float32)
127+
elif penalty is not None:
128+
penalty = compiled_clone(penalty)
129+
# TODO add support for bool spec in compiled_clone
130+
# currently, doing so break the code
131+
# penalty = compiled_clone(penalty, to_float32=X.dtype == np.float32)
127132

128133
if run_checks:
129134
self._validate(X, y, datafit, penalty)

0 commit comments

Comments
 (0)