Skip to content
This repository was archived by the owner on Jan 8, 2026. It is now read-only.

Commit 01d9450

Browse files
committed
Amend docs
1 parent 8292a59 commit 01d9450

File tree

1 file changed

+160
-8
lines changed

1 file changed

+160
-8
lines changed

qadst/clustering/models.py

Lines changed: 160 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,24 @@
1717

1818

1919
class DirichletProcess:
20-
"""Dirichlet Process clustering implementation."""
20+
"""
21+
Dirichlet Process clustering implementation for text data.
22+
23+
This implementation uses a Chinese Restaurant Process (CRP) formulation combined
24+
with semantic similarity measures to cluster text data.
25+
26+
The model assigns each text to either an existing cluster or creates a new cluster
27+
based on both the CRP probabilities and the semantic similarity between the text
28+
and existing cluster centroids.
29+
30+
Attributes:
31+
alpha (float): Concentration parameter for new cluster creation.
32+
Higher values lead to more clusters.
33+
clusters (list[int]): List of cluster assignments for each processed text.
34+
cluster_params (list[Tensor]): List of cluster embeddings for each cluster.
35+
model: Sentence transformer model used for text embeddings.
36+
cache (EmbeddingCache): Optional cache for storing text embeddings.
37+
"""
2138

2239
def __init__(
2340
self,
@@ -26,6 +43,19 @@ def __init__(
2643
similarity_metric: Optional[Callable[[str, Tensor], float]] = None,
2744
cache: Optional[EmbeddingCache] = None,
2845
):
46+
"""
47+
Initialize a Dirichlet Process clustering model.
48+
49+
Args:
50+
alpha (float): Concentration parameter for new cluster creation.
51+
Higher values lead to more clusters.
52+
base_measure (Optional[Tensor]): Base measure for the Dirichlet Process.
53+
Currently not used in this implementation.
54+
similarity_metric (Optional[Callable]): Function to compute similarity
55+
between a text and cluster parameters. If None, uses bert_similarity.
56+
cache (Optional[EmbeddingCache]): Cache for storing text embeddings.
57+
Helps avoid redundant embedding computations.
58+
"""
2959
self.alpha = alpha
3060
self.base_measure = base_measure
3161
self.clusters: list[int] = []
@@ -40,6 +70,15 @@ def __init__(
4070
self.cache.load_cache()
4171

4272
def get_embedding(self, text: str) -> Tensor:
73+
"""
74+
Get the embedding for a text, using cache if available.
75+
76+
Args:
77+
text (str): The text to embed.
78+
79+
Returns:
80+
Tensor: The embedding vector for the text.
81+
"""
4382
# Try to get from cache first
4483
if self.cache and text in self.cache:
4584
embedding = self.cache.get(text)
@@ -56,19 +95,61 @@ def get_embedding(self, text: str) -> Tensor:
5695
return embedding
5796

5897
def save_embedding_cache(self):
98+
"""
99+
Save the embedding cache to disk if a cache provider is available.
100+
This helps preserve embeddings between runs for faster processing.
101+
"""
59102
if self.cache:
60103
self.cache.save_cache()
61104

62105
def bert_similarity(self, text, cluster_param):
106+
"""
107+
Calculate cosine similarity between a text and cluster parameters.
108+
109+
Args:
110+
text (str): The text to compare.
111+
cluster_param (Tensor): The cluster parameters (embedding).
112+
113+
Returns:
114+
float: Similarity score between 0 and 1, where 1 means identical.
115+
"""
63116
text_embedding = self.get_embedding(text)
64117
cluster_embedding = cluster_param
65118
similarity = 1 - cosine(text_embedding, cluster_embedding)
66119
return max(0.0, similarity)
67120

68121
def sample_new_cluster(self, text):
122+
"""
123+
Sample parameters for a new cluster based on the given text.
124+
125+
Args:
126+
text (str): The text to use as the basis for the new cluster.
127+
128+
Returns:
129+
Tensor: Embedding to use as parameters for the new cluster.
130+
"""
69131
return self.get_embedding(text)
70132

71133
def assign_cluster(self, text):
134+
"""
135+
Assign a text to a cluster using the Chinese Restaurant Process with similarity.
136+
137+
This method computes probabilities for assigning the text to each existing
138+
cluster or creating a new cluster. The probabilities are based on:
139+
1. The number of texts already in each cluster (CRP prior)
140+
2. The similarity between the text and each cluster's parameters
141+
3. The concentration parameter alpha
142+
143+
The method then samples from this probability distribution to make the
144+
assignment.
145+
146+
Args:
147+
text (str): The text to assign to a cluster.
148+
149+
Side effects:
150+
Updates self.clusters with the cluster assignment for this text.
151+
Updates self.cluster_params if a new cluster is created.
152+
"""
72153
probs = []
73154
total_points = len(self.clusters)
74155

@@ -100,11 +181,16 @@ def fit(self, texts: List[str]) -> Tuple[List[int], List[Tensor]]:
100181
"""
101182
Train the Dirichlet Process model on the given text data.
102183
184+
This method processes each text in the input list, assigning it to a cluster
185+
using the Chinese Restaurant Process.
186+
103187
Args:
104-
texts: List of text strings to cluster
188+
texts (List[str]): List of text strings to cluster.
105189
106190
Returns:
107-
Tuple containing (cluster_assignments, cluster_parameters)
191+
Tuple[List[int], List[Tensor]]: A tuple containing:
192+
- List of cluster assignments for each text
193+
- List of cluster parameters (embeddings)
108194
"""
109195
logger.info(f"Processing {len(texts)} texts...")
110196
for text in tqdm(texts, desc="Clustering"):
@@ -116,7 +202,30 @@ def fit(self, texts: List[str]) -> Tuple[List[int], List[Tensor]]:
116202

117203

118204
class PitmanYorProcess(DirichletProcess):
119-
"""Pitman-Yor Process clustering implementation."""
205+
"""
206+
Pitman-Yor Process clustering implementation for text data.
207+
208+
The Pitman-Yor Process is a generalization of the Dirichlet Process that introduces
209+
a discount parameter (sigma) to control the power-law behavior of the cluster
210+
size distribution. It is particularly effective for modeling natural language
211+
phenomena that exhibit power-law distributions.
212+
213+
This implementation extends the DirichletProcess class, adding the sigma parameter
214+
and modifying the cluster assignment probabilities according to the Pitman-Yor
215+
Process. It also includes optimizations for tracking cluster sizes.
216+
217+
Attributes:
218+
alpha (float): Concentration parameter inherited from DirichletProcess.
219+
sigma (float): Discount parameter controlling power-law behavior.
220+
Should be in range [0, 1). Higher values create more heavy-tailed
221+
distributions.
222+
clusters (list[int]): List of cluster assignments for each processed text.
223+
cluster_params (list[Tensor]): List of cluster parameters (embeddings) for
224+
each cluster.
225+
cluster_sizes (dict[int, int]): Dictionary tracking the size of each cluster.
226+
model: Sentence transformer model used for text embeddings.
227+
cache (EmbeddingCache): Optional cache for storing text embeddings.
228+
"""
120229

121230
def __init__(
122231
self,
@@ -126,13 +235,51 @@ def __init__(
126235
similarity_metric: Optional[Callable[[str, Tensor], float]] = None,
127236
cache: Optional[EmbeddingCache] = None,
128237
):
238+
"""
239+
Initialize a Pitman-Yor Process clustering model.
240+
241+
Args:
242+
alpha (float): Concentration parameter that controls the propensity to
243+
create new clusters. Higher values lead to more clusters.
244+
sigma (float): Discount parameter controlling power-law behavior.
245+
Should be in range [0, 1). Higher values create more heavy-tailed
246+
distributions.
247+
base_measure (Optional[Tensor]): Base measure for the Pitman-Yor Process.
248+
Currently not used in this implementation.
249+
similarity_metric (Optional[Callable[[str, Tensor], float]]): Function to
250+
compute similarity between a text and cluster parameters.
251+
If None, uses bert_similarity.
252+
cache (Optional[EmbeddingCache]): Cache for storing text embeddings.
253+
Helps avoid redundant embedding computations.
254+
"""
129255
super().__init__(alpha, base_measure, similarity_metric, cache)
130256
self.sigma = sigma
131257
# Keep track of cluster sizes for faster access
132258
self.cluster_sizes = {}
133259

134260
def assign_cluster(self, text):
135-
"""Uses Pitman-Yor process probability calculations."""
261+
"""
262+
Assign a text to a cluster using the Pitman-Yor Process with similarity.
263+
264+
This method extends the DirichletProcess assignment method by modifying the
265+
probability calculations to incorporate the discount parameter sigma. The
266+
Pitman-Yor probabilities favor a power-law distribution of cluster sizes,
267+
which is often more realistic for natural language data.
268+
269+
The method computes probabilities for assigning the text to each existing
270+
cluster or creating a new cluster based on:
271+
1. The number of texts already in each cluster (PYP prior)
272+
2. The similarity between the text and each cluster's parameters
273+
3. The concentration parameter alpha and discount parameter sigma
274+
275+
Args:
276+
text (str): The text to assign to a cluster.
277+
278+
Side effects:
279+
Updates self.clusters with the cluster assignment for this text.
280+
Updates self.cluster_params if a new cluster is created.
281+
Updates self.cluster_sizes to track cluster populations.
282+
"""
136283
probs = []
137284
total_points = len(self.clusters)
138285

@@ -185,13 +332,18 @@ def assign_cluster(self, text):
185332

186333
def fit(self, texts: List[str]) -> Tuple[List[int], List[Tensor]]:
187334
"""
188-
Optimized version of fit for PitmanYorProcess.
335+
Train the Pitman-Yor Process model on the given text data.
336+
337+
This is an optimized version of the fit method for PitmanYorProcess that
338+
processes texts with tracking of cluster sizes for better performance.
189339
190340
Args:
191-
texts: List of text strings to cluster
341+
texts (List[str]): List of text strings to cluster.
192342
193343
Returns:
194-
Tuple containing (cluster_assignments, cluster_parameters)
344+
Tuple[List[int], List[Tensor]]: A tuple containing:
345+
- List of cluster assignments for each text
346+
- List of cluster parameters (embeddings)
195347
"""
196348
logger.info(f"Processing {len(texts)} texts with optimized PitmanYorProcess...")
197349

0 commit comments

Comments
 (0)