Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions docs/Topeax.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Topeax

Topeax is a probabilistic topic model based on the Peax clustering model, which finds topics based on peaks in point density in the embedding space.
It can recover the number of topics automatically.

<br>
<figure>
<img src="../images/peax.png" width="100%" style="margin-left: auto;margin-right: auto;">
<figcaption>Schematic overview of the steps of the Peax clustering algorithm</figcaption>
</figure>

:warning: **This part of the documentation is still under construction, as more details and a paper are on their way.** :warning:

## API Reference

::: turftopic.models.topeax.Topeax

::: turftopic.models.topeax.Peax
Binary file added docs/images/peax.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ profile = "black"

[project]
name = "turftopic"
version = "0.19.1"
version = "0.20.0"
description = "Topic modeling with contextual representations from sentence transformers."
authors = [
{ name = "Márton Kardos <[email protected]>", email = "[email protected]" }
Expand Down
2 changes: 2 additions & 0 deletions turftopic/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from turftopic.models.fastopic import FASTopic
from turftopic.models.gmm import GMM
from turftopic.models.keynmf import KeyNMF
from turftopic.models.topeax import Topeax
from turftopic.serialization import load_model

try:
Expand All @@ -20,6 +21,7 @@
"ClusteringTopicModel",
"SemanticSignalSeparation",
"GMM",
"Topeax",
"KeyNMF",
"AutoEncodingTopicModel",
"ContextualModel",
Expand Down
48 changes: 48 additions & 0 deletions turftopic/encoders/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import itertools
from typing import Iterable, List

import numpy as np
import torch
from tqdm import trange


def batched(iterable, n: int) -> Iterable[List[str]]:
"Batch data into tuples of length n. The last batch may be shorter."
Expand All @@ -10,3 +14,47 @@ def batched(iterable, n: int) -> Iterable[List[str]]:
it = iter(iterable)
while batch := list(itertools.islice(it, n)):
yield batch


def encode_chunks(
encoder,
sentences,
batch_size=64,
window_size=50,
step_size=40,
return_chunks=False,
show_progress_bar=False,
):
chunks = []
chunk_embeddings = []
for start_index in trange(
0,
len(sentences),
batch_size,
desc="Encoding batches...",
disable=not show_progress_bar,
):
batch = sentences[start_index : start_index + batch_size]
features = encoder.tokenize(batch)
with torch.no_grad():
output_features = encoder.forward(features)
n_tokens = output_features["attention_mask"].sum(axis=1)
for i_doc in range(len(batch)):
for chunk_start in range(0, n_tokens[i_doc], step_size):
chunk_end = min(chunk_start + window_size, n_tokens[i_doc])
_emb = output_features["token_embeddings"][
i_doc, chunk_start:chunk_end, :
].mean(axis=0)
chunk_embeddings.append(_emb)
if return_chunks:
chunks.append(
encoder.tokenizer.decode(
features["input_ids"][i_doc, chunk_start:chunk_end]
)
.replace("[CLS]", "")
.replace("[SEP]", "")
)
if not return_chunks:
chunks = None
chunk_embeddings = np.stack(chunk_embeddings)
return chunk_embeddings, chunks
31 changes: 29 additions & 2 deletions turftopic/models/_keynmf.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
import warnings
from collections import defaultdict
from datetime import datetime
from typing import Iterable, Literal, Optional
from functools import partial
from typing import Iterable, Literal, Optional, Union

import igraph as ig
import numpy as np
Expand All @@ -21,6 +22,10 @@
from sklearn.utils.validation import check_non_negative

from turftopic.base import Encoder
from turftopic.optimization import (
decomposition_gaussian_bic,
optimize_n_components,
)

NOT_MATCHING_ERROR = (
"Document embedding dimensionality ({n_dims}) doesn't match term embedding dimensionality ({n_word_dims}). "
Expand Down Expand Up @@ -242,7 +247,7 @@ def batch_extract_keywords(
class KeywordNMF:
def __init__(
self,
n_components: int,
n_components: Union[int, Literal["auto"]],
seed: Optional[int] = None,
top_n: Optional[int] = None,
):
Expand Down Expand Up @@ -318,6 +323,15 @@ def vectorize(

def fit_transform(self, keywords: list[dict[str, float]]) -> np.ndarray:
X = self.vectorize(keywords, fitting=True)
if self.n_components == "auto":
# Finding N components with BIC
bic_fn = partial(
decomposition_gaussian_bic,
decomp_class=NMF,
X=X,
)
n_components = optimize_n_components(bic_fn, min_n=1, verbose=True)
self.n_components = n_components
check_non_negative(X, "NMF (input X)")
W, H = _initialize_nmf(X, self.n_components, random_state=self.seed)
W, H, self.n_iter = NMF(
Expand All @@ -339,6 +353,10 @@ def transform(self, keywords: list[dict[str, float]]):
return W.astype(X.dtype)

def partial_fit(self, keyword_batch: list[dict[str, float]]):
if self.n_components == "auto":
raise ValueError(
"Cannot infer number of components with BIC when online fitting the model."
)
X = self.vectorize(keyword_batch, fitting=True)
try:
check_non_negative(X, "NMF (input X)")
Expand All @@ -365,6 +383,15 @@ def fit_transform_dynamic(
n_bins = len(time_bin_edges) - 1
document_term_matrix = self.vectorize(keywords, fitting=True)
check_non_negative(document_term_matrix, "NMF (input X)")
if self.n_components == "auto":
# Finding N components with BIC
bic_fn = partial(
decomposition_gaussian_bic,
decomp_class=NMF,
X=X,
)
n_components = optimize_n_components(bic_fn, verbose=True)
self.n_components = n_components
document_topic_matrix, H = _initialize_nmf(
document_term_matrix,
self.n_components,
Expand Down
Loading