|
4 | 4 |
|
5 | 5 | import numpy as np |
6 | 6 | import scipy.sparse as spr |
| 7 | +from PIL import Image |
7 | 8 | from rich.console import Console |
8 | 9 | from sentence_transformers import SentenceTransformer |
9 | 10 | from sklearn.exceptions import NotFittedError |
|
13 | 14 | from turftopic.base import ContextualModel, Encoder |
14 | 15 | from turftopic.data import TopicData |
15 | 16 | from turftopic.dynamic import DynamicTopicModel |
| 17 | +from turftopic.encoders.multimodal import MultimodalEncoder |
16 | 18 | from turftopic.hierarchical import DivisibleTopicNode |
17 | 19 | from turftopic.models._keynmf import KeywordNMF, SBertKeywordExtractor |
18 | 20 | from turftopic.models.wnmf import weighted_nmf |
| 21 | +from turftopic.multimodal import (ImageRepr, MultimodalEmbeddings, |
| 22 | + MultimodalModel) |
19 | 23 | from turftopic.vectorizers.default import default_vectorizer |
20 | 24 |
|
21 | 25 |
|
22 | | -class KeyNMF(ContextualModel, DynamicTopicModel): |
| 26 | +class KeyNMF(ContextualModel, DynamicTopicModel, MultimodalModel): |
23 | 27 | """Extracts keywords from documents based on semantic similarity of |
24 | 28 | term encodings to document encodings. |
25 | 29 | Topics are then extracted with non-negative matrix factorization from |
@@ -64,7 +68,7 @@ def __init__( |
64 | 68 | self, |
65 | 69 | n_components: int, |
66 | 70 | encoder: Union[ |
67 | | - Encoder, str |
| 71 | + Encoder, str, MultimodalEncoder |
68 | 72 | ] = "sentence-transformers/all-MiniLM-L6-v2", |
69 | 73 | vectorizer: Optional[CountVectorizer] = None, |
70 | 74 | top_n: int = 25, |
@@ -235,6 +239,54 @@ def fit_transform( |
235 | 239 | ) |
236 | 240 | return doc_topic_matrix |
237 | 241 |
|
| 242 | + def fit_transform_multimodal( |
| 243 | + self, |
| 244 | + raw_documents: list[str], |
| 245 | + images: list[ImageRepr], |
| 246 | + y=None, |
| 247 | + embeddings: Optional[MultimodalEmbeddings] = None, |
| 248 | + ) -> np.ndarray: |
| 249 | + console = Console() |
| 250 | + self.multimodal_embeddings = embeddings |
| 251 | + with console.status("Fitting model") as status: |
| 252 | + if self.multimodal_embeddings is None: |
| 253 | + status.update("Encoding documents") |
| 254 | + self.multimodal_embeddings = self.encode_multimodal( |
| 255 | + raw_documents, images |
| 256 | + ) |
| 257 | + console.log("Documents encoded.") |
| 258 | + status.update("Extracting keywords") |
| 259 | + document_keywords = self.extract_keywords( |
| 260 | + raw_documents, |
| 261 | + embeddings=self.multimodal_embeddings["document_embeddings"], |
| 262 | + ) |
| 263 | + image_keywords = self.extract_keywords( |
| 264 | + raw_documents, |
| 265 | + embeddings=self.multimodal_embeddings["image_embeddings"], |
| 266 | + ) |
| 267 | + console.log("Keyword extraction done.") |
| 268 | + status.update("Decomposing with NMF") |
| 269 | + try: |
| 270 | + doc_topic_matrix = self.model.transform(document_keywords) |
| 271 | + except (NotFittedError, AttributeError): |
| 272 | + doc_topic_matrix = self.model.fit_transform(document_keywords) |
| 273 | + self.components_ = self.model.components |
| 274 | + console.log("Model fitting done.") |
| 275 | + status.update("Transforming images") |
| 276 | + self.image_topic_matrix = self.model.transform(image_keywords) |
| 277 | + self.top_images: list[list[Image.Image]] = [] |
| 278 | + for image_topic_vector in self.image_topic_matrix.T: |
| 279 | + top_im_ind = np.argsort(-image_topic_vector)[:9] |
| 280 | + top_im = [images[i] for i in top_im_ind] |
| 281 | + self.top_images.append(top_im) |
| 282 | + console.log("Images transformed") |
| 283 | + self.document_topic_matrix = doc_topic_matrix |
| 284 | + self.document_term_matrix = self.model.vectorize(document_keywords) |
| 285 | + self.hierarchy = DivisibleTopicNode.create_root( |
| 286 | + self, self.components_, self.document_topic_matrix |
| 287 | + ) |
| 288 | + return doc_topic_matrix |
| 289 | + |
238 | 290 | def fit( |
239 | 291 | self, |
240 | 292 | raw_documents=None, |
|
0 commit comments