Skip to content
This repository was archived by the owner on Jan 8, 2026. It is now read-only.

Commit fc0fcd5

Browse files
committed
Uodate type hints
1 parent 75bf3d2 commit fc0fcd5

File tree

4 files changed

+39
-35
lines changed

4 files changed

+39
-35
lines changed

clusx/clustering/models.py

Lines changed: 35 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -76,17 +76,17 @@ class DirichletProcess:
7676
cluster_params : dict
7777
Dictionary of cluster parameters for each cluster.
7878
Contains 'mean' (centroid) and 'count' (number of points).
79-
global_mean : ndarray
79+
global_mean : Optional[EmbeddingTensor]
8080
Global mean of all document embeddings.
8181
next_id : int
8282
Next available cluster ID.
83-
embeddings_ : ndarray
83+
embeddings_ : Optional[EmbeddingTensor]
8484
Document embeddings after fitting.
85-
labels_ : ndarray
85+
labels_ : Optional[NDArray[np.int64]]
8686
Cluster assignments after fitting.
87-
text_embeddings : dict
87+
text_embeddings : dict[str, EmbeddingTensor]
8888
Cache of text to embedding mappings.
89-
embedding_dim : int or None
89+
embedding_dim : Optional[int]
9090
Dimension of the embedding vectors.
9191
"""
9292

@@ -108,10 +108,10 @@ def __init__(
108108
kappa : float
109109
Precision parameter for the von Mises-Fisher distribution.
110110
Higher values lead to tighter, more concentrated clusters.
111-
model_name : str, optional
111+
model_name : Optional[str]
112112
Name of the sentence transformer model to use.
113113
Default is "all-MiniLM-L6-v2".
114-
random_state : int, optional
114+
random_state : Optional[int]
115115
Random seed for reproducibility.
116116
If None, fresh, unpredictable entropy will be pulled from the OS.
117117
"""
@@ -124,10 +124,10 @@ def __init__(
124124

125125
self.clusters = []
126126
self.cluster_params = {}
127-
self.global_mean = None
127+
self.global_mean: Optional[EmbeddingTensor] = None
128128
self.next_id = 0
129-
self.embeddings_ = None
130-
self.labels_ = None
129+
self.embeddings_: Optional[EmbeddingTensor] = None
130+
self.labels_: Optional[NDArray[np.int64]] = None
131131

132132
# For tracking processed texts and their embeddings
133133
self.text_embeddings: dict[str, EmbeddingTensor] = {}
@@ -148,7 +148,7 @@ def get_embedding(self, text: Union[str, list[str]]) -> EmbeddingTensor:
148148
149149
Returns
150150
-------
151-
numpy.ndarray
151+
EmbeddingTensor
152152
The normalized embedding vector(s) for the text.
153153
If input is a single string, returns a single embedding vector.
154154
If input is a list, returns an array of embedding vectors.
@@ -191,7 +191,7 @@ def get_embedding(self, text: Union[str, list[str]]) -> EmbeddingTensor:
191191
# Return single embedding or list based on input
192192
return results[0] if is_single else np.array(results)
193193

194-
def _normalize(self, embedding: EmbeddingTensor) -> EmbeddingTensor:
194+
def _normalize(self, embedding: EmbeddingTensor) -> NDArray[np.float32]:
195195
"""
196196
Normalize vector to unit length for use with von Mises-Fisher distribution.
197197
@@ -205,13 +205,15 @@ def _normalize(self, embedding: EmbeddingTensor) -> EmbeddingTensor:
205205
206206
Returns
207207
-------
208-
EmbeddingTensor
209-
The normalized embedding vector with unit length.
208+
NDArray[np.float32]
209+
The normalized embedding vector with unit length as a NumPy array.
210210
"""
211211
norm = np.linalg.norm(embedding)
212212
# Convert to numpy array to ensure division works properly
213213
embedding_np = to_numpy(embedding)
214-
return embedding_np / norm if norm > 0 else embedding_np
214+
# Ensure the result is float32 to match the return type
215+
result = embedding_np / norm if norm > 0 else embedding_np
216+
return result.astype(np.float32)
215217

216218
def _log_likelihood_vmf(self, embedding: EmbeddingTensor, cluster_id: int) -> float:
217219
"""
@@ -262,7 +264,7 @@ def log_crp_prior(self, cluster_id: Optional[int] = None) -> float:
262264
263265
Parameters
264266
----------
265-
cluster_id : int, optional
267+
cluster_id : Optional[int]
266268
The cluster ID.
267269
If provided, calculate prior for an existing cluster.
268270
If None, calculate prior for a new cluster.
@@ -399,7 +401,7 @@ def _create_or_update_cluster(
399401
Document embedding vector.
400402
is_new_cluster : bool
401403
Whether to create a new cluster.
402-
existing_cluster_id : int, optional
404+
existing_cluster_id : Optional[int]
403405
ID of existing cluster to update, if is_new_cluster is False.
404406
405407
Returns
@@ -486,9 +488,9 @@ def fit(self, documents, _y: Union[Any, None] = None):
486488
487489
Parameters
488490
----------
489-
documents : array-like of shape (n_samples,)
491+
documents : Union[list[str], list[EmbeddingTensor]]
490492
The text documents or embeddings to cluster.
491-
_y : Any, optional
493+
_y : Union[Any, None]
492494
Ignored. Added for compatibility with scikit-learn API.
493495
494496
Returns
@@ -500,9 +502,9 @@ def fit(self, documents, _y: Union[Any, None] = None):
500502
----
501503
After fitting, the following attributes are set:
502504
503-
- :data:`embeddings_` : ndarray of shape (n_samples, n_features)
505+
- :data:`embeddings_` : Optional[EmbeddingTensor]
504506
The document embeddings.
505-
- :data:`labels_` : ndarray of shape (n_samples,)
507+
- :data:`labels_` : NDArray[np.int64]
506508
The cluster assignments for each document.
507509
- :data:`clusters` : list
508510
List of cluster IDs for each document.
@@ -540,12 +542,12 @@ def predict(self, documents):
540542
541543
Parameters
542544
----------
543-
documents : array-like of shape (n_samples,)
545+
documents : Union[list[str], list[EmbeddingTensor]]
544546
The text documents or embeddings to predict clusters for.
545547
546548
Returns
547549
-------
548-
labels : ndarray of shape (n_samples,)
550+
labels : NDArray[np.int64]
549551
Cluster labels for each document.
550552
Returns -1 if no clusters exist yet.
551553
@@ -584,14 +586,14 @@ def fit_predict(self, documents, _y: Union[Any, None] = None):
584586
585587
Parameters
586588
----------
587-
documents : array-like of shape (n_samples,)
589+
documents : Union[list[str], list[EmbeddingTensor]]
588590
The text documents or embeddings to cluster.
589-
_y : Ignored
591+
_y : Union[Any, None]
590592
This parameter exists only for compatibility with scikit-learn API.
591593
592594
Returns
593595
-------
594-
labels : ndarray of shape (n_samples,)
596+
labels : NDArray[np.int64]
595597
Cluster labels for each document.
596598
597599
Notes
@@ -624,17 +626,17 @@ class PitmanYorProcess(DirichletProcess):
624626
cluster_params : dict
625627
Dictionary of cluster parameters for each cluster.
626628
Contains 'mean' (centroid) and 'count' (number of points).
627-
global_mean : ndarray
629+
global_mean : Optional[EmbeddingTensor]
628630
Global mean of all document embeddings.
629631
next_id : int
630632
Next available cluster ID.
631-
embeddings_ : ndarray
633+
embeddings_ : Optional[EmbeddingTensor]
632634
Document embeddings after fitting.
633-
labels_ : ndarray
635+
labels_ : Optional[NDArray[np.int64]]
634636
Cluster assignments after fitting.
635-
text_embeddings : dict
637+
text_embeddings : dict[str, EmbeddingTensor]
636638
Cache of text to embedding mappings.
637-
embedding_dim : int, optional
639+
embedding_dim : Optional[int]
638640
Dimension of the embedding vectors.
639641
640642
Notes
@@ -685,10 +687,10 @@ def __init__(
685687
Controls the power-law behavior. Higher values create more
686688
power-law-like cluster size distributions. When σ=0, the model
687689
reduces to a Dirichlet Process.
688-
model_name : str, optional
690+
model_name : Optional[str]
689691
Name of the sentence transformer model to use.
690692
Default is "all-MiniLM-L6-v2".
691-
random_state : int, optional
693+
random_state : Optional[int]
692694
Random seed for reproducibility.
693695
If None, fresh, unpredictable entropy will be pulled from the OS.
694696

clusx/logging.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@ def setup_logging(level: Optional[int] = None) -> None:
4040
4141
Args:
4242
level: The logging level (defaults to logging.INFO if None).
43-
Common levels: DEBUG(10), INFO(20), WARNING(30), ERROR(40), CRITICAL(50)
4443
"""
4544
if level is None:
4645
level = logging.INFO

clusx/visualization.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""
2-
Visualization module for QA Dataset Clustering.
2+
Visualization module for Clusterium.
33
44
This module provides functions for visualizing clustering results and evaluation
55
metrics.

docs/source/conf.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,13 @@
3939
# to non-types.
4040
nitpick_ignore = [
4141
("py:class", "np.float32"),
42+
("py:class", "np.int64"),
43+
("py:class", "np.ndarray"),
4244
("py:class", "numpy.bool_"),
4345
("py:class", "NDArray"),
4446
("py:class", "EmbeddingTensor"),
4547
("py:class", "Axes"),
48+
("py:class", "SentenceTransformer"),
4649
]
4750

4851
# -- Options for intersphinx -------------------------------------------------

0 commit comments

Comments
 (0)