Skip to content

Commit a408b99

Browse files
DimasfromLavoisieriulusoypre-commit-ci[bot]
authored
Multimodal search (#276)
* fix: include audio model class in init * fix: remove model from init, and reference model module in notebook instead * add multimodal search module * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add multi query support * small fixes for code improvement * small adjustments * upd notebook * add tests --------- Co-authored-by: Inga Ulusoy <inga.ulusoy@uni-heidelberg.de> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent ac19b1d commit a408b99

File tree

10 files changed

+1456
-12
lines changed

10 files changed

+1456
-12
lines changed

ammico/__init__.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,14 @@
11
from ammico.display import AnalysisExplorer
2-
from ammico.model import MultimodalSummaryModel
2+
from ammico.model import (
3+
MultimodalSummaryModel,
4+
AudioToTextModel,
5+
MultimodalEmbeddingsModel,
6+
)
37
from ammico.text import TextDetector, TextAnalyzer, privacy_disclosure
48
from ammico.image_summary import ImageSummaryDetector
59
from ammico.utils import find_files, get_dataframe, AnalysisType, find_videos
610
from ammico.video_summary import VideoSummaryDetector
11+
from ammico.multimodal_search import MultimodalSearch
712

813
# Export the version defined in project metadata
914
try:
@@ -17,6 +22,9 @@
1722
"AnalysisType",
1823
"AnalysisExplorer",
1924
"MultimodalSummaryModel",
25+
"MultimodalEmbeddingsModel",
26+
"AudioToTextModel",
27+
"MultimodalSearch",
2028
"TextDetector",
2129
"TextAnalyzer",
2230
"ImageSummaryDetector",

ammico/model.py

Lines changed: 109 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,10 @@
1313
BitsAndBytesConfig,
1414
AutoTokenizer,
1515
)
16-
from typing import Optional
16+
from typing import Optional, List, Union
17+
from PIL import Image
18+
from sentence_transformers import SentenceTransformer
19+
import numpy as np
1720

1821

1922
class MultimodalSummaryModel:
@@ -208,3 +211,108 @@ def close(self) -> None:
208211
RuntimeWarning,
209212
stacklevel=2,
210213
)
214+
215+
216+
class MultimodalEmbeddingsModel:
217+
def __init__(
218+
self,
219+
device: Optional[str] = None,
220+
) -> None:
221+
"""
222+
Class for Multimodal Embeddings model loading and inference. Uses Jina CLIP-V2 model.
223+
Args:
224+
device: "cuda" or "cpu" (auto-detected when None).
225+
"""
226+
self.device = resolve_model_device(device)
227+
228+
model_id = "jinaai/jina-clip-v2"
229+
230+
self.model = SentenceTransformer(
231+
model_id,
232+
device=self.device,
233+
trust_remote_code=True,
234+
model_kwargs={"torch_dtype": "auto"},
235+
)
236+
237+
self.model.eval()
238+
239+
self.embedding_dim = 1024
240+
241+
@torch.inference_mode()
242+
def encode_text(
243+
self,
244+
texts: Union[str, List[str]],
245+
batch_size: int = 64,
246+
truncate_dim: Optional[int] = None,
247+
) -> Union[torch.Tensor, np.ndarray]:
248+
if isinstance(texts, str):
249+
texts = [texts]
250+
251+
convert_to_tensor = self.device == "cuda"
252+
convert_to_numpy = not convert_to_tensor
253+
254+
embeddings = self.model.encode(
255+
texts,
256+
batch_size=batch_size,
257+
convert_to_tensor=convert_to_tensor,
258+
convert_to_numpy=convert_to_numpy,
259+
normalize_embeddings=True,
260+
)
261+
262+
if truncate_dim is not None:
263+
if not (64 <= truncate_dim <= self.embedding_dim):
264+
raise ValueError(
265+
f"truncate_dim must be between 64 and {self.embedding_dim}"
266+
)
267+
embeddings = embeddings[:, :truncate_dim]
268+
return embeddings
269+
270+
@torch.inference_mode()
271+
def encode_image(
272+
self,
273+
images: Union[Image.Image, List[Image.Image]],
274+
batch_size: int = 32,
275+
truncate_dim: Optional[int] = None,
276+
) -> Union[torch.Tensor, np.ndarray]:
277+
if not isinstance(images, (Image.Image, list)):
278+
raise ValueError(
279+
"images must be a PIL.Image or a list of PIL.Image objects. Please load images properly."
280+
)
281+
282+
convert_to_tensor = self.device == "cuda"
283+
convert_to_numpy = not convert_to_tensor
284+
285+
embeddings = self.model.encode(
286+
images if isinstance(images, list) else [images],
287+
batch_size=batch_size,
288+
convert_to_tensor=convert_to_tensor,
289+
convert_to_numpy=convert_to_numpy,
290+
normalize_embeddings=True,
291+
)
292+
293+
if truncate_dim is not None:
294+
if not (64 <= truncate_dim <= self.embedding_dim):
295+
raise ValueError(
296+
f"truncate_dim must be between 64 and {self.embedding_dim}"
297+
)
298+
embeddings = embeddings[:, :truncate_dim]
299+
300+
return embeddings
301+
302+
def close(self) -> None:
303+
"""Free model resources (helpful in long-running processes)."""
304+
try:
305+
if self.model is not None:
306+
del self.model
307+
self.model = None
308+
finally:
309+
try:
310+
if torch.cuda.is_available():
311+
torch.cuda.empty_cache()
312+
except Exception as e:
313+
warnings.warn(
314+
"Failed to empty CUDA cache. This is not critical, but may lead to memory lingering: "
315+
f"{e!r}",
316+
RuntimeWarning,
317+
stacklevel=2,
318+
)

0 commit comments

Comments
 (0)