Skip to content

Commit 6058825

Browse files
QB3mathurinm
andauthored
FIX updates for sklearn 1.1 (#15)
Co-authored-by: mathurinm <[email protected]>
1 parent a187c66 commit 6058825

File tree

1 file changed

+18
-4
lines changed

1 file changed

+18
-4
lines changed

skglm/estimators.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,9 @@ def fit(self, X, y):
196196
elif isinstance(self.datafit, Logistic):
197197
self.datafit = Logistic_32()
198198

199+
if not hasattr(self, "n_features_in_"):
200+
self.n_features_in_ = X.shape[1]
201+
199202
self.classes_ = None
200203
n_classes_ = 0
201204

@@ -259,8 +262,10 @@ def fit(self, X, y):
259262
self.coef_ = np.empty([len(self.classes_), X.shape[1]])
260263
self.intercept_ = 0
261264
multiclass = OneVsRestClassifier(self).fit(X, y)
262-
self.coef_ = np.array([clf.coef_[0] for clf in multiclass.estimators_])
263-
self.n_iter_ = max(clf.n_iter_ for clf in multiclass.estimators_)
265+
self.coef_ = np.array([clf.coef_[0]
266+
for clf in multiclass.estimators_])
267+
self.n_iter_ = max(
268+
clf.n_iter_ for clf in multiclass.estimators_)
264269
elif isinstance(self.datafit, Logistic):
265270
self.coef_ = coefs.T
266271
return self
@@ -905,6 +910,9 @@ def fit(self, X, y):
905910
self.classes_ = enc.classes_
906911
n_classes = len(enc.classes_)
907912

913+
if not hasattr(self, "n_features_in_"):
914+
self.n_features_in_ = X.shape[1]
915+
908916
if n_classes <= 2:
909917
_, coefs, _, self.n_iter_ = self.path(
910918
X, 2 * y_ind - 1, np.array([self.alpha]), solver=self.solver)
@@ -914,7 +922,8 @@ def fit(self, X, y):
914922
self.coef_ = np.empty([n_classes, X.shape[1]])
915923
self.intercept_ = 0.
916924
multiclass = OneVsRestClassifier(self).fit(X, y)
917-
self.coef_ = multiclass.coef_
925+
self.coef_ = np.array([clf.coef_[0]
926+
for clf in multiclass.estimators_])
918927
self.n_iter_ = max(clf.n_iter_ for clf in multiclass.estimators_)
919928
return self
920929

@@ -1082,10 +1091,14 @@ def fit(self, X, y):
10821091
check_classification_targets(y)
10831092
self.classes_ = np.unique(y)
10841093

1094+
if not hasattr(self, "n_features_in_"):
1095+
self.n_features_in_ = X.shape[1]
1096+
10851097
enc = LabelEncoder()
10861098
y_ind = enc.fit_transform(y)
10871099
self.classes_ = enc.classes_
10881100
n_classes = len(enc.classes_)
1101+
10891102
if n_classes <= 2:
10901103
y_ind = 2 * y_ind - 1
10911104
is_sparse = issparse(X)
@@ -1114,7 +1127,8 @@ def fit(self, X, y):
11141127
self.coef_ = np.empty([n_classes, X.shape[1]])
11151128
self.intercept_ = 0.
11161129
multiclass = OneVsRestClassifier(self).fit(X, y)
1117-
self.coef_ = np.array([clf.coef_[0] for clf in multiclass.estimators_])
1130+
self.coef_ = np.array([clf.coef_[0]
1131+
for clf in multiclass.estimators_])
11181132
self.n_iter_ = max(clf.n_iter_ for clf in multiclass.estimators_)
11191133
return self
11201134

0 commit comments

Comments
 (0)