From b0bcac88b0f2d476eb765a81c4e16c881bb4d3db Mon Sep 17 00:00:00 2001 From: "Mcgrievy, Kathleen" Date: Tue, 22 Jul 2025 10:36:01 -0700 Subject: [PATCH 01/33] ArrayAPI update. --- onedal/cluster/kmeans.py | 8 +++++--- sklearnex/cluster/k_means.py | 9 ++++++--- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/onedal/cluster/kmeans.py b/onedal/cluster/kmeans.py index a0155bfa66..321748d793 100644 --- a/onedal/cluster/kmeans.py +++ b/onedal/cluster/kmeans.py @@ -38,7 +38,7 @@ from ..common._mixin import ClusterMixin, TransformerMixin from ..datatypes import from_table, to_table from ..utils.validation import _check_array, _is_arraylike_not_scalar, _is_csr - +from ..utils._array_api import get_namespace class _BaseKMeans(TransformerMixin, ClusterMixin, ABC): def __init__( @@ -363,14 +363,15 @@ def cluster_centers_(self): @cluster_centers_.setter def cluster_centers_(self, cluster_centers): - self._cluster_centers_ = np.asarray(cluster_centers) + xp, _ = get_namespace(cluster_centers) + self._cluster_centers_ = xp.asarray(cluster_centers) self.n_iter_ = 0 self.inertia_ = 0 self.model_.centroids = to_table(self._cluster_centers_) self.n_features_in_ = self.model_.centroids.column_count - self.labels_ = np.arange(self.model_.centroids.row_count) + self.labels_ = xp.arange(self.model_.centroids.row_count) return self @@ -401,6 +402,7 @@ def _score(self, X): ) def _transform(self, X): + xp, _ = get_namespace(X) return euclidean_distances(X, self.cluster_centers_) diff --git a/sklearnex/cluster/k_means.py b/sklearnex/cluster/k_means.py index aba871c21b..9c862eef55 100644 --- a/sklearnex/cluster/k_means.py +++ b/sklearnex/cluster/k_means.py @@ -155,11 +155,12 @@ def fit(self, X, y=None, sample_weight=None): return self def _onedal_fit(self, X, _, sample_weight, queue=None): + xp, _ = get_namespace(X) X = validate_data( self, X, accept_sparse="csr", - dtype=[np.float64, np.float32], + dtype=[xp.float64, xp.float32], order="C", copy=self.copy_x, accept_large_sparse=False, @@ -294,12 +295,13 @@ def predict( ) def _onedal_predict(self, X, sample_weight=None, queue=None): + xp, _ = get_namespace(X) X = validate_data( self, X, accept_sparse="csr", reset=False, - dtype=[np.float64, np.float32], + dtype=[xp.float64, xp.float32], ) if not hasattr(self, "_onedal_estimator"): @@ -351,12 +353,13 @@ def score(self, X, y=None, sample_weight=None): ) def _onedal_score(self, X, y=None, sample_weight=None, queue=None): + xp, _ = get_namespace(X) X = validate_data( self, X, accept_sparse="csr", reset=False, - dtype=[np.float64, np.float32], + dtype=[xp.float64, xp.float32], ) if not sklearn_check_version("1.5") and sklearn_check_version("1.3"): From 883b627ec556e0e3219774572f9cc2afe41921da Mon Sep 17 00:00:00 2001 From: "Mcgrievy, Kathleen" Date: Tue, 22 Jul 2025 11:51:21 -0700 Subject: [PATCH 02/33] Updated imports. --- onedal/cluster/kmeans.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onedal/cluster/kmeans.py b/onedal/cluster/kmeans.py index 321748d793..6eb95b2463 100644 --- a/onedal/cluster/kmeans.py +++ b/onedal/cluster/kmeans.py @@ -38,7 +38,7 @@ from ..common._mixin import ClusterMixin, TransformerMixin from ..datatypes import from_table, to_table from ..utils.validation import _check_array, _is_arraylike_not_scalar, _is_csr -from ..utils._array_api import get_namespace +from sklearnex.utils._array_api import get_namespace class _BaseKMeans(TransformerMixin, ClusterMixin, ABC): def __init__( From 8acc28015587190239cfc6e18dc66a3c57ab27f8 Mon Sep 17 00:00:00 2001 From: "Mcgrievy, Kathleen" Date: Tue, 22 Jul 2025 14:14:59 -0700 Subject: [PATCH 03/33] Updated default parameter for oneAPI. --- onedal/cluster/kmeans.py | 23 ++++++++++++++++++++--- 1 file changed, 20 insertions(+), 3 deletions(-) diff --git a/onedal/cluster/kmeans.py b/onedal/cluster/kmeans.py index 6eb95b2463..69538582b0 100644 --- a/onedal/cluster/kmeans.py +++ b/onedal/cluster/kmeans.py @@ -176,9 +176,15 @@ def _init_centroids_onedal( init, random_seed, is_csr, - dtype=np.float32, + dtype=None, n_centroids=None, ): + + xp = X_table.__array_namespace__() + + if dtype is None: + dtype = xp.float32 + n_clusters = self.n_clusters if n_centroids is None else n_centroids if isinstance(init, str) and init == "k-means++": @@ -219,10 +225,15 @@ def _init_centroids_onedal( return centers_table - def _init_centroids_sklearn(self, X, init, random_state, dtype=np.float32): + def _init_centroids_sklearn(self, X, init, random_state, dtype=None): # For oneDAL versions < 2023.2 or callable init, # using the scikit-learn implementation logging.getLogger("sklearnex").info("Computing KMeansInit with Stock sklearn") + xp, _ = get_namespace(X) + + if dtype is None: + dtype = xp.float32 + n_samples = X.shape[0] if isinstance(init, str) and init == "k-means++": @@ -249,7 +260,13 @@ def _init_centroids_sklearn(self, X, init, random_state, dtype=np.float32): return to_table(centers, queue=getattr(QM.get_global_queue(), "_queue", None)) - def _fit_backend(self, X_table, centroids_table, dtype=np.float32, is_csr=False): + def _fit_backend(self, X_table, centroids_table, dtype=None, is_csr=False): + + xp = X_table.__array_namespace__() + + if dtype is None: + dtype = xp.float32 + params = self._get_onedal_params(is_csr, dtype) assert X_table.dtype == dtype From 3f683ecc832c597275697c85dcdc9f7589567105 Mon Sep 17 00:00:00 2001 From: "Mcgrievy, Kathleen" Date: Tue, 22 Jul 2025 14:19:18 -0700 Subject: [PATCH 04/33] Updated format. --- onedal/cluster/kmeans.py | 1 - 1 file changed, 1 deletion(-) diff --git a/onedal/cluster/kmeans.py b/onedal/cluster/kmeans.py index 69538582b0..48f5ab40de 100644 --- a/onedal/cluster/kmeans.py +++ b/onedal/cluster/kmeans.py @@ -179,7 +179,6 @@ def _init_centroids_onedal( dtype=None, n_centroids=None, ): - xp = X_table.__array_namespace__() if dtype is None: From c73ad6521052556b0893c41f81e77caaa07e4ec2 Mon Sep 17 00:00:00 2001 From: "Mcgrievy, Kathleen" Date: Tue, 22 Jul 2025 14:22:17 -0700 Subject: [PATCH 05/33] Updated format. --- onedal/cluster/kmeans.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/onedal/cluster/kmeans.py b/onedal/cluster/kmeans.py index 48f5ab40de..364934f4a7 100644 --- a/onedal/cluster/kmeans.py +++ b/onedal/cluster/kmeans.py @@ -34,11 +34,13 @@ from sklearn.metrics.pairwise import euclidean_distances from sklearn.utils import check_random_state +from sklearnex.utils._array_api import get_namespace + from .._config import _get_config from ..common._mixin import ClusterMixin, TransformerMixin from ..datatypes import from_table, to_table from ..utils.validation import _check_array, _is_arraylike_not_scalar, _is_csr -from sklearnex.utils._array_api import get_namespace + class _BaseKMeans(TransformerMixin, ClusterMixin, ABC): def __init__( @@ -179,8 +181,9 @@ def _init_centroids_onedal( dtype=None, n_centroids=None, ): + xp = X_table.__array_namespace__() - + if dtype is None: dtype = xp.float32 @@ -229,7 +232,7 @@ def _init_centroids_sklearn(self, X, init, random_state, dtype=None): # using the scikit-learn implementation logging.getLogger("sklearnex").info("Computing KMeansInit with Stock sklearn") xp, _ = get_namespace(X) - + if dtype is None: dtype = xp.float32 @@ -260,8 +263,8 @@ def _init_centroids_sklearn(self, X, init, random_state, dtype=None): return to_table(centers, queue=getattr(QM.get_global_queue(), "_queue", None)) def _fit_backend(self, X_table, centroids_table, dtype=None, is_csr=False): - - xp = X_table.__array_namespace__() + + xp = X_table.__array_namespace__() if dtype is None: dtype = xp.float32 From 06f042272bf6a7d68457898d21a464f3dda3fc42 Mon Sep 17 00:00:00 2001 From: "Mcgrievy, Kathleen" Date: Tue, 22 Jul 2025 15:02:53 -0700 Subject: [PATCH 06/33] Fixed to get namespace from function. --- onedal/cluster/kmeans.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onedal/cluster/kmeans.py b/onedal/cluster/kmeans.py index 364934f4a7..3da6ef56f9 100644 --- a/onedal/cluster/kmeans.py +++ b/onedal/cluster/kmeans.py @@ -264,7 +264,8 @@ def _init_centroids_sklearn(self, X, init, random_state, dtype=None): def _fit_backend(self, X_table, centroids_table, dtype=None, is_csr=False): - xp = X_table.__array_namespace__() + xp, _ = get_namespace(X_table) + xp = X_table if dtype is None: dtype = xp.float32 @@ -421,7 +422,6 @@ def _score(self, X): ) def _transform(self, X): - xp, _ = get_namespace(X) return euclidean_distances(X, self.cluster_centers_) From cdda10d29538a41b26c27c72628967670e504b34 Mon Sep 17 00:00:00 2001 From: "Mcgrievy, Kathleen" Date: Tue, 22 Jul 2025 15:44:30 -0700 Subject: [PATCH 07/33] Updated format. --- onedal/cluster/kmeans.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onedal/cluster/kmeans.py b/onedal/cluster/kmeans.py index 3da6ef56f9..529233a6a9 100644 --- a/onedal/cluster/kmeans.py +++ b/onedal/cluster/kmeans.py @@ -182,7 +182,7 @@ def _init_centroids_onedal( n_centroids=None, ): - xp = X_table.__array_namespace__() + xp, _ = get_namespace(X_table) if dtype is None: dtype = xp.float32 From 292b97a8d2ba0f0e819f55f4802adc0477e8b340 Mon Sep 17 00:00:00 2001 From: "Mcgrievy, Kathleen" Date: Wed, 23 Jul 2025 13:26:02 -0700 Subject: [PATCH 08/33] Fixed for csr_matrix. --- onedal/utils/_array_api.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/onedal/utils/_array_api.py b/onedal/utils/_array_api.py index 56211197f9..00fb0107f5 100644 --- a/onedal/utils/_array_api.py +++ b/onedal/utils/_array_api.py @@ -20,9 +20,14 @@ from functools import lru_cache import numpy as np +import scipy.sparse as sp from ..utils._third_party import _is_subclass_fast +try: + from dpctl.tensor import usm_ndarray +except ImportError: + usm_ndarray = () # fallback if not available def _supports_buffer_protocol(obj): # the array_api standard mandates conversion with the buffer protocol, @@ -74,12 +79,14 @@ def _cls_to_sycl_namespace(cls): else: raise ValueError(f"SYCL type not recognized: {cls}") - def _get_sycl_namespace(*arrays): """Get namespace of sycl arrays.""" - - # sycl support designed to work regardless of array_api_dispatch sklearn global value - sua_iface = {type(x): x for x in arrays if hasattr(x, "__sycl_usm_array_interface__")} + # Accept only known dense SYCL-compatible arrays + sua_iface = { + type(x): x + for x in arrays + if isinstance(x, usm_ndarray) + } if len(sua_iface) > 1: raise ValueError(f"Multiple SYCL types for array inputs: {sua_iface}") From cc3672fd36a518e447fad52ec8af026b30290f4c Mon Sep 17 00:00:00 2001 From: "Mcgrievy, Kathleen" Date: Wed, 23 Jul 2025 13:31:24 -0700 Subject: [PATCH 09/33] Fixed formatting. --- onedal/utils/_array_api.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/onedal/utils/_array_api.py b/onedal/utils/_array_api.py index 00fb0107f5..909fa8cf1a 100644 --- a/onedal/utils/_array_api.py +++ b/onedal/utils/_array_api.py @@ -29,6 +29,7 @@ except ImportError: usm_ndarray = () # fallback if not available + def _supports_buffer_protocol(obj): # the array_api standard mandates conversion with the buffer protocol, # which can only be checked via a try-catch in native python @@ -79,14 +80,12 @@ def _cls_to_sycl_namespace(cls): else: raise ValueError(f"SYCL type not recognized: {cls}") + def _get_sycl_namespace(*arrays): """Get namespace of sycl arrays.""" # Accept only known dense SYCL-compatible arrays - sua_iface = { - type(x): x - for x in arrays - if isinstance(x, usm_ndarray) - } + + sua_iface = {type(x): x for x in arrays if isinstance(x, usm_ndarray)} if len(sua_iface) > 1: raise ValueError(f"Multiple SYCL types for array inputs: {sua_iface}") From 92bd00ef479d607797d923778b5b7c212405a718 Mon Sep 17 00:00:00 2001 From: "Mcgrievy, Kathleen" Date: Wed, 23 Jul 2025 13:59:09 -0700 Subject: [PATCH 10/33] Updated formatting. --- onedal/utils/_array_api.py | 27 ++++++++++++++++++++++++--- 1 file changed, 24 insertions(+), 3 deletions(-) diff --git a/onedal/utils/_array_api.py b/onedal/utils/_array_api.py index 909fa8cf1a..5f32b5e69e 100644 --- a/onedal/utils/_array_api.py +++ b/onedal/utils/_array_api.py @@ -81,11 +81,32 @@ def _cls_to_sycl_namespace(cls): raise ValueError(f"SYCL type not recognized: {cls}") +def _get_allowed_sycl_types(): + """Return a tuple of known SYCL-compatible dense array types.""" + allowed = [] + + try: + from dpctl.tensor import usm_ndarray + + allowed.append(usm_ndarray) + except ImportError: + pass + + try: + from dpnp import ndarray as dpnp_ndarray + + allowed.append(dpnp_ndarray) + except ImportError: + pass + + return tuple(allowed) + + def _get_sycl_namespace(*arrays): - """Get namespace of sycl arrays.""" - # Accept only known dense SYCL-compatible arrays + """Get namespace of SYCL-compatible arrays (excluding sparse or unsupported types).""" + allowed_sycl_types = _get_allowed_sycl_types() - sua_iface = {type(x): x for x in arrays if isinstance(x, usm_ndarray)} + sua_iface = {type(x): x for x in arrays if isinstance(x, allowed_sycl_types)} if len(sua_iface) > 1: raise ValueError(f"Multiple SYCL types for array inputs: {sua_iface}") From 7714377d5976cb6d6962d29988a93c99f1863dba Mon Sep 17 00:00:00 2001 From: "Mcgrievy, Kathleen" Date: Wed, 23 Jul 2025 14:04:16 -0700 Subject: [PATCH 11/33] Fixed formatting. --- onedal/utils/_array_api.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/onedal/utils/_array_api.py b/onedal/utils/_array_api.py index 5f32b5e69e..be40c15bdc 100644 --- a/onedal/utils/_array_api.py +++ b/onedal/utils/_array_api.py @@ -24,11 +24,6 @@ from ..utils._third_party import _is_subclass_fast -try: - from dpctl.tensor import usm_ndarray -except ImportError: - usm_ndarray = () # fallback if not available - def _supports_buffer_protocol(obj): # the array_api standard mandates conversion with the buffer protocol, From 90075179eb1d06d7bd0fe3929fac2c7e07a409db Mon Sep 17 00:00:00 2001 From: "Mcgrievy, Kathleen" Date: Wed, 23 Jul 2025 14:23:36 -0700 Subject: [PATCH 12/33] Fixed formatting. --- onedal/utils/_array_api.py | 40 +++++++++++++------------------------- 1 file changed, 14 insertions(+), 26 deletions(-) diff --git a/onedal/utils/_array_api.py b/onedal/utils/_array_api.py index be40c15bdc..56d15f0ea3 100644 --- a/onedal/utils/_array_api.py +++ b/onedal/utils/_array_api.py @@ -22,7 +22,7 @@ import numpy as np import scipy.sparse as sp -from ..utils._third_party import _is_subclass_fast +from ..utils._third_party import _is_subclass_fast, is_dpctl_tensor, is_dpnp_ndarray def _supports_buffer_protocol(obj): @@ -76,32 +76,20 @@ def _cls_to_sycl_namespace(cls): raise ValueError(f"SYCL type not recognized: {cls}") -def _get_allowed_sycl_types(): - """Return a tuple of known SYCL-compatible dense array types.""" - allowed = [] - - try: - from dpctl.tensor import usm_ndarray - - allowed.append(usm_ndarray) - except ImportError: - pass - - try: - from dpnp import ndarray as dpnp_ndarray - - allowed.append(dpnp_ndarray) - except ImportError: - pass - - return tuple(allowed) - - def _get_sycl_namespace(*arrays): - """Get namespace of SYCL-compatible arrays (excluding sparse or unsupported types).""" - allowed_sycl_types = _get_allowed_sycl_types() - - sua_iface = {type(x): x for x in arrays if isinstance(x, allowed_sycl_types)} + sua_iface = {} + for x in arrays: + try: + has_sycl = getattr(x, "__sycl_usm_array_interface__", None) is not None + except RuntimeError: + has_sycl = False + + if ( + has_sycl + and not isinstance(x, sp.spmatrix) + and (is_dpctl_tensor(x) or is_dpnp_ndarray(x)) + ): + sua_iface[type(x)] = x if len(sua_iface) > 1: raise ValueError(f"Multiple SYCL types for array inputs: {sua_iface}") From 2b83525eac894d9a349599be6a2f77a4ba0351ee Mon Sep 17 00:00:00 2001 From: "Mcgrievy, Kathleen" Date: Wed, 23 Jul 2025 14:35:59 -0700 Subject: [PATCH 13/33] Changed to use getattr. --- onedal/utils/_array_api.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/onedal/utils/_array_api.py b/onedal/utils/_array_api.py index 56d15f0ea3..e94281fd0f 100644 --- a/onedal/utils/_array_api.py +++ b/onedal/utils/_array_api.py @@ -77,19 +77,17 @@ def _cls_to_sycl_namespace(cls): def _get_sycl_namespace(*arrays): + """Get namespace of sycl arrays.""" + + # sycl support designed to work regardless of array_api_dispatch sklearn global value sua_iface = {} for x in arrays: try: - has_sycl = getattr(x, "__sycl_usm_array_interface__", None) is not None + if getattr(x, "__sycl_usm_array_interface__", None) is not None: + sua_iface[type(x)] = x except RuntimeError: - has_sycl = False - - if ( - has_sycl - and not isinstance(x, sp.spmatrix) - and (is_dpctl_tensor(x) or is_dpnp_ndarray(x)) - ): - sua_iface[type(x)] = x + # Skip objects that raise errors when accessing the attribute + continue if len(sua_iface) > 1: raise ValueError(f"Multiple SYCL types for array inputs: {sua_iface}") From b59ba3684e914f9edce618c65ad2f788cb10f38f Mon Sep 17 00:00:00 2001 From: "Mcgrievy, Kathleen" Date: Wed, 23 Jul 2025 14:38:42 -0700 Subject: [PATCH 14/33] Updated imports. --- onedal/utils/_array_api.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/onedal/utils/_array_api.py b/onedal/utils/_array_api.py index e94281fd0f..47a3f69a59 100644 --- a/onedal/utils/_array_api.py +++ b/onedal/utils/_array_api.py @@ -20,9 +20,8 @@ from functools import lru_cache import numpy as np -import scipy.sparse as sp -from ..utils._third_party import _is_subclass_fast, is_dpctl_tensor, is_dpnp_ndarray +from ..utils._third_party import _is_subclass_fast def _supports_buffer_protocol(obj): From 24f55b67cd5561201a49099ac2dff40d187b2a51 Mon Sep 17 00:00:00 2001 From: "Mcgrievy, Kathleen" Date: Wed, 23 Jul 2025 17:10:35 -0700 Subject: [PATCH 15/33] Formatted. --- onedal/cluster/kmeans.py | 2 +- onedal/utils/_array_api.py | 31 ++++++++++++++++++++----------- sklearnex/utils/_array_api.py | 13 ++++++++++++- 3 files changed, 33 insertions(+), 13 deletions(-) diff --git a/onedal/cluster/kmeans.py b/onedal/cluster/kmeans.py index 529233a6a9..cafb11f228 100644 --- a/onedal/cluster/kmeans.py +++ b/onedal/cluster/kmeans.py @@ -216,7 +216,7 @@ def _init_centroids_onedal( # oneDAL KMeans only supports Dense Centroids centers = init.toarray() else: - centers = np.asarray(init) + centers = xp.asarray(init) assert centers.shape[0] == n_clusters assert centers.shape[1] == X_table.column_count # KMeans is implemented on both CPU and GPU for Dense and CSR data diff --git a/onedal/utils/_array_api.py b/onedal/utils/_array_api.py index 47a3f69a59..507a73d3ba 100644 --- a/onedal/utils/_array_api.py +++ b/onedal/utils/_array_api.py @@ -20,9 +20,15 @@ from functools import lru_cache import numpy as np +import scipy.sparse as sp from ..utils._third_party import _is_subclass_fast +try: + from onedal._onedal_py_dpc import table as onedal_table_type +except ImportError: + onedal_table_type = type(None) + def _supports_buffer_protocol(obj): # the array_api standard mandates conversion with the buffer protocol, @@ -75,18 +81,21 @@ def _cls_to_sycl_namespace(cls): raise ValueError(f"SYCL type not recognized: {cls}") +def _is_valid_sycl_array(x): + try: + if getattr(x, "__sycl_usm_array_interface__", None) is None: + return False + except RuntimeError: + return False + + if isinstance(x, (sp.spmatrix, onedal_table_type)): + return False + + return True + + def _get_sycl_namespace(*arrays): - """Get namespace of sycl arrays.""" - - # sycl support designed to work regardless of array_api_dispatch sklearn global value - sua_iface = {} - for x in arrays: - try: - if getattr(x, "__sycl_usm_array_interface__", None) is not None: - sua_iface[type(x)] = x - except RuntimeError: - # Skip objects that raise errors when accessing the attribute - continue + sua_iface = {type(x): x for x in arrays if _is_valid_sycl_array(x)} if len(sua_iface) > 1: raise ValueError(f"Multiple SYCL types for array inputs: {sua_iface}") diff --git a/sklearnex/utils/_array_api.py b/sklearnex/utils/_array_api.py index 8162d5d040..de5d95653d 100644 --- a/sklearnex/utils/_array_api.py +++ b/sklearnex/utils/_array_api.py @@ -17,6 +17,7 @@ """Tools to support array_api.""" import numpy as np +from scipy import sparse as sp from daal4py.sklearn._utils import sklearn_check_version from onedal.utils._array_api import _get_sycl_namespace @@ -29,6 +30,11 @@ if sklearn_check_version("1.2"): from sklearn.utils._array_api import get_namespace as sklearn_get_namespace +try: + from onedal._onedal_py_dpc import table as onedal_table_type +except ImportError: + onedal_table_type = type(None) + def get_namespace(*arrays): """Get namespace of arrays. @@ -81,7 +87,12 @@ def get_namespace(*arrays): if sycl_type: return xp, is_array_api_compliant - elif sklearn_check_version("1.2"): + + for x in arrays: + if isinstance(x, (onedal_table_type, sp.spmatrix)): + return np, False + + if sklearn_check_version("1.2"): return sklearn_get_namespace(*arrays) else: return np, False From a6a70c57a2f8049047e2a628133ce360d0cce0e2 Mon Sep 17 00:00:00 2001 From: "Mcgrievy, Kathleen" Date: Thu, 31 Jul 2025 13:28:17 -0700 Subject: [PATCH 16/33] Fixes for get_namespace. --- onedal/datatypes/sycl_usm/data_conversion.cpp | 3 +- onedal/utils/_array_api.py | 24 +----- sklearnex/utils/_array_api.py | 76 ++++++++++++------- 3 files changed, 54 insertions(+), 49 deletions(-) diff --git a/onedal/datatypes/sycl_usm/data_conversion.cpp b/onedal/datatypes/sycl_usm/data_conversion.cpp index d5e370c3d6..21eb0c96b9 100644 --- a/onedal/datatypes/sycl_usm/data_conversion.cpp +++ b/onedal/datatypes/sycl_usm/data_conversion.cpp @@ -156,7 +156,8 @@ py::dict construct_sua_iface(const dal::table& input) { // for constructing DPCTL usm_ndarray or DPNP ndarray with zero-copy on python level. const auto kind = input.get_kind(); if (kind != dal::homogen_table::kind()) - report_problem_to_sua_iface(": only homogen tables are supported"); + throw py::attribute_error( + "__sycl_usm_array_interface__ not available: only homogen tables are supported"); const auto& homogen_input = reinterpret_cast(input); diff --git a/onedal/utils/_array_api.py b/onedal/utils/_array_api.py index 507a73d3ba..56211197f9 100644 --- a/onedal/utils/_array_api.py +++ b/onedal/utils/_array_api.py @@ -20,15 +20,9 @@ from functools import lru_cache import numpy as np -import scipy.sparse as sp from ..utils._third_party import _is_subclass_fast -try: - from onedal._onedal_py_dpc import table as onedal_table_type -except ImportError: - onedal_table_type = type(None) - def _supports_buffer_protocol(obj): # the array_api standard mandates conversion with the buffer protocol, @@ -81,21 +75,11 @@ def _cls_to_sycl_namespace(cls): raise ValueError(f"SYCL type not recognized: {cls}") -def _is_valid_sycl_array(x): - try: - if getattr(x, "__sycl_usm_array_interface__", None) is None: - return False - except RuntimeError: - return False - - if isinstance(x, (sp.spmatrix, onedal_table_type)): - return False - - return True - - def _get_sycl_namespace(*arrays): - sua_iface = {type(x): x for x in arrays if _is_valid_sycl_array(x)} + """Get namespace of sycl arrays.""" + + # sycl support designed to work regardless of array_api_dispatch sklearn global value + sua_iface = {type(x): x for x in arrays if hasattr(x, "__sycl_usm_array_interface__")} if len(sua_iface) > 1: raise ValueError(f"Multiple SYCL types for array inputs: {sua_iface}") diff --git a/sklearnex/utils/_array_api.py b/sklearnex/utils/_array_api.py index de5d95653d..3308023e12 100644 --- a/sklearnex/utils/_array_api.py +++ b/sklearnex/utils/_array_api.py @@ -16,8 +16,10 @@ """Tools to support array_api.""" +from collections.abc import Callable +from typing import Union + import numpy as np -from scipy import sparse as sp from daal4py.sklearn._utils import sklearn_check_version from onedal.utils._array_api import _get_sycl_namespace @@ -30,11 +32,6 @@ if sklearn_check_version("1.2"): from sklearn.utils._array_api import get_namespace as sklearn_get_namespace -try: - from onedal._onedal_py_dpc import table as onedal_table_type -except ImportError: - onedal_table_type = type(None) - def get_namespace(*arrays): """Get namespace of arrays. @@ -87,33 +84,13 @@ def get_namespace(*arrays): if sycl_type: return xp, is_array_api_compliant - - for x in arrays: - if isinstance(x, (onedal_table_type, sp.spmatrix)): - return np, False - - if sklearn_check_version("1.2"): + elif sklearn_check_version("1.2"): return sklearn_get_namespace(*arrays) else: return np, False -def enable_array_api(original_class: type[oneDALEstimator]) -> type[oneDALEstimator]: - """Enable sklearnex to use dpctl, dpnp or array_api inputs in oneDAL offloading. - - This wrapper sets the proper flags/tags for the sklearnex infrastructure - to maintain the data framework, as the estimator can use it natively. - - Parameters - ---------- - original_class : oneDALEstimator subclass - Class which should enable data zero-copy support in sklearnex. - - Returns - ------- - original_class : modified oneDALEstimator subclass - Estimator class. - """ +def _enable_array_api(original_class: type[oneDALEstimator]) -> type[oneDALEstimator]: if sklearn_check_version("1.6"): def __sklearn_tags__(self) -> Tags: @@ -131,3 +108,46 @@ def _more_tags(self) -> dict[str, bool]: original_class._more_tags = _more_tags return original_class + + +def enable_array_api( + class_or_str: Union[type[oneDALEstimator], str], +) -> Union[type[oneDALEstimator], Callable]: + """Enable sklearnex to use dpctl, dpnp or array API inputs in oneDAL offloading. + + This wrapper sets the proper flags/tags for the sklearnex infrastructure + to maintain the data framework, as the estimator can use it natively. + + Parameters + ---------- + class_or_str : oneDALEstimator subclass or str + Class which should enable data zero-copy support in sklearnex. By + default it will enable for sklearn versions >1.3. If the wrapper is + decorated with an argument, it must be a string defining the oldest + sklearn version where array API support begins. + + Returns + ------- + cls or wrapper : modified oneDALEstimator subclass or wrapper + Estimator class or wrapper. + + Examples + -------- + @enable_array_api # default array API support + class PCA(): + ... + + @enable_array_api("1.5") # array API support for sklearn > 1.5 + class Ridge(): + ... + """ + if isinstance(class_or_str, str): + # enable array_api for the estimator for a given sklearn version str + if sklearn_check_version(class_or_str): + return _enable_array_api + else: + # do not apply the wrapper as it is not supported + return lambda x: x + else: + # default setting (apply array_api enablement for sklearn >=1.3) + return _enable_array_api(class_or_str) From 510e6e9c8a0aa23181ee556c2e10d0d7444041fe Mon Sep 17 00:00:00 2001 From: "Mcgrievy, Kathleen" Date: Fri, 1 Aug 2025 11:30:16 -0700 Subject: [PATCH 17/33] Updated formatting. --- onedal/utils/_array_api.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/onedal/utils/_array_api.py b/onedal/utils/_array_api.py index 56211197f9..2dd29c209a 100644 --- a/onedal/utils/_array_api.py +++ b/onedal/utils/_array_api.py @@ -23,6 +23,11 @@ from ..utils._third_party import _is_subclass_fast +try: + from onedal._onedal_py_dpc import table as onedal_table_type +except ImportError: + onedal_table_type = type(None) + def _supports_buffer_protocol(obj): # the array_api standard mandates conversion with the buffer protocol, @@ -91,5 +96,7 @@ def _get_sycl_namespace(*arrays): _cls_to_sycl_namespace(type(X)), hasattr(X, "__array_namespace__"), ) + elif any(isinstance(x, onedal_table_type) for x in arrays): + return True, np, False return sua_iface, np, False From bd1818e03e39377489c4476dfc787dc9c8ea6d8a Mon Sep 17 00:00:00 2001 From: "Mcgrievy, Kathleen" Date: Fri, 1 Aug 2025 13:36:32 -0700 Subject: [PATCH 18/33] Updated format. --- onedal/cluster/kmeans.py | 18 +++++++++++------- onedal/utils/_array_api.py | 7 ------- 2 files changed, 11 insertions(+), 14 deletions(-) diff --git a/onedal/cluster/kmeans.py b/onedal/cluster/kmeans.py index 2b6d79e07e..bb8d28efed 100644 --- a/onedal/cluster/kmeans.py +++ b/onedal/cluster/kmeans.py @@ -34,6 +34,7 @@ from sklearn.metrics.pairwise import euclidean_distances from sklearn.utils import check_random_state +from onedal._onedal_py_dpc import table from sklearnex.utils._array_api import get_namespace from .._config import _get_config @@ -182,10 +183,12 @@ def _init_centroids_onedal( n_centroids=None, ): - xp, _ = get_namespace(X_table) - if dtype is None: - dtype = xp.float32 + if isinstance(X_table, table): + dtype = np.float32 + else: + xp, _ = get_namespace(X_table) + dtype = xp.float32 n_clusters = self.n_clusters if n_centroids is None else n_centroids @@ -264,11 +267,12 @@ def _init_centroids_sklearn(self, X, init, random_state, dtype=None): def _fit_backend(self, X_table, centroids_table, dtype=None, is_csr=False): - xp, _ = get_namespace(X_table) - xp = X_table - if dtype is None: - dtype = xp.float32 + if isinstance(X_table, table): + dtype = np.float32 + else: + xp, _ = get_namespace(X_table) + dtype = xp.float32 params = self._get_onedal_params(is_csr, dtype) diff --git a/onedal/utils/_array_api.py b/onedal/utils/_array_api.py index 2dd29c209a..56211197f9 100644 --- a/onedal/utils/_array_api.py +++ b/onedal/utils/_array_api.py @@ -23,11 +23,6 @@ from ..utils._third_party import _is_subclass_fast -try: - from onedal._onedal_py_dpc import table as onedal_table_type -except ImportError: - onedal_table_type = type(None) - def _supports_buffer_protocol(obj): # the array_api standard mandates conversion with the buffer protocol, @@ -96,7 +91,5 @@ def _get_sycl_namespace(*arrays): _cls_to_sycl_namespace(type(X)), hasattr(X, "__array_namespace__"), ) - elif any(isinstance(x, onedal_table_type) for x in arrays): - return True, np, False return sua_iface, np, False From 6710d0345ce7be821ae7e6ed4f9966ab86b0449a Mon Sep 17 00:00:00 2001 From: "Mcgrievy, Kathleen" Date: Fri, 1 Aug 2025 13:49:58 -0700 Subject: [PATCH 19/33] Updated the format. --- onedal/cluster/kmeans.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/onedal/cluster/kmeans.py b/onedal/cluster/kmeans.py index bb8d28efed..7b83f41bee 100644 --- a/onedal/cluster/kmeans.py +++ b/onedal/cluster/kmeans.py @@ -95,6 +95,16 @@ def _get_kmeans_init(self, cluster_count, seed, algorithm, is_csr): is_csr=is_csr, ) + def _infer_dtype(self, X_table, dtype=None): + if dtype is not None: + return dtype + + if isinstance(X_table, table): + return np.float32 + + xp, _ = get_namespace(X_table) + return xp.float32 + # Get appropriate backend (required for SPMD) def _get_basic_statistics_backend(self, result_options): return BasicStatistics(result_options) @@ -183,12 +193,7 @@ def _init_centroids_onedal( n_centroids=None, ): - if dtype is None: - if isinstance(X_table, table): - dtype = np.float32 - else: - xp, _ = get_namespace(X_table) - dtype = xp.float32 + dtype = self._infer_dtype(X_table, dtype) n_clusters = self.n_clusters if n_centroids is None else n_centroids @@ -267,12 +272,7 @@ def _init_centroids_sklearn(self, X, init, random_state, dtype=None): def _fit_backend(self, X_table, centroids_table, dtype=None, is_csr=False): - if dtype is None: - if isinstance(X_table, table): - dtype = np.float32 - else: - xp, _ = get_namespace(X_table) - dtype = xp.float32 + dtype = self._infer_dtype(X_table, dtype) params = self._get_onedal_params(is_csr, dtype) From 612c8bb784ce89cf9020537dfe41713cba630149 Mon Sep 17 00:00:00 2001 From: "Mcgrievy, Kathleen" Date: Fri, 1 Aug 2025 14:39:03 -0700 Subject: [PATCH 20/33] Fixed format. --- onedal/cluster/kmeans.py | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/onedal/cluster/kmeans.py b/onedal/cluster/kmeans.py index 7b83f41bee..7bbeb9d8f7 100644 --- a/onedal/cluster/kmeans.py +++ b/onedal/cluster/kmeans.py @@ -95,15 +95,20 @@ def _get_kmeans_init(self, cluster_count, seed, algorithm, is_csr): is_csr=is_csr, ) + def _infer_namespace(self, X_table): + if isinstance(X_table, table): + xp = np + else: + xp, _ = get_namespace(X_table) + return xp + def _infer_dtype(self, X_table, dtype=None): - if dtype is not None: - return dtype + xp = self._infer_namespace(X_table) - if isinstance(X_table, table): - return np.float32 + if dtype is not None: + return xp, dtype - xp, _ = get_namespace(X_table) - return xp.float32 + return xp, xp.float32 # Get appropriate backend (required for SPMD) def _get_basic_statistics_backend(self, result_options): @@ -193,7 +198,7 @@ def _init_centroids_onedal( n_centroids=None, ): - dtype = self._infer_dtype(X_table, dtype) + xp, dtype = self._infer_dtype(X_table, dtype) n_clusters = self.n_clusters if n_centroids is None else n_centroids @@ -272,7 +277,7 @@ def _init_centroids_sklearn(self, X, init, random_state, dtype=None): def _fit_backend(self, X_table, centroids_table, dtype=None, is_csr=False): - dtype = self._infer_dtype(X_table, dtype) + xp, dtype = self._infer_dtype(X_table, dtype) params = self._get_onedal_params(is_csr, dtype) From c5a44e4ae4c6dad7b78f73f43d29bf5a155f17ce Mon Sep 17 00:00:00 2001 From: "Mcgrievy, Kathleen" Date: Fri, 1 Aug 2025 15:15:19 -0700 Subject: [PATCH 21/33] Don't need array api stuff for X_tables. --- onedal/cluster/kmeans.py | 21 +++++++++------------ 1 file changed, 9 insertions(+), 12 deletions(-) diff --git a/onedal/cluster/kmeans.py b/onedal/cluster/kmeans.py index 7bbeb9d8f7..2bf7dbbd2f 100644 --- a/onedal/cluster/kmeans.py +++ b/onedal/cluster/kmeans.py @@ -130,6 +130,7 @@ def _tolerance(self, X_table, rtol, is_csr, dtype): def _check_params_vs_input( self, X_table, is_csr, default_n_init=10, dtype=np.float32 ): + # n_clusters if X_table.shape[0] < self.n_clusters: raise ValueError( @@ -194,12 +195,10 @@ def _init_centroids_onedal( init, random_seed, is_csr, - dtype=None, + dtype=np.float32, n_centroids=None, ): - xp, dtype = self._infer_dtype(X_table, dtype) - n_clusters = self.n_clusters if n_centroids is None else n_centroids if isinstance(init, str) and init == "k-means++": @@ -229,7 +228,7 @@ def _init_centroids_onedal( # oneDAL KMeans only supports Dense Centroids centers = init.toarray() else: - centers = xp.asarray(init) + centers = np.asarray(init) assert centers.shape[0] == n_clusters assert centers.shape[1] == X_table.column_count # KMeans is implemented on both CPU and GPU for Dense and CSR data @@ -244,10 +243,7 @@ def _init_centroids_sklearn(self, X, init, random_state, dtype=None): # For oneDAL versions < 2023.2 or callable init, # using the scikit-learn implementation logging.getLogger("sklearnex").info("Computing KMeansInit with Stock sklearn") - xp, _ = get_namespace(X) - - if dtype is None: - dtype = xp.float32 + xp, dbtype = self._infer_dtype(X, dtype) n_samples = X.shape[0] @@ -275,9 +271,7 @@ def _init_centroids_sklearn(self, X, init, random_state, dtype=None): return to_table(centers, queue=getattr(QM.get_global_queue(), "_queue", None)) - def _fit_backend(self, X_table, centroids_table, dtype=None, is_csr=False): - - xp, dtype = self._infer_dtype(X_table, dtype) + def _fit_backend(self, X_table, centroids_table, dtype=np.float32, is_csr=False): params = self._get_onedal_params(is_csr, dtype) @@ -295,13 +289,16 @@ def _fit_backend(self, X_table, centroids_table, dtype=None, is_csr=False): def _fit(self, X): is_csr = _is_csr(X) + xp = self._infer_namespace(X) + if _get_config()["use_raw_input"] is False: X = _check_array( X, - dtype=[np.float64, np.float32], + dtype=[xp.float64, xp.float32], accept_sparse="csr", force_all_finite=False, ) + X_table = to_table(X, queue=QM.get_global_queue()) dtype = X_table.dtype From fb806e0ed798815f5c5bed4739ec67a6df906a5e Mon Sep 17 00:00:00 2001 From: "Mcgrievy, Kathleen" Date: Fri, 1 Aug 2025 15:57:33 -0700 Subject: [PATCH 22/33] Removed table import. --- onedal/cluster/kmeans.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/onedal/cluster/kmeans.py b/onedal/cluster/kmeans.py index 2bf7dbbd2f..8f3c4079b7 100644 --- a/onedal/cluster/kmeans.py +++ b/onedal/cluster/kmeans.py @@ -34,7 +34,6 @@ from sklearn.metrics.pairwise import euclidean_distances from sklearn.utils import check_random_state -from onedal._onedal_py_dpc import table from sklearnex.utils._array_api import get_namespace from .._config import _get_config @@ -96,10 +95,7 @@ def _get_kmeans_init(self, cluster_count, seed, algorithm, is_csr): ) def _infer_namespace(self, X_table): - if isinstance(X_table, table): - xp = np - else: - xp, _ = get_namespace(X_table) + xp, _ = get_namespace(X_table) return xp def _infer_dtype(self, X_table, dtype=None): From 78a364d6f9fc60c450a0540ae37fdf611ccdb939 Mon Sep 17 00:00:00 2001 From: "Mcgrievy, Kathleen" Date: Mon, 4 Aug 2025 09:01:09 -0700 Subject: [PATCH 23/33] Updated format. --- onedal/cluster/kmeans.py | 19 +++++-------------- 1 file changed, 5 insertions(+), 14 deletions(-) diff --git a/onedal/cluster/kmeans.py b/onedal/cluster/kmeans.py index 8f3c4079b7..079859ced1 100644 --- a/onedal/cluster/kmeans.py +++ b/onedal/cluster/kmeans.py @@ -94,18 +94,6 @@ def _get_kmeans_init(self, cluster_count, seed, algorithm, is_csr): is_csr=is_csr, ) - def _infer_namespace(self, X_table): - xp, _ = get_namespace(X_table) - return xp - - def _infer_dtype(self, X_table, dtype=None): - xp = self._infer_namespace(X_table) - - if dtype is not None: - return xp, dtype - - return xp, xp.float32 - # Get appropriate backend (required for SPMD) def _get_basic_statistics_backend(self, result_options): return BasicStatistics(result_options) @@ -239,7 +227,10 @@ def _init_centroids_sklearn(self, X, init, random_state, dtype=None): # For oneDAL versions < 2023.2 or callable init, # using the scikit-learn implementation logging.getLogger("sklearnex").info("Computing KMeansInit with Stock sklearn") - xp, dbtype = self._infer_dtype(X, dtype) + xp, _ = get_namespace(X_table) + + if dtype is None: + dtype = xp.float32 n_samples = X.shape[0] @@ -285,7 +276,7 @@ def _fit_backend(self, X_table, centroids_table, dtype=np.float32, is_csr=False) def _fit(self, X): is_csr = _is_csr(X) - xp = self._infer_namespace(X) + xp, _ = get_namespace(X) if _get_config()["use_raw_input"] is False: X = _check_array( From 8d422cdb30cfd8168b32f76a7eeb9f1df0d37ab6 Mon Sep 17 00:00:00 2001 From: "Mcgrievy, Kathleen" Date: Mon, 4 Aug 2025 09:12:40 -0700 Subject: [PATCH 24/33] Fix what is being passed into namespace. --- onedal/cluster/kmeans.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onedal/cluster/kmeans.py b/onedal/cluster/kmeans.py index 079859ced1..19f2c7d740 100644 --- a/onedal/cluster/kmeans.py +++ b/onedal/cluster/kmeans.py @@ -227,7 +227,7 @@ def _init_centroids_sklearn(self, X, init, random_state, dtype=None): # For oneDAL versions < 2023.2 or callable init, # using the scikit-learn implementation logging.getLogger("sklearnex").info("Computing KMeansInit with Stock sklearn") - xp, _ = get_namespace(X_table) + xp, _ = get_namespace(X) if dtype is None: dtype = xp.float32 From d8254ea6e5dc7e3d3c5cae8c0c9de629768a800d Mon Sep 17 00:00:00 2001 From: "Mcgrievy, Kathleen" Date: Mon, 4 Aug 2025 09:58:21 -0700 Subject: [PATCH 25/33] Updated format. --- onedal/cluster/kmeans.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/onedal/cluster/kmeans.py b/onedal/cluster/kmeans.py index 19f2c7d740..8748765968 100644 --- a/onedal/cluster/kmeans.py +++ b/onedal/cluster/kmeans.py @@ -226,7 +226,13 @@ def _init_centroids_onedal( def _init_centroids_sklearn(self, X, init, random_state, dtype=None): # For oneDAL versions < 2023.2 or callable init, # using the scikit-learn implementation + import logging + + from sklearn.utils._array_api import get_namespace + logging.getLogger("sklearnex").info("Computing KMeansInit with Stock sklearn") + + # Get the Array API namespace (xp) and original type info (_) xp, _ = get_namespace(X) if dtype is None: @@ -245,7 +251,8 @@ def _init_centroids_sklearn(self, X, init, random_state, dtype=None): centers = X[seeds] elif callable(init): cc_arr = init(X, self.n_clusters, random_state) - cc_arr = np.ascontiguousarray(cc_arr, dtype=dtype) + if cc_arr.dtype != dtype: + cc_arr = xp.astype(cc_arr, dtype) self._validate_center_shape(X, cc_arr) centers = cc_arr elif _is_arraylike_not_scalar(init): @@ -253,7 +260,7 @@ def _init_centroids_sklearn(self, X, init, random_state, dtype=None): else: raise ValueError( f"init should be either 'k-means++', 'random', a ndarray or a " - f"callable, got '{ init }' instead." + f"callable, got '{init}' instead." ) return to_table(centers, queue=getattr(QM.get_global_queue(), "_queue", None)) From f4c4a0b8b4db6916d48236a52725a0bcca14f890 Mon Sep 17 00:00:00 2001 From: "Mcgrievy, Kathleen" Date: Mon, 4 Aug 2025 10:48:44 -0700 Subject: [PATCH 26/33] Updated init. --- onedal/cluster/kmeans.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/onedal/cluster/kmeans.py b/onedal/cluster/kmeans.py index 8748765968..a16bc174c9 100644 --- a/onedal/cluster/kmeans.py +++ b/onedal/cluster/kmeans.py @@ -251,8 +251,9 @@ def _init_centroids_sklearn(self, X, init, random_state, dtype=None): centers = X[seeds] elif callable(init): cc_arr = init(X, self.n_clusters, random_state) - if cc_arr.dtype != dtype: - cc_arr = xp.astype(cc_arr, dtype) + xp_cc_arr, _ = get_namespace(cc_arr) + if xp_cc_arr.dtype != dtype: + cc_arr = xp_cc_arr.astype(cc_arr, dtype) self._validate_center_shape(X, cc_arr) centers = cc_arr elif _is_arraylike_not_scalar(init): From 117d8bd343929670f18f63cf01d03dde8d90f3b6 Mon Sep 17 00:00:00 2001 From: "Mcgrievy, Kathleen" Date: Mon, 4 Aug 2025 13:15:11 -0700 Subject: [PATCH 27/33] Updated format. --- onedal/cluster/kmeans.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/onedal/cluster/kmeans.py b/onedal/cluster/kmeans.py index a16bc174c9..8cab11cd7d 100644 --- a/onedal/cluster/kmeans.py +++ b/onedal/cluster/kmeans.py @@ -228,8 +228,6 @@ def _init_centroids_sklearn(self, X, init, random_state, dtype=None): # using the scikit-learn implementation import logging - from sklearn.utils._array_api import get_namespace - logging.getLogger("sklearnex").info("Computing KMeansInit with Stock sklearn") # Get the Array API namespace (xp) and original type info (_) @@ -251,9 +249,15 @@ def _init_centroids_sklearn(self, X, init, random_state, dtype=None): centers = X[seeds] elif callable(init): cc_arr = init(X, self.n_clusters, random_state) - xp_cc_arr, _ = get_namespace(cc_arr) - if xp_cc_arr.dtype != dtype: - cc_arr = xp_cc_arr.astype(cc_arr, dtype) + + # Try Array API path + if hasattr(cc_arr, "__array_namespace__"): + xp, _ = get_namespace(cc_arr) + if cc_arr.dtype != dtype: + cc_arr = xp.astype(cc_arr, dtype) + else: + cc_arr = np.ascontiguousarray(cc_arr, dtype=dtype) + self._validate_center_shape(X, cc_arr) centers = cc_arr elif _is_arraylike_not_scalar(init): From dd4d424cfc1fde54801b350beffb76a880908024 Mon Sep 17 00:00:00 2001 From: "Mcgrievy, Kathleen" Date: Mon, 4 Aug 2025 13:17:14 -0700 Subject: [PATCH 28/33] fixed init. --- onedal/cluster/kmeans.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onedal/cluster/kmeans.py b/onedal/cluster/kmeans.py index 8cab11cd7d..60eb6f453b 100644 --- a/onedal/cluster/kmeans.py +++ b/onedal/cluster/kmeans.py @@ -253,8 +253,8 @@ def _init_centroids_sklearn(self, X, init, random_state, dtype=None): # Try Array API path if hasattr(cc_arr, "__array_namespace__"): xp, _ = get_namespace(cc_arr) - if cc_arr.dtype != dtype: - cc_arr = xp.astype(cc_arr, dtype) + if cc_arr.dtype != dtype: + cc_arr = xp.astype(cc_arr, dtype) else: cc_arr = np.ascontiguousarray(cc_arr, dtype=dtype) From 0bb921c563eea01303c189505a054d8d7baed2f3 Mon Sep 17 00:00:00 2001 From: "Mcgrievy, Kathleen" Date: Mon, 4 Aug 2025 13:25:18 -0700 Subject: [PATCH 29/33] Updated name. --- onedal/cluster/kmeans.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onedal/cluster/kmeans.py b/onedal/cluster/kmeans.py index 60eb6f453b..43440a9907 100644 --- a/onedal/cluster/kmeans.py +++ b/onedal/cluster/kmeans.py @@ -252,9 +252,9 @@ def _init_centroids_sklearn(self, X, init, random_state, dtype=None): # Try Array API path if hasattr(cc_arr, "__array_namespace__"): - xp, _ = get_namespace(cc_arr) + xp_cc_arr, _ = get_namespace(cc_arr) if cc_arr.dtype != dtype: - cc_arr = xp.astype(cc_arr, dtype) + cc_arr = xp_cc_arr.astype(cc_arr, dtype) else: cc_arr = np.ascontiguousarray(cc_arr, dtype=dtype) From ccf2a9f8fe26ed637f661798d8cbfc5f76ab26b1 Mon Sep 17 00:00:00 2001 From: "Mcgrievy, Kathleen" Date: Wed, 6 Aug 2025 08:32:56 -0700 Subject: [PATCH 30/33] merged other changes. --- onedal/cluster/kmeans.py | 1 - 1 file changed, 1 deletion(-) diff --git a/onedal/cluster/kmeans.py b/onedal/cluster/kmeans.py index 43440a9907..c8848a226c 100644 --- a/onedal/cluster/kmeans.py +++ b/onedal/cluster/kmeans.py @@ -226,7 +226,6 @@ def _init_centroids_onedal( def _init_centroids_sklearn(self, X, init, random_state, dtype=None): # For oneDAL versions < 2023.2 or callable init, # using the scikit-learn implementation - import logging logging.getLogger("sklearnex").info("Computing KMeansInit with Stock sklearn") From e634e06a963a794949facab13f952d500fc961b1 Mon Sep 17 00:00:00 2001 From: "Mcgrievy, Kathleen" Date: Wed, 6 Aug 2025 13:24:17 -0700 Subject: [PATCH 31/33] Formated. --- onedal/cluster/kmeans.py | 382 ++++++++++++++------------------------- 1 file changed, 131 insertions(+), 251 deletions(-) diff --git a/onedal/cluster/kmeans.py b/onedal/cluster/kmeans.py index c8848a226c..8fddce782d 100644 --- a/onedal/cluster/kmeans.py +++ b/onedal/cluster/kmeans.py @@ -26,6 +26,31 @@ from onedal.common._backend import bind_default_backend from onedal.utils import _sycl_queue_manager as QM +if daal_check_version((2023, "P", 200)): + from .kmeans_init import KMeansInit + +import logging +import warnings +from abc import ABC + +import numpy as np +from sklearn.cluster._kmeans import _kmeans_plusplus +from sklearn.exceptions import ConvergenceWarning +from sklearn.metrics.pairwise import euclidean_distances +from sklearn.utils import check_random_state + +from daal4py.sklearn._utils import daal_check_version +from onedal._device_offload import supports_queue +from onedal.basic_statistics import BasicStatistics +from onedal.common._backend import bind_default_backend +from onedal.utils import _sycl_queue_manager as QM +from sklearnex.utils._array_api import get_namespace + +from .._config import _get_config +from ..common._mixin import ClusterMixin, TransformerMixin +from ..datatypes import from_table, to_table +from ..utils.validation import _check_array, _is_arraylike_not_scalar, _is_csr + if daal_check_version((2023, "P", 200)): from .kmeans_init import KMeansInit @@ -54,15 +79,29 @@ def __init__( verbose, random_state, n_local_trials=None, + algorithm="lloyd", ): + # __init__ only stores user-visible params self.n_clusters = n_clusters self.init = init + self.n_init = n_init self.max_iter = max_iter self.tol = tol - self.n_init = n_init self.verbose = verbose self.random_state = random_state self.n_local_trials = n_local_trials + self.algorithm = algorithm # kept for parity; we support "lloyd" only + + # runtime/learned attrs (set during fit) + self._tol = None + self.model_ = None + self.n_iter_ = None + self.inertia_ = None + self.labels_ = None + self.n_features_in_ = None + self._cluster_centers_ = None + + # --- pybind11 backends (thin proxies) --- @bind_default_backend("kmeans_common", no_policy=True) def _is_same_clustering(self, labels, best_labels, n_clusters): ... @@ -71,10 +110,14 @@ def _is_same_clustering(self, labels, best_labels, n_clusters): ... def train(self, params, X_table, centroids_table): ... @bind_default_backend("kmeans.clustering") - def infer(self, params, model, centroids_table): ... + def infer(self, params, model, X_table): ... + + # --- helpers matching the pattern --- + + def _get_basic_statistics_backend(self, result_options): + return BasicStatistics(result_options) def _validate_center_shape(self, X, centers): - """Check if centers is compatible with X and n_clusters.""" if centers.shape[0] != self.n_clusters: raise ValueError( f"The shape of the initial centers {centers.shape} does not " @@ -86,53 +129,33 @@ def _validate_center_shape(self, X, centers): f"match the number of features of the data {X.shape[1]}." ) - def _get_kmeans_init(self, cluster_count, seed, algorithm, is_csr): - return KMeansInit( - cluster_count=cluster_count, - seed=seed, - algorithm=algorithm, - is_csr=is_csr, - ) - - # Get appropriate backend (required for SPMD) - def _get_basic_statistics_backend(self, result_options): - return BasicStatistics(result_options) - def _tolerance(self, X_table, rtol, is_csr, dtype): - """Compute absolute tolerance from the relative tolerance""" if rtol == 0.0: - return rtol + return 0.0 dummy = to_table(None) - bs = self._get_basic_statistics_backend("variance") - res = bs._compute_raw(X_table, dummy, dtype, is_csr) mean_var = from_table(res.variance).mean() - return mean_var * rtol def _check_params_vs_input( self, X_table, is_csr, default_n_init=10, dtype=np.float32 ): - - # n_clusters if X_table.shape[0] < self.n_clusters: raise ValueError( f"n_samples={X_table.shape[0]} should be >= n_clusters={self.n_clusters}." ) - - # tol + # compute absolute tolerance once we know dtype self._tol = self._tolerance(X_table, self.tol, is_csr, dtype) - # n-init - # TODO(1.4): Remove + # n_init resolution (kept from your logic) self._n_init = self.n_init if self._n_init == "warn": warnings.warn( ( "The default value of `n_init` will change from " - f"{default_n_init} to 'auto' in 1.4. Set the value of `n_init`" - " explicitly to suppress the warning" + f"{default_n_init} to 'auto' in 1.4. Set `n_init` explicitly " + "to suppress the warning" ), FutureWarning, stacklevel=2, @@ -151,20 +174,24 @@ def _check_params_vs_input( if _is_arraylike_not_scalar(self.init) and self._n_init != 1: warnings.warn( ( - "Explicit initial center position passed: performing only" - f" one init in {self.__class__.__name__} instead of " + "Explicit initial center position passed: performing only " + f"one init in {self.__class__.__name__} instead of " f"n_init={self._n_init}." ), RuntimeWarning, stacklevel=2, ) self._n_init = 1 + + # only "lloyd" is supported in this implementation assert self.algorithm == "lloyd" def _get_onedal_params(self, is_csr=False, dtype=np.float32, result_options=None): - thr = self._tol if hasattr(self, "_tol") else self.tol + thr = self._tol if self._tol is not None else self.tol return { + # fptype chosen from input table dtype (pattern) "fptype": dtype, + # map method names to backend dispatch (CSR vs dense) "method": "lloyd_csr" if is_csr else "by_default", "seed": -1, "max_iteration_count": self.max_iter, @@ -173,6 +200,14 @@ def _get_onedal_params(self, is_csr=False, dtype=np.float32, result_options=None "result_options": "" if result_options is None else result_options, } + def _get_kmeans_init(self, cluster_count, seed, algorithm, is_csr): + return KMeansInit( + cluster_count=cluster_count, + seed=seed, + algorithm=algorithm, + is_csr=is_csr, + ) + def _init_centroids_onedal( self, X_table, @@ -182,81 +217,46 @@ def _init_centroids_onedal( dtype=np.float32, n_centroids=None, ): - n_clusters = self.n_clusters if n_centroids is None else n_centroids - if isinstance(init, str) and init == "k-means++": algorithm = "plus_plus_dense" if not is_csr else "plus_plus_csr" - alg = self._get_kmeans_init( - cluster_count=n_clusters, - seed=random_seed, - algorithm=algorithm, - is_csr=is_csr, - ) - # We pass down the queue that was set through the KMeans.fit() - queue = QM.get_global_queue() - centers_table = alg.compute_raw(X_table, dtype, queue=queue) elif isinstance(init, str) and init == "random": algorithm = "random_dense" if not is_csr else "random_csr" - alg = self._get_kmeans_init( - cluster_count=n_clusters, - seed=random_seed, - algorithm=algorithm, - is_csr=is_csr, - ) - # We pass down the queue that was set through the KMeans.fit() - queue = QM.get_global_queue() - centers_table = alg.compute_raw(X_table, dtype, queue=queue) elif _is_arraylike_not_scalar(init): - if _is_csr(init): - # oneDAL KMeans only supports Dense Centroids - centers = init.toarray() - else: - centers = np.asarray(init) - assert centers.shape[0] == n_clusters - assert centers.shape[1] == X_table.column_count - # KMeans is implemented on both CPU and GPU for Dense and CSR data - # The original policy can be used here - centers_table = to_table(centers, queue=QM.get_global_queue()) + centers = init.toarray() if _is_csr(init) else np.asarray(init) + self._validate_center_shape(np.empty((0, X_table.column_count)), centers) + return to_table(centers, queue=QM.get_global_queue()) else: raise TypeError("Unsupported type of the `init` value") - return centers_table + alg = self._get_kmeans_init( + cluster_count=n_clusters, + seed=random_seed, + algorithm=algorithm, + is_csr=is_csr, + ) + return alg.compute_raw(X_table, dtype, queue=QM.get_global_queue()) def _init_centroids_sklearn(self, X, init, random_state, dtype=None): - # For oneDAL versions < 2023.2 or callable init, - # using the scikit-learn implementation - logging.getLogger("sklearnex").info("Computing KMeansInit with Stock sklearn") - - # Get the Array API namespace (xp) and original type info (_) xp, _ = get_namespace(X) - if dtype is None: dtype = xp.float32 n_samples = X.shape[0] - if isinstance(init, str) and init == "k-means++": - centers, _ = _kmeans_plusplus( - X, - self.n_clusters, - random_state=random_state, - ) + centers, _ = _kmeans_plusplus(X, self.n_clusters, random_state=random_state) elif isinstance(init, str) and init == "random": seeds = random_state.choice(n_samples, size=self.n_clusters, replace=False) centers = X[seeds] elif callable(init): cc_arr = init(X, self.n_clusters, random_state) - - # Try Array API path if hasattr(cc_arr, "__array_namespace__"): xp_cc_arr, _ = get_namespace(cc_arr) if cc_arr.dtype != dtype: cc_arr = xp_cc_arr.astype(cc_arr, dtype) else: cc_arr = np.ascontiguousarray(cc_arr, dtype=dtype) - self._validate_center_shape(X, cc_arr) centers = cc_arr elif _is_arraylike_not_scalar(init): @@ -269,14 +269,11 @@ def _init_centroids_sklearn(self, X, init, random_state, dtype=None): return to_table(centers, queue=getattr(QM.get_global_queue(), "_queue", None)) - def _fit_backend(self, X_table, centroids_table, dtype=np.float32, is_csr=False): + # --- core train/infer wrappers in the estimator pattern --- + def _fit_backend(self, X_table, centroids_table, dtype=np.float32, is_csr=False): params = self._get_onedal_params(is_csr, dtype) - - assert X_table.dtype == dtype - result = self.train(params, X_table, centroids_table) - return ( result.responses, result.objective_function_value, @@ -284,9 +281,17 @@ def _fit_backend(self, X_table, centroids_table, dtype=np.float32, is_csr=False) result.iteration_count, ) - def _fit(self, X): - is_csr = _is_csr(X) + def _predict_backend(self, X_table, result_options=None): + params = self._get_onedal_params( + _is_csr=False, dtype=X_table.dtype, result_options=result_options + ) + return self.infer(params, self.model_, X_table) + # --- public API matched to the pattern --- + + @supports_queue + def fit(self, X, y=None, queue=None): + is_csr = _is_csr(X) xp, _ = get_namespace(X) if _get_config()["use_raw_input"] is False: @@ -301,26 +306,24 @@ def _fit(self, X): dtype = X_table.dtype self._check_params_vs_input(X_table, is_csr, dtype=dtype) - self.n_features_in_ = X_table.column_count - best_model, best_n_iter = None, None - best_inertia, best_labels = None, None + best_model = best_labels = None + best_inertia = None + best_n_iter = None - def is_better_iteration(inertia, labels): + def is_better(inertia, labels): if best_inertia is None: return True - else: - better_inertia = inertia < best_inertia - return better_inertia and not self._is_same_clustering( - labels, best_labels, self.n_clusters - ) + better = inertia < best_inertia + return better and not self._is_same_clustering( + labels, best_labels, self.n_clusters + ) random_state = check_random_state(self.random_state) init = self.init - init_is_array_like = _is_arraylike_not_scalar(init) - if init_is_array_like: + if _is_arraylike_not_scalar(init): init = _check_array( init, dtype=dtype, accept_sparse="csr", copy=True, order="C" ) @@ -330,9 +333,9 @@ def is_better_iteration(inertia, labels): for _ in range(self._n_init): if use_onedal_init: - random_seed = random_state.randint(np.iinfo("i").max) + seed = random_state.randint(np.iinfo("i").max) centroids_table = self._init_centroids_onedal( - X_table, init, random_seed, is_csr, dtype=dtype + X_table, init, seed, is_csr, dtype=dtype ) else: centroids_table = self._init_centroids_sklearn( @@ -342,90 +345,73 @@ def is_better_iteration(inertia, labels): if self.verbose: print("Initialization complete") - labels, inertia, model, n_iter = self._fit_backend( + labels_t, inertia, model, n_iter = self._fit_backend( X_table, centroids_table, dtype, is_csr ) if self.verbose: - print("Iteration {}, inertia {}.".format(n_iter, inertia)) + print(f"Iteration {n_iter}, inertia {inertia}.") - if is_better_iteration(inertia, labels): + if is_better(inertia, labels_t): best_model, best_n_iter = model, n_iter - best_inertia, best_labels = inertia, labels + best_inertia, best_labels = inertia, labels_t - # Types without conversion + # assign learned attributes (pattern) self.model_ = best_model - - # Simple types self.n_iter_ = best_n_iter self.inertia_ = best_inertia - - # Complex type conversion self.labels_ = from_table(best_labels).ravel() distinct_clusters = len(np.unique(self.labels_)) if distinct_clusters < self.n_clusters: warnings.warn( - "Number of distinct clusters ({}) found smaller than " - "n_clusters ({}). Possibly due to duplicate points " - "in X.".format(distinct_clusters, self.n_clusters), + "Number of distinct clusters ({}) found smaller than n_clusters ({}). " + "Possibly due to duplicate points in X.".format( + distinct_clusters, self.n_clusters + ), ConvergenceWarning, stacklevel=2, ) - return self @property def cluster_centers_(self): - if not hasattr(self, "_cluster_centers_"): - if hasattr(self, "model_"): - centroids = self.model_.centroids - self._cluster_centers_ = from_table(centroids) - else: + if self._cluster_centers_ is None: + if not hasattr(self, "model_") or self.model_ is None: raise NameError("This model has not been trained") + self._cluster_centers_ = from_table(self.model_.centroids) return self._cluster_centers_ @cluster_centers_.setter def cluster_centers_(self, cluster_centers): xp, _ = get_namespace(cluster_centers) self._cluster_centers_ = xp.asarray(cluster_centers) - self.n_iter_ = 0 self.inertia_ = 0 - + # keep backend model in sync self.model_.centroids = to_table(self._cluster_centers_) self.n_features_in_ = self.model_.centroids.column_count self.labels_ = xp.arange(self.model_.centroids.row_count) - return self - @cluster_centers_.deleter def cluster_centers_(self): - del self._cluster_centers_ - - def _predict(self, X, result_options=None): - is_csr = _is_csr(X) + self._cluster_centers_ = None + @supports_queue + def predict(self, X, queue=None): X_table = to_table(X, queue=QM.get_global_queue()) - params = self._get_onedal_params(is_csr, X_table.dtype, result_options) - - result = self.infer(params, self.model_, X_table) - - if result_options == "compute_exact_objective_function": - # This is only set for score function - return -1 * result.objective_function_value - else: - return from_table(result.responses).ravel() - - def _score(self, X): - result_options = "compute_exact_objective_function" + result = self._predict_backend(X_table) + return from_table(result.responses).ravel() - return self._predict( - X, - result_options, + @supports_queue + def score(self, X, queue=None): + X_table = to_table(X, queue=QM.get_global_queue()) + result = self._predict_backend( + X_table, result_options="compute_exact_objective_function" ) + return -1 * result.objective_function_value - def _transform(self, X): + def transform(self, X): return euclidean_distances(X, self.cluster_centers_) @@ -451,128 +437,21 @@ def __init__( tol=tol, verbose=verbose, random_state=random_state, + algorithm=algorithm, ) - - self.copy_x = copy_x - self.algorithm = algorithm - assert self.algorithm == "lloyd" + self.copy_x = copy_x # stored, but not used by oneDAL path @supports_queue def fit(self, X, y=None, queue=None): - return self._fit(X) + return super().fit(X, y=y, queue=queue) @supports_queue def predict(self, X, queue=None): - """Predict the closest cluster each sample in X belongs to. - - In the vector quantization literature, `cluster_centers_` is called - the code book and each value returned by `predict` is the index of - the closest code in the code book. - - Parameters - ---------- - X : array-like of shape (n_samples, n_features) - New data to predict. - - queue : SyclQueue or None, default=None - SYCL Queue object for device code execution. Default - value None causes computation on host. - - Returns - ------- - labels : ndarray of shape (n_samples,) - Index of the cluster each sample belongs to. - """ - return self._predict(X) - - def fit_predict(self, X, y=None, queue=None): - """Compute cluster centers and predict cluster index for each sample. - - Convenience method; equivalent to calling fit(X) followed by - predict(X). - - Parameters - ---------- - X : array-like of shape (n_samples, n_features) - New data to transform. - - y : Ignored - Not used, present here for API consistency by convention. - - queue : SyclQueue or None, default=None - SYCL Queue object for device code execution. Default - value None causes computation on host. - - Returns - ------- - labels : ndarray of shape (n_samples,) - Index of the cluster each sample belongs to. - """ - return self.fit(X, queue=queue).labels_ - - def fit_transform(self, X, y=None, queue=None): - """Compute clustering and transform X to cluster-distance space. - - Equivalent to fit(X).transform(X), but more efficiently implemented. - - Parameters - ---------- - X : array-like of shape (n_samples, n_features) - New data to transform. - - y : Ignored - Not used, present here for API consistency by convention. - - queue : SyclQueue or None, default=None - SYCL Queue object for device code execution. Default - value None causes computation on host. - - Returns - ------- - X_new : ndarray of shape (n_samples, n_clusters) - X transformed in the new space. - """ - return self.fit(X, queue=queue)._transform(X) - - def transform(self, X): - """Transform X to a cluster-distance space. - - In the new space, each dimension is the distance to the cluster - centers. Note that even if X is sparse, the array returned by - `transform` will typically be dense. - - Parameters - ---------- - X : array-like of shape (n_samples, n_features) - New data to transform. - - Returns - ------- - X_new : ndarray of shape (n_samples, n_clusters) - X transformed in the new space. - """ - - return self._transform(X) + return super().predict(X, queue=queue) @supports_queue def score(self, X, queue=None): - """Opposite of the value of X on the K-means objective. - - Parameters - ---------- - X : {array-like, sparse matrix} of shape (n_samples, n_features) - New data. - - queue : SyclQueue or None, default=None - SYCL Queue object for device code execution. Default - value None causes computation on host. - - Returns - ------- - score: float - Opposite of the value of X on the K-means objective. - """ - return self._score(X) + return super().score(X, queue=queue) def k_means( @@ -601,6 +480,7 @@ def k_means( copy_x=copy_x, algorithm=algorithm, ).fit(X, queue=queue) + if return_n_iter: return est.cluster_centers_, est.labels_, est.inertia_, est.n_iter_ else: From 28fad432d8fcae18d0db9dbca309320abb6dafc1 Mon Sep 17 00:00:00 2001 From: "Mcgrievy, Kathleen" Date: Wed, 6 Aug 2025 13:51:51 -0700 Subject: [PATCH 32/33] Fixed imports. --- onedal/cluster/kmeans.py | 25 ------------------------- 1 file changed, 25 deletions(-) diff --git a/onedal/cluster/kmeans.py b/onedal/cluster/kmeans.py index 8fddce782d..b7b9da39af 100644 --- a/onedal/cluster/kmeans.py +++ b/onedal/cluster/kmeans.py @@ -26,31 +26,6 @@ from onedal.common._backend import bind_default_backend from onedal.utils import _sycl_queue_manager as QM -if daal_check_version((2023, "P", 200)): - from .kmeans_init import KMeansInit - -import logging -import warnings -from abc import ABC - -import numpy as np -from sklearn.cluster._kmeans import _kmeans_plusplus -from sklearn.exceptions import ConvergenceWarning -from sklearn.metrics.pairwise import euclidean_distances -from sklearn.utils import check_random_state - -from daal4py.sklearn._utils import daal_check_version -from onedal._device_offload import supports_queue -from onedal.basic_statistics import BasicStatistics -from onedal.common._backend import bind_default_backend -from onedal.utils import _sycl_queue_manager as QM -from sklearnex.utils._array_api import get_namespace - -from .._config import _get_config -from ..common._mixin import ClusterMixin, TransformerMixin -from ..datatypes import from_table, to_table -from ..utils.validation import _check_array, _is_arraylike_not_scalar, _is_csr - if daal_check_version((2023, "P", 200)): from .kmeans_init import KMeansInit From c0f24f030ce92981e16a76893214499df8a04562 Mon Sep 17 00:00:00 2001 From: "Mcgrievy, Kathleen" Date: Wed, 6 Aug 2025 14:40:36 -0700 Subject: [PATCH 33/33] fixed _get_onedal_params call error. --- onedal/cluster/kmeans.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onedal/cluster/kmeans.py b/onedal/cluster/kmeans.py index b7b9da39af..cd05f8f969 100644 --- a/onedal/cluster/kmeans.py +++ b/onedal/cluster/kmeans.py @@ -258,7 +258,7 @@ def _fit_backend(self, X_table, centroids_table, dtype=np.float32, is_csr=False) def _predict_backend(self, X_table, result_options=None): params = self._get_onedal_params( - _is_csr=False, dtype=X_table.dtype, result_options=result_options + is_csr=False, dtype=X_table.dtype, result_options=result_options ) return self.infer(params, self.model_, X_table)