Skip to content

Commit 33d5ae2

Browse files
Speedup ci (#489)
* chore: Trigger CI test * chore: Trigger CI test * chore: Trigger CI test * chore: Trigger CI test * chore: Trigger CI test * chore: Trigger CI test * chore: Trigger CI test * chore: Trigger CI test * chore: Trigger CI test * Trigger CI * Trigger CI * Trigger CI * Trigger CI test * Trigger CI test * Trigger CI test * Trigger CI test * Trigger CI test * Trigger CI test * Trigger CI test * Trigger CI test * Trigger CI test * Trigger CI test * Trigger CI test * Trigger CI test * Trigger CI test * Trigger CI test * Trigger CI test * Trigger CI test * Trigger CI test * Trigger CI test * Trigger CI test * Trigger CI test * Trigger CI test * Trigger CI test * Trigger CI test * Trigger CI test * Trigger CI test * Trigger CI test * new: Added on workflow dispatch * tests: Updated tests * fix: Fix CI * fix: Fix CI * fix: Fix CI * improve: Prevent stop iteration error caused by next * fix: Fix variable might be referenced before assignment * refactor: Revised the way of getting models to test * fix: Fix test in image model * refactor: Call one model * fix: Fix ci * fix: Fix splade model name * tests: Updated tests * chore: Remove cache * tests: Update multi task tests * tests: Update multi task tests * tests: Updated tests * refactor: refactor utils func, add comments, conditions refactor --------- Co-authored-by: George Panchuk <george.panchuk@qdrant.tech>
1 parent 3ee7c3f commit 33d5ae2

9 files changed

+223
-191
lines changed

.github/workflows/python-tests.yml

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,10 @@ name: Tests
22
run-name: Tests (gpu)
33

44
on:
5-
push:
6-
branches: [ master, main, gpu ]
75
pull_request:
6+
branches: [ master, main, gpu ]
7+
workflow_dispatch:
8+
89

910
env:
1011
CARGO_TERM_COLOR: always
@@ -41,4 +42,4 @@ jobs:
4142
4243
- name: Run pytest
4344
run: |
44-
poetry run pytest
45+
poetry run pytest

tests/test_image_onnx_embeddings.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from fastembed import ImageEmbedding
1010
from tests.config import TEST_MISC_DIR
11-
from tests.utils import delete_model_cache
11+
from tests.utils import delete_model_cache, should_test_model
1212

1313
CANONICAL_VECTOR_VALUES = {
1414
"Qdrant/clip-ViT-B-32-vision": np.array([-0.0098, 0.0128, -0.0274, 0.002, -0.0059]),
@@ -27,11 +27,13 @@
2727
}
2828

2929

30-
def test_embedding() -> None:
30+
@pytest.mark.parametrize("model_name", ["Qdrant/clip-ViT-B-32-vision"])
31+
def test_embedding(model_name: str) -> None:
3132
is_ci = os.getenv("CI")
33+
is_manual = os.getenv("GITHUB_EVENT_NAME") == "workflow_dispatch"
3234

3335
for model_desc in ImageEmbedding._list_supported_models():
34-
if not is_ci and model_desc.size_in_GB > 1:
36+
if not should_test_model(model_desc, model_name, is_ci, is_manual):
3537
continue
3638

3739
dim = model_desc.dim
@@ -74,8 +76,12 @@ def test_batch_embedding(n_dims: int, model_name: str) -> None:
7476

7577
embeddings = list(model.embed(images, batch_size=10))
7678
embeddings = np.stack(embeddings, axis=0)
79+
assert np.allclose(embeddings[1], embeddings[2])
80+
81+
canonical_vector = CANONICAL_VECTOR_VALUES[model_name]
7782

7883
assert embeddings.shape == (len(test_images) * n_images, n_dims)
84+
assert np.allclose(embeddings[0, : canonical_vector.shape[0]], canonical_vector, atol=1e-3)
7985
if is_ci:
8086
delete_model_cache(model.model._model_dir)
8187

tests/test_late_interaction_embeddings.py

Lines changed: 32 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from fastembed.late_interaction.late_interaction_text_embedding import (
77
LateInteractionTextEmbedding,
88
)
9-
from tests.utils import delete_model_cache
9+
from tests.utils import delete_model_cache, should_test_model
1010

1111
# vectors are abridged and rounded for brevity
1212
CANONICAL_COLUMN_VALUES = {
@@ -153,57 +153,70 @@
153153
docs = ["Hello World"]
154154

155155

156-
def test_batch_embedding():
156+
@pytest.mark.parametrize("model_name", ["answerdotai/answerai-colbert-small-v1"])
157+
def test_batch_embedding(model_name: str):
157158
is_ci = os.getenv("CI")
158159
docs_to_embed = docs * 10
159160

160-
for model_name, expected_result in CANONICAL_COLUMN_VALUES.items():
161-
print("evaluating", model_name)
162-
model = LateInteractionTextEmbedding(model_name=model_name)
163-
result = list(model.embed(docs_to_embed, batch_size=6))
161+
model = LateInteractionTextEmbedding(model_name=model_name)
162+
result = list(model.embed(docs_to_embed, batch_size=6))
163+
expected_result = CANONICAL_COLUMN_VALUES[model_name]
164164

165-
for value in result:
166-
token_num, abridged_dim = expected_result.shape
167-
assert np.allclose(value[:, :abridged_dim], expected_result, atol=2e-3)
165+
for value in result:
166+
token_num, abridged_dim = expected_result.shape
167+
assert np.allclose(value[:, :abridged_dim], expected_result, atol=2e-3)
168168

169-
if is_ci:
170-
delete_model_cache(model.model._model_dir)
169+
if is_ci:
170+
delete_model_cache(model.model._model_dir)
171171

172172

173-
def test_single_embedding():
173+
@pytest.mark.parametrize("model_name", ["answerdotai/answerai-colbert-small-v1"])
174+
def test_single_embedding(model_name: str):
174175
is_ci = os.getenv("CI")
176+
is_manual = os.getenv("GITHUB_EVENT_NAME") == "workflow_dispatch"
175177
docs_to_embed = docs
176178

177-
for model_name, expected_result in CANONICAL_COLUMN_VALUES.items():
179+
for model_desc in LateInteractionTextEmbedding._list_supported_models():
180+
if not should_test_model(model_desc, model_name, is_ci, is_manual):
181+
continue
182+
178183
print("evaluating", model_name)
179184
model = LateInteractionTextEmbedding(model_name=model_name)
180185
result = next(iter(model.embed(docs_to_embed, batch_size=6)))
186+
expected_result = CANONICAL_COLUMN_VALUES[model_name]
181187
token_num, abridged_dim = expected_result.shape
182188
assert np.allclose(result[:, :abridged_dim], expected_result, atol=2e-3)
183189

184190
if is_ci:
185191
delete_model_cache(model.model._model_dir)
186192

187193

188-
def test_single_embedding_query():
194+
@pytest.mark.parametrize("model_name", ["answerdotai/answerai-colbert-small-v1"])
195+
def test_single_embedding_query(model_name: str):
189196
is_ci = os.getenv("CI")
197+
is_manual = os.getenv("GITHUB_EVENT_NAME") == "workflow_dispatch"
190198
queries_to_embed = docs
191199

192-
for model_name, expected_result in CANONICAL_QUERY_VALUES.items():
200+
for model_desc in LateInteractionTextEmbedding._list_supported_models():
201+
if not should_test_model(model_desc, model_name, is_ci, is_manual):
202+
continue
203+
193204
print("evaluating", model_name)
194205
model = LateInteractionTextEmbedding(model_name=model_name)
195206
result = next(iter(model.query_embed(queries_to_embed)))
207+
expected_result = CANONICAL_QUERY_VALUES[model_name]
196208
token_num, abridged_dim = expected_result.shape
197209
assert np.allclose(result[:, :abridged_dim], expected_result, atol=2e-3)
198210

199211
if is_ci:
200212
delete_model_cache(model.model._model_dir)
201213

202214

203-
def test_parallel_processing():
215+
@pytest.mark.parametrize("token_dim,model_name", [(96, "answerdotai/answerai-colbert-small-v1")])
216+
def test_parallel_processing(token_dim: int, model_name: str):
204217
is_ci = os.getenv("CI")
205-
model = LateInteractionTextEmbedding(model_name="colbert-ir/colbertv2.0")
206-
token_dim = 128
218+
model = LateInteractionTextEmbedding(model_name=model_name)
219+
207220
docs = ["hello world", "flag embedding"] * 100
208221
embeddings = list(model.embed(docs, batch_size=10, parallel=2))
209222
embeddings = np.stack(embeddings, axis=0)
@@ -222,10 +235,7 @@ def test_parallel_processing():
222235
delete_model_cache(model.model._model_dir)
223236

224237

225-
@pytest.mark.parametrize(
226-
"model_name",
227-
["colbert-ir/colbertv2.0"],
228-
)
238+
@pytest.mark.parametrize("model_name", ["answerdotai/answerai-colbert-small-v1"])
229239
def test_lazy_load(model_name: str):
230240
is_ci = os.getenv("CI")
231241

tests/test_late_interaction_multimodal.py

Lines changed: 28 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import os
22

3+
import pytest
34
from PIL import Image
45
import numpy as np
56

@@ -45,38 +46,38 @@
4546

4647

4748
def test_batch_embedding():
48-
is_ci = os.getenv("CI")
49+
if os.getenv("CI"):
50+
pytest.skip("Colpali is too large to test in CI")
4951

50-
if not is_ci:
51-
for model_name, expected_result in CANONICAL_IMAGE_VALUES.items():
52-
print("evaluating", model_name)
53-
model = LateInteractionMultimodalEmbedding(model_name=model_name)
54-
result = list(model.embed_image(images, batch_size=2))
52+
for model_name, expected_result in CANONICAL_IMAGE_VALUES.items():
53+
print("evaluating", model_name)
54+
model = LateInteractionMultimodalEmbedding(model_name=model_name)
55+
result = list(model.embed_image(images, batch_size=2))
5556

56-
for value in result:
57-
token_num, abridged_dim = expected_result.shape
58-
assert np.allclose(value[:token_num, :abridged_dim], expected_result, atol=2e-3)
57+
for value in result:
58+
token_num, abridged_dim = expected_result.shape
59+
assert np.allclose(value[:token_num, :abridged_dim], expected_result, atol=2e-3)
5960

6061

6162
def test_single_embedding():
62-
is_ci = os.getenv("CI")
63-
if not is_ci:
64-
for model_name, expected_result in CANONICAL_IMAGE_VALUES.items():
65-
print("evaluating", model_name)
66-
model = LateInteractionMultimodalEmbedding(model_name=model_name)
67-
result = next(iter(model.embed_image(images, batch_size=6)))
68-
token_num, abridged_dim = expected_result.shape
69-
assert np.allclose(result[:token_num, :abridged_dim], expected_result, atol=2e-3)
63+
if os.getenv("CI"):
64+
pytest.skip("Colpali is too large to test in CI")
65+
66+
for model_name, expected_result in CANONICAL_IMAGE_VALUES.items():
67+
print("evaluating", model_name)
68+
model = LateInteractionMultimodalEmbedding(model_name=model_name)
69+
result = next(iter(model.embed_image(images, batch_size=6)))
70+
token_num, abridged_dim = expected_result.shape
71+
assert np.allclose(result[:token_num, :abridged_dim], expected_result, atol=2e-3)
7072

7173

7274
def test_single_embedding_query():
73-
is_ci = os.getenv("CI")
74-
if not is_ci:
75-
queries_to_embed = queries
76-
77-
for model_name, expected_result in CANONICAL_QUERY_VALUES.items():
78-
print("evaluating", model_name)
79-
model = LateInteractionMultimodalEmbedding(model_name=model_name)
80-
result = next(iter(model.embed_text(queries_to_embed)))
81-
token_num, abridged_dim = expected_result.shape
82-
assert np.allclose(result[:token_num, :abridged_dim], expected_result, atol=2e-3)
75+
if os.getenv("CI"):
76+
pytest.skip("Colpali is too large to test in CI")
77+
78+
for model_name, expected_result in CANONICAL_QUERY_VALUES.items():
79+
print("evaluating", model_name)
80+
model = LateInteractionMultimodalEmbedding(model_name=model_name)
81+
result = next(iter(model.embed_text(queries)))
82+
token_num, abridged_dim = expected_result.shape
83+
assert np.allclose(result[:token_num, :abridged_dim], expected_result, atol=2e-3)

tests/test_sparse_embeddings.py

Lines changed: 30 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,10 @@
55

66
from fastembed.sparse.bm25 import Bm25
77
from fastembed.sparse.sparse_text_embedding import SparseTextEmbedding
8-
from tests.utils import delete_model_cache
8+
from tests.utils import delete_model_cache, should_test_model
99

1010
CANONICAL_COLUMN_VALUES = {
11-
"prithvida/Splade_PP_en_v1": {
11+
"prithivida/Splade_PP_en_v1": {
1212
"indices": [
1313
2040,
1414
2047,
@@ -49,28 +49,41 @@
4949
docs = ["Hello World"]
5050

5151

52-
def test_batch_embedding() -> None:
52+
@pytest.mark.parametrize("model_name", ["prithivida/Splade_PP_en_v1"])
53+
def test_batch_embedding(model_name: str) -> None:
5354
is_ci = os.getenv("CI")
5455
docs_to_embed = docs * 10
5556

56-
for model_name, expected_result in CANONICAL_COLUMN_VALUES.items():
57-
model = SparseTextEmbedding(model_name=model_name)
58-
result = next(iter(model.embed(docs_to_embed, batch_size=6)))
59-
assert result.indices.tolist() == expected_result["indices"]
57+
model = SparseTextEmbedding(model_name=model_name)
58+
result = next(iter(model.embed(docs_to_embed, batch_size=6)))
59+
expected_result = CANONICAL_COLUMN_VALUES[model_name]
60+
assert result.indices.tolist() == expected_result["indices"]
6061

61-
for i, value in enumerate(result.values):
62-
assert pytest.approx(value, abs=0.001) == expected_result["values"][i]
63-
if is_ci:
64-
delete_model_cache(model.model._model_dir)
62+
for i, value in enumerate(result.values):
63+
assert pytest.approx(value, abs=0.001) == expected_result["values"][i]
64+
if is_ci:
65+
delete_model_cache(model.model._model_dir)
6566

6667

67-
def test_single_embedding() -> None:
68+
@pytest.mark.parametrize("model_name", ["prithivida/Splade_PP_en_v1"])
69+
def test_single_embedding(model_name: str) -> None:
6870
is_ci = os.getenv("CI")
69-
for model_name, expected_result in CANONICAL_COLUMN_VALUES.items():
71+
is_manual = os.getenv("GITHUB_EVENT_NAME") == "workflow_dispatch"
72+
73+
for model_desc in SparseTextEmbedding._list_supported_models():
74+
if (
75+
model_desc.model not in CANONICAL_COLUMN_VALUES
76+
): # attention models and bm25 are also parts of
77+
# SparseTextEmbedding, however, they have their own tests
78+
continue
79+
if not should_test_model(model_desc, model_name, is_ci, is_manual):
80+
continue
81+
7082
model = SparseTextEmbedding(model_name=model_name)
7183

7284
passage_result = next(iter(model.embed(docs, batch_size=6)))
7385
query_result = next(iter(model.query_embed(docs)))
86+
expected_result = CANONICAL_COLUMN_VALUES[model_name]
7487
for result in [passage_result, query_result]:
7588
assert result.indices.tolist() == expected_result["indices"]
7689

@@ -80,9 +93,10 @@ def test_single_embedding() -> None:
8093
delete_model_cache(model.model._model_dir)
8194

8295

83-
def test_parallel_processing() -> None:
96+
@pytest.mark.parametrize("model_name", ["prithivida/Splade_PP_en_v1"])
97+
def test_parallel_processing(model_name: str) -> None:
8498
is_ci = os.getenv("CI")
85-
model = SparseTextEmbedding(model_name="prithivida/Splade_PP_en_v1")
99+
model = SparseTextEmbedding(model_name=model_name)
86100
docs = ["hello world", "flag embedding"] * 30
87101
sparse_embeddings_duo = list(model.embed(docs, batch_size=10, parallel=2))
88102
sparse_embeddings_all = list(model.embed(docs, batch_size=10, parallel=0))
@@ -172,10 +186,7 @@ def test_disable_stemmer_behavior(disable_stemmer: bool) -> None:
172186
assert result == expected, f"Expected {expected}, but got {result}"
173187

174188

175-
@pytest.mark.parametrize(
176-
"model_name",
177-
["prithivida/Splade_PP_en_v1"],
178-
)
189+
@pytest.mark.parametrize("model_name", ["prithivida/Splade_PP_en_v1"])
179190
def test_lazy_load(model_name: str) -> None:
180191
is_ci = os.getenv("CI")
181192
model = SparseTextEmbedding(model_name=model_name, lazy_load=True)

0 commit comments

Comments
 (0)