We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent a3fb166 commit 26a313cCopy full SHA for 26a313c
turftopic/encoders/multimodal.py
@@ -0,0 +1,33 @@
1
+from typing import Protocol
2
+
3
+from PIL import Image
4
+from torch import Tensor
5
+from torch.utils.data import DataLoader
6
7
8
+class MultimodalEncoder(Protocol):
9
+ """Base class for external encoder models."""
10
11
+ def get_text_embeddings(
12
+ self,
13
+ texts: list[str],
14
+ *,
15
+ batch_size: int = 8,
16
+ **kwargs,
17
+ ) -> Tensor: ...
18
19
+ def get_image_embeddings(
20
21
+ images: list[Image.Image] | DataLoader,
22
23
24
25
26
27
+ def get_fused_embeddings(
28
29
+ texts: list[str] = None,
30
+ images: list[Image.Image] = None,
31
32
33
0 commit comments