diff --git a/.DS_Store b/.DS_Store new file mode 100644 index 00000000..98053944 Binary files /dev/null and b/.DS_Store differ diff --git a/skglm/estimators.py b/skglm/estimators.py index 870b1cbe..6197101c 100644 --- a/skglm/estimators.py +++ b/skglm/estimators.py @@ -448,6 +448,11 @@ def path(self, X, y, alphas, coef_init=None, return_n_iter=True, **params): warm_start=self.warm_start, verbose=self.verbose) return solver.path(X, y, datafit, penalty, alphas, coef_init, return_n_iter) + def __sklearn_tags__(self): + tags = super().__sklearn_tags__() + tags.input_tags.sparse = True + return tags + class WeightedLasso(RegressorMixin, LinearModel): r"""WeightedLasso estimator based on Celer solver and primal extrapolation. @@ -611,6 +616,11 @@ def fit(self, X, y): warm_start=self.warm_start, verbose=self.verbose) return _glm_fit(X, y, self, Quadratic(), penalty, solver) + def __sklearn_tags__(self): + tags = super().__sklearn_tags__() + tags.input_tags.sparse = True + return tags + class ElasticNet(RegressorMixin, LinearModel): r"""Elastic net estimator. @@ -765,6 +775,11 @@ def fit(self, X, y): return _glm_fit(X, y, self, Quadratic(), L1_plus_L2(self.alpha, self.l1_ratio, self.positive), solver) + def __sklearn_tags__(self): + tags = super().__sklearn_tags__() + tags.input_tags.sparse = True + return tags + class MCPRegression(RegressorMixin, LinearModel): r"""Linear regression with MCP penalty estimator. @@ -953,6 +968,11 @@ def fit(self, X, y): warm_start=self.warm_start, verbose=self.verbose) return _glm_fit(X, y, self, Quadratic(), penalty, solver) + def __sklearn_tags__(self): + tags = super().__sklearn_tags__() + tags.input_tags.sparse = True + return tags + class SparseLogisticRegression(LinearClassifierMixin, SparseCoefMixin, BaseEstimator): r"""Sparse Logistic regression estimator. @@ -1380,6 +1400,11 @@ def fit(self, X, y): return self + def __sklearn_tags__(self): + tags = super().__sklearn_tags__() + tags.input_tags.sparse = True + return tags + class MultiTaskLasso(RegressorMixin, LinearModel): r"""MultiTaskLasso estimator.