Skip to content

Commit 4dc76e3

Browse files
authored
fix: fix colbert query postprocessing (#557)
* fix: fix colbert query postprocessing * fix: improve colbert single embedding tests
1 parent 6efe06b commit 4dc76e3

File tree

2 files changed

+25
-24
lines changed

2 files changed

+25
-24
lines changed

fastembed/late_interaction/colbert.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -49,24 +49,24 @@ def _post_process_onnx_output(
4949
if not is_doc:
5050
for embedding in output.model_output:
5151
yield embedding
52-
53-
if output.input_ids is None or output.attention_mask is None:
54-
raise ValueError(
55-
"input_ids and attention_mask must be provided for document post-processing"
56-
)
57-
58-
for i, token_sequence in enumerate(output.input_ids):
59-
for j, token_id in enumerate(token_sequence): # type: ignore
60-
if token_id in self.skip_list or token_id == self.pad_token_id:
61-
output.attention_mask[i, j] = 0
62-
63-
output.model_output *= np.expand_dims(output.attention_mask, 2)
64-
norm = np.linalg.norm(output.model_output, ord=2, axis=2, keepdims=True)
65-
norm_clamped = np.maximum(norm, 1e-12)
66-
output.model_output /= norm_clamped
67-
68-
for embedding, attention_mask in zip(output.model_output, output.attention_mask):
69-
yield embedding[attention_mask == 1]
52+
else:
53+
if output.input_ids is None or output.attention_mask is None:
54+
raise ValueError(
55+
"input_ids and attention_mask must be provided for document post-processing"
56+
)
57+
58+
for i, token_sequence in enumerate(output.input_ids):
59+
for j, token_id in enumerate(token_sequence): # type: ignore
60+
if token_id in self.skip_list or token_id == self.pad_token_id:
61+
output.attention_mask[i, j] = 0
62+
63+
output.model_output *= np.expand_dims(output.attention_mask, 2)
64+
norm = np.linalg.norm(output.model_output, ord=2, axis=2, keepdims=True)
65+
norm_clamped = np.maximum(norm, 1e-12)
66+
output.model_output /= norm_clamped
67+
68+
for embedding, attention_mask in zip(output.model_output, output.attention_mask):
69+
yield embedding[attention_mask == 1]
7070

7171
def _preprocess_onnx_input(
7272
self, onnx_input: dict[str, NumpyArray], is_doc: bool = True, **kwargs: Any

tests/test_late_interaction_embeddings.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -175,10 +175,7 @@ def test_batch_inference_size_same_as_single_inference(model_name: str):
175175
is_ci = os.getenv("CI")
176176

177177
model = LateInteractionTextEmbedding(model_name=model_name)
178-
docs_to_embed = [
179-
"short document",
180-
"A bit longer document, which should not affect the size"
181-
]
178+
docs_to_embed = ["short document", "A bit longer document, which should not affect the size"]
182179
result = list(model.embed(docs_to_embed, batch_size=1))
183180
result_2 = list(model.embed(docs_to_embed, batch_size=2))
184181
assert len(result[0]) == len(result_2[0])
@@ -199,7 +196,9 @@ def test_single_embedding(model_name: str):
199196

200197
print("evaluating", model_name)
201198
model = LateInteractionTextEmbedding(model_name=model_name)
202-
result = next(iter(model.embed(docs_to_embed, batch_size=6)))
199+
whole_result = list(model.embed(docs_to_embed, batch_size=6))
200+
assert len(whole_result) == 1
201+
result = whole_result[0]
203202
expected_result = CANONICAL_COLUMN_VALUES[model_name]
204203
token_num, abridged_dim = expected_result.shape
205204
assert np.allclose(result[:, :abridged_dim], expected_result, atol=2e-3)
@@ -220,7 +219,9 @@ def test_single_embedding_query(model_name: str):
220219

221220
print("evaluating", model_name)
222221
model = LateInteractionTextEmbedding(model_name=model_name)
223-
result = next(iter(model.query_embed(queries_to_embed)))
222+
whole_result = list(model.query_embed(queries_to_embed))
223+
assert len(whole_result) == 1
224+
result = whole_result[0]
224225
expected_result = CANONICAL_QUERY_VALUES[model_name]
225226
token_num, abridged_dim = expected_result.shape
226227
assert np.allclose(result[:, :abridged_dim], expected_result, atol=2e-3)

0 commit comments

Comments
 (0)