Skip to content

Commit b08febb

Browse files
Late interaction type hints (#461)
* chore: Add type hints * new: Add late_interaction type hints * fix: ndarray -> numpy array --------- Co-authored-by: George Panchuk <[email protected]>
1 parent 6dbdd6d commit b08febb

File tree

4 files changed

+37
-36
lines changed

4 files changed

+37
-36
lines changed

fastembed/late_interaction/colbert.py

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import numpy as np
55
from tokenizers import Encoding
66

7+
from fastembed.common.types import NumpyArray
78
from fastembed.common import OnnxProvider
89
from fastembed.common.onnx_model import OnnxOutputContext
910
from fastembed.common.utils import define_cache_dir
@@ -39,15 +40,15 @@
3940
]
4041

4142

42-
class Colbert(LateInteractionTextEmbeddingBase, OnnxTextModel[np.ndarray]):
43+
class Colbert(LateInteractionTextEmbeddingBase, OnnxTextModel[NumpyArray]):
4344
QUERY_MARKER_TOKEN_ID = 1
4445
DOCUMENT_MARKER_TOKEN_ID = 2
4546
MIN_QUERY_LENGTH = 31 # it's 32, we add one additional special token in the beginning
4647
MASK_TOKEN = "[MASK]"
4748

4849
def _post_process_onnx_output(
4950
self, output: OnnxOutputContext, is_doc: bool = True
50-
) -> Iterable[np.ndarray]:
51+
) -> Iterable[NumpyArray]:
5152
if not is_doc:
5253
return output.model_output.astype(np.float32)
5354

@@ -68,11 +69,15 @@ def _post_process_onnx_output(
6869
return output.model_output.astype(np.float32)
6970

7071
def _preprocess_onnx_input(
71-
self, onnx_input: dict[str, np.ndarray], is_doc: bool = True, **kwargs: Any
72-
) -> dict[str, np.ndarray]:
72+
self, onnx_input: dict[str, NumpyArray], is_doc: bool = True, **kwargs: Any
73+
) -> dict[str, NumpyArray]:
7374
marker_token = self.DOCUMENT_MARKER_TOKEN_ID if is_doc else self.QUERY_MARKER_TOKEN_ID
74-
onnx_input["input_ids"] = np.insert(onnx_input["input_ids"], 1, marker_token, axis=1)
75-
onnx_input["attention_mask"] = np.insert(onnx_input["attention_mask"], 1, 1, axis=1)
75+
onnx_input["input_ids"] = np.insert(
76+
onnx_input["input_ids"].astype(np.int64), 1, marker_token, axis=1
77+
)
78+
onnx_input["attention_mask"] = np.insert(
79+
onnx_input["attention_mask"].astype(np.int64), 1, 1, axis=1
80+
)
7681
return onnx_input
7782

7883
def tokenize(self, documents: list[str], is_doc: bool = True, **kwargs: Any) -> list[Encoding]:
@@ -166,17 +171,17 @@ def __init__(
166171
self.device_id = None
167172

168173
self.model_description = self._get_model_description(model_name)
169-
self.cache_dir = define_cache_dir(cache_dir)
174+
self.cache_dir = str(define_cache_dir(cache_dir))
170175

171176
self._model_dir = self.download_model(
172177
self.model_description,
173178
self.cache_dir,
174179
local_files_only=self._local_files_only,
175180
specific_model_path=specific_model_path,
176181
)
177-
self.mask_token_id = None
178-
self.pad_token_id = None
179-
self.skip_list = set()
182+
self.mask_token_id: Optional[int] = None
183+
self.pad_token_id: Optional[int] = None
184+
self.skip_list: set[str] = set()
180185

181186
if not self.lazy_load:
182187
self.load_onnx_model()
@@ -206,7 +211,7 @@ def embed(
206211
batch_size: int = 256,
207212
parallel: Optional[int] = None,
208213
**kwargs: Any,
209-
) -> Iterable[np.ndarray]:
214+
) -> Iterable[NumpyArray]:
210215
"""
211216
Encode a list of documents into list of embeddings.
212217
We use mean pooling with attention so that the model can handle variable-length inputs.
@@ -234,7 +239,7 @@ def embed(
234239
**kwargs,
235240
)
236241

237-
def query_embed(self, query: Union[str, Iterable[str]], **kwargs: Any) -> Iterable[np.ndarray]:
242+
def query_embed(self, query: Union[str, Iterable[str]], **kwargs: Any) -> Iterable[NumpyArray]:
238243
if isinstance(query, str):
239244
query = [query]
240245

@@ -247,11 +252,11 @@ def query_embed(self, query: Union[str, Iterable[str]], **kwargs: Any) -> Iterab
247252
)
248253

249254
@classmethod
250-
def _get_worker_class(cls) -> Type[TextEmbeddingWorker]:
255+
def _get_worker_class(cls) -> Type[TextEmbeddingWorker[NumpyArray]]:
251256
return ColbertEmbeddingWorker
252257

253258

254-
class ColbertEmbeddingWorker(TextEmbeddingWorker):
259+
class ColbertEmbeddingWorker(TextEmbeddingWorker[NumpyArray]):
255260
def init_embedding(self, model_name: str, cache_dir: str, **kwargs: Any) -> Colbert:
256261
return Colbert(
257262
model_name=model_name,

fastembed/late_interaction/jina_colbert.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
11
from typing import Any, Type
22

3-
import numpy as np
4-
5-
from fastembed.late_interaction.colbert import Colbert
6-
from fastembed.text.onnx_text_model import TextEmbeddingWorker
3+
from fastembed.common.types import NumpyArray
4+
from fastembed.late_interaction.colbert import Colbert, ColbertEmbeddingWorker
75

86

97
supported_jina_colbert_models = [
@@ -29,7 +27,7 @@ class JinaColbert(Colbert):
2927
MASK_TOKEN = "<mask>"
3028

3129
@classmethod
32-
def _get_worker_class(cls) -> Type[TextEmbeddingWorker]:
30+
def _get_worker_class(cls) -> Type[ColbertEmbeddingWorker]:
3331
return JinaColbertEmbeddingWorker
3432

3533
@classmethod
@@ -42,8 +40,8 @@ def list_supported_models(cls) -> list[dict[str, Any]]:
4240
return supported_jina_colbert_models
4341

4442
def _preprocess_onnx_input(
45-
self, onnx_input: dict[str, np.ndarray], is_doc: bool = True, **kwargs: Any
46-
) -> dict[str, np.ndarray]:
43+
self, onnx_input: dict[str, NumpyArray], is_doc: bool = True, **kwargs: Any
44+
) -> dict[str, NumpyArray]:
4745
onnx_input = super()._preprocess_onnx_input(onnx_input, is_doc)
4846

4947
# the attention mask for jina-colbert-v2 is always 1 in queries
@@ -52,7 +50,7 @@ def _preprocess_onnx_input(
5250
return onnx_input
5351

5452

55-
class JinaColbertEmbeddingWorker(TextEmbeddingWorker):
53+
class JinaColbertEmbeddingWorker(ColbertEmbeddingWorker):
5654
def init_embedding(self, model_name: str, cache_dir: str, **kwargs: Any) -> JinaColbert:
5755
return JinaColbert(
5856
model_name=model_name,

fastembed/late_interaction/late_interaction_embedding_base.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from typing import Iterable, Optional, Union, Any
22

3-
import numpy as np
4-
3+
from fastembed.common.types import NumpyArray
54
from fastembed.common.model_management import ModelManagement
65

76

@@ -24,10 +23,10 @@ def embed(
2423
batch_size: int = 256,
2524
parallel: Optional[int] = None,
2625
**kwargs: Any,
27-
) -> Iterable[np.ndarray]:
26+
) -> Iterable[NumpyArray]:
2827
raise NotImplementedError()
2928

30-
def passage_embed(self, texts: Iterable[str], **kwargs: Any) -> Iterable[np.ndarray]:
29+
def passage_embed(self, texts: Iterable[str], **kwargs: Any) -> Iterable[NumpyArray]:
3130
"""
3231
Embeds a list of text passages into a list of embeddings.
3332
@@ -36,25 +35,25 @@ def passage_embed(self, texts: Iterable[str], **kwargs: Any) -> Iterable[np.ndar
3635
**kwargs: Additional keyword argument to pass to the embed method.
3736
3837
Yields:
39-
Iterable[np.ndarray]: The embeddings.
38+
Iterable[NdArray]: The embeddings.
4039
"""
4140

4241
# This is model-specific, so that different models can have specialized implementations
4342
yield from self.embed(texts, **kwargs)
4443

45-
def query_embed(self, query: Union[str, Iterable[str]], **kwargs: Any) -> Iterable[np.ndarray]:
44+
def query_embed(self, query: Union[str, Iterable[str]], **kwargs: Any) -> Iterable[NumpyArray]:
4645
"""
4746
Embeds queries
4847
4948
Args:
5049
query (Union[str, Iterable[str]]): The query to embed, or an iterable e.g. list of queries.
5150
5251
Returns:
53-
Iterable[np.ndarray]: The embeddings.
52+
Iterable[NdArray]: The embeddings.
5453
"""
5554

5655
# This is model-specific, so that different models can have specialized implementations
5756
if isinstance(query, str):
5857
yield from self.embed([query], **kwargs)
59-
if isinstance(query, Iterable):
58+
else:
6059
yield from self.embed(query, **kwargs)

fastembed/late_interaction/late_interaction_text_embedding.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from typing import Any, Iterable, Optional, Sequence, Type, Union
22

3-
import numpy as np
4-
3+
from fastembed.common.types import NumpyArray
54
from fastembed.common import OnnxProvider
65
from fastembed.late_interaction.colbert import Colbert
76
from fastembed.late_interaction.jina_colbert import JinaColbert
@@ -38,7 +37,7 @@ def list_supported_models(cls) -> list[dict[str, Any]]:
3837
]
3938
```
4039
"""
41-
result = []
40+
result: list[dict[str, Any]] = []
4241
for embedding in cls.EMBEDDINGS_REGISTRY:
4342
result.extend(embedding.list_supported_models())
4443
return result
@@ -81,7 +80,7 @@ def embed(
8180
batch_size: int = 256,
8281
parallel: Optional[int] = None,
8382
**kwargs: Any,
84-
) -> Iterable[np.ndarray]:
83+
) -> Iterable[NumpyArray]:
8584
"""
8685
Encode a list of documents into list of embeddings.
8786
We use mean pooling with attention so that the model can handle variable-length inputs.
@@ -99,15 +98,15 @@ def embed(
9998
"""
10099
yield from self.model.embed(documents, batch_size, parallel, **kwargs)
101100

102-
def query_embed(self, query: Union[str, Iterable[str]], **kwargs: Any) -> Iterable[np.ndarray]:
101+
def query_embed(self, query: Union[str, Iterable[str]], **kwargs: Any) -> Iterable[NumpyArray]:
103102
"""
104103
Embeds queries
105104
106105
Args:
107106
query (Union[str, Iterable[str]]): The query to embed, or an iterable e.g. list of queries.
108107
109108
Returns:
110-
Iterable[np.ndarray]: The embeddings.
109+
Iterable[NdArray]: The embeddings.
111110
"""
112111

113112
# This is model-specific, so that different models can have specialized implementations

0 commit comments

Comments
 (0)