1717import numpy as np
1818from sklearn .base import BaseEstimator
1919from sklearn .utils import check_array , gen_batches
20+ from sklearn .utils .validation import _check_sample_weight
2021
2122from daal4py .sklearn ._n_jobs_support import control_n_jobs
2223from daal4py .sklearn ._utils import sklearn_check_version
@@ -139,7 +140,7 @@ def _onedal_finalize_fit(self):
139140 self ._onedal_estimator .finalize_fit ()
140141 self ._need_to_finalize = False
141142
142- def _onedal_partial_fit (self , X , weights , queue ):
143+ def _onedal_partial_fit (self , X , sample_weight = None , queue = None ):
143144 first_pass = not hasattr (self , "n_samples_seen_" ) or self .n_samples_seen_ == 0
144145
145146 if sklearn_check_version ("1.0" ):
@@ -152,9 +153,11 @@ def _onedal_partial_fit(self, X, weights, queue):
152153 X = check_array (
153154 X ,
154155 dtype = [np .float64 , np .float32 ],
155- copy = self .copy_X ,
156156 )
157157
158+ if sample_weight is not None :
159+ sample_weight = _check_sample_weight (sample_weight , X )
160+
158161 if first_pass :
159162 self .n_samples_seen_ = X .shape [0 ]
160163 self .n_features_in_ = X .shape [1 ]
@@ -168,15 +171,18 @@ def _onedal_partial_fit(self, X, weights, queue):
168171 self ._onedal_estimator = self ._onedal_incremental_basic_statistics (
169172 ** onedal_params
170173 )
171- self ._onedal_estimator .partial_fit (X , weights , queue )
174+ self ._onedal_estimator .partial_fit (X , sample_weight , queue )
172175 self ._need_to_finalize = True
173176
174- def _onedal_fit (self , X , weights , queue = None ):
177+ def _onedal_fit (self , X , sample_weight = None , queue = None ):
175178 if sklearn_check_version ("1.0" ):
176179 X = self ._validate_data (X , dtype = [np .float64 , np .float32 ])
177180 else :
178181 X = check_array (X , dtype = [np .float64 , np .float32 ])
179182
183+ if sample_weight is not None :
184+ sample_weight = _check_sample_weight (sample_weight , X )
185+
180186 n_samples , n_features = X .shape
181187 if self .batch_size is None :
182188 self .batch_size_ = 5 * n_features
@@ -189,7 +195,7 @@ def _onedal_fit(self, X, weights, queue=None):
189195
190196 for batch in gen_batches (X .shape [0 ], self .batch_size_ ):
191197 X_batch = X [batch ]
192- weights_batch = weights [batch ] if weights is not None else None
198+ weights_batch = sample_weight [batch ] if sample_weight is not None else None
193199 self ._onedal_partial_fit (X_batch , weights_batch , queue = queue )
194200
195201 if sklearn_check_version ("1.2" ):
@@ -217,7 +223,7 @@ def __getattr__(self, attr):
217223 f"'{ self .__class__ .__name__ } ' object has no attribute '{ attr } '"
218224 )
219225
220- def partial_fit (self , X , weights = None ):
226+ def partial_fit (self , X , sample_weight = None ):
221227 """Incremental fit with X. All of X is processed as a single batch.
222228
223229 Parameters
@@ -226,7 +232,10 @@ def partial_fit(self, X, weights=None):
226232 Data for compute, where `n_samples` is the number of samples and
227233 `n_features` is the number of features.
228234
229- weights : array-like of shape (n_samples,)
235+ y : Ignored
236+ Not used, present for API consistency by convention.
237+
238+ sample_weight : array-like of shape (n_samples,), default=None
230239 Weights for compute weighted statistics, where `n_samples` is the number of samples.
231240
232241 Returns
@@ -242,11 +251,11 @@ def partial_fit(self, X, weights=None):
242251 "sklearn" : None ,
243252 },
244253 X ,
245- weights ,
254+ sample_weight ,
246255 )
247256 return self
248257
249- def fit (self , X , weights = None ):
258+ def fit (self , X , y = None , sample_weight = None ):
250259 """Compute statistics with X, using minibatches of size batch_size.
251260
252261 Parameters
@@ -255,7 +264,10 @@ def fit(self, X, weights=None):
255264 Data for compute, where `n_samples` is the number of samples and
256265 `n_features` is the number of features.
257266
258- weights : array-like of shape (n_samples,)
267+ y : Ignored
268+ Not used, present for API consistency by convention.
269+
270+ sample_weight : array-like of shape (n_samples,), default=None
259271 Weights for compute weighted statistics, where `n_samples` is the number of samples.
260272
261273 Returns
@@ -271,6 +283,6 @@ def fit(self, X, weights=None):
271283 "sklearn" : None ,
272284 },
273285 X ,
274- weights ,
286+ sample_weight ,
275287 )
276288 return self
0 commit comments