1414# limitations under the License.
1515# ===============================================================================
1616
17+ import warnings
18+
19+ import numpy as np
1720from scipy import sparse as sp
1821from sklearn .covariance import EmpiricalCovariance as sklearn_EmpiricalCovariance
1922from sklearn .utils import check_array
2225from daal4py .sklearn ._utils import sklearn_check_version
2326from onedal .common .hyperparameters import get_hyperparameters
2427from onedal .covariance import EmpiricalCovariance as onedal_EmpiricalCovariance
28+ from sklearnex import config_context
29+ from sklearnex .metrics import pairwise_distances
2530
26- from ..._device_offload import dispatch
31+ from ..._device_offload import dispatch , wrap_output_data
2732from ..._utils import PatchingConditionsChain , register_hyperparameters
2833
2934
3035@register_hyperparameters ({"fit" : get_hyperparameters ("covariance" , "compute" )})
31- @control_n_jobs (decorated_methods = ["fit" ])
36+ @control_n_jobs (decorated_methods = ["fit" , "mahalanobis" ])
3237class EmpiricalCovariance (sklearn_EmpiricalCovariance ):
3338 __doc__ = sklearn_EmpiricalCovariance .__doc__
3439
@@ -39,12 +44,17 @@ class EmpiricalCovariance(sklearn_EmpiricalCovariance):
3944
4045 def _save_attributes (self ):
4146 assert hasattr (self , "_onedal_estimator" )
42- self .covariance_ = self ._onedal_estimator .covariance_
47+ self ._set_covariance ( self ._onedal_estimator .covariance_ )
4348 self .location_ = self ._onedal_estimator .location_
4449
4550 _onedal_covariance = staticmethod (onedal_EmpiricalCovariance )
4651
4752 def _onedal_fit (self , X , queue = None ):
53+ if X .shape [0 ] == 1 :
54+ warnings .warn (
55+ "Only one sample available. You may want to reshape your data array"
56+ )
57+
4858 onedal_params = {
4959 "method" : "dense" ,
5060 "bias" : True ,
@@ -59,18 +69,14 @@ def _onedal_supported(self, method_name, *data):
5969 patching_status = PatchingConditionsChain (
6070 f"sklearn.covariance.{ class_name } .{ method_name } "
6171 )
62- if method_name == "fit" :
72+ if method_name in [ "fit" , "mahalanobis" ] :
6373 (X ,) = data
6474 patching_status .and_conditions (
6575 [
6676 (
6777 self .assume_centered == False ,
6878 "assume_centered parameter is not supported on oneDAL side" ,
6979 ),
70- (
71- self .store_precision == False ,
72- "precision matrix calculation is not supported on oneDAL side" ,
73- ),
7480 (not sp .issparse (X ), "X is sparse. Sparse input is not supported." ),
7581 ]
7682 )
@@ -84,9 +90,9 @@ def fit(self, X, y=None):
8490 if sklearn_check_version ("1.2" ):
8591 self ._validate_params ()
8692 if sklearn_check_version ("0.23" ):
87- self ._validate_data (X )
93+ X = self ._validate_data (X , force_all_finite = False )
8894 else :
89- check_array (X )
95+ X = check_array (X , force_all_finite = False )
9096
9197 dispatch (
9298 self ,
@@ -100,4 +106,27 @@ def fit(self, X, y=None):
100106
101107 return self
102108
109+ # expose sklearnex pairwise_distances if mahalanobis distance eventually supported
110+ @wrap_output_data
111+ def mahalanobis (self , X ):
112+ if sklearn_check_version ("1.0" ):
113+ X = self ._validate_data (X , reset = False )
114+ else :
115+ X = check_array (X )
116+
117+ precision = self .get_precision ()
118+ with config_context (assume_finite = True ):
119+ # compute mahalanobis distances
120+ dist = pairwise_distances (
121+ X , self .location_ [np .newaxis , :], metric = "mahalanobis" , VI = precision
122+ )
123+
124+ return np .reshape (dist , (len (X ),)) ** 2
125+
126+ error_norm = wrap_output_data (sklearn_EmpiricalCovariance .error_norm )
127+ score = wrap_output_data (sklearn_EmpiricalCovariance .score )
128+
103129 fit .__doc__ = sklearn_EmpiricalCovariance .fit .__doc__
130+ mahalanobis .__doc__ = sklearn_EmpiricalCovariance .mahalanobis
131+ error_norm .__doc__ = sklearn_EmpiricalCovariance .error_norm .__doc__
132+ score .__doc__ = sklearn_EmpiricalCovariance .score .__doc__
0 commit comments