From e4f8fc0de7dbbf8ea564a60557ec5ea64458eb3e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?M=C3=A1rton=20Kardos?= Date: Tue, 10 Dec 2024 11:35:48 +0100 Subject: [PATCH 1/3] Added fighting-words term importance to clustering models --- turftopic/feature_importance.py | 57 +++++++++++++++++++++++++++++++++ turftopic/models/cluster.py | 24 ++++++++++---- 2 files changed, 74 insertions(+), 7 deletions(-) diff --git a/turftopic/feature_importance.py b/turftopic/feature_importance.py index ea1fb32..1b2442a 100644 --- a/turftopic/feature_importance.py +++ b/turftopic/feature_importance.py @@ -1,3 +1,7 @@ +from __future__ import annotations + +from typing import Literal + import numpy as np import scipy.sparse as spr from sklearn.metrics.pairwise import cosine_similarity @@ -126,3 +130,56 @@ def bayes_rule( p_tw = (p_wt.T * p_t).T / p_w p_tw /= np.nansum(p_tw, axis=0) return p_tw + + +def fighting_words( + doc_topic_matrix: np.ndarray, + doc_term_matrix: spr.csr_matrix, + prior: float | Literal["corpus"] = "corpus", +) -> np.ndarray: + """Computes feature importance using the *Fighting Words* algorithm. + + Parameters + ---------- + doc_topic_matrix: np.ndarray + Document-topic matrix of shape (n_documents, n_topics) + doc_term_matrix: np.ndarray + Document-term matrix of shape (n_documents, vocab_size) + prior: float or "corpus", default "corpus" + Dirichlet prior to use. When a float, it indicates the alpha + parameter of a symmetric Dirichlet, if "corpus", + word frequencies from the background corpus are used. + Returns + ------- + ndarray of shape (n_topics, vocab_size) + Term importance matrix. + """ + labels = np.argmax(doc_topic_matrix, axis=1) + n_topics = doc_topic_matrix.shape[1] + n_vocab = doc_term_matrix.shape[1] + components = [] + if prior == "corpus": + priors = np.ravel(np.asarray(doc_term_matrix.sum(axis=0))) + else: + priors = np.full(n_vocab, prior) + a0 = np.sum(priors) # prior * n_vocab + for i_topic in range(n_topics): + topic_freq = np.ravel( + np.asarray(doc_term_matrix[labels == i_topic].sum(axis=0)) + ) + rest_freq = np.ravel( + np.asarray(doc_term_matrix[labels != i_topic].sum(axis=0)) + ) + n1 = np.sum(topic_freq) + n2 = np.sum(rest_freq) + topic_logodds = np.log( + (topic_freq + priors) / (n1 + a0 - topic_freq - priors) + ) + rest_logodds = np.log( + (rest_freq + priors) / (n2 + a0 - rest_freq - priors) + ) + delta = topic_logodds - rest_logodds + delta_var = 1 / (topic_freq + priors) + 1 / (rest_freq + priors) + zscore = delta / np.sqrt(delta_var) + components.append(zscore) + return np.stack(components) diff --git a/turftopic/models/cluster.py b/turftopic/models/cluster.py index a6eea27..8f442f4 100644 --- a/turftopic/models/cluster.py +++ b/turftopic/models/cluster.py @@ -20,7 +20,7 @@ from turftopic.dynamic import DynamicTopicModel from turftopic.feature_importance import (bayes_rule, cluster_centroid_distance, ctf_idf, - soft_ctf_idf) + fighting_words, soft_ctf_idf) from turftopic.vectorizer import default_vectorizer integer_message = """ @@ -39,7 +39,7 @@ """ feature_message = """ -feature_importance must be one of 'soft-c-tf-idf', 'c-tf-idf', 'centroid' +feature_importance must be one of 'soft-c-tf-idf', 'c-tf-idf', 'centroid', 'fighting_words' """ NOT_MATCHING_ERROR = ( @@ -152,14 +152,14 @@ class ClusteringTopicModel(ContextualModel, ClusterMixin, DynamicTopicModel): Clustering method to use for finding topics. Defaults to OPTICS with 25 minimum cluster size. To imitate the behavior of BERTopic or Top2Vec you should use HDBSCAN. - feature_importance: {'soft-c-tf-idf', 'c-tf-idf', 'bayes', 'centroid'}, default 'soft-c-tf-idf' + feature_importance: {'soft-c-tf-idf', 'c-tf-idf', 'fighting-words', 'centroid'}, default 'soft-c-tf-idf' Method for estimating term importances. 'centroid' uses distances from cluster centroid similarly to Top2Vec. 'c-tf-idf' uses BERTopic's c-tf-idf. 'soft-c-tf-idf' uses Soft c-TF-IDF from GMM, the results should be very similar to 'c-tf-idf'. - 'bayes' uses Bayes' rule. + 'fighting-words', uses the fighting-words algorithm (a Bayesian probabilistic model). n_reduce_to: int, default None Number of topics to reduce topics to. The specified reduction method will be used to merge them. @@ -188,6 +188,7 @@ def __init__( "soft-c-tf-idf", "centroid", "bayes", + "fighting-words", ] = "soft-c-tf-idf", n_reduce_to: Optional[int] = None, reduction_method: Literal[ @@ -202,6 +203,7 @@ def __init__( "soft-c-tf-idf", "centroid", "bayes", + "fighting-words", ]: raise ValueError(feature_message) if isinstance(encoder, int): @@ -364,21 +366,21 @@ def reset_topics(self): def estimate_components( self, feature_importance: Literal[ - "centroid", "soft-c-tf-idf", "bayes", "c-tf-idf" + "centroid", "soft-c-tf-idf", "bayes", "c-tf-idf", "fighting-words" ], ) -> np.ndarray: """Estimates feature importances based on a fitted clustering. Parameters ---------- - feature_importance: {'soft-c-tf-idf', 'c-tf-idf', 'bayes', 'centroid'}, default 'soft-c-tf-idf' + feature_importance: {'soft-c-tf-idf', 'c-tf-idf', 'bayes', 'centroid', 'fighting-words'}, default 'soft-c-tf-idf' Method for estimating term importances. 'centroid' uses distances from cluster centroid similarly to Top2Vec. 'c-tf-idf' uses BERTopic's c-tf-idf. 'soft-c-tf-idf' uses Soft c-TF-IDF from GMM, the results should be very similar to 'c-tf-idf'. - 'bayes' uses Bayes' rule. + 'fighting-words', uses the fighting-words algorithm (a Bayesian probabilistic model). Returns ------- @@ -426,6 +428,10 @@ def estimate_components( self.components_ = bayes_rule( document_topic_matrix, self.doc_term_matrix ) + elif feature_importance == "fighting-words": + self.components_ = fighting_words( + document_topic_matrix, self.doc_term_matrix + ) else: self.components_ = ctf_idf( document_topic_matrix, self.doc_term_matrix @@ -556,6 +562,10 @@ def estimate_temporal_components( self.temporal_components_[i_timebin] = bayes_rule( t_doc_topic, t_dtm ) + elif feature_importance == "fighting-words": + self.temporal_components_[i_timebin] = fighting_words( + t_doc_topic, t_dtm + ) elif feature_importance == "centroid": t_topic_vectors = self._calculate_topic_vectors( time_labels == i_timebin, From d8662f062ff3bcdd01c28c0532edd57c0e988206 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?M=C3=A1rton=20Kardos?= Date: Wed, 18 Dec 2024 14:19:12 +0100 Subject: [PATCH 2/3] Added semantic-difference based feature importance --- turftopic/feature_importance.py | 34 +++++++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/turftopic/feature_importance.py b/turftopic/feature_importance.py index 1b2442a..4c057be 100644 --- a/turftopic/feature_importance.py +++ b/turftopic/feature_importance.py @@ -5,6 +5,7 @@ import numpy as np import scipy.sparse as spr from sklearn.metrics.pairwise import cosine_similarity +from sklearn.preprocessing import scale def cluster_centroid_distance( @@ -183,3 +184,36 @@ def fighting_words( zscore = delta / np.sqrt(delta_var) components.append(zscore) return np.stack(components) + + +def semantic_difference( + doc_topic_matrix: np.ndarray, + embeddings: np.ndarray, + vocab_embeddings: np.ndarray, +) -> np.ndarray: + """Computes feature importances based on semantic differences + between one group and the rest. + + Parameters + ---------- + doc_topic_matrix: np.ndarray + Document-topic matrix of shape (n_documents, n_topics) + embeddings: np.ndarray + Document embeddingsof shape (n_documents, embedding_size). + vocab_embeddings: np.ndarray + Term embeddings of shape (vocab_size, embedding_size) + + Returns + ------- + ndarray of shape (n_topics, vocab_size) + Term importance matrix. + """ + labels = np.argmax(doc_topic_matrix, axis=1) + unique_labels = np.sort(np.unique(labels)) + components = [] + for label in unique_labels: + mean_diff = np.mean(embeddings[label == labels], axis=0) - np.mean( + embeddings[label != labels], axis=0 + ) + components.append(np.dot(vocab_embeddings, mean_diff)) + return scale(np.stack(components), axis=1) From 3d8d1fd5c1d74465fc0057d7be60d72a459e8dd8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?M=C3=A1rton=20Kardos?= Date: Wed, 18 Dec 2024 16:35:19 +0100 Subject: [PATCH 3/3] Added a first draft of SemanticLexicalAnalysis --- turftopic/supervised/__init__.py | 0 turftopic/supervised/semantic_lexical.py | 308 +++++++++++++++++++++++ 2 files changed, 308 insertions(+) create mode 100644 turftopic/supervised/__init__.py create mode 100644 turftopic/supervised/semantic_lexical.py diff --git a/turftopic/supervised/__init__.py b/turftopic/supervised/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/turftopic/supervised/semantic_lexical.py b/turftopic/supervised/semantic_lexical.py new file mode 100644 index 0000000..10bc01a --- /dev/null +++ b/turftopic/supervised/semantic_lexical.py @@ -0,0 +1,308 @@ +from typing import Literal, Optional, Union + +import numpy as np +from rich.console import Console +from sentence_transformers import SentenceTransformer +from sklearn.feature_extraction.text import CountVectorizer +from sklearn.linear_model import LinearRegression +from sklearn.metrics.pairwise import euclidean_distances +from sklearn.preprocessing import label_binarize, scale + +from turftopic.base import ContextualModel, Encoder +from turftopic.feature_importance import fighting_words, semantic_difference +from turftopic.vectorizer import default_vectorizer + +NOT_MATCHING_ERROR = ( + "Document embedding dimensionality ({n_dims}) doesn't match term embedding dimensionality ({n_word_dims}). " + + "Perhaps you are using precomputed embeddings but forgot to pass an encoder to your model. " + + "Try to initialize the model with the encoder you used for computing the embeddings." +) + + +class SemanticLexicalAnalysis(ContextualModel): + """Analyzes groups of texts based on their semantic/lexical differences. + + ```python + from turftopic import SemanticLexicalAnalysis + + corpus: list[str] = ["some text", "more text", ...] + labels: list[str] = ["group0", "group1"] + + model = SemanticLexicalAnalysis().fit(corpus, y=labels) + model.print_topics() + ``` + + Parameters + ---------- + encoder: str or SentenceTransformer + Model to encode documents/terms, all-MiniLM-L6-v2 is the default. + vectorizer: CountVectorizer, default None + Vectorizer used for term extraction. + Can be used to prune or filter the vocabulary. + """ + + def __init__( + self, + *, + encoder: Union[ + Encoder, str + ] = "sentence-transformers/all-MiniLM-L6-v2", + vectorizer: Optional[CountVectorizer] = None, + random_state: Optional[int] = None, + ): + self.encoder = encoder + if isinstance(encoder, str): + self.encoder_ = SentenceTransformer(encoder) + else: + self.encoder_ = encoder + if vectorizer is None: + self.vectorizer = default_vectorizer() + else: + self.vectorizer = vectorizer + + def fit_transform( + self, raw_documents, y, embeddings: Optional[np.ndarray] = None + ) -> np.ndarray: + y = np.array(y) + self.classes_ = np.sort(np.unique(y)) + doc_topic_matrix = label_binarize(y, classes=self.classes_) + console = Console() + self.embeddings = embeddings + with console.status("Fitting model") as status: + if self.embeddings is None: + status.update("Encoding documents") + self.embeddings = self.encoder_.encode(raw_documents) + console.log("Documents encoded.") + status.update("Extracting terms.") + doc_term_matrix = self.vectorizer.fit_transform(raw_documents) + vocab = self.vectorizer.get_feature_names_out() + console.log("Term extraction done.") + status.update("Computing lexical differences.") + self.lexical_components_ = scale( + fighting_words(doc_topic_matrix, doc_term_matrix), axis=1 + ) + console.log("Lexical components done.") + status.update("Encoding vocabulary") + self.vocab_embeddings = self.encoder_.encode(vocab) + if self.vocab_embeddings.shape[1] != self.embeddings.shape[1]: + raise ValueError( + NOT_MATCHING_ERROR.format( + n_dims=self.embeddings.shape[1], + n_word_dims=self.vocab_embeddings.shape[1], + ) + ) + console.log("Vocabulary encoded.") + status.update("Computing semantic differences.") + self.semantic_components_ = semantic_difference( + doc_topic_matrix, self.embeddings, self.vocab_embeddings + ) + self.components_ = self.semantic_components_ + console.log("Semantic comoponents done.") + console.log("Model fitting done.") + return doc_topic_matrix + + def plot_semantic_lexical_square(self, label): + vocab = self.get_vocab() + try: + import plotly.express as px + except (ImportError, ModuleNotFoundError) as e: + raise ModuleNotFoundError( + "Please install plotly if you intend to use plots in Turftopic." + ) from e + i_component = {lab: i for i, lab in enumerate(self.classes_)}[label] + # Semantic-lexical compass + x = self.semantic_components_[i_component] + y = self.lexical_components_[i_component] + points = np.array(list(zip(x, y))) + xx, yy = np.meshgrid( + np.linspace(np.min(x), np.max(x), 20), + np.linspace(np.min(y), np.max(y), 20), + ) + coords = np.array(list(zip(np.ravel(xx), np.ravel(yy)))) + coords = coords + np.random.default_rng(0).normal( + [0, 0], [0.1, 0.1], size=coords.shape + ) + dist = euclidean_distances(coords, points) + idxs = np.argmin(dist, axis=1) + fig = px.scatter( + x=x[idxs], + y=y[idxs], + text=vocab[idxs], + template="plotly_white", + ) + fig = fig.update_traces( + mode="text", textfont_color="black", marker=dict(color="black") + ).update_layout( + xaxis_title="Semantic Importance", + yaxis_title="Lexical Importance", + ) + fig = fig.update_layout( + width=1000, + height=1000, + font=dict(family="Times New Roman", color="black", size=21), + margin=dict(l=5, r=5, t=5, b=5), + ) + fig = fig.add_hline(y=0, line_color="black", line_width=4) + fig = fig.add_vline(x=0, line_color="black", line_width=4) + fig.add_annotation( + text="Lexical-Semantic", + x=np.max(x[(x > 0) & (y > 0)]), + y=np.max(y[(x > 0) & (y > 0)]), + ax=60, + ay=-60, + showarrow=True, + arrowwidth=3, + arrowhead=6, + arrowcolor="black", + font=dict(size=34, color="black"), + ) + fig.add_annotation( + text="Lexical-Nonsemantic", + x=np.min(x[(x < 0) & (y > 0)]), + y=np.max(y[(x < 0) & (y > 0)]), + ax=-60, + ay=-60, + showarrow=True, + arrowwidth=3, + arrowhead=6, + arrowcolor="black", + font=dict(size=34, color="black"), + ) + fig.add_annotation( + text="Semantic-Nonlexical", + x=np.max(x[(x > 0) & (y < 0)]), + y=np.min(y[(x > 0) & (y < 0)]), + ax=60, + ay=60, + showarrow=True, + arrowwidth=3, + arrowhead=6, + arrowcolor="black", + font=dict(size=34, color="black"), + ) + return fig + + def plot_residuals( + self, + label, + independent_variable: Literal["semantic", "lexical"] = "semantic", + ): + try: + import plotly.express as px + except (ImportError, ModuleNotFoundError) as e: + raise ModuleNotFoundError( + "Please install plotly if you intend to use plots in Turftopic." + ) from e + vocab = self.get_vocab() + i_component = {lab: i for i, lab in enumerate(self.classes_)}[label] + semantic_component = self.semantic_components_[i_component] + lexical_component = self.lexical_components_[i_component] + if independent_variable == "semantic": + x, y = semantic_component, lexical_component + else: + x, y = lexical_component, semantic_component + linreg = LinearRegression().fit(x[:, None], y) + y_pred = linreg.predict(x[:, None]) + residuals = y_pred - y + absres = np.abs(residuals) + sorted_res = np.argsort(residuals) + idxs = [*sorted_res[:100], *sorted_res[-100:]] + fig = px.scatter( + x=x, + y=residuals, + text=vocab, + template="plotly_white", + size=absres, + ) + for idx in idxs: + fig.add_annotation( + text=vocab[idx], + x=x[idx], + y=residuals[idx], + showarrow=False, + font=dict(size=max(int(8 * np.sqrt(absres[idx])), 8)), + ) + max_absres = np.max(absres) + fig.update_yaxes(range=(-max_absres * 1.1, max_absres * 1.1)) + fig.update_traces(mode="text") + fig.update_traces( + mode="markers", + textfont_color="black", + marker=dict(color="white", line=dict(color="black")), + hovertemplate="%{text}", + opacity=1, + ).update_layout( + xaxis_title=( + "Semantic Importance" + if independent_variable == "semantic" + else "Lexical Importance" + ), + yaxis_title=( + "Lexical Residual" + if independent_variable == "semantic" + else "Semantic Residual" + ), + ) + fig.update_layout( + width=1200, + height=600, + font=dict(family="Times New Roman", color="black", size=21), + margin=dict(l=5, r=5, t=5, b=5), + hoverlabel=dict( + bgcolor="white", font_size=24, font_family="Times New Roman" + ), + ) + fig.add_hline(y=0, line_color="black", line_width=4) + return fig + + def _topics_table( + self, + top_k: int = 10, + show_scores: bool = False, + show_negative: bool = False, + ) -> list[list[str]]: + columns = ["Topic ID"] + if getattr(self, "topic_names_", None): + columns.append("Topic Name") + columns.append("Semantic") + columns.append("Lexical") + rows = [] + try: + classes = self.classes_ + except AttributeError: + classes = list(range(self.components_.shape[0])) + vocab = self.get_vocab() + for i_topic, (topic_id, sem_component, lex_component) in enumerate( + zip(classes, self.semantic_components_, self.lexical_components_) + ): + semantic = np.argpartition(-sem_component, top_k)[:top_k] + semantic = semantic[np.argsort(-sem_component[semantic])] + lexical = np.argpartition(-lex_component, top_k)[:top_k] + lexical = lexical[np.argsort(-lex_component[lexical])] + if show_scores: + concat_semantic = ", ".join( + [ + f"{word}({importance:.2f})" + for word, importance in zip( + vocab[semantic], sem_component[semantic] + ) + ] + ) + concat_lexical = ", ".join( + [ + f"{word}({importance:.2f})" + for word, importance in zip( + vocab[lexical], lex_component[lexical] + ) + ] + ) + else: + concat_semantic = ", ".join([word for word in vocab[semantic]]) + concat_lexical = ", ".join([word for word in vocab[lexical]]) + row = [f"{topic_id}"] + if getattr(self, "topic_names_", None): + row.append(self.topic_names_[i_topic]) + row.append(concat_semantic) + row.append(concat_lexical) + rows.append(row) + return [columns, *rows]