Skip to content

Commit 8956737

Browse files
MAINT: adding common BaseEstimatorSPMD for spmd ifaces (#1679)
1 parent aa3e156 commit 8956737

File tree

8 files changed

+41
-65
lines changed

8 files changed

+41
-65
lines changed

onedal/spmd/_common.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
# ==============================================================================
2+
# Copyright 2024 Intel Corporation
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
# ==============================================================================
16+
17+
from abc import ABC
18+
19+
from ..common._spmd_policy import _get_spmd_policy
20+
21+
22+
class BaseEstimatorSPMD(ABC):
23+
def _get_policy(self, queue, *data):
24+
return _get_spmd_policy(queue)

onedal/spmd/basic_statistics/basic_statistics.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,20 +14,13 @@
1414
# limitations under the License.
1515
# ==============================================================================
1616

17-
from abc import ABC
18-
1917
from onedal.basic_statistics import BasicStatistics as BasicStatistics_Batch
2018

2119
from ..._device_offload import support_usm_ndarray
22-
from ...common._spmd_policy import _get_spmd_policy
23-
24-
25-
class BaseBasicStatisticsSPMD(ABC):
26-
def _get_policy(self, queue, *data):
27-
return _get_spmd_policy(queue)
20+
from .._common import BaseEstimatorSPMD
2821

2922

30-
class BasicStatistics(BaseBasicStatisticsSPMD, BasicStatistics_Batch):
23+
class BasicStatistics(BaseEstimatorSPMD, BasicStatistics_Batch):
3124
@support_usm_ndarray()
3225
def compute(self, data, weights=None, queue=None):
3326
return super().compute(data, weights, queue)

onedal/spmd/cluster/dbscan.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,17 +14,10 @@
1414
# limitations under the License.
1515
# ==============================================================================
1616

17-
from abc import ABC
18-
1917
from onedal.cluster import DBSCAN as DBSCAN_Batch
2018

21-
from ...common._spmd_policy import _get_spmd_policy
22-
23-
24-
class BaseDBSCANspmd(ABC):
25-
def _get_policy(self, queue, *data):
26-
return _get_spmd_policy(queue)
19+
from .._common import BaseEstimatorSPMD
2720

2821

29-
class DBSCAN(BaseDBSCANspmd, DBSCAN_Batch):
22+
class DBSCAN(BaseEstimatorSPMD, DBSCAN_Batch):
3023
pass

onedal/spmd/cluster/kmeans.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,20 +14,13 @@
1414
# limitations under the License.
1515
# ==============================================================================
1616

17-
from abc import ABC
18-
1917
from onedal.cluster import KMeans as KMeans_Batch
2018

2119
from ..._device_offload import support_usm_ndarray
22-
from ...common._spmd_policy import _get_spmd_policy
23-
24-
25-
class BaseKMeansSPMD(ABC):
26-
def _get_policy(self, queue, *data):
27-
return _get_spmd_policy(queue)
20+
from .._common import BaseEstimatorSPMD
2821

2922

30-
class KMeans(BaseKMeansSPMD, KMeans_Batch):
23+
class KMeans(BaseEstimatorSPMD, KMeans_Batch):
3124
@support_usm_ndarray()
3225
def fit(self, X, queue=None):
3326
return super().fit(X, queue)

onedal/spmd/decomposition/pca.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,19 +14,13 @@
1414
# limitations under the License.
1515
# ==============================================================================
1616

17-
1817
from onedal.decomposition.pca import PCA as PCABatch
1918

2019
from ..._device_offload import support_usm_ndarray
21-
from ...common._spmd_policy import _get_spmd_policy
22-
23-
24-
class BasePCASPMD:
25-
def _get_policy(self, queue, *data):
26-
return _get_spmd_policy(queue)
20+
from .._common import BaseEstimatorSPMD
2721

2822

29-
class PCA(BasePCASPMD, PCABatch):
23+
class PCA(BaseEstimatorSPMD, PCABatch):
3024
@support_usm_ndarray()
3125
def fit(self, X, queue):
3226
return super().fit(X, queue)

onedal/spmd/ensemble/forest.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,22 +14,15 @@
1414
# limitations under the License.
1515
# ==============================================================================
1616

17-
from abc import ABC
18-
1917
from onedal.ensemble import RandomForestClassifier as RandomForestClassifier_Batch
2018
from onedal.ensemble import RandomForestRegressor as RandomForestRegressor_Batch
2119

22-
from ...common._spmd_policy import _get_spmd_policy
23-
24-
25-
class BaseForestSPMD(ABC):
26-
def _get_policy(self, queue, *data):
27-
return _get_spmd_policy(queue)
20+
from .._common import BaseEstimatorSPMD
2821

2922

30-
class RandomForestClassifier(BaseForestSPMD, RandomForestClassifier_Batch):
23+
class RandomForestClassifier(BaseEstimatorSPMD, RandomForestClassifier_Batch):
3124
pass
3225

3326

34-
class RandomForestRegressor(BaseForestSPMD, RandomForestRegressor_Batch):
27+
class RandomForestRegressor(BaseEstimatorSPMD, RandomForestRegressor_Batch):
3528
pass

onedal/spmd/linear_model/linear_model.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,20 +14,13 @@
1414
# limitations under the License.
1515
# ==============================================================================
1616

17-
from abc import ABC
18-
1917
from onedal.linear_model import LinearRegression as LinearRegression_Batch
2018

2119
from ..._device_offload import support_usm_ndarray
22-
from ...common._spmd_policy import _get_spmd_policy
23-
24-
25-
class BaseLinearRegressionSPMD(ABC):
26-
def _get_policy(self, queue, *data):
27-
return _get_spmd_policy(queue)
20+
from .._common import BaseEstimatorSPMD
2821

2922

30-
class LinearRegression(BaseLinearRegressionSPMD, LinearRegression_Batch):
23+
class LinearRegression(BaseEstimatorSPMD, LinearRegression_Batch):
3124
@support_usm_ndarray()
3225
def fit(self, X, y, queue=None):
3326
return super().fit(X, y, queue)

onedal/spmd/neighbors/neighbors.py

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,21 +14,14 @@
1414
# limitations under the License.
1515
# ==============================================================================
1616

17-
from abc import ABC
18-
1917
from onedal.neighbors import KNeighborsClassifier as KNeighborsClassifier_Batch
2018
from onedal.neighbors import KNeighborsRegressor as KNeighborsRegressor_Batch
2119

2220
from ..._device_offload import support_usm_ndarray
23-
from ...common._spmd_policy import _get_spmd_policy
24-
25-
26-
class NeighborsCommonBaseSPMD(ABC):
27-
def _get_policy(self, queue, *data):
28-
return _get_spmd_policy(queue)
21+
from .._common import BaseEstimatorSPMD
2922

3023

31-
class KNeighborsClassifier(NeighborsCommonBaseSPMD, KNeighborsClassifier_Batch):
24+
class KNeighborsClassifier(BaseEstimatorSPMD, KNeighborsClassifier_Batch):
3225
@support_usm_ndarray()
3326
def fit(self, X, y, queue=None):
3427
return super().fit(X, y, queue)
@@ -46,7 +39,7 @@ def kneighbors(self, X=None, n_neighbors=None, return_distance=True, queue=None)
4639
return super().kneighbors(X, n_neighbors, return_distance, queue)
4740

4841

49-
class KNeighborsRegressor(NeighborsCommonBaseSPMD, KNeighborsRegressor_Batch):
42+
class KNeighborsRegressor(BaseEstimatorSPMD, KNeighborsRegressor_Batch):
5043
@support_usm_ndarray()
5144
def fit(self, X, y, queue=None):
5245
if queue is not None and queue.sycl_device.is_gpu:
@@ -72,7 +65,7 @@ def _get_onedal_params(self, X, y=None):
7265
return params
7366

7467

75-
class NearestNeighbors(NeighborsCommonBaseSPMD):
68+
class NearestNeighbors(BaseEstimatorSPMD):
7669
@support_usm_ndarray()
7770
def fit(self, X, y, queue=None):
7871
return super().fit(X, y, queue)

0 commit comments

Comments
 (0)