diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml
index 143566f..1fb4f1b 100644
--- a/.github/workflows/tests.yml
+++ b/.github/workflows/tests.yml
@@ -1,35 +1,30 @@
name: Tests
-on:
- push:
- branches: [main]
- pull_request:
- branches: [main]
+
+on: [push]
jobs:
- pytest:
+ build:
+
runs-on: ubuntu-latest
- strategy:
- matrix:
- python-version: ["3.11"]
- #
- # This allows a subsequently queued workflow run to interrupt previous runs
- concurrency:
- group: "${{ github.workflow }}-${{ matrix.python-version}}-${{ matrix.os }} @ ${{ github.ref }}"
- cancel-in-progress: true
steps:
- - uses: actions/checkout@v4
- - name: Set up Python ${{ matrix.python-version }}
- uses: actions/setup-python@v4
+ - uses: actions/checkout@v5
+ - name: Set up Python
+ # This is the version of the action for setting up Python, not the Python version.
+ uses: actions/setup-python@v5
with:
- python-version: ${{ matrix.python-version }}
+ # Semantic version range syntax or exact version of a Python version
+ python-version: '3.11'
+ # Optional - x64 or x86 architecture, defaults to x64
+ architecture: 'x64'
cache: "pip"
# You can test your matrix by printing the current Python version
- name: Display Python version
- run: python3 -c "import sys; print(sys.version)"
-
- - name: Install dependencies
- run: python3 -m pip install --upgrade turftopic[pyro-ppl] pandas pytest plotly igraph datasets pillow
+ run: python -c "import sys; print(sys.version)"
+ - name: Install package
+ run: |
+ python -m pip install --upgrade pip
+ pip install .[dev]
+ pip install pytest
- name: Run tests
- run: python3 -m pytest tests/
-
+ run: python -m pytest tests/
diff --git a/docs/SensTopic.md b/docs/SensTopic.md
new file mode 100644
index 0000000..deac1b8
--- /dev/null
+++ b/docs/SensTopic.md
@@ -0,0 +1,161 @@
+# SensTopic (BETA)
+
+SensTopic is a version of Semantic Signal Separation, that only discovers positive signals, while allowing components to be unbounded.
+This is achieved with an algorithm called Semi-nonnegative Matrix Factorization or SNMF.
+
+> :warning: This model is still in an experimental phase. More documentation and a paper are on their way. :warning:
+
+SensTopic uses a very efficient implementation of the SNMF algorithm, that is implemented in raw NumPy, but also in JAX.
+If you want to enable hardware acceleration and JIT compilation, make sure to install JAX before running the model.
+
+```bash
+pip install jax
+```
+
+Here's an example of running SensTopic on the 20 Newsgroups dataset:
+
+```python
+from sklearn.datasets import fetch_20newsgroups
+from turftopic import SensTopic
+
+corpus = fetch_20newsgroups(
+ subset="all",
+ remove=("headers", "footers", "quotes"),
+).data
+
+model = SensTopic(25)
+model.fit(corpus)
+
+model.print_topics()
+```
+
+
+| Topic ID | Highest Ranking |
+| - | - |
+| | ... |
+| 8 | gospels, mormon, catholics, protestant, mormons, synagogues, seminary, catholic, liturgy, churches |
+| 9 | encryption, encrypt, encrypting, crypt, cryptosystem, cryptography, cryptosystems, decryption, encrypted, spying |
+| 10 | palestinians, israelis, palestinian, israeli, gaza, israel, gazans, palestine, zionist, aviv |
+| 11 | nasa, spacecraft, spaceflight, satellites, interplanetary, astronomy, astronauts, astronomical, orbiting, astronomers |
+| 12 | imagewriter, colormaps, bitmap, bitmaps, pkzip, imagemagick, colormap, formats, adobe, ghostscript |
+| | ... |
+
+## Sparsity
+
+SensTopic has a sparsity hyper-parameter, that roughly dictates how many documents will be assigned to a single document, where many topics per document get penalized.
+This means that the model is both a matrix factorization model, but can also function as a soft clustering model, depending on this parameter.
+Unlike clustering models, however, it may assign multiple topics to documents that have them, and won't force every document to contain only one topic.
+
+Higher values will make your model more like a clustering model, while lower values will make it more like a decomposition model:
+
+??? info "Click to see code"
+ ```python
+ import pandas as pd
+ import numpy as np
+ import plotly.express as px
+ from sentence_transformers import SentenceTransformer
+ from datasets import load_dataset
+
+ from turftopic import SensTopic
+
+ ds = load_dataset("gopalkalpande/bbc-news-summary", split="train")
+ corpus = list(ds["Summaries"])
+
+ encoder = SentenceTransformer("all-MiniLM-L6-v2")
+ embeddings = encoder.encode(corpus, show_progress_bar=True)
+
+ models = []
+ doc_topic_ms = []
+ sparsities = np.array(
+ [
+ 0.05,
+ 0.1,
+ 0.25,
+ 0.5,
+ 0.75,
+ 1.0,
+ 2.5,
+ 5.0,
+ 10.0,
+ ]
+ )
+ for i, sparsity in enumerate(sparsities):
+ model = SensTopic(
+ n_components=3, random_state=42, sparsity=sparsity, encoder=encoder
+ )
+ doc_topic = model.fit_transform(corpus, embeddings=embeddings)
+ doc_topic = (doc_topic.T / doc_topic.sum(axis=1)).T
+ models.append(model)
+ doc_topic_ms.append(doc_topic)
+ a_name, b_name, c_name = models[0].topic_names
+ records = []
+ for i, doc_topic in enumerate(doc_topic_ms):
+ for dt in doc_topic:
+ a, b, c, *_ = dt
+ records.append(
+ {
+ "sparsity": sparsities[i],
+ a_name: a,
+ b_name: b,
+ c_name: c,
+ "topic": models[0].topic_names[np.argmax(dt)],
+ }
+ )
+ df = pd.DataFrame.from_records(records)
+ fig = px.scatter_ternary(
+ df, a=a_name, b=b_name, c=c_name, animation_frame="sparsity", color="topic"
+ )
+ fig.show()
+ ```
+
+
+
+ Ternary plot of topic distribution in a 3 topic SensTopic model varying with sparsity.
+
+
+You can see that as the sparsity increases, topics get clustered much more clearly, and more weight gets allocated to the edges of the graph.
+
+To see how many topics there are in your document you can use the `plot_topic_decay()` method, that shows you how topic weights get assigned to documents.
+
+```python
+model.plot_topic_decay()
+```
+
+
+
+ Topic Decay in a SensTopic Model with sparsity=1.
+
+
+## Automatic number of topics
+
+SensTopic can learn the number of topics in a given dataset.
+In order to determine this quantity, we use a version of the Bayesian Information Criterion modified for NMF.
+This does not work equally well for all corpora, but it can be a powerful tool when the number of topics is not known a-priori.
+
+In this example the model finds 6 topics in the BBC News dataset:
+
+```python
+# pip install datasets
+from datasets import load_dataset
+
+ds = load_dataset("gopalkalpande/bbc-news-summary", split="train")
+corpus = list(ds["Summaries"])
+
+model = SensTopic("auto")
+model.fit(corpus)
+model.print_topics()
+```
+
+| Topic ID | Highest Ranking |
+| - | - |
+| 0 | liverpool, mourinho, chelsea, premiership, arsenal, striker, madrid, midfield, uefa, manchester |
+| 1 | oscar, bafta, oscars, cast, cinema, hollywood, actor, screenplay, actors, films |
+| 2 | mobile, mobiles, broadband, devices, digital, internet, computers, microsoft, phones, telecoms |
+| 3 | tory, blair, minister, ministers, parliamentary, mps, parliament, politicians, constituency, ukip |
+| 4 | tennis, competing, federer, wimbledon, iaaf, olympic, tournament, athlete, rugby, olympics |
+| 5 | gdp, stock, economy, earnings, investments, investment, invest, exports, finance, economies |
+
+
+## API Reference
+
+::: turftopic.models.senstopic.SensTopic
diff --git a/docs/images/ternary_sparsity.html b/docs/images/ternary_sparsity.html
new file mode 100644
index 0000000..7d6981a
--- /dev/null
+++ b/docs/images/ternary_sparsity.html
@@ -0,0 +1,3892 @@
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/docs/images/topic_decay.html b/docs/images/topic_decay.html
new file mode 100644
index 0000000..820d474
--- /dev/null
+++ b/docs/images/topic_decay.html
@@ -0,0 +1,3888 @@
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/mkdocs.yml b/mkdocs.yml
index dff95ab..f0a885e 100644
--- a/mkdocs.yml
+++ b/mkdocs.yml
@@ -25,6 +25,7 @@ nav:
- Topic Models:
- Model Overview: model_overview.md
- Semantic Signal Separation (S³): s3.md
+ - SensTopic (BETA): SensTopic.md
- KeyNMF: KeyNMF.md
- Topeax: Topeax.md
- GMM: GMM.md
diff --git a/pyproject.toml b/pyproject.toml
index 245fb63..6531db2 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -33,7 +33,6 @@ dependencies = [
[project.optional-dependencies]
pyro-ppl = ["pyro-ppl>=1.8.0,<2.0.0"]
openai = ["openai>=1.40.0,<2.0.0"]
-opentsne = ["openTSNE>=1.0.0,<2.0.0"]
datamapplot=["datamapplot>=0.4.2, <1.0.0"]
jieba = ["jieba>=0.40.0,<1.0.0"]
spacy = ["spacy>=3.6.0,<4.0.0"]
@@ -52,7 +51,6 @@ docs = [
dev = [
"pyro-ppl>=1.8.0,<2.0.0",
"openai>=1.40.0,<2.0.0",
- "openTSNE>=1.0.0,<2.0.0",
"datamapplot>=0.4.2, <1.0.0",
"jieba>=0.40.0,<1.0.0",
"snowballstemmer>=2.0.0,<3.0.0",
@@ -65,6 +63,7 @@ dev = [
"mkdocstrings==0.22.0",
"mkdocstrings-python==1.8.0",
"griffe==0.40.0",
+ "datasets>=4.3.0"
]
[build-system]
diff --git a/tests/test_integration.py b/tests/test_integration.py
index efe2f05..181e591 100644
--- a/tests/test_integration.py
+++ b/tests/test_integration.py
@@ -18,6 +18,8 @@
FASTopic,
KeyNMF,
SemanticSignalSeparation,
+ SensTopic,
+ Topeax,
load_model,
)
@@ -79,6 +81,8 @@ def generate_dates(
),
AutoEncodingTopicModel(3, combined=True),
FASTopic(3, batch_size=None),
+ SensTopic(),
+ Topeax(),
]
dynamic_models = [
diff --git a/turftopic/__init__.py b/turftopic/__init__.py
index 99250de..8372841 100644
--- a/turftopic/__init__.py
+++ b/turftopic/__init__.py
@@ -3,10 +3,11 @@
from turftopic.base import ContextualModel
from turftopic.error import NotInstalled
from turftopic.models.cluster import BERTopic, ClusteringTopicModel, Top2Vec
-from turftopic.models.decomp import SemanticSignalSeparation
+from turftopic.models.decomp import S3, SemanticSignalSeparation
from turftopic.models.fastopic import FASTopic
from turftopic.models.gmm import GMM
from turftopic.models.keynmf import KeyNMF
+from turftopic.models.senstopic import SensTopic
from turftopic.models.topeax import Topeax
from turftopic.serialization import load_model
@@ -31,4 +32,6 @@
"load_model",
"build_datamapplot",
"create_concept_browser",
+ "S3",
+ "SensTopic",
]
diff --git a/turftopic/models/_snmf.py b/turftopic/models/_snmf.py
new file mode 100644
index 0000000..fcc1d7d
--- /dev/null
+++ b/turftopic/models/_snmf.py
@@ -0,0 +1,143 @@
+"""This file implements semi-NMF, where doc_topic proportions are not allowed to be negative, but components are unbounded."""
+
+import warnings
+from typing import Optional
+
+import numpy as np
+from sklearn.base import BaseEstimator, TransformerMixin
+from sklearn.cluster import KMeans
+from sklearn.preprocessing import label_binarize
+from tqdm import trange
+
+EPSILON = np.finfo(np.float32).eps
+
+try:
+ import jax.numpy as jnp
+ from jax import jit
+except ModuleNotFoundError:
+ warnings.warn("JAX not found, continuing with NumPy implementation.")
+ jnp = np
+
+ # Dummy JIT as the identity function
+ def jit(f):
+ return f
+
+
+def init_G(
+ X, n_components: int, constant=0.2, random_state=None
+) -> np.ndarray:
+ """Returns W"""
+ kmeans = KMeans(n_components, random_state=random_state).fit(X.T)
+ # n_components, n_columns
+ G = label_binarize(kmeans.labels_, classes=np.arange(n_components))
+ return G + constant
+
+
+@jit
+def separate(A):
+ abs_A = jnp.abs(A)
+ pos = (abs_A + A) / 2
+ neg = (abs_A - A) / 2
+ return pos, neg
+
+
+@jit
+def update_F(X, G):
+ return X @ G @ jnp.linalg.inv(G.T @ G)
+
+
+@jit
+def update_G(X, G, F, sparsity=0):
+ pos_xtf, neg_xtf = separate(X.T @ F)
+ pos_gftf, neg_gftf = separate(G @ (F.T @ F))
+ numerator = pos_xtf + neg_gftf
+ denominator = neg_xtf + pos_gftf
+ denominator += sparsity
+ denominator = jnp.maximum(denominator, EPSILON)
+ delta_G = jnp.sqrt(numerator / denominator)
+ G *= delta_G
+ return G
+
+
+@jit
+def rec_err(X, F, G):
+ err = X - (F @ G.T)
+ return jnp.linalg.norm(err)
+
+
+class SNMF(TransformerMixin, BaseEstimator):
+ def __init__(
+ self,
+ n_components: int,
+ tol: float = 1e-5,
+ max_iter: int = 200,
+ progress_bar: bool = True,
+ random_state: Optional[int] = None,
+ sparsity: float = 0.0,
+ verbose: bool = False,
+ ):
+ self.n_components = n_components
+ self.tol = tol
+ self.max_iter = max_iter
+ self.progress_bar = progress_bar
+ self.random_state = random_state
+ self.sparsity = sparsity
+ self.verbose = verbose
+
+ def fit_transform(self, X: np.ndarray, y=None):
+ G = init_G(X.T, self.n_components, random_state=self.random_state)
+ F = update_F(X.T, G)
+ error_at_init = rec_err(X.T, F, G)
+ prev_error = error_at_init
+ for i in trange(
+ self.max_iter,
+ desc="Iterative updates.",
+ disable=not self.progress_bar,
+ ):
+ G = update_G(X.T, G, F, self.sparsity)
+ F = update_F(X.T, G)
+ error = rec_err(X.T, F, G)
+ difference = prev_error - error
+ if (error < error_at_init) and (
+ (prev_error - error) / error_at_init
+ ) < self.tol:
+ if self.verbose:
+ print(f"Converged after {i} iterations")
+ self.n_iter_ = i
+ break
+ prev_error = error
+ if self.verbose:
+ print(
+ f"Iteration: {i}, Error: {error}, init_error: {error_at_init}, difference from previous: {difference}"
+ )
+ else:
+ warnings.warn(
+ "SNMF did not converge, try specifying a higher max_iter."
+ )
+ self.components_ = np.array(F.T)
+ self.reconstruction_err_ = error
+ self.n_iter_ = i
+ return np.array(G)
+
+ def fit_timeslice(self, X_t: np.ndarray, G_t: np.ndarray):
+ F = update_F(X_t.T, G_t)
+ return F.T
+
+ def transform(self, X: np.ndarray):
+ G = jnp.maximum(X @ jnp.linalg.pinv(self.components_), 0)
+ return np.array(G)
+
+ def inverse_transform(self, X):
+ """Transform data back to its original space.
+
+ Parameters
+ ----------
+ X : ndarray of shape (n_samples, n_components)
+ Transformed data matrix.
+
+ Returns
+ -------
+ X_original : ndarray of shape (n_samples, n_features)
+ Returns a data matrix of the original shape.
+ """
+ return X @ self.components_
diff --git a/turftopic/models/decomp.py b/turftopic/models/decomp.py
index 9e4a221..ef21911 100644
--- a/turftopic/models/decomp.py
+++ b/turftopic/models/decomp.py
@@ -3,7 +3,6 @@
from typing import Literal, Optional, Union
import numpy as np
-from PIL import Image
from rich.console import Console
from sentence_transformers import SentenceTransformer
from sklearn.base import TransformerMixin
@@ -131,7 +130,7 @@ def estimate_components(
@property
def has_negative_side(self) -> bool:
- return False
+ return True
def fit_transform(
self, raw_documents, y=None, embeddings: Optional[np.ndarray] = None
@@ -181,6 +180,7 @@ def fit_transform(
self.negative_documents = self.get_top_documents(
raw_documents, document_topic_matrix=doc_topic, positive=False
)
+ self.document_topic_matrix = doc_topic
console.log("Model fitting done.")
return doc_topic
@@ -484,7 +484,7 @@ def fit_transform_dynamic(
)
self.temporal_importance_ = np.zeros((n_bins, n_comp))
whitened_embeddings = np.copy(self.embeddings)
- if getattr(self.decomposition, "whiten"):
+ if getattr(self.decomposition, "whiten", False):
whitened_embeddings -= self.decomposition.mean_
# doc_topic = np.dot(X, self.components_.T)
for i_timebin in np.unique(time_labels):
@@ -641,19 +641,21 @@ def transform(
def print_topics(
self,
- top_k: int = 5,
+ top_k: int = 10,
show_scores: bool = False,
- show_negative: bool = True,
+ show_negative: bool = False,
):
+ show_negative = self.has_negative_side
super().print_topics(top_k, show_scores, show_negative)
def export_topics(
self,
- top_k: int = 5,
+ top_k: int = 10,
show_scores: bool = False,
show_negative: bool = True,
format: str = "csv",
) -> str:
+ show_negative = self.has_negative_side
return super().export_topics(top_k, show_scores, show_negative, format)
def print_representative_documents(
@@ -661,9 +663,10 @@ def print_representative_documents(
topic_id,
raw_documents,
document_topic_matrix=None,
- top_k=5,
- show_negative: bool = True,
+ top_k=10,
+ show_negative: bool = False,
):
+ show_negative = self.has_negative_side
super().print_representative_documents(
topic_id,
raw_documents,
@@ -677,10 +680,11 @@ def export_representative_documents(
topic_id,
raw_documents,
document_topic_matrix=None,
- top_k=5,
- show_negative: bool = True,
+ top_k=10,
+ show_negative: bool = False,
format: str = "csv",
):
+ show_negative = self.has_negative_side
return super().export_representative_documents(
topic_id,
raw_documents,
@@ -958,3 +962,7 @@ def _topics_over_time(
fields.append(concat_words)
rows.append(fields)
return [columns, *rows]
+
+
+# Alias for base class SemanticSignalSeparation
+S3 = SemanticSignalSeparation
diff --git a/turftopic/models/senstopic.py b/turftopic/models/senstopic.py
new file mode 100644
index 0000000..cadc37a
--- /dev/null
+++ b/turftopic/models/senstopic.py
@@ -0,0 +1,494 @@
+from datetime import datetime
+from functools import partial
+from typing import Literal, Optional, Union
+
+import numpy as np
+from rich.console import Console
+from sentence_transformers import SentenceTransformer
+from sklearn.exceptions import NotFittedError
+from sklearn.feature_extraction.text import CountVectorizer
+from sklearn.manifold import TSNE
+from sklearn.metrics.pairwise import cosine_similarity
+
+from turftopic._datamapplot import build_datamapplot
+from turftopic.base import ContextualModel, Encoder
+from turftopic.dynamic import DynamicTopicModel
+from turftopic.encoders.multimodal import MultimodalEncoder
+from turftopic.models._snmf import SNMF
+from turftopic.multimodal import (
+ ImageRepr,
+ MultimodalEmbeddings,
+ MultimodalModel,
+)
+from turftopic.optimization import (
+ optimize_n_components,
+)
+from turftopic.vectorizers.default import default_vectorizer
+
+NOT_MATCHING_ERROR = (
+ "Document embedding dimensionality ({n_dims}) doesn't match term embedding dimensionality ({n_word_dims}). "
+ + "Perhaps you are using precomputed embeddings but forgot to pass an encoder to your model. "
+ + "Try to initialize the model with the encoder you used for computing the embeddings."
+)
+
+
+def bic_snmf(
+ n_components: int, sparsity: float, X, random_state: int = 42
+) -> float:
+ decomp = SNMF(
+ n_components=n_components,
+ sparsity=sparsity,
+ random_state=42,
+ verbose=False,
+ progress_bar=False,
+ )
+ doc_topic = decomp.fit_transform(X)
+ rss = np.square(decomp.reconstruction_err_)
+ n_docs, n_dims = X.shape
+ # BIC1 from https://pmc.ncbi.nlm.nih.gov/articles/PMC9181460/
+ bic1 = np.log(rss) + n_components * (
+ (n_docs + n_dims) / (n_docs * n_dims)
+ ) * np.log((n_docs * n_dims) / (n_docs + n_dims))
+ return bic1
+
+
+class SensTopic(ContextualModel, DynamicTopicModel, MultimodalModel):
+ """Semi-nonnegative Semantic Signal Separation.
+
+ ```python
+ from turftopic import SensTopic
+
+ corpus: list[str] = ["some text", "more text", ...]
+
+ model = SensTopic(10).fit(corpus)
+ model.print_topics()
+ ```
+
+ Parameters
+ ----------
+ n_components: int, default "auto"
+ Number of topics.
+ If "auto", the number of topics is determined using BIC.
+ encoder: str or SentenceTransformer
+ Model to encode documents/terms, all-MiniLM-L6-v2 is the default.
+ vectorizer: CountVectorizer, default None
+ Vectorizer used for term extraction.
+ Can be used to prune or filter the vocabulary.
+ max_iter: int, default 200
+ Maximum number of iterations for S-NMF.
+ feature_importance: "axial", "angular" or "combined", default "combined"
+ Defines whether the word's position on an axis ('axial'), it's angle to the axis ('angular')
+ or their combination ('combined') should determine the word's importance for a topic.
+ random_state: int, default None
+ Random state to use so that results are exactly reproducible.
+ sparsity: float, default 1
+ L1 penalty applied to document-topic proportions.
+ Higher values push the model to assign fewer topics to a single document,
+ while lower values will distribute topics across documents.
+ """
+
+ def __init__(
+ self,
+ n_components: Union[int, Literal["auto"]] = "auto",
+ encoder: Union[
+ Encoder, str, MultimodalEncoder
+ ] = "sentence-transformers/all-MiniLM-L6-v2",
+ vectorizer: Optional[CountVectorizer] = None,
+ max_iter: int = 200,
+ feature_importance: Literal[
+ "axial", "angular", "combined"
+ ] = "combined",
+ random_state: Optional[int] = None,
+ sparsity: float = 1,
+ ):
+ self.n_components = n_components
+ self.encoder = encoder
+ self.feature_importance = feature_importance
+ if isinstance(encoder, str):
+ self.encoder_ = SentenceTransformer(encoder)
+ else:
+ self.encoder_ = encoder
+ self.validate_encoder()
+ if vectorizer is None:
+ self.vectorizer = default_vectorizer()
+ else:
+ self.vectorizer = vectorizer
+ self.max_iter = max_iter
+ self.random_state = random_state
+ self.sparsity = sparsity
+
+ def estimate_components(
+ self, feature_importance: Literal["axial", "angular", "combined"]
+ ) -> np.ndarray:
+ """Reestimates components based on the chosen feature_importance method."""
+ if feature_importance == "axial":
+ self.components_ = self.axial_components_
+ elif feature_importance == "angular":
+ self.components_ = self.angular_components_
+ elif feature_importance == "combined":
+ self.components_ = (
+ np.square(self.axial_components_) * self.angular_components_
+ )
+ if hasattr(self, "axial_temporal_components_"):
+ if feature_importance == "axial":
+ self.temporal_components_ = self.axial_temporal_components_
+ elif feature_importance == "angular":
+ self.temporal_components_ = self.angular_temporal_components_
+ elif feature_importance == "combined":
+ self.temporal_components_ = (
+ np.square(self.axial_temporal_components_)
+ * self.angular_temporal_components_
+ )
+ return self.components_
+
+ def fit_transform(
+ self, raw_documents, y=None, embeddings: Optional[np.ndarray] = None
+ ) -> np.ndarray:
+ console = Console()
+ self.embeddings = embeddings
+ with console.status("Fitting model") as status:
+ if self.embeddings is None:
+ status.update("Encoding documents")
+ self.embeddings = self.encoder_.encode(raw_documents)
+ console.log("Documents encoded.")
+ if self.n_components == "auto":
+ status.update("Finding the number of components.")
+ self.n_components_ = optimize_n_components(
+ partial(
+ bic_snmf, X=self.embeddings, sparsity=self.sparsity
+ ),
+ min_n=1,
+ verbose=True,
+ )
+ console.log("N components set at: " + str(self.n_components_))
+ else:
+ self.n_components_ = self.n_components
+ self.decomposition = SNMF(
+ self.n_components_,
+ max_iter=self.max_iter,
+ sparsity=self.sparsity,
+ random_state=self.random_state,
+ )
+ status.update("Decomposing embeddings")
+ doc_topic = self.decomposition.fit_transform(self.embeddings, y=y)
+ console.log("Decomposition done.")
+ status.update("Extracting terms.")
+ vocab = self.vectorizer.fit(raw_documents).get_feature_names_out()
+ console.log("Term extraction done.")
+ if getattr(self, "vocab_embeddings", None) is None:
+ status.update("Encoding vocabulary")
+ self.vocab_embeddings = self.encoder_.encode(vocab)
+ if self.vocab_embeddings.shape[1] != self.embeddings.shape[1]:
+ raise ValueError(
+ NOT_MATCHING_ERROR.format(
+ n_dims=self.embeddings.shape[1],
+ n_word_dims=self.vocab_embeddings.shape[1],
+ )
+ )
+ console.log("Vocabulary encoded.")
+ status.update("Estimating term importances")
+ vocab_topic = self.decomposition.transform(self.vocab_embeddings)
+ self.axial_components_ = vocab_topic.T
+ if self.feature_importance == "axial":
+ self.components_ = self.axial_components_
+ elif self.feature_importance == "angular":
+ self.components_ = self.angular_components_
+ elif self.feature_importance == "combined":
+ self.components_ = (
+ np.square(self.axial_components_)
+ * self.angular_components_
+ )
+ self.top_documents = self.get_top_documents(
+ raw_documents, document_topic_matrix=doc_topic
+ )
+ self.document_topic_matrix = doc_topic
+ console.log("Model fitting done.")
+ return doc_topic
+
+ def fit_transform_multimodal(
+ self,
+ raw_documents: list[str],
+ images: list[ImageRepr],
+ y=None,
+ embeddings: Optional[MultimodalEmbeddings] = None,
+ ) -> np.ndarray:
+ self.validate_embeddings(embeddings)
+ console = Console()
+ self.images = images
+ self.multimodal_embeddings = embeddings
+ with console.status("Fitting model") as status:
+ if self.multimodal_embeddings is None:
+ status.update("Encoding documents")
+ self.multimodal_embeddings = self.encode_multimodal(
+ raw_documents, images
+ )
+ console.log("Documents encoded.")
+ self.embeddings = self.multimodal_embeddings["document_embeddings"]
+ if self.n_components == "auto":
+ status.update("Finding the number of components.")
+ self.n_components_ = optimize_n_components(
+ partial(
+ bic_snmf, X=self.embeddings, sparsity=self.sparsity
+ ),
+ min_n=1,
+ verbose=True,
+ )
+ console.log("N components set at: " + str(self.n_components_))
+ else:
+ self.n_components_ = self.n_components
+ self.decomposition = SNMF(
+ self.n_components_,
+ max_iter=self.max_iter,
+ sparsity=self.sparsity,
+ random_state=self.random_state,
+ )
+ status.update("Decomposing embeddings")
+ doc_topic = self.decomposition.fit_transform(self.embeddings, y=y)
+ console.log("Decomposition done.")
+ status.update("Extracting terms.")
+ vocab = self.vectorizer.fit(raw_documents).get_feature_names_out()
+ console.log("Term extraction done.")
+ status.update("Encoding vocabulary")
+ self.vocab_embeddings = self.encode_documents(vocab)
+ if self.vocab_embeddings.shape[1] != self.embeddings.shape[1]:
+ raise ValueError(
+ NOT_MATCHING_ERROR.format(
+ n_dims=self.embeddings.shape[1],
+ n_word_dims=self.vocab_embeddings.shape[1],
+ )
+ )
+ console.log("Vocabulary encoded.")
+ status.update("Estimating term importances")
+ vocab_topic = self.decomposition.transform(self.vocab_embeddings)
+ self.axial_components_ = vocab_topic.T
+ if self.feature_importance == "axial":
+ self.components_ = self.axial_components_
+ elif self.feature_importance == "angular":
+ self.components_ = self.angular_components_
+ elif self.feature_importance == "combined":
+ self.components_ = (
+ np.square(self.axial_components_)
+ * self.angular_components_
+ )
+ console.log("Model fitting done.")
+ status.update("Transforming images")
+ self.image_topic_matrix = self.transform(
+ [], embeddings=self.multimodal_embeddings["image_embeddings"]
+ )
+ self.top_images = self.collect_top_images(
+ images, self.image_topic_matrix
+ )
+ self.top_documents = self.get_top_documents(
+ raw_documents, document_topic_matrix=doc_topic
+ )
+ console.log("Images transformed")
+ return doc_topic
+
+ def fit_transform_dynamic(
+ self,
+ raw_documents,
+ timestamps: list[datetime],
+ embeddings: Optional[np.ndarray] = None,
+ bins: Union[int, list[datetime]] = 10,
+ ) -> np.ndarray:
+ document_topic_matrix = self.fit_transform(
+ raw_documents, embeddings=embeddings
+ )
+ time_labels, self.time_bin_edges = self.bin_timestamps(
+ timestamps, bins
+ )
+ n_comp, n_vocab = self.components_.shape
+ n_bins = len(self.time_bin_edges) - 1
+ self.axial_temporal_components_ = np.full(
+ (n_bins, n_comp, n_vocab),
+ np.nan,
+ dtype=self.components_.dtype,
+ )
+ self.temporal_importance_ = np.zeros((n_bins, n_comp))
+ # doc_topic = np.dot(X, self.components_.T)
+ for i_timebin in np.unique(time_labels):
+ topic_importances = document_topic_matrix[
+ time_labels == i_timebin
+ ].mean(axis=0)
+ self.temporal_importance_[i_timebin, :] = topic_importances
+ t_doc_topic = document_topic_matrix[time_labels == i_timebin]
+ t_embeddings = self.embeddings[time_labels == i_timebin]
+ t_components = self.decomposition.fit_timeslice(
+ t_embeddings, t_doc_topic
+ )
+ ax_t = np.maximum(
+ self.vocab_embeddings @ np.linalg.pinv(t_components), 0
+ )
+ self.axial_temporal_components_[i_timebin, :, :] = ax_t.T
+ self.estimate_components(self.feature_importance)
+ return document_topic_matrix
+
+ @property
+ def angular_components_(self):
+ """Reweights words based on their angle in ICA-space to the axis
+ base vectors.
+ """
+ if not hasattr(self, "axial_components_"):
+ raise NotFittedError("Model has not been fitted yet.")
+ word_vectors = self.axial_components_.T
+ n_topics = self.axial_components_.shape[0]
+ axis_vectors = np.eye(n_topics)
+ cosine_components = cosine_similarity(axis_vectors, word_vectors)
+ return cosine_components
+
+ @property
+ def angular_temporal_components_(self):
+ """Reweights words based on their angle in ICA-space to the axis
+ base vectors in a dynamic model.
+ """
+ if not hasattr(self, "axial_temporal_components_"):
+ raise NotFittedError("Model has not been fitted dynamically.")
+ components = []
+ for axial_components in self.axial_temporal_components_:
+ word_vectors = axial_components.T
+ n_topics = axial_components.shape[0]
+ axis_vectors = np.eye(n_topics)
+ cosine_components = cosine_similarity(axis_vectors, word_vectors)
+ components.append(cosine_components)
+ return np.stack(components)
+
+ def plot_topic_decay(self):
+ try:
+ import plotly.graph_objects as go
+ except (ImportError, ModuleNotFoundError) as e:
+ raise ModuleNotFoundError(
+ "Please install plotly if you intend to use plots in Turftopic."
+ ) from e
+ doc_topic = self.document_topic_matrix
+ topic_proportions = []
+ for dt in doc_topic:
+ sum_dt = dt.sum()
+ if sum_dt > 0:
+ dt /= sum_dt
+ dt = -np.sort(-dt)
+ topic_proportions.append(dt)
+ topic_proportions = np.stack(topic_proportions)
+ med_prop = np.median(topic_proportions, axis=0)
+ upper = np.quantile(topic_proportions, 0.975, axis=0)
+ lower = np.quantile(topic_proportions, 0.025, axis=0)
+ fig = go.Figure(
+ [
+ go.Scatter(
+ name="Median",
+ x=np.arange(self.n_components_),
+ y=med_prop,
+ mode="lines",
+ line=dict(color="rgb(31, 119, 180)"),
+ ),
+ go.Scatter(
+ name="Upper Bound",
+ x=np.arange(self.n_components_),
+ y=upper,
+ mode="lines",
+ marker=dict(color="#444"),
+ line=dict(width=0),
+ showlegend=False,
+ ),
+ go.Scatter(
+ name="Lower Bound",
+ x=np.arange(self.n_components_),
+ y=lower,
+ marker=dict(color="#444"),
+ line=dict(width=0),
+ mode="lines",
+ fillcolor="rgba(68, 68, 68, 0.3)",
+ fill="tonexty",
+ showlegend=False,
+ ),
+ ]
+ )
+ fig = fig.update_layout(
+ template="plotly_white",
+ xaxis_title="Topic Rank",
+ yaxis_title="Topic Proportion",
+ title="Topic Decay",
+ font=dict(family="Merriweather", size=16),
+ )
+ return fig
+
+ def plot_components(
+ self, hover_text: Optional[list[str]] = None, **kwargs
+ ):
+ """Creates an interactive browser plot of the topics in your data using plotly.
+
+ Parameters
+ ----------
+ hover_text: list of str, optional
+ Text to show when hovering over a document.
+
+ Returns
+ -------
+ plot
+ Interactive datamap plot, you can call the `.show()` method to
+ display it in your default browser or save it as static HTML using `.write_html()`.
+ """
+ doc_topic = self.document_topic_matrix
+ coords = TSNE(2, metric="cosine").fit_transform(doc_topic)
+ labels = np.argmax(doc_topic, axis=1)
+ print(np.unique_counts(labels))
+ topics_present = np.sort(np.unique(labels))
+ names = [self.topic_names[i] for i in topics_present]
+ if getattr(self, "topic_descriptions", None) is not None:
+ desc = [self.topic_descriptions[i] for i in topics_present]
+ else:
+ desc = None
+ all_words = self.get_top_words()
+ keywords = [all_words[i] for i in topics_present]
+ fig = build_datamapplot(
+ coords,
+ labels=labels,
+ topic_names=names,
+ top_words=keywords,
+ hover_text=hover_text,
+ topic_descriptions=desc,
+ classes=topics_present,
+ # Boundaries are unlikely to be very clear
+ cluster_boundary_polygons=False,
+ )
+ return fig
+
+ def plot_components_datamapplot(
+ self, hover_text: Optional[list[str]] = None, **kwargs
+ ):
+ """Creates an interactive browser plot of the topics in your data using datamapplot.
+
+ Parameters
+ ----------
+ hover_text: list of str, optional
+ Text to show when hovering over a document.
+
+ Returns
+ -------
+ plot
+ Interactive datamap plot, you can call the `.show()` method to
+ display it in your default browser or save it as static HTML using `.write_html()`.
+ """
+ doc_topic = self.document_topic_matrix
+ coords = TSNE(2, metric="cosine").fit_transform(doc_topic)
+ labels = np.argmax(doc_topic, axis=1)
+ print(np.unique_counts(labels))
+ topics_present = np.sort(np.unique(labels))
+ names = [self.topic_names[i] for i in topics_present]
+ if getattr(self, "topic_descriptions", None) is not None:
+ desc = [self.topic_descriptions[i] for i in topics_present]
+ else:
+ desc = None
+ all_words = self.get_top_words()
+ keywords = [all_words[i] for i in topics_present]
+ fig = build_datamapplot(
+ coords,
+ labels=labels,
+ topic_names=names,
+ top_words=keywords,
+ hover_text=hover_text,
+ topic_descriptions=desc,
+ classes=topics_present,
+ # Boundaries are unlikely to be very clear
+ cluster_boundary_polygons=False,
+ )
+ return fig