Skip to content

Commit 39b986d

Browse files
Merge pull request #84 from x-tabdeveloping/multimodal
Beta feature: Multimodal topic modelling
2 parents dc42d80 + fd05205 commit 39b986d

File tree

17 files changed

+1111
-44
lines changed

17 files changed

+1111
-44
lines changed

.github/workflows/tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ jobs:
2929
run: python3 -c "import sys; print(sys.version)"
3030

3131
- name: Install dependencies
32-
run: python3 -m pip install --upgrade turftopic[pyro-ppl] pandas pytest plotly igraph datasets
32+
run: python3 -m pip install --upgrade turftopic[pyro-ppl] pandas pytest plotly igraph datasets pillow
3333
- name: Run tests
3434
run: python3 -m pytest tests/
3535

docs/images/multimodal.html

Lines changed: 14 additions & 0 deletions
Large diffs are not rendered by default.

docs/multimodal.md

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
# Multimodal Topic Modelling ***(BETA)***
2+
3+
!!! note
4+
Multimodal modeling is still a BETA feature in Turftopic, and it is likely that we will add more features and change the interface in the near future.
5+
6+
Some corpora spread across multiple modalities.
7+
A good example of this would be news articles with images attached.
8+
Turftopic now supports multimodal modelling with a number of models.
9+
10+
11+
## Multimodal Encoders
12+
13+
In order for images to be usable in Turftopic, you will need an embedding model that can both encode texts and images.
14+
You can both use models that are supported in SentenceTransformers, or those that support the MTEB multimodal encoder interface.
15+
16+
17+
!!! quote "Use a multimodal encoder model "
18+
=== "SentenceTransformers"
19+
20+
```python
21+
from turftopic import KeyNMF
22+
23+
multimodal_keynmf = KeyNMF(10, encoder="clip-ViT-B-32")
24+
```
25+
26+
=== "MTEB/MIEB"
27+
!!! tip
28+
You can find current state-of-the-art embedding models and their capabilities on the [Massive Image Embedding Benchmark leaderboard](http://mteb-leaderboard.hf.space/?benchmark_name=MIEB%28Multilingual%29).
29+
30+
```bash
31+
pip install "mteb<2.0.0"
32+
```
33+
34+
```python
35+
from turftopic import KeyNMF
36+
import mteb
37+
38+
encoder = mteb.get_model("kakaobrain/align-base")
39+
40+
multimodal_keynmf = KeyNMF(10, encoder="clip-ViT-B-32")
41+
```
42+
43+
## Corpus Structure
44+
45+
Currently all documents **have to have** an image attached to them, and only one image.
46+
This is a limitation, and we will address it in the future.
47+
Images can both be represented as file paths or `PIL.Image` objects.
48+
49+
```python
50+
from PIL import Image
51+
52+
images: list[Image] = [Image.open("file_path/something.jpeg"), ...]
53+
texts: list[str] = [...]
54+
55+
len(images) == len(texts)
56+
```
57+
58+
## Basic Usage
59+
60+
All multimodal models have a `fit_multimodal()`/`fit_transform_multimodal()` method,
61+
that you can use to discover topics in multimodal corpora.
62+
63+
!!! quote "Fit a multimodal model on a corpus"
64+
=== "KeyNMF"
65+
66+
```python
67+
from turftopic import KeyNMF
68+
69+
model = KeyNMF(12, encoder="clip-ViT-B-32")
70+
model.fit_multimodal(texts, images=images)
71+
model.plot_topics_with_images()
72+
```
73+
74+
=== "SemanticSignalSeparation"
75+
76+
```python
77+
from turftopic import SemanticSignalSeparation
78+
79+
model = SemanticSignalSeparation(12, encoder="clip-ViT-B-32")
80+
model.fit_multimodal(texts, images=images)
81+
model.plot_topics_with_images()
82+
```
83+
84+
=== "Clustering Models"
85+
86+
```python
87+
from turftopic import ClusteringTopicModel
88+
89+
# BERTopic-style
90+
model = ClusteringTopicModel(encoder="clip-ViT-B-32", feature_importance="c-tf-idf")
91+
# Top2Vec-style
92+
model = ClusteringTopicModel(encoder="clip-ViT-B-32", feature_importance="centroid")
93+
model.fit_multimodal(texts, images=images)
94+
model.plot_topics_with_images()
95+
```
96+
97+
=== "GMM"
98+
99+
```python
100+
from turftopic import GMM
101+
102+
model = GMM(12, encoder="clip-ViT-B-32")
103+
model.fit_multimodal(texts, images=images)
104+
model.plot_topics_with_images()
105+
```
106+
107+
=== "AutoEncodingTopicModel"
108+
109+
```python
110+
from turftopic import AutoEncodingTopicModel
111+
112+
# CombinedTM
113+
model = AutoEncodingTopicModel(12, combined=True, encoder="clip-ViT-B-32")
114+
# ZeroShotTM
115+
model = AutoEncodingTopicModel(12, combined=False, encoder="clip-ViT-B-32")
116+
model.fit_multimodal(texts, images=images)
117+
model.plot_topics_with_images()
118+
```
119+
120+
<iframe src="../images/multimodal.html", title="Multimodal KeyNMF on IKEA catalogue", style="height:350px;width:100%;padding:0px;border:none;"></iframe>
121+
122+
## API reference
123+
124+
::: turftopic.multimodal.MultimodalModel
125+
126+
::: turftopic.encoders.multimodal.MultimodalEncoder
127+
128+
129+

mkdocs.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ nav:
1111
- Online Topic Modeling: online.md
1212
- Hierarchical Topic Modeling: hierarchical.md
1313
- Cross-Lingual Topic Modeling: cross_lingual.md
14+
- Multimodal Modeling (BETA): multimodal.md
1415
- Modifying and Finetuning Models: finetuning.md
1516
- Saving and Loading: persistence.md
1617
- Using TopicData: topic_data.md

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ profile = "black"
99

1010
[tool.poetry]
1111
name = "turftopic"
12-
version = "0.14.1"
12+
version = "0.15.0"
1313
description = "Topic modeling with contextual representations from sentence transformers."
1414
authors = ["Márton Kardos <power.up1163@gmail.com>"]
1515
license = "MIT"
@@ -26,6 +26,7 @@ rich = "^13.6.0"
2626
huggingface-hub = ">=0.23.2"
2727
joblib = "^1.2.0"
2828
igraph = "~0.11.6"
29+
pillow = "~10.4.0"
2930
snowballstemmer = {version=">=2.0.0", optional=true}
3031
spacy = {version=">=3.6.0", optional=true}
3132
jieba = {version=">=0.40.0", optional=true}

tests/test_multimodal.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
import pytest
2+
from datasets import load_dataset
3+
from sentence_transformers import SentenceTransformer
4+
from sklearn.cluster import KMeans
5+
from sklearn.decomposition import PCA
6+
from sklearn.feature_extraction.text import CountVectorizer
7+
8+
from turftopic import (
9+
GMM,
10+
AutoEncodingTopicModel,
11+
ClusteringTopicModel,
12+
KeyNMF,
13+
SemanticSignalSeparation,
14+
)
15+
16+
17+
@pytest.fixture
18+
def multimodal_models():
19+
encoder = SentenceTransformer("sentence-transformers/clip-ViT-B-16")
20+
return [
21+
AutoEncodingTopicModel(
22+
2, combined=True, encoder=encoder, vectorizer=CountVectorizer()
23+
),
24+
GMM(2, encoder=encoder, vectorizer=CountVectorizer()),
25+
KeyNMF(2, encoder=encoder, vectorizer=CountVectorizer()),
26+
SemanticSignalSeparation(
27+
2, encoder=encoder, vectorizer=CountVectorizer()
28+
),
29+
ClusteringTopicModel(
30+
dimensionality_reduction=PCA(10),
31+
clustering=KMeans(3),
32+
feature_importance="c-tf-idf",
33+
encoder=encoder,
34+
),
35+
ClusteringTopicModel(
36+
dimensionality_reduction=PCA(10),
37+
clustering=KMeans(3),
38+
feature_importance="centroid",
39+
encoder=encoder,
40+
),
41+
]
42+
43+
44+
flowers = load_dataset("kardosdrur/flowers_multimodal_test", split="train")
45+
texts = flowers["blip_caption"]
46+
images = flowers["image"]
47+
48+
49+
def test_multimodal(multimodal_models):
50+
for model in multimodal_models:
51+
doc_topic_matrix = model.fit_transform_multimodal(texts, images=images)
52+
fig = model.plot_topics_with_images()
53+
assert len(model.top_images) == model.components_.shape[0]
54+
assert doc_topic_matrix.shape[1] == model.components_.shape[0]

turftopic/base.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@ def encode_documents(self, raw_documents: Iterable[str]) -> np.ndarray:
3939
ndarray of shape (n_documents, n_dimensions)
4040
Matrix of document embeddings.
4141
"""
42+
if not hasattr(self.encoder_, "encode"):
43+
return self.encoder.get_text_embeddings(list(raw_documents))
4244
return self.encoder_.encode(raw_documents)
4345

4446
@abstractmethod

turftopic/encoders/multimodal.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
from typing import Protocol
2+
3+
from PIL import Image
4+
5+
6+
class MultimodalEncoder(Protocol):
7+
"""Base class for external encoder models."""
8+
9+
def get_text_embeddings(
10+
self,
11+
texts: list[str],
12+
*,
13+
batch_size: int = 8,
14+
**kwargs,
15+
): ...
16+
17+
def get_image_embeddings(
18+
self,
19+
images: list[Image.Image],
20+
*,
21+
batch_size: int = 8,
22+
**kwargs,
23+
): ...
24+
25+
def get_fused_embeddings(
26+
self,
27+
texts: list[str] = None,
28+
images: list[Image.Image] = None,
29+
batch_size: int = 8,
30+
**kwargs,
31+
): ...

turftopic/feature_importance.py

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
import numpy as np
22
import scipy.sparse as spr
3+
from sklearn.feature_extraction.text import TfidfTransformer
34
from sklearn.metrics.pairwise import cosine_similarity
5+
from sklearn.preprocessing import normalize
6+
from sklearn.utils import check_array
47

58

69
def cluster_centroid_distance(
@@ -34,7 +37,9 @@ def cluster_centroid_distance(
3437

3538

3639
def soft_ctf_idf(
37-
doc_topic_matrix: np.ndarray, doc_term_matrix: spr.csr_matrix
40+
doc_topic_matrix: np.ndarray,
41+
doc_term_matrix: spr.csr_matrix,
42+
return_idf: bool = False,
3843
) -> np.ndarray:
3944
"""Computes feature importances using Soft C-TF-IDF
4045
@@ -57,11 +62,23 @@ def soft_ctf_idf(
5762
tf = (term_importance.T / (overall_in_topic + eps)).T
5863
idf = np.log(n_docs / (np.abs(term_importance).sum(axis=0) + eps))
5964
ctf_idf = tf * idf
60-
return ctf_idf
65+
idf_diag = spr.diags(
66+
idf,
67+
offsets=0,
68+
shape=(doc_term_matrix.shape[1], doc_term_matrix.shape[1]),
69+
format="csr",
70+
dtype=tf.dtype,
71+
)
72+
if not return_idf:
73+
return ctf_idf
74+
else:
75+
return ctf_idf, idf_diag
6176

6277

6378
def ctf_idf(
64-
doc_topic_matrix: np.ndarray, doc_term_matrix: spr.csr_matrix
79+
doc_topic_matrix: np.ndarray,
80+
doc_term_matrix: spr.csr_matrix,
81+
return_idf: bool = False,
6582
) -> np.ndarray:
6683
"""Computes feature importances using standard C-TF-IDF
6784
@@ -89,7 +106,18 @@ def ctf_idf(
89106
)
90107
component = freq * np.log(1 + average / overall_freq)
91108
components.append(component)
92-
return np.stack(components)
109+
idf = np.log((average / overall_freq) + 1)
110+
idf_diag = spr.diags(
111+
idf,
112+
offsets=0,
113+
shape=(doc_term_matrix.shape[1], doc_term_matrix.shape[1]),
114+
format="csr",
115+
dtype=doc_term_matrix.dtype,
116+
)
117+
if not return_idf:
118+
return np.stack(components)
119+
else:
120+
return np.stack(components), idf_diag
93121

94122

95123
def bayes_rule(

turftopic/models/_hierarchical_clusters.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ def _estimate_children_components(self) -> dict[int, np.ndarray]:
187187
) # type: ignore
188188
elif self.model.feature_importance == "centroid":
189189
if not hasattr(self.model, "vocab_embeddings"):
190-
self.model.vocab_embeddings = self.model.encoder_.encode(
190+
self.model.vocab_embeddings = self.model.encode_documents(
191191
self.model.vectorizer.get_feature_names_out()
192192
) # type: ignore
193193
if (

0 commit comments

Comments
 (0)