Skip to content

Commit 6e1ecc4

Browse files
implement comments from Badr (docstrings, invert sequence, remove scoring)
1 parent e761de7 commit 6e1ecc4

File tree

1 file changed

+55
-4
lines changed

1 file changed

+55
-4
lines changed

skglm/cv.py

Lines changed: 55 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,18 +6,69 @@
66

77

88
class GeneralizedLinearEstimatorCV(GeneralizedLinearEstimator):
9-
"""CV wrapper for GeneralizedLinearEstimator."""
9+
"""
10+
Cross-validated wrapper for GeneralizedLinearEstimator.
11+
12+
This class performs cross-validated selection of the regularization parameter(s)
13+
for a generalized linear estimator, supporting both L1 and elastic-net penalties.
14+
15+
Parameters
16+
----------
17+
datafit : object
18+
Datafit (loss) function instance (e.g., Logistic, Quadratic).
19+
penalty : object
20+
Penalty instance with an 'alpha' parameter (and optionally 'l1_ratio').
21+
solver : object
22+
Solver instance to use for optimization.
23+
alphas : array-like of shape (n_alphas,), optional
24+
List of alpha values to try. If None, they are set automatically.
25+
l1_ratio : float or array-like, optional
26+
The ElasticNet mixing parameter(s), with 0 <= l1_ratio <= 1.
27+
Only used if the penalty supports 'l1_ratio'. If None, defaults to 1.0 (Lasso).
28+
cv : int, default=4
29+
Number of cross-validation folds.
30+
n_jobs : int, default=1
31+
Number of jobs to run in parallel for cross-validation.
32+
random_state : int or None, default=None
33+
Random seed for cross-validation splitting.
34+
eps : float, default=1e-3
35+
Ratio of minimum to maximum alpha if alphas are set automatically.
36+
n_alphas : int, default=100
37+
Number of alphas along the regularization path if alphas are set automatically.
38+
39+
Attributes
40+
----------
41+
alpha_ : float
42+
Best alpha found by cross-validation.
43+
l1_ratio_ : float or None
44+
Best l1_ratio found by cross-validation (if applicable).
45+
best_estimator_ : GeneralizedLinearEstimator
46+
Estimator fitted on the full data with the best parameters.
47+
coef_ : ndarray
48+
Coefficients of the fitted model.
49+
intercept_ : float or ndarray
50+
Intercept of the fitted model.
51+
alphas_ : ndarray
52+
Array of alphas used in the search.
53+
scores_path_ : ndarray
54+
Cross-validation scores for each parameter combination.
55+
n_iter_ : int or None
56+
Number of iterations run by the solver (if available).
57+
n_features_in_ : int or None
58+
Number of features seen during fit.
59+
feature_names_in_ : ndarray or None
60+
Names of features seen during fit.
61+
"""
1062

1163
def __init__(self, datafit, penalty, solver, alphas=None, l1_ratio=None,
12-
cv=4, n_jobs=1, random_state=None, scoring=None,
64+
cv=4, n_jobs=1, random_state=None,
1365
eps=1e-3, n_alphas=100):
1466
super().__init__(datafit=datafit, penalty=penalty, solver=solver)
1567
self.alphas = alphas
1668
self.l1_ratio = l1_ratio
1769
self.cv = cv
1870
self.n_jobs = n_jobs
1971
self.random_state = random_state
20-
self.scoring = scoring
2172
self.eps = eps
2273
self.n_alphas = n_alphas
2374

@@ -44,7 +95,7 @@ def fit(self, X, y):
4495
alpha_max,
4596
alpha_max * self.eps,
4697
self.n_alphas
47-
)[::-1]
98+
)
4899
has_l1_ratio = hasattr(self.penalty, "l1_ratio")
49100
l1_ratios = [1.] if not has_l1_ratio else np.atleast_1d(
50101
self.l1_ratio if self.l1_ratio is not None else [1.])

0 commit comments

Comments
 (0)