diff --git a/.github/workflows/python-tests.yml b/.github/workflows/python-tests.yml index fda03169..949333dd 100644 --- a/.github/workflows/python-tests.yml +++ b/.github/workflows/python-tests.yml @@ -24,6 +24,20 @@ jobs: - ubuntu-latest - macos-latest - windows-latest + exclude: + # Exclude 3.10–3.12 for macOS and Windows + - os: macos-latest + python-version: '3.10.x' + - os: macos-latest + python-version: '3.11.x' + - os: macos-latest + python-version: '3.12.x' + - os: windows-latest + python-version: '3.10.x' + - os: windows-latest + python-version: '3.11.x' + - os: windows-latest + python-version: '3.12.x' runs-on: ${{ matrix.os }} diff --git a/fastembed/sparse/bm42.py b/fastembed/sparse/bm42.py index 5fb90d71..3e51404f 100644 --- a/fastembed/sparse/bm42.py +++ b/fastembed/sparse/bm42.py @@ -31,9 +31,17 @@ ), ] -MODEL_TO_LANGUAGE = { + +_MODEL_TO_LANGUAGE = { "Qdrant/bm42-all-minilm-l6-v2-attentions": "english", } +MODEL_TO_LANGUAGE = { + model_name.lower(): language for model_name, language in _MODEL_TO_LANGUAGE.items() +} + + +def get_language_by_model_name(model_name: str) -> str: + return MODEL_TO_LANGUAGE[model_name.lower()] class Bm42(SparseTextEmbeddingBase, OnnxTextModel[SparseEmbedding]): @@ -124,7 +132,7 @@ def __init__( self.special_tokens_ids: set[int] = set() self.punctuation = set(string.punctuation) self.stopwords = set(self._load_stopwords(self._model_dir)) - self.stemmer = SnowballStemmer(MODEL_TO_LANGUAGE[model_name]) + self.stemmer = SnowballStemmer(get_language_by_model_name(self.model_name)) self.alpha = alpha if not self.lazy_load: diff --git a/fastembed/sparse/minicoil.py b/fastembed/sparse/minicoil.py index 475cc9d8..efaa9abb 100644 --- a/fastembed/sparse/minicoil.py +++ b/fastembed/sparse/minicoil.py @@ -46,9 +46,16 @@ ), ] -MODEL_TO_LANGUAGE = { +_MODEL_TO_LANGUAGE = { "Qdrant/minicoil-v1": "english", } +MODEL_TO_LANGUAGE = { + model_name.lower(): language for model_name, language in _MODEL_TO_LANGUAGE.items() +} + + +def get_language_by_model_name(model_name: str) -> str: + return MODEL_TO_LANGUAGE[model_name.lower()] class MiniCOIL(SparseTextEmbeddingBase, OnnxTextModel[SparseEmbedding]): @@ -156,7 +163,7 @@ def load_onnx_model(self) -> None: self.special_tokens_ids = set(self.special_token_to_id.values()) self.stopwords = set(self._load_stopwords(self._model_dir)) - stemmer = SnowballStemmer(MODEL_TO_LANGUAGE[self.model_name]) + stemmer = SnowballStemmer(get_language_by_model_name(self.model_name)) self.vocab_resolver = VocabResolver( tokenizer=VocabTokenizer(self.tokenizer), diff --git a/tests/test_attention_embeddings.py b/tests/test_attention_embeddings.py index a3598e06..8f62af43 100644 --- a/tests/test_attention_embeddings.py +++ b/tests/test_attention_embeddings.py @@ -1,4 +1,5 @@ import os +from contextlib import contextmanager import numpy as np import pytest @@ -7,98 +8,119 @@ from tests.utils import delete_model_cache -@pytest.mark.parametrize("model_name", ["Qdrant/bm42-all-minilm-l6-v2-attentions", "Qdrant/bm25"]) -def test_attention_embeddings(model_name: str) -> None: - is_ci = os.getenv("CI") - model = SparseTextEmbedding(model_name=model_name) - - output = list( - model.query_embed( - [ - "I must not fear. Fear is the mind-killer.", - ] - ) - ) - - assert len(output) == 1 - - for result in output: - assert len(result.indices) == len(result.values) - assert np.allclose(result.values, np.ones(len(result.values))) - - quotes = [ - "I must not fear. Fear is the mind-killer.", - "All animals are equal, but some animals are more equal than others.", - "It was a pleasure to burn.", - "The sky above the port was the color of television, tuned to a dead channel.", - "In the beginning, the universe was created." - " This has made a lot of people very angry and been widely regarded as a bad move.", - "It's a truth universally acknowledged that a zombie in possession of brains must be in want of more brains.", - "War is peace. Freedom is slavery. Ignorance is strength.", - "We're not in Infinity; we're in the suburbs.", - "I was a thousand times more evil than thou!", - "History is merely a list of surprises... It can only prepare us to be surprised yet again.", - ".", # Empty string - ] - - output = list(model.embed(quotes)) - - assert len(output) == len(quotes) - - for result in output[:-1]: - assert len(result.indices) == len(result.values) - assert len(result.indices) > 0 - - assert len(output[-1].indices) == 0 - - # Test support for unknown languages - output = list( - model.query_embed( - [ - "привет мир!", - ] - ) - ) +_MODELS_TO_CACHE = ("Qdrant/bm42-all-minilm-l6-v2-attentions", "Qdrant/bm25") +MODELS_TO_CACHE = tuple([x.lower() for x in _MODELS_TO_CACHE]) - assert len(output) == 1 - for result in output: - assert len(result.indices) == len(result.values) - assert len(result.indices) == 2 +@pytest.fixture(scope="module") +def model_cache(): + is_ci = os.getenv("CI") + cache = {} + + @contextmanager + def get_model(model_name: str): + lowercase_model_name = model_name.lower() + if lowercase_model_name not in cache: + cache[lowercase_model_name] = SparseTextEmbedding(lowercase_model_name) + yield cache[lowercase_model_name] + if lowercase_model_name not in MODELS_TO_CACHE: + print("deleting model") + model_inst = cache.pop(lowercase_model_name) + if is_ci: + delete_model_cache(model_inst.model._model_dir) + del model_inst + + yield get_model if is_ci: - delete_model_cache(model.model._model_dir) + for name, model in cache.items(): + delete_model_cache(model.model._model_dir) + cache.clear() @pytest.mark.parametrize("model_name", ["Qdrant/bm42-all-minilm-l6-v2-attentions", "Qdrant/bm25"]) -def test_parallel_processing(model_name: str) -> None: - is_ci = os.getenv("CI") +def test_attention_embeddings(model_cache, model_name: str) -> None: + with model_cache(model_name) as model: + output = list( + model.query_embed( + [ + "I must not fear. Fear is the mind-killer.", + ] + ) + ) - model = SparseTextEmbedding(model_name=model_name) + assert len(output) == 1 + + for result in output: + assert len(result.indices) == len(result.values) + assert np.allclose(result.values, np.ones(len(result.values))) + + quotes = [ + "I must not fear. Fear is the mind-killer.", + "All animals are equal, but some animals are more equal than others.", + "It was a pleasure to burn.", + "The sky above the port was the color of television, tuned to a dead channel.", + "In the beginning, the universe was created." + " This has made a lot of people very angry and been widely regarded as a bad move.", + "It's a truth universally acknowledged that a zombie in possession of brains must be in want of more brains.", + "War is peace. Freedom is slavery. Ignorance is strength.", + "We're not in Infinity; we're in the suburbs.", + "I was a thousand times more evil than thou!", + "History is merely a list of surprises... It can only prepare us to be surprised yet again.", + ".", # Empty string + ] + + output = list(model.embed(quotes)) + + assert len(output) == len(quotes) + + for result in output[:-1]: + assert len(result.indices) == len(result.values) + assert len(result.indices) > 0 + + assert len(output[-1].indices) == 0 + + # Test support for unknown languages + output = list( + model.query_embed( + [ + "привет мир!", + ] + ) + ) - docs = ["hello world", "attention embedding", "Mangez-vous vraiment des grenouilles?"] * 100 - embeddings = list(model.embed(docs, batch_size=10, parallel=2)) + assert len(output) == 1 - embeddings_2 = list(model.embed(docs, batch_size=10, parallel=None)) + for result in output: + assert len(result.indices) == len(result.values) + assert len(result.indices) == 2 - embeddings_3 = list(model.embed(docs, batch_size=10, parallel=0)) - assert len(embeddings) == len(docs) +@pytest.mark.parametrize("model_name", ["Qdrant/bm42-all-minilm-l6-v2-attentions", "Qdrant/bm25"]) +def test_parallel_processing(model_cache, model_name: str) -> None: + with model_cache(model_name) as model: + docs = [ + "hello world", + "attention embedding", + "Mangez-vous vraiment des grenouilles?", + ] * 100 + embeddings = list(model.embed(docs, batch_size=10, parallel=2)) - for emb_1, emb_2, emb_3 in zip(embeddings, embeddings_2, embeddings_3): - assert np.allclose(emb_1.indices, emb_2.indices) - assert np.allclose(emb_1.indices, emb_3.indices) - assert np.allclose(emb_1.values, emb_2.values) - assert np.allclose(emb_1.values, emb_3.values) + embeddings_2 = list(model.embed(docs, batch_size=10, parallel=None)) - if is_ci: - delete_model_cache(model.model._model_dir) + embeddings_3 = list(model.embed(docs, batch_size=10, parallel=0)) + assert len(embeddings) == len(docs) + + for emb_1, emb_2, emb_3 in zip(embeddings, embeddings_2, embeddings_3): + assert np.allclose(emb_1.indices, emb_2.indices) + assert np.allclose(emb_1.indices, emb_3.indices) + assert np.allclose(emb_1.values, emb_2.values) + assert np.allclose(emb_1.values, emb_3.values) -@pytest.mark.parametrize("model_name", ["Qdrant/bm25"]) -def test_multilanguage(model_name: str) -> None: - is_ci = os.getenv("CI") +@pytest.mark.parametrize("model_name", ["Qdrant/bm25"]) +def test_multilanguage(model_cache, model_name: str) -> None: docs = ["Mangez-vous vraiment des grenouilles?", "Je suis au lit"] model = SparseTextEmbedding(model_name=model_name, language="french") @@ -109,39 +131,30 @@ def test_multilanguage(model_name: str) -> None: assert embeddings[1].values.shape == (1,) assert embeddings[1].indices.shape == (1,) - model = SparseTextEmbedding(model_name=model_name, language="english") - embeddings = list(model.embed(docs))[:2] - assert embeddings[0].values.shape == (5,) - assert embeddings[0].indices.shape == (5,) + with model_cache(model_name) as model: # language = "english" + embeddings = list(model.embed(docs))[:2] + assert embeddings[0].values.shape == (5,) + assert embeddings[0].indices.shape == (5,) - assert embeddings[1].values.shape == (4,) - assert embeddings[1].indices.shape == (4,) - - if is_ci: - delete_model_cache(model.model._model_dir) + assert embeddings[1].values.shape == (4,) + assert embeddings[1].indices.shape == (4,) @pytest.mark.parametrize("model_name", ["Qdrant/bm25"]) -def test_special_characters(model_name: str) -> None: - is_ci = os.getenv("CI") - - docs = [ - "Über den größten Flüssen Österreichs äußern sich Experten häufig: Öko-Systeme müssen geschützt werden!", - "L'élève français s'écrie : « Où est mon crayon ? J'ai besoin de finir cet exercice avant la récréation!", - "Într-o zi însorită, Ștefan și Ioana au mâncat mămăligă cu brânză și au băut țuică la cabană.", - "Üzgün öğretmen öğrencilere seslendi: Lütfen gürültü yapmayın, sınavınızı bitirmeye çalışıyorum!", - "Ο Ξενοφών είπε: «Ψάχνω για ένα ωραίο δώρο για τη γιαγιά μου. Ίσως ένα φυτό ή ένα βιβλίο;»", - "Hola! ¿Cómo estás? Estoy muy emocionado por el cumpleaños de mi hermano, ¡va a ser increíble! También quiero comprar un pastel de chocolate con fresas y un regalo especial: un libro titulado «Cien años de soledad", - ] - - model = SparseTextEmbedding(model_name=model_name, language="english") - embeddings = list(model.embed(docs)) - for idx, shape in enumerate([14, 18, 15, 10, 15]): - assert embeddings[idx].values.shape == (shape,) - assert embeddings[idx].indices.shape == (shape,) - - if is_ci: - delete_model_cache(model.model._model_dir) +def test_special_characters(model_cache, model_name: str) -> None: + with model_cache(model_name) as model: + docs = [ + "Über den größten Flüssen Österreichs äußern sich Experten häufig: Öko-Systeme müssen geschützt werden!", + "L'élève français s'écrie : « Où est mon crayon ? J'ai besoin de finir cet exercice avant la récréation!", + "Într-o zi însorită, Ștefan și Ioana au mâncat mămăligă cu brânză și au băut țuică la cabană.", + "Üzgün öğretmen öğrencilere seslendi: Lütfen gürültü yapmayın, sınavınızı bitirmeye çalışıyorum!", + "Ο Ξενοφών είπε: «Ψάχνω για ένα ωραίο δώρο για τη γιαγιά μου. Ίσως ένα φυτό ή ένα βιβλίο;»", + "Hola! ¿Cómo estás? Estoy muy emocionado por el cumpleaños de mi hermano, ¡va a ser increíble! También quiero comprar un pastel de chocolate con fresas y un regalo especial: un libro titulado «Cien años de soledad", + ] + embeddings = list(model.embed(docs)) + for idx, shape in enumerate([14, 18, 15, 10, 15]): + assert embeddings[idx].values.shape == (shape,) + assert embeddings[idx].indices.shape == (shape,) @pytest.mark.parametrize("model_name", ["Qdrant/bm42-all-minilm-l6-v2-attentions"]) diff --git a/tests/test_custom_models.py b/tests/test_custom_models.py index 321e5b3f..dcca7d89 100644 --- a/tests/test_custom_models.py +++ b/tests/test_custom_models.py @@ -70,9 +70,13 @@ def test_text_custom_model(): assert embeddings.shape == (2, dim) assert np.allclose(embeddings[0, : canonical_vector.shape[0]], canonical_vector, atol=1e-3) + if is_ci: delete_model_cache(model.model._model_dir) + CustomTextEmbedding.SUPPORTED_MODELS.clear() + CustomTextEmbedding.POSTPROCESSING_MAPPING.clear() + def test_cross_encoder_custom_model(): is_ci = os.getenv("CI") @@ -110,6 +114,8 @@ def test_cross_encoder_custom_model(): if is_ci: delete_model_cache(model.model._model_dir) + CustomTextCrossEncoder.SUPPORTED_MODELS.clear() + def test_mock_add_custom_models(): dim = 5 @@ -169,6 +175,9 @@ def test_mock_add_custom_models(): ) assert np.allclose(post_processed_output, expected_output[model_name], atol=1e-3) + CustomTextEmbedding.SUPPORTED_MODELS.clear() + CustomTextEmbedding.POSTPROCESSING_MAPPING.clear() + def test_do_not_add_existing_model(): existing_base_model = "sentence-transformers/all-MiniLM-L6-v2" @@ -203,6 +212,9 @@ def test_do_not_add_existing_model(): size_in_gb=0.47, ) + CustomTextEmbedding.SUPPORTED_MODELS.clear() + CustomTextEmbedding.POSTPROCESSING_MAPPING.clear() + def test_do_not_add_existing_cross_encoder(): existing_base_model = "Xenova/ms-marco-MiniLM-L-6-v2" @@ -227,3 +239,5 @@ def test_do_not_add_existing_cross_encoder(): sources=ModelSource(hf=custom_model_name), size_in_gb=0.08, ) + + CustomTextCrossEncoder.SUPPORTED_MODELS.clear() diff --git a/tests/test_image_onnx_embeddings.py b/tests/test_image_onnx_embeddings.py index 9b15ad7d..8369702e 100644 --- a/tests/test_image_onnx_embeddings.py +++ b/tests/test_image_onnx_embeddings.py @@ -1,4 +1,5 @@ import os +from contextlib import contextmanager from io import BytesIO import numpy as np @@ -26,9 +27,37 @@ ), } +_MODELS_TO_CACHE = ("Qdrant/clip-ViT-B-32-vision",) +MODELS_TO_CACHE = tuple([x.lower() for x in _MODELS_TO_CACHE]) + + +@pytest.fixture(scope="module") +def model_cache(): + is_ci = os.getenv("CI") + cache = {} + + @contextmanager + def get_model(model_name: str): + lowercase_model_name = model_name.lower() + if lowercase_model_name not in cache: + cache[lowercase_model_name] = ImageEmbedding(lowercase_model_name) + yield cache[lowercase_model_name] + if lowercase_model_name not in MODELS_TO_CACHE: + model_inst = cache.pop(lowercase_model_name) + if is_ci: + delete_model_cache(model_inst.model._model_dir) + del model_inst + + yield get_model + + if is_ci: + for name, model in cache.items(): + delete_model_cache(model.model._model_dir) + cache.clear() + @pytest.mark.parametrize("model_name", ["Qdrant/clip-ViT-B-32-vision"]) -def test_embedding(model_name: str) -> None: +def test_embedding(model_cache, model_name: str) -> None: is_ci = os.getenv("CI") is_manual = os.getenv("GITHUB_EVENT_NAME") == "workflow_dispatch" @@ -38,80 +67,69 @@ def test_embedding(model_name: str) -> None: dim = model_desc.dim - model = ImageEmbedding(model_name=model_desc.model) - - images = [ - TEST_MISC_DIR / "image.jpeg", - str(TEST_MISC_DIR / "small_image.jpeg"), - Image.open((TEST_MISC_DIR / "small_image.jpeg")), - Image.open(BytesIO(requests.get("https://qdrant.tech/img/logo.png").content)), - ] - embeddings = list(model.embed(images)) - embeddings = np.stack(embeddings, axis=0) - assert embeddings.shape == (len(images), dim) - - canonical_vector = CANONICAL_VECTOR_VALUES[model_desc.model] + with model_cache(model_desc.model) as model: + images = [ + TEST_MISC_DIR / "image.jpeg", + str(TEST_MISC_DIR / "small_image.jpeg"), + Image.open((TEST_MISC_DIR / "small_image.jpeg")), + Image.open(BytesIO(requests.get("https://qdrant.tech/img/logo.png").content)), + ] + embeddings = list(model.embed(images)) + embeddings = np.stack(embeddings, axis=0) + assert embeddings.shape == (len(images), dim) - assert np.allclose( - embeddings[0, : canonical_vector.shape[0]], canonical_vector, atol=1e-3 - ), model_desc.model + canonical_vector = CANONICAL_VECTOR_VALUES[model_desc.model] - assert np.allclose(embeddings[1], embeddings[2]), model_desc.model + assert np.allclose( + embeddings[0, : canonical_vector.shape[0]], canonical_vector, atol=1e-3 + ), model_desc.model - if is_ci: - delete_model_cache(model.model._model_dir) + assert np.allclose(embeddings[1], embeddings[2]), model_desc.model @pytest.mark.parametrize("n_dims,model_name", [(512, "Qdrant/clip-ViT-B-32-vision")]) -def test_batch_embedding(n_dims: int, model_name: str) -> None: - is_ci = os.getenv("CI") - model = ImageEmbedding(model_name=model_name) - n_images = 32 - test_images = [ - TEST_MISC_DIR / "image.jpeg", - str(TEST_MISC_DIR / "small_image.jpeg"), - Image.open(TEST_MISC_DIR / "small_image.jpeg"), - ] - images = test_images * n_images +def test_batch_embedding(model_cache, n_dims: int, model_name: str) -> None: + with model_cache(model_name) as model: + n_images = 32 + test_images = [ + TEST_MISC_DIR / "image.jpeg", + str(TEST_MISC_DIR / "small_image.jpeg"), + Image.open(TEST_MISC_DIR / "small_image.jpeg"), + ] + images = test_images * n_images - embeddings = list(model.embed(images, batch_size=10)) - embeddings = np.stack(embeddings, axis=0) - assert np.allclose(embeddings[1], embeddings[2]) + embeddings = list(model.embed(images, batch_size=10)) + embeddings = np.stack(embeddings, axis=0) + assert np.allclose(embeddings[1], embeddings[2]) - canonical_vector = CANONICAL_VECTOR_VALUES[model_name] + canonical_vector = CANONICAL_VECTOR_VALUES[model_name] - assert embeddings.shape == (len(test_images) * n_images, n_dims) - assert np.allclose(embeddings[0, : canonical_vector.shape[0]], canonical_vector, atol=1e-3) - if is_ci: - delete_model_cache(model.model._model_dir) + assert embeddings.shape == (len(test_images) * n_images, n_dims) + assert np.allclose(embeddings[0, : canonical_vector.shape[0]], canonical_vector, atol=1e-3) @pytest.mark.parametrize("n_dims,model_name", [(512, "Qdrant/clip-ViT-B-32-vision")]) -def test_parallel_processing(n_dims: int, model_name: str) -> None: - is_ci = os.getenv("CI") - model = ImageEmbedding(model_name=model_name) - - n_images = 32 - test_images = [ - TEST_MISC_DIR / "image.jpeg", - str(TEST_MISC_DIR / "small_image.jpeg"), - Image.open(TEST_MISC_DIR / "small_image.jpeg"), - ] - images = test_images * n_images - embeddings = list(model.embed(images, batch_size=10, parallel=2)) - embeddings = np.stack(embeddings, axis=0) +def test_parallel_processing(model_cache, n_dims: int, model_name: str) -> None: + with model_cache(model_name) as model: + n_images = 32 + test_images = [ + TEST_MISC_DIR / "image.jpeg", + str(TEST_MISC_DIR / "small_image.jpeg"), + Image.open(TEST_MISC_DIR / "small_image.jpeg"), + ] + images = test_images * n_images + embeddings = list(model.embed(images, batch_size=10, parallel=2)) + embeddings = np.stack(embeddings, axis=0) - embeddings_2 = list(model.embed(images, batch_size=10, parallel=None)) - embeddings_2 = np.stack(embeddings_2, axis=0) + embeddings_2 = list(model.embed(images, batch_size=10, parallel=None)) + embeddings_2 = np.stack(embeddings_2, axis=0) - embeddings_3 = list(model.embed(images, batch_size=10, parallel=0)) - embeddings_3 = np.stack(embeddings_3, axis=0) + embeddings_3 = list(model.embed(images, batch_size=10, parallel=0)) + embeddings_3 = np.stack(embeddings_3, axis=0) - assert embeddings.shape == (n_images * len(test_images), n_dims) - assert np.allclose(embeddings, embeddings_2, atol=1e-3) - assert np.allclose(embeddings, embeddings_3, atol=1e-3) - if is_ci: - delete_model_cache(model.model._model_dir) + assert embeddings.shape == (n_images * len(test_images), n_dims) + assert np.allclose(embeddings, embeddings_2, atol=1e-3) + assert np.allclose(embeddings, embeddings_3, atol=1e-3) @pytest.mark.parametrize("model_name", ["Qdrant/clip-ViT-B-32-vision"]) diff --git a/tests/test_late_interaction_embeddings.py b/tests/test_late_interaction_embeddings.py index 9c7f20c6..f89882f4 100644 --- a/tests/test_late_interaction_embeddings.py +++ b/tests/test_late_interaction_embeddings.py @@ -1,4 +1,5 @@ import os +from contextlib import contextmanager import pytest import numpy as np @@ -150,42 +151,65 @@ ), } -docs = ["Hello World"] +_MODELS_TO_CACHE = ("answerdotai/answerai-colbert-small-v1",) +MODELS_TO_CACHE = tuple([x.lower() for x in _MODELS_TO_CACHE]) -@pytest.mark.parametrize("model_name", ["answerdotai/answerai-colbert-small-v1"]) -def test_batch_embedding(model_name: str): +@pytest.fixture(scope="module") +def model_cache(): is_ci = os.getenv("CI") - docs_to_embed = docs * 10 + cache = {} + + @contextmanager + def get_model(model_name: str): + lowercase_model_name = model_name.lower() + if lowercase_model_name not in cache: + cache[lowercase_model_name] = LateInteractionTextEmbedding(lowercase_model_name) + yield cache[lowercase_model_name] + if lowercase_model_name not in MODELS_TO_CACHE: + model_inst = cache.pop(lowercase_model_name) + if is_ci: + delete_model_cache(model_inst.model._model_dir) + del model_inst + + yield get_model - model = LateInteractionTextEmbedding(model_name=model_name) - result = list(model.embed(docs_to_embed, batch_size=6)) - expected_result = CANONICAL_COLUMN_VALUES[model_name] + if is_ci: + for name, model in cache.items(): + delete_model_cache(model.model._model_dir) + cache.clear() - for value in result: - token_num, abridged_dim = expected_result.shape - assert np.allclose(value[:, :abridged_dim], expected_result, atol=2e-3) - if is_ci: - delete_model_cache(model.model._model_dir) +docs = ["Hello World"] @pytest.mark.parametrize("model_name", ["answerdotai/answerai-colbert-small-v1"]) -def test_batch_inference_size_same_as_single_inference(model_name: str): - is_ci = os.getenv("CI") +def test_batch_embedding(model_cache, model_name: str): + docs_to_embed = docs * 10 - model = LateInteractionTextEmbedding(model_name=model_name) - docs_to_embed = ["short document", "A bit longer document, which should not affect the size"] - result = list(model.embed(docs_to_embed, batch_size=1)) - result_2 = list(model.embed(docs_to_embed, batch_size=2)) - assert len(result[0]) == len(result_2[0]) + with model_cache(model_name) as model: + result = list(model.embed(docs_to_embed, batch_size=6)) + expected_result = CANONICAL_COLUMN_VALUES[model_name] - if is_ci: - delete_model_cache(model.model._model_dir) + for value in result: + token_num, abridged_dim = expected_result.shape + assert np.allclose(value[:, :abridged_dim], expected_result, atol=2e-3) @pytest.mark.parametrize("model_name", ["answerdotai/answerai-colbert-small-v1"]) -def test_single_embedding(model_name: str): +def test_batch_inference_size_same_as_single_inference(model_cache, model_name: str): + with model_cache(model_name) as model: + docs_to_embed = [ + "short document", + "A bit longer document, which should not affect the size", + ] + result = list(model.embed(docs_to_embed, batch_size=1)) + result_2 = list(model.embed(docs_to_embed, batch_size=2)) + assert len(result[0]) == len(result_2[0]) + + +@pytest.mark.parametrize("model_name", ["answerdotai/answerai-colbert-small-v1"]) +def test_single_embedding(model_cache, model_name: str): is_ci = os.getenv("CI") is_manual = os.getenv("GITHUB_EVENT_NAME") == "workflow_dispatch" docs_to_embed = docs @@ -195,20 +219,17 @@ def test_single_embedding(model_name: str): continue print("evaluating", model_name) - model = LateInteractionTextEmbedding(model_name=model_name) - whole_result = list(model.embed(docs_to_embed, batch_size=6)) - assert len(whole_result) == 1 - result = whole_result[0] - expected_result = CANONICAL_COLUMN_VALUES[model_name] - token_num, abridged_dim = expected_result.shape - assert np.allclose(result[:, :abridged_dim], expected_result, atol=2e-3) - - if is_ci: - delete_model_cache(model.model._model_dir) + with model_cache(model_desc.model) as model: + whole_result = list(model.embed(docs_to_embed, batch_size=6)) + assert len(whole_result) == 1 + result = whole_result[0] + expected_result = CANONICAL_COLUMN_VALUES[model_desc.model] + token_num, abridged_dim = expected_result.shape + assert np.allclose(result[:, :abridged_dim], expected_result, atol=2e-3) @pytest.mark.parametrize("model_name", ["answerdotai/answerai-colbert-small-v1"]) -def test_single_embedding_query(model_name: str): +def test_single_embedding_query(model_cache, model_name: str): is_ci = os.getenv("CI") is_manual = os.getenv("GITHUB_EVENT_NAME") == "workflow_dispatch" queries_to_embed = docs @@ -217,39 +238,34 @@ def test_single_embedding_query(model_name: str): if not should_test_model(model_desc, model_name, is_ci, is_manual): continue - print("evaluating", model_name) - model = LateInteractionTextEmbedding(model_name=model_name) - whole_result = list(model.query_embed(queries_to_embed)) - assert len(whole_result) == 1 - result = whole_result[0] - expected_result = CANONICAL_QUERY_VALUES[model_name] - token_num, abridged_dim = expected_result.shape - assert np.allclose(result[:, :abridged_dim], expected_result, atol=2e-3) - - if is_ci: - delete_model_cache(model.model._model_dir) + print("evaluating", model_desc.model) + with model_cache(model_desc.model) as model: + whole_result = list(model.query_embed(queries_to_embed)) + assert len(whole_result) == 1 + result = whole_result[0] + expected_result = CANONICAL_QUERY_VALUES[model_desc.model] + token_num, abridged_dim = expected_result.shape + assert np.allclose(result[:, :abridged_dim], expected_result, atol=2e-3) @pytest.mark.parametrize("token_dim,model_name", [(96, "answerdotai/answerai-colbert-small-v1")]) -def test_parallel_processing(token_dim: int, model_name: str): - is_ci = os.getenv("CI") - model = LateInteractionTextEmbedding(model_name=model_name) - - docs = ["hello world", "flag embedding"] * 100 - embeddings = list(model.embed(docs, batch_size=10, parallel=2)) +def test_parallel_processing(model_cache, token_dim: int, model_name: str): + with model_cache(model_name) as model: + docs = ["hello world", "flag embedding"] * 100 + embeddings = list(model.embed(docs, batch_size=10, parallel=2)) - embeddings_2 = list(model.embed(docs, batch_size=10, parallel=None)) + embeddings_2 = list(model.embed(docs, batch_size=10, parallel=None)) - embeddings_3 = list(model.embed(docs, batch_size=10, parallel=0)) + # embeddings_3 = list(model.embed(docs, batch_size=10, parallel=0)) # inherits OnnxTextModel which + # # is tested in TextEmbedding, disabling it here to reduce number of requests to hf + # # multiprocessing is enough to test with `parallel=2`, and `parallel=None` is okay to tests since it reuses + # # model from cache - assert len(embeddings) == len(docs) and embeddings[0].shape[-1] == token_dim + assert len(embeddings) == len(docs) and embeddings[0].shape[-1] == token_dim - for i in range(len(embeddings)): - assert np.allclose(embeddings[i], embeddings_2[i], atol=1e-3) - assert np.allclose(embeddings[i], embeddings_3[i], atol=1e-3) - - if is_ci: - delete_model_cache(model.model._model_dir) + for i in range(len(embeddings)): + assert np.allclose(embeddings[i], embeddings_2[i], atol=1e-3) + # assert np.allclose(embeddings[i], embeddings_3[i], atol=1e-3) @pytest.mark.parametrize("model_name", ["answerdotai/answerai-colbert-small-v1"]) diff --git a/tests/test_sparse_embeddings.py b/tests/test_sparse_embeddings.py index 297996c5..90514967 100644 --- a/tests/test_sparse_embeddings.py +++ b/tests/test_sparse_embeddings.py @@ -1,4 +1,5 @@ import os +from contextlib import contextmanager import pytest import numpy as np @@ -76,6 +77,35 @@ } +_MODELS_TO_CACHE = ("prithivida/Splade_PP_en_v1", "Qdrant/minicoil-v1", "Qdrant/bm25") +MODELS_TO_CACHE = tuple([x.lower() for x in _MODELS_TO_CACHE]) + + +@pytest.fixture(scope="module") +def model_cache(): + is_ci = os.getenv("CI") + cache = {} + + @contextmanager + def get_model(model_name: str): + lowercase_model_name = model_name.lower() + if lowercase_model_name not in cache: + cache[lowercase_model_name] = SparseTextEmbedding(lowercase_model_name) + yield cache[lowercase_model_name] + if lowercase_model_name not in MODELS_TO_CACHE: + model_inst = cache.pop(lowercase_model_name) + if is_ci: + delete_model_cache(model_inst.model._model_dir) + del model_inst + + yield get_model + + if is_ci: + for name, model in cache.items(): + delete_model_cache(model.model._model_dir) + cache.clear() + + docs = ["Hello World"] @@ -83,23 +113,19 @@ "model_name", ["prithivida/Splade_PP_en_v1", "Qdrant/minicoil-v1"], ) -def test_batch_embedding(model_name: str) -> None: - is_ci = os.getenv("CI") +def test_batch_embedding(model_cache, model_name: str) -> None: docs_to_embed = docs * 10 - model = SparseTextEmbedding(model_name=model_name) - result = next(iter(model.embed(docs_to_embed, batch_size=6))) - expected_result = CANONICAL_COLUMN_VALUES[model_name] - assert result.indices.tolist() == expected_result["indices"] + with model_cache(model_name) as model: + result = next(iter(model.embed(docs_to_embed, batch_size=6))) + expected_result = CANONICAL_COLUMN_VALUES[model_name] + assert result.indices.tolist() == expected_result["indices"] - for i, value in enumerate(result.values): - assert pytest.approx(value, abs=0.001) == expected_result["values"][i] - if is_ci: - delete_model_cache(model.model._model_dir) + for i, value in enumerate(result.values): + assert pytest.approx(value, abs=0.001) == expected_result["values"][i] -@pytest.mark.parametrize("model_name", ["prithivida/Splade_PP_en_v1", "Qdrant/minicoil-v1"]) -def test_single_embedding(model_name: str) -> None: +def test_single_embedding(model_cache) -> None: is_ci = os.getenv("CI") is_manual = os.getenv("GITHUB_EVENT_NAME") == "workflow_dispatch" @@ -109,100 +135,106 @@ def test_single_embedding(model_name: str) -> None: ): # attention models and bm25 are also parts of # SparseTextEmbedding, however, they have their own tests continue - if not should_test_model(model_desc, model_name, is_ci, is_manual): + if not should_test_model(model_desc, model_desc.model, is_ci, is_manual): continue - model = SparseTextEmbedding(model_name=model_name) - - passage_result = next(iter(model.embed(docs, batch_size=6))) - query_result = next(iter(model.query_embed(docs))) - expected_result = CANONICAL_COLUMN_VALUES[model_name] - expected_query_result = CANONICAL_QUERY_VALUES.get(model_name, expected_result) - assert passage_result.indices.tolist() == expected_result["indices"] - for i, value in enumerate(passage_result.values): - assert pytest.approx(value, abs=0.001) == expected_result["values"][i] - - assert query_result.indices.tolist() == expected_query_result["indices"] - for i, value in enumerate(query_result.values): - assert pytest.approx(value, abs=0.001) == expected_query_result["values"][i] + with model_cache(model_desc.model) as model: + passage_result = next(iter(model.embed(docs, batch_size=6))) + query_result = next(iter(model.query_embed(docs))) + expected_result = CANONICAL_COLUMN_VALUES[model_desc.model] + expected_query_result = CANONICAL_QUERY_VALUES.get(model_desc.model, expected_result) + assert passage_result.indices.tolist() == expected_result["indices"] + for i, value in enumerate(passage_result.values): + assert pytest.approx(value, abs=0.001) == expected_result["values"][i] - if is_ci: - delete_model_cache(model.model._model_dir) + assert query_result.indices.tolist() == expected_query_result["indices"] + for i, value in enumerate(query_result.values): + assert pytest.approx(value, abs=0.001) == expected_query_result["values"][i] @pytest.mark.parametrize( "model_name", ["prithivida/Splade_PP_en_v1", "Qdrant/minicoil-v1"], ) -def test_parallel_processing(model_name: str) -> None: - is_ci = os.getenv("CI") - model = SparseTextEmbedding(model_name=model_name) - docs = ["hello world", "flag embedding"] * 30 - sparse_embeddings_duo = list(model.embed(docs, batch_size=10, parallel=2)) - sparse_embeddings_all = list(model.embed(docs, batch_size=10, parallel=0)) - sparse_embeddings = list(model.embed(docs, batch_size=10, parallel=None)) - - assert ( - len(sparse_embeddings) - == len(sparse_embeddings_duo) - == len(sparse_embeddings_all) - == len(docs) - ) - - for sparse_embedding, sparse_embedding_duo, sparse_embedding_all in zip( - sparse_embeddings, sparse_embeddings_duo, sparse_embeddings_all - ): +def test_parallel_processing(model_cache, model_name: str) -> None: + with model_cache(model_name) as model: + docs = ["hello world", "flag embedding"] * 30 + sparse_embeddings_duo = list(model.embed(docs, batch_size=10, parallel=2)) + # sparse_embeddings_all = list(model.embed(docs, batch_size=10, parallel=0)) # inherits OnnxTextModel which + # is tested in TextEmbedding, disabling it here to reduce number of requests to hf + # multiprocessing is enough to test with `parallel=2`, and `parallel=None` is okay to tests since it reuses + # model from cache + sparse_embeddings = list(model.embed(docs, batch_size=10, parallel=None)) + assert ( - sparse_embedding.indices.tolist() - == sparse_embedding_duo.indices.tolist() - == sparse_embedding_all.indices.tolist() + len(sparse_embeddings) + == len(sparse_embeddings_duo) + # == len(sparse_embeddings_all) + == len(docs) ) - assert np.allclose(sparse_embedding.values, sparse_embedding_duo.values, atol=1e-3) - assert np.allclose(sparse_embedding.values, sparse_embedding_all.values, atol=1e-3) - if is_ci: - delete_model_cache(model.model._model_dir) - - -@pytest.fixture -def bm25_instance() -> None: - ci = os.getenv("CI", True) - model = Bm25("Qdrant/bm25", language="english") - yield model - if ci: - delete_model_cache(model._model_dir) - - -def test_stem_with_stopwords_and_punctuation(bm25_instance: Bm25) -> None: - # Setup - bm25_instance.stopwords = {"the", "is", "a"} - bm25_instance.punctuation = {".", ",", "!"} - - # Test data - tokens = ["The", "quick", "brown", "fox", "is", "a", "test", "sentence", ".", "!"] + for ( + sparse_embedding, + sparse_embedding_duo, + # sparse_embedding_all + ) in zip( + sparse_embeddings, + sparse_embeddings_duo, + # sparse_embeddings_all + ): + assert ( + sparse_embedding.indices.tolist() == sparse_embedding_duo.indices.tolist() + # == sparse_embedding_all.indices.tolist() + ) + assert np.allclose(sparse_embedding.values, sparse_embedding_duo.values, atol=1e-3) + # assert np.allclose(sparse_embedding.values, sparse_embedding_all.values, atol=1e-3) + + +def test_stem_with_stopwords_and_punctuation(model_cache) -> None: + with model_cache("Qdrant/bm25") as model: + bm25_instance = model.model + # Setup + original_stopwords = bm25_instance.stopwords.copy() + original_punctuation = bm25_instance.punctuation.copy() + + bm25_instance.stopwords = {"the", "is", "a"} + bm25_instance.punctuation = {".", ",", "!"} + + # Test data + tokens = ["The", "quick", "brown", "fox", "is", "a", "test", "sentence", ".", "!"] + + # Execute + result = bm25_instance._stem(tokens) + + # Assert + expected = ["quick", "brown", "fox", "test", "sentenc"] + assert result == expected, f"Expected {expected}, but got {result}" - # Execute - result = bm25_instance._stem(tokens) + bm25_instance.stopwords = original_stopwords + bm25_instance.punctuation = original_punctuation - # Assert - expected = ["quick", "brown", "fox", "test", "sentenc"] - assert result == expected, f"Expected {expected}, but got {result}" +def test_stem_case_insensitive_stopwords(model_cache) -> None: + with model_cache("Qdrant/bm25") as model: + bm25_instance = model.model + original_stopwords = bm25_instance.stopwords.copy() + original_punctuation = bm25_instance.punctuation.copy() -def test_stem_case_insensitive_stopwords(bm25_instance: Bm25) -> None: - # Setup - bm25_instance.stopwords = {"the", "is", "a"} - bm25_instance.punctuation = {".", ",", "!"} + # Setup + bm25_instance.stopwords = {"the", "is", "a"} + bm25_instance.punctuation = {".", ",", "!"} - # Test data - tokens = ["THE", "Quick", "Brown", "Fox", "IS", "A", "Test", "Sentence", ".", "!"] + # Test data + tokens = ["THE", "Quick", "Brown", "Fox", "IS", "A", "Test", "Sentence", ".", "!"] - # Execute - result = bm25_instance._stem(tokens) + # Execute + result = bm25_instance._stem(tokens) - # Assert - expected = ["quick", "brown", "fox", "test", "sentenc"] - assert result == expected, f"Expected {expected}, but got {result}" + # Assert + expected = ["quick", "brown", "fox", "test", "sentenc"] + assert result == expected, f"Expected {expected}, but got {result}" + bm25_instance.stopwords = original_stopwords + bm25_instance.punctuation = original_punctuation @pytest.mark.parametrize("disable_stemmer", [True, False]) diff --git a/tests/test_text_cross_encoder.py b/tests/test_text_cross_encoder.py index 76362fdc..925ae8a3 100644 --- a/tests/test_text_cross_encoder.py +++ b/tests/test_text_cross_encoder.py @@ -1,4 +1,5 @@ import os +from contextlib import contextmanager import numpy as np import pytest @@ -16,8 +17,37 @@ } +_MODELS_TO_CACHE = ("Xenova/ms-marco-MiniLM-L-6-v2",) +MODELS_TO_CACHE = tuple([x.lower() for x in _MODELS_TO_CACHE]) + + +@pytest.fixture(scope="module") +def model_cache(): + is_ci = os.getenv("CI") + cache = {} + + @contextmanager + def get_model(model_name: str): + lowercase_model_name = model_name.lower() + if lowercase_model_name not in cache: + cache[lowercase_model_name] = TextCrossEncoder(lowercase_model_name) + yield cache[lowercase_model_name] + if lowercase_model_name not in MODELS_TO_CACHE: + model_inst = cache.pop(lowercase_model_name) + if is_ci: + delete_model_cache(model_inst.model._model_dir) + del model_inst + + yield get_model + + if is_ci: + for name, model in cache.items(): + delete_model_cache(model.model._model_dir) + cache.clear() + + @pytest.mark.parametrize("model_name", ["Xenova/ms-marco-MiniLM-L-6-v2"]) -def test_rerank(model_name: str) -> None: +def test_rerank(model_cache, model_name: str) -> None: is_ci = os.getenv("CI") is_manual = os.getenv("GITHUB_EVENT_NAME") == "workflow_dispatch" @@ -25,11 +55,29 @@ def test_rerank(model_name: str) -> None: if not should_test_model(model_desc, model_name, is_ci, is_manual): continue - model = TextCrossEncoder(model_name=model_name) + with model_cache(model_desc.model) as model: + query = "What is the capital of France?" + documents = ["Paris is the capital of France.", "Berlin is the capital of Germany."] + scores = np.array(list(model.rerank(query, documents))) + + pairs = [(query, doc) for doc in documents] + scores2 = np.array(list(model.rerank_pairs(pairs))) + assert np.allclose( + scores, scores2, atol=1e-5 + ), f"Model: {model_desc.model}, Scores: {scores}, Scores2: {scores2}" + + canonical_scores = CANONICAL_SCORE_VALUES[model_desc.model] + assert np.allclose( + scores, canonical_scores, atol=1e-3 + ), f"Model: {model_desc.model}, Scores: {scores}, Expected: {canonical_scores}" + +@pytest.mark.parametrize("model_name", ["Xenova/ms-marco-MiniLM-L-6-v2"]) +def test_batch_rerank(model_cache, model_name: str) -> None: + with model_cache(model_name) as model: query = "What is the capital of France?" - documents = ["Paris is the capital of France.", "Berlin is the capital of Germany."] - scores = np.array(list(model.rerank(query, documents))) + documents = ["Paris is the capital of France.", "Berlin is the capital of Germany."] * 50 + scores = np.array(list(model.rerank(query, documents, batch_size=10))) pairs = [(query, doc) for doc in documents] scores2 = np.array(list(model.rerank_pairs(pairs))) @@ -37,38 +85,12 @@ def test_rerank(model_name: str) -> None: scores, scores2, atol=1e-5 ), f"Model: {model_name}, Scores: {scores}, Scores2: {scores2}" - canonical_scores = CANONICAL_SCORE_VALUES[model_name] + canonical_scores = np.tile(CANONICAL_SCORE_VALUES[model_name], 50) + + assert scores.shape == canonical_scores.shape, f"Unexpected shape for model {model_name}" assert np.allclose( scores, canonical_scores, atol=1e-3 ), f"Model: {model_name}, Scores: {scores}, Expected: {canonical_scores}" - if is_ci: - delete_model_cache(model.model._model_dir) - - -@pytest.mark.parametrize("model_name", ["Xenova/ms-marco-MiniLM-L-6-v2"]) -def test_batch_rerank(model_name: str) -> None: - is_ci = os.getenv("CI") - - model = TextCrossEncoder(model_name=model_name) - - query = "What is the capital of France?" - documents = ["Paris is the capital of France.", "Berlin is the capital of Germany."] * 50 - scores = np.array(list(model.rerank(query, documents, batch_size=10))) - - pairs = [(query, doc) for doc in documents] - scores2 = np.array(list(model.rerank_pairs(pairs))) - assert np.allclose( - scores, scores2, atol=1e-5 - ), f"Model: {model_name}, Scores: {scores}, Scores2: {scores2}" - - canonical_scores = np.tile(CANONICAL_SCORE_VALUES[model_name], 50) - - assert scores.shape == canonical_scores.shape, f"Unexpected shape for model {model_name}" - assert np.allclose( - scores, canonical_scores, atol=1e-3 - ), f"Model: {model_name}, Scores: {scores}, Expected: {canonical_scores}" - if is_ci: - delete_model_cache(model.model._model_dir) @pytest.mark.parametrize("model_name", ["Xenova/ms-marco-MiniLM-L-6-v2"]) @@ -86,21 +108,17 @@ def test_lazy_load(model_name: str) -> None: @pytest.mark.parametrize("model_name", ["Xenova/ms-marco-MiniLM-L-6-v2"]) -def test_rerank_pairs_parallel(model_name: str) -> None: - is_ci = os.getenv("CI") - - model = TextCrossEncoder(model_name=model_name) - query = "What is the capital of France?" - documents = ["Paris is the capital of France.", "Berlin is the capital of Germany."] * 10 - pairs = [(query, doc) for doc in documents] - scores_parallel = np.array(list(model.rerank_pairs(pairs, parallel=2, batch_size=10))) - scores_sequential = np.array(list(model.rerank_pairs(pairs, batch_size=10))) - assert np.allclose( - scores_parallel, scores_sequential, atol=1e-5 - ), f"Model: {model_name}, Scores (Parallel): {scores_parallel}, Scores (Sequential): {scores_sequential}" - canonical_scores = CANONICAL_SCORE_VALUES[model_name] - assert np.allclose( - scores_parallel[: len(canonical_scores)], canonical_scores, atol=1e-3 - ), f"Model: {model_name}, Scores (Parallel): {scores_parallel}, Expected: {canonical_scores}" - if is_ci: - delete_model_cache(model.model._model_dir) +def test_rerank_pairs_parallel(model_cache, model_name: str) -> None: + with model_cache(model_name) as model: + query = "What is the capital of France?" + documents = ["Paris is the capital of France.", "Berlin is the capital of Germany."] * 10 + pairs = [(query, doc) for doc in documents] + scores_parallel = np.array(list(model.rerank_pairs(pairs, parallel=2, batch_size=10))) + scores_sequential = np.array(list(model.rerank_pairs(pairs, batch_size=10))) + assert np.allclose( + scores_parallel, scores_sequential, atol=1e-5 + ), f"Model: {model_name}, Scores (Parallel): {scores_parallel}, Scores (Sequential): {scores_sequential}" + canonical_scores = CANONICAL_SCORE_VALUES[model_name] + assert np.allclose( + scores_parallel[: len(canonical_scores)], canonical_scores, atol=1e-3 + ), f"Model: {model_name}, Scores (Parallel): {scores_parallel}, Expected: {canonical_scores}" diff --git a/tests/test_text_onnx_embeddings.py b/tests/test_text_onnx_embeddings.py index 6b25d900..46ce6554 100644 --- a/tests/test_text_onnx_embeddings.py +++ b/tests/test_text_onnx_embeddings.py @@ -1,5 +1,6 @@ import os import platform +from contextlib import contextmanager import numpy as np import pytest @@ -71,9 +72,37 @@ MULTI_TASK_MODELS = ["jinaai/jina-embeddings-v3"] +_MODELS_TO_CACHE = ("BAAI/bge-small-en-v1.5",) +MODELS_TO_CACHE = tuple([x.lower() for x in _MODELS_TO_CACHE]) + + +@pytest.fixture(scope="module") +def model_cache(): + is_ci = os.getenv("CI") + cache = {} + + @contextmanager + def get_model(model_name: str): + lowercase_model_name = model_name.lower() + if lowercase_model_name not in cache: + cache[lowercase_model_name] = TextEmbedding(lowercase_model_name) + yield cache[lowercase_model_name] + if lowercase_model_name not in MODELS_TO_CACHE: + model_inst = cache.pop(lowercase_model_name) + if is_ci: + delete_model_cache(model_inst.model._model_dir) + del model_inst + + yield get_model + + if is_ci: + for name, model in cache.items(): + delete_model_cache(model.model._model_dir) + cache.clear() + @pytest.mark.parametrize("model_name", ["BAAI/bge-small-en-v1.5"]) -def test_embedding(model_name: str) -> None: +def test_embedding(model_cache, model_name: str) -> None: is_ci = os.getenv("CI") is_mac = platform.system() == "Darwin" is_manual = os.getenv("GITHUB_EVENT_NAME") == "workflow_dispatch" @@ -88,55 +117,44 @@ def test_embedding(model_name: str) -> None: dim = model_desc.dim - model = TextEmbedding(model_name=model_desc.model) - docs = ["hello world", "flag embedding"] - embeddings = list(model.embed(docs)) - embeddings = np.stack(embeddings, axis=0) - assert embeddings.shape == (2, dim) + with model_cache(model_desc.model) as model: + docs = ["hello world", "flag embedding"] + embeddings = list(model.embed(docs)) + embeddings = np.stack(embeddings, axis=0) + assert embeddings.shape == (2, dim) - canonical_vector = CANONICAL_VECTOR_VALUES[model_desc.model] - assert np.allclose( - embeddings[0, : canonical_vector.shape[0]], canonical_vector, atol=1e-3 - ), model_desc.model - if is_ci: - delete_model_cache(model.model._model_dir) + canonical_vector = CANONICAL_VECTOR_VALUES[model_desc.model] + assert np.allclose( + embeddings[0, : canonical_vector.shape[0]], canonical_vector, atol=1e-3 + ), model_desc.model @pytest.mark.parametrize("n_dims,model_name", [(384, "BAAI/bge-small-en-v1.5")]) -def test_batch_embedding(n_dims: int, model_name: str) -> None: - is_ci = os.getenv("CI") - model = TextEmbedding(model_name=model_name) +def test_batch_embedding(model_cache, n_dims: int, model_name: str) -> None: + with model_cache(model_name) as model: + docs = ["hello world", "flag embedding"] * 100 + embeddings = list(model.embed(docs, batch_size=10)) + embeddings = np.stack(embeddings, axis=0) - docs = ["hello world", "flag embedding"] * 100 - embeddings = list(model.embed(docs, batch_size=10)) - embeddings = np.stack(embeddings, axis=0) - - assert embeddings.shape == (len(docs), n_dims) - if is_ci: - delete_model_cache(model.model._model_dir) + assert embeddings.shape == (len(docs), n_dims) @pytest.mark.parametrize("n_dims,model_name", [(384, "BAAI/bge-small-en-v1.5")]) -def test_parallel_processing(n_dims: int, model_name: str) -> None: - is_ci = os.getenv("CI") - model = TextEmbedding(model_name=model_name) - - docs = ["hello world", "flag embedding"] * 100 - embeddings = list(model.embed(docs, batch_size=10, parallel=2)) - embeddings = np.stack(embeddings, axis=0) - - embeddings_2 = list(model.embed(docs, batch_size=10, parallel=None)) - embeddings_2 = np.stack(embeddings_2, axis=0) +def test_parallel_processing(model_cache, n_dims: int, model_name: str) -> None: + with model_cache(model_name) as model: + docs = ["hello world", "flag embedding"] * 100 + embeddings = list(model.embed(docs, batch_size=10, parallel=2)) + embeddings = np.stack(embeddings, axis=0) - embeddings_3 = list(model.embed(docs, batch_size=10, parallel=0)) - embeddings_3 = np.stack(embeddings_3, axis=0) + embeddings_2 = list(model.embed(docs, batch_size=10, parallel=None)) + embeddings_2 = np.stack(embeddings_2, axis=0) - assert embeddings.shape == (len(docs), n_dims) - assert np.allclose(embeddings, embeddings_2, atol=1e-3) - assert np.allclose(embeddings, embeddings_3, atol=1e-3) + embeddings_3 = list(model.embed(docs, batch_size=10, parallel=0)) + embeddings_3 = np.stack(embeddings_3, axis=0) - if is_ci: - delete_model_cache(model.model._model_dir) + assert embeddings.shape == (len(docs), n_dims) + assert np.allclose(embeddings, embeddings_2, atol=1e-3) + assert np.allclose(embeddings, embeddings_3, atol=1e-3) @pytest.mark.parametrize("model_name", ["BAAI/bge-small-en-v1.5"]) diff --git a/tests/utils.py b/tests/utils.py index 40a7febf..5fc0fd17 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -49,7 +49,7 @@ def should_test_model( Tests can be run either in ci or locally. Testing all models each time in ci is too long. - The testing scheme in ci and on a local machine are different, therefore, there are 3 possible scenarious. + The testing scheme in ci and on a local machine are different, therefore, there are 3 possible scenarios. 1) Run lightweight tests in ci: - test only one model that has been manually chosen as a representative for a certain class family 2) Run heavyweight (manual) tests in ci: