Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 147 additions & 1 deletion test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# noqa: E501
from wtpsplit import WtP, SaT
import numpy as np


def test_weighting():
sat = SaT("sat-3l-sm", ort_providers=["CPUExecutionProvider"])
Expand Down Expand Up @@ -253,4 +255,148 @@ def test_split_threshold_wtp():

splits = wtp.split("This is a test sentence. This is another test sentence.", threshold=-1e-3)
# space might still be included in a character split
assert splits[:3] == list("Thi")
assert splits[:3] == list("Thi")

# ============================================================================
# Length-Constrained Segmentation Tests
# ============================================================================

def test_min_length_constraint_wtp():
"""Test minimum length constraint with WtP"""
wtp = WtP("wtp-bert-mini", ort_providers=["CPUExecutionProvider"])

text = "Short. Test. Hello. World. This is longer."
splits = wtp.split(text, min_length=15, threshold=0.005)

# All segments should be >= 15 characters
for segment in splits:
assert len(segment) >= 15, f"Segment '{segment}' is shorter than min_length"
Comment on lines 266 to 275
Copy link

Copilot AI Nov 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar to the comprehensive test suite, line 273 asserts that all segments are >= 15 characters when using min_length=15. However, min_length is documented as "best effort" and segments may be shorter if constraints cannot be satisfied. This test may fail in edge cases and doesn't align with the documented behavior. Consider either:

  1. Allowing for exceptions where segments can be shorter
  2. Using test data where min_length can always be guaranteed

Copilot uses AI. Check for mistakes.

# Text should be preserved
assert "".join(splits) == text


def test_max_length_constraint_sat():
"""Test maximum length constraint with SaT"""
sat = SaT("sat-3l-sm", ort_providers=["CPUExecutionProvider"])

text = "This is a test sentence. " * 10
splits = sat.split(text, max_length=60, threshold=0.025)

# All segments should be <= 60 characters
for segment in splits:
assert len(segment) <= 60, f"Segment '{segment}' is longer than max_length"

# Text should be preserved
assert "".join(splits) == text


def test_min_max_constraints_together():
"""Test both constraints simultaneously"""
wtp = WtP("wtp-bert-mini", ort_providers=["CPUExecutionProvider"])

text = "Hello world. " * 15
splits = wtp.split(text, min_length=25, max_length=65, threshold=0.005)

# All segments should satisfy both constraints
for segment in splits:
assert 25 <= len(segment) <= 65, f"Segment '{segment}' violates constraints"
Comment on lines 296 to 305
Copy link

Copilot AI Nov 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Line 303 asserts that all segments satisfy both constraints 25 <= len(segment) <= 65. This strict assertion doesn't account for the "best effort" nature of min_length. While max_length is always strictly enforced, min_length may not be achievable in all cases. The test might fail for certain inputs where min_length cannot be satisfied without violating max_length.

Copilot uses AI. Check for mistakes.

# Text should be preserved
assert "".join(splits) == text


def test_gaussian_prior():
"""Test Gaussian prior preference"""
sat = SaT("sat-3l-sm", ort_providers=["CPUExecutionProvider"])

text = "Sentence. " * 30
splits = sat.split(
text,
min_length=20,
max_length=60,
prior_type="gaussian",
prior_kwargs={"mu": 40.0, "sigma": 5.0},
threshold=0.025
)

# Should produce valid splits
for segment in splits:
assert 20 <= len(segment) <= 60

# Text should be preserved
assert "".join(splits) == text


def test_greedy_algorithm():
"""Test greedy algorithm"""
wtp = WtP("wtp-bert-mini", ort_providers=["CPUExecutionProvider"])

text = "Test sentence. " * 10
splits = wtp.split(text, min_length=20, max_length=50, algorithm="greedy", threshold=0.005)

# Should produce valid splits
for segment in splits:
assert 20 <= len(segment) <= 50

# Text should be preserved
assert "".join(splits) == text


def test_constraints_with_paragraph_segmentation():
"""Test constraints with nested paragraph segmentation"""
wtp = WtP("wtp-bert-mini", ort_providers=["CPUExecutionProvider"])

text = " ".join([
"First paragraph first sentence. First paragraph second sentence.",
"Second paragraph first sentence. Second paragraph second sentence."
])

paragraphs = wtp.split(text, do_paragraph_segmentation=True, min_length=20, max_length=70)

# Check structure
assert isinstance(paragraphs, list)
for paragraph in paragraphs:
assert isinstance(paragraph, list)
for sentence in paragraph:
assert 20 <= len(sentence) <= 70


def test_constraints_preserved_in_batched():
"""Test constraints work with batched processing"""
sat = SaT("sat-3l-sm", ort_providers=["CPUExecutionProvider"])

texts = [
"First batch text. " * 5,
"Second batch text. " * 5,
]

results = list(sat.split(texts, min_length=25, max_length=55, threshold=0.025))

assert len(results) == 2
for splits in results:
for segment in splits:
assert 25 <= len(segment) <= 55


def test_constraint_low_level():
"""Test constrained_segmentation directly"""
from wtpsplit.utils.constraints import constrained_segmentation
from wtpsplit.utils.priors import create_prior_function

probs = np.array([0.1, 0.3, 0.5, 0.7, 0.9, 0.2, 0.4, 0.6, 0.8, 1.0])
prior_fn = create_prior_function("uniform", {"max_length": 5})

indices = constrained_segmentation(probs, prior_fn, min_length=3, max_length=5, algorithm="viterbi")

# Verify constraints on chunk lengths
prev = 0
for idx in indices:
chunk_len = idx - prev
assert 3 <= chunk_len <= 5, f"Chunk length {chunk_len} violates constraints"
prev = idx

# Check last chunk
if prev < len(probs):
last_len = len(probs) - prev
assert 3 <= last_len <= 5, f"Last chunk length {last_len} violates constraints"
Loading
Loading