Skip to content

Commit 9ae3309

Browse files
authored
[testing] Remove sklearn mixin dependence from test_patching (#1795)
* Update test_patching.py * Update _utils.py * formatting * Update _utils.py * Update _utils.py * Update test_patching.py
1 parent 87ddd6f commit 9ae3309

File tree

2 files changed

+27
-35
lines changed

2 files changed

+27
-35
lines changed

sklearnex/tests/_utils.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -104,10 +104,13 @@ def _load_all_models(with_sklearnex=True, estimator=True):
104104
def gen_models_info(algorithms):
105105
output = []
106106
for i in algorithms:
107-
# split handles SPECIAL_INSTANCES or custom inputs
108-
# custom sklearn inputs must be a dict of estimators
109-
# with keys set by the __str__ method
110-
est = PATCHED_MODELS[i.split("(")[0]]
107+
108+
if i in PATCHED_MODELS:
109+
est = PATCHED_MODELS[i]
110+
elif i in SPECIAL_INSTANCES:
111+
est = SPECIAL_INSTANCES[i].__class__
112+
else:
113+
raise KeyError(f"Unrecognized sklearnex estimator: {i}")
111114

112115
methods = set()
113116
candidates = set(
@@ -118,7 +121,11 @@ def gen_models_info(algorithms):
118121
if issubclass(est, mixin):
119122
methods |= candidates & set(method)
120123

121-
output += [[i, j] for j in methods]
124+
output += [[i, j] for j in methods] if methods else [[i, None]]
125+
126+
# In the case that no methods are available, set method to None.
127+
# This will allow estimators without mixins to still test the fit
128+
# method in various tests.
122129
return output
123130

124131

sklearnex/tests/test_patching.py

Lines changed: 15 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -26,14 +26,7 @@
2626
import numpy as np
2727
import numpy.random as nprnd
2828
import pytest
29-
from sklearn.base import (
30-
BaseEstimator,
31-
ClassifierMixin,
32-
ClusterMixin,
33-
OutlierMixin,
34-
RegressorMixin,
35-
TransformerMixin,
36-
)
29+
from sklearn.base import BaseEstimator
3730

3831
from daal4py.sklearn._utils import sklearn_check_version
3932
from onedal.tests.utils._dataframes_support import (
@@ -149,16 +142,17 @@ def test_standard_estimator_patching(caplog, dataframe, queue, dtype, estimator,
149142
and dtype in [np.uint32, np.uint64]
150143
):
151144
pytest.skip("Windows segmentation fault for Ridge.predict for unsigned ints")
152-
elif not hasattr(est, method):
145+
elif method and not hasattr(est, method):
153146
pytest.skip(f"sklearn available_if prevents testing {estimator}.{method}")
154147

155148
X, y = gen_dataset(est, queue=queue, target_df=dataframe, dtype=dtype)
156149
est.fit(X, y)
157150

158-
if method != "score":
159-
getattr(est, method)(X)
160-
else:
161-
est.score(X, y)
151+
if method:
152+
if method != "score":
153+
getattr(est, method)(X)
154+
else:
155+
est.score(X, y)
162156
assert all(
163157
[
164158
"running accelerated version" in i.message
@@ -186,12 +180,15 @@ def test_special_estimator_patching(caplog, dataframe, queue, dtype, estimator,
186180
X, y = gen_dataset(est, queue=queue, target_df=dataframe, dtype=dtype)
187181
est.fit(X, y)
188182

189-
if not hasattr(est, method):
183+
if method and not hasattr(est, method):
190184
pytest.skip(f"sklearn available_if prevents testing {estimator}.{method}")
191-
if method != "score":
192-
getattr(est, method)(X)
193-
else:
194-
est.score(X, y)
185+
186+
if method:
187+
if method != "score":
188+
getattr(est, method)(X)
189+
else:
190+
est.score(X, y)
191+
195192
assert all(
196193
[
197194
"running accelerated version" in i.message
@@ -336,18 +333,6 @@ def test_if_estimator_inherits_sklearn(estimator):
336333
), f"{estimator} does not inherit from the patched sklearn estimator"
337334
else:
338335
assert issubclass(est, BaseEstimator)
339-
assert any(
340-
[
341-
issubclass(est, i)
342-
for i in [
343-
ClassifierMixin,
344-
ClusterMixin,
345-
OutlierMixin,
346-
RegressorMixin,
347-
TransformerMixin,
348-
]
349-
]
350-
), f"{estimator} does not inherit a sklearn Mixin"
351336

352337

353338
@pytest.mark.parametrize("estimator", UNPATCHED_MODELS.keys())

0 commit comments

Comments
 (0)