1717
1818
1919class DirichletProcess :
20- """Dirichlet Process clustering implementation."""
20+ """
21+ Dirichlet Process clustering implementation for text data.
22+
23+ This implementation uses a Chinese Restaurant Process (CRP) formulation combined
24+ with semantic similarity measures to cluster text data.
25+
26+ The model assigns each text to either an existing cluster or creates a new cluster
27+ based on both the CRP probabilities and the semantic similarity between the text
28+ and existing cluster centroids.
29+
30+ Attributes:
31+ alpha (float): Concentration parameter for new cluster creation.
32+ Higher values lead to more clusters.
33+ clusters (list[int]): List of cluster assignments for each processed text.
34+ cluster_params (list[Tensor]): List of cluster embeddings for each cluster.
35+ model: Sentence transformer model used for text embeddings.
36+ cache (EmbeddingCache): Optional cache for storing text embeddings.
37+ """
2138
2239 def __init__ (
2340 self ,
@@ -26,6 +43,19 @@ def __init__(
2643 similarity_metric : Optional [Callable [[str , Tensor ], float ]] = None ,
2744 cache : Optional [EmbeddingCache ] = None ,
2845 ):
46+ """
47+ Initialize a Dirichlet Process clustering model.
48+
49+ Args:
50+ alpha (float): Concentration parameter for new cluster creation.
51+ Higher values lead to more clusters.
52+ base_measure (Optional[Tensor]): Base measure for the Dirichlet Process.
53+ Currently not used in this implementation.
54+ similarity_metric (Optional[Callable]): Function to compute similarity
55+ between a text and cluster parameters. If None, uses bert_similarity.
56+ cache (Optional[EmbeddingCache]): Cache for storing text embeddings.
57+ Helps avoid redundant embedding computations.
58+ """
2959 self .alpha = alpha
3060 self .base_measure = base_measure
3161 self .clusters : list [int ] = []
@@ -40,6 +70,15 @@ def __init__(
4070 self .cache .load_cache ()
4171
4272 def get_embedding (self , text : str ) -> Tensor :
73+ """
74+ Get the embedding for a text, using cache if available.
75+
76+ Args:
77+ text (str): The text to embed.
78+
79+ Returns:
80+ Tensor: The embedding vector for the text.
81+ """
4382 # Try to get from cache first
4483 if self .cache and text in self .cache :
4584 embedding = self .cache .get (text )
@@ -56,19 +95,61 @@ def get_embedding(self, text: str) -> Tensor:
5695 return embedding
5796
5897 def save_embedding_cache (self ):
98+ """
99+ Save the embedding cache to disk if a cache provider is available.
100+ This helps preserve embeddings between runs for faster processing.
101+ """
59102 if self .cache :
60103 self .cache .save_cache ()
61104
62105 def bert_similarity (self , text , cluster_param ):
106+ """
107+ Calculate cosine similarity between a text and cluster parameters.
108+
109+ Args:
110+ text (str): The text to compare.
111+ cluster_param (Tensor): The cluster parameters (embedding).
112+
113+ Returns:
114+ float: Similarity score between 0 and 1, where 1 means identical.
115+ """
63116 text_embedding = self .get_embedding (text )
64117 cluster_embedding = cluster_param
65118 similarity = 1 - cosine (text_embedding , cluster_embedding )
66119 return max (0.0 , similarity )
67120
68121 def sample_new_cluster (self , text ):
122+ """
123+ Sample parameters for a new cluster based on the given text.
124+
125+ Args:
126+ text (str): The text to use as the basis for the new cluster.
127+
128+ Returns:
129+ Tensor: Embedding to use as parameters for the new cluster.
130+ """
69131 return self .get_embedding (text )
70132
71133 def assign_cluster (self , text ):
134+ """
135+ Assign a text to a cluster using the Chinese Restaurant Process with similarity.
136+
137+ This method computes probabilities for assigning the text to each existing
138+ cluster or creating a new cluster. The probabilities are based on:
139+ 1. The number of texts already in each cluster (CRP prior)
140+ 2. The similarity between the text and each cluster's parameters
141+ 3. The concentration parameter alpha
142+
143+ The method then samples from this probability distribution to make the
144+ assignment.
145+
146+ Args:
147+ text (str): The text to assign to a cluster.
148+
149+ Side effects:
150+ Updates self.clusters with the cluster assignment for this text.
151+ Updates self.cluster_params if a new cluster is created.
152+ """
72153 probs = []
73154 total_points = len (self .clusters )
74155
@@ -100,11 +181,16 @@ def fit(self, texts: List[str]) -> Tuple[List[int], List[Tensor]]:
100181 """
101182 Train the Dirichlet Process model on the given text data.
102183
184+ This method processes each text in the input list, assigning it to a cluster
185+ using the Chinese Restaurant Process.
186+
103187 Args:
104- texts: List of text strings to cluster
188+ texts (List[str]) : List of text strings to cluster.
105189
106190 Returns:
107- Tuple containing (cluster_assignments, cluster_parameters)
191+ Tuple[List[int], List[Tensor]]: A tuple containing:
192+ - List of cluster assignments for each text
193+ - List of cluster parameters (embeddings)
108194 """
109195 logger .info (f"Processing { len (texts )} texts..." )
110196 for text in tqdm (texts , desc = "Clustering" ):
@@ -116,7 +202,30 @@ def fit(self, texts: List[str]) -> Tuple[List[int], List[Tensor]]:
116202
117203
118204class PitmanYorProcess (DirichletProcess ):
119- """Pitman-Yor Process clustering implementation."""
205+ """
206+ Pitman-Yor Process clustering implementation for text data.
207+
208+ The Pitman-Yor Process is a generalization of the Dirichlet Process that introduces
209+ a discount parameter (sigma) to control the power-law behavior of the cluster
210+ size distribution. It is particularly effective for modeling natural language
211+ phenomena that exhibit power-law distributions.
212+
213+ This implementation extends the DirichletProcess class, adding the sigma parameter
214+ and modifying the cluster assignment probabilities according to the Pitman-Yor
215+ Process. It also includes optimizations for tracking cluster sizes.
216+
217+ Attributes:
218+ alpha (float): Concentration parameter inherited from DirichletProcess.
219+ sigma (float): Discount parameter controlling power-law behavior.
220+ Should be in range [0, 1). Higher values create more heavy-tailed
221+ distributions.
222+ clusters (list[int]): List of cluster assignments for each processed text.
223+ cluster_params (list[Tensor]): List of cluster parameters (embeddings) for
224+ each cluster.
225+ cluster_sizes (dict[int, int]): Dictionary tracking the size of each cluster.
226+ model: Sentence transformer model used for text embeddings.
227+ cache (EmbeddingCache): Optional cache for storing text embeddings.
228+ """
120229
121230 def __init__ (
122231 self ,
@@ -126,13 +235,51 @@ def __init__(
126235 similarity_metric : Optional [Callable [[str , Tensor ], float ]] = None ,
127236 cache : Optional [EmbeddingCache ] = None ,
128237 ):
238+ """
239+ Initialize a Pitman-Yor Process clustering model.
240+
241+ Args:
242+ alpha (float): Concentration parameter that controls the propensity to
243+ create new clusters. Higher values lead to more clusters.
244+ sigma (float): Discount parameter controlling power-law behavior.
245+ Should be in range [0, 1). Higher values create more heavy-tailed
246+ distributions.
247+ base_measure (Optional[Tensor]): Base measure for the Pitman-Yor Process.
248+ Currently not used in this implementation.
249+ similarity_metric (Optional[Callable[[str, Tensor], float]]): Function to
250+ compute similarity between a text and cluster parameters.
251+ If None, uses bert_similarity.
252+ cache (Optional[EmbeddingCache]): Cache for storing text embeddings.
253+ Helps avoid redundant embedding computations.
254+ """
129255 super ().__init__ (alpha , base_measure , similarity_metric , cache )
130256 self .sigma = sigma
131257 # Keep track of cluster sizes for faster access
132258 self .cluster_sizes = {}
133259
134260 def assign_cluster (self , text ):
135- """Uses Pitman-Yor process probability calculations."""
261+ """
262+ Assign a text to a cluster using the Pitman-Yor Process with similarity.
263+
264+ This method extends the DirichletProcess assignment method by modifying the
265+ probability calculations to incorporate the discount parameter sigma. The
266+ Pitman-Yor probabilities favor a power-law distribution of cluster sizes,
267+ which is often more realistic for natural language data.
268+
269+ The method computes probabilities for assigning the text to each existing
270+ cluster or creating a new cluster based on:
271+ 1. The number of texts already in each cluster (PYP prior)
272+ 2. The similarity between the text and each cluster's parameters
273+ 3. The concentration parameter alpha and discount parameter sigma
274+
275+ Args:
276+ text (str): The text to assign to a cluster.
277+
278+ Side effects:
279+ Updates self.clusters with the cluster assignment for this text.
280+ Updates self.cluster_params if a new cluster is created.
281+ Updates self.cluster_sizes to track cluster populations.
282+ """
136283 probs = []
137284 total_points = len (self .clusters )
138285
@@ -185,13 +332,18 @@ def assign_cluster(self, text):
185332
186333 def fit (self , texts : List [str ]) -> Tuple [List [int ], List [Tensor ]]:
187334 """
188- Optimized version of fit for PitmanYorProcess.
335+ Train the Pitman-Yor Process model on the given text data.
336+
337+ This is an optimized version of the fit method for PitmanYorProcess that
338+ processes texts with tracking of cluster sizes for better performance.
189339
190340 Args:
191- texts: List of text strings to cluster
341+ texts (List[str]) : List of text strings to cluster.
192342
193343 Returns:
194- Tuple containing (cluster_assignments, cluster_parameters)
344+ Tuple[List[int], List[Tensor]]: A tuple containing:
345+ - List of cluster assignments for each text
346+ - List of cluster parameters (embeddings)
195347 """
196348 logger .info (f"Processing { len (texts )} texts with optimized PitmanYorProcess..." )
197349
0 commit comments