Skip to content

Commit e4b916f

Browse files
FIX: Fix html widget docs links (#2299)
* fix html widget doc link * workaround for older sklearn * fix test * attempt at fixing n_jobs failures * better fix for TSNE * revert test to 'daal4py'
1 parent ff244af commit e4b916f

File tree

13 files changed

+78
-26
lines changed

13 files changed

+78
-26
lines changed

sklearnex/_utils.py

Lines changed: 39 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,27 @@
1616

1717
import logging
1818
import os
19+
import re
1920
import warnings
2021
from abc import ABC
2122

23+
import sklearn
24+
2225
from daal4py.sklearn._utils import (
2326
PatchingConditionsChain as daal4py_PatchingConditionsChain,
2427
)
25-
from daal4py.sklearn._utils import daal_check_version
28+
from daal4py.sklearn._utils import daal_check_version, sklearn_check_version
29+
30+
# Note: if inheriting from '_HTMLDocumentationLinkMixin' here, it then doesn't matter
31+
# the order of inheritance of classes for estimators when this is later subclassed,
32+
# whereas if inheriting from something else, the subclass that inherits from this needs
33+
# to be the first inherited class in estimators in order for it to take effect.
34+
if sklearn_check_version("1.4"):
35+
from sklearn.utils._estimator_html_repr import _HTMLDocumentationLinkMixin
36+
37+
BaseForHTMLDocLink = _HTMLDocumentationLinkMixin
38+
else:
39+
BaseForHTMLDocLink = ABC
2640

2741

2842
class PatchingConditionsChain(daal4py_PatchingConditionsChain):
@@ -128,10 +142,8 @@ def get_hyperparameters(self, op):
128142

129143

130144
# This abstract class is meant to generate a clickable doc link for classses
131-
# in sklearnex that are not part of base scikit-learn. It should be inherited
132-
# before inheriting from a scikit-learn estimator, otherwise will get overriden
133-
# by the estimator's original.
134-
class IntelEstimator(ABC):
145+
# in sklearnex that are not part of base scikit-learn.
146+
class IntelEstimator(BaseForHTMLDocLink):
135147
@property
136148
def _doc_link_module(self) -> str:
137149
return "sklearnex"
@@ -141,3 +153,25 @@ def _doc_link_template(self) -> str:
141153
module_path, _ = self.__class__.__module__.rsplit(".", 1)
142154
class_name = self.__class__.__name__
143155
return f"https://uxlfoundation.github.io/scikit-learn-intelex/latest/non-scikit-algorithms.html#{module_path}.{class_name}"
156+
157+
158+
# This abstract class is meant to generate a clickable doc link for classses
159+
# in sklearnex that have counterparts in scikit-learn.
160+
class PatchableEstimator(BaseForHTMLDocLink):
161+
@property
162+
def _doc_link_module(self) -> str:
163+
return "sklearnex"
164+
165+
@property
166+
def _doc_link_template(self) -> str:
167+
if re.search(r"^\d\.\d\.\d$", sklearn.__version__):
168+
sklearn_version_parts = sklearn.__version__.split(".")
169+
doc_version_url = sklearn_version_parts[0] + "." + sklearn_version_parts[1]
170+
else:
171+
doc_version_url = "stable"
172+
module_path, _ = self.__class__.__module__.rsplit(".", 1)
173+
module_path = re.sub("sklearnex", "sklearn", module_path)
174+
class_name = self.__class__.__name__
175+
# for TSNE, which re-uses daal4py
176+
module_path = re.sub(r"daal4py\.", "", module_path)
177+
return f"https://scikit-learn.org/{doc_version_url}/modules/generated/{module_path}.{class_name}.html"

sklearnex/cluster/dbscan.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from onedal.cluster import DBSCAN as onedal_DBSCAN
2727

2828
from .._device_offload import dispatch
29-
from .._utils import PatchingConditionsChain
29+
from .._utils import PatchableEstimator, PatchingConditionsChain
3030

3131
if sklearn_check_version("1.1") and not sklearn_check_version("1.2"):
3232
from sklearn.utils import check_scalar
@@ -51,7 +51,7 @@ def _save_attributes(self):
5151

5252

5353
@control_n_jobs(decorated_methods=["fit"])
54-
class DBSCAN(_sklearn_DBSCAN, BaseDBSCAN):
54+
class DBSCAN(PatchableEstimator, _sklearn_DBSCAN, BaseDBSCAN):
5555
__doc__ = _sklearn_DBSCAN.__doc__
5656

5757
if sklearn_check_version("1.2"):

sklearnex/cluster/k_means.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,15 +39,15 @@
3939
from onedal.utils import _is_csr
4040

4141
from .._device_offload import dispatch, wrap_output_data
42-
from .._utils import PatchingConditionsChain
42+
from .._utils import PatchableEstimator, PatchingConditionsChain
4343

4444
if sklearn_check_version("1.6"):
4545
from sklearn.utils.validation import validate_data
4646
else:
4747
validate_data = _sklearn_KMeans._validate_data
4848

4949
@control_n_jobs(decorated_methods=["fit", "fit_transform", "predict", "score"])
50-
class KMeans(_sklearn_KMeans):
50+
class KMeans(PatchableEstimator, _sklearn_KMeans):
5151
__doc__ = _sklearn_KMeans.__doc__
5252

5353
if sklearn_check_version("1.2"):

sklearnex/decomposition/pca.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818

1919
from daal4py.sklearn._utils import daal_check_version
2020

21+
from .._utils import PatchableEstimator
22+
2123
if daal_check_version((2024, "P", 100)):
2224
import numbers
2325
from math import sqrt
@@ -50,7 +52,7 @@
5052
validate_data = _sklearn_PCA._validate_data
5153

5254
@control_n_jobs(decorated_methods=["fit", "transform", "fit_transform"])
53-
class PCA(_sklearn_PCA):
55+
class PCA(PatchableEstimator, _sklearn_PCA):
5456
__doc__ = _sklearn_PCA.__doc__
5557

5658
if sklearn_check_version("1.2"):

sklearnex/ensemble/_forest.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@
6161
from sklearnex._utils import register_hyperparameters
6262

6363
from .._device_offload import dispatch, wrap_output_data
64-
from .._utils import PatchingConditionsChain
64+
from .._utils import PatchableEstimator, PatchingConditionsChain
6565
from ..utils._array_api import get_namespace
6666

6767
if sklearn_check_version("1.2"):
@@ -75,7 +75,7 @@
7575
validate_data = BaseEstimator._validate_data
7676

7777

78-
class BaseForest(ABC):
78+
class BaseForest(PatchableEstimator, ABC):
7979
_onedal_factory = None
8080

8181
def _onedal_fit(self, X, y, sample_weight=None, queue=None):

sklearnex/linear_model/linear.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,12 @@
2727

2828
from .._config import get_config
2929
from .._device_offload import dispatch, wrap_output_data
30-
from .._utils import PatchingConditionsChain, get_patch_message, register_hyperparameters
30+
from .._utils import (
31+
PatchableEstimator,
32+
PatchingConditionsChain,
33+
get_patch_message,
34+
register_hyperparameters,
35+
)
3136

3237
if sklearn_check_version("1.0") and not sklearn_check_version("1.2"):
3338
from sklearn.linear_model._base import _deprecate_normalize
@@ -47,7 +52,7 @@
4752

4853
@register_hyperparameters({"fit": get_hyperparameters("linear_regression", "train")})
4954
@control_n_jobs(decorated_methods=["fit", "predict", "score"])
50-
class LinearRegression(_sklearn_LinearRegression):
55+
class LinearRegression(PatchableEstimator, _sklearn_LinearRegression):
5156
__doc__ = _sklearn_LinearRegression.__doc__
5257

5358
if sklearn_check_version("1.2"):

sklearnex/linear_model/logistic_regression.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838

3939
from .._config import get_config
4040
from .._device_offload import dispatch, wrap_output_data
41-
from .._utils import PatchingConditionsChain, get_patch_message
41+
from .._utils import PatchableEstimator, PatchingConditionsChain, get_patch_message
4242

4343
if sklearn_check_version("1.6"):
4444
from sklearn.utils.validation import validate_data
@@ -65,7 +65,9 @@ def _onedal_gpu_save_attributes(self):
6565
"score",
6666
]
6767
)
68-
class LogisticRegression(_sklearn_LogisticRegression, BaseLogisticRegression):
68+
class LogisticRegression(
69+
PatchableEstimator, _sklearn_LogisticRegression, BaseLogisticRegression
70+
):
6971
__doc__ = _sklearn_LogisticRegression.__doc__
7072

7173
if sklearn_check_version("1.2"):

sklearnex/linear_model/ridge.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,15 +38,15 @@
3838
from onedal.utils import _num_features, _num_samples
3939

4040
from .._device_offload import dispatch, wrap_output_data
41-
from .._utils import PatchingConditionsChain
41+
from .._utils import PatchableEstimator, PatchingConditionsChain
4242

4343
if sklearn_check_version("1.6"):
4444
from sklearn.utils.validation import validate_data
4545
else:
4646
validate_data = _sklearn_Ridge._validate_data
4747

4848
@control_n_jobs(decorated_methods=["fit", "predict", "score"])
49-
class Ridge(_sklearn_Ridge):
49+
class Ridge(PatchableEstimator, _sklearn_Ridge):
5050
__doc__ = _sklearn_Ridge.__doc__
5151

5252
if sklearn_check_version("1.2"):

sklearnex/manifold/t_sne.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,13 @@
1414
# limitations under the License.
1515
# ===============================================================================
1616

17+
from daal4py.sklearn._n_jobs_support import control_n_jobs
1718
from daal4py.sklearn.manifold import TSNE
1819
from onedal._device_offload import support_input_format
1920

21+
from .._utils import PatchableEstimator
22+
2023
TSNE.fit = support_input_format(queue_param=False)(TSNE.fit)
2124
TSNE.fit_transform = support_input_format(queue_param=False)(TSNE.fit_transform)
25+
TSNE._doc_link_module = "daal4py"
26+
TSNE._doc_link_template = PatchableEstimator._doc_link_template

sklearnex/neighbors/common.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,11 @@
2727
from daal4py.sklearn._utils import sklearn_check_version
2828
from onedal.utils import _check_array, _num_features, _num_samples
2929

30-
from .._utils import PatchingConditionsChain
30+
from .._utils import PatchableEstimator, PatchingConditionsChain
3131
from ..utils._array_api import get_namespace
3232

3333

34-
class KNeighborsDispatchingBase:
34+
class KNeighborsDispatchingBase(PatchableEstimator):
3535
def _fit_validation(self, X, y=None):
3636
if sklearn_check_version("1.2"):
3737
self._validate_params()

0 commit comments

Comments
 (0)