Skip to content

Commit 8558ee0

Browse files
committed
fix some tests
1 parent 5e1ad04 commit 8558ee0

File tree

3 files changed

+9
-10
lines changed

3 files changed

+9
-10
lines changed

skglm/solvers/base.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,9 @@ def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None,
111111
"Do not pass a compiled datafit, compilation is done inside solver now")
112112
else:
113113
datafit = compiled_clone(datafit, to_float32=X.dtype == np.float32)
114-
penalty = compiled_clone(penalty, to_float32=X.dtype == np.float32)
114+
penalty = compiled_clone(penalty)
115+
# TODO add support for bool spec in compiled_clone
116+
# penalty = compiled_clone(penalty, to_float32=X.dtype == np.float32)
115117

116118
if run_checks:
117119
self._validate(X, y, datafit, penalty)

skglm/tests/test_fista.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,11 @@
33
import numpy as np
44
from numpy.linalg import norm
55

6-
from scipy.sparse import csc_matrix, issparse
6+
from scipy.sparse import csc_matrix
77

8-
from skglm.penalties import L1, IndicatorBox
8+
from skglm.penalties import L1
99
from skglm.solvers import FISTA, AndersonCD
10-
from skglm.datafits import Quadratic, Logistic, QuadraticSVC
10+
from skglm.datafits import Quadratic, Logistic
1111

1212
from skglm.utils.data import make_correlated_data
1313

@@ -36,11 +36,6 @@
3636
def test_fista_solver(X, Datafit, Penalty):
3737
_y = y if isinstance(Datafit, Quadratic) else y_classif
3838
datafit = Datafit()
39-
# _init = y @ X.T if isinstance(Datafit, QuadraticSVC) else X
40-
# if issparse(X):
41-
# datafit.initialize_sparse(_init.data, _init.indptr, _init.indices, _y)
42-
# else:
43-
# datafit.initialize(_init, _y)
4439
penalty = Penalty(alpha)
4540

4641
solver = FISTA(max_iter=1000, tol=tol)

skglm/utils/jit_compilation.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,9 @@ def spec_to_float32(spec):
2929
else:
3030
dtype32 = dtype
3131
else:
32-
raise ValueError(f"Unknown spec type {dtype}")
32+
# raise ValueError(f"Unknown spec type {dtype}")
33+
# bool types and others are not affected:
34+
dtype32 = dtype
3335
spec32.append((name, dtype32))
3436
return spec32
3537

0 commit comments

Comments
 (0)