@@ -109,14 +109,7 @@ def _onedal_fit_supported(self, method_name, X, y=None, sample_weight=None):
109109 _is_csr (X ) and daal_check_version ((2024 , "P" , 700 ))
110110 ) or not issparse (X )
111111
112- _acceptable_sample_weights = True
113- if sample_weight is not None or not isinstance (sample_weight , numbers .Number ):
114- sample_weight = _check_sample_weight (
115- sample_weight , X , dtype = X .dtype if hasattr (X , "dtype" ) else None
116- )
117- _acceptable_sample_weights = np .allclose (
118- sample_weight , np .ones_like (sample_weight )
119- )
112+ _acceptable_sample_weights = self ._validate_sample_weight (sample_weight , X )
120113
121114 patching_status .and_conditions (
122115 [
@@ -127,7 +120,7 @@ def _onedal_fit_supported(self, method_name, X, y=None, sample_weight=None):
127120 (correct_count , "n_clusters is smaller than number of samples" ),
128121 (
129122 _acceptable_sample_weights ,
130- "oneDAL doesn't support sample_weight, either None or ones are acceptable " ,
123+ "oneDAL doesn't support sample_weight. Accepted options are None, constant, or equal weights. " ,
131124 ),
132125 (
133126 is_data_supported ,
@@ -161,6 +154,9 @@ def _onedal_fit(self, X, _, sample_weight, queue=None):
161154 X ,
162155 accept_sparse = "csr" ,
163156 dtype = [np .float64 , np .float32 ],
157+ order = "C" ,
158+ copy = self .copy_x ,
159+ accept_large_sparse = False ,
164160 )
165161
166162 if sklearn_check_version ("1.2" ):
@@ -176,6 +172,22 @@ def _onedal_fit(self, X, _, sample_weight, queue=None):
176172
177173 self ._save_attributes ()
178174
175+ def _validate_sample_weight (self , sample_weight , X ):
176+ if sample_weight is None :
177+ return True
178+ elif isinstance (sample_weight , numbers .Number ):
179+ return True
180+ else :
181+ sample_weight = _check_sample_weight (
182+ sample_weight ,
183+ X ,
184+ dtype = X .dtype if hasattr (X , "dtype" ) else None ,
185+ )
186+ if np .all (sample_weight == sample_weight [0 ]):
187+ return True
188+ else :
189+ return False
190+
179191 def _onedal_predict_supported (self , method_name , X , sample_weight = None ):
180192 class_name = self .__class__ .__name__
181193 is_data_supported = (
@@ -194,12 +206,9 @@ def _onedal_predict_supported(self, method_name, X, sample_weight=None):
194206 )
195207
196208 _acceptable_sample_weights = True
197- if sample_weight is not None or not isinstance (sample_weight , numbers .Number ):
198- sample_weight = _check_sample_weight (
199- sample_weight , X , dtype = X .dtype if hasattr (X , "dtype" ) else None
200- )
201- _acceptable_sample_weights = np .allclose (
202- sample_weight , np .ones_like (sample_weight )
209+ if not sklearn_check_version ("1.5" ):
210+ _acceptable_sample_weights = self ._validate_sample_weight (
211+ sample_weight , X
203212 )
204213
205214 patching_status .and_conditions (
@@ -214,7 +223,7 @@ def _onedal_predict_supported(self, method_name, X, sample_weight=None):
214223 ),
215224 (
216225 _acceptable_sample_weights ,
217- "oneDAL doesn't support sample_weight, None or ones are acceptable " ,
226+ "oneDAL doesn't support sample_weight. Acceptable options are None, constant, or equal weights. " ,
218227 ),
219228 ]
220229 )
0 commit comments