Skip to content

Commit 495333b

Browse files
FIX make PDCD_WS solver usable in GeneralizedLinearEstimator (#274)
Co-authored-by: Badr-MOUFAD <[email protected]>
1 parent 6ac303c commit 495333b

File tree

2 files changed

+14
-2
lines changed

2 files changed

+14
-2
lines changed

skglm/experimental/pdcd_ws.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,13 +82,17 @@ class PDCD_WS(BaseSolver):
8282
_datafit_required_attr = ('prox_conjugate',)
8383
_penalty_required_attr = ("prox_1d",)
8484

85-
def __init__(self, max_iter=1000, max_epochs=1000, dual_init=None,
86-
p0=100, tol=1e-6, verbose=False):
85+
def __init__(
86+
self, max_iter=1000, max_epochs=1000, dual_init=None, p0=100, tol=1e-6,
87+
fit_intercept=False, warm_start=True, verbose=False
88+
):
8789
self.max_iter = max_iter
8890
self.max_epochs = max_epochs
8991
self.dual_init = dual_init
9092
self.p0 = p0
9193
self.tol = tol
94+
self.fit_intercept = fit_intercept # TODO not handled
95+
self.warm_start = warm_start # TODO not handled
9296
self.verbose = verbose
9397

9498
def _solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None):

skglm/experimental/tests/test_quantile_regression.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from numpy.linalg import norm
44

55
from skglm.penalties import L1
6+
from skglm import GeneralizedLinearEstimator
67
from skglm.experimental.pdcd_ws import PDCD_WS
78
from skglm.experimental.quantile_regression import Pinball
89
from skglm.utils.jit_compilation import compiled_clone
@@ -37,6 +38,13 @@ def test_PDCD_WS(quantile_level):
3738
).fit(X, y)
3839

3940
np.testing.assert_allclose(w, clf.coef_, atol=1e-5)
41+
# test compatibility when inside GLM:
42+
estimator = GeneralizedLinearEstimator(
43+
datafit=Pinball(.2),
44+
penalty=L1(alpha=1.),
45+
solver=PDCD_WS(),
46+
)
47+
estimator.fit(X, y)
4048

4149

4250
if __name__ == '__main__':

0 commit comments

Comments
 (0)