diff --git a/docs/SensTopic.md b/docs/SensTopic.md index deac1b8..7bee147 100644 --- a/docs/SensTopic.md +++ b/docs/SensTopic.md @@ -12,6 +12,13 @@ If you want to enable hardware acceleration and JIT compilation, make sure to in pip install jax ``` +SensTopic produces the same quality of topics as $S^3$ without having to interpret negative topic descriptions, and also has better scaling properties. Our implementation of SNMF is considerably faster than FastICA, which $S^3$ is based on. + +
+ +
FastICA vs SNMF runtime on the 20 newsgroups dataset with different number of topics.
+
+ Here's an example of running SensTopic on the 20 Newsgroups dataset: ```python diff --git a/docs/benchmark.md b/docs/benchmark.md new file mode 100644 index 0000000..19756bb --- /dev/null +++ b/docs/benchmark.md @@ -0,0 +1,347 @@ +# Model Leaderboard + +To aid you in choosing the best model for your use case, we have made a topic model benchmark and leaderboard. +The benchmark consists of all English P2P clustering tasks from the most recent version of [MTEB](https://huggingface.co/spaces/mteb/leaderboard), plus a tweet and a news dataset, as these are not present in MTEB. + +Models were tested for topic quality using the methodology of [Kardos et al. 2025](https://aclanthology.org/2025.acl-long.32/), +and cluster quality using adjusted mutual information (AMI), the Fowlkes-Mallows index (FMI) and V-measure scores. +All models were run on an older, but still powerful Dell Precision laptop, with 32 GBs of RAM, and i7, which was apparently not enough, +as some models ran out of memory on some of the larger datasets. +Due to this, and the fact that the scale of the scores is different for different tasks, we present the **average percentile** scores on these metrics in the table bellow. + + +??? info "Click to see Benchmark code" + ```python + import argparse + import json + import time + from itertools import chain, combinations + from pathlib import Path + from typing import Callable, Iterable + + import gensim.downloader as api + import mteb + import numpy as np + from datasets import load_dataset + from glovpy import GloVe + from sentence_transformers import SentenceTransformer + from sklearn import metrics + from sklearn.feature_extraction.text import CountVectorizer + from turftopic import (GMM, AutoEncodingTopicModel, BERTopic, FASTopic, KeyNMF, + SemanticSignalSeparation, SensTopic, Top2Vec, Topeax) + + topic_models = { + "Topeax(Auto)": lambda encoder, n_components: Topeax( + encoder=encoder, random_state=42 + ), + "BERTopic(Auto)": lambda encoder, n_components: BERTopic( + encoder=encoder, random_state=42 + ), + "Top2Vec(Auto)": lambda encoder, n_components: Top2Vec( + encoder=encoder, random_state=42 + ), + "SensTopic(Auto)": lambda encoder, n_components: SensTopic( + n_components="auto", encoder=encoder, random_state=42 + ), + "SensTopic": lambda encoder, n_components: SensTopic( + n_components=n_components, encoder=encoder, random_state=42 + ), + "KeyNMF(Auto)": lambda encoder, n_components: KeyNMF( + n_components="auto", encoder=encoder, random_state=42 + ), + "KeyNMF": lambda encoder, n_components: KeyNMF( + n_components=n_components, encoder=encoder, random_state=42 + ), + "GMM": lambda encoder, n_components: GMM( + n_components=n_components, encoder=encoder, random_state=42 + ), + "Top2Vec(Reduce)": lambda encoder, n_components: Top2Vec( + n_reduce_to=n_components, encoder=encoder, random_state=42 + ), + "BERTopic(Reduce)": lambda encoder, n_components: BERTopic( + n_reduce_to=n_components, encoder=encoder, random_state=42 + ), + "ZeroShotTM": lambda encoder, n_components: AutoEncodingTopicModel( + n_components=n_components, encoder=encoder, random_state=42, combined=False + ), + "SemanticSignalSeparation": lambda encoder, n_components: SemanticSignalSeparation( + n_components=n_components, encoder=encoder, random_state=42 + ), + "FASTopic": lambda encoder, n_components: FASTopic( + n_components=n_components, encoder=encoder, random_state=42 + ), + } + + + def load_corpora() -> Iterable[tuple[str, Callable]]: + mteb_tasks = mteb.get_tasks( + [ + "ArXivHierarchicalClusteringP2P", + "BiorxivClusteringP2P.v2", + "MedrxivClusteringP2P.v2", + "StackExchangeClusteringP2P.v2", + "TwentyNewsgroupsClustering.v2", + ] + ) + for task in mteb_tasks: + + def _load_dataset(): + task.load_data() + ds = task.dataset["test"] + corpus = list(ds["sentences"]) + if isinstance(ds["labels"][0], list): + true_labels = [label[0] for label in ds["labels"]] + else: + true_labels = list(ds["labels"]) + return corpus, true_labels + + yield task.metadata.name, _load_dataset + + def _load_dataset(): + # Taken from here cardiffnlp/tweet_topic_single with "train_all" + ds = load_dataset("kardosdrur/tweet_topic_clustering", split="train_all") + corpus = list(ds["text"]) + labels = list(ds["label"]) + return corpus, labels + + yield "TweetTopicClustering", _load_dataset + + def _load_dataset(): + ds = load_dataset("gopalkalpande/bbc-news-summary", split="train") + corpus = list(ds["Summaries"]) + labels = list(ds["File_path"]) + return corpus, labels + + yield "BBCNewsClustering", _load_dataset + + + def diversity(keywords: list[list[str]]) -> float: + all_words = list(chain.from_iterable(keywords)) + unique_words = set(all_words) + total_words = len(all_words) + return float(len(unique_words) / total_words) + + + def word_embedding_coherence(keywords, wv): + arrays = [] + for index, topic in enumerate(keywords): + if len(topic) > 0: + local_simi = [] + for word1, word2 in combinations(topic, 2): + if word1 in wv.index_to_key and word2 in wv.index_to_key: + local_simi.append(wv.similarity(word1, word2)) + arrays.append(np.nanmean(local_simi)) + return float(np.nanmean(arrays)) + + + def evaluate_clustering(true_labels, pred_labels) -> dict[str, float]: + res = {} + for metric in [ + metrics.fowlkes_mallows_score, + metrics.homogeneity_score, + metrics.completeness_score, + metrics.adjusted_mutual_info_score, + ]: + res[metric.__name__] = metric(true_labels, pred_labels) + return res + + + def get_keywords(model) -> list[list[str]]: + """Get top words and ignore outlier topic.""" + n_topics = model.components_.shape[0] + try: + classes = model.classes_ + except AttributeError: + classes = list(range(n_topics)) + res = [] + for topic_id, words in zip(classes, model.get_top_words()): + if topic_id != -1: + res.append(words) + return res + + + def evaluate_topic_quality(keywords, ex_wv, in_wv) -> dict[str, float]: + res = { + "diversity": diversity(keywords), + "c_in": word_embedding_coherence(keywords, in_wv), + "c_ex": word_embedding_coherence(keywords, ex_wv), + } + return res + + + def load_cache(out_path): + cache_entries = [] + with out_path.open() as cache_file: + for line in cache_file: + entry = json.loads(line.strip()) + cache_entry = (entry["task"], entry["model"]) + cache_entries.append(cache_entry) + return set(cache_entries) + + + def main(encoder_name: str = "all-MiniLM-L6-v2"): + out_dir = Path("results") + out_dir.mkdir(exist_ok=True) + encoder_path_name = encoder_name.replace("/", "__") + out_path = out_dir.joinpath(f"{encoder_path_name}.jsonl") + if out_path.is_file(): + cache = load_cache(out_path) + else: + cache = set() + # Create file if doesn't exist + with out_path.open("w"): + pass + print("Loading external word embeddings") + ex_wv = api.load("word2vec-google-news-300") + print("Loading benchmark") + tasks = load_corpora() + for task_name, load in tasks: + if all([(task_name, model_name) in cache for model_name in topic_models]): + print("All models already completed, skipping.") + continue + print("Load corpus") + corpus, true_labels = load() + print("Training internal word embeddings using GloVe...") + tokenizer = CountVectorizer().build_analyzer() + glove = GloVe(vector_size=50) + tokenized_corpus = [tokenizer(text) for text in corpus] + glove.train(tokenized_corpus) + in_wv = glove.wv + encoder = SentenceTransformer(encoder_name, device="cpu") + print("Encoding task corpus.") + embeddings = encoder.encode(corpus, show_progress_bar=True) + for model_name in topic_models: + if (task_name, model_name) in cache: + print(f"{model_name} already done, skipping.") + continue + print(f"Running {model_name}.") + true_n = len(set(true_labels)) + model = topic_models[model_name](encoder=encoder, n_components=true_n) + start_time = time.time() + doc_topic_matrx = model.fit_transform(corpus, embeddings=embeddings) + end_time = time.time() + labels = getattr(model, "labels_", None) + if labels is None: + labels = np.argmax(doc_topic_matrx, axis=1) + keywords = get_keywords(model) + print("Evaluating model.") + clust_scores = evaluate_clustering(true_labels, labels) + topic_scores = evaluate_topic_quality(keywords, ex_wv, in_wv) + runtime = end_time - start_time + res = { + "encoder": encoder_name, + "task": task_name, + "model": model_name, + "auto": "(Auto)" in model_name, + "runtime": runtime, + "dps": len(corpus) / runtime, + "n_components": model.components_.shape[0], + "true_n": len(set(true_labels)), + **clust_scores, + **topic_scores, + } + print("Results: ", res) + res["keywords"] = keywords + with out_path.open("a") as out_file: + out_file.write(json.dumps(res) + "\n") + + + if __name__ == "__main__": + parser = argparse.ArgumentParser(prog="Evaluate clustering.") + parser.add_argument("embedding_model") + args = parser.parse_args() + encoder = args.embedding_model + main(encoder) + print("DONE") + ``` + + + + +For models that are able to detect the number of topics, we ran the test with this setting, this is marked as ***(Auto)*** in our tables and plots. +For models, where users can set the number of topics, we also ran the benchmark setting the correct number of topics a-priori. + +#### Topic Quality + +It seems, that Auto models, and, in particular, Topeax, SensTopic, KeyNMF and GMM were best at generating high quality topics, as can be seen from interpretability scores. +Out of non-auto models, KeyNMF, GMM, ZeroShotTM, FASTopic and SensTopic did best, though ZeroShotTM and FASTopic did not run on some of the more challenging datasets due to running out of memory. + +#### Cluster Quality + +Clear winners in cluster quality were GMM, Topeax(also GMM-based) and SensTopic. FASTopic also did reasonably well when recovering gold clusters in the data. + +
+ +
Performance profile of all models on different metrics. + Top 5 models on average performance are highlighted, click on legend to show the others. +
+
+ +## Computational Efficiency + +
+ +
+ +#### Speed + +We recorded the amount of documents a model could process per second for each of the runs. +It seems that matrix factorization approaches were fastest ($S^3$, SensTopic, KeyNMF), while neural approaches (FASTopic, ZeroShotTM) the slowest. +While in our investigations, SensTopic seems slower than SemanticSignalSeparation, it is important to note, that SensTopic has built-in JIT compilation capabilities, once JAX is installed, and is therefore likely to be even faster than $S^3$. For more detail, see [SensTopic](SensTopic.md). +We plotted model speed versus performance on the interactive graph to the right. Model size represents the Fowlkes-Mallows Index. + +
+ +
+ +#### Out of Memory + +While we did not record memory usage, three models stood out for being unable to complete some of the more challenging tasks on the test hardware. +FASTopic failed twice, on some of the larger corpora, while Top2Vec and BERTopic had problems when trying to reduce the number of topics to a desired amount. +This is likely due to the computational and memory burden of hierarchical clustering, and thus we recommend that you do not use topic reduction if you are unsure whether your hardware will be able to handle it. +If you got your heart set on using FASTopic, we recommend that you get a lot of memory, and preferably a GPU too. +Unfortunately neural topic modelling still takes a lot of resources to run. + +## Discovering the Number of Topics + +A number of methods are, in theory, able to discover the number of topics in a dataset. +We have tested this, and found that this claim is rather exaggerated, especially in the case of BERTopic and Top2Vec, +which consistently overestimated the number of topics, sometimes by orders of magnitude. +This effect gets worse with larger corpora. +Topeax was the most accurate at this task, mostly when run on larger corpora, but it was still very much off most of the time. +KeyNMF and SensTopic also got reasonably close sometimes, while completely missing the mark in others. + +We conclude that this area needs a lot of improvement. + +| Model | ArXivHierarchical (23) | BBCNews (5) | Biorxiv (26) | Medrxiv (51) | StackExchange (524) | TweetTopic (6) | TwentyNewsgroups (20) | +|--------------|--------------------------|-------------|---------------|---------------|-----------------------|----------------|-------------------------| +| BERTopic | **25** | 42 | 602 | 1583 | 2542 | 76 | 1861 | +| KeyNMF | 3 | **5** | 250 | 250 | **250** | 2 | 10 | +| SensTopic | 8 | 6 | 14 | 14 | 6 | 11 | 4 | +| Top2Vec | 18 | 18 | 405 | 1000 | 1495 | 49 | 1612 | +| Topeax | 6 | 8 | **19** | **23** | 21 | **8** | **13** | + diff --git a/docs/images/coherence_diversity.html b/docs/images/coherence_diversity.html new file mode 100644 index 0000000..bd69b1b --- /dev/null +++ b/docs/images/coherence_diversity.html @@ -0,0 +1,3888 @@ + + + +
+
+ + \ No newline at end of file diff --git a/docs/images/fastica_vs_nmf.html b/docs/images/fastica_vs_nmf.html new file mode 100644 index 0000000..6c3d667 --- /dev/null +++ b/docs/images/fastica_vs_nmf.html @@ -0,0 +1,3888 @@ + + + +
+
+ + \ No newline at end of file diff --git a/docs/images/leaderboard_screenshot.png b/docs/images/leaderboard_screenshot.png new file mode 100644 index 0000000..2474e8e Binary files /dev/null and b/docs/images/leaderboard_screenshot.png differ diff --git a/docs/images/model_speed.html b/docs/images/model_speed.html new file mode 100644 index 0000000..541f743 --- /dev/null +++ b/docs/images/model_speed.html @@ -0,0 +1,3888 @@ + + + +
+
+ + \ No newline at end of file diff --git a/docs/images/oom_error.html b/docs/images/oom_error.html new file mode 100644 index 0000000..0efa9ee --- /dev/null +++ b/docs/images/oom_error.html @@ -0,0 +1,3888 @@ + + + +
+
+ + \ No newline at end of file diff --git a/docs/images/performance_speed_plot.html b/docs/images/performance_speed_plot.html new file mode 100644 index 0000000..b02ad35 --- /dev/null +++ b/docs/images/performance_speed_plot.html @@ -0,0 +1,3888 @@ + + + +
+
+ + \ No newline at end of file diff --git a/docs/images/radar_chart.html b/docs/images/radar_chart.html new file mode 100644 index 0000000..cf0e2fe --- /dev/null +++ b/docs/images/radar_chart.html @@ -0,0 +1,3888 @@ + + + +
+
+ + \ No newline at end of file diff --git a/docs/model_overview.md b/docs/model_overview.md index 0312f38..04e1d34 100644 --- a/docs/model_overview.md +++ b/docs/model_overview.md @@ -4,65 +4,56 @@ Turftopic contains implementations of a number of contemporary topic models. Some of these models might be similar to each other in a lot of aspects, but they might be different in others. It is quite important that you choose the right topic model for your use case. +!!! tip "Looking for Model Performance?" -| ⚡ **Speed** | 📖 **Long Documents** | 🐘 **Scalability** | 🔩 **Flexibility** | -|-------------|-----------------------|--------------------|---------------------| -| [SemanticSignalSeparation](s3.md) | [KeyNMF](KeyNMF.md) | [KeyNMF](KeyNMF.md) | [ClusteringTopicModel](clustering.md) | + If you are interested in seeing how these models perform on a bunch of datasets, and would like to base your model choice on evaluations, + make sure to check out the [Model Leaderboard](benchmark.md) tab: -_Table 1: You should tailor your model choice to your needs_ +
+
+ +
+
-
- -
Figure 1: Speed of Different Models on 20 Newsgroups
(Documents per Second; Higher is better)
-
+| Model | Summary | Strengths | Weaknesses | +| - | - | - | - | +| [Topeax](Topeax.md) | Density peak detection + Gaussian mixture approximation | Cluster quality, Topic quality, Stability, Automatic n-topics | Underestimates N topics, Slower, No inference for new documents | +| [KeyNMF](KeyNMF.md) | Keyword importance estimation + matrix factorization | Reliability, Topic quality, Scalability to large corpora and long documents | Automatic topic number detection, Multilingual performance, Sometimes includes stop words | +| [SensTopic(BETA)](SensTopic.md) | Regularized Semi-nonnegative matrix factorization in embedding space | Very fast, High quality topics and clusters, Can assign multiple soft clusters to documents, GPU support | Automatic n-topics is not too good | +| [GMM](GMM.md) | Soft clustering with Gaussian Mixtures and soft-cTF-IDF | Reliability, Speed, Cluster quality | Manual n-topics, Lower quality keywords, [Curse of dimensionality](https://en.wikipedia.org/wiki/Curse_of_dimensionality) | +| [FASTopic](FASTopic.md) | Neural topic modelling with Dual Semantic-relation Reconstruction | High quality topics and clusters, GPU support | Very slow, Memory hungry, Manual n-topics | +| [$S^3$](s3.md) | Semantic axis discovery in embedding space | Fastest, Human-readable topics | Axes can be very unintuitive, Manual n-topics | +| [BERTopic and Top2Vec](clustering.md) | Embed -> Reduce -> Cluster | Flexible, Feature rich | Slow, Unreliable and unstable, Wildly overestimates number of clusters, Low topic and cluster quality | +| [AutoEncodingTopicModel](ctm.md) | Discover topics by generating BoW with a variational autoencoder | GPU-support | Slow, Sometimes low quality topics | -Different models will naturally be good at different things, because they conceptualize topics differently for instance: +
+Different models will naturally be good at different things, because they conceptualize topics differently for instance: -- `SemanticSignalSeparation`($S^3$) conceptualizes topics as **semantic axes**, along which topics are distributed -- `ClusteringTopicModel` finds **clusters** of documents and treats those as topics -- `KeyNMF` conceptualizes topics as **factors**, or looked at it from a different angle, it finds **clusters of words** +- `BERTopic`, `Top2Vec`, `GMM` and `Topeax` find **clusters** of documents and treats those as topics +- `KeyNMF`, `SensTopic`, `FASTopic` and `AutoEncodingTopicModel` conceptualize topics as latent nonnegative **factors** that generate the documents. +- `SemanticSignalSeparation`($S^3$) conceptualizes topics as **semantic axes**, along which topics are distributed. You can find a detailed overview of how each of these models work in their respective tabs. -Some models are also capable of being used in a dynamic context, some can be fitted online, some can detect the number of topics for you and some can detect topic hierarchies. You can find an overview of these features in Table 2 below. - -
- -
Figure 2: Models' Coherence and Diversity on 20 Newsgroups
(Higher is better)
-
- -!!! warning - You should take the results presented here with a grain of salt. A more comprehensive and in-depth analysis can be found in [Kardos et al., 2024](https://arxiv.org/abs/2406.09556), though the general tendencies are similar. - Note that some topic models are also less stable than others, and they might require tweaking optimal results (like BERTopic), while others perform well out-of-the-box, but are not as flexible ($S^3$) - -The quality of the topics you can get out of your topic model can depend on a lot of things, including your choice of [vectorizer](vectorizers.md) and [encoder model](encoders.md). -More rigorous evaluation regimes can be found in a number of studies on topic modeling. - -Two usual metrics to evaluate models by are *coherence* and *diversity*. -These metrics indicate how easy it is to interpret the topics provided by the topic model. -Good models typically balance these to metrics, and should produce highly coherent and diverse topics. -On Figure 2 you can see how good different models are on these metrics on 20 Newsgroups. - -In general, the most balanced models are $S^3$, Clustering models with `centroid` feature importance, GMM and KeyNMF, while FASTopic excels at diversity. - -
+## Model Features +Some models are also capable of being used in a dynamic context, some can be fitted online, some can detect the number of topics for you and some can detect topic hierarchies. You can find an overview of these features in the table below. | Model | :1234: Multiple Topics per Document | :hash: Detecting Number of Topics | :chart_with_upwards_trend: Dynamic Modeling | :evergreen_tree: Hierarchical Modeling | :star: Inference over New Documents | :globe_with_meridians: Cross-Lingual | :ocean: Online Fitting | | - | - | - | - | - | - | - | - | -| **[KeyNMF](KeyNMF.md)** | :heavy_check_mark: | :x: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | +| **[KeyNMF](KeyNMF.md)** | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | +| **[SensTopic](SensTopic.md)** | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :x: | :heavy_check_mark: | :heavy_check_mark: | :x: | +| **[Topeax](Topeax.md)** | :x: | :heavy_check_mark: | :heavy_check_mark: | :x: | :x: | :x: | :x: | | **[SemanticSignalSeparation](s3.md)** | :heavy_check_mark: | :x: | :heavy_check_mark: | :x: | :heavy_check_mark: | :heavy_check_mark: | :x: | | **[ClusteringTopicModel](clustering.md)** | :x: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :x: | :heavy_check_mark: | :x: | | **[GMM](GMM.md)** | :heavy_check_mark: | :x: | :heavy_check_mark: | :x: | :heavy_check_mark: | :heavy_check_mark: | :x: | | **[AutoEncodingTopicModel](ctm.md)** | :heavy_check_mark: | :x: | :x: | :x: | :heavy_check_mark: | :heavy_check_mark: | :x: | | **[FASTopic](fastopic.md)** | :heavy_check_mark: | :x: | :x: | :x: | :heavy_check_mark: | :heavy_check_mark: | :x: | -_Table 2: Comparison of the models based on their capabilities_ - -## API Reference +## Model API Reference :::turftopic.base.ContextualModel diff --git a/mkdocs.yml b/mkdocs.yml index f0a885e..160a465 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -22,8 +22,9 @@ nav: - Discourse Analysis on Morality and Religion: tutorials/religious.md - Discovering a Data-driven Political Compass: tutorials/ideologies.md - Customer Dissatisfaction Analysis: tutorials/reviews.md - - Topic Models: + - Topic Models (Overview and Performance): - Model Overview: model_overview.md + - Model Leaderboard: benchmark.md - Semantic Signal Separation (S³): s3.md - SensTopic (BETA): SensTopic.md - KeyNMF: KeyNMF.md diff --git a/pyproject.toml b/pyproject.toml index 8bfe462..0f6d837 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,7 +9,7 @@ profile = "black" [project] name = "turftopic" -version = "0.22.0" +version = "0.23.0" description = "Topic modeling with contextual representations from sentence transformers." authors = [ { name = "Márton Kardos ", email = "martonkardos@cas.au.dk" } diff --git a/turftopic/models/_snmf.py b/turftopic/models/_snmf.py index fcc1d7d..7b9b016 100644 --- a/turftopic/models/_snmf.py +++ b/turftopic/models/_snmf.py @@ -1,6 +1,7 @@ """This file implements semi-NMF, where doc_topic proportions are not allowed to be negative, but components are unbounded.""" import warnings +from functools import partial from typing import Optional import numpy as np @@ -33,7 +34,6 @@ def init_G( return G + constant -@jit def separate(A): abs_A = jnp.abs(A) pos = (abs_A + A) / 2 @@ -41,12 +41,10 @@ def separate(A): return pos, neg -@jit def update_F(X, G): return X @ G @ jnp.linalg.inv(G.T @ G) -@jit def update_G(X, G, F, sparsity=0): pos_xtf, neg_xtf = separate(X.T @ F) pos_gftf, neg_gftf = separate(G @ (F.T @ F)) @@ -59,12 +57,19 @@ def update_G(X, G, F, sparsity=0): return G -@jit def rec_err(X, F, G): err = X - (F @ G.T) return jnp.linalg.norm(err) +@jit +def step(G, F, X, sparsity=0): + G = update_G(X.T, G, F, sparsity) + F = update_F(X.T, G) + error = rec_err(X.T, F, G) + return G, F, error + + class SNMF(TransformerMixin, BaseEstimator): def __init__( self, @@ -89,14 +94,13 @@ def fit_transform(self, X: np.ndarray, y=None): F = update_F(X.T, G) error_at_init = rec_err(X.T, F, G) prev_error = error_at_init + _step = partial(step, sparsity=self.sparsity, X=X) for i in trange( self.max_iter, desc="Iterative updates.", disable=not self.progress_bar, ): - G = update_G(X.T, G, F, self.sparsity) - F = update_F(X.T, G) - error = rec_err(X.T, F, G) + G, F, error = _step(G, F) difference = prev_error - error if (error < error_at_init) and ( (prev_error - error) / error_at_init