Skip to content

Commit 3af440b

Browse files
Merge pull request #87 from x-tabdeveloping/prepare_multimodal_topic_data
Multimodal TopicData
2 parents 6fdb291 + ac047fa commit 3af440b

File tree

5 files changed

+367
-110
lines changed

5 files changed

+367
-110
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ profile = "black"
99

1010
[tool.poetry]
1111
name = "turftopic"
12-
version = "0.15.0"
12+
version = "0.16.0"
1313
description = "Topic modeling with contextual representations from sentence transformers."
1414
authors = ["Márton Kardos <power.up1163@gmail.com>"]
1515
license = "MIT"

tests/test_multimodal.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,10 @@ def multimodal_models():
4848

4949
def test_multimodal(multimodal_models):
5050
for model in multimodal_models:
51-
doc_topic_matrix = model.fit_transform_multimodal(texts, images=images)
52-
fig = model.plot_topics_with_images()
53-
assert len(model.top_images) == model.components_.shape[0]
54-
assert doc_topic_matrix.shape[1] == model.components_.shape[0]
51+
topic_data = model.prepare_multimodal_topic_data(texts, images=images)
52+
fig = topic_data.plot_topics_with_images()
53+
assert len(topic_data.top_images) == model.components_.shape[0]
54+
assert (
55+
topic_data.document_topic_matrix.shape[1]
56+
== model.components_.shape[0]
57+
)

turftopic/container.py

Lines changed: 303 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
import itertools
2+
import warnings
23
from abc import ABC
34
from datetime import datetime
45
from typing import Any, Iterable, List, Optional, Tuple, Union
56

67
import numpy as np
8+
from PIL import Image
79
from rich.console import Console
810
from rich.table import Table
911

@@ -76,6 +78,92 @@ def _top_terms(
7678
terms.append(list(vocab[highest]))
7779
return terms
7880

81+
def get_top_words(
82+
self, top_k: int = 10, positive: bool = True
83+
) -> list[list[str]]:
84+
"""Returns list of top words for each topic.
85+
86+
Parameters
87+
----------
88+
top_k: int, default 10
89+
Number of words to return.
90+
positive: bool, default True
91+
Indicates whether the highest
92+
or lowest scoring terms should be returned.
93+
"""
94+
return self._top_terms(top_k, positive)
95+
96+
def get_top_documents(
97+
self,
98+
raw_documents=None,
99+
document_topic_matrix=None,
100+
top_k: int = 10,
101+
positive: bool = True,
102+
) -> list[list[str]]:
103+
"""Returns list of top documents for each topic.
104+
105+
Parameters
106+
----------
107+
top_k: int, default 10
108+
Number of documents to return per topic.
109+
positive: bool, default True
110+
Indicates whether the highest
111+
or lowest scoring documents should be returned.
112+
"""
113+
docs = []
114+
raw_documents = raw_documents or getattr(self, "corpus", None)
115+
if raw_documents is None:
116+
raise ValueError(
117+
"No corpus was passed, can't search for representative documents."
118+
)
119+
document_topic_matrix = document_topic_matrix or getattr(
120+
self, "document_topic_matrix", None
121+
)
122+
if document_topic_matrix is None:
123+
try:
124+
document_topic_matrix = self.transform(raw_documents)
125+
except AttributeError:
126+
raise ValueError(
127+
"Transductive methods cannot "
128+
"infer topical content in documents.\n"
129+
"Please pass a document_topic_matrix."
130+
)
131+
for topic_doc_vec in document_topic_matrix.T:
132+
if positive:
133+
topic_doc_vec = -topic_doc_vec
134+
highest = np.argsort(topic_doc_vec)[:top_k]
135+
docs.append([raw_documents[i_doc] for i_doc in highest])
136+
return docs
137+
138+
def get_top_images(self, top_k: int = True, positive: bool = True):
139+
"""Returns list of top images for each topic.
140+
141+
Parameters
142+
----------
143+
top_k: int, default 10
144+
Number of images to return.
145+
positive: bool, default True
146+
Indicates whether the highest
147+
or lowest scoring images should be returned.
148+
"""
149+
if not hasattr(self, "top_images"):
150+
raise ValueError(
151+
"Model either has not been fit or was fit without images. top_images property missing."
152+
)
153+
if (not positive) and not hasattr(self, "negative_images"):
154+
raise ValueError(
155+
"Model either has not been fit or was fit without images. top_images property missing."
156+
)
157+
top_images = self.top_images if positive else self.negative_images
158+
ims = []
159+
for topic_images in top_images:
160+
if len(topic_images) < top_k:
161+
warnings.warn(
162+
"Number of images stored in the topic model is smaller than the specified top_k, returning all that the model has."
163+
)
164+
ims.append(topic_images[:top_k])
165+
return ims
166+
79167
def _rename_automatic(self, namer: TopicNamer) -> list[str]:
80168
self.topic_names_ = namer.name_topics(self._top_terms())
81169
return self.topic_names_
@@ -928,3 +1016,218 @@ def plot_topics_over_time(
9281016
fig.update_xaxes(title="Time Slice Start")
9291017
fig.update_yaxes(title="Topic Importance")
9301018
return fig
1019+
1020+
@staticmethod
1021+
def _image_grid(
1022+
images: list[Image.Image],
1023+
final_size=(1200, 1200),
1024+
grid_size: tuple[int, int] = (4, 4),
1025+
):
1026+
grid_img = Image.new("RGB", final_size, (255, 255, 255))
1027+
cell_width = final_size[0] // grid_size[0]
1028+
cell_height = final_size[1] // grid_size[1]
1029+
n_rows, n_cols = grid_size
1030+
for idx, img in enumerate(images[: n_rows * n_cols]):
1031+
img = img.resize(
1032+
(cell_width, cell_height), resample=Image.Resampling.LANCZOS
1033+
)
1034+
x_offset = (idx % grid_size[0]) * cell_width
1035+
y_offset = (idx // grid_size[1]) * cell_height
1036+
grid_img.paste(img, (x_offset, y_offset))
1037+
return grid_img
1038+
1039+
def plot_topics_with_images(self, n_cols: int = 3, grid_size: int = 4):
1040+
"""Plots the most important images for each topic, along with keywords.
1041+
1042+
Note that you will need to `pip install plotly` to use plots in Turftopic.
1043+
1044+
Parameters
1045+
----------
1046+
n_cols: int, default 3
1047+
Number of columns you want to have in the grid of topics.
1048+
grid_size: int, default 4
1049+
The square root of the number of images you want to display for a given topic.
1050+
For instance if grid_size==4, all topics will have 16 images displayed,
1051+
since the joint image will have 4 columns and 4 rows.
1052+
1053+
Returns
1054+
-------
1055+
go.Figure
1056+
Plotly figure containing top images and keywords for topics.
1057+
"""
1058+
if not hasattr(self, "top_images"):
1059+
raise ValueError(
1060+
"Model either has not been fit or was fit without images. top_images property missing."
1061+
)
1062+
try:
1063+
import plotly.graph_objects as go
1064+
except (ImportError, ModuleNotFoundError) as e:
1065+
raise ModuleNotFoundError(
1066+
"Please install plotly if you intend to use plots in Turftopic."
1067+
) from e
1068+
negative_images = getattr(self, "negative_images", None)
1069+
if negative_images is not None:
1070+
# If the model has negative images, it should display them side by side with the positive ones.
1071+
n_components = self.components_.shape[0]
1072+
fig = go.Figure()
1073+
width, height = 1200, 1200
1074+
scale_factor = 0.25
1075+
w, h = width * scale_factor, height * scale_factor
1076+
padding = 10
1077+
figure_height = (h + padding) * n_components
1078+
figure_width = (w + padding) * 2
1079+
fig = fig.add_trace(
1080+
go.Scatter(
1081+
x=[0, figure_width],
1082+
y=[0, figure_height],
1083+
mode="markers",
1084+
marker_opacity=0,
1085+
)
1086+
)
1087+
vocab = self.get_vocab()
1088+
for i, component in enumerate(self.components_):
1089+
positive = vocab[np.argsort(-component)[:7]]
1090+
negative = vocab[np.argsort(component)[:7]]
1091+
pos_image = self._image_grid(
1092+
self.top_images[i],
1093+
(width, height),
1094+
grid_size=(grid_size, grid_size),
1095+
)
1096+
neg_image = self._image_grid(
1097+
self.negative_images[i],
1098+
(width, height),
1099+
grid_size=(grid_size, grid_size),
1100+
)
1101+
x0 = 0
1102+
y0 = (h + padding) * (n_components - i)
1103+
fig = fig.add_layout_image(
1104+
dict(
1105+
x=x0,
1106+
sizex=w,
1107+
y=y0,
1108+
sizey=h,
1109+
xref="x",
1110+
yref="y",
1111+
opacity=1.0,
1112+
layer="below",
1113+
sizing="stretch",
1114+
source=pos_image,
1115+
),
1116+
)
1117+
fig.add_annotation(
1118+
x=(w / 2),
1119+
y=(h + padding) * (n_components - i) - (h / 2),
1120+
text="<b> " + "<br> ".join(positive),
1121+
font=dict(
1122+
size=16,
1123+
family="Times New Roman",
1124+
color="white",
1125+
),
1126+
bgcolor="rgba(0,0,255, 0.5)",
1127+
)
1128+
x0 = (w + padding) * 1
1129+
fig = fig.add_layout_image(
1130+
dict(
1131+
x=x0,
1132+
sizex=w,
1133+
y=y0,
1134+
sizey=h,
1135+
xref="x",
1136+
yref="y",
1137+
opacity=1.0,
1138+
layer="below",
1139+
sizing="stretch",
1140+
source=neg_image,
1141+
),
1142+
)
1143+
fig.add_annotation(
1144+
x=(w + padding) + (w / 2),
1145+
y=(h + padding) * (n_components - i) - (h / 2),
1146+
text="<b> " + "<br> ".join(negative),
1147+
font=dict(
1148+
size=16,
1149+
family="Times New Roman",
1150+
color="white",
1151+
),
1152+
bgcolor="rgba(255,0,0, 0.5)",
1153+
)
1154+
fig = fig.update_xaxes(visible=False, range=[0, figure_width])
1155+
fig = fig.update_yaxes(
1156+
visible=False,
1157+
range=[0, figure_height],
1158+
# the scaleanchor attribute ensures that the aspect ratio stays constant
1159+
scaleanchor="x",
1160+
)
1161+
fig = fig.update_layout(
1162+
width=figure_width,
1163+
height=figure_height,
1164+
margin={"l": 0, "r": 0, "t": 0, "b": 0},
1165+
)
1166+
return fig
1167+
else:
1168+
fig = go.Figure()
1169+
width, height = 1200, 1200
1170+
scale_factor = 0.25
1171+
w, h = width * scale_factor, height * scale_factor
1172+
padding = 10
1173+
n_components = self.components_.shape[0]
1174+
n_rows = n_components // n_cols + int(bool(n_components % n_cols))
1175+
figure_height = (h + padding) * n_rows
1176+
figure_width = (w + padding) * n_cols
1177+
fig = fig.add_trace(
1178+
go.Scatter(
1179+
x=[0, figure_width],
1180+
y=[0, figure_height],
1181+
mode="markers",
1182+
marker_opacity=0,
1183+
)
1184+
)
1185+
vocab = self.get_vocab()
1186+
for i, component in enumerate(self.components_):
1187+
col = i % n_cols
1188+
row = i // n_cols
1189+
top_7 = vocab[np.argsort(-component)[:7]]
1190+
images = self.top_images[i]
1191+
image = self._image_grid(
1192+
images, (width, height), grid_size=(grid_size, grid_size)
1193+
)
1194+
x0 = (w + padding) * col
1195+
y0 = (h + padding) * (n_rows - row)
1196+
fig = fig.add_layout_image(
1197+
dict(
1198+
x=x0,
1199+
sizex=w,
1200+
y=y0,
1201+
sizey=h,
1202+
xref="x",
1203+
yref="y",
1204+
opacity=1.0,
1205+
layer="below",
1206+
sizing="stretch",
1207+
source=image,
1208+
),
1209+
)
1210+
fig.add_annotation(
1211+
x=(w + padding) * col + (w / 2),
1212+
y=(h + padding) * (n_rows - row) - (h / 2),
1213+
text="<b> " + "<br> ".join(top_7),
1214+
font=dict(
1215+
size=16,
1216+
family="Times New Roman",
1217+
color="white",
1218+
),
1219+
bgcolor="rgba(0,0,0, 0.5)",
1220+
)
1221+
fig = fig.update_xaxes(visible=False, range=[0, figure_width])
1222+
fig = fig.update_yaxes(
1223+
visible=False,
1224+
range=[0, figure_height],
1225+
# the scaleanchor attribute ensures that the aspect ratio stays constant
1226+
scaleanchor="x",
1227+
)
1228+
fig = fig.update_layout(
1229+
width=figure_width,
1230+
height=figure_height,
1231+
margin={"l": 0, "r": 0, "t": 0, "b": 0},
1232+
)
1233+
return fig

turftopic/data.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import joblib
1111
import numpy as np
12+
from PIL import Image
1213
from rich.console import Console
1314
from rich.tree import Tree
1415

@@ -63,6 +64,13 @@ class TopicData(Mapping, TopicContainer):
6364
This is in contrast to KeyNMF for instance, where only positive word importance should be considered.
6465
hierarchy: TopicNode, default None
6566
Optional topic hierarchy for models that support hierarchical topic modeling.
67+
images: list[ImageRepr], default None
68+
Images the model has been fit on
69+
top_images: list[list[Image]], default None
70+
Top images discovered by the topic model.
71+
negative_images: list[list[Image]], default None
72+
Lowest ranking images discivered by the topic model.
73+
(Only relevant with models like S^3)
6674
"""
6775

6876
def __init__(
@@ -82,6 +90,9 @@ def __init__(
8290
temporal_importance: Optional[np.ndarray] = None,
8391
has_negative_side: bool = False,
8492
hierarchy: Optional[TopicNode] = None,
93+
images: Optional[list[str | Image.Image]] = None,
94+
top_images: Optional[list[list[Image.Image]]] = None,
95+
negative_images: Optional[list[list[Image.Image]]] = None,
8596
**kwargs,
8697
):
8798
self.corpus = corpus
@@ -98,6 +109,9 @@ def __init__(
98109
self.temporal_importance = temporal_importance
99110
self.hierarchy = hierarchy
100111
self._has_negative_side = has_negative_side
112+
self.top_images = top_images
113+
self.negative_images = negative_images
114+
self.images = images
101115
for key, value in kwargs:
102116
setattr(self, key, value)
103117
self._attributes = [

0 commit comments

Comments
 (0)