diff --git a/docs/Topeax.md b/docs/Topeax.md index f7c8a4d..56b3227 100644 --- a/docs/Topeax.md +++ b/docs/Topeax.md @@ -1,15 +1,127 @@ # Topeax -Topeax is a probabilistic topic model based on the Peax clustering model, which finds topics based on peaks in point density in the embedding space. -It can recover the number of topics automatically. +Topeax is a probabilistic topic model based on the Peax clustering model, which finds topics based on peaks in point density in the embedding space. The model can recover the number of topics automatically. + +In the following example I run a Topeax model on the BBC News corpus, and plot the steps of the algorithm to inspect how our documents have been clustered and why: + +```python +# pip install datasets, plotly +from datasets import load_dataset +from turftopic import Topeax + +ds = load_dataset("gopalkalpande/bbc-news-summary", split="train") +topeax = Topeax(random_state=42) +doc_topic = topeax.fit_transform(list(ds["Summaries"])) + +fig = topeax.plot_steps(hover_text=[text[:200] for text in corpus]) +fig.show() +``` + +
+ +
Figure 1: Steps in a Topeax model fitted on BBC News displayed on an interactive graph.
+
+ +```python +topeax.print_topics() +``` + + +| Topic ID | Highest Ranking | +| - | - | +| 0 | mobile, microsoft, digital, technology, broadband, phones, devices, internet, mobiles, computer | +| 1 | economy, growth, economic, deficit, prices, gdp, inflation, currency, rates, exports | +| 2 | profits, shareholders, shares, takeover, shareholder, company, profit, merger, investors, financial | +| 3 | film, actor, oscar, films, actress, oscars, bafta, movie, awards, actors | +| 4 | band, album, song, singer, concert, rock, songs, rapper, rap, grammy | +| 5 | tory, blair, labour, ukip, mps, minister, election, tories, mr, ministers | +| 6 | olympic, tennis, iaaf, federer, wimbledon, doping, roddick, champion, athletics, olympics | +| 7 | rugby, liverpool, england, mourinho, chelsea, premiership, arsenal, gerrard, hodgson, gareth | + +## How does Topeax work? + +The Topeax algorithm, similar to clustering topic models consists of two consecutive steps. +One of them discovers the underlying clusters in the data, the other one estimates term importance scores for each topic in the corpus.
-
Schematic overview of the steps of the Peax clustering algorithm
+
Figure 2: Schematic overview of the steps of the Peax clustering algorithm
-:warning: **This part of the documentation is still under construction, as more details and a paper are on their way.** :warning: +### 1. Clustering + + +Documents embeddings first get projected into two-dimensional space using t-SNE. +In order to identify clusters, we first calculate a Kernel Density Estimate over the embedding space, +then find local maxima in the KDE by grid approximation. +When we discover local maxima (peaks), we assume these to be cluster means. +Cluster density is then approximated with a Gaussian Mixture, where we fix means to the density peaks and then use expectation-maximization to fit the rest of the parameters. (see Figure 2) +Documents are then assigned to the component with the highest responsibility: + +$$\hat{z_d} = arg max_k (r_{kd}); r_{kd}=p(z_k=1 | \hat{x}_d)$$ + +where $z_d$ is the cluster label for document $d$, $r_{kd}$ is the responsibility of component $k$ for document $d$ and $\hat{x}_d$ is the 2D embedding of document $d$. + +### 2. Term Importance Estimation + +Topeax uses a combined semantic-lexical term importance, which is the geometric mean of the NPMI method (see [Clustering Topic Models](clustering.md) for more detail) and a slightly modified centroid-based method. +The modified centroids are calculated like so: + +$$t_k = \frac{\sum_d r_{kd} \cdot x_d}{\sum_d r_{kd}}$$ + +where $t_k$ is the embedding of topic $k$ and $x_d$ is the embedding of document $d$. + +## Visualization + +Topeax has a number of plots available that can aid you when interpreting your results: + +### Density Plots + +One can plot the kernel density estimate on both a 2D and a 3D plot. + +```python +topeax.plot_density() +``` + +
+ +
Figure 2: Density contour plot of the Topeax model.
+
+ +```python +topeax.plot_density3d() +``` + +
+ +
Figure 3: 3D Density Surface of the Topeax model.
+
+ +### Component Plots + +You can also create a plot over the mixture components/clusters found by the model. + +```python +topeax.plot_components() +``` + +
+ +
Figure 4: Gaussian components estimated for the model.
+
+ +You can also create a datamapplot figure similar to clustering models: + +```python +# pip install turftopic[datamapplot] +topeax.plot_components_datamapplot() +``` + +
+ +
Figure 5: Datapoints colored by mixture components on a datamapplot.
+
## API Reference diff --git a/docs/images/topeax_components.html b/docs/images/topeax_components.html new file mode 100644 index 0000000..1f55909 --- /dev/null +++ b/docs/images/topeax_components.html @@ -0,0 +1,3888 @@ + + + +
+
+ + \ No newline at end of file diff --git a/docs/images/topeax_components_datamapplot.html b/docs/images/topeax_components_datamapplot.html new file mode 100644 index 0000000..f99a33c --- /dev/null +++ b/docs/images/topeax_components_datamapplot.html @@ -0,0 +1,676 @@ + + + + + Interactive Data Map + + + + + + + + + + + + + + + + + + + +
+ Loading... +
+
+
+
+
+ + + + +
+
+ +
+
+ +
+
+
+
+
+ +
+ +
+ +
+
+
+ + Point Data: 0% + +
+
+ + Label Data: 0% + +
+
+ + Meta Data: 0% + +
+
+ +
+

1_economy_growth_economic_deficit

+

Keywords: economy, growth, economic, deficit, gdp, prices, inflation, currency, rates, exports

+ + 7.28% of all documents +
+

+
+ + + + + + \ No newline at end of file diff --git a/docs/images/topeax_density.html b/docs/images/topeax_density.html new file mode 100644 index 0000000..53d6390 --- /dev/null +++ b/docs/images/topeax_density.html @@ -0,0 +1,3888 @@ + + + +
+
+ + \ No newline at end of file diff --git a/docs/images/topeax_density_3d.html b/docs/images/topeax_density_3d.html new file mode 100644 index 0000000..bc35c88 --- /dev/null +++ b/docs/images/topeax_density_3d.html @@ -0,0 +1,3888 @@ + + + +
+
+ + \ No newline at end of file diff --git a/docs/images/topeax_steps.html b/docs/images/topeax_steps.html new file mode 100644 index 0000000..8f4c121 --- /dev/null +++ b/docs/images/topeax_steps.html @@ -0,0 +1,3888 @@ + + + +
+
+ + \ No newline at end of file diff --git a/mkdocs.yml b/mkdocs.yml index d7c928e..dff95ab 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -26,6 +26,7 @@ nav: - Model Overview: model_overview.md - Semantic Signal Separation (S³): s3.md - KeyNMF: KeyNMF.md + - Topeax: Topeax.md - GMM: GMM.md - Clustering Models (BERTopic & Top2Vec): clustering.md - Autoencoding Models (ZeroShotTM & CombinedTM): ctm.md diff --git a/papers/topeax/Merriweather.ttf b/papers/topeax/Merriweather.ttf new file mode 100644 index 0000000..558f666 Binary files /dev/null and b/papers/topeax/Merriweather.ttf differ diff --git a/papers/topeax/figures/.xdp-performance.svg.2025_11_06_21_47_55.0.svg-LKVaas b/papers/topeax/figures/.xdp-performance.svg.2025_11_06_21_47_55.0.svg-LKVaas new file mode 100644 index 0000000..e30a7c2 --- /dev/null +++ b/papers/topeax/figures/.xdp-performance.svg.2025_11_06_21_47_55.0.svg-LKVaas @@ -0,0 +1,112 @@ + + + +TopeaxTop2VecBERTopic diff --git a/papers/topeax/figures/20ng_groups.png b/papers/topeax/figures/20ng_groups.png new file mode 100644 index 0000000..e83244a Binary files /dev/null and b/papers/topeax/figures/20ng_groups.png differ diff --git a/papers/topeax/figures/20ng_groups.svg b/papers/topeax/figures/20ng_groups.svg new file mode 100644 index 0000000..605207b --- /dev/null +++ b/papers/topeax/figures/20ng_groups.svg @@ -0,0 +1,138 @@ + + + + diff --git a/papers/topeax/figures/bbc_density_paper_ready.svg b/papers/topeax/figures/bbc_density_paper_ready.svg new file mode 100644 index 0000000..e25012d --- /dev/null +++ b/papers/topeax/figures/bbc_density_paper_ready.svg @@ -0,0 +1,68 @@ + + + + diff --git a/papers/topeax/figures/bbc_news_density.png b/papers/topeax/figures/bbc_news_density.png new file mode 100644 index 0000000..0d1ad89 Binary files /dev/null and b/papers/topeax/figures/bbc_news_density.png differ diff --git a/papers/topeax/figures/bbc_news_light.png b/papers/topeax/figures/bbc_news_light.png new file mode 100644 index 0000000..1fcd94e Binary files /dev/null and b/papers/topeax/figures/bbc_news_light.png differ diff --git a/papers/topeax/figures/cluster_overlap_20news.png b/papers/topeax/figures/cluster_overlap_20news.png new file mode 100644 index 0000000..44bffd6 Binary files /dev/null and b/papers/topeax/figures/cluster_overlap_20news.png differ diff --git a/papers/topeax/figures/clustering_models.png b/papers/topeax/figures/clustering_models.png new file mode 100644 index 0000000..c9fa218 Binary files /dev/null and b/papers/topeax/figures/clustering_models.png differ diff --git a/papers/topeax/figures/clustering_models.svg b/papers/topeax/figures/clustering_models.svg new file mode 100755 index 0000000..720c4b9 --- /dev/null +++ b/papers/topeax/figures/clustering_models.svg @@ -0,0 +1,2878 @@ + + + +...catcarmomEncodingReductionClusteringEstimationXDX^z𝛽t0t1t2...catcarmomEncodingReductionClusteringEstimationXDX^z𝛽t0t1t2 diff --git a/papers/topeax/figures/density_20news.png b/papers/topeax/figures/density_20news.png new file mode 100644 index 0000000..7ae79fe Binary files /dev/null and b/papers/topeax/figures/density_20news.png differ diff --git a/papers/topeax/figures/peax.png b/papers/topeax/figures/peax.png new file mode 100644 index 0000000..fa2a7bc Binary files /dev/null and b/papers/topeax/figures/peax.png differ diff --git a/papers/topeax/figures/peax.svg b/papers/topeax/figures/peax.svg new file mode 100644 index 0000000..0cc17f9 --- /dev/null +++ b/papers/topeax/figures/peax.svg @@ -0,0 +1,173 @@ + + + +0. TSN Embeddings1. Kernel Density Estimate2. Peak detection3. GMM Approximation diff --git a/papers/topeax/figures/performance.png b/papers/topeax/figures/performance.png new file mode 100644 index 0000000..ddf6054 Binary files /dev/null and b/papers/topeax/figures/performance.png differ diff --git a/papers/topeax/figures/performance.svg b/papers/topeax/figures/performance.svg new file mode 100644 index 0000000..9f436ad --- /dev/null +++ b/papers/topeax/figures/performance.svg @@ -0,0 +1,112 @@ + + + + diff --git a/papers/topeax/figures/perplexity_robustness.png b/papers/topeax/figures/perplexity_robustness.png new file mode 100644 index 0000000..5b39f3f Binary files /dev/null and b/papers/topeax/figures/perplexity_robustness.png differ diff --git a/papers/topeax/figures/perplexity_robustness.svg b/papers/topeax/figures/perplexity_robustness.svg new file mode 100644 index 0000000..052a9b4 --- /dev/null +++ b/papers/topeax/figures/perplexity_robustness.svg @@ -0,0 +1,101 @@ + + + + diff --git a/papers/topeax/figures/robustness_perplexity.png b/papers/topeax/figures/robustness_perplexity.png new file mode 100644 index 0000000..df437fa Binary files /dev/null and b/papers/topeax/figures/robustness_perplexity.png differ diff --git a/papers/topeax/figures/robustness_sample_size.png b/papers/topeax/figures/robustness_sample_size.png new file mode 100644 index 0000000..9fccd1f Binary files /dev/null and b/papers/topeax/figures/robustness_sample_size.png differ diff --git a/papers/topeax/figures/subsampling.png b/papers/topeax/figures/subsampling.png new file mode 100644 index 0000000..d453068 Binary files /dev/null and b/papers/topeax/figures/subsampling.png differ diff --git a/papers/topeax/figures/subsampling.svg b/papers/topeax/figures/subsampling.svg new file mode 100644 index 0000000..0354ab8 --- /dev/null +++ b/papers/topeax/figures/subsampling.svg @@ -0,0 +1,81 @@ + + + + diff --git a/papers/topeax/main.html b/papers/topeax/main.html new file mode 100644 index 0000000..b4ec68e --- /dev/null +++ b/papers/topeax/main.html @@ -0,0 +1,10 @@ + + + + + + + +

A Fluid Dynamic Model for Glacier Flow

+ + diff --git a/papers/topeax/main.pdf b/papers/topeax/main.pdf new file mode 100644 index 0000000..975d253 Binary files /dev/null and b/papers/topeax/main.pdf differ diff --git a/papers/topeax/main.typ b/papers/topeax/main.typ new file mode 100644 index 0000000..769aecb --- /dev/null +++ b/papers/topeax/main.typ @@ -0,0 +1,455 @@ +#show title: set text(size: 18pt) +#show title: set align(left) +#show figure.caption: set align(left) + +#set text( + size: 12pt, + weight: "medium", +) +#set page( + paper: "a4", + margin: (x: 1.8cm, y: 1.5cm), +) +#set highlight( + fill: rgb("#ddddff"), + radius: 5pt, + extent: 3pt +) + +#let appendix(body) = { + set heading(numbering: "A", supplement: [Appendix]) + counter(heading).update(0) + body +} + +#title[ + #highlight[Topeax] - + An Improved Clustering Topic Model with Density Peak Detection and Lexical-Semantic Term Importance +] + +#v(10pt) +#par[ + *Márton Kardos* \ + Aarhus University \ +#link("mailto:martonkardos@cas.au.dk") +] + +== Abstract + +#text[ + Text clustering is today the most popular paradigm for topic modelling, both in academia and industry. + Despite clustering topic models' apparent success, we identify a number of issues in Top2Vec and BERTopic, which remain largely unsolved. + Firstly, these approaches are unreliable at discovering natural clusters in corpora, due to extreme sensitivity to sample size and hyperparameters, the default values of which result in suboptimal behaviour. + Secondly, when estimating term importance, BERTopic ignores the semantic distance of keywords to topic vectors, while Top2Vec ignores word counts in the corpus. + This results in, on the one hand, less coherent topics due to the presence of stop words and junk words, + and lack of variety and trust on the other. + In this paper, I introduce a new approach, *#highlight[Topeax]*, which discovers the number of clusters from peaks in density estimates, + and combines lexical and semantic indices of term importance to gain high-quality topic keywords. + Topeax is demonstrated to be better at both cluster recovery and cluster description than Top2Vec and BERTopic, + while also exhibiting less erratic behaviour in response to changing sample size and hyperparameters. +] + +#set heading(numbering: "1.") + += Introduction + +== Clustering Topic Models + +#figure( + image("figures/clustering_models.png", width: 100%), + caption: [Schematic overview of clustering topic models' steps.], +) + + += Model Specification + +I introduce Topeax, a novel topic modelling approach based on document clustering. +The model differs in a number of aspects from traditional clustering topic models like BERTopic and Top2Vec. The model is implemented in the Turftopic Python package (cite), following scikit-learn API conventions. +Example usage is presented in @example_code. + +#figure( + image("figures/peax.png", width: 100%), + caption: [A schematic overview of the Peax clustering algorithm. + \ Illustrations were generated from the _political ideologies dataset#footnote[https://huggingface.co/datasets/JyotiNayak/political_ideologies]._], +) + +== Dimensionality Reduction + +Unlike other clustering topic models, Topeax relies on +t-Distributed Stochastic Neighbour Embeddings (cite it here) instead of UMAP. +I use the the cosine metric to calculate document similarities for TSNE, +as it is widely used for model training and downstream applications. +The number of dimensions was fixed to 2 in all of our experiments, +as this allows us to visualize the reduced embeddings. +Additionally, TSNE has fewer hyperparameters than UMAP. +While it has been demonstrated that TSNE can be sensitive the chosen value of `perplexity`, +we will show that, within a reasonable range, this will not have an effect on the number of topics +or topic quality. + + +== The Peax Clustering Model + +While HDBSCAN is the choice of clustering model for both BERTopic and Top2Vec, +I introduce a new technique for document clustering, termed *#highlight[Peax]*, which, +instead, clusters documents based on density peaks in the reduced document space. + + +The Peax algorithm consists of the following steps: + ++ A Gaussian Kernel Density Estimate (KDE) is obtained over the reduced document embeddings. + Bandwidth is determined with the Scott method. ++ The KDE is evaluated on a 100x100 grid over the embedding space. + Density peaks are then detected by applying a local-maximum filter to the KDE heatmap. + A neighbourhood connectivity of 25 is used, which means, + every pixel is included within a 5 unit radius. ++ Cluster centres are assigned to these density peaks. + The density structure of each cluster is estimated + by fitting a Gaussian mixture model, with its means fixed to the peaks, using the Expectation-Maximization algorithm. + Documents are assigned to the component with the highest responsibility: + \ #align(center)[$accent(z_d, "^") = arg max_(k) (r_("kd"))" , and " r_("kd")=p(z_k=1 | accent(x, "^")_d)$] + where $accent(z_d, "^")$ is the estimated underlying component assigned to document $d$, + $accent(x, "^")_d$ is the TSN embedding of document $d$, and $r_("kd")$ is the responsibility of component $k$ for document $d$. + +#figure( + placement: top, + image("figures/bbc_news_light.png", width: 80%), + caption: [Topeax model illustrated on the BBC News dataset. Topics are identified at density peaks, and keywords get selected based on combined term importance.\ + _Left_: Density plot in 2D with topic names and keywords. + _Right_: Density landscape in 3D with topic names. + +], +) + +== Term Importance Estimation + +To mitigate the issues experienced with c-TF-IDF and centroid-based term importance estimation in previously proposed clustering topic models, +I introduce a novel approach that uses a combination of a semantic and a lexical cluster-term importance. + +=== Semantic Importance + +Semantic term importance is estimated similar to (cite Top2Vec), but, +since we have access to a probabilistic, non-spherical model, and cluster boundaries are not hard, +topic vectors are estimated from the responsibility-weighted average of document embeddings. \ +#align(center)[$t_k = frac(sum_(d) r_("kd") dot x_d, sum_(d) r_("kd"))$] +where $t_k$ is the embedding of topic $k$ and $x_d$ is the embedding of document $d$. +Let the embedding of term $j$ be $w_j$. The semantic importance of term $j$ for cluster $k$ is then: +#align(center)[$s_("kj") = cos(t_k, w_j)$] + +=== Lexical Importance + +Instead of relying on a tf-idf-based measure for computing the valence of a term in a corpus, +an information-theoretical approach is used. +Theoretically, we can estimate the lexical importance of a term for a cluster, +by computing the mutual information of the term's occurrence with the cluster's occurrence. +Due to its convenient interpretability properties, I opt for using normalized pointwise mutual information (NPMI), +which has been historically used for phrase detection (cite) and topic-coherence evaluation (cite). + +We calculate the pointwise mutual information by taking the logarithm of the fraction of conditional and marginal word probabilities: +#align(center)[$"pmi"_("kj") = log_2 frac(p(v_j|z_k), p(v_j))$] +where $p(v_j|z_k)$ is the conditional probability of word $j$ given the presence of topic $z$, +and $p(v_j)$ is the probability of word $j$ occurring. + +A naive approach might include estimating these probabilities empirically: +#align(center)[$p(v_j) = frac(n_j, sum_i n_i)", and " p(v_j | z_k) = frac(n_("jt"), sum_i n_("it"))$] +where $n_j$ is the number of times word $j$ occurs, $n_"jt"$ is the number of times word $j$ occurs in cluster $t$. + +This would, however, overestimate the importance of rare words in the clusters where they appear. +We can therefore opt for a mean-a-posteriori estimate under a symmetric dirichlet prior with an $alpha$ _smoothing_ parameter, +which is analyticaly tractable: +#align(center)[$p(v_j) = frac(n_j + alpha, N alpha + sum_i n_i)", and " p(v_j | z_k) = frac(n_("jt") + alpha, N alpha + sum_i n_("it"))$] +where $N$ is the size of the vocabulary. In further analysis, $alpha=2$ will be used. +Since regular PMI scores have no lower bound, we normalize them to obtain NPMI: +#align(center)[$"npmi"_("kj") = frac("pmi"_("kj"), -log_2 p(v_j, z_k))", where " p(v_j, z_k) = p(v_j|z_k) dot p(z_k)$] + +=== Combined Term Importance + +To balance the semantic proximity of keywords to topic embeddings and cluster-term occurrences, +a I introduce a combined approach, which consists of the geometric mean of min-max normalized lexical and semantic scores: + +#align(center)[$beta_("kj") = sqrt(frac(1 + "npmi"_("kj"), 2) dot frac(1 + s_("kj"), 2))$] + + += Experimental Methods + +Since one of the main strengths of clustering approaches, that they can supposedly find the number of clusters in the data, and are not given this information a-priori, +a good clustering topic model should be able to faithfully replicate a human-assigned clustering of the data, and should be able to describe these clusters in a manner that is human-interpretable. I will therefore utilize datasets with gold-standard labels. +In this section I will outline the criteria and considerations taken into account when designing an evaluation procedure: + ++ The number of clusters in the topic model should preferably be not too far from the number of gold categories. ++ Preferably, if two points are in the same gold category, they should also belong together in the predicted clustering, while points that do not, shouldn't. ++ For topic modelling purposes, it is often preferable that the number of clusters is not overly large. + Topic models should, in theory, aid the understanding of a corpus. Using a topic model becomes impractical when the number of topics one has to interpret is over a couple hundred. ++ Topics should be distinct and easily readable. + +Reproducible scripts used for evaluation, along with instructions on how to run them, are made available in the `x-tabdeveloping/topeax-eval`#footnote("https://github.com/x-tabdeveloping/topeax-eval") Github repository. Results for all evaluations can be found in the `results/` directory. + +== Datasets + +In order to evaluate these properties, I used a number of openly available datasets with gold-standard category metadata. +This included all clustering tasks from the new version of the Massive Text Embedding Benchmark `MTEB(eng, v2)` (cite). +To avoid evaluating on the same corpus twice, the P2P variants of the tasks where used. +In addition an annotated Twitter topic-classification dataset, and a BBC News dataset was used. + +#figure( + caption: [Descriptive statistics of the datasets used for evaluation\ _Document length is reported as mean±standard deviation_], + table( + columns: 4, + stroke: none, + align: (left, center, center, center), + table.hline(), + table.header[*Dataset*][*Document Length*\ _N characters_ ][*Corpus Size*\ _N documents_ ][*Clusters* \ _N unique gold labels_], + table.hline(), + [ArXivHierarchicalClusteringP2P],[1008.44±438.01],[2048],[23], + [BiorxivClusteringP2P.v2],[1663.97±541.93],[53787],[26], + [MedrxivClusteringP2P.v2],[1981.20±922.01],[37500],[51], + [StackExchangeClusteringP2P.v2],[1091.06±808.88],[74914],[524], + [TwentyNewsgroupsClustering.v2],[32.04±14.60],[59545],[20], + [TweetTopicClustering],[165.66±68.19],[4374],[6], + [BBCNewsClustering],[1000.46±638.41],[2224],[5], + table.hline(), + ) +) + +== Models + +To compare Topeax with existing approaches, it was run on all corpora alongside BERTopic and Top2Vec. +Implementations were sourced from the Turftopic (cite) Python package. +For the main analysis, default hyperparameters were used from the original BERTopic and Top2Vec packages respectively, +as these give different clusterings, despite having the same pipeline. +All models were run with both the `all-MiniLM-L6-v2`, the slightly larger and higher performing `all-mpnet-base-v2` sentence encoders (cite sbert), as well as Google's `embeddinggemma-300m` +to control for embedding size and quality. +The models were fitted without filtering for stop words and uncommon terms, +since state-of-the art topic models have been shown to be able to handle such information without issues (cite S3). + +== Metrics + +For evaluating model performance, both clustering quality and topic quality was evaluated. +I evaluated the faithfulness of the predicted clustering to the gold labels using the Fowlkes-Mallows index (cite). +The FMI, is very similar to the F1 score for classification, in that it also intends to balance precision and recall. +Unlike F1, however, FMI uses the geometric mean of these quantities: +#align(center)[$"FMI" = N_("TP")/sqrt((N_("TP") + N_("FP")) dot (N_("TP") + N_("FN")))$] +where $N_("TP")$ is the number of pairs of points that get clustered together in both clusterings (true positives), +$N_("FP")$ is the number of pairs that get clustered together in the predicted clustering but not in the gold labels (false positives) and +$N_("FN")$ is the number of pairs that do not get clustered together in the predicted clustering, despite them belonging together in the gold labels (false negatives). + +For topic quality, I adopt the methodology of (cite S3), with minor differences. +I use GloVe embeddings (cite GloVe) for evaluating internal word embedding coherence instead of Skip-gram. +As such, topic quality was evaluated on topic diversity $d$, external word embedding coherence $C_("ex")$ using the `word2vec-google-news-300` word embedding model, +as well as internal word embedding coherence $C_("in")$ with a GloVe model trained on each corpus. +Ideally a model should both have high intrinsic and extrinsic coherence, and thus an aggregate measure of coherence can give a better +estimate of topic quality: $accent(C, -) = sqrt(C_("in") dot C_("ex"))$. +In addition an aggregate metric of topic quality can be calculated by taking the geometric mean of coherence and diversity $I = sqrt(accent(C, -) dot d)$. +We will also refer to this quantity as _interpretability_. + +== Sensitivity to Perplexity + +Both TSNE and UMAP, have a hyperparameter that determines, how many neighbours of a given point are considered when generating lower-dimensional projections, this hyperparameter is usually referred to as _perplexity_. +It is also known that both methods are sensitive to the choice of hyperparameters, and depending on these, structures, that do not exist in the higher-dimensional feature space might occur (cite Distill article and "Understanding UMAP"). +In order to see how this affects the Topeax algorithm, and how robust it is to the choice of this hyperparameter in comparison with other clustering topic models, I fitted each model to the 20 Newsgroups corpus from `scikit-learn`, using `all-MiniLM-L6-v2` with `perplexities=[2, 5, 30, 50, 100]`. +This choice of values was inspired by (cite Distill). Each model was evaluated on the metrics outlined above. + +== Subsampling Invariance + +Ideally, a good topic model should roughly recover the same topics, and same number of topics in a corpus even when we only have access to a subsample of that corpus, assuming that the underlying categories are the same. +On the other hand, we would reasonably assume that a model having access to the full corpus, instead of a subsample, should increase the accuracy of the results, not decrease it. +To evaluate models' ability to cope with subsampling, I fit each model on the same corpus and embeddings as in the perplexity sensitivity test, and evaluate them on the previously outlined metrics. +Subsample sizes are the following: `[250, 1000, 5000, 10_000, "full"]`. + += Results + +Topeax substantially outperformed both Top2Vec and BERTopic in cluster recovery, as well as the quality of the topic keywords (see @performance). +A regression analysis predicting Fowlkes-Mallows index from model type, with random effects and intercepts for encoders and datasets was conducted. +The regression was significant at $alpha=0.05$. ($R^2=0.127$, $F=4.368$, $p=0.0169$). +Both BERTopic and Top2Vec had significantly negative slopes (see @coeffs). + +#figure( + table( + columns: 4, + align: (left, center, center, center), + stroke: none, + table.hline(), + table.header([*Coefficients*], [*Estimate*], [*p-value*], [*95% CI*]), + table.hline(), + [Intercept (_Topeax_)], [0.3405], [0.000], [[0.267, 0.414]], + [Topeax], [-0.1106], [0.038], [[-0.215, -0.006]], + [BERTopic], [-0.1479], [0.006], [[-0.252, -0.044]], + table.hline(), + + ), + caption: [Regression coefficients for predicting Fowlkes-Mallows Index from choice of topic model] +) + +Topeax also exhibited the lowest absolute percentage error in recovering the number of topics (see @performance) with $"MAPE" = 60.52$ ($"SD"=26.19$), +while Top2Vec ($M=1797.29%, "SD"=2622.52$) and BERTopic ($M = 2438.91%,"SD" = 3011.63$) drastically deviated from the number of gold labels in the datasets. +It is also important to note the opposite directionality of these errors. +While Topeax almost universally underestimated the number of topics, especially in `StackExchangeClusteringP2P` and `MedrxivClusteringP2P`, where the number of unique labels was very large, Top2Vec and BERTopic almost always grossly overestimated the number of clusters in the data. +This is undesirable behaviour for a topic model, as topic interpretation requires manual effort, and vast numbers of topics (>500) become difficult and labour-intensive to label for any individual. + + +#figure( +table( + columns: 5, + align: (left, center, center, center, center), + stroke: none, + table.hline(), + table.vline(x: 4), + table.header([*Model*], [*$C_("in")$*], [*$C_("ex")$*], [*$d$*], [*$I$*]), + table.hline(), + [Topeax], [*0.35±0.15*], [#underline[0.32±0.09]], [*0.96±0.05*], [*0.55±0.10*], + [Top2Vec], [0.21±0.11], [*0.39±0.09*], [#underline[0.57±0.29]], [#underline[0.38±0.15]], + [BERTopic], [#underline[0.24±0.12]], [0.17±0.04], [0.64±0.17], [0.35±0.10], + table.hline(), +), +caption: [Metrics of topic quality compared between different models. Best bold, second best underlined. Uncertainty is standard deviation. Higher is better.] +) + +#figure( + image("figures/performance.png", width: 100%), + caption: [Performance comparison of clustering topic models.\ + _Left (Higher is better)_: Fowlkes-Mallows Index against topic interpretability. Large point with error bar represents mean with bootstrapped 95% confidence interval. \ + _Right (Lower is better)_: Distribution of absolute percentage error in finding the number of topics. + ], +) + + +== Perplexity + +Metrics of quality and number of topics across perplexity values can are displayed on @perplexity_robustness. +Topeax converges very early on the number of topics with perplexity, and remains stable from `perplexity=5`, while converges at around `perplexity=30` for quality metrics. It is reasonable to conclude that 50 is a reasonable recommendation and default value. +Meanwhile, BERTopic converges at around `perplexity=50`, and has the lowest performance on all metrics. Top2Vec does not seem to converge at all for the values of perplexity tested, and is most unstable. It does seem to improve with larger values of the hyperparameter. +Keep in mind, that while BERTopic and Top2Vec improve with higher values, their default is set at `perplexity=15`, which, in light of these evaluations, seems rather unreasonable. + + +#figure( + image("figures/robustness_perplexity.png", width: 100%), + caption: [Clustering model's performance at different perplexity values.\ + _Left_: Fowlkes-Mallows Index at different perplexity values, + _Middle_: Topic Interpretability Score at different values of Perplexity, + _Right_: Number of Topics at each value of perplexity against Gold label. + +], +) + + +== Subsampling +Number of topics, topic quality and cluster quality are displayed on @subsampling. +Topeax is relatively well-behaved, and converges to the highest performance when it has access to the full corpus. +The number of topics is also relatively stable across from a sample size of 5000 (hovers around 10-12). +In contrast, BERTopic and Top2Vec do not converge to a single value of N topics and keep growing with the size of the subsample. +This also has an impact on cluster and topic quality. BERTopic has highest performance on the smallest subsamples (250-1000), while Top2Vec has best performance on a subsample of 5000, both methods decrease in performance as the number of topics grows with sample size. This behaviour is far from ideal, and it is apparent that Topeax is much more reliable at determining the number and structure of clusters in subsampled and full corpora. + +#figure( + image("figures/robustness_sample_size.png", width: 100%), + caption: [Topic models' performance at different subsample sizes.\ + _Left_: Fowlkes-Mallows Index as a function of sample size, + _Middle_: Topic Interpretability Score at different subsamples, + _Right_: Number of Topics discovered in each subsample per sample size. + +], +) + +== Qualitative Considerations + +As per the experimental evaluations presented above, Topeax systematically underestimates the number of clusters in a given dataset, despite matching the gold labels better as per the Fowlkes-Mallows index. +This warrants further investigation. A Topeax model was run on 20 Newsgroups with `all-MiniLM-L6-v2` embeddings, where the estimated number of clusters was 11, while the original dataset contains data from 20 categories, as suggested by its name. +Adjusted mutual information was calculated between each topic discovered by the model and each newsgroup (see @20ng_groups). + +#figure( + image("figures/20ng_groups.png", width: 100%), + caption: [Topeax model fit on the 20 Newsgroups Corpus in relation to the gold labels provided in the corpus.\ + _Left_: Density estimate and density peaks annotated with top 4 keywords from each topic.\ + _Right_: Adjusted Mutual Information between cluster labels in the model, and gold labels in the corpus. +], +) <20ng_groups> +While, indeed the number of clusters is less than the categories in the original dataset, the clustering provided by Topeax is arguably just as natural. +Most clusters ended up compressing information from one or two newsgroups, that were in some way related. +For instance the `1_god_atheism_christians_christianity` topic contained documents from `alt.atheism`, `talk.religion.misc` and `soc.religion.christian`, thereby combining discourse on religion into a single topic. Likewise `6_car_bikes_bmw` compresses the `rec.autos` and `rec.motorcycles` newsgroups. +In addition, the model uncovered a topic of outlier documents (`7_yer_umm_ahh__i_`), which were either empty, or only contained a few words, no coherent sentences. + +Meanwhile, BERTopic discovered 232, and Top2Vec 145 topics in the same corpus using the same embeddings, while labelling 34.15% and 35.07% of documents as outliers respectively. +While different users and use cases might have different tolerance levels for time spent on analyzing topics, and the number of outliers, this behaviour seems far from ideal under most circumstances. +Interpreting, and labelling the topics would take a considerable amount of time in both cases. +In addition, regarding more than a third of documents as outliers means that a substantial amount of information is not covered by these models. +This will inevitably prompt users of these topic models to a) hierarchically reduce topics, where they are required to specify the number of topics or b) fiddle with hyperparameters until they arrive at a result they deem sensible. +It is thus questionable, how much these models are at all able to identify the number of natural clusters in a corpus, and until better and more rigorous heuristics are established for hyperparameter selection, their use remains highly subjective and circular. + += Conclusion + +I propose a novel method, Topeax for finding natural clusters in text data, and assigning keywords to these clusters +based on peak finding in kernel-density estimates. +The model is compared to popular clustering topic models, Top2Vec and BERTopic on a number of clustering datasets from the Massive Text Embedding Benchmark. +In addition, models' robustness and stability to sample size and hyperparameter choices is evaluated. +Topeax approximates human clusterings significantly more faithfully than previous approaches and describes topics with more diverse and coherent keywords. +Furthermore, the model exhibits much more sensible behaviour under changing circumstances and hyperparameters. +It is found, however, that Topeax underestimates the number of clusters systematically. +Qualitative investigation suggests that this is due to the model grouping together related clusters in the case of 20 Newsgroups. +In light of these findings, Topeax seems a better choice for text clustering, + += Limitations + +While the model has been shown to perform better than the baselines discussed, there are a number of issues it still exhibits: ++ Topeax underestimates the number of clusters, compared to humans. ++ The model, as of now, cannot be used in an online setting, when new topics have are as new information comes in. + +Some of these issues might be addressed by using emerging dimensionality reduction techniques that allow for aligning between multiple datasets, and projection of out-of-distribution points. +These issues should be subject to further investigation. + +In addition the evaluation methodology also has a number of limitations of its own: ++ Quantitative metrics of topic quality, while roughly correlate with human preference, do not perfectly capture interpretability. Preferably, future research should evaluate topic quality with human subjects. ++ Subsampling and perplexity were only tested on the 20NG corpus in the interest of time and compute. This is of course a limitation, and evaluation on multiple corpora would be preferable. + +#pagebreak() +#heading(level:1, numbering: none, "Appendix") + +#show: appendix + += Example code + +Due to the model being implemented in Turftopic, +it is very easy to run on a new corpus. One first has to install the package: + +```bash +pip install turftopic +``` + +Then run fit the model to a corpus, here's an example with 20 Newsgroups: + +```python +from sklearn.datasets import fetch_20newsgroups +from turftopic import Topeax + +ds = fetch_20newsgroups( + subset="all", + remove=("headers", "footers", "quotes"), +) +corpus = ds.data + +model = Topeax() +model.fit(corpus) +model.print_topics() +``` + +#figure( + caption: [Topics found in the 20 Newsgroups corpus], + table( + columns: 2, + stroke: none, + align: (right, left), + table.hline(), + table.header([ *ID* ], [*Highest Ranking*]), + table.hline(), + [ 0 ], [armenians, armenian, israel, israeli, jews, genocide, turkish, palestinians, palestinian, israelis ], + [ 1 ], [god, christians, atheism, christianity, bible, scripture, christian, theology, faith, church ], + [ 2 ], [ pitching, pitcher, hitter, baseball, braves, batting, pitchers, cubs, sox, fielder ], + [ 3] ,[ hockey, nhl, puck, leafs, sabres, bruins, flyers, islanders, team, canucks ], + [ 4],[ gun, guns, militia, amendment, firearms, homicides, nra, fbi, crime, homicide], + [ 5],[ patients, disease, medical, treatment, doctor, clinical, vitamin, medicine, treatments, infection ], + [ 6],[ car, bike, cars, bmw, honda, engine, motorcycle, ford, dealer, bikes ], + [ 7], [yer, umm, ahhh, \_i\_, \_you\_, cheek, expresses, reacted, ths, advertisement ], + [ 8], [ ax, nasa, spacecraft, a86, satellite, detectors, satellites, spaceflight, max, langley ], + [ 9],[ encryption, nsa, key, privacy, security, clipper, chip, encrypted, crypto, cryptography ], + [ 10], [motherboard, scsi, card, ram, mhz, chipset, bios, hardware, monitor, modem ], + [ 11],[ windows, xfree86, x11r5, x11, openwindows, jpeg, window, xterm, x11r4, microsoft ], + table.hline(), + ), +) diff --git a/pyproject.toml b/pyproject.toml index c91abef..245fb63 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,7 +9,7 @@ profile = "black" [project] name = "turftopic" -version = "0.20.0" +version = "0.21.0" description = "Topic modeling with contextual representations from sentence transformers." authors = [ { name = "Márton Kardos ", email = "martonkardos@cas.au.dk" } diff --git a/turftopic/models/gmm.py b/turftopic/models/gmm.py index 69ebad0..0ceca05 100644 --- a/turftopic/models/gmm.py +++ b/turftopic/models/gmm.py @@ -28,6 +28,7 @@ MultimodalModel, ) from turftopic.optimization import optimize_n_components +from turftopic.utils import confidence_ellipse from turftopic.vectorizers.default import default_vectorizer FEATURE_IMPORTANCE_METHODS = { @@ -411,7 +412,11 @@ def plot_components_datamapplot( return plot def plot_density( - self, hover_text: list[str] = None, show_points=False, light_mode=False + self, + hover_text: list[str] = None, + show_keywords=True, + show_points=False, + light_mode=False, ): try: import plotly.graph_objects as go @@ -428,9 +433,9 @@ def plot_density( warnings.warn( "Embeddings are not in 2d space, only using first 2 dimensions" ) - - coord_min, coord_max = np.min(self.reduced_embeddings), np.max( - self.reduced_embeddings + reduced_embeddings = self.reduced_embeddings[:, :2] + coord_min, coord_max = np.min(reduced_embeddings), np.max( + reduced_embeddings ) coord_spread = coord_max - coord_min coord_min = coord_min - coord_spread * 0.05 @@ -464,8 +469,8 @@ def plot_density( ] if show_points: scatter = go.Scatter( - x=self.reduced_embeddings[:, 0], - y=self.reduced_embeddings[:, 1], + x=reduced_embeddings[:, 0], + y=reduced_embeddings[:, 1], mode="markers", showlegend=False, text=hover_text, @@ -488,13 +493,195 @@ def plot_density( self.gmm_.means_, self.topic_names, self.get_top_words() ): _keys = "" - for i, key in enumerate(keywords): - if (i % 5) == 0: - _keys += "
" - _keys += key - if i < (len(keywords) - 1): - _keys += "," - _keys += " " + if show_keywords: + for i, key in enumerate(keywords): + if (i % 5) == 0: + _keys += "
" + _keys += key + if i < (len(keywords) - 1): + _keys += "," + _keys += " " + text = f"{name} {_keys} " + fig.add_annotation( + text=text, + x=mean[0], + y=mean[1], + align="left", + showarrow=False, + xshift=0, + yshift=50, + font=dict(family="Roboto Mono", size=18, color="black"), + bgcolor="rgba(255,255,255,0.9)", + bordercolor="black", + borderwidth=2, + ) + return fig + + def plot_density_3d(self, show_keywords=False): + try: + import plotly.graph_objects as go + except (ImportError, ModuleNotFoundError) as e: + raise ModuleNotFoundError( + "Please install plotly if you intend to use plots in Turftopic." + ) from e + + if not hasattr(self, "reduced_embeddings"): + raise ValueError( + "No reduced embeddings found, can't display in 2d space." + ) + if self.reduced_embeddings.shape[1] != 2: + warnings.warn( + "Embeddings are not in 2d space, only using first 2 dimensions" + ) + reduced_embeddings = self.reduced_embeddings[:, :2] + coord_min, coord_max = np.min(reduced_embeddings), np.max( + reduced_embeddings + ) + coord_spread = coord_max - coord_min + coord_min = coord_min - coord_spread * 0.05 + coord_max = coord_max + coord_spread * 0.05 + coord = np.linspace(coord_min, coord_max, num=100) + z = [] + for yval in coord: + points = np.stack([coord, np.full(coord.shape, yval)]).T + prob = np.exp(self.gmm_.score_samples(points)) + z.append(prob) + z = np.stack(z) + means = self.gmm_.means_ + means_z = np.exp(self.gmm_.score_samples(means)) + annotations = [] + for (x_mean, y_mean), z_mean, name, keywords in zip( + means, means_z, self.topic_names, self.get_top_words() + ): + _keys = "" + if show_keywords: + for i, key in enumerate(keywords): + if (i % 5) == 0: + _keys += "
" + _keys += key + if i < (len(keywords) - 1): + _keys += "," + _keys += " " + text = f"{name} {_keys} " + annotations.append( + dict( + showarrow=True, + x=x_mean, + y=y_mean, + z=z_mean, + text=text, + font=dict(family="Roboto Mono", size=18, color="black"), + bgcolor="rgba(255,255,255,0.9)", + bordercolor="black", + borderwidth=2, + ) + ) + color_grid = [0.0, 0.25, 0.5, 0.75, 1.0] + colorscale = [ + "#01014B", + "#000080", + "#5D5DEF", + "#B7B7FF", + "#ffffff", + ] + fig = go.Figure( + data=[ + go.Surface( + z=z, + x=coord, + y=coord, + colorscale=list(zip(color_grid, colorscale)), + showscale=False, + ) + ] + ) + fig = fig.update_layout( + margin=dict(l=0, r=0, b=0, t=0), + template="plotly_white", + scene=dict(annotations=annotations), + ) + return fig + + def plot_components( + self, + show_points=False, + show_keywords=True, + hover_text: Optional[list[str]] = None, + ): + try: + import plotly.express as px + import plotly.graph_objects as go + except (ImportError, ModuleNotFoundError) as e: + raise ModuleNotFoundError( + "Please install plotly if you intend to use plots in Turftopic." + ) from e + + if not hasattr(self, "reduced_embeddings"): + raise ValueError( + "No reduced embeddings found, can't display in 2d space." + ) + if self.reduced_embeddings.shape[1] != 2: + warnings.warn( + "Embeddings are not in 2d space, only using first 2 dimensions" + ) + reduced_embeddings = self.reduced_embeddings[:, :2] + coord_min, coord_max = np.min(reduced_embeddings), np.max( + reduced_embeddings + ) + coord_spread = coord_max - coord_min + coord_min = coord_min - coord_spread * 0.05 + coord_max = coord_max + coord_spread * 0.05 + coord = np.linspace(coord_min, coord_max, num=100) + z = [] + for yval in coord: + points = np.stack([coord, np.full(coord.shape, yval)]).T + prob = np.exp(self.gmm_.score_samples(points)) + z.append(prob) + z = np.stack(z) + fig = go.Figure( + [ + go.Contour( + z=z, + x=coord, + y=coord, + colorscale="Greys", + opacity=0.25, + hoverinfo="skip", + showscale=False, + ), + ] + ) + gmm_colors = px.colors.qualitative.Dark24 + for i_std, n_std in enumerate(np.linspace(0.1, 3.0, num=5)): + for name, color, mean, cov in zip( + self.topic_names, + gmm_colors, + self.gmm_.means_, + self.gmm_.covariances_, + ): + fig.add_shape( + legend="legend", + showlegend=False, + type="path", + path=confidence_ellipse(mean, cov, n_std=n_std), + legendgroup=name, + name=0, + legendwidth=0, + fillcolor=color, + opacity=0.1, + ) + for mean, name, keywords in zip( + self.gmm_.means_, self.topic_names, self.get_top_words() + ): + _keys = "" + if show_keywords: + for i, key in enumerate(keywords): + if (i % 5) == 0: + _keys += "
" + _keys += key + if i < (len(keywords) - 1): + _keys += "," + _keys += " " text = f"{name} {_keys} " fig.add_annotation( text=text, @@ -509,4 +696,41 @@ def plot_density( bordercolor="black", borderwidth=2, ) + fig = fig.update_layout( + margin=dict(l=0, r=0, b=0, t=0), + template="plotly_white", + ) + if show_points: + for i, (name, color) in enumerate( + zip(self.topic_names, gmm_colors) + ): + include = self.labels_ == i + text = ( + None + if hover_text is None + else [ + text + for text, in_cluster in zip(hover_text, include) + if in_cluster + ] + ) + scatter = go.Scatter( + x=reduced_embeddings[:, 0][include], + y=reduced_embeddings[:, 1][include], + mode="markers", + showlegend=False, + text=text, + name=name, + legendgroup=name, + hovertemplate=f"{name}
%{{text}}", + marker=dict( + symbol="circle", + opacity=0.5, + color=color, + size=6, + line=dict(width=1), + ), + ) + fig.add_trace(scatter) + fig = fig.update_layout(coloraxis=dict(showscale=False)) return fig diff --git a/turftopic/models/topeax.py b/turftopic/models/topeax.py index 7f17608..a6089f6 100644 --- a/turftopic/models/topeax.py +++ b/turftopic/models/topeax.py @@ -109,6 +109,7 @@ def fit(self, X, y=None): self.classes_ = np.sort(np.unique(self.labels_)) self.means_ = self.gmm_.means_ self.weights_ = self.gmm_.weights_ + self.covariances_ = self.gmm_.covariances_ return self.labels_ @property @@ -158,6 +159,7 @@ def __init__( perplexity=perplexity, random_state=random_state, ) + self.perplexity = perplexity super().__init__( n_components=0, encoder=encoder, @@ -209,3 +211,118 @@ def estimate_components( def _init_model(self, n_components: int): mixture = Peax() return mixture + + def plot_steps(self, hover_text=None): + try: + import plotly.express as px + from plotly.subplots import make_subplots + except (ImportError, ModuleNotFoundError) as e: + raise ModuleNotFoundError( + "Please install plotly if you intend to use plots in Turftopic." + ) from e + dens_3d = self.plot_density_3d() + component_plot = self.plot_components( + show_points=True, hover_text=hover_text + ) + points_plot = px.scatter( + x=self.reduced_embeddings[:, 0], + y=self.reduced_embeddings[:, 1], + template="plotly_white", + ) + points_plot = points_plot.update_layout( + margin=dict(l=0, r=0, b=0, t=0), + ) + points_plot = points_plot.update_traces( + marker=dict( + color="#B7B7FF", + size=6, + opacity=0.5, + line=dict(color="#01014B", width=2), + ) + ) + colormap = { + name: color + for name, color in zip( + self.topic_names, px.colors.qualitative.Dark24 + ) + } + bar = px.bar( + y=self.topic_names, + x=self.weights_, + template="plotly_white", + color_discrete_map=colormap, + color=self.topic_names, + text=[f"{p:.2f}" for p in self.weights_], + ) + bar = bar.update_traces( + marker_line_color="black", + marker_line_width=1.5, + opacity=0.8, + ) + + def update_annotation(a): + name = a.text.removeprefix("").split("<")[0] + return a.update( + # text=name, + font=dict(size=8, color=colormap[name]), + arrowsize=1, + arrowhead=1, + arrowwidth=1, + bgcolor=None, + opacity=0.7, + # bgcolor=colormap[name], + bordercolor=colormap[name], + borderwidth=0, + ) + + fig = make_subplots( + horizontal_spacing=0.0, + vertical_spacing=0.1, + rows=2, + cols=2, + subplot_titles=[ + "t-SN Embeddings", + "Peaks in Kernel Density Estimate", + "Gaussian Mixture Approximation", + "Component Probabilities", + ], + specs=[ + [ + {"type": "xy"}, + {"type": "surface"}, + ], + [ + {"type": "xy"}, + {"type": "bar"}, + ], + ], + ) + for i, sub in enumerate([points_plot, dens_3d, component_plot, bar]): + row = i // 2 + col = i % 2 + for trace in sub.data: + fig.add_trace(trace, row=row + 1, col=col + 1) + for shape in sub.layout.shapes: + fig.add_shape(shape, row=row + 1, col=col + 1) + fig = fig.update_layout( + template="plotly_white", + font=dict(family="Merriweather", size=14, color="black"), + width=1200, + height=800, + autosize=False, + margin=dict(r=0, l=0, t=40, b=0), + ) + fig = fig.update_scenes( + annotations=[ + update_annotation(annotation) + for annotation in dens_3d.layout.scene.annotations + ], + col=2, + row=1, + ) + fig = fig.for_each_annotation(lambda a: a.update(yshift=0)) + fig = fig.update_yaxes(visible=False, row=2, col=2) + fig = fig.update_xaxes( + title=dict(text="$P(z)$", font=dict(size=16)), row=2, col=2 + ) + return fig diff --git a/turftopic/utils.py b/turftopic/utils.py index 25d99b6..1492e8e 100644 --- a/turftopic/utils.py +++ b/turftopic/utils.py @@ -71,3 +71,35 @@ def sanitize_for_html(text: str) -> str: # Removing unnecessary whitespace text = " ".join(text.split()) return text + + +def confidence_ellipse(mean, cov, n_std=1, size=100): + pearson = cov[0, 1] / np.sqrt(cov[0, 0] * cov[1, 1]) + ell_radius_x = np.sqrt(1 + pearson) + ell_radius_y = np.sqrt(1 - pearson) + theta = np.linspace(0, 2 * np.pi, size) + ellipse_coords = np.column_stack( + [ell_radius_x * np.cos(theta), ell_radius_y * np.sin(theta)] + ) + x_scale = np.sqrt(cov[0, 0]) * n_std + y_scale = np.sqrt(cov[1, 1]) * n_std + x_mean, y_mean = mean + translation_matrix = np.tile( + [x_mean, y_mean], (ellipse_coords.shape[0], 1) + ) + rotation_matrix = np.array( + [ + [np.cos(np.pi / 4), np.sin(np.pi / 4)], + [-np.sin(np.pi / 4), np.cos(np.pi / 4)], + ] + ) + scale_matrix = np.array([[x_scale, 0], [0, y_scale]]) + ellipse_coords = ( + ellipse_coords.dot(rotation_matrix).dot(scale_matrix) + + translation_matrix + ) + path = f"M {ellipse_coords[0, 0]}, {ellipse_coords[0, 1]}" + for k in range(1, len(ellipse_coords)): + path += f"L{ellipse_coords[k, 0]}, {ellipse_coords[k, 1]}" + path += " Z" + return path