|
3 | 3 | # SPDX-License-Identifier: Apache-2.0 |
4 | 4 |
|
5 | 5 | # Standard library imports |
| 6 | +import logging |
6 | 7 | import shutil |
7 | 8 | from pathlib import Path |
8 | 9 |
|
9 | 10 | import numpy as np |
10 | 11 | import pytest |
11 | 12 |
|
12 | | -from wurzel.utils import HAS_LANGCHAIN_CORE, HAS_REQUESTS |
| 13 | +from wurzel.utils import HAS_LANGCHAIN_CORE, HAS_REQUESTS, HAS_SPACY, HAS_TIKTOKEN |
13 | 14 |
|
14 | | -if not HAS_LANGCHAIN_CORE or not HAS_REQUESTS: |
15 | | - pytest.skip("Embedding dependencies (langchain-core, requests) are not available", allow_module_level=True) |
| 15 | +if not HAS_LANGCHAIN_CORE or not HAS_REQUESTS or not HAS_SPACY or not HAS_TIKTOKEN: |
| 16 | + pytest.skip("Embedding dependencies (langchain-core, requests, spacy, tiktoken) are not available", allow_module_level=True) |
16 | 17 |
|
17 | 18 | from wurzel.exceptions import StepFailed |
18 | 19 | from wurzel.step_executor import BaseStepExecutor |
19 | | - |
20 | | -# Local application/library specific imports |
21 | 20 | from wurzel.steps import EmbeddingStep |
22 | 21 | from wurzel.steps.embedding.huggingface import HuggingFaceInferenceAPIEmbeddings |
23 | 22 | from wurzel.steps.embedding.step_multivector import EmbeddingMultiVectorStep |
24 | 23 |
|
| 24 | +SPLITTER_TOKENIZER_MODEL = "gpt-3.5-turbo" |
| 25 | +SENTENCE_SPLITTER_MODEL = "de_core_news_sm" |
| 26 | + |
25 | 27 |
|
26 | 28 | @pytest.fixture(scope="module") |
27 | 29 | def mock_embedding(): |
@@ -87,12 +89,23 @@ def test_embedding_step(mock_embedding, default_embedding_data, env): |
87 | 89 |
|
88 | 90 | """ |
89 | 91 | env.set("EMBEDDINGSTEP__API", "https://example-embedding.com/embed") |
| 92 | + env.set("EMBEDDINGSTEP__TOKEN_COUNT_MIN", "64") |
| 93 | + env.set("EMBEDDINGSTEP__TOKEN_COUNT_MAX", "256") |
| 94 | + env.set("EMBEDDINGSTEP__TOKEN_COUNT_BUFFER", "32") |
| 95 | + env.set("EMBEDDINGSTEP__TOKENIZER_MODEL", SPLITTER_TOKENIZER_MODEL) |
| 96 | + env.set("EMBEDDINGSTEP__SENTENCE_SPLITTER_MODEL", SENTENCE_SPLITTER_MODEL) |
| 97 | + |
90 | 98 | EmbeddingStep._select_embedding = mock_embedding |
91 | 99 | input_folder, output_folder = default_embedding_data |
92 | | - BaseStepExecutor(dont_encapsulate=False).execute_step(EmbeddingStep, [input_folder], output_folder) |
| 100 | + step_res = BaseStepExecutor(dont_encapsulate=False).execute_step(EmbeddingStep, [input_folder], output_folder) |
93 | 101 | assert output_folder.is_dir() |
94 | 102 | assert len(list(output_folder.glob("*"))) > 0 |
95 | 103 |
|
| 104 | + step_output, step_report = step_res[0] |
| 105 | + |
| 106 | + assert len(step_output) == 11, "Step outputs have wrong count." |
| 107 | + assert step_report.results == 11, "Step report has wrong count of outputs." |
| 108 | + |
96 | 109 |
|
97 | 110 | def test_mutlivector_embedding_step(mock_embedding, tmp_path, env): |
98 | 111 | """Tests the execution of the `EmbeddingMultiVectorStep` with a mock input file. |
@@ -137,3 +150,60 @@ def _select_embedding(*args, **kwargs) -> HuggingFaceInferenceAPIEmbeddings: |
137 | 150 | with BaseStepExecutor() as ex: |
138 | 151 | ex(InheritedStep, [inp], out) |
139 | 152 | assert sf.value.message.endswith(EXPECTED_EXCEPTION) |
| 153 | + |
| 154 | + |
| 155 | +def test_embedding_step_log_statistics(mock_embedding, default_embedding_data, env, caplog): |
| 156 | + """Tests the logging of descriptive statistics in the `EmbeddingStep` with a mock input file.""" |
| 157 | + env.set("EMBEDDINGSTEP__API", "https://example-embedding.com/embed") |
| 158 | + env.set("EMBEDDINGSTEP__NUM_THREADS", "1") # Ensure deterministic behavior with single thread |
| 159 | + env.set("EMBEDDINGSTEP__TOKEN_COUNT_MIN", "64") |
| 160 | + env.set("EMBEDDINGSTEP__TOKEN_COUNT_MAX", "256") |
| 161 | + env.set("EMBEDDINGSTEP__TOKEN_COUNT_BUFFER", "32") |
| 162 | + env.set("EMBEDDINGSTEP__TOKENIZER_MODEL", SPLITTER_TOKENIZER_MODEL) |
| 163 | + env.set("EMBEDDINGSTEP__SENTENCE_SPLITTER_MODEL", SENTENCE_SPLITTER_MODEL) |
| 164 | + |
| 165 | + EmbeddingStep._select_embedding = mock_embedding |
| 166 | + input_folder, output_folder = default_embedding_data |
| 167 | + |
| 168 | + with caplog.at_level(logging.INFO): |
| 169 | + BaseStepExecutor(dont_encapsulate=False).execute_step(EmbeddingStep, [input_folder], output_folder) |
| 170 | + |
| 171 | + # check if output log exists |
| 172 | + assert "Distribution of char length" in caplog.text, "Missing log output for char length" |
| 173 | + assert "Distribution of token length" in caplog.text, "Missing log output for token length" |
| 174 | + assert "Distribution of chunks count" in caplog.text, "Missing log output for chunks count" |
| 175 | + |
| 176 | + # check extras |
| 177 | + char_length_record = None |
| 178 | + token_length_record = None |
| 179 | + chunks_count_record = None |
| 180 | + |
| 181 | + for record in caplog.records: |
| 182 | + if "Distribution of char length" in record.message: |
| 183 | + char_length_record = record |
| 184 | + |
| 185 | + if "Distribution of token length" in record.message: |
| 186 | + token_length_record = record |
| 187 | + |
| 188 | + if "Distribution of chunks count" in record.message: |
| 189 | + chunks_count_record = record |
| 190 | + |
| 191 | + expected_char_length_count = 11 |
| 192 | + |
| 193 | + # Check values if a small tolerance |
| 194 | + expected_char_length_mean = pytest.approx(609.18, abs=0.1) |
| 195 | + expected_token_length_mean = pytest.approx(257.18, abs=0.1) |
| 196 | + expected_chunks_count_mean = pytest.approx(3.18, abs=0.2) |
| 197 | + |
| 198 | + assert char_length_record.count == expected_char_length_count, ( |
| 199 | + f"Invalid char length count: expected {expected_char_length_count}, got {char_length_record.count}" |
| 200 | + ) |
| 201 | + assert char_length_record.mean == expected_char_length_mean, ( |
| 202 | + f"Invalid char length mean: expected {expected_char_length_mean}, got {char_length_record.mean}" |
| 203 | + ) |
| 204 | + assert token_length_record.mean == expected_token_length_mean, ( |
| 205 | + f"Invalid token length mean: expected {expected_token_length_mean}, got {token_length_record.mean}" |
| 206 | + ) |
| 207 | + assert chunks_count_record.mean == expected_chunks_count_mean, ( |
| 208 | + f"Invalid chunks count mean: expected {expected_chunks_count_mean}, got {chunks_count_record.mean}" |
| 209 | + ) |
0 commit comments