Skip to content

Commit 8c80ab3

Browse files
committed
add typing
1 parent 00c8846 commit 8c80ab3

File tree

2 files changed

+60
-43
lines changed

2 files changed

+60
-43
lines changed

src/rapids_singlecell/preprocessing/_harmony/__init__.py

Lines changed: 27 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import numpy as np
77
from cuml import KMeans as cumlKMeans
88

9+
from ._fuses import _calc_R, _div_clip, _get_factor, _get_pen, _log_div_OE, _R_multi_m
910
from ._kernels._normalize import _get_normalize_kernel_optimized
1011

1112
if TYPE_CHECKING:
@@ -36,11 +37,6 @@ def _normalize_cp_p1(X: cp.ndarray) -> cp.ndarray:
3637
return X
3738

3839

39-
@cp.fuse
40-
def _div_clip(X: cp.ndarray, norm: cp.ndarray) -> cp.ndarray:
41-
return X / cp.clip(norm, a_min=1e-12, a_max=cp.inf)
42-
43-
4440
def _normalize_cp(X: cp.ndarray, p: int = 2) -> cp.ndarray:
4541
if p == 2:
4642
return _div_clip(X, cp.linalg.norm(X, ord=p, axis=1, keepdims=True))
@@ -210,12 +206,12 @@ def harmonize(
210206
block_proportion,
211207
)
212208

213-
Z_hat = correction(Z, R, Phi, O, ridge_lambda, correction_method)
209+
Z_hat = _correction(Z, R, Phi, O, ridge_lambda, correction_method)
214210
Z_norm = _normalize_cp(Z_hat, p=2)
215211
if verbose:
216212
print(f"\tCompleted {i + 1} / {max_iter_harmony} iteration(s).")
217213

218-
if is_convergent_harmony(objectives_harmony, tol=tol_harmony):
214+
if _is_convergent_harmony(objectives_harmony, tol=tol_harmony):
219215
if verbose:
220216
print(f"Reach convergence after {i + 1} iteration(s).")
221217
break
@@ -253,16 +249,6 @@ def _initialize_centroids(
253249
return R, E, O, objectives_harmony
254250

255251

256-
@cp.fuse
257-
def _get_pen(E: cp.ndarray, O: cp.ndarray, theta: cp.ndarray) -> cp.ndarray:
258-
return cp.power(cp.divide(E + 1, O + 1), theta)
259-
260-
261-
@cp.fuse
262-
def _calc_R(term: cp.ndarray, mm: cp.ndarray) -> cp.ndarray:
263-
return cp.exp(term * (1 - mm))
264-
265-
266252
def _clustering(
267253
Z_norm: cp.ndarray,
268254
Pr_b: cp.ndarray,
@@ -335,19 +321,28 @@ def _clustering(
335321
pos += block_size
336322
_compute_objective(Y_norm, Z_norm, R, theta, sigma, O, E, objectives_clustering)
337323

338-
if is_convergent_clustering(objectives_clustering, tol):
324+
if _is_convergent_clustering(objectives_clustering, tol):
339325
objectives_harmony.append(objectives_clustering[-1])
340326
break
341327

342328

343-
def correction(X, R, Phi, O, ridge_lambda, correction_method): # noqa: PLR0917, RUF100
329+
def _correction(
330+
X: cp.ndarray,
331+
R: cp.ndarray,
332+
Phi: cp.ndarray,
333+
O: cp.ndarray,
334+
ridge_lambda: float,
335+
correction_method: str,
336+
) -> cp.ndarray:
344337
if correction_method == "fast":
345-
return correction_fast(X, R, Phi, O, ridge_lambda)
338+
return _correction_fast(X, R, Phi, O, ridge_lambda)
346339
else:
347-
return correction_original(X, R, Phi, ridge_lambda)
340+
return _correction_original(X, R, Phi, ridge_lambda)
348341

349342

350-
def correction_original(X, R, Phi, ridge_lambda):
343+
def _correction_original(
344+
X: cp.ndarray, R: cp.ndarray, Phi: cp.ndarray, ridge_lambda: float
345+
) -> cp.ndarray:
351346
n_cells = X.shape[0]
352347
n_clusters = R.shape[1]
353348
n_batches = Phi.shape[1]
@@ -367,12 +362,9 @@ def correction_original(X, R, Phi, ridge_lambda):
367362
return Z
368363

369364

370-
@cp.fuse
371-
def _get_factor(O_k, ridge_lambda):
372-
return 1 / (O_k + ridge_lambda)
373-
374-
375-
def correction_fast(X, R, Phi, O, ridge_lambda):
365+
def _correction_fast(
366+
X: cp.ndarray, R: cp.ndarray, Phi: cp.ndarray, O: cp.ndarray, ridge_lambda: float
367+
) -> cp.ndarray:
376368
n_cells = X.shape[0]
377369
n_clusters = R.shape[1]
378370
n_batches = Phi.shape[1]
@@ -408,16 +400,6 @@ def correction_fast(X, R, Phi, O, ridge_lambda):
408400
return Z
409401

410402

411-
@cp.fuse
412-
def log_div_OE(O: cp.ndarray, E: cp.ndarray) -> cp.ndarray:
413-
return O * cp.log((O + 1) / (E + 1))
414-
415-
416-
@cp.fuse
417-
def R_multi_m(R, other):
418-
return R * 2 * (1 - other)
419-
420-
421403
def _compute_objective(
422404
Y_norm: cp.ndarray,
423405
Z_norm: cp.ndarray,
@@ -427,17 +409,17 @@ def _compute_objective(
427409
O: cp.ndarray,
428410
E: cp.ndarray,
429411
objective_arr: list,
430-
): # noqa: PLR0917, RUF100
431-
kmeans_error = cp.sum(R_multi_m(R, cp.dot(Z_norm, Y_norm.T)))
412+
):
413+
kmeans_error = cp.sum(_R_multi_m(R, cp.dot(Z_norm, Y_norm.T)))
432414
R = R / R.sum(axis=1, keepdims=True)
433415
entropy = cp.sum(R * cp.log(R + 1e-12))
434416
entropy_term = sigma * entropy
435-
diversity_penalty = sigma * cp.sum(cp.dot(theta, log_div_OE(O, E)))
417+
diversity_penalty = sigma * cp.sum(cp.dot(theta, _log_div_OE(O, E)))
436418
objective = kmeans_error + entropy_term + diversity_penalty
437419
objective_arr.append(objective)
438420

439421

440-
def is_convergent_harmony(objectives_harmony, tol):
422+
def _is_convergent_harmony(objectives_harmony: list, tol: float) -> bool:
441423
if len(objectives_harmony) < 2:
442424
return False
443425

@@ -447,7 +429,9 @@ def is_convergent_harmony(objectives_harmony, tol):
447429
return (obj_old - obj_new) < tol * np.abs(obj_old)
448430

449431

450-
def is_convergent_clustering(objectives_clustering, tol, window_size=3):
432+
def _is_convergent_clustering(
433+
objectives_clustering: list, tol: list, window_size: int = 3
434+
):
451435
if len(objectives_clustering) < window_size + 1:
452436
return False
453437
obj_old = 0.0
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
from __future__ import annotations
2+
3+
import cupy as cp
4+
5+
6+
@cp.fuse
7+
def _get_factor(O_k, ridge_lambda):
8+
return 1 / (O_k + ridge_lambda)
9+
10+
11+
@cp.fuse
12+
def _get_pen(E: cp.ndarray, O: cp.ndarray, theta: cp.ndarray) -> cp.ndarray:
13+
return cp.power(cp.divide(E + 1, O + 1), theta)
14+
15+
16+
@cp.fuse
17+
def _calc_R(term: cp.ndarray, mm: cp.ndarray) -> cp.ndarray:
18+
return cp.exp(term * (1 - mm))
19+
20+
21+
@cp.fuse
22+
def _div_clip(X: cp.ndarray, norm: cp.ndarray) -> cp.ndarray:
23+
return X / cp.clip(norm, a_min=1e-12, a_max=cp.inf)
24+
25+
26+
@cp.fuse
27+
def _log_div_OE(O: cp.ndarray, E: cp.ndarray) -> cp.ndarray:
28+
return O * cp.log((O + 1) / (E + 1))
29+
30+
31+
@cp.fuse
32+
def _R_multi_m(R, other):
33+
return R * 2 * (1 - other)

0 commit comments

Comments
 (0)