Skip to content

Commit c1910f5

Browse files
authored
Fix sklearn tests [part 1] (#948)
1 parent a8014de commit c1910f5

File tree

26 files changed

+98
-82
lines changed

26 files changed

+98
-82
lines changed

.circleci/deselect_tests.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,25 +19,25 @@
1919
import argparse
2020
import os.path
2121
from yaml import FullLoader, load as yaml_load
22-
from distutils.version import LooseVersion
22+
from packaging.version import Version
2323
import sklearn
2424
from sklearn import __version__ as sklearn_version
2525
import warnings
2626

2727

2828
def evaluate_cond(cond, v):
2929
if cond.startswith(">="):
30-
return LooseVersion(v) >= LooseVersion(cond[2:])
30+
return Version(v) >= Version(cond[2:])
3131
if cond.startswith("<="):
32-
return LooseVersion(v) <= LooseVersion(cond[2:])
32+
return Version(v) <= Version(cond[2:])
3333
if cond.startswith("=="):
34-
return LooseVersion(v) == LooseVersion(cond[2:])
34+
return Version(v) == Version(cond[2:])
3535
if cond.startswith("!="):
36-
return LooseVersion(v) != LooseVersion(cond[2:])
36+
return Version(v) != Version(cond[2:])
3737
if cond.startswith("<"):
38-
return LooseVersion(v) < LooseVersion(cond[1:])
38+
return Version(v) < Version(cond[1:])
3939
if cond.startswith(">"):
40-
return LooseVersion(v) > LooseVersion(cond[1:])
40+
return Version(v) > Version(cond[1:])
4141
warnings.warn(
4242
'Test selection condition "{0}" should start with '
4343
'>=, <=, ==, !=, < or > to compare to version of scikit-learn run. '

daal4py/sklearn/_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
from daal4py import _get__daal_link_version__ as dv
2121
from sklearn import __version__ as sklearn_version
22-
from distutils.version import LooseVersion
22+
from packaging.version import Version
2323
import logging
2424

2525

@@ -56,7 +56,7 @@ def daal_check_version(rule):
5656

5757

5858
def sklearn_check_version(ver):
59-
return bool(LooseVersion(sklearn_version) >= LooseVersion(ver))
59+
return bool(Version(sklearn_version) >= Version(ver))
6060

6161

6262
def get_daal_version():

daal4py/sklearn/cluster/_dbscan.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,7 @@
2424

2525
import daal4py
2626
from daal4py.sklearn._utils import (
27-
make2d, getFPType, get_patch_message, PatchingConditionsChain)
28-
import logging
27+
make2d, getFPType, PatchingConditionsChain)
2928

3029
from .._device_offload import support_usm_ndarray
3130
from .._utils import sklearn_check_version

daal4py/sklearn/cluster/_k_means_0_22.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,8 @@
3535

3636
import daal4py
3737
from .._utils import (
38-
getFPType, get_patch_message, daal_check_version, PatchingConditionsChain)
38+
getFPType, daal_check_version, PatchingConditionsChain)
3939
from .._device_offload import support_usm_ndarray
40-
import logging
4140

4241

4342
def _tolerance(X, rtol):

daal4py/sklearn/cluster/_k_means_0_23.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,9 @@
3737
import daal4py
3838
from .._utils import (
3939
getFPType,
40-
get_patch_message,
4140
sklearn_check_version,
4241
PatchingConditionsChain)
4342
from .._device_offload import support_usm_ndarray
44-
import logging
4543

4644

4745
def _validate_center_shape(X, n_centers, centers):

daal4py/sklearn/decomposition/_pca.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,8 @@
2525

2626
import daal4py
2727
from .._utils import (
28-
getFPType, get_patch_message, sklearn_check_version, PatchingConditionsChain)
28+
getFPType, sklearn_check_version, PatchingConditionsChain)
2929
from .._device_offload import support_usm_ndarray
30-
import logging
3130

3231
if sklearn_check_version('0.22'):
3332
from sklearn.decomposition._pca import PCA as PCA_original

daal4py/sklearn/ensemble/AdaBoostClassifier.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from .._utils import getFPType
2727

2828
from sklearn import __version__ as sklearn_version
29-
from distutils.version import LooseVersion
29+
from packaging.version import Version
3030

3131

3232
class AdaBoostClassifier(BaseEstimator, ClassifierMixin):
@@ -129,7 +129,7 @@ def fit(self, X, y):
129129

130130
def predict(self, X):
131131
# Check is fit had been called
132-
if LooseVersion(sklearn_version) >= LooseVersion("0.22"):
132+
if Version(sklearn_version) >= Version("0.22"):
133133
check_is_fitted(self)
134134
else:
135135
check_is_fitted(self, ['n_features_in_', 'n_classes_'])

daal4py/sklearn/ensemble/_forest.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,14 @@
2020
import warnings
2121

2222
import daal4py
23-
from .._utils import (getFPType, get_patch_message)
23+
from .._utils import getFPType
2424
from .._device_offload import support_usm_ndarray
2525
from daal4py.sklearn._utils import (
2626
daal_check_version, sklearn_check_version,
2727
PatchingConditionsChain)
2828
import logging
2929

30-
from sklearn.tree import (DecisionTreeClassifier, DecisionTreeRegressor)
30+
from sklearn.tree import DecisionTreeClassifier
3131
from sklearn.tree._tree import Tree
3232
from sklearn.ensemble import RandomForestClassifier as RandomForestClassifier_original
3333
from sklearn.ensemble import RandomForestRegressor as RandomForestRegressor_original
@@ -40,8 +40,6 @@
4040
from sklearn.base import clone
4141
from sklearn.exceptions import DataConversionWarning
4242

43-
from sklearn import __version__ as sklearn_version
44-
from distutils.version import LooseVersion
4543
from math import ceil
4644
from scipy import sparse as sp
4745

@@ -97,6 +95,12 @@ def _get_n_samples_bootstrap(n_samples, max_samples):
9795

9896

9997
def _check_parameters(self):
98+
if not self.bootstrap and self.max_samples is not None:
99+
raise ValueError(
100+
"`max_sample` cannot be set if `bootstrap=False`. "
101+
"Either switch to `bootstrap=True` or set "
102+
"`max_sample=None`."
103+
)
100104
if isinstance(self.min_samples_leaf, numbers.Integral):
101105
if not 1 <= self.min_samples_leaf:
102106
raise ValueError("min_samples_leaf must be at least 1 "

daal4py/sklearn/linear_model/_linear.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,11 @@
1616

1717
import numpy as np
1818
from scipy import sparse as sp
19-
from scipy import linalg
2019

21-
from sklearn.linear_model._base import _rescale_data
2220
from ..utils.validation import _daal_check_array, _daal_check_X_y
2321
from ..utils.base import _daal_validate_data
2422
from .._utils import sklearn_check_version
2523
from .._device_offload import support_usm_ndarray
26-
from sklearn.utils.fixes import sparse_lsqr
27-
from sklearn.utils.validation import _check_sample_weight
2824
from sklearn.utils import check_array
2925

3026
from sklearn.linear_model import LinearRegression as LinearRegression_original

daal4py/sklearn/linear_model/logistic_path.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,8 @@
4848
LogisticRegression as LogisticRegression_original)
4949
from sklearn.preprocessing import LabelEncoder, LabelBinarizer
5050
from .._utils import (
51-
getFPType, get_patch_message, sklearn_check_version, PatchingConditionsChain)
51+
getFPType, sklearn_check_version, PatchingConditionsChain)
5252
from .._device_offload import support_usm_ndarray
53-
import logging
5453

5554

5655
# Code adapted from sklearn.linear_model.logistic version 0.21

0 commit comments

Comments
 (0)