|
41 | 41 | PatchingConditionsChain) |
42 | 42 | from .._device_offload import support_usm_ndarray |
43 | 43 |
|
| 44 | +if sklearn_check_version('1.1'): |
| 45 | + from sklearn.utils.validation import ( |
| 46 | + _check_sample_weight, _is_arraylike_not_scalar) |
| 47 | + |
44 | 48 |
|
45 | 49 | def _validate_center_shape(X, n_centers, centers): |
46 | 50 | """Check if centers is compatible with X and n_centers""" |
@@ -242,53 +246,82 @@ def _fit(self, X, y=None, sample_weight=None): |
242 | 246 | are assigned equal weight (default: None) |
243 | 247 |
|
244 | 248 | """ |
245 | | - if hasattr(self, 'precompute_distances'): |
246 | | - if self.precompute_distances != 'deprecated': |
247 | | - if sklearn_check_version('0.24'): |
248 | | - warnings.warn("'precompute_distances' was deprecated in version " |
249 | | - "0.23 and will be removed in 1.0 (renaming of 0.25)." |
250 | | - " It has no effect", FutureWarning) |
251 | | - elif sklearn_check_version('0.23'): |
252 | | - warnings.warn("'precompute_distances' was deprecated in version " |
253 | | - "0.23 and will be removed in 0.25. It has no " |
254 | | - "effect", FutureWarning) |
255 | | - |
256 | | - self._n_threads = None |
257 | | - if hasattr(self, 'n_jobs'): |
258 | | - if self.n_jobs != 'deprecated': |
259 | | - if sklearn_check_version('0.24'): |
260 | | - warnings.warn("'n_jobs' was deprecated in version 0.23 and will be" |
261 | | - " removed in 1.0 (renaming of 0.25).", FutureWarning) |
262 | | - elif sklearn_check_version('0.23'): |
263 | | - warnings.warn("'n_jobs' was deprecated in version 0.23 and will be" |
264 | | - " removed in 0.25.", FutureWarning) |
265 | | - self._n_threads = self.n_jobs |
266 | | - self._n_threads = _openmp_effective_n_threads(self._n_threads) |
267 | | - |
268 | | - if self.n_init <= 0: |
269 | | - raise ValueError( |
270 | | - f"n_init should be > 0, got {self.n_init} instead.") |
271 | | - |
272 | | - random_state = check_random_state(self.random_state) |
273 | | - if sklearn_check_version("1.0"): |
274 | | - self._check_feature_names(X, reset=True) |
275 | | - |
276 | | - if self.max_iter <= 0: |
277 | | - raise ValueError( |
278 | | - f"max_iter should be > 0, got {self.max_iter} instead.") |
| 249 | + init = self.init |
| 250 | + if sklearn_check_version('1.1'): |
| 251 | + if sklearn_check_version('1.2'): |
| 252 | + self._validate_params() |
| 253 | + |
| 254 | + X = self._validate_data( |
| 255 | + X, |
| 256 | + accept_sparse="csr", |
| 257 | + dtype=[np.float64, np.float32], |
| 258 | + order="C", |
| 259 | + copy=self.copy_x, |
| 260 | + accept_large_sparse=False, |
| 261 | + ) |
279 | 262 |
|
280 | | - algorithm = self.algorithm |
281 | | - if algorithm == "elkan" and self.n_clusters == 1: |
282 | | - warnings.warn("algorithm='elkan' doesn't make sense for a single " |
283 | | - "cluster. Using 'full' instead.", RuntimeWarning) |
284 | | - algorithm = "full" |
| 263 | + if sklearn_check_version('1.2'): |
| 264 | + self._check_params_vs_input(X) |
| 265 | + else: |
| 266 | + self._check_params(X) |
285 | 267 |
|
286 | | - if algorithm == "auto": |
287 | | - algorithm = "full" if self.n_clusters == 1 else "elkan" |
| 268 | + random_state = check_random_state(self.random_state) |
| 269 | + sample_weight = _check_sample_weight(sample_weight, X, dtype=X.dtype) |
| 270 | + self._n_threads = _openmp_effective_n_threads() |
288 | 271 |
|
289 | | - if algorithm not in ["full", "elkan"]: |
290 | | - raise ValueError("Algorithm must be 'auto', 'full' or 'elkan', got" |
291 | | - " {}".format(str(algorithm))) |
| 272 | + # Validate init array |
| 273 | + init_is_array_like = _is_arraylike_not_scalar(init) |
| 274 | + if init_is_array_like: |
| 275 | + init = check_array(init, dtype=X.dtype, copy=True, order="C") |
| 276 | + self._validate_center_shape(X, init) |
| 277 | + else: |
| 278 | + if hasattr(self, 'precompute_distances'): |
| 279 | + if self.precompute_distances != 'deprecated': |
| 280 | + if sklearn_check_version('0.24'): |
| 281 | + warnings.warn("'precompute_distances' was deprecated in version " |
| 282 | + "0.23 and will be removed in 1.0 (renaming of 0.25)." |
| 283 | + " It has no effect", FutureWarning) |
| 284 | + elif sklearn_check_version('0.23'): |
| 285 | + warnings.warn("'precompute_distances' was deprecated in version " |
| 286 | + "0.23 and will be removed in 0.25. It has no " |
| 287 | + "effect", FutureWarning) |
| 288 | + |
| 289 | + self._n_threads = None |
| 290 | + if hasattr(self, 'n_jobs'): |
| 291 | + if self.n_jobs != 'deprecated': |
| 292 | + if sklearn_check_version('0.24'): |
| 293 | + warnings.warn("'n_jobs' was deprecated in version 0.23 and will be" |
| 294 | + " removed in 1.0 (renaming of 0.25).", FutureWarning) |
| 295 | + elif sklearn_check_version('0.23'): |
| 296 | + warnings.warn("'n_jobs' was deprecated in version 0.23 and will be" |
| 297 | + " removed in 0.25.", FutureWarning) |
| 298 | + self._n_threads = self.n_jobs |
| 299 | + self._n_threads = _openmp_effective_n_threads(self._n_threads) |
| 300 | + |
| 301 | + if self.n_init <= 0: |
| 302 | + raise ValueError( |
| 303 | + f"n_init should be > 0, got {self.n_init} instead.") |
| 304 | + |
| 305 | + random_state = check_random_state(self.random_state) |
| 306 | + if sklearn_check_version("1.0"): |
| 307 | + self._check_feature_names(X, reset=True) |
| 308 | + |
| 309 | + if self.max_iter <= 0: |
| 310 | + raise ValueError( |
| 311 | + f"max_iter should be > 0, got {self.max_iter} instead.") |
| 312 | + |
| 313 | + algorithm = self.algorithm |
| 314 | + if algorithm == "elkan" and self.n_clusters == 1: |
| 315 | + warnings.warn("algorithm='elkan' doesn't make sense for a single " |
| 316 | + "cluster. Using 'full' instead.", RuntimeWarning) |
| 317 | + algorithm = "full" |
| 318 | + |
| 319 | + if algorithm == "auto": |
| 320 | + algorithm = "full" if self.n_clusters == 1 else "elkan" |
| 321 | + |
| 322 | + if algorithm not in ["full", "elkan"]: |
| 323 | + raise ValueError("Algorithm must be 'auto', 'full' or 'elkan', got" |
| 324 | + " {}".format(str(algorithm))) |
292 | 325 |
|
293 | 326 | X_len = _num_samples(X) |
294 | 327 |
|
@@ -317,8 +350,10 @@ def _fit(self, X, y=None, sample_weight=None): |
317 | 350 | self.n_features_in_ = X.shape[1] |
318 | 351 | self.cluster_centers_, self.labels_, self.inertia_, self.n_iter_ = \ |
319 | 352 | _daal4py_k_means_fit( |
320 | | - X, self.n_clusters, self.max_iter, self.tol, self.init, self.n_init, |
| 353 | + X, self.n_clusters, self.max_iter, self.tol, init, self.n_init, |
321 | 354 | self.verbose, random_state) |
| 355 | + if sklearn_check_version('1.1'): |
| 356 | + self._n_features_out = self.cluster_centers_.shape[0] |
322 | 357 | else: |
323 | 358 | super(KMeans, self).fit(X, y=y, sample_weight=sample_weight) |
324 | 359 | return self |
|
0 commit comments