1818
1919from onedal import _backend
2020from ..common ._policy import _get_policy
21- from ..datatypes ._data_conversion import from_table , to_table , _convert_to_supported
21+ from ..datatypes ._data_conversion import from_table , to_table
22+ from ..datatypes import _convert_to_supported
2223from daal4py .sklearn ._utils import sklearn_check_version
2324
2425
@@ -43,17 +44,20 @@ def get_onedal_params(self, data):
4344 'is_deterministic' : self .is_deterministic
4445 }
4546
46- def fit (self , X , y , queue ):
47+ def _get_policy (self , queue , * data ):
48+ return _get_policy (queue , * data )
49+
50+ def fit (self , X , queue ):
4751 n_samples , n_features = X .shape
4852 n_sf_min = min (n_samples , n_features )
4953
50- policy = _get_policy (queue , X , y )
51-
54+ policy = self ._get_policy (queue , X )
5255 # TODO: investigate why np.ndarray with OWNDATA=FALSE flag
5356 # fails to be converted to oneDAL table
5457 if isinstance (X , np .ndarray ) and not X .flags ['OWNDATA' ]:
5558 X = X .copy ()
56- X , y = _convert_to_supported (policy , X , y )
59+ X = _convert_to_supported (policy , X )
60+
5761 params = self .get_onedal_params (X )
5862 cov_result = _backend .covariance .compute (
5963 policy ,
@@ -99,10 +103,11 @@ def fit(self, X, y, queue):
99103 def _create_model (self ):
100104 m = _backend .decomposition .dim_reduction .model ()
101105 m .eigenvectors = to_table (self .components_ )
106+ self ._onedal_model = m
102107 return m
103108
104109 def predict (self , X , queue ):
105- policy = _get_policy (queue , X )
110+ policy = self . _get_policy (queue , X )
106111 model = self ._create_model ()
107112
108113 X = _convert_to_supported (policy , X )
0 commit comments