Skip to content

Commit 0d5821d

Browse files
committed
refactor
1 parent c8eabdb commit 0d5821d

File tree

2 files changed

+32
-13
lines changed

2 files changed

+32
-13
lines changed

fastembed/text/multitask_embedding.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -60,9 +60,14 @@ def _list_supported_models(cls) -> list[DenseModelDescription]:
6060
return supported_multitask_models
6161

6262
def _preprocess_onnx_input(
63-
self, onnx_input: dict[str, NumpyArray], **kwargs: Any
63+
self,
64+
onnx_input: dict[str, NumpyArray],
65+
task_id: Optional[Union[int, Task]] = None,
66+
**kwargs: Any,
6467
) -> dict[str, NumpyArray]:
65-
onnx_input["task_id"] = np.array(kwargs["task_id"], dtype=np.int64)
68+
if task_id is None:
69+
raise ValueError(f"task_id must be provided for JinaEmbeddingV3, got <{task_id}>")
70+
onnx_input["task_id"] = np.array(task_id, dtype=np.int64)
6671
return onnx_input
6772

6873
def embed(
@@ -73,18 +78,16 @@ def embed(
7378
task_id: Optional[int] = None,
7479
**kwargs: Any,
7580
) -> Iterable[NumpyArray]:
76-
kwargs["task_id"] = (
81+
task_id = (
7782
task_id if task_id is not None else self.default_task_id
7883
) # required for multiprocessing
79-
yield from super().embed(documents, batch_size, parallel, **kwargs)
84+
yield from super().embed(documents, batch_size, parallel, task_id=task_id, **kwargs)
8085

8186
def query_embed(self, query: Union[str, Iterable[str]], **kwargs: Any) -> Iterable[NumpyArray]:
82-
kwargs["task_id"] = self.QUERY_TASK
83-
yield from super().embed(query, **kwargs)
87+
yield from super().embed(query, task_id=self.QUERY_TASK, **kwargs)
8488

8589
def passage_embed(self, texts: Iterable[str], **kwargs: Any) -> Iterable[NumpyArray]:
86-
kwargs["task_id"] = self.PASSAGE_TASK
87-
yield from super().embed(texts, **kwargs)
90+
yield from super().embed(texts, task_id=self.PASSAGE_TASK, **kwargs)
8891

8992

9093
class JinaEmbeddingV3Worker(OnnxTextEmbeddingWorker):
@@ -94,15 +97,15 @@ def init_embedding(
9497
cache_dir: str,
9598
**kwargs: Any,
9699
) -> JinaEmbeddingV3:
97-
model = JinaEmbeddingV3(
100+
return JinaEmbeddingV3(
98101
model_name=model_name,
99102
cache_dir=cache_dir,
100103
threads=1,
101104
**kwargs,
102105
)
103-
return model
104106

105107
def process(self, items: Iterable[tuple[int, Any]]) -> Iterable[tuple[int, OnnxOutputContext]]:
108+
self.model: JinaEmbeddingV3 # mypy complaints `self.model` does not have `default_task_id`
106109
for idx, batch in items:
107110
onnx_output = self.model.onnx_embed(batch, task_id=self.model.default_task_id)
108111
yield idx, onnx_output

tests/test_text_multitask_embeddings.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -109,9 +109,25 @@ def test_single_embedding():
109109

110110
canonical_vector = task["vectors"]
111111
assert np.allclose(
112-
embeddings[: len(docs), : canonical_vector.shape[1]], canonical_vector, atol=1e-4
112+
embeddings[:, : canonical_vector.shape[1]], canonical_vector, atol=1e-4
113113
), model_desc.model
114114

115+
classification_embeddings = list(model.embed(documents=docs, task_id=Task.CLASSIFICATION))
116+
classification_embeddings = np.stack(classification_embeddings, axis=0)
117+
118+
assert classification_embeddings.shape == (len(docs), dim)
119+
120+
model = TextEmbedding(model_name=model_name, task_id=Task.CLASSIFICATION)
121+
default_embeddings = list(model.embed(documents=docs))
122+
default_embeddings = np.stack(default_embeddings, axis=0)
123+
124+
assert default_embeddings.shape == (len(docs), dim)
125+
126+
assert np.allclose(
127+
classification_embeddings,
128+
default_embeddings,
129+
atol=1e-4,
130+
), model_desc.model
115131
if is_ci:
116132
delete_model_cache(model.model._model_dir)
117133

@@ -140,7 +156,7 @@ def test_single_embedding_query():
140156

141157
canonical_vector = CANONICAL_VECTOR_VALUES[model_name][task_id]["vectors"]
142158
assert np.allclose(
143-
embeddings[: len(docs), : canonical_vector.shape[1]], canonical_vector, atol=1e-4
159+
embeddings[:, : canonical_vector.shape[1]], canonical_vector, atol=1e-4
144160
), model_desc.model
145161

146162
if is_ci:
@@ -172,7 +188,7 @@ def test_single_embedding_passage():
172188

173189
canonical_vector = CANONICAL_VECTOR_VALUES[model_name][task_id]["vectors"]
174190
assert np.allclose(
175-
embeddings[: len(docs), : canonical_vector.shape[1]], canonical_vector, atol=1e-4
191+
embeddings[:, : canonical_vector.shape[1]], canonical_vector, atol=1e-4
176192
), model_desc.model
177193

178194
if is_ci:

0 commit comments

Comments
 (0)