Skip to content

Commit f9cb32c

Browse files
authored
[enhancement] add assume_centered capability to EmpiricalCovariance (#1796)
* Update covariance.cpp * Update covariance.py * Update covariance.py * Update covariance.py * Update test_covariance.py * import daal_check_version * Update covariance.py * Update covariance.py
1 parent 11d676f commit f9cb32c

File tree

4 files changed

+39
-11
lines changed

4 files changed

+39
-11
lines changed

onedal/covariance/covariance.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,11 @@ struct params2desc {
5252
desc.set_bias(params["bias"].cast<bool>());
5353
}
5454
#endif // defined(ONEDAL_VERSION) && ONEDAL_VERSION>=20240001
55+
#if defined(ONEDAL_VERSION) && ONEDAL_VERSION >= 20240400
56+
if (params.contains("assumeCentered")) {
57+
desc.set_assume_centered(params["assumeCentered"].cast<bool>());
58+
}
59+
#endif // defined(ONEDAL_VERSION) && ONEDAL_VERSION>=20240400
5560
return desc;
5661
}
5762
};

onedal/covariance/covariance.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,10 @@
2626

2727

2828
class BaseEmpiricalCovariance(BaseEstimator, metaclass=ABCMeta):
29-
def __init__(self, method="dense", bias=False):
29+
def __init__(self, method="dense", bias=False, assume_centered=False):
3030
self.method = method
3131
self.bias = bias
32+
self.assume_centered = assume_centered
3233

3334
def _get_onedal_params(self, dtype=np.float32):
3435
params = {
@@ -37,6 +38,8 @@ def _get_onedal_params(self, dtype=np.float32):
3738
}
3839
if daal_check_version((2024, "P", 1)):
3940
params["bias"] = self.bias
41+
if daal_check_version((2024, "P", 400)):
42+
params["assumeCentered"] = self.assume_centered
4043

4144
return params
4245

@@ -55,6 +58,12 @@ class EmpiricalCovariance(BaseEmpiricalCovariance):
5558
If True biased estimation of covariance is computed which equals to
5659
the unbiased one multiplied by (n_samples - 1) / n_samples.
5760
61+
assume_centered : bool, default=False
62+
If True, data are not centered before computation.
63+
Useful when working with data whose mean is almost, but not exactly
64+
zero.
65+
If False (default), data are centered before computation.
66+
5867
Attributes
5968
----------
6069
location_ : ndarray of shape (n_features,)

sklearnex/preview/covariance/covariance.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from sklearn.utils import check_array
2323

2424
from daal4py.sklearn._n_jobs_support import control_n_jobs
25-
from daal4py.sklearn._utils import sklearn_check_version
25+
from daal4py.sklearn._utils import daal_check_version, sklearn_check_version
2626
from onedal.common.hyperparameters import get_hyperparameters
2727
from onedal.covariance import EmpiricalCovariance as onedal_EmpiricalCovariance
2828
from sklearnex import config_context
@@ -44,6 +44,10 @@ class EmpiricalCovariance(sklearn_EmpiricalCovariance):
4444

4545
def _save_attributes(self):
4646
assert hasattr(self, "_onedal_estimator")
47+
if not daal_check_version((2024, "P", 400)) and self.assume_centered:
48+
location = self._onedal_estimator.location_[None, :]
49+
self._onedal_estimator.covariance_ += np.dot(location.T, location)
50+
self._onedal_estimator.location_ = np.zeros_like(np.squeeze(location))
4751
self._set_covariance(self._onedal_estimator.covariance_)
4852
self.location_ = self._onedal_estimator.location_
4953

@@ -58,6 +62,7 @@ def _onedal_fit(self, X, queue=None):
5862
onedal_params = {
5963
"method": "dense",
6064
"bias": True,
65+
"assume_centered": self.assume_centered,
6166
}
6267

6368
self._onedal_estimator = self._onedal_covariance(**onedal_params)
@@ -73,10 +78,6 @@ def _onedal_supported(self, method_name, *data):
7378
(X,) = data
7479
patching_status.and_conditions(
7580
[
76-
(
77-
self.assume_centered == False,
78-
"assume_centered parameter is not supported on oneDAL side",
79-
),
8081
(not sp.issparse(X), "X is sparse. Sparse input is not supported."),
8182
]
8283
)

sklearnex/preview/covariance/tests/test_covariance.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,27 +27,40 @@
2727

2828
@pytest.mark.parametrize("dataframe,queue", get_dataframes_and_queues())
2929
@pytest.mark.parametrize("macro_block", [None, 1024])
30-
def test_sklearnex_import_covariance(dataframe, queue, macro_block):
30+
@pytest.mark.parametrize("assume_centered", [True, False])
31+
def test_sklearnex_import_covariance(dataframe, queue, macro_block, assume_centered):
3132
from sklearnex.preview.covariance import EmpiricalCovariance
3233

3334
X = np.array([[0, 1], [0, 1]])
35+
3436
X = _convert_to_dataframe(X, sycl_queue=queue, target_df=dataframe)
35-
empcov = EmpiricalCovariance()
37+
empcov = EmpiricalCovariance(assume_centered=assume_centered)
3638
if daal_check_version((2024, "P", 0)) and macro_block is not None:
3739
hparams = empcov.get_hyperparameters("fit")
3840
hparams.cpu_macro_block = macro_block
3941
result = empcov.fit(X)
42+
4043
expected_covariance = np.array([[0, 0], [0, 0]])
41-
expected_means = np.array([0, 1])
44+
expected_means = np.array([0, 0])
45+
46+
if assume_centered:
47+
expected_covariance = np.array([[0, 0], [0, 1]])
48+
else:
49+
expected_means = np.array([0, 1])
4250

4351
assert_allclose(expected_covariance, result.covariance_)
4452
assert_allclose(expected_means, result.location_)
4553

4654
X = np.array([[1, 2], [3, 6]])
55+
4756
X = _convert_to_dataframe(X, sycl_queue=queue, target_df=dataframe)
4857
result = empcov.fit(X)
49-
expected_covariance = np.array([[1, 2], [2, 4]])
50-
expected_means = np.array([2, 4])
58+
59+
if assume_centered:
60+
expected_covariance = np.array([[5, 10], [10, 20]])
61+
else:
62+
expected_covariance = np.array([[1, 2], [2, 4]])
63+
expected_means = np.array([2, 4])
5164

5265
assert_allclose(expected_covariance, result.covariance_)
5366
assert_allclose(expected_means, result.location_)

0 commit comments

Comments
 (0)