77import shutil
88from pathlib import Path
99
10- import numpy as np
1110import pytest
1211
1312from wurzel .utils import HAS_LANGCHAIN_CORE , HAS_REQUESTS , HAS_SPACY , HAS_TIKTOKEN
2120from wurzel .steps .embedding .huggingface import HuggingFaceInferenceAPIEmbeddings
2221from wurzel .steps .embedding .step_multivector import EmbeddingMultiVectorStep
2322
24- SPLITTER_TOKENIZER_MODEL = "gpt-3.5-turbo"
25- SENTENCE_SPLITTER_MODEL = "de_core_news_sm"
2623
27-
28- @pytest .fixture (scope = "module" )
29- def mock_embedding ():
30- """A pytest fixture that provides a mock embedding class for testing.
31-
32- Overrides the `_select_embedding` method of the `EmbeddingStep` class
33- to return an instance of the mock embedding class, which generates
34- a fixed-size random vector upon calling `embed_query`.
35-
36- Returns:
37- -------
38- MockEmbedding
39- An instance of the mock embedding class.
40-
41- """
42-
43- class MockEmbedding :
44- def embed_query (self , _ : str ) -> list [float ]:
45- """Simulates embedding of a query string into a fixed-size random vector.
46-
47- Parameters
48- ----------
49- _ : str
50- The input query string (ignored in this mock implementation).
51-
52- Returns:
53- -------
54- np.ndarray
55- A random vector of size 768.
56-
57- """
58- return list (np .random .random (768 ))
59-
60- def mock_func (* args , ** kwargs ):
61- return MockEmbedding ()
62-
63- return mock_func
64-
65-
66- @pytest .fixture
67- def default_embedding_data (tmp_path ):
68- mock_file = Path ("tests/data/markdown.json" )
69- input_folder = tmp_path / "input"
70- input_folder .mkdir ()
71- shutil .copy (mock_file , input_folder )
72- output_folder = tmp_path / "out"
73- return (input_folder , output_folder )
74-
75-
76- def test_embedding_step (mock_embedding , default_embedding_data , env ):
24+ def test_embedding_step (mock_embedding , default_embedding_data , env , splitter_tokenizer_model , sentence_splitter_model ):
7725 """Tests the execution of the `EmbeddingStep` with a mock input file.
7826
7927 Parameters
@@ -92,8 +40,8 @@ def test_embedding_step(mock_embedding, default_embedding_data, env):
9240 env .set ("EMBEDDINGSTEP__TOKEN_COUNT_MIN" , "64" )
9341 env .set ("EMBEDDINGSTEP__TOKEN_COUNT_MAX" , "256" )
9442 env .set ("EMBEDDINGSTEP__TOKEN_COUNT_BUFFER" , "32" )
95- env .set ("EMBEDDINGSTEP__TOKENIZER_MODEL" , SPLITTER_TOKENIZER_MODEL )
96- env .set ("EMBEDDINGSTEP__SENTENCE_SPLITTER_MODEL" , SENTENCE_SPLITTER_MODEL )
43+ env .set ("EMBEDDINGSTEP__TOKENIZER_MODEL" , splitter_tokenizer_model )
44+ env .set ("EMBEDDINGSTEP__SENTENCE_SPLITTER_MODEL" , sentence_splitter_model )
9745
9846 EmbeddingStep ._select_embedding = mock_embedding
9947 input_folder , output_folder = default_embedding_data
@@ -152,15 +100,17 @@ def _select_embedding(*args, **kwargs) -> HuggingFaceInferenceAPIEmbeddings:
152100 assert sf .value .message .endswith (EXPECTED_EXCEPTION )
153101
154102
155- def test_embedding_step_log_statistics (mock_embedding , default_embedding_data , env , caplog ):
103+ def test_embedding_step_log_statistics (
104+ mock_embedding , default_embedding_data , env , caplog , splitter_tokenizer_model , sentence_splitter_model
105+ ):
156106 """Tests the logging of descriptive statistics in the `EmbeddingStep` with a mock input file."""
157107 env .set ("EMBEDDINGSTEP__API" , "https://example-embedding.com/embed" )
158108 env .set ("EMBEDDINGSTEP__NUM_THREADS" , "1" ) # Ensure deterministic behavior with single thread
159109 env .set ("EMBEDDINGSTEP__TOKEN_COUNT_MIN" , "64" )
160110 env .set ("EMBEDDINGSTEP__TOKEN_COUNT_MAX" , "256" )
161111 env .set ("EMBEDDINGSTEP__TOKEN_COUNT_BUFFER" , "32" )
162- env .set ("EMBEDDINGSTEP__TOKENIZER_MODEL" , SPLITTER_TOKENIZER_MODEL )
163- env .set ("EMBEDDINGSTEP__SENTENCE_SPLITTER_MODEL" , SENTENCE_SPLITTER_MODEL )
112+ env .set ("EMBEDDINGSTEP__TOKENIZER_MODEL" , splitter_tokenizer_model )
113+ env .set ("EMBEDDINGSTEP__SENTENCE_SPLITTER_MODEL" , sentence_splitter_model )
164114
165115 EmbeddingStep ._select_embedding = mock_embedding
166116 input_folder , output_folder = default_embedding_data
0 commit comments