Skip to content

Commit 8496857

Browse files
authored
enh: Add SPMD interfaces for Covariance (#1697)
* enh: Add SPMD interfaces for Covariance * linted
1 parent 79b28ad commit 8496857

File tree

10 files changed

+133
-0
lines changed

10 files changed

+133
-0
lines changed
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
# ==============================================================================
2+
# Copyright 2024 Intel Corporation
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
# ==============================================================================
16+
17+
import dpctl
18+
import dpctl.tensor as dpt
19+
import numpy as np
20+
from mpi4py import MPI
21+
22+
from sklearnex.spmd.covariance import EmpiricalCovariance
23+
24+
25+
def get_data(data_seed):
26+
ns, nf = 3000, 3
27+
drng = np.random.default_rng(data_seed)
28+
X = drng.random(size=(ns, nf))
29+
return X
30+
31+
32+
q = dpctl.SyclQueue("gpu")
33+
comm = MPI.COMM_WORLD
34+
rank = comm.Get_rank()
35+
size = comm.Get_size()
36+
37+
X = get_data(rank)
38+
dpt_X = dpt.asarray(X, usm_type="device", sycl_queue=q)
39+
40+
cov = EmpiricalCovariance().fit(dpt_X)
41+
42+
print(f"Computed covariance values on rank {rank}:\n", cov.covariance_)

onedal/spmd/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
__all__ = [
1818
"basic_statistics",
1919
"cluster",
20+
"covariance",
2021
"decomposition",
2122
"ensemble",
2223
"linear_model",

onedal/spmd/covariance/__init__.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# ==============================================================================
2+
# Copyright 2024 Intel Corporation
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
# ==============================================================================
16+
17+
from .covariance import EmpiricalCovariance
18+
19+
__all__ = ["EmpiricalCovariance"]
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
# ==============================================================================
2+
# Copyright 2024 Intel Corporation
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
# ==============================================================================
16+
17+
from onedal.covariance import EmpiricalCovariance as EmpiricalCovariance_Batch
18+
19+
from ..._device_offload import support_usm_ndarray
20+
from .._common import BaseEstimatorSPMD
21+
22+
23+
class EmpiricalCovariance(BaseEstimatorSPMD, EmpiricalCovariance_Batch):
24+
@support_usm_ndarray()
25+
def fit(self, X, queue=None):
26+
return super().fit(X, queue)

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -565,6 +565,7 @@ def run(self):
565565
if build_distribute:
566566
packages_with_tests += [
567567
"onedal.spmd",
568+
"onedal.spmd.covariance",
568569
"onedal.spmd.decomposition",
569570
"onedal.spmd.ensemble",
570571
]

setup_sklearnex.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@
105105
if build_distribute:
106106
packages_with_tests += [
107107
"sklearnex.spmd",
108+
"sklearnex.spmd.covariance",
108109
"sklearnex.spmd.decomposition",
109110
"sklearnex.spmd.ensemble",
110111
]

sklearnex/spmd/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
__all__ = [
1818
"basic_statistics",
1919
"cluster",
20+
"covariance",
2021
"decomposition",
2122
"ensemble",
2223
"linear_model",
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# ==============================================================================
2+
# Copyright 2024 Intel Corporation
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
# ==============================================================================
16+
17+
from .covariance import EmpiricalCovariance
18+
19+
__all__ = ["EmpiricalCovariance"]
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# ==============================================================================
2+
# Copyright 2024 Intel Corporation
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
# ==============================================================================
16+
17+
from onedal.spmd.covariance import EmpiricalCovariance
18+
19+
# TODO:
20+
# Currently it uses `onedal` module interface.
21+
# Add sklearnex dispatching.

tests/run_examples.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,7 @@ def check_library(rule):
147147

148148
req_device = defaultdict(lambda: [])
149149
req_device["basic_statistics_spmd.py"] = ["gpu"]
150+
req_device["covariance_spmd.py"] = ["gpu"]
150151
req_device["dbscan_spmd.py"] = ["gpu"]
151152
req_device["kmeans_spmd.py"] = ["gpu"]
152153
req_device["knn_bf_classification_dpnp.py"] = ["gpu"]
@@ -163,6 +164,7 @@ def check_library(rule):
163164

164165
req_library = defaultdict(lambda: [])
165166
req_library["basic_statistics_spmd.py"] = ["dpctl", "mpi4py"]
167+
req_library["covariance_spmd.py"] = ["dpctl", "mpi4py"]
166168
req_library["dbscan_spmd.py"] = ["dpctl", "mpi4py"]
167169
req_library["basic_statistics_spmd.py"] = ["dpctl", "mpi4py"]
168170
req_library["kmeans_spmd.py"] = ["dpctl", "mpi4py"]

0 commit comments

Comments
 (0)