Skip to content

Commit 16aebc0

Browse files
refactor: Refactor query_embed and passage_embed
1 parent dd6111c commit 16aebc0

File tree

1 file changed

+2
-15
lines changed

1 file changed

+2
-15
lines changed

fastembed/text/multitask_embedding.py

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -75,24 +75,11 @@ def embed(
7575

7676
def query_embed(self, query: Union[str, Iterable[str]], **kwargs) -> Iterable[np.ndarray]:
7777
self._current_task_id = self.QUERY_TASK
78-
79-
if isinstance(query, str):
80-
query = [query]
81-
82-
if not hasattr(self, "model") or self.model is None:
83-
self.load_onnx_model()
84-
85-
for text in query:
86-
yield from self._post_process_onnx_output(self.onnx_embed([text]))
78+
yield from super().embed(query, **kwargs)
8779

8880
def passage_embed(self, texts: Iterable[str], **kwargs) -> Iterable[np.ndarray]:
8981
self._current_task_id = self.PASSAGE_TASK
90-
91-
if not hasattr(self, "model") or self.model is None:
92-
self.load_onnx_model()
93-
94-
for text in texts:
95-
yield from self._post_process_onnx_output(self.onnx_embed([text]))
82+
yield from super().embed(texts, **kwargs)
9683

9784

9885
class JinaEmbeddingV3Worker(OnnxTextEmbeddingWorker):

0 commit comments

Comments
 (0)