Skip to content

Commit ad475ff

Browse files
md-shafiul-alamnapetrov
authored andcommitted
allow n_init param in KMeans to accept str 'auto' (#1045)
* allow n_init param to accept in kmeans * whitespace removed * whitespace removed * removed check for string for param n_init in kmeans * changed param 'algorithm' default to lloyd * whitespace * add versioning for param change * pep8 fix * pep8 fix * pep8 fix
1 parent 68c9be4 commit ad475ff

File tree

1 file changed

+59
-10
lines changed

1 file changed

+59
-10
lines changed

daal4py/sklearn/cluster/_k_means_0_23.py

Lines changed: 59 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,14 @@ def _daal4py_k_means_fit(X, nClusters, numIterations,
174174
if numIterations < 0:
175175
raise ValueError("Wrong iterations number")
176176

177+
def is_string(s, target_str):
178+
return isinstance(s, str) and s == target_str
179+
180+
if n_init == 'auto':
181+
if is_string(cluster_centers_0, 'random'):
182+
n_init = 10
183+
elif is_string(cluster_centers_0, 'k-means++'):
184+
n_init = 1
177185
X_fptype = getFPType(X)
178186
abs_tol = _tolerance(X, tol) # tol is relative tolerance
179187
is_sparse = sp.isspmatrix(X)
@@ -311,17 +319,32 @@ def _fit(self, X, y=None, sample_weight=None):
311319
f"max_iter should be > 0, got {self.max_iter} instead.")
312320

313321
algorithm = self.algorithm
314-
if algorithm == "elkan" and self.n_clusters == 1:
315-
warnings.warn("algorithm='elkan' doesn't make sense for a single "
316-
"cluster. Using 'full' instead.", RuntimeWarning)
317-
algorithm = "full"
322+
if sklearn_check_version('1.2'):
323+
if algorithm == "elkan" and self.n_clusters == 1:
324+
warnings.warn("algorithm='elkan' doesn't make sense for a single "
325+
"cluster. Using 'full' instead.", RuntimeWarning)
326+
algorithm = "lloyd"
327+
328+
if algorithm == "auto" or algorithm == "full":
329+
warnings.warn("algorithm= {'auto','full'} is deprecated"
330+
"Using 'lloyd' instead.", RuntimeWarning)
331+
algorithm = "lloyd" if self.n_clusters == 1 else "elkan"
332+
333+
if algorithm not in ["lloyd", "full", "elkan"]:
334+
raise ValueError("Algorithm must be 'auto','lloyd', 'full' or 'elkan',"
335+
"got {}".format(str(algorithm)))
336+
else:
337+
if algorithm == "elkan" and self.n_clusters == 1:
338+
warnings.warn("algorithm='elkan' doesn't make sense for a single "
339+
"cluster. Using 'full' instead.", RuntimeWarning)
340+
algorithm = "full"
318341

319-
if algorithm == "auto":
320-
algorithm = "full" if self.n_clusters == 1 else "elkan"
342+
if algorithm == "auto":
343+
algorithm = "full" if self.n_clusters == 1 else "elkan"
321344

322-
if algorithm not in ["full", "elkan"]:
323-
raise ValueError("Algorithm must be 'auto', 'full' or 'elkan', got"
324-
" {}".format(str(algorithm)))
345+
if algorithm not in ["full", "elkan"]:
346+
raise ValueError("Algorithm must be 'auto', 'full' or 'elkan', got"
347+
" {}".format(str(algorithm)))
325348

326349
X_len = _num_samples(X)
327350

@@ -422,7 +445,33 @@ def _predict(self, X, sample_weight=None):
422445
class KMeans(KMeans_original):
423446
__doc__ = KMeans_original.__doc__
424447

425-
if sklearn_check_version('1.0'):
448+
if sklearn_check_version('1.2'):
449+
@_deprecate_positional_args
450+
def __init__(
451+
self,
452+
n_clusters=8,
453+
*,
454+
init='k-means++',
455+
n_init=10,
456+
max_iter=300,
457+
tol=1e-4,
458+
verbose=0,
459+
random_state=None,
460+
copy_x=True,
461+
algorithm='lloyd',
462+
):
463+
super(KMeans, self).__init__(
464+
n_clusters=n_clusters,
465+
init=init,
466+
max_iter=max_iter,
467+
tol=tol,
468+
n_init=n_init,
469+
verbose=verbose,
470+
random_state=random_state,
471+
copy_x=copy_x,
472+
algorithm=algorithm,
473+
)
474+
elif sklearn_check_version('1.0'):
426475
@_deprecate_positional_args
427476
def __init__(
428477
self,

0 commit comments

Comments
 (0)