Skip to content

Commit 4a276d1

Browse files
authored
feat: Splitter tracks token and char length, embedding step logs the length statistics (#167)
1 parent 212a6e2 commit 4a276d1

File tree

10 files changed

+366
-31
lines changed

10 files changed

+366
-31
lines changed

tests/datacontract/md_test.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,9 @@ def test_manual_step_md_parsing(tmp_path, md, url, bread):
5050
else:
5151
assert s.md == "Text"
5252

53+
# check metadata field
54+
assert s.metadata is None
55+
5356

5457
class MDCChild(MarkdownDataContract):
5558
pass
@@ -200,6 +203,40 @@ def test_topics_deprecation_warning(tmp_path):
200203
assert s.md.startswith("# Some title")
201204

202205

206+
def test_metadata_field_metadata(tmp_path):
207+
md = """---
208+
keywords: "k1"
209+
url: foo/bar
210+
metadata:
211+
foo: bar
212+
bar: 123
213+
---
214+
# Title
215+
216+
Text.
217+
"""
218+
f = tmp_path / "file.md"
219+
f.write_text(md)
220+
s = MarkdownDataContract.from_file(f)
221+
222+
assert "# Title" in s.md
223+
assert s.metadata is not None
224+
assert s.metadata["foo"] == "bar"
225+
assert s.metadata["bar"] == 123
226+
assert s.url == "foo/bar"
227+
228+
assert s.__hash__() == 21317556317919954558699657768736304700342060298586059611903002870732316103488, "Invalid hash"
229+
230+
# save and load again
231+
f2 = tmp_path / "file2.json"
232+
233+
MarkdownDataContract.save_to_path(f2, s)
234+
235+
s2 = MarkdownDataContract.load_from_path(f2, MarkdownDataContract)
236+
237+
assert s.__hash__() == s2.__hash__(), "Invalid hash after write/load file"
238+
239+
203240
def test_utf8_encoding(tmp_path):
204241
"""Test that UTF-8 encoded files are read correctly, especially on Windows."""
205242
f = tmp_path / "file.md"

tests/steps/embedding/e2e_test.py

Lines changed: 76 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,25 +3,27 @@
33
# SPDX-License-Identifier: Apache-2.0
44

55
# Standard library imports
6+
import logging
67
import shutil
78
from pathlib import Path
89

910
import numpy as np
1011
import pytest
1112

12-
from wurzel.utils import HAS_LANGCHAIN_CORE, HAS_REQUESTS
13+
from wurzel.utils import HAS_LANGCHAIN_CORE, HAS_REQUESTS, HAS_SPACY, HAS_TIKTOKEN
1314

14-
if not HAS_LANGCHAIN_CORE or not HAS_REQUESTS:
15-
pytest.skip("Embedding dependencies (langchain-core, requests) are not available", allow_module_level=True)
15+
if not HAS_LANGCHAIN_CORE or not HAS_REQUESTS or not HAS_SPACY or not HAS_TIKTOKEN:
16+
pytest.skip("Embedding dependencies (langchain-core, requests, spacy, tiktoken) are not available", allow_module_level=True)
1617

1718
from wurzel.exceptions import StepFailed
1819
from wurzel.step_executor import BaseStepExecutor
19-
20-
# Local application/library specific imports
2120
from wurzel.steps import EmbeddingStep
2221
from wurzel.steps.embedding.huggingface import HuggingFaceInferenceAPIEmbeddings
2322
from wurzel.steps.embedding.step_multivector import EmbeddingMultiVectorStep
2423

24+
SPLITTER_TOKENIZER_MODEL = "gpt-3.5-turbo"
25+
SENTENCE_SPLITTER_MODEL = "de_core_news_sm"
26+
2527

2628
@pytest.fixture(scope="module")
2729
def mock_embedding():
@@ -87,12 +89,23 @@ def test_embedding_step(mock_embedding, default_embedding_data, env):
8789
8890
"""
8991
env.set("EMBEDDINGSTEP__API", "https://example-embedding.com/embed")
92+
env.set("EMBEDDINGSTEP__TOKEN_COUNT_MIN", "64")
93+
env.set("EMBEDDINGSTEP__TOKEN_COUNT_MAX", "256")
94+
env.set("EMBEDDINGSTEP__TOKEN_COUNT_BUFFER", "32")
95+
env.set("EMBEDDINGSTEP__TOKENIZER_MODEL", SPLITTER_TOKENIZER_MODEL)
96+
env.set("EMBEDDINGSTEP__SENTENCE_SPLITTER_MODEL", SENTENCE_SPLITTER_MODEL)
97+
9098
EmbeddingStep._select_embedding = mock_embedding
9199
input_folder, output_folder = default_embedding_data
92-
BaseStepExecutor(dont_encapsulate=False).execute_step(EmbeddingStep, [input_folder], output_folder)
100+
step_res = BaseStepExecutor(dont_encapsulate=False).execute_step(EmbeddingStep, [input_folder], output_folder)
93101
assert output_folder.is_dir()
94102
assert len(list(output_folder.glob("*"))) > 0
95103

104+
step_output, step_report = step_res[0]
105+
106+
assert len(step_output) == 11, "Step outputs have wrong count."
107+
assert step_report.results == 11, "Step report has wrong count of outputs."
108+
96109

97110
def test_mutlivector_embedding_step(mock_embedding, tmp_path, env):
98111
"""Tests the execution of the `EmbeddingMultiVectorStep` with a mock input file.
@@ -137,3 +150,60 @@ def _select_embedding(*args, **kwargs) -> HuggingFaceInferenceAPIEmbeddings:
137150
with BaseStepExecutor() as ex:
138151
ex(InheritedStep, [inp], out)
139152
assert sf.value.message.endswith(EXPECTED_EXCEPTION)
153+
154+
155+
def test_embedding_step_log_statistics(mock_embedding, default_embedding_data, env, caplog):
156+
"""Tests the logging of descriptive statistics in the `EmbeddingStep` with a mock input file."""
157+
env.set("EMBEDDINGSTEP__API", "https://example-embedding.com/embed")
158+
env.set("EMBEDDINGSTEP__NUM_THREADS", "1") # Ensure deterministic behavior with single thread
159+
env.set("EMBEDDINGSTEP__TOKEN_COUNT_MIN", "64")
160+
env.set("EMBEDDINGSTEP__TOKEN_COUNT_MAX", "256")
161+
env.set("EMBEDDINGSTEP__TOKEN_COUNT_BUFFER", "32")
162+
env.set("EMBEDDINGSTEP__TOKENIZER_MODEL", SPLITTER_TOKENIZER_MODEL)
163+
env.set("EMBEDDINGSTEP__SENTENCE_SPLITTER_MODEL", SENTENCE_SPLITTER_MODEL)
164+
165+
EmbeddingStep._select_embedding = mock_embedding
166+
input_folder, output_folder = default_embedding_data
167+
168+
with caplog.at_level(logging.INFO):
169+
BaseStepExecutor(dont_encapsulate=False).execute_step(EmbeddingStep, [input_folder], output_folder)
170+
171+
# check if output log exists
172+
assert "Distribution of char length" in caplog.text, "Missing log output for char length"
173+
assert "Distribution of token length" in caplog.text, "Missing log output for token length"
174+
assert "Distribution of chunks count" in caplog.text, "Missing log output for chunks count"
175+
176+
# check extras
177+
char_length_record = None
178+
token_length_record = None
179+
chunks_count_record = None
180+
181+
for record in caplog.records:
182+
if "Distribution of char length" in record.message:
183+
char_length_record = record
184+
185+
if "Distribution of token length" in record.message:
186+
token_length_record = record
187+
188+
if "Distribution of chunks count" in record.message:
189+
chunks_count_record = record
190+
191+
expected_char_length_count = 11
192+
193+
# Check values if a small tolerance
194+
expected_char_length_mean = pytest.approx(609.18, abs=0.1)
195+
expected_token_length_mean = pytest.approx(257.18, abs=0.1)
196+
expected_chunks_count_mean = pytest.approx(3.18, abs=0.2)
197+
198+
assert char_length_record.count == expected_char_length_count, (
199+
f"Invalid char length count: expected {expected_char_length_count}, got {char_length_record.count}"
200+
)
201+
assert char_length_record.mean == expected_char_length_mean, (
202+
f"Invalid char length mean: expected {expected_char_length_mean}, got {char_length_record.mean}"
203+
)
204+
assert token_length_record.mean == expected_token_length_mean, (
205+
f"Invalid token length mean: expected {expected_token_length_mean}, got {token_length_record.mean}"
206+
)
207+
assert chunks_count_record.mean == expected_chunks_count_mean, (
208+
f"Invalid chunks count mean: expected {expected_chunks_count_mean}, got {chunks_count_record.mean}"
209+
)

tests/steps/simple_splitter/__init__.py

Whitespace-only changes.
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
# SPDX-FileCopyrightText: 2025 Deutsche Telekom AG (opensource@telekom.de)
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
import shutil
5+
from pathlib import Path
6+
7+
import pytest
8+
9+
from wurzel.utils import HAS_SPACY, HAS_TIKTOKEN
10+
11+
if not HAS_SPACY or not HAS_TIKTOKEN:
12+
pytest.skip("Simple splitter dependencies (spacy, tiktoken) are not available", allow_module_level=True)
13+
14+
from wurzel.step_executor import BaseStepExecutor
15+
from wurzel.steps.splitter import SimpleSplitterStep
16+
17+
18+
@pytest.fixture
19+
def default_markdown_data(tmp_path):
20+
mock_file = Path("tests/data/markdown.json")
21+
input_folder = tmp_path / "input"
22+
input_folder.mkdir()
23+
shutil.copy(mock_file, input_folder)
24+
output_folder = tmp_path / "out"
25+
return (input_folder, output_folder)
26+
27+
28+
def test_simple_splitter_step(default_markdown_data, env):
29+
"""Tests the execution of the `SimpleSplitterStep` with a mock input file."""
30+
env.set("SIMPLESPLITTERSTEP__TOKEN_COUNT_MIN", "64")
31+
env.set("SIMPLESPLITTERSTEP__TOKEN_COUNT_MAX", "256")
32+
env.set("SIMPLESPLITTERSTEP__TOKEN_COUNT_BUFFER", "32")
33+
34+
input_folder, output_folder = default_markdown_data
35+
step_res = BaseStepExecutor(dont_encapsulate=False).execute_step(SimpleSplitterStep, [input_folder], output_folder)
36+
assert output_folder.is_dir()
37+
assert len(list(output_folder.glob("*"))) > 0
38+
39+
step_output, step_report = step_res[0]
40+
41+
assert len(step_output) == 11, "Step outputs have wrong count."
42+
assert step_report.results == 11, "Step report has wrong count of outputs."

wurzel/datacontract/common.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020

2121
class MarkdownDataContract(PydanticModel):
22-
"""A data contract of the input of the EmbeddingStep representing a document in Markdown format.
22+
"""A data contract of the input/output of the various pipeline steps representing a document in Markdown format.
2323
2424
The document consists have the Markdown body (document content) and additional metadata (keywords, url).
2525
The metadata is optional.
@@ -47,11 +47,25 @@ class MarkdownDataContract(PydanticModel):
4747
Another text.
4848
```
4949
50+
Example 3 (with extra metadata fields)
51+
```md
52+
---
53+
keywords: "bread,butter"
54+
url: "some/file/path.md"
55+
metadata:
56+
token_len: 123
57+
char_len: 550
58+
---
59+
# Some title
60+
61+
A short text.
62+
```
5063
"""
5164

5265
md: str
5366
keywords: str
5467
url: str # Url of pydantic is buggy in serialization
68+
metadata: dict[str, Any] | None = None
5569

5670
@classmethod
5771
@pydantic.validate_call
@@ -61,6 +75,7 @@ def from_dict_w_function(cls, doc: dict[str, Any], func: Callable[[str], str]):
6175
md=func(doc["text"]),
6276
url=doc["metadata"]["url"],
6377
keywords=doc["metadata"]["keywords"],
78+
metadata=doc["metadata"].get("metadata", None),
6479
)
6580

6681
@classmethod
@@ -115,4 +130,5 @@ def from_file(cls, path: Path, url_prefix: str = "") -> Self:
115130
# Extract metadata fields or use default value
116131
url=metadata.get("url", url_prefix + str(path.absolute())),
117132
keywords=metadata.get("keywords", path.name.split(".")[0]),
133+
metadata=metadata.get("metadata", None),
118134
)

wurzel/datacontract/datacontract.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,11 +116,12 @@ def load_from_path(cls, path: Path, model_type: type[Union[Self, list[Self]]]) -
116116
raise NotImplementedError(f"Can not load {model_type}")
117117

118118
def __hash__(self) -> int:
119+
"""Compute a hash based on all not-none field values."""
119120
# pylint: disable-next=not-an-iterable
120121
return int(
121122
hashlib.sha256(
122123
bytes(
123-
"".join([getattr(self, name) for name in sorted(type(self).model_fields)]),
124+
"".join([str(getattr(self, name) or "") for name in sorted(type(self).model_fields)]),
124125
encoding="utf-8",
125126
),
126127
usedforsecurity=False,

wurzel/steps/embedding/step.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,12 @@
77
# Standard library imports
88
import os
99
import re
10+
from collections import defaultdict
1011
from io import StringIO
1112
from logging import getLogger
1213
from typing import Optional, TypedDict
1314

15+
import numpy as np
1416
from markdown import Markdown
1517
from pandera.typing import DataFrame
1618
from tqdm.auto import tqdm
@@ -87,9 +89,18 @@ def run(self, inpt: list[MarkdownDataContract]) -> DataFrame[EmbeddingResult]:
8789
splitted_md_rows = self._split_markdown(inpt)
8890
rows = []
8991
failed = 0
92+
stats = defaultdict(list)
93+
9094
for row in tqdm(splitted_md_rows, desc="Calculate Embeddings"):
9195
try:
9296
rows.append(self._get_embedding(row))
97+
98+
# collect statistics
99+
if row.metadata is not None:
100+
stats["char length"].append(row.metadata.get("char_len", 0))
101+
stats["token length"].append(row.metadata.get("token_len", 0))
102+
stats["chunks count"].append(row.metadata.get("chunks_count", 0))
103+
93104
except EmbeddingAPIException as err:
94105
log.warning(
95106
f"Skipped because EmbeddingAPIException: {err.message}",
@@ -100,8 +111,47 @@ def run(self, inpt: list[MarkdownDataContract]) -> DataFrame[EmbeddingResult]:
100111
log.warning(f"{failed}/{len(splitted_md_rows)} got skipped")
101112
if failed == len(splitted_md_rows):
102113
raise StepFailed(f"all {len(splitted_md_rows)} embeddings got skipped")
114+
115+
# log statistics
116+
for k, v in stats.items():
117+
self.log_statistics(series=np.array(v), name=k)
118+
103119
return DataFrame[EmbeddingResult](DataFrame[EmbeddingResult](rows))
104120

121+
def log_statistics(self, series: np.ndarray, name: str):
122+
"""Log descriptive statistics for all documents.
123+
124+
Parameters
125+
----------
126+
series : np.ndarray
127+
Numerical values representing the documents.
128+
name : str
129+
The name of the document metric.
130+
"""
131+
stats = {
132+
"count": len(series),
133+
"mean": None,
134+
"std": None,
135+
}
136+
137+
if len(series) > 0:
138+
stats.update(
139+
{
140+
"mean": np.mean(series),
141+
"median": np.median(series),
142+
"std": np.std(series),
143+
"var": np.var(series),
144+
"min": np.min(series),
145+
"percentile_5": np.percentile(series, 5),
146+
"percentile_25": np.percentile(series, 25),
147+
"percentile_75": np.percentile(series, 75),
148+
"percentile_95": np.percentile(series, 95),
149+
"max": np.max(series),
150+
}
151+
)
152+
153+
log.info(f"Distribution of {name}: count={stats['count']}; mean={stats['mean']}; std={stats['std']}", extra=stats)
154+
105155
def get_embedding_input_from_document(self, doc: MarkdownDataContract) -> str:
106156
"""Clean the document such that it can be used as input to the embedding model.
107157

wurzel/steps/splitter.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ class SplitterSettings(Settings):
2424
"""Anything Embedding-related."""
2525

2626
BATCH_SIZE: int = Field(100, gt=0)
27-
NUM_THREADS: int = Field(4, gt=1)
27+
NUM_THREADS: int = Field(4, ge=1)
2828
TOKEN_COUNT_MIN: int = Field(64, gt=0)
2929
TOKEN_COUNT_MAX: int = Field(1024, gt=1)
3030
TOKEN_COUNT_BUFFER: int = Field(32, gt=0)
@@ -94,13 +94,13 @@ def _split_markdown(self, markdowns: list[MarkdownDataContract]) -> list[Markdow
9494
"""Creates data rows from a batch of markdown texts by splitting them and counting tokens."""
9595
rows = []
9696
skipped = 0
97-
for s in markdowns:
97+
for md_data_contract in markdowns:
9898
try:
99-
rows.extend(self.splitter.split_markdown_document(s))
99+
rows.extend(self.splitter.split_markdown_document(md_data_contract))
100100
except MarkdownException as err:
101101
log.warning(
102102
"skipped dokument ",
103-
extra={"reason": err.__class__.__name__, "doc": s},
103+
extra={"reason": err.__class__.__name__, "doc": md_data_contract},
104104
)
105105
skipped += 1
106106
if skipped == len(markdowns):

0 commit comments

Comments
 (0)