Skip to content

Commit 32c958a

Browse files
Readded estimate_components to clustering models
1 parent decf0fe commit 32c958a

File tree

1 file changed

+34
-4
lines changed

1 file changed

+34
-4
lines changed

turftopic/models/cluster.py

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,36 @@ def _calculate_topic_vectors(
261261
)
262262
return topic_vectors
263263

264+
def estimate_components(
265+
self, feature_importance: Optional[WordImportance] = None
266+
) -> np.ndarray:
267+
"""Estimates feature importances based on a fitted clustering.
268+
269+
Parameters
270+
----------
271+
feature_importance: WordImportance, default None
272+
Method for estimating term importances.
273+
'centroid' uses distances from cluster centroid similarly
274+
to Top2Vec.
275+
'c-tf-idf' uses BERTopic's c-tf-idf.
276+
'soft-c-tf-idf' uses Soft c-TF-IDF from GMM, the results should
277+
be very similar to 'c-tf-idf'.
278+
'bayes' uses Bayes' rule.
279+
280+
Returns
281+
-------
282+
ndarray of shape (n_components, n_vocab)
283+
Topic-term matrix.
284+
"""
285+
if feature_importance is not None:
286+
if feature_importance not in VALID_WORD_IMPORTANCE:
287+
raise ValueError(
288+
f"feature_importance must be one of {VALID_WORD_IMPORTANCE} got {feature_importance} instead."
289+
)
290+
self.feature_importance = feature_importance
291+
self.hierarchy.estimate_components()
292+
return self.components_
293+
264294
def reduce_topics(
265295
self,
266296
n_reduce_to: int,
@@ -424,15 +454,13 @@ def estimate_temporal_components(
424454
self,
425455
time_labels,
426456
time_bin_edges,
427-
feature_importance: Literal[
428-
"c-tf-idf", "soft-c-tf-idf", "centroid", "bayes"
429-
],
457+
feature_importance: Optional[WordImportance] = None,
430458
) -> np.ndarray:
431459
"""Estimates temporal components based on a fitted topic model.
432460
433461
Parameters
434462
----------
435-
feature_importance: {'soft-c-tf-idf', 'c-tf-idf', 'bayes', 'centroid'}, default 'soft-c-tf-idf'
463+
feature_importance: WordImportance, default None
436464
Method for estimating term importances.
437465
'centroid' uses distances from cluster centroid similarly
438466
to Top2Vec.
@@ -450,6 +478,8 @@ def estimate_temporal_components(
450478
raise NotFittedError(
451479
"The model has not been fitted yet, please fit the model before estimating temporal components."
452480
)
481+
if feature_importance is None:
482+
feature_importance = self.feature_importance
453483
n_comp, n_vocab = self.components_.shape
454484
self.time_bin_edges = time_bin_edges
455485
n_bins = len(self.time_bin_edges) - 1

0 commit comments

Comments
 (0)