Skip to content

Commit b450725

Browse files
authored
[Enhancement] enable array API support in EmpiricialCovariance and IncrementalEmpiricalCovariance (#2207)
* fix onedal side * add SPMD interface * add new assert_all_finite * change incremental algo * remove old code * readd deletion * missing? * Update test_covariance.py * fix error in spmd interface * fixes * remove changes * fix * fix preview * fixes * forgotten import * remove import * fix squeeze * Update test_covariance.py * Update covariance.py * Update incremental_covariance.py * Update incremental_covariance.py * try to fix issues * updates to try and solve score * fixes for mahalanobis * Update incremental_covariance.py * Update incremental_covariance.py * Update incremental_covariance.py * Update incremental_covariance.py * Update incremental_covariance.py * Update incremental_covariance.py * expand array API support in sklearnex * fix spelling mistake * remove print statement * fix issues with respect to dpctl dpnp * Update covariance.py * fix error_norm and score * try to fix mahal * fixes * fixes * fixes again * try again * fixes for test case failures * try again * try to remove regression * fix mistakes * updates * fix covariance * standardize validate_params * Update covariance.py * Update covariance.py * fix issues with double import * Update covariance.py * Update covariance.py * Update _array_api.py * Update _array_api.py * Update _array_api.py * Update _array_api.py * Update incremental_covariance.py * Update incremental_covariance.py * Update covariance.py * Update incremental_covariance.py * Update covariance.py * Update covariance.py * Update incremental_covariance.py * Update incremental_covariance.py * Update incremental_covariance.py * Update incremental_covariance.py * Update covariance.py * Update _array_api.py * Update covariance.py * Update incremental_covariance.py * Update test_incremental_covariance_spmd.py * Update covariance.py * Update test_incremental_covariance_spmd.py * Update covariance.py * Update incremental_covariance.py * Update covariance.py * Update test_covariance_spmd.py * Update test_covariance_spmd.py * Update test_covariance_spmd.py * Update test_covariance_spmd.py * Update test_covariance_spmd.py * many requisite changes * Update test_covariance.py * force output to a 0d * try to fix test' * updates
1 parent 8d35f64 commit b450725

File tree

10 files changed

+429
-165
lines changed

10 files changed

+429
-165
lines changed

onedal/covariance/covariance.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
from daal4py.sklearn._utils import daal_check_version
2121
from onedal._device_offload import supports_queue
2222
from onedal.common._backend import bind_default_backend
23-
from onedal.utils.validation import _check_array
2423

2524
from .._config import _get_config
2625
from ..common.hyperparameters import get_hyperparameters
@@ -101,13 +100,7 @@ def fit(self, X, y=None, queue=None):
101100
self : object
102101
Returns the instance itself.
103102
"""
104-
use_raw_input = _get_config()["use_raw_input"] is True
105-
sua_iface, xp, _ = _get_sycl_namespace(X)
106-
if use_raw_input and sua_iface:
107-
queue = X.sycl_queue
108103

109-
if not use_raw_input:
110-
X = _check_array(X, dtype=[np.float64, np.float32])
111104
X_table = to_table(X, queue=queue)
112105

113106
params = self._get_onedal_params(X_table.dtype)
@@ -123,6 +116,6 @@ def fit(self, X, y=None, queue=None):
123116
from_table(result.cov_matrix, like=X) * (X.shape[0] - 1) / X.shape[0]
124117
)
125118

126-
self.location_ = xp.squeeze(from_table(result.means, like=X))
119+
self.location_ = from_table(result.means, like=X)[0, ...]
127120

128121
return self

onedal/covariance/incremental_covariance.py

Lines changed: 13 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,12 @@
1717
import numpy as np
1818

1919
from daal4py.sklearn._utils import daal_check_version
20-
from onedal._device_offload import supports_queue
21-
from onedal.common._backend import bind_default_backend
22-
from onedal.utils import _sycl_queue_manager as QM
2320

2421
from .._config import _get_config
25-
from ..datatypes import from_table, to_table
22+
from .._device_offload import supports_queue
23+
from ..common._backend import bind_default_backend
24+
from ..datatypes import from_table, return_type_constructor, to_table
25+
from ..utils import _sycl_queue_manager as QM
2626
from ..utils._array_api import _get_sycl_namespace
2727
from ..utils.validation import _check_array
2828
from .covariance import BaseEmpiricalCovariance
@@ -74,6 +74,7 @@ def finalize_compute(self, params, partial_result): ...
7474
def _reset(self):
7575
self._need_to_finalize = False
7676
self._queue = None
77+
self._outtype = None
7778
self._partial_result = self.partial_compute_result()
7879

7980
def __getstate__(self):
@@ -108,15 +109,10 @@ def partial_fit(self, X, y=None, queue=None):
108109
self : object
109110
Returns the instance itself.
110111
"""
111-
use_raw_input = _get_config()["use_raw_input"] is True
112-
sua_iface, _, _ = _get_sycl_namespace(X)
113-
114-
if use_raw_input and sua_iface:
115-
queue = X.sycl_queue
116-
if not use_raw_input:
117-
X = _check_array(X, dtype=[np.float64, np.float32], ensure_2d=True)
118112

119113
self._queue = queue
114+
if not self._outtype:
115+
self._outtype = return_type_constructor(X)
120116
X_table = to_table(X, queue=queue)
121117

122118
if not hasattr(self, "_dtype"):
@@ -125,8 +121,6 @@ def partial_fit(self, X, y=None, queue=None):
125121
params = self._get_onedal_params(self._dtype)
126122
self._partial_result = self.partial_compute(params, self._partial_result, X_table)
127123
self._need_to_finalize = True
128-
# store the queue for when we finalize
129-
self._queue = queue
130124

131125
def finalize_fit(self):
132126
"""Finalize covariance matrix from the current `_partial_result`.
@@ -143,13 +137,14 @@ def finalize_fit(self):
143137
with QM.manage_global_queue(self._queue):
144138
result = self.finalize_compute(params, self._partial_result)
145139

146-
if daal_check_version((2024, "P", 1)) or (not self.bias):
147-
self.covariance_ = from_table(result.cov_matrix)
148-
else:
140+
self.covariance_ = from_table(result.cov_matrix, like=self._outtype)
141+
142+
if self.bias and not daal_check_version((2024, "P", 1)):
149143
n_rows = self._partial_result.partial_n_rows
150-
self.covariance_ = from_table(result.cov_matrix) * (n_rows - 1) / n_rows
144+
self.covariance_ *= (n_rows - 1) / n_rows
151145

152-
self.location_ = from_table(result.means).ravel()
146+
self.location_ = from_table(result.means, like=self._outtype)[0, ...]
147+
self._outtype = None
153148

154149
self._need_to_finalize = False
155150

onedal/covariance/tests/test_covariance.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,23 +25,23 @@
2525
def test_onedal_import_covariance(queue):
2626
from onedal.covariance import EmpiricalCovariance
2727

28-
X = np.array([[0, 1], [0, 1]])
28+
X = np.array([[0, 1], [0, 1]], dtype=np.float64)
2929
result = EmpiricalCovariance().fit(X, queue=queue)
3030
expected_covariance = np.array([[0, 0], [0, 0]])
3131
expected_means = np.array([0, 1])
3232

3333
assert_allclose(expected_covariance, result.covariance_)
3434
assert_allclose(expected_means, result.location_)
3535

36-
X = np.array([[1, 2], [3, 6]])
36+
X = np.array([[1, 2], [3, 6]], dtype=np.float64)
3737
result = EmpiricalCovariance().fit(X, queue=queue)
3838
expected_covariance = np.array([[2, 4], [4, 8]])
3939
expected_means = np.array([2, 4])
4040

4141
assert_allclose(expected_covariance, result.covariance_)
4242
assert_allclose(expected_means, result.location_)
4343

44-
X = np.array([[1, 2], [3, 6]])
44+
X = np.array([[1, 2], [3, 6]], dtype=np.float64)
4545
result = EmpiricalCovariance(bias=True).fit(X, queue=queue)
4646
expected_covariance = np.array([[1, 2], [2, 4]])
4747
expected_means = np.array([2, 4])

0 commit comments

Comments
 (0)