Skip to content

Commit bdf30a5

Browse files
authored
fix: Setting token length correctly in splitter metadata (#186)
1 parent b783fcd commit bdf30a5

File tree

3 files changed

+24
-8
lines changed

3 files changed

+24
-8
lines changed

tests/steps/embedding/e2e_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ def test_embedding_step_log_statistics(mock_embedding, default_embedding_data, e
192192

193193
# Check values if a small tolerance
194194
expected_char_length_mean = pytest.approx(609.18, abs=0.1)
195-
expected_token_length_mean = pytest.approx(257.18, abs=0.1)
195+
expected_token_length_mean = pytest.approx(188.3, abs=0.1)
196196
expected_chunks_count_mean = pytest.approx(3.18, abs=0.2)
197197

198198
assert char_length_record.count == expected_char_length_count, (

tests/steps/simple_splitter/e2e_simple_splitter_test.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,3 +40,13 @@ def test_simple_splitter_step(default_markdown_data, env):
4040

4141
assert len(step_output) == 11, "Step outputs have wrong count."
4242
assert step_report.results == 11, "Step report has wrong count of outputs."
43+
44+
hash_count = [o.md.count("#") for o in step_output]
45+
nl_count = [o.md.count("\n") for o in step_output]
46+
token_lens = [o.metadata["token_len"] for o in step_output]
47+
char_lens = [o.metadata["char_len"] for o in step_output]
48+
49+
assert hash_count == [9, 3, 10, 13, 12, 12, 15, 1, 9, 3, 0], "Chunks have invalid hash count"
50+
assert nl_count == [4, 0, 16, 23, 23, 24, 22, 9, 4, 0, 6], "Chunks have invalid new line count"
51+
assert token_lens == [236, 74, 243, 278, 240, 225, 247, 136, 245, 67, 81], "Chunks have invalid token length"
52+
assert char_lens == [757, 235, 839, 917, 776, 699, 797, 447, 787, 227, 220], "Chunks have invalid char length"

wurzel/utils/splitters/semantic_splitter.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -520,14 +520,19 @@ def _handle_parsing_of_children(
520520
remaining_snipped = text_w_prev_child
521521
elif self._is_within_targetlen_w_buffer(text_w_prev_child):
522522
child["text"] = text_w_prev_child
523+
524+
# Make sure text in within token limit
525+
limited_child_text = self._cut_to_tokenlen(child["text"], self.token_limit)
526+
527+
# Build document from text and child metadata
523528
return_doc += [
524529
MarkdownDataContract(
525-
md=self._cut_to_tokenlen(child["text"], self.token_limit),
530+
md=limited_child_text,
526531
url=child["metadata"]["url"],
527532
keywords=child["metadata"]["keywords"],
528533
metadata={
529-
"token_len": self.token_limit,
530-
"char_len": len(child["text"]),
534+
"token_len": self._get_token_len(limited_child_text),
535+
"char_len": len(limited_child_text),
531536
},
532537
)
533538
]
@@ -583,7 +588,7 @@ def _md_data_from_dict_cut(self, doc: DocumentNode) -> MarkdownDataContract:
583588
url=doc["metadata"]["url"],
584589
keywords=doc["metadata"]["keywords"],
585590
metadata={
586-
"token_len": self.token_limit,
591+
"token_len": self._get_token_len(text),
587592
"char_len": len(text),
588593
},
589594
)
@@ -677,14 +682,15 @@ def _parse_hierarchical(
677682

678683
# add potential short remaining spillovers
679684
if self._get_token_len(remaining_snipped) >= self.token_limit_min:
685+
limited_remaining_snipped = self._cut_to_tokenlen(remaining_snipped, self.token_limit)
680686
return_doc += [
681687
MarkdownDataContract(
682-
md=self._cut_to_tokenlen(remaining_snipped, self.token_limit),
688+
md=limited_remaining_snipped,
683689
url=doc["metadata"]["url"],
684690
keywords=doc["metadata"]["keywords"],
685691
metadata={
686-
"token_len": self.token_limit,
687-
"char_len": len(remaining_snipped),
692+
"token_len": self._get_token_len(limited_remaining_snipped),
693+
"char_len": len(limited_remaining_snipped),
688694
},
689695
)
690696
]

0 commit comments

Comments
 (0)