Skip to content

Commit 26a313c

Browse files
Added multimodal encoder interface
1 parent a3fb166 commit 26a313c

File tree

1 file changed

+33
-0
lines changed

1 file changed

+33
-0
lines changed

turftopic/encoders/multimodal.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
from typing import Protocol
2+
3+
from PIL import Image
4+
from torch import Tensor
5+
from torch.utils.data import DataLoader
6+
7+
8+
class MultimodalEncoder(Protocol):
9+
"""Base class for external encoder models."""
10+
11+
def get_text_embeddings(
12+
self,
13+
texts: list[str],
14+
*,
15+
batch_size: int = 8,
16+
**kwargs,
17+
) -> Tensor: ...
18+
19+
def get_image_embeddings(
20+
self,
21+
images: list[Image.Image] | DataLoader,
22+
*,
23+
batch_size: int = 8,
24+
**kwargs,
25+
) -> Tensor: ...
26+
27+
def get_fused_embeddings(
28+
self,
29+
texts: list[str] = None,
30+
images: list[Image.Image] = None,
31+
batch_size: int = 8,
32+
**kwargs,
33+
) -> Tensor: ...

0 commit comments

Comments
 (0)