66import numpy as np
77from cuml import KMeans as cumlKMeans
88
9+ from ._fuses import _calc_R , _div_clip , _get_factor , _get_pen , _log_div_OE , _R_multi_m
910from ._kernels ._normalize import _get_normalize_kernel_optimized
1011
1112if TYPE_CHECKING :
@@ -36,11 +37,6 @@ def _normalize_cp_p1(X: cp.ndarray) -> cp.ndarray:
3637 return X
3738
3839
39- @cp .fuse
40- def _div_clip (X : cp .ndarray , norm : cp .ndarray ) -> cp .ndarray :
41- return X / cp .clip (norm , a_min = 1e-12 , a_max = cp .inf )
42-
43-
4440def _normalize_cp (X : cp .ndarray , p : int = 2 ) -> cp .ndarray :
4541 if p == 2 :
4642 return _div_clip (X , cp .linalg .norm (X , ord = p , axis = 1 , keepdims = True ))
@@ -210,12 +206,12 @@ def harmonize(
210206 block_proportion ,
211207 )
212208
213- Z_hat = correction (Z , R , Phi , O , ridge_lambda , correction_method )
209+ Z_hat = _correction (Z , R , Phi , O , ridge_lambda , correction_method )
214210 Z_norm = _normalize_cp (Z_hat , p = 2 )
215211 if verbose :
216212 print (f"\t Completed { i + 1 } / { max_iter_harmony } iteration(s)." )
217213
218- if is_convergent_harmony (objectives_harmony , tol = tol_harmony ):
214+ if _is_convergent_harmony (objectives_harmony , tol = tol_harmony ):
219215 if verbose :
220216 print (f"Reach convergence after { i + 1 } iteration(s)." )
221217 break
@@ -253,16 +249,6 @@ def _initialize_centroids(
253249 return R , E , O , objectives_harmony
254250
255251
256- @cp .fuse
257- def _get_pen (E : cp .ndarray , O : cp .ndarray , theta : cp .ndarray ) -> cp .ndarray :
258- return cp .power (cp .divide (E + 1 , O + 1 ), theta )
259-
260-
261- @cp .fuse
262- def _calc_R (term : cp .ndarray , mm : cp .ndarray ) -> cp .ndarray :
263- return cp .exp (term * (1 - mm ))
264-
265-
266252def _clustering (
267253 Z_norm : cp .ndarray ,
268254 Pr_b : cp .ndarray ,
@@ -335,19 +321,28 @@ def _clustering(
335321 pos += block_size
336322 _compute_objective (Y_norm , Z_norm , R , theta , sigma , O , E , objectives_clustering )
337323
338- if is_convergent_clustering (objectives_clustering , tol ):
324+ if _is_convergent_clustering (objectives_clustering , tol ):
339325 objectives_harmony .append (objectives_clustering [- 1 ])
340326 break
341327
342328
343- def correction (X , R , Phi , O , ridge_lambda , correction_method ): # noqa: PLR0917, RUF100
329+ def _correction (
330+ X : cp .ndarray ,
331+ R : cp .ndarray ,
332+ Phi : cp .ndarray ,
333+ O : cp .ndarray ,
334+ ridge_lambda : float ,
335+ correction_method : str ,
336+ ) -> cp .ndarray :
344337 if correction_method == "fast" :
345- return correction_fast (X , R , Phi , O , ridge_lambda )
338+ return _correction_fast (X , R , Phi , O , ridge_lambda )
346339 else :
347- return correction_original (X , R , Phi , ridge_lambda )
340+ return _correction_original (X , R , Phi , ridge_lambda )
348341
349342
350- def correction_original (X , R , Phi , ridge_lambda ):
343+ def _correction_original (
344+ X : cp .ndarray , R : cp .ndarray , Phi : cp .ndarray , ridge_lambda : float
345+ ) -> cp .ndarray :
351346 n_cells = X .shape [0 ]
352347 n_clusters = R .shape [1 ]
353348 n_batches = Phi .shape [1 ]
@@ -367,12 +362,9 @@ def correction_original(X, R, Phi, ridge_lambda):
367362 return Z
368363
369364
370- @cp .fuse
371- def _get_factor (O_k , ridge_lambda ):
372- return 1 / (O_k + ridge_lambda )
373-
374-
375- def correction_fast (X , R , Phi , O , ridge_lambda ):
365+ def _correction_fast (
366+ X : cp .ndarray , R : cp .ndarray , Phi : cp .ndarray , O : cp .ndarray , ridge_lambda : float
367+ ) -> cp .ndarray :
376368 n_cells = X .shape [0 ]
377369 n_clusters = R .shape [1 ]
378370 n_batches = Phi .shape [1 ]
@@ -408,16 +400,6 @@ def correction_fast(X, R, Phi, O, ridge_lambda):
408400 return Z
409401
410402
411- @cp .fuse
412- def log_div_OE (O : cp .ndarray , E : cp .ndarray ) -> cp .ndarray :
413- return O * cp .log ((O + 1 ) / (E + 1 ))
414-
415-
416- @cp .fuse
417- def R_multi_m (R , other ):
418- return R * 2 * (1 - other )
419-
420-
421403def _compute_objective (
422404 Y_norm : cp .ndarray ,
423405 Z_norm : cp .ndarray ,
@@ -427,17 +409,17 @@ def _compute_objective(
427409 O : cp .ndarray ,
428410 E : cp .ndarray ,
429411 objective_arr : list ,
430- ): # noqa: PLR0917, RUF100
431- kmeans_error = cp .sum (R_multi_m (R , cp .dot (Z_norm , Y_norm .T )))
412+ ):
413+ kmeans_error = cp .sum (_R_multi_m (R , cp .dot (Z_norm , Y_norm .T )))
432414 R = R / R .sum (axis = 1 , keepdims = True )
433415 entropy = cp .sum (R * cp .log (R + 1e-12 ))
434416 entropy_term = sigma * entropy
435- diversity_penalty = sigma * cp .sum (cp .dot (theta , log_div_OE (O , E )))
417+ diversity_penalty = sigma * cp .sum (cp .dot (theta , _log_div_OE (O , E )))
436418 objective = kmeans_error + entropy_term + diversity_penalty
437419 objective_arr .append (objective )
438420
439421
440- def is_convergent_harmony (objectives_harmony , tol ) :
422+ def _is_convergent_harmony (objectives_harmony : list , tol : float ) -> bool :
441423 if len (objectives_harmony ) < 2 :
442424 return False
443425
@@ -447,7 +429,9 @@ def is_convergent_harmony(objectives_harmony, tol):
447429 return (obj_old - obj_new ) < tol * np .abs (obj_old )
448430
449431
450- def is_convergent_clustering (objectives_clustering , tol , window_size = 3 ):
432+ def _is_convergent_clustering (
433+ objectives_clustering : list , tol : list , window_size : int = 3
434+ ):
451435 if len (objectives_clustering ) < window_size + 1 :
452436 return False
453437 obj_old = 0.0
0 commit comments