Skip to content

Commit faf3d9f

Browse files
authored
batch inference should return same shape as individual inference (#547)
1 parent acec312 commit faf3d9f

File tree

2 files changed

+27
-8
lines changed

2 files changed

+27
-8
lines changed

fastembed/late_interaction/colbert.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,8 @@ def _post_process_onnx_output(
4646
self, output: OnnxOutputContext, is_doc: bool = True, **kwargs: Any
4747
) -> Iterable[NumpyArray]:
4848
if not is_doc:
49-
return output.model_output
49+
for embedding in output.model_output:
50+
yield embedding
5051

5152
if output.input_ids is None or output.attention_mask is None:
5253
raise ValueError(
@@ -62,7 +63,9 @@ def _post_process_onnx_output(
6263
norm = np.linalg.norm(output.model_output, ord=2, axis=2, keepdims=True)
6364
norm_clamped = np.maximum(norm, 1e-12)
6465
output.model_output /= norm_clamped
65-
return output.model_output
66+
67+
for embedding, attention_mask in zip(output.model_output, output.attention_mask):
68+
yield embedding[attention_mask == 1]
6669

6770
def _preprocess_onnx_input(
6871
self, onnx_input: dict[str, NumpyArray], is_doc: bool = True, **kwargs: Any

tests/test_late_interaction_embeddings.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,23 @@ def test_batch_embedding(model_name: str):
170170
delete_model_cache(model.model._model_dir)
171171

172172

173+
@pytest.mark.parametrize("model_name", ["answerdotai/answerai-colbert-small-v1"])
174+
def test_batch_inference_size_same_as_single_inference(model_name: str):
175+
is_ci = os.getenv("CI")
176+
177+
model = LateInteractionTextEmbedding(model_name=model_name)
178+
docs_to_embed = [
179+
"short document",
180+
"A bit longer document, which should not affect the size"
181+
]
182+
result = list(model.embed(docs_to_embed, batch_size=1))
183+
result_2 = list(model.embed(docs_to_embed, batch_size=2))
184+
assert len(result[0]) == len(result_2[0])
185+
186+
if is_ci:
187+
delete_model_cache(model.model._model_dir)
188+
189+
173190
@pytest.mark.parametrize("model_name", ["answerdotai/answerai-colbert-small-v1"])
174191
def test_single_embedding(model_name: str):
175192
is_ci = os.getenv("CI")
@@ -219,17 +236,16 @@ def test_parallel_processing(token_dim: int, model_name: str):
219236

220237
docs = ["hello world", "flag embedding"] * 100
221238
embeddings = list(model.embed(docs, batch_size=10, parallel=2))
222-
embeddings = np.stack(embeddings, axis=0)
223239

224240
embeddings_2 = list(model.embed(docs, batch_size=10, parallel=None))
225-
embeddings_2 = np.stack(embeddings_2, axis=0)
226241

227242
embeddings_3 = list(model.embed(docs, batch_size=10, parallel=0))
228-
embeddings_3 = np.stack(embeddings_3, axis=0)
229243

230-
assert embeddings.shape[0] == len(docs) and embeddings.shape[-1] == token_dim
231-
assert np.allclose(embeddings, embeddings_2, atol=1e-3)
232-
assert np.allclose(embeddings, embeddings_3, atol=1e-3)
244+
assert len(embeddings) == len(docs) and embeddings[0].shape[-1] == token_dim
245+
246+
for i in range(len(embeddings)):
247+
assert np.allclose(embeddings[i], embeddings_2[i], atol=1e-3)
248+
assert np.allclose(embeddings[i], embeddings_3[i], atol=1e-3)
233249

234250
if is_ci:
235251
delete_model_cache(model.model._model_dir)

0 commit comments

Comments
 (0)