17
17
import numpy as np
18
18
19
19
from daal4py .sklearn ._utils import daal_check_version
20
- from onedal ._device_offload import supports_queue
21
- from onedal .common ._backend import bind_default_backend
22
- from onedal .utils import _sycl_queue_manager as QM
23
20
24
21
from .._config import _get_config
25
- from ..datatypes import from_table , to_table
22
+ from .._device_offload import supports_queue
23
+ from ..common ._backend import bind_default_backend
24
+ from ..datatypes import from_table , return_type_constructor , to_table
25
+ from ..utils import _sycl_queue_manager as QM
26
26
from ..utils ._array_api import _get_sycl_namespace
27
27
from ..utils .validation import _check_array
28
28
from .covariance import BaseEmpiricalCovariance
@@ -74,6 +74,7 @@ def finalize_compute(self, params, partial_result): ...
74
74
def _reset (self ):
75
75
self ._need_to_finalize = False
76
76
self ._queue = None
77
+ self ._outtype = None
77
78
self ._partial_result = self .partial_compute_result ()
78
79
79
80
def __getstate__ (self ):
@@ -108,15 +109,10 @@ def partial_fit(self, X, y=None, queue=None):
108
109
self : object
109
110
Returns the instance itself.
110
111
"""
111
- use_raw_input = _get_config ()["use_raw_input" ] is True
112
- sua_iface , _ , _ = _get_sycl_namespace (X )
113
-
114
- if use_raw_input and sua_iface :
115
- queue = X .sycl_queue
116
- if not use_raw_input :
117
- X = _check_array (X , dtype = [np .float64 , np .float32 ], ensure_2d = True )
118
112
119
113
self ._queue = queue
114
+ if not self ._outtype :
115
+ self ._outtype = return_type_constructor (X )
120
116
X_table = to_table (X , queue = queue )
121
117
122
118
if not hasattr (self , "_dtype" ):
@@ -125,8 +121,6 @@ def partial_fit(self, X, y=None, queue=None):
125
121
params = self ._get_onedal_params (self ._dtype )
126
122
self ._partial_result = self .partial_compute (params , self ._partial_result , X_table )
127
123
self ._need_to_finalize = True
128
- # store the queue for when we finalize
129
- self ._queue = queue
130
124
131
125
def finalize_fit (self ):
132
126
"""Finalize covariance matrix from the current `_partial_result`.
@@ -143,13 +137,14 @@ def finalize_fit(self):
143
137
with QM .manage_global_queue (self ._queue ):
144
138
result = self .finalize_compute (params , self ._partial_result )
145
139
146
- if daal_check_version (( 2024 , "P" , 1 )) or ( not self .bias ):
147
- self . covariance_ = from_table ( result . cov_matrix )
148
- else :
140
+ self . covariance_ = from_table ( result . cov_matrix , like = self ._outtype )
141
+
142
+ if self . bias and not daal_check_version (( 2024 , "P" , 1 )) :
149
143
n_rows = self ._partial_result .partial_n_rows
150
- self .covariance_ = from_table ( result . cov_matrix ) * (n_rows - 1 ) / n_rows
144
+ self .covariance_ *= (n_rows - 1 ) / n_rows
151
145
152
- self .location_ = from_table (result .means ).ravel ()
146
+ self .location_ = from_table (result .means , like = self ._outtype )[0 , ...]
147
+ self ._outtype = None
153
148
154
149
self ._need_to_finalize = False
155
150
0 commit comments