Skip to content

Commit fba7cbe

Browse files
authored
Some fixes before sklearn 1.0 (#641)
* tsne fix * fix logreg * fixes * fix fixes :-) * wrong argument
1 parent 467b1e8 commit fba7cbe

File tree

6 files changed

+54
-18
lines changed

6 files changed

+54
-18
lines changed

daal4py/sklearn/cluster/_k_means_0_23.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#===============================================================================
1616

1717
import numpy as np
18+
import numbers
1819
from scipy import sparse as sp
1920

2021
from sklearn.utils import (check_random_state, check_array)
@@ -284,14 +285,15 @@ def _fit(self, X, y=None, sample_weight=None):
284285
raise ValueError("Algorithm must be 'auto', 'full' or 'elkan', got"
285286
" {}".format(str(algorithm)))
286287

287-
daal_ready = True
288-
if daal_ready:
289-
X_len = _num_samples(X)
290-
daal_ready = (self.n_clusters <= X_len)
291-
if daal_ready and sample_weight is not None:
288+
X_len = _num_samples(X)
289+
daal_ready = self.n_clusters <= X_len
290+
if daal_ready and sample_weight is not None:
291+
if isinstance(sample_weight, numbers.Number):
292+
sample_weight = np.full(X_len, sample_weight, dtype=np.float64)
293+
else:
292294
sample_weight = np.asarray(sample_weight)
293-
daal_ready = (sample_weight.shape == (X_len,)) and (
294-
np.allclose(sample_weight, np.ones_like(sample_weight)))
295+
daal_ready = (sample_weight.shape == (X_len,)) and (
296+
np.allclose(sample_weight, np.ones_like(sample_weight)))
295297

296298
if daal_ready:
297299
logging.info(

daal4py/sklearn/ensemble/_forest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -775,7 +775,7 @@ class RandomForestRegressor(RandomForestRegressor_original):
775775

776776
def __init__(self,
777777
n_estimators=100, *,
778-
criterion="mse",
778+
criterion="squared_error" if sklearn_check_version('1.0') else "mse",
779779
max_depth=None,
780780
min_samples_split=2,
781781
min_samples_leaf=1,

daal4py/sklearn/linear_model/_linear_0_24.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def __init__(self, fit_intercept=True, normalize=False, copy_X=True,
4040
def fit(self, X, y, sample_weight=None):
4141
if sklearn_check_version('1.0'):
4242
from sklearn.linear_model._base import _deprecate_normalize
43-
self.normalize = _deprecate_normalize(
43+
self._normalize = _deprecate_normalize(
4444
self.normalize, default=False,
4545
estimator_name=self.__class__.__name__
4646
)

daal4py/sklearn/linear_model/_ridge_0_22.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def _fit_ridge(self, X, y, sample_weight=None):
102102
"""
103103
if sklearn_check_version('1.0'):
104104
from sklearn.linear_model._base import _deprecate_normalize
105-
self.normalize = _deprecate_normalize(
105+
self._normalize = _deprecate_normalize(
106106
self.normalize, default=False,
107107
estimator_name=self.__class__.__name__
108108
)

daal4py/sklearn/linear_model/logistic_path.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@
5454
from sklearn.linear_model._base import (LinearClassifierMixin, SparseCoefMixin,
5555
BaseEstimator)
5656
from .._utils import (daal_check_version, getFPType,
57-
get_patch_message)
57+
get_patch_message, sklearn_check_version)
5858
import logging
5959

6060

@@ -243,8 +243,13 @@ def __logistic_regression_path(X, y, pos_class=None, Cs=10, fit_intercept=True,
243243
daal_ready = daal_ready and sample_weight is None and class_weight is None
244244

245245
if not daal_ready:
246-
sample_weight = _check_sample_weight(sample_weight, X,
247-
dtype=X.dtype)
246+
if sklearn_check_version('0.24'):
247+
sample_weight = _check_sample_weight(sample_weight, X,
248+
dtype=X.dtype,
249+
copy=True)
250+
else:
251+
sample_weight = _check_sample_weight(sample_weight, X,
252+
dtype=X.dtype)
248253
# If class_weights is a dict (provided by the user), the weights
249254
# are assigned to the original labels. If it is "balanced", then
250255
# the class_weights are assigned after masking the labels with a OvR.

daal4py/sklearn/manifold/_t_sne.py

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,10 +40,35 @@ class TSNE(BaseTSNE):
4040
def _fit(self, X, skip_num_points=0):
4141
"""Private function to fit the model using X as training data."""
4242

43+
if isinstance(self.init, str) and self.init == 'warn':
44+
warnings.warn("The default initialization in TSNE will change "
45+
"from 'random' to 'pca' in 1.2.", FutureWarning)
46+
self._init = 'random'
47+
else:
48+
self._init = self.init
49+
50+
if isinstance(self._init, str) and self._init == 'pca' and issparse(X):
51+
raise TypeError("PCA initialization is currently not suported "
52+
"with the sparse input matrix. Use "
53+
"init=\"random\" instead.")
54+
4355
if self.method not in ['barnes_hut', 'exact']:
4456
raise ValueError("'method' must be 'barnes_hut' or 'exact'")
4557
if self.angle < 0.0 or self.angle > 1.0:
4658
raise ValueError("'angle' must be between 0.0 - 1.0")
59+
if self.learning_rate == 'warn':
60+
warnings.warn("The default learning rate in TSNE will change "
61+
"from 200.0 to 'auto' in 1.2.", FutureWarning)
62+
self._learning_rate = 200.0
63+
else:
64+
self._learning_rate = self.learning_rate
65+
if self._learning_rate == 'auto':
66+
self._learning_rate = X.shape[0] / self.early_exaggeration / 4
67+
self._learning_rate = np.maximum(self._learning_rate, 50)
68+
else:
69+
if not (self._learning_rate > 0):
70+
raise ValueError("'learning_rate' must be a positive number "
71+
"or 'auto'.")
4772

4873
if hasattr(self, 'square_distances'):
4974
if self.square_distances not in [True, 'legacy']:
@@ -74,7 +99,7 @@ def _fit(self, X, skip_num_points=0):
7499
X = check_array(X, accept_sparse=['csr', 'csc', 'coo'],
75100
dtype=[np.float32, np.float64])
76101
if self.metric == "precomputed":
77-
if isinstance(self.init, str) and self.init == 'pca':
102+
if isinstance(self._init, str) and self._init == 'pca':
78103
raise ValueError("The parameter init=\"pca\" cannot be "
79104
"used with metric=\"precomputed\".")
80105
if X.shape[0] != X.shape[1]:
@@ -187,13 +212,17 @@ def _fit(self, X, skip_num_points=0):
187212
P = _joint_probabilities_nn(distances_nn, self.perplexity,
188213
self.verbose)
189214

190-
if isinstance(self.init, np.ndarray):
191-
X_embedded = self.init
192-
elif self.init == 'pca':
215+
if isinstance(self._init, np.ndarray):
216+
X_embedded = self._init
217+
elif self._init == 'pca':
193218
pca = PCA(n_components=self.n_components, svd_solver='randomized',
194219
random_state=random_state)
195220
X_embedded = pca.fit_transform(X).astype(np.float32, copy=False)
196-
elif self.init == 'random':
221+
warnings.warn("The PCA initialization in TSNE will change to "
222+
"have the standard deviation of PC1 equal to 1e-4 "
223+
"in 1.2. This will ensure better convergence.",
224+
FutureWarning)
225+
elif self._init == 'random':
197226
# The embedding is initialized with iid samples from Gaussians with
198227
# standard deviation 1e-4.
199228
X_embedded = 1e-4 * random_state.randn(

0 commit comments

Comments
 (0)