-
Notifications
You must be signed in to change notification settings - Fork 78
Add length-constrained segmentation with configurable priors and algo… #164
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 1 commit
e0ca77e
b9e5b44
6832b97
471a6e2
d5949d1
1922168
6b52cc1
f392418
cbd47b8
bd4db82
300c2f9
b9b08c9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,5 +1,7 @@ | ||
| # noqa: E501 | ||
| from wtpsplit import WtP, SaT | ||
| import numpy as np | ||
|
|
||
|
|
||
| def test_weighting(): | ||
| sat = SaT("sat-3l-sm", ort_providers=["CPUExecutionProvider"]) | ||
|
|
@@ -253,4 +255,148 @@ def test_split_threshold_wtp(): | |
|
|
||
| splits = wtp.split("This is a test sentence. This is another test sentence.", threshold=-1e-3) | ||
| # space might still be included in a character split | ||
| assert splits[:3] == list("Thi") | ||
| assert splits[:3] == list("Thi") | ||
|
|
||
| # ============================================================================ | ||
| # Length-Constrained Segmentation Tests | ||
| # ============================================================================ | ||
|
|
||
| def test_min_length_constraint_wtp(): | ||
| """Test minimum length constraint with WtP""" | ||
| wtp = WtP("wtp-bert-mini", ort_providers=["CPUExecutionProvider"]) | ||
|
|
||
| text = "Short. Test. Hello. World. This is longer." | ||
| splits = wtp.split(text, min_length=15, threshold=0.005) | ||
|
|
||
| # All segments should be >= 15 characters | ||
| for segment in splits: | ||
| assert len(segment) >= 15, f"Segment '{segment}' is shorter than min_length" | ||
|
|
||
| # Text should be preserved | ||
| assert "".join(splits) == text | ||
|
|
||
|
|
||
| def test_max_length_constraint_sat(): | ||
| """Test maximum length constraint with SaT""" | ||
| sat = SaT("sat-3l-sm", ort_providers=["CPUExecutionProvider"]) | ||
|
|
||
| text = "This is a test sentence. " * 10 | ||
| splits = sat.split(text, max_length=60, threshold=0.025) | ||
|
|
||
| # All segments should be <= 60 characters | ||
| for segment in splits: | ||
| assert len(segment) <= 60, f"Segment '{segment}' is longer than max_length" | ||
|
|
||
| # Text should be preserved | ||
| assert "".join(splits) == text | ||
|
|
||
|
|
||
| def test_min_max_constraints_together(): | ||
| """Test both constraints simultaneously""" | ||
| wtp = WtP("wtp-bert-mini", ort_providers=["CPUExecutionProvider"]) | ||
|
|
||
| text = "Hello world. " * 15 | ||
| splits = wtp.split(text, min_length=25, max_length=65, threshold=0.005) | ||
|
|
||
| # All segments should satisfy both constraints | ||
| for segment in splits: | ||
| assert 25 <= len(segment) <= 65, f"Segment '{segment}' violates constraints" | ||
|
Comment on lines
296
to
305
|
||
|
|
||
| # Text should be preserved | ||
| assert "".join(splits) == text | ||
|
|
||
|
|
||
| def test_gaussian_prior(): | ||
| """Test Gaussian prior preference""" | ||
| sat = SaT("sat-3l-sm", ort_providers=["CPUExecutionProvider"]) | ||
|
|
||
| text = "Sentence. " * 30 | ||
| splits = sat.split( | ||
| text, | ||
| min_length=20, | ||
| max_length=60, | ||
| prior_type="gaussian", | ||
| prior_kwargs={"mu": 40.0, "sigma": 5.0}, | ||
| threshold=0.025 | ||
| ) | ||
|
|
||
| # Should produce valid splits | ||
| for segment in splits: | ||
| assert 20 <= len(segment) <= 60 | ||
|
|
||
| # Text should be preserved | ||
| assert "".join(splits) == text | ||
|
|
||
|
|
||
| def test_greedy_algorithm(): | ||
| """Test greedy algorithm""" | ||
| wtp = WtP("wtp-bert-mini", ort_providers=["CPUExecutionProvider"]) | ||
|
|
||
| text = "Test sentence. " * 10 | ||
| splits = wtp.split(text, min_length=20, max_length=50, algorithm="greedy", threshold=0.005) | ||
|
|
||
| # Should produce valid splits | ||
| for segment in splits: | ||
| assert 20 <= len(segment) <= 50 | ||
|
|
||
| # Text should be preserved | ||
| assert "".join(splits) == text | ||
|
|
||
|
|
||
| def test_constraints_with_paragraph_segmentation(): | ||
| """Test constraints with nested paragraph segmentation""" | ||
| wtp = WtP("wtp-bert-mini", ort_providers=["CPUExecutionProvider"]) | ||
|
|
||
| text = " ".join([ | ||
| "First paragraph first sentence. First paragraph second sentence.", | ||
| "Second paragraph first sentence. Second paragraph second sentence." | ||
| ]) | ||
|
|
||
| paragraphs = wtp.split(text, do_paragraph_segmentation=True, min_length=20, max_length=70) | ||
|
|
||
| # Check structure | ||
| assert isinstance(paragraphs, list) | ||
| for paragraph in paragraphs: | ||
| assert isinstance(paragraph, list) | ||
| for sentence in paragraph: | ||
| assert 20 <= len(sentence) <= 70 | ||
|
|
||
|
|
||
| def test_constraints_preserved_in_batched(): | ||
| """Test constraints work with batched processing""" | ||
| sat = SaT("sat-3l-sm", ort_providers=["CPUExecutionProvider"]) | ||
|
|
||
| texts = [ | ||
| "First batch text. " * 5, | ||
| "Second batch text. " * 5, | ||
| ] | ||
|
|
||
| results = list(sat.split(texts, min_length=25, max_length=55, threshold=0.025)) | ||
|
|
||
| assert len(results) == 2 | ||
| for splits in results: | ||
| for segment in splits: | ||
| assert 25 <= len(segment) <= 55 | ||
|
|
||
|
|
||
| def test_constraint_low_level(): | ||
| """Test constrained_segmentation directly""" | ||
| from wtpsplit.utils.constraints import constrained_segmentation | ||
| from wtpsplit.utils.priors import create_prior_function | ||
|
|
||
| probs = np.array([0.1, 0.3, 0.5, 0.7, 0.9, 0.2, 0.4, 0.6, 0.8, 1.0]) | ||
| prior_fn = create_prior_function("uniform", {"max_length": 5}) | ||
|
|
||
| indices = constrained_segmentation(probs, prior_fn, min_length=3, max_length=5, algorithm="viterbi") | ||
|
|
||
| # Verify constraints on chunk lengths | ||
| prev = 0 | ||
| for idx in indices: | ||
| chunk_len = idx - prev | ||
| assert 3 <= chunk_len <= 5, f"Chunk length {chunk_len} violates constraints" | ||
| prev = idx | ||
|
|
||
| # Check last chunk | ||
| if prev < len(probs): | ||
| last_len = len(probs) - prev | ||
| assert 3 <= last_len <= 5, f"Last chunk length {last_len} violates constraints" | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar to the comprehensive test suite, line 273 asserts that all segments are
>= 15characters when usingmin_length=15. However,min_lengthis documented as "best effort" and segments may be shorter if constraints cannot be satisfied. This test may fail in edge cases and doesn't align with the documented behavior. Consider either: