Skip to content

Commit 4a88a00

Browse files
Added validation to encoder model in multimodal models
1 parent fe2c87b commit 4a88a00

File tree

6 files changed

+20
-1
lines changed

6 files changed

+20
-1
lines changed

turftopic/models/cluster.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,7 @@ def __init__(
208208
self.encoder_ = SentenceTransformer(encoder)
209209
else:
210210
self.encoder_ = encoder
211+
self.validate_encoder()
211212
if vectorizer is None:
212213
self.vectorizer = default_vectorizer()
213214
else:

turftopic/models/ctm.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,7 @@ def __init__(
159159
self.encoder_ = SentenceTransformer(encoder)
160160
else:
161161
self.encoder_ = encoder
162+
self.validate_encoder()
162163
if vectorizer is None:
163164
self.vectorizer = default_vectorizer()
164165
else:

turftopic/models/decomp.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ def __init__(
8787
self.encoder_ = SentenceTransformer(encoder)
8888
else:
8989
self.encoder_ = encoder
90+
self.validate_encoder()
9091
if vectorizer is None:
9192
self.vectorizer = default_vectorizer()
9293
else:

turftopic/models/gmm.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ def __init__(
8787
self.encoder_ = SentenceTransformer(encoder)
8888
else:
8989
self.encoder_ = encoder
90+
self.validate_encoder()
9091
if vectorizer is None:
9192
self.vectorizer = default_vectorizer()
9293
else:

turftopic/models/keynmf.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ def __init__(
9595
self.encoder_ = SentenceTransformer(encoder)
9696
else:
9797
self.encoder_ = encoder
98+
self.validate_encoder()
9899
if vectorizer is None:
99100
self.vectorizer = default_vectorizer()
100101
else:

turftopic/multimodal.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@
55

66
import numpy as np
77
from PIL import Image
8+
from sentence_transformers import SentenceTransformer
9+
10+
from turftopic.encoders.multimodal import MultimodalEncoder
811

912
UrlStr = str
1013

@@ -74,7 +77,8 @@ def encode_multimodal(
7477
"document_embeddings": document_embeddings,
7578
}
7679

77-
def validate_embeddings(self, embeddings: Optional[MultimodalEmbeddings]):
80+
@staticmethod
81+
def validate_embeddings(embeddings: Optional[MultimodalEmbeddings]):
7882
if embeddings is None:
7983
return
8084
try:
@@ -89,6 +93,16 @@ def validate_embeddings(self, embeddings: Optional[MultimodalEmbeddings]):
8993
f"Shape mismatch between document_embeddings {document_embeddings.shape} and image_embeddings {image_embeddings.shape}"
9094
)
9195

96+
def validate_encoder(self):
97+
if not hasattr(self.encoder_, "encode"):
98+
if not all(
99+
hasattr(self.encoder_, "get_text_embeddings"),
100+
hasattr(self.encoder_, "get_image_embeddings"),
101+
):
102+
raise TypeError(
103+
"An encoder must either have an encode() method or a get_text_embeddings and get_image_embeddings method (optionally get_fused_embeddings)"
104+
)
105+
92106
@abstractmethod
93107
def fit_transform_multimodal(
94108
self,

0 commit comments

Comments
 (0)