1414# limitations under the License.
1515# ==============================================================================
1616
17- from abc import ABC
18-
1917from onedal .neighbors import KNeighborsClassifier as KNeighborsClassifier_Batch
2018from onedal .neighbors import KNeighborsRegressor as KNeighborsRegressor_Batch
2119
2220from ..._device_offload import support_usm_ndarray
23- from ...common ._spmd_policy import _get_spmd_policy
24-
25-
26- class NeighborsCommonBaseSPMD (ABC ):
27- def _get_policy (self , queue , * data ):
28- return _get_spmd_policy (queue )
21+ from .._common import BaseEstimatorSPMD
2922
3023
31- class KNeighborsClassifier (NeighborsCommonBaseSPMD , KNeighborsClassifier_Batch ):
24+ class KNeighborsClassifier (BaseEstimatorSPMD , KNeighborsClassifier_Batch ):
3225 @support_usm_ndarray ()
3326 def fit (self , X , y , queue = None ):
3427 return super ().fit (X , y , queue )
@@ -46,7 +39,7 @@ def kneighbors(self, X=None, n_neighbors=None, return_distance=True, queue=None)
4639 return super ().kneighbors (X , n_neighbors , return_distance , queue )
4740
4841
49- class KNeighborsRegressor (NeighborsCommonBaseSPMD , KNeighborsRegressor_Batch ):
42+ class KNeighborsRegressor (BaseEstimatorSPMD , KNeighborsRegressor_Batch ):
5043 @support_usm_ndarray ()
5144 def fit (self , X , y , queue = None ):
5245 if queue is not None and queue .sycl_device .is_gpu :
@@ -72,7 +65,7 @@ def _get_onedal_params(self, X, y=None):
7265 return params
7366
7467
75- class NearestNeighbors (NeighborsCommonBaseSPMD ):
68+ class NearestNeighbors (BaseEstimatorSPMD ):
7669 @support_usm_ndarray ()
7770 def fit (self , X , y , queue = None ):
7871 return super ().fit (X , y , queue )
0 commit comments