Skip to content

Commit e232bc1

Browse files
authored
[enhancement] Enable oneDAL support for EmpiricalCovariance with default parameters (#1677)
* Update covariance.py * Update covariance.py * formatting * Update covariance.py * Update covariance.py * Update covariance.py * isort * Update covariance.py * Update covariance.py * Update covariance.py
1 parent 130d068 commit e232bc1

File tree

2 files changed

+42
-17
lines changed

2 files changed

+42
-17
lines changed

onedal/covariance/covariance.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,10 @@
1616
from abc import ABCMeta
1717

1818
import numpy as np
19-
from sklearn.utils import check_array
2019

21-
from daal4py.sklearn._utils import daal_check_version, get_dtype, make2d
20+
from daal4py.sklearn._utils import daal_check_version, get_dtype
2221
from onedal import _backend
22+
from onedal.utils import _check_array
2323

2424
from ..common._policy import _get_policy
2525
from ..common.hyperparameters import get_hyperparameters
@@ -86,11 +86,7 @@ def fit(self, X, queue=None):
8686
Returns the instance itself.
8787
"""
8888
policy = self._get_policy(queue, X)
89-
X = check_array(X, dtype=[np.float64, np.float32])
90-
X = make2d(X)
91-
types = [np.float32, np.float64]
92-
if get_dtype(X) not in types:
93-
X = X.astype(np.float64)
89+
X = _check_array(X, dtype=[np.float64, np.float32])
9490
X = _convert_to_supported(policy, X)
9591
dtype = get_dtype(X)
9692
params = self._get_onedal_params(dtype)

sklearnex/preview/covariance/covariance.py

Lines changed: 39 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@
1414
# limitations under the License.
1515
# ===============================================================================
1616

17+
import warnings
18+
19+
import numpy as np
1720
from scipy import sparse as sp
1821
from sklearn.covariance import EmpiricalCovariance as sklearn_EmpiricalCovariance
1922
from sklearn.utils import check_array
@@ -22,13 +25,15 @@
2225
from daal4py.sklearn._utils import sklearn_check_version
2326
from onedal.common.hyperparameters import get_hyperparameters
2427
from onedal.covariance import EmpiricalCovariance as onedal_EmpiricalCovariance
28+
from sklearnex import config_context
29+
from sklearnex.metrics import pairwise_distances
2530

26-
from ..._device_offload import dispatch
31+
from ..._device_offload import dispatch, wrap_output_data
2732
from ..._utils import PatchingConditionsChain, register_hyperparameters
2833

2934

3035
@register_hyperparameters({"fit": get_hyperparameters("covariance", "compute")})
31-
@control_n_jobs(decorated_methods=["fit"])
36+
@control_n_jobs(decorated_methods=["fit", "mahalanobis"])
3237
class EmpiricalCovariance(sklearn_EmpiricalCovariance):
3338
__doc__ = sklearn_EmpiricalCovariance.__doc__
3439

@@ -39,12 +44,17 @@ class EmpiricalCovariance(sklearn_EmpiricalCovariance):
3944

4045
def _save_attributes(self):
4146
assert hasattr(self, "_onedal_estimator")
42-
self.covariance_ = self._onedal_estimator.covariance_
47+
self._set_covariance(self._onedal_estimator.covariance_)
4348
self.location_ = self._onedal_estimator.location_
4449

4550
_onedal_covariance = staticmethod(onedal_EmpiricalCovariance)
4651

4752
def _onedal_fit(self, X, queue=None):
53+
if X.shape[0] == 1:
54+
warnings.warn(
55+
"Only one sample available. You may want to reshape your data array"
56+
)
57+
4858
onedal_params = {
4959
"method": "dense",
5060
"bias": True,
@@ -59,18 +69,14 @@ def _onedal_supported(self, method_name, *data):
5969
patching_status = PatchingConditionsChain(
6070
f"sklearn.covariance.{class_name}.{method_name}"
6171
)
62-
if method_name == "fit":
72+
if method_name in ["fit", "mahalanobis"]:
6373
(X,) = data
6474
patching_status.and_conditions(
6575
[
6676
(
6777
self.assume_centered == False,
6878
"assume_centered parameter is not supported on oneDAL side",
6979
),
70-
(
71-
self.store_precision == False,
72-
"precision matrix calculation is not supported on oneDAL side",
73-
),
7480
(not sp.issparse(X), "X is sparse. Sparse input is not supported."),
7581
]
7682
)
@@ -84,9 +90,9 @@ def fit(self, X, y=None):
8490
if sklearn_check_version("1.2"):
8591
self._validate_params()
8692
if sklearn_check_version("0.23"):
87-
self._validate_data(X)
93+
X = self._validate_data(X, force_all_finite=False)
8894
else:
89-
check_array(X)
95+
X = check_array(X, force_all_finite=False)
9096

9197
dispatch(
9298
self,
@@ -100,4 +106,27 @@ def fit(self, X, y=None):
100106

101107
return self
102108

109+
# expose sklearnex pairwise_distances if mahalanobis distance eventually supported
110+
@wrap_output_data
111+
def mahalanobis(self, X):
112+
if sklearn_check_version("1.0"):
113+
X = self._validate_data(X, reset=False)
114+
else:
115+
X = check_array(X)
116+
117+
precision = self.get_precision()
118+
with config_context(assume_finite=True):
119+
# compute mahalanobis distances
120+
dist = pairwise_distances(
121+
X, self.location_[np.newaxis, :], metric="mahalanobis", VI=precision
122+
)
123+
124+
return np.reshape(dist, (len(X),)) ** 2
125+
126+
error_norm = wrap_output_data(sklearn_EmpiricalCovariance.error_norm)
127+
score = wrap_output_data(sklearn_EmpiricalCovariance.score)
128+
103129
fit.__doc__ = sklearn_EmpiricalCovariance.fit.__doc__
130+
mahalanobis.__doc__ = sklearn_EmpiricalCovariance.mahalanobis
131+
error_norm.__doc__ = sklearn_EmpiricalCovariance.error_norm.__doc__
132+
score.__doc__ = sklearn_EmpiricalCovariance.score.__doc__

0 commit comments

Comments
 (0)