Skip to content

Commit a92d004

Browse files
authored
feat: Adding TruncatedEmbeddingStep (2) and document metadata to Qdrant (#183)
1 parent bdf30a5 commit a92d004

File tree

15 files changed

+368
-126
lines changed

15 files changed

+368
-126
lines changed

tests/data/embedded.csv

Lines changed: 3 additions & 3 deletions
Large diffs are not rendered by default.
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# SPDX-FileCopyrightText: 2025 Deutsche Telekom AG (opensource@telekom.de)
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
from wurzel.steps.data import EmbeddingResult
6+
7+
8+
def test_load_from_csv_converts_dict_columns(tmp_path):
9+
csv = tmp_path / "embedded.csv"
10+
csv.write_text(
11+
'text,vector,url,keywords,embedding_input_text,metadata\n"foo","[0.1, 0.2]","","https://example.com","kw","{\'foo\': \'bar\'}"\n',
12+
encoding="utf-8",
13+
)
14+
15+
df = EmbeddingResult.load_from_path(csv)
16+
17+
metadata = df["metadata"].iloc[0]
18+
assert isinstance(metadata, dict)
19+
assert metadata["foo"] == "bar"

tests/splitter_test.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,8 +100,17 @@ def test_split_markdown_document(Splitter):
100100
)
101101

102102
result = Splitter.split_markdown_document(contract)
103-
assert len(result) > 1
104-
assert "TV HD Recorder Fehlerbehebun" in result[-1].md
103+
assert len(result) == 5, "Splitter produce invalid number of chunks"
104+
assert "TV HD Recorder Fehlerbehebun" in result[-1].md, "Invalid chunk content"
105+
106+
# Check metadata
107+
expected_hash = "1b5098dbc4584f019bb00cbbb42a36ef27e908b216f40e09ae77f30ca1cddc2f" # pragma: allowlist secret
108+
assert result[0].metadata["source_sha256_hash"] == expected_hash, "Invalid source hash"
109+
110+
assert result[0].metadata["source_sha256_hash"] == result[-1].metadata["source_sha256_hash"], "Source hashes are not the same"
111+
assert result[0].metadata["chunks_count"] == 5, "Chunk metadata is invalid"
112+
assert result[0].metadata["chunk_index"] == 0, "Chunk metadata is invalid"
113+
assert result[-1].metadata["chunk_index"] == 4, "Chunk metadata is invalid"
105114

106115

107116
def test_sentence_splitter(Splitter: SemanticSplitter):

tests/step_executor/sort_test.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@
2525
pytest.param(
2626
DataFrame[EmbeddingResult](
2727
[
28-
{"text": "a", "url": "url", "vector": [0.1], "keywords": "kw"},
29-
{"text": "b", "url": "url", "vector": [0.1], "keywords": "kw"},
28+
{"text": "a", "url": "url", "vector": [0.1], "keywords": "kw", "embedding_input_text": "a", "metadata": {"foo": "bar"}},
29+
{"text": "b", "url": "url", "vector": [0.1], "keywords": "kw", "embedding_input_text": "b", "metadata": {"foo": "bar"}},
3030
]
3131
),
3232
id="DataFrame",
@@ -63,14 +63,14 @@ def test_unsorted(run_num, expected):
6363
def test_unsorted_df():
6464
unsorted = DataFrame[EmbeddingResult](
6565
[
66-
{"text": "b", "url": "url", "vector": [0.1], "keywords": "kw"},
67-
{"text": "a", "url": "url", "vector": [0.1], "keywords": "kw"},
66+
{"text": "b", "url": "url", "vector": [0.1], "keywords": "kw", "embedding_input_text": "b", "metadata": {"foo": "bar"}},
67+
{"text": "a", "url": "url", "vector": [0.1], "keywords": "kw", "embedding_input_text": "a", "metadata": {"foo": "bar"}},
6868
]
6969
)
7070
sort = DataFrame[EmbeddingResult](
7171
[
72-
{"text": "a", "url": "url", "vector": [0.1], "keywords": "kw"},
73-
{"text": "b", "url": "url", "vector": [0.1], "keywords": "kw"},
72+
{"text": "a", "url": "url", "vector": [0.1], "keywords": "kw", "embedding_input_text": "a", "metadata": {"foo": "bar"}},
73+
{"text": "b", "url": "url", "vector": [0.1], "keywords": "kw", "embedding_input_text": "b", "metadata": {"foo": "bar"}},
7474
]
7575
)
7676
assert not sort.equals(unsorted), "sanity check"

tests/steps/embedding/conftest.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,10 @@
22
#
33
# SPDX-License-Identifier: Apache-2.0
44

5+
import shutil
6+
from pathlib import Path
7+
8+
import numpy as np
59
import pytest
610
import requests_mock
711

@@ -32,3 +36,61 @@ def embedding_service_mock():
3236
m.post("/embed", text=POST_RESULT_EMBEDDING_STR)
3337
m.get("/info", text=GET_RESULT_INFO)
3438
yield
39+
40+
41+
@pytest.fixture(scope="module")
42+
def mock_embedding():
43+
"""A pytest fixture that provides a mock embedding class for testing.
44+
45+
Overrides the `_select_embedding` method of the `EmbeddingStep` class
46+
to return an instance of the mock embedding class, which generates
47+
a fixed-size random vector upon calling `embed_query`.
48+
49+
Returns:
50+
-------
51+
MockEmbedding
52+
An instance of the mock embedding class.
53+
54+
"""
55+
56+
class MockEmbedding:
57+
def embed_query(self, _: str) -> list[float]:
58+
"""Simulates embedding of a query string into a fixed-size random vector.
59+
60+
Parameters
61+
----------
62+
_ : str
63+
The input query string (ignored in this mock implementation).
64+
65+
Returns:
66+
-------
67+
np.ndarray
68+
A random vector of size 768.
69+
70+
"""
71+
return list(np.random.random(768))
72+
73+
def mock_func(*args, **kwargs):
74+
return MockEmbedding()
75+
76+
return mock_func
77+
78+
79+
@pytest.fixture
80+
def default_embedding_data(tmp_path):
81+
mock_file = Path("tests/data/markdown.json")
82+
input_folder = tmp_path / "input"
83+
input_folder.mkdir()
84+
shutil.copy(mock_file, input_folder)
85+
output_folder = tmp_path / "out"
86+
return (input_folder, output_folder)
87+
88+
89+
@pytest.fixture
90+
def splitter_tokenizer_model():
91+
return "gpt-3.5-turbo"
92+
93+
94+
@pytest.fixture
95+
def sentence_splitter_model():
96+
return "de_core_news_sm"

tests/steps/embedding/e2e_test.py

Lines changed: 8 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import shutil
88
from pathlib import Path
99

10-
import numpy as np
1110
import pytest
1211

1312
from wurzel.utils import HAS_LANGCHAIN_CORE, HAS_REQUESTS, HAS_SPACY, HAS_TIKTOKEN
@@ -21,59 +20,8 @@
2120
from wurzel.steps.embedding.huggingface import HuggingFaceInferenceAPIEmbeddings
2221
from wurzel.steps.embedding.step_multivector import EmbeddingMultiVectorStep
2322

24-
SPLITTER_TOKENIZER_MODEL = "gpt-3.5-turbo"
25-
SENTENCE_SPLITTER_MODEL = "de_core_news_sm"
2623

27-
28-
@pytest.fixture(scope="module")
29-
def mock_embedding():
30-
"""A pytest fixture that provides a mock embedding class for testing.
31-
32-
Overrides the `_select_embedding` method of the `EmbeddingStep` class
33-
to return an instance of the mock embedding class, which generates
34-
a fixed-size random vector upon calling `embed_query`.
35-
36-
Returns:
37-
-------
38-
MockEmbedding
39-
An instance of the mock embedding class.
40-
41-
"""
42-
43-
class MockEmbedding:
44-
def embed_query(self, _: str) -> list[float]:
45-
"""Simulates embedding of a query string into a fixed-size random vector.
46-
47-
Parameters
48-
----------
49-
_ : str
50-
The input query string (ignored in this mock implementation).
51-
52-
Returns:
53-
-------
54-
np.ndarray
55-
A random vector of size 768.
56-
57-
"""
58-
return list(np.random.random(768))
59-
60-
def mock_func(*args, **kwargs):
61-
return MockEmbedding()
62-
63-
return mock_func
64-
65-
66-
@pytest.fixture
67-
def default_embedding_data(tmp_path):
68-
mock_file = Path("tests/data/markdown.json")
69-
input_folder = tmp_path / "input"
70-
input_folder.mkdir()
71-
shutil.copy(mock_file, input_folder)
72-
output_folder = tmp_path / "out"
73-
return (input_folder, output_folder)
74-
75-
76-
def test_embedding_step(mock_embedding, default_embedding_data, env):
24+
def test_embedding_step(mock_embedding, default_embedding_data, env, splitter_tokenizer_model, sentence_splitter_model):
7725
"""Tests the execution of the `EmbeddingStep` with a mock input file.
7826
7927
Parameters
@@ -92,8 +40,8 @@ def test_embedding_step(mock_embedding, default_embedding_data, env):
9240
env.set("EMBEDDINGSTEP__TOKEN_COUNT_MIN", "64")
9341
env.set("EMBEDDINGSTEP__TOKEN_COUNT_MAX", "256")
9442
env.set("EMBEDDINGSTEP__TOKEN_COUNT_BUFFER", "32")
95-
env.set("EMBEDDINGSTEP__TOKENIZER_MODEL", SPLITTER_TOKENIZER_MODEL)
96-
env.set("EMBEDDINGSTEP__SENTENCE_SPLITTER_MODEL", SENTENCE_SPLITTER_MODEL)
43+
env.set("EMBEDDINGSTEP__TOKENIZER_MODEL", splitter_tokenizer_model)
44+
env.set("EMBEDDINGSTEP__SENTENCE_SPLITTER_MODEL", sentence_splitter_model)
9745

9846
EmbeddingStep._select_embedding = mock_embedding
9947
input_folder, output_folder = default_embedding_data
@@ -152,15 +100,17 @@ def _select_embedding(*args, **kwargs) -> HuggingFaceInferenceAPIEmbeddings:
152100
assert sf.value.message.endswith(EXPECTED_EXCEPTION)
153101

154102

155-
def test_embedding_step_log_statistics(mock_embedding, default_embedding_data, env, caplog):
103+
def test_embedding_step_log_statistics(
104+
mock_embedding, default_embedding_data, env, caplog, splitter_tokenizer_model, sentence_splitter_model
105+
):
156106
"""Tests the logging of descriptive statistics in the `EmbeddingStep` with a mock input file."""
157107
env.set("EMBEDDINGSTEP__API", "https://example-embedding.com/embed")
158108
env.set("EMBEDDINGSTEP__NUM_THREADS", "1") # Ensure deterministic behavior with single thread
159109
env.set("EMBEDDINGSTEP__TOKEN_COUNT_MIN", "64")
160110
env.set("EMBEDDINGSTEP__TOKEN_COUNT_MAX", "256")
161111
env.set("EMBEDDINGSTEP__TOKEN_COUNT_BUFFER", "32")
162-
env.set("EMBEDDINGSTEP__TOKENIZER_MODEL", SPLITTER_TOKENIZER_MODEL)
163-
env.set("EMBEDDINGSTEP__SENTENCE_SPLITTER_MODEL", SENTENCE_SPLITTER_MODEL)
112+
env.set("EMBEDDINGSTEP__TOKENIZER_MODEL", splitter_tokenizer_model)
113+
env.set("EMBEDDINGSTEP__SENTENCE_SPLITTER_MODEL", sentence_splitter_model)
164114

165115
EmbeddingStep._select_embedding = mock_embedding
166116
input_folder, output_folder = default_embedding_data
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# SPDX-FileCopyrightText: 2025 Deutsche Telekom AG (opensource@telekom.de)
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
import pytest
5+
6+
from wurzel.utils import HAS_LANGCHAIN_CORE, HAS_REQUESTS, HAS_SPACY, HAS_TIKTOKEN
7+
8+
if not HAS_LANGCHAIN_CORE or not HAS_REQUESTS or not HAS_SPACY or not HAS_TIKTOKEN:
9+
pytest.skip("Embedding dependencies (langchain-core, requests, spacy, tiktoken) are not available", allow_module_level=True)
10+
11+
from wurzel.step_executor import BaseStepExecutor
12+
from wurzel.steps.embedding.step import TruncatedEmbeddingStep
13+
14+
15+
@pytest.mark.parametrize(
16+
"token_count_max,mean_text_length",
17+
[
18+
(99999, 959.4),
19+
(9999, 959.4),
20+
(256, 491.1),
21+
(128, 309.8),
22+
(32, 103.4),
23+
],
24+
)
25+
def test_truncated_embedding_step(
26+
token_count_max, mean_text_length, mock_embedding, default_embedding_data, env, splitter_tokenizer_model, sentence_splitter_model
27+
):
28+
"""Tests the execution of the `TruncatedEmbeddingStep` with a mock input file and check total output count and mean length of texts.
29+
30+
Parameters
31+
----------
32+
mock_embedding : MockEmbedding
33+
The mock embedding fixture.
34+
tmp_path : pathlib.Path
35+
A pytest fixture that provides a temporary directory unique to the test invocation.
36+
37+
Asserts
38+
-------
39+
Asserts that the `embedding.csv` file is created in the output folder.
40+
41+
"""
42+
env.set("TRUNCATEDEMBEDDINGSTEP__API", "https://example-embedding.com/embed")
43+
env.set("TRUNCATEDEMBEDDINGSTEP__TOKEN_COUNT_MIN", "64")
44+
env.set("TRUNCATEDEMBEDDINGSTEP__TOKEN_COUNT_MAX", str(token_count_max))
45+
env.set("TRUNCATEDEMBEDDINGSTEP__TOKEN_COUNT_BUFFER", "32")
46+
env.set("TRUNCATEDEMBEDDINGSTEP__TOKENIZER_MODEL", splitter_tokenizer_model)
47+
env.set("TRUNCATEDEMBEDDINGSTEP__SENTENCE_SPLITTER_MODEL", sentence_splitter_model)
48+
49+
TruncatedEmbeddingStep._select_embedding = mock_embedding
50+
input_folder, output_folder = default_embedding_data
51+
step_res = BaseStepExecutor(dont_encapsulate=False).execute_step(TruncatedEmbeddingStep, [input_folder], output_folder)
52+
assert output_folder.is_dir()
53+
assert len(list(output_folder.glob("*"))) > 0
54+
55+
step_output, step_report = step_res[0]
56+
57+
assert len(step_output) == 7, "Step outputs have wrong count."
58+
assert step_report.results == 7, "Step report has wrong count of outputs."
59+
60+
assert step_output.embedding_input_text.str.len().mean() == pytest.approx(mean_text_length, abs=0.1), (
61+
"Invalid mean length of embedding_input_text"
62+
)

tests/steps/qdrant/e2e_test.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,17 @@ def test_qdrant_connector_first(input_output_folder: tuple[Path, Path], dummy_co
3131
input_file = input_path / "qdrant_at.csv"
3232
output_file = output_path / "QdrantConnectorStep"
3333
shutil.copy("./tests/data/embedded.csv", input_file)
34-
BaseStepExecutor().execute_step(QdrantConnectorStep, {input_path}, output_file)
34+
35+
with BaseStepExecutor() as ex:
36+
step_res = ex(QdrantConnectorStep, {input_path}, output_file)
37+
38+
step_output, step_report = step_res[0]
39+
40+
# Validate step output
41+
assert step_report.results == 2, "Invalid step results"
42+
43+
assert step_output["collection"][1] == "dummy_v1", "Invalid step output in collection name"
44+
assert step_output["metadata"][0]["foo"] == "bar", "Invalid step output in metadata"
3545

3646

3747
def test_qdrant_connector_has_previous(input_output_folder: tuple[Path, Path], dummy_collection):

wurzel/datacontract/datacontract.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,10 +52,28 @@ def load_from_path(cls, path: Path, *args) -> Self:
5252
"""Switch case to find the matching file ending."""
5353
import pandas as pd # pylint: disable=import-outside-toplevel
5454

55+
# Load CSV from path
5556
read_data = pd.read_csv(path.open(encoding="utf-8"))
57+
58+
def _literal_eval_or_passthrough(value):
59+
"""Convert stringified literals to Python objects because pandas keeps CSV cells as strings."""
60+
if not isinstance(value, str):
61+
return value
62+
stripped = value.strip()
63+
if stripped == "":
64+
return None
65+
try:
66+
return literal_eval(stripped)
67+
except (ValueError, SyntaxError):
68+
return value
69+
70+
# Iterate over coluns and load data
5671
for key, atr in cls.to_schema().columns.items():
57-
if atr.dtype.type is list:
58-
read_data[key] = read_data[key].apply(literal_eval)
72+
if key not in read_data.columns:
73+
continue
74+
if atr.dtype.type in {list, dict}:
75+
read_data[key] = read_data[key].apply(_literal_eval_or_passthrough)
76+
5977
return patyp.DataFrame[cls](read_data)
6078

6179

wurzel/steps/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,10 @@
1212
if HAS_LANGCHAIN_CORE and HAS_REQUESTS:
1313
try:
1414
from .embedding import * # noqa: F403 Allow importing Step classes
15-
from .embedding import EmbeddingStep # noqa: F401
15+
from .embedding import (
16+
EmbeddingStep, # noqa: F401
17+
TruncatedEmbeddingStep, # noqa: F401
18+
)
1619
except ImportError:
1720
pass
1821

0 commit comments

Comments
 (0)