Skip to content

Commit a3fb166

Browse files
Added multimodal KeyNMF
1 parent 9d8e30e commit a3fb166

File tree

2 files changed

+67
-20
lines changed

2 files changed

+67
-20
lines changed

turftopic/models/_keynmf.py

Lines changed: 13 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,8 @@
88
import numpy as np
99
import scipy.sparse as spr
1010
from sklearn.base import clone
11-
from sklearn.decomposition._nmf import (
12-
NMF,
13-
MiniBatchNMF,
14-
_initialize_nmf,
15-
_update_coordinate_descent,
16-
)
11+
from sklearn.decomposition._nmf import (NMF, MiniBatchNMF, _initialize_nmf,
12+
_update_coordinate_descent)
1713
from sklearn.exceptions import NotFittedError
1814
from sklearn.feature_extraction.text import CountVectorizer
1915
from sklearn.metrics.pairwise import cosine_similarity
@@ -144,19 +140,23 @@ def is_encoder_promptable(self) -> bool:
144140
if ("query" in prompts) and ("passage" in prompts):
145141
return True
146142

143+
def encode(
144+
self, texts: Iterable[str], prompt_name: str = None
145+
) -> np.ndarray:
146+
if not hasattr(self.encoder, "encode"):
147+
return self.encoder.get_text_embeddings(list(texts))
148+
if (prompt_name is not None) and (self.is_encoder_promptable):
149+
return self.encoder.encode(texts, prompt_name=prompt_name)
150+
return self.encoder.encode(texts)
151+
147152
@property
148153
def n_vocab(self) -> int:
149154
return len(self.key_to_index)
150155

151156
def _add_terms(self, new_terms: list[str]):
152157
for term in new_terms:
153158
self.key_to_index[term] = self.n_vocab
154-
if not self.is_encoder_promptable:
155-
term_encodings = self.encoder.encode(new_terms)
156-
else:
157-
term_encodings = self.encoder.encode(
158-
new_terms, prompt_name="passage"
159-
)
159+
term_encodings = self.encode(new_terms, prompt_name="passage")
160160
if self.term_embeddings is not None:
161161
self.term_embeddings = np.concatenate(
162162
(self.term_embeddings, term_encodings), axis=0
@@ -174,12 +174,7 @@ def batch_extract_keywords(
174174
if not len(documents):
175175
return []
176176
if embeddings is None:
177-
if not self.is_encoder_promptable:
178-
embeddings = self.encoder.encode(documents)
179-
else:
180-
embeddings = self.encoder.encode(
181-
documents, prompt_name="query"
182-
)
177+
embeddings = self.encode(documents, prompt_name="query")
183178
if len(embeddings) != len(documents):
184179
raise ValueError(
185180
"Number of documents doesn't match number of embeddings."

turftopic/models/keynmf.py

Lines changed: 54 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import numpy as np
66
import scipy.sparse as spr
7+
from PIL import Image
78
from rich.console import Console
89
from sentence_transformers import SentenceTransformer
910
from sklearn.exceptions import NotFittedError
@@ -13,13 +14,16 @@
1314
from turftopic.base import ContextualModel, Encoder
1415
from turftopic.data import TopicData
1516
from turftopic.dynamic import DynamicTopicModel
17+
from turftopic.encoders.multimodal import MultimodalEncoder
1618
from turftopic.hierarchical import DivisibleTopicNode
1719
from turftopic.models._keynmf import KeywordNMF, SBertKeywordExtractor
1820
from turftopic.models.wnmf import weighted_nmf
21+
from turftopic.multimodal import (ImageRepr, MultimodalEmbeddings,
22+
MultimodalModel)
1923
from turftopic.vectorizers.default import default_vectorizer
2024

2125

22-
class KeyNMF(ContextualModel, DynamicTopicModel):
26+
class KeyNMF(ContextualModel, DynamicTopicModel, MultimodalModel):
2327
"""Extracts keywords from documents based on semantic similarity of
2428
term encodings to document encodings.
2529
Topics are then extracted with non-negative matrix factorization from
@@ -64,7 +68,7 @@ def __init__(
6468
self,
6569
n_components: int,
6670
encoder: Union[
67-
Encoder, str
71+
Encoder, str, MultimodalEncoder
6872
] = "sentence-transformers/all-MiniLM-L6-v2",
6973
vectorizer: Optional[CountVectorizer] = None,
7074
top_n: int = 25,
@@ -235,6 +239,54 @@ def fit_transform(
235239
)
236240
return doc_topic_matrix
237241

242+
def fit_transform_multimodal(
243+
self,
244+
raw_documents: list[str],
245+
images: list[ImageRepr],
246+
y=None,
247+
embeddings: Optional[MultimodalEmbeddings] = None,
248+
) -> np.ndarray:
249+
console = Console()
250+
self.multimodal_embeddings = embeddings
251+
with console.status("Fitting model") as status:
252+
if self.multimodal_embeddings is None:
253+
status.update("Encoding documents")
254+
self.multimodal_embeddings = self.encode_multimodal(
255+
raw_documents, images
256+
)
257+
console.log("Documents encoded.")
258+
status.update("Extracting keywords")
259+
document_keywords = self.extract_keywords(
260+
raw_documents,
261+
embeddings=self.multimodal_embeddings["document_embeddings"],
262+
)
263+
image_keywords = self.extract_keywords(
264+
raw_documents,
265+
embeddings=self.multimodal_embeddings["image_embeddings"],
266+
)
267+
console.log("Keyword extraction done.")
268+
status.update("Decomposing with NMF")
269+
try:
270+
doc_topic_matrix = self.model.transform(document_keywords)
271+
except (NotFittedError, AttributeError):
272+
doc_topic_matrix = self.model.fit_transform(document_keywords)
273+
self.components_ = self.model.components
274+
console.log("Model fitting done.")
275+
status.update("Transforming images")
276+
self.image_topic_matrix = self.model.transform(image_keywords)
277+
self.top_images: list[list[Image.Image]] = []
278+
for image_topic_vector in self.image_topic_matrix.T:
279+
top_im_ind = np.argsort(-image_topic_vector)[:9]
280+
top_im = [images[i] for i in top_im_ind]
281+
self.top_images.append(top_im)
282+
console.log("Images transformed")
283+
self.document_topic_matrix = doc_topic_matrix
284+
self.document_term_matrix = self.model.vectorize(document_keywords)
285+
self.hierarchy = DivisibleTopicNode.create_root(
286+
self, self.components_, self.document_topic_matrix
287+
)
288+
return doc_topic_matrix
289+
238290
def fit(
239291
self,
240292
raw_documents=None,

0 commit comments

Comments
 (0)