@@ -174,6 +174,14 @@ def _daal4py_k_means_fit(X, nClusters, numIterations,
174174 if numIterations < 0 :
175175 raise ValueError ("Wrong iterations number" )
176176
177+ def is_string (s , target_str ):
178+ return isinstance (s , str ) and s == target_str
179+
180+ if n_init == 'auto' :
181+ if is_string (cluster_centers_0 , 'random' ):
182+ n_init = 10
183+ elif is_string (cluster_centers_0 , 'k-means++' ):
184+ n_init = 1
177185 X_fptype = getFPType (X )
178186 abs_tol = _tolerance (X , tol ) # tol is relative tolerance
179187 is_sparse = sp .isspmatrix (X )
@@ -311,17 +319,32 @@ def _fit(self, X, y=None, sample_weight=None):
311319 f"max_iter should be > 0, got { self .max_iter } instead." )
312320
313321 algorithm = self .algorithm
314- if algorithm == "elkan" and self .n_clusters == 1 :
315- warnings .warn ("algorithm='elkan' doesn't make sense for a single "
316- "cluster. Using 'full' instead." , RuntimeWarning )
317- algorithm = "full"
322+ if sklearn_check_version ('1.2' ):
323+ if algorithm == "elkan" and self .n_clusters == 1 :
324+ warnings .warn ("algorithm='elkan' doesn't make sense for a single "
325+ "cluster. Using 'full' instead." , RuntimeWarning )
326+ algorithm = "lloyd"
327+
328+ if algorithm == "auto" or algorithm == "full" :
329+ warnings .warn ("algorithm= {'auto','full'} is deprecated"
330+ "Using 'lloyd' instead." , RuntimeWarning )
331+ algorithm = "lloyd" if self .n_clusters == 1 else "elkan"
332+
333+ if algorithm not in ["lloyd" , "full" , "elkan" ]:
334+ raise ValueError ("Algorithm must be 'auto','lloyd', 'full' or 'elkan',"
335+ "got {}" .format (str (algorithm )))
336+ else :
337+ if algorithm == "elkan" and self .n_clusters == 1 :
338+ warnings .warn ("algorithm='elkan' doesn't make sense for a single "
339+ "cluster. Using 'full' instead." , RuntimeWarning )
340+ algorithm = "full"
318341
319- if algorithm == "auto" :
320- algorithm = "full" if self .n_clusters == 1 else "elkan"
342+ if algorithm == "auto" :
343+ algorithm = "full" if self .n_clusters == 1 else "elkan"
321344
322- if algorithm not in ["full" , "elkan" ]:
323- raise ValueError ("Algorithm must be 'auto', 'full' or 'elkan', got"
324- " {}" .format (str (algorithm )))
345+ if algorithm not in ["full" , "elkan" ]:
346+ raise ValueError ("Algorithm must be 'auto', 'full' or 'elkan', got"
347+ " {}" .format (str (algorithm )))
325348
326349 X_len = _num_samples (X )
327350
@@ -422,7 +445,33 @@ def _predict(self, X, sample_weight=None):
422445class KMeans (KMeans_original ):
423446 __doc__ = KMeans_original .__doc__
424447
425- if sklearn_check_version ('1.0' ):
448+ if sklearn_check_version ('1.2' ):
449+ @_deprecate_positional_args
450+ def __init__ (
451+ self ,
452+ n_clusters = 8 ,
453+ * ,
454+ init = 'k-means++' ,
455+ n_init = 10 ,
456+ max_iter = 300 ,
457+ tol = 1e-4 ,
458+ verbose = 0 ,
459+ random_state = None ,
460+ copy_x = True ,
461+ algorithm = 'lloyd' ,
462+ ):
463+ super (KMeans , self ).__init__ (
464+ n_clusters = n_clusters ,
465+ init = init ,
466+ max_iter = max_iter ,
467+ tol = tol ,
468+ n_init = n_init ,
469+ verbose = verbose ,
470+ random_state = random_state ,
471+ copy_x = copy_x ,
472+ algorithm = algorithm ,
473+ )
474+ elif sklearn_check_version ('1.0' ):
426475 @_deprecate_positional_args
427476 def __init__ (
428477 self ,
0 commit comments