diff --git a/pyproject.toml b/pyproject.toml index 41fd802..45add48 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,7 +9,7 @@ profile = "black" [project] name = "turftopic" -version = "0.23.1" +version = "0.23.2" description = "Topic modeling with contextual representations from sentence transformers." authors = [ { name = "Márton Kardos ", email = "martonkardos@cas.au.dk" } diff --git a/turftopic/models/_snmf.py b/turftopic/models/_snmf.py index 7b9b016..37efd47 100644 --- a/turftopic/models/_snmf.py +++ b/turftopic/models/_snmf.py @@ -54,6 +54,7 @@ def update_G(X, G, F, sparsity=0): denominator = jnp.maximum(denominator, EPSILON) delta_G = jnp.sqrt(numerator / denominator) G *= delta_G + G = G / jnp.linalg.norm(G) return G @@ -128,7 +129,24 @@ def fit_timeslice(self, X_t: np.ndarray, G_t: np.ndarray): return F.T def transform(self, X: np.ndarray): - G = jnp.maximum(X @ jnp.linalg.pinv(self.components_), 0) + G = init_G( + X.T, + n_components=self.n_components, + random_state=self.random_state, + ) + F = self.components_.T + update = jit(lambda G: update_G(X.T, G, F, sparsity=self.sparsity)) + error_at_init = rec_err(X.T, F, G) + prev_error = error_at_init + for i in range(self.max_iter): + G = update(G) + err = rec_err(X.T, F, G) + if (err < error_at_init) and ( + (prev_error - err) / error_at_init + ) < self.tol: + if self.verbose: + print(f"Converged after {i} iterations") + break return np.array(G) def inverse_transform(self, X): diff --git a/turftopic/models/senstopic.py b/turftopic/models/senstopic.py index cadc37a..42cc44e 100644 --- a/turftopic/models/senstopic.py +++ b/turftopic/models/senstopic.py @@ -205,6 +205,11 @@ def fit_transform( console.log("Model fitting done.") return doc_topic + def transform(self, raw_documents, embeddings=None): + if embeddings is None: + embeddings = self.encoder_.encode(raw_documents) + return self.decomposition.transform(embeddings) + def fit_transform_multimodal( self, raw_documents: list[str],