@@ -76,17 +76,17 @@ class DirichletProcess:
7676 cluster_params : dict
7777 Dictionary of cluster parameters for each cluster.
7878 Contains 'mean' (centroid) and 'count' (number of points).
79- global_mean : ndarray
79+ global_mean : Optional[EmbeddingTensor]
8080 Global mean of all document embeddings.
8181 next_id : int
8282 Next available cluster ID.
83- embeddings_ : ndarray
83+ embeddings_ : Optional[EmbeddingTensor]
8484 Document embeddings after fitting.
85- labels_ : ndarray
85+ labels_ : Optional[NDArray[np.int64]]
8686 Cluster assignments after fitting.
87- text_embeddings : dict
87+ text_embeddings : dict[str, EmbeddingTensor]
8888 Cache of text to embedding mappings.
89- embedding_dim : int or None
89+ embedding_dim : Optional[ int]
9090 Dimension of the embedding vectors.
9191 """
9292
@@ -108,10 +108,10 @@ def __init__(
108108 kappa : float
109109 Precision parameter for the von Mises-Fisher distribution.
110110 Higher values lead to tighter, more concentrated clusters.
111- model_name : str, optional
111+ model_name : Optional[ str]
112112 Name of the sentence transformer model to use.
113113 Default is "all-MiniLM-L6-v2".
114- random_state : int, optional
114+ random_state : Optional[ int]
115115 Random seed for reproducibility.
116116 If None, fresh, unpredictable entropy will be pulled from the OS.
117117 """
@@ -124,10 +124,10 @@ def __init__(
124124
125125 self .clusters = []
126126 self .cluster_params = {}
127- self .global_mean = None
127+ self .global_mean : Optional [ EmbeddingTensor ] = None
128128 self .next_id = 0
129- self .embeddings_ = None
130- self .labels_ = None
129+ self .embeddings_ : Optional [ EmbeddingTensor ] = None
130+ self .labels_ : Optional [ NDArray [ np . int64 ]] = None
131131
132132 # For tracking processed texts and their embeddings
133133 self .text_embeddings : dict [str , EmbeddingTensor ] = {}
@@ -148,7 +148,7 @@ def get_embedding(self, text: Union[str, list[str]]) -> EmbeddingTensor:
148148
149149 Returns
150150 -------
151- numpy.ndarray
151+ EmbeddingTensor
152152 The normalized embedding vector(s) for the text.
153153 If input is a single string, returns a single embedding vector.
154154 If input is a list, returns an array of embedding vectors.
@@ -191,7 +191,7 @@ def get_embedding(self, text: Union[str, list[str]]) -> EmbeddingTensor:
191191 # Return single embedding or list based on input
192192 return results [0 ] if is_single else np .array (results )
193193
194- def _normalize (self , embedding : EmbeddingTensor ) -> EmbeddingTensor :
194+ def _normalize (self , embedding : EmbeddingTensor ) -> NDArray [ np . float32 ] :
195195 """
196196 Normalize vector to unit length for use with von Mises-Fisher distribution.
197197
@@ -205,13 +205,15 @@ def _normalize(self, embedding: EmbeddingTensor) -> EmbeddingTensor:
205205
206206 Returns
207207 -------
208- EmbeddingTensor
209- The normalized embedding vector with unit length.
208+ NDArray[np.float32]
209+ The normalized embedding vector with unit length as a NumPy array .
210210 """
211211 norm = np .linalg .norm (embedding )
212212 # Convert to numpy array to ensure division works properly
213213 embedding_np = to_numpy (embedding )
214- return embedding_np / norm if norm > 0 else embedding_np
214+ # Ensure the result is float32 to match the return type
215+ result = embedding_np / norm if norm > 0 else embedding_np
216+ return result .astype (np .float32 )
215217
216218 def _log_likelihood_vmf (self , embedding : EmbeddingTensor , cluster_id : int ) -> float :
217219 """
@@ -262,7 +264,7 @@ def log_crp_prior(self, cluster_id: Optional[int] = None) -> float:
262264
263265 Parameters
264266 ----------
265- cluster_id : int, optional
267+ cluster_id : Optional[ int]
266268 The cluster ID.
267269 If provided, calculate prior for an existing cluster.
268270 If None, calculate prior for a new cluster.
@@ -399,7 +401,7 @@ def _create_or_update_cluster(
399401 Document embedding vector.
400402 is_new_cluster : bool
401403 Whether to create a new cluster.
402- existing_cluster_id : int, optional
404+ existing_cluster_id : Optional[ int]
403405 ID of existing cluster to update, if is_new_cluster is False.
404406
405407 Returns
@@ -486,9 +488,9 @@ def fit(self, documents, _y: Union[Any, None] = None):
486488
487489 Parameters
488490 ----------
489- documents : array-like of shape (n_samples,)
491+ documents : Union[list[str], list[EmbeddingTensor]]
490492 The text documents or embeddings to cluster.
491- _y : Any, optional
493+ _y : Union[ Any, None]
492494 Ignored. Added for compatibility with scikit-learn API.
493495
494496 Returns
@@ -500,9 +502,9 @@ def fit(self, documents, _y: Union[Any, None] = None):
500502 ----
501503 After fitting, the following attributes are set:
502504
503- - :data:`embeddings_` : ndarray of shape (n_samples, n_features)
505+ - :data:`embeddings_` : Optional[EmbeddingTensor]
504506 The document embeddings.
505- - :data:`labels_` : ndarray of shape (n_samples,)
507+ - :data:`labels_` : NDArray[np.int64]
506508 The cluster assignments for each document.
507509 - :data:`clusters` : list
508510 List of cluster IDs for each document.
@@ -540,12 +542,12 @@ def predict(self, documents):
540542
541543 Parameters
542544 ----------
543- documents : array-like of shape (n_samples,)
545+ documents : Union[list[str], list[EmbeddingTensor]]
544546 The text documents or embeddings to predict clusters for.
545547
546548 Returns
547549 -------
548- labels : ndarray of shape (n_samples,)
550+ labels : NDArray[np.int64]
549551 Cluster labels for each document.
550552 Returns -1 if no clusters exist yet.
551553
@@ -584,14 +586,14 @@ def fit_predict(self, documents, _y: Union[Any, None] = None):
584586
585587 Parameters
586588 ----------
587- documents : array-like of shape (n_samples,)
589+ documents : Union[list[str], list[EmbeddingTensor]]
588590 The text documents or embeddings to cluster.
589- _y : Ignored
591+ _y : Union[Any, None]
590592 This parameter exists only for compatibility with scikit-learn API.
591593
592594 Returns
593595 -------
594- labels : ndarray of shape (n_samples,)
596+ labels : NDArray[np.int64]
595597 Cluster labels for each document.
596598
597599 Notes
@@ -624,17 +626,17 @@ class PitmanYorProcess(DirichletProcess):
624626 cluster_params : dict
625627 Dictionary of cluster parameters for each cluster.
626628 Contains 'mean' (centroid) and 'count' (number of points).
627- global_mean : ndarray
629+ global_mean : Optional[EmbeddingTensor]
628630 Global mean of all document embeddings.
629631 next_id : int
630632 Next available cluster ID.
631- embeddings_ : ndarray
633+ embeddings_ : Optional[EmbeddingTensor]
632634 Document embeddings after fitting.
633- labels_ : ndarray
635+ labels_ : Optional[NDArray[np.int64]]
634636 Cluster assignments after fitting.
635- text_embeddings : dict
637+ text_embeddings : dict[str, EmbeddingTensor]
636638 Cache of text to embedding mappings.
637- embedding_dim : int, optional
639+ embedding_dim : Optional[ int]
638640 Dimension of the embedding vectors.
639641
640642 Notes
@@ -685,10 +687,10 @@ def __init__(
685687 Controls the power-law behavior. Higher values create more
686688 power-law-like cluster size distributions. When σ=0, the model
687689 reduces to a Dirichlet Process.
688- model_name : str, optional
690+ model_name : Optional[ str]
689691 Name of the sentence transformer model to use.
690692 Default is "all-MiniLM-L6-v2".
691- random_state : int, optional
693+ random_state : Optional[ int]
692694 Random seed for reproducibility.
693695 If None, fresh, unpredictable entropy will be pulled from the OS.
694696
0 commit comments