|
26 | 26 | import numpy as np |
27 | 27 | import numpy.random as nprnd |
28 | 28 | import pytest |
29 | | -from sklearn.base import ( |
30 | | - BaseEstimator, |
31 | | - ClassifierMixin, |
32 | | - ClusterMixin, |
33 | | - OutlierMixin, |
34 | | - RegressorMixin, |
35 | | - TransformerMixin, |
36 | | -) |
| 29 | +from sklearn.base import BaseEstimator |
37 | 30 |
|
38 | 31 | from daal4py.sklearn._utils import sklearn_check_version |
39 | 32 | from onedal.tests.utils._dataframes_support import ( |
@@ -149,16 +142,17 @@ def test_standard_estimator_patching(caplog, dataframe, queue, dtype, estimator, |
149 | 142 | and dtype in [np.uint32, np.uint64] |
150 | 143 | ): |
151 | 144 | pytest.skip("Windows segmentation fault for Ridge.predict for unsigned ints") |
152 | | - elif not hasattr(est, method): |
| 145 | + elif method and not hasattr(est, method): |
153 | 146 | pytest.skip(f"sklearn available_if prevents testing {estimator}.{method}") |
154 | 147 |
|
155 | 148 | X, y = gen_dataset(est, queue=queue, target_df=dataframe, dtype=dtype) |
156 | 149 | est.fit(X, y) |
157 | 150 |
|
158 | | - if method != "score": |
159 | | - getattr(est, method)(X) |
160 | | - else: |
161 | | - est.score(X, y) |
| 151 | + if method: |
| 152 | + if method != "score": |
| 153 | + getattr(est, method)(X) |
| 154 | + else: |
| 155 | + est.score(X, y) |
162 | 156 | assert all( |
163 | 157 | [ |
164 | 158 | "running accelerated version" in i.message |
@@ -186,12 +180,15 @@ def test_special_estimator_patching(caplog, dataframe, queue, dtype, estimator, |
186 | 180 | X, y = gen_dataset(est, queue=queue, target_df=dataframe, dtype=dtype) |
187 | 181 | est.fit(X, y) |
188 | 182 |
|
189 | | - if not hasattr(est, method): |
| 183 | + if method and not hasattr(est, method): |
190 | 184 | pytest.skip(f"sklearn available_if prevents testing {estimator}.{method}") |
191 | | - if method != "score": |
192 | | - getattr(est, method)(X) |
193 | | - else: |
194 | | - est.score(X, y) |
| 185 | + |
| 186 | + if method: |
| 187 | + if method != "score": |
| 188 | + getattr(est, method)(X) |
| 189 | + else: |
| 190 | + est.score(X, y) |
| 191 | + |
195 | 192 | assert all( |
196 | 193 | [ |
197 | 194 | "running accelerated version" in i.message |
@@ -336,18 +333,6 @@ def test_if_estimator_inherits_sklearn(estimator): |
336 | 333 | ), f"{estimator} does not inherit from the patched sklearn estimator" |
337 | 334 | else: |
338 | 335 | assert issubclass(est, BaseEstimator) |
339 | | - assert any( |
340 | | - [ |
341 | | - issubclass(est, i) |
342 | | - for i in [ |
343 | | - ClassifierMixin, |
344 | | - ClusterMixin, |
345 | | - OutlierMixin, |
346 | | - RegressorMixin, |
347 | | - TransformerMixin, |
348 | | - ] |
349 | | - ] |
350 | | - ), f"{estimator} does not inherit a sklearn Mixin" |
351 | 336 |
|
352 | 337 |
|
353 | 338 | @pytest.mark.parametrize("estimator", UNPATCHED_MODELS.keys()) |
|
0 commit comments