diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 143566f..1fb4f1b 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -1,35 +1,30 @@ name: Tests -on: - push: - branches: [main] - pull_request: - branches: [main] + +on: [push] jobs: - pytest: + build: + runs-on: ubuntu-latest - strategy: - matrix: - python-version: ["3.11"] - # - # This allows a subsequently queued workflow run to interrupt previous runs - concurrency: - group: "${{ github.workflow }}-${{ matrix.python-version}}-${{ matrix.os }} @ ${{ github.ref }}" - cancel-in-progress: true steps: - - uses: actions/checkout@v4 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v4 + - uses: actions/checkout@v5 + - name: Set up Python + # This is the version of the action for setting up Python, not the Python version. + uses: actions/setup-python@v5 with: - python-version: ${{ matrix.python-version }} + # Semantic version range syntax or exact version of a Python version + python-version: '3.11' + # Optional - x64 or x86 architecture, defaults to x64 + architecture: 'x64' cache: "pip" # You can test your matrix by printing the current Python version - name: Display Python version - run: python3 -c "import sys; print(sys.version)" - - - name: Install dependencies - run: python3 -m pip install --upgrade turftopic[pyro-ppl] pandas pytest plotly igraph datasets pillow + run: python -c "import sys; print(sys.version)" + - name: Install package + run: | + python -m pip install --upgrade pip + pip install .[dev] + pip install pytest - name: Run tests - run: python3 -m pytest tests/ - + run: python -m pytest tests/ diff --git a/docs/SensTopic.md b/docs/SensTopic.md new file mode 100644 index 0000000..deac1b8 --- /dev/null +++ b/docs/SensTopic.md @@ -0,0 +1,161 @@ +# SensTopic (BETA) + +SensTopic is a version of Semantic Signal Separation, that only discovers positive signals, while allowing components to be unbounded. +This is achieved with an algorithm called Semi-nonnegative Matrix Factorization or SNMF. + +> :warning: This model is still in an experimental phase. More documentation and a paper are on their way. :warning: + +SensTopic uses a very efficient implementation of the SNMF algorithm, that is implemented in raw NumPy, but also in JAX. +If you want to enable hardware acceleration and JIT compilation, make sure to install JAX before running the model. + +```bash +pip install jax +``` + +Here's an example of running SensTopic on the 20 Newsgroups dataset: + +```python +from sklearn.datasets import fetch_20newsgroups +from turftopic import SensTopic + +corpus = fetch_20newsgroups( + subset="all", + remove=("headers", "footers", "quotes"), +).data + +model = SensTopic(25) +model.fit(corpus) + +model.print_topics() +``` + + +| Topic ID | Highest Ranking | +| - | - | +| | ... | +| 8 | gospels, mormon, catholics, protestant, mormons, synagogues, seminary, catholic, liturgy, churches | +| 9 | encryption, encrypt, encrypting, crypt, cryptosystem, cryptography, cryptosystems, decryption, encrypted, spying | +| 10 | palestinians, israelis, palestinian, israeli, gaza, israel, gazans, palestine, zionist, aviv | +| 11 | nasa, spacecraft, spaceflight, satellites, interplanetary, astronomy, astronauts, astronomical, orbiting, astronomers | +| 12 | imagewriter, colormaps, bitmap, bitmaps, pkzip, imagemagick, colormap, formats, adobe, ghostscript | +| | ... | + +## Sparsity + +SensTopic has a sparsity hyper-parameter, that roughly dictates how many documents will be assigned to a single document, where many topics per document get penalized. +This means that the model is both a matrix factorization model, but can also function as a soft clustering model, depending on this parameter. +Unlike clustering models, however, it may assign multiple topics to documents that have them, and won't force every document to contain only one topic. + +Higher values will make your model more like a clustering model, while lower values will make it more like a decomposition model: + +??? info "Click to see code" + ```python + import pandas as pd + import numpy as np + import plotly.express as px + from sentence_transformers import SentenceTransformer + from datasets import load_dataset + + from turftopic import SensTopic + + ds = load_dataset("gopalkalpande/bbc-news-summary", split="train") + corpus = list(ds["Summaries"]) + + encoder = SentenceTransformer("all-MiniLM-L6-v2") + embeddings = encoder.encode(corpus, show_progress_bar=True) + + models = [] + doc_topic_ms = [] + sparsities = np.array( + [ + 0.05, + 0.1, + 0.25, + 0.5, + 0.75, + 1.0, + 2.5, + 5.0, + 10.0, + ] + ) + for i, sparsity in enumerate(sparsities): + model = SensTopic( + n_components=3, random_state=42, sparsity=sparsity, encoder=encoder + ) + doc_topic = model.fit_transform(corpus, embeddings=embeddings) + doc_topic = (doc_topic.T / doc_topic.sum(axis=1)).T + models.append(model) + doc_topic_ms.append(doc_topic) + a_name, b_name, c_name = models[0].topic_names + records = [] + for i, doc_topic in enumerate(doc_topic_ms): + for dt in doc_topic: + a, b, c, *_ = dt + records.append( + { + "sparsity": sparsities[i], + a_name: a, + b_name: b, + c_name: c, + "topic": models[0].topic_names[np.argmax(dt)], + } + ) + df = pd.DataFrame.from_records(records) + fig = px.scatter_ternary( + df, a=a_name, b=b_name, c=c_name, animation_frame="sparsity", color="topic" + ) + fig.show() + ``` + +
+ +
Ternary plot of topic distribution in a 3 topic SensTopic model varying with sparsity.
+
+ +You can see that as the sparsity increases, topics get clustered much more clearly, and more weight gets allocated to the edges of the graph. + +To see how many topics there are in your document you can use the `plot_topic_decay()` method, that shows you how topic weights get assigned to documents. + +```python +model.plot_topic_decay() +``` + +
+ +
Topic Decay in a SensTopic Model with sparsity=1.
+
+ +## Automatic number of topics + +SensTopic can learn the number of topics in a given dataset. +In order to determine this quantity, we use a version of the Bayesian Information Criterion modified for NMF. +This does not work equally well for all corpora, but it can be a powerful tool when the number of topics is not known a-priori. + +In this example the model finds 6 topics in the BBC News dataset: + +```python +# pip install datasets +from datasets import load_dataset + +ds = load_dataset("gopalkalpande/bbc-news-summary", split="train") +corpus = list(ds["Summaries"]) + +model = SensTopic("auto") +model.fit(corpus) +model.print_topics() +``` + +| Topic ID | Highest Ranking | +| - | - | +| 0 | liverpool, mourinho, chelsea, premiership, arsenal, striker, madrid, midfield, uefa, manchester | +| 1 | oscar, bafta, oscars, cast, cinema, hollywood, actor, screenplay, actors, films | +| 2 | mobile, mobiles, broadband, devices, digital, internet, computers, microsoft, phones, telecoms | +| 3 | tory, blair, minister, ministers, parliamentary, mps, parliament, politicians, constituency, ukip | +| 4 | tennis, competing, federer, wimbledon, iaaf, olympic, tournament, athlete, rugby, olympics | +| 5 | gdp, stock, economy, earnings, investments, investment, invest, exports, finance, economies | + + +## API Reference + +::: turftopic.models.senstopic.SensTopic diff --git a/docs/images/ternary_sparsity.html b/docs/images/ternary_sparsity.html new file mode 100644 index 0000000..7d6981a --- /dev/null +++ b/docs/images/ternary_sparsity.html @@ -0,0 +1,3892 @@ + + + +
+
+ + \ No newline at end of file diff --git a/docs/images/topic_decay.html b/docs/images/topic_decay.html new file mode 100644 index 0000000..820d474 --- /dev/null +++ b/docs/images/topic_decay.html @@ -0,0 +1,3888 @@ + + + +
+
+ + \ No newline at end of file diff --git a/mkdocs.yml b/mkdocs.yml index dff95ab..f0a885e 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -25,6 +25,7 @@ nav: - Topic Models: - Model Overview: model_overview.md - Semantic Signal Separation (S³): s3.md + - SensTopic (BETA): SensTopic.md - KeyNMF: KeyNMF.md - Topeax: Topeax.md - GMM: GMM.md diff --git a/pyproject.toml b/pyproject.toml index 245fb63..6531db2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,7 +33,6 @@ dependencies = [ [project.optional-dependencies] pyro-ppl = ["pyro-ppl>=1.8.0,<2.0.0"] openai = ["openai>=1.40.0,<2.0.0"] -opentsne = ["openTSNE>=1.0.0,<2.0.0"] datamapplot=["datamapplot>=0.4.2, <1.0.0"] jieba = ["jieba>=0.40.0,<1.0.0"] spacy = ["spacy>=3.6.0,<4.0.0"] @@ -52,7 +51,6 @@ docs = [ dev = [ "pyro-ppl>=1.8.0,<2.0.0", "openai>=1.40.0,<2.0.0", - "openTSNE>=1.0.0,<2.0.0", "datamapplot>=0.4.2, <1.0.0", "jieba>=0.40.0,<1.0.0", "snowballstemmer>=2.0.0,<3.0.0", @@ -65,6 +63,7 @@ dev = [ "mkdocstrings==0.22.0", "mkdocstrings-python==1.8.0", "griffe==0.40.0", + "datasets>=4.3.0" ] [build-system] diff --git a/tests/test_integration.py b/tests/test_integration.py index efe2f05..181e591 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -18,6 +18,8 @@ FASTopic, KeyNMF, SemanticSignalSeparation, + SensTopic, + Topeax, load_model, ) @@ -79,6 +81,8 @@ def generate_dates( ), AutoEncodingTopicModel(3, combined=True), FASTopic(3, batch_size=None), + SensTopic(), + Topeax(), ] dynamic_models = [ diff --git a/turftopic/__init__.py b/turftopic/__init__.py index 99250de..8372841 100644 --- a/turftopic/__init__.py +++ b/turftopic/__init__.py @@ -3,10 +3,11 @@ from turftopic.base import ContextualModel from turftopic.error import NotInstalled from turftopic.models.cluster import BERTopic, ClusteringTopicModel, Top2Vec -from turftopic.models.decomp import SemanticSignalSeparation +from turftopic.models.decomp import S3, SemanticSignalSeparation from turftopic.models.fastopic import FASTopic from turftopic.models.gmm import GMM from turftopic.models.keynmf import KeyNMF +from turftopic.models.senstopic import SensTopic from turftopic.models.topeax import Topeax from turftopic.serialization import load_model @@ -31,4 +32,6 @@ "load_model", "build_datamapplot", "create_concept_browser", + "S3", + "SensTopic", ] diff --git a/turftopic/models/_snmf.py b/turftopic/models/_snmf.py new file mode 100644 index 0000000..fcc1d7d --- /dev/null +++ b/turftopic/models/_snmf.py @@ -0,0 +1,143 @@ +"""This file implements semi-NMF, where doc_topic proportions are not allowed to be negative, but components are unbounded.""" + +import warnings +from typing import Optional + +import numpy as np +from sklearn.base import BaseEstimator, TransformerMixin +from sklearn.cluster import KMeans +from sklearn.preprocessing import label_binarize +from tqdm import trange + +EPSILON = np.finfo(np.float32).eps + +try: + import jax.numpy as jnp + from jax import jit +except ModuleNotFoundError: + warnings.warn("JAX not found, continuing with NumPy implementation.") + jnp = np + + # Dummy JIT as the identity function + def jit(f): + return f + + +def init_G( + X, n_components: int, constant=0.2, random_state=None +) -> np.ndarray: + """Returns W""" + kmeans = KMeans(n_components, random_state=random_state).fit(X.T) + # n_components, n_columns + G = label_binarize(kmeans.labels_, classes=np.arange(n_components)) + return G + constant + + +@jit +def separate(A): + abs_A = jnp.abs(A) + pos = (abs_A + A) / 2 + neg = (abs_A - A) / 2 + return pos, neg + + +@jit +def update_F(X, G): + return X @ G @ jnp.linalg.inv(G.T @ G) + + +@jit +def update_G(X, G, F, sparsity=0): + pos_xtf, neg_xtf = separate(X.T @ F) + pos_gftf, neg_gftf = separate(G @ (F.T @ F)) + numerator = pos_xtf + neg_gftf + denominator = neg_xtf + pos_gftf + denominator += sparsity + denominator = jnp.maximum(denominator, EPSILON) + delta_G = jnp.sqrt(numerator / denominator) + G *= delta_G + return G + + +@jit +def rec_err(X, F, G): + err = X - (F @ G.T) + return jnp.linalg.norm(err) + + +class SNMF(TransformerMixin, BaseEstimator): + def __init__( + self, + n_components: int, + tol: float = 1e-5, + max_iter: int = 200, + progress_bar: bool = True, + random_state: Optional[int] = None, + sparsity: float = 0.0, + verbose: bool = False, + ): + self.n_components = n_components + self.tol = tol + self.max_iter = max_iter + self.progress_bar = progress_bar + self.random_state = random_state + self.sparsity = sparsity + self.verbose = verbose + + def fit_transform(self, X: np.ndarray, y=None): + G = init_G(X.T, self.n_components, random_state=self.random_state) + F = update_F(X.T, G) + error_at_init = rec_err(X.T, F, G) + prev_error = error_at_init + for i in trange( + self.max_iter, + desc="Iterative updates.", + disable=not self.progress_bar, + ): + G = update_G(X.T, G, F, self.sparsity) + F = update_F(X.T, G) + error = rec_err(X.T, F, G) + difference = prev_error - error + if (error < error_at_init) and ( + (prev_error - error) / error_at_init + ) < self.tol: + if self.verbose: + print(f"Converged after {i} iterations") + self.n_iter_ = i + break + prev_error = error + if self.verbose: + print( + f"Iteration: {i}, Error: {error}, init_error: {error_at_init}, difference from previous: {difference}" + ) + else: + warnings.warn( + "SNMF did not converge, try specifying a higher max_iter." + ) + self.components_ = np.array(F.T) + self.reconstruction_err_ = error + self.n_iter_ = i + return np.array(G) + + def fit_timeslice(self, X_t: np.ndarray, G_t: np.ndarray): + F = update_F(X_t.T, G_t) + return F.T + + def transform(self, X: np.ndarray): + G = jnp.maximum(X @ jnp.linalg.pinv(self.components_), 0) + return np.array(G) + + def inverse_transform(self, X): + """Transform data back to its original space. + + Parameters + ---------- + X : ndarray of shape (n_samples, n_components) + Transformed data matrix. + + Returns + ------- + X_original : ndarray of shape (n_samples, n_features) + Returns a data matrix of the original shape. + """ + return X @ self.components_ diff --git a/turftopic/models/decomp.py b/turftopic/models/decomp.py index 9e4a221..ef21911 100644 --- a/turftopic/models/decomp.py +++ b/turftopic/models/decomp.py @@ -3,7 +3,6 @@ from typing import Literal, Optional, Union import numpy as np -from PIL import Image from rich.console import Console from sentence_transformers import SentenceTransformer from sklearn.base import TransformerMixin @@ -131,7 +130,7 @@ def estimate_components( @property def has_negative_side(self) -> bool: - return False + return True def fit_transform( self, raw_documents, y=None, embeddings: Optional[np.ndarray] = None @@ -181,6 +180,7 @@ def fit_transform( self.negative_documents = self.get_top_documents( raw_documents, document_topic_matrix=doc_topic, positive=False ) + self.document_topic_matrix = doc_topic console.log("Model fitting done.") return doc_topic @@ -484,7 +484,7 @@ def fit_transform_dynamic( ) self.temporal_importance_ = np.zeros((n_bins, n_comp)) whitened_embeddings = np.copy(self.embeddings) - if getattr(self.decomposition, "whiten"): + if getattr(self.decomposition, "whiten", False): whitened_embeddings -= self.decomposition.mean_ # doc_topic = np.dot(X, self.components_.T) for i_timebin in np.unique(time_labels): @@ -641,19 +641,21 @@ def transform( def print_topics( self, - top_k: int = 5, + top_k: int = 10, show_scores: bool = False, - show_negative: bool = True, + show_negative: bool = False, ): + show_negative = self.has_negative_side super().print_topics(top_k, show_scores, show_negative) def export_topics( self, - top_k: int = 5, + top_k: int = 10, show_scores: bool = False, show_negative: bool = True, format: str = "csv", ) -> str: + show_negative = self.has_negative_side return super().export_topics(top_k, show_scores, show_negative, format) def print_representative_documents( @@ -661,9 +663,10 @@ def print_representative_documents( topic_id, raw_documents, document_topic_matrix=None, - top_k=5, - show_negative: bool = True, + top_k=10, + show_negative: bool = False, ): + show_negative = self.has_negative_side super().print_representative_documents( topic_id, raw_documents, @@ -677,10 +680,11 @@ def export_representative_documents( topic_id, raw_documents, document_topic_matrix=None, - top_k=5, - show_negative: bool = True, + top_k=10, + show_negative: bool = False, format: str = "csv", ): + show_negative = self.has_negative_side return super().export_representative_documents( topic_id, raw_documents, @@ -958,3 +962,7 @@ def _topics_over_time( fields.append(concat_words) rows.append(fields) return [columns, *rows] + + +# Alias for base class SemanticSignalSeparation +S3 = SemanticSignalSeparation diff --git a/turftopic/models/senstopic.py b/turftopic/models/senstopic.py new file mode 100644 index 0000000..cadc37a --- /dev/null +++ b/turftopic/models/senstopic.py @@ -0,0 +1,494 @@ +from datetime import datetime +from functools import partial +from typing import Literal, Optional, Union + +import numpy as np +from rich.console import Console +from sentence_transformers import SentenceTransformer +from sklearn.exceptions import NotFittedError +from sklearn.feature_extraction.text import CountVectorizer +from sklearn.manifold import TSNE +from sklearn.metrics.pairwise import cosine_similarity + +from turftopic._datamapplot import build_datamapplot +from turftopic.base import ContextualModel, Encoder +from turftopic.dynamic import DynamicTopicModel +from turftopic.encoders.multimodal import MultimodalEncoder +from turftopic.models._snmf import SNMF +from turftopic.multimodal import ( + ImageRepr, + MultimodalEmbeddings, + MultimodalModel, +) +from turftopic.optimization import ( + optimize_n_components, +) +from turftopic.vectorizers.default import default_vectorizer + +NOT_MATCHING_ERROR = ( + "Document embedding dimensionality ({n_dims}) doesn't match term embedding dimensionality ({n_word_dims}). " + + "Perhaps you are using precomputed embeddings but forgot to pass an encoder to your model. " + + "Try to initialize the model with the encoder you used for computing the embeddings." +) + + +def bic_snmf( + n_components: int, sparsity: float, X, random_state: int = 42 +) -> float: + decomp = SNMF( + n_components=n_components, + sparsity=sparsity, + random_state=42, + verbose=False, + progress_bar=False, + ) + doc_topic = decomp.fit_transform(X) + rss = np.square(decomp.reconstruction_err_) + n_docs, n_dims = X.shape + # BIC1 from https://pmc.ncbi.nlm.nih.gov/articles/PMC9181460/ + bic1 = np.log(rss) + n_components * ( + (n_docs + n_dims) / (n_docs * n_dims) + ) * np.log((n_docs * n_dims) / (n_docs + n_dims)) + return bic1 + + +class SensTopic(ContextualModel, DynamicTopicModel, MultimodalModel): + """Semi-nonnegative Semantic Signal Separation. + + ```python + from turftopic import SensTopic + + corpus: list[str] = ["some text", "more text", ...] + + model = SensTopic(10).fit(corpus) + model.print_topics() + ``` + + Parameters + ---------- + n_components: int, default "auto" + Number of topics. + If "auto", the number of topics is determined using BIC. + encoder: str or SentenceTransformer + Model to encode documents/terms, all-MiniLM-L6-v2 is the default. + vectorizer: CountVectorizer, default None + Vectorizer used for term extraction. + Can be used to prune or filter the vocabulary. + max_iter: int, default 200 + Maximum number of iterations for S-NMF. + feature_importance: "axial", "angular" or "combined", default "combined" + Defines whether the word's position on an axis ('axial'), it's angle to the axis ('angular') + or their combination ('combined') should determine the word's importance for a topic. + random_state: int, default None + Random state to use so that results are exactly reproducible. + sparsity: float, default 1 + L1 penalty applied to document-topic proportions. + Higher values push the model to assign fewer topics to a single document, + while lower values will distribute topics across documents. + """ + + def __init__( + self, + n_components: Union[int, Literal["auto"]] = "auto", + encoder: Union[ + Encoder, str, MultimodalEncoder + ] = "sentence-transformers/all-MiniLM-L6-v2", + vectorizer: Optional[CountVectorizer] = None, + max_iter: int = 200, + feature_importance: Literal[ + "axial", "angular", "combined" + ] = "combined", + random_state: Optional[int] = None, + sparsity: float = 1, + ): + self.n_components = n_components + self.encoder = encoder + self.feature_importance = feature_importance + if isinstance(encoder, str): + self.encoder_ = SentenceTransformer(encoder) + else: + self.encoder_ = encoder + self.validate_encoder() + if vectorizer is None: + self.vectorizer = default_vectorizer() + else: + self.vectorizer = vectorizer + self.max_iter = max_iter + self.random_state = random_state + self.sparsity = sparsity + + def estimate_components( + self, feature_importance: Literal["axial", "angular", "combined"] + ) -> np.ndarray: + """Reestimates components based on the chosen feature_importance method.""" + if feature_importance == "axial": + self.components_ = self.axial_components_ + elif feature_importance == "angular": + self.components_ = self.angular_components_ + elif feature_importance == "combined": + self.components_ = ( + np.square(self.axial_components_) * self.angular_components_ + ) + if hasattr(self, "axial_temporal_components_"): + if feature_importance == "axial": + self.temporal_components_ = self.axial_temporal_components_ + elif feature_importance == "angular": + self.temporal_components_ = self.angular_temporal_components_ + elif feature_importance == "combined": + self.temporal_components_ = ( + np.square(self.axial_temporal_components_) + * self.angular_temporal_components_ + ) + return self.components_ + + def fit_transform( + self, raw_documents, y=None, embeddings: Optional[np.ndarray] = None + ) -> np.ndarray: + console = Console() + self.embeddings = embeddings + with console.status("Fitting model") as status: + if self.embeddings is None: + status.update("Encoding documents") + self.embeddings = self.encoder_.encode(raw_documents) + console.log("Documents encoded.") + if self.n_components == "auto": + status.update("Finding the number of components.") + self.n_components_ = optimize_n_components( + partial( + bic_snmf, X=self.embeddings, sparsity=self.sparsity + ), + min_n=1, + verbose=True, + ) + console.log("N components set at: " + str(self.n_components_)) + else: + self.n_components_ = self.n_components + self.decomposition = SNMF( + self.n_components_, + max_iter=self.max_iter, + sparsity=self.sparsity, + random_state=self.random_state, + ) + status.update("Decomposing embeddings") + doc_topic = self.decomposition.fit_transform(self.embeddings, y=y) + console.log("Decomposition done.") + status.update("Extracting terms.") + vocab = self.vectorizer.fit(raw_documents).get_feature_names_out() + console.log("Term extraction done.") + if getattr(self, "vocab_embeddings", None) is None: + status.update("Encoding vocabulary") + self.vocab_embeddings = self.encoder_.encode(vocab) + if self.vocab_embeddings.shape[1] != self.embeddings.shape[1]: + raise ValueError( + NOT_MATCHING_ERROR.format( + n_dims=self.embeddings.shape[1], + n_word_dims=self.vocab_embeddings.shape[1], + ) + ) + console.log("Vocabulary encoded.") + status.update("Estimating term importances") + vocab_topic = self.decomposition.transform(self.vocab_embeddings) + self.axial_components_ = vocab_topic.T + if self.feature_importance == "axial": + self.components_ = self.axial_components_ + elif self.feature_importance == "angular": + self.components_ = self.angular_components_ + elif self.feature_importance == "combined": + self.components_ = ( + np.square(self.axial_components_) + * self.angular_components_ + ) + self.top_documents = self.get_top_documents( + raw_documents, document_topic_matrix=doc_topic + ) + self.document_topic_matrix = doc_topic + console.log("Model fitting done.") + return doc_topic + + def fit_transform_multimodal( + self, + raw_documents: list[str], + images: list[ImageRepr], + y=None, + embeddings: Optional[MultimodalEmbeddings] = None, + ) -> np.ndarray: + self.validate_embeddings(embeddings) + console = Console() + self.images = images + self.multimodal_embeddings = embeddings + with console.status("Fitting model") as status: + if self.multimodal_embeddings is None: + status.update("Encoding documents") + self.multimodal_embeddings = self.encode_multimodal( + raw_documents, images + ) + console.log("Documents encoded.") + self.embeddings = self.multimodal_embeddings["document_embeddings"] + if self.n_components == "auto": + status.update("Finding the number of components.") + self.n_components_ = optimize_n_components( + partial( + bic_snmf, X=self.embeddings, sparsity=self.sparsity + ), + min_n=1, + verbose=True, + ) + console.log("N components set at: " + str(self.n_components_)) + else: + self.n_components_ = self.n_components + self.decomposition = SNMF( + self.n_components_, + max_iter=self.max_iter, + sparsity=self.sparsity, + random_state=self.random_state, + ) + status.update("Decomposing embeddings") + doc_topic = self.decomposition.fit_transform(self.embeddings, y=y) + console.log("Decomposition done.") + status.update("Extracting terms.") + vocab = self.vectorizer.fit(raw_documents).get_feature_names_out() + console.log("Term extraction done.") + status.update("Encoding vocabulary") + self.vocab_embeddings = self.encode_documents(vocab) + if self.vocab_embeddings.shape[1] != self.embeddings.shape[1]: + raise ValueError( + NOT_MATCHING_ERROR.format( + n_dims=self.embeddings.shape[1], + n_word_dims=self.vocab_embeddings.shape[1], + ) + ) + console.log("Vocabulary encoded.") + status.update("Estimating term importances") + vocab_topic = self.decomposition.transform(self.vocab_embeddings) + self.axial_components_ = vocab_topic.T + if self.feature_importance == "axial": + self.components_ = self.axial_components_ + elif self.feature_importance == "angular": + self.components_ = self.angular_components_ + elif self.feature_importance == "combined": + self.components_ = ( + np.square(self.axial_components_) + * self.angular_components_ + ) + console.log("Model fitting done.") + status.update("Transforming images") + self.image_topic_matrix = self.transform( + [], embeddings=self.multimodal_embeddings["image_embeddings"] + ) + self.top_images = self.collect_top_images( + images, self.image_topic_matrix + ) + self.top_documents = self.get_top_documents( + raw_documents, document_topic_matrix=doc_topic + ) + console.log("Images transformed") + return doc_topic + + def fit_transform_dynamic( + self, + raw_documents, + timestamps: list[datetime], + embeddings: Optional[np.ndarray] = None, + bins: Union[int, list[datetime]] = 10, + ) -> np.ndarray: + document_topic_matrix = self.fit_transform( + raw_documents, embeddings=embeddings + ) + time_labels, self.time_bin_edges = self.bin_timestamps( + timestamps, bins + ) + n_comp, n_vocab = self.components_.shape + n_bins = len(self.time_bin_edges) - 1 + self.axial_temporal_components_ = np.full( + (n_bins, n_comp, n_vocab), + np.nan, + dtype=self.components_.dtype, + ) + self.temporal_importance_ = np.zeros((n_bins, n_comp)) + # doc_topic = np.dot(X, self.components_.T) + for i_timebin in np.unique(time_labels): + topic_importances = document_topic_matrix[ + time_labels == i_timebin + ].mean(axis=0) + self.temporal_importance_[i_timebin, :] = topic_importances + t_doc_topic = document_topic_matrix[time_labels == i_timebin] + t_embeddings = self.embeddings[time_labels == i_timebin] + t_components = self.decomposition.fit_timeslice( + t_embeddings, t_doc_topic + ) + ax_t = np.maximum( + self.vocab_embeddings @ np.linalg.pinv(t_components), 0 + ) + self.axial_temporal_components_[i_timebin, :, :] = ax_t.T + self.estimate_components(self.feature_importance) + return document_topic_matrix + + @property + def angular_components_(self): + """Reweights words based on their angle in ICA-space to the axis + base vectors. + """ + if not hasattr(self, "axial_components_"): + raise NotFittedError("Model has not been fitted yet.") + word_vectors = self.axial_components_.T + n_topics = self.axial_components_.shape[0] + axis_vectors = np.eye(n_topics) + cosine_components = cosine_similarity(axis_vectors, word_vectors) + return cosine_components + + @property + def angular_temporal_components_(self): + """Reweights words based on their angle in ICA-space to the axis + base vectors in a dynamic model. + """ + if not hasattr(self, "axial_temporal_components_"): + raise NotFittedError("Model has not been fitted dynamically.") + components = [] + for axial_components in self.axial_temporal_components_: + word_vectors = axial_components.T + n_topics = axial_components.shape[0] + axis_vectors = np.eye(n_topics) + cosine_components = cosine_similarity(axis_vectors, word_vectors) + components.append(cosine_components) + return np.stack(components) + + def plot_topic_decay(self): + try: + import plotly.graph_objects as go + except (ImportError, ModuleNotFoundError) as e: + raise ModuleNotFoundError( + "Please install plotly if you intend to use plots in Turftopic." + ) from e + doc_topic = self.document_topic_matrix + topic_proportions = [] + for dt in doc_topic: + sum_dt = dt.sum() + if sum_dt > 0: + dt /= sum_dt + dt = -np.sort(-dt) + topic_proportions.append(dt) + topic_proportions = np.stack(topic_proportions) + med_prop = np.median(topic_proportions, axis=0) + upper = np.quantile(topic_proportions, 0.975, axis=0) + lower = np.quantile(topic_proportions, 0.025, axis=0) + fig = go.Figure( + [ + go.Scatter( + name="Median", + x=np.arange(self.n_components_), + y=med_prop, + mode="lines", + line=dict(color="rgb(31, 119, 180)"), + ), + go.Scatter( + name="Upper Bound", + x=np.arange(self.n_components_), + y=upper, + mode="lines", + marker=dict(color="#444"), + line=dict(width=0), + showlegend=False, + ), + go.Scatter( + name="Lower Bound", + x=np.arange(self.n_components_), + y=lower, + marker=dict(color="#444"), + line=dict(width=0), + mode="lines", + fillcolor="rgba(68, 68, 68, 0.3)", + fill="tonexty", + showlegend=False, + ), + ] + ) + fig = fig.update_layout( + template="plotly_white", + xaxis_title="Topic Rank", + yaxis_title="Topic Proportion", + title="Topic Decay", + font=dict(family="Merriweather", size=16), + ) + return fig + + def plot_components( + self, hover_text: Optional[list[str]] = None, **kwargs + ): + """Creates an interactive browser plot of the topics in your data using plotly. + + Parameters + ---------- + hover_text: list of str, optional + Text to show when hovering over a document. + + Returns + ------- + plot + Interactive datamap plot, you can call the `.show()` method to + display it in your default browser or save it as static HTML using `.write_html()`. + """ + doc_topic = self.document_topic_matrix + coords = TSNE(2, metric="cosine").fit_transform(doc_topic) + labels = np.argmax(doc_topic, axis=1) + print(np.unique_counts(labels)) + topics_present = np.sort(np.unique(labels)) + names = [self.topic_names[i] for i in topics_present] + if getattr(self, "topic_descriptions", None) is not None: + desc = [self.topic_descriptions[i] for i in topics_present] + else: + desc = None + all_words = self.get_top_words() + keywords = [all_words[i] for i in topics_present] + fig = build_datamapplot( + coords, + labels=labels, + topic_names=names, + top_words=keywords, + hover_text=hover_text, + topic_descriptions=desc, + classes=topics_present, + # Boundaries are unlikely to be very clear + cluster_boundary_polygons=False, + ) + return fig + + def plot_components_datamapplot( + self, hover_text: Optional[list[str]] = None, **kwargs + ): + """Creates an interactive browser plot of the topics in your data using datamapplot. + + Parameters + ---------- + hover_text: list of str, optional + Text to show when hovering over a document. + + Returns + ------- + plot + Interactive datamap plot, you can call the `.show()` method to + display it in your default browser or save it as static HTML using `.write_html()`. + """ + doc_topic = self.document_topic_matrix + coords = TSNE(2, metric="cosine").fit_transform(doc_topic) + labels = np.argmax(doc_topic, axis=1) + print(np.unique_counts(labels)) + topics_present = np.sort(np.unique(labels)) + names = [self.topic_names[i] for i in topics_present] + if getattr(self, "topic_descriptions", None) is not None: + desc = [self.topic_descriptions[i] for i in topics_present] + else: + desc = None + all_words = self.get_top_words() + keywords = [all_words[i] for i in topics_present] + fig = build_datamapplot( + coords, + labels=labels, + topic_names=names, + top_words=keywords, + hover_text=hover_text, + topic_descriptions=desc, + classes=topics_present, + # Boundaries are unlikely to be very clear + cluster_boundary_polygons=False, + ) + return fig