diff --git a/.github/workflows/python.yml b/.github/workflows/python.yml index 8d535119..26d71e63 100644 --- a/.github/workflows/python.yml +++ b/.github/workflows/python.yml @@ -8,26 +8,26 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"] + python-version: ["3.9", "3.10", "3.11", "3.12"] steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - name: Install dependencies run: | python -m pip install --upgrade pip - pip install ruff==0.1.5 pytest==7.4.3 + pip install ruff pytest pip install -e .[onnx-cpu] pip install torch --index-url https://download.pytorch.org/whl/cpu - name: Lint with ruff run: | # stop the build if there are Python syntax errors or undefined names - ruff --output-format=github --select=E9,F63,F7,F82 --target-version=py37 . + ruff check --output-format=github --select=E9,F63,F7,F82 --target-version=py39 . # default set of ruff rules with GitHub Annotations - ruff --output-format=github --target-version=py37 . + ruff check --output-format=github --target-version=py39 . - name: Test with pytest run: | pytest test.py diff --git a/.gitignore b/.gitignore index 637bb077..2f0ff645 100644 --- a/.gitignore +++ b/.gitignore @@ -20,4 +20,5 @@ data_subset/** *.txt *.log xlmr-*/** -**/checkpoint-*/** \ No newline at end of file +**/checkpoint-*/** +.venv/ \ No newline at end of file diff --git a/README.md b/README.md index b3484341..62ecea36 100644 --- a/README.md +++ b/README.md @@ -141,9 +141,93 @@ Since SaT are trained to predict newline probablity, they can segment text into sat.split(text, do_paragraph_segmentation=True) ``` +## (NEW! v2.2+) Length-Constrained Segmentation + +Control segment lengths with `min_length` and `max_length` parameters. This is useful when you need segments within specific size limits (e.g., for embedding models, storage, or downstream processing). + +### Basic Usage + +```python +# segments will be at most 100 characters +sat.split(text, max_length=100) + +# segments will be at least 20 characters (best effort) and at most 100 characters (strict) +sat.split(text, min_length=20, max_length=100) + +# use different algorithms: "viterbi" (optimal, default) or "greedy" (faster) +sat.split(text, max_length=100, algorithm="greedy") +``` + +### Priors for Length Preference + +Use priors to influence segment length distribution. Available priors: + +| Prior | Best For | +|-------|----------| +| `"uniform"` (default) | Just enforce max_length, let model decide | +| `"gaussian"` | Prefer segments around a target length (intuitive) | +| `"lognormal"` | Right-skewed preference (more tolerant of longer segments) | +| `"clipped_polynomial"` | Must be very close to target length | + +```python +# Gaussian prior (recommended): prefer segments around target_length +sat.split(text, max_length=100, prior_type="gaussian", + prior_kwargs={"target_length": 50, "spread": 10}) + +# Log-normal prior: right-skewed (more tolerant of longer segments) +sat.split(text, max_length=100, prior_type="lognormal", + prior_kwargs={"target_length": 70, "spread": 25}) + +# Clipped polynomial: hard cutoff at ±spread from target +sat.split(text, max_length=100, prior_type="clipped_polynomial", + prior_kwargs={"target_length": 60, "spread": 25}) +``` + +### Language-Aware Defaults + +Pass `lang_code` to use language-specific defaults for `target_length` and `spread`: + +```python +# German has longer average sentences → auto-uses target_length=90, spread=35 +sat.split(text, max_length=150, prior_type="gaussian", + prior_kwargs={"lang_code": "de"}) + +# Chinese has shorter sentences → auto-uses target_length=45, spread=15 +sat.split(text, max_length=100, prior_type="gaussian", + prior_kwargs={"lang_code": "zh"}) +``` + +When using LoRA with a language, this happens automatically: + +```python +sat = SaT("sat-3l", style_or_domain="ud", language="de") +sat.split(text, max_length=150, prior_type="gaussian") # auto-uses German defaults +``` + +### How It Works + +The Viterbi algorithm finds globally optimal segmentation points that balance: +- The model's sentence boundary predictions (where natural splits occur) +- Your length preferences (via the prior; if provided) + +**Text Reconstruction:** +```python +# With constraints (max_length or min_length): +original_text = "".join(segments) # segments may contain newlines + +# Without constraints (SaT default with split_on_input_newlines=True): +original_text = "\n".join(segments) +``` + +> **Note**: When using length constraints, segments may contain newlines. If you want to remove them, you can just post-process the output. + +> **Note**: When `max_length` is set, the `threshold` parameter is ignored. The Viterbi/greedy algorithms use raw model probabilities directly instead of threshold-based filtering. + +For more details, see the [Length Constraints Documentation](./docs/LENGTH_CONSTRAINTS.md). + ## Adaptation -SaT can be domain- and style-adapted via LoRA. We provide trained LoRA modules for Universal Dependencies, OPUS100, Ersatz, and TED (i.e., ASR-style transcribed speecjes) sentence styles in 81 languages for `sat-3l`and `sat-12l`. Additionally, we provide LoRA modules for legal documents (laws and judgements) in 6 languages, code-switching in 4 language pairs, and tweets in 3 languages. For details, we refer to our [paper](https://arxiv.org/abs/2406.16678). +SaT can be domain- and style-adapted via LoRA. We provide trained LoRA modules for Universal Dependencies, OPUS100, Ersatz, and TED (i.e., ASR-style transcribed speeches) sentence styles in 81 languages for `sat-3l`and `sat-12l`. Additionally, we provide LoRA modules for legal documents (laws and judgements) in 6 languages, code-switching in 4 language pairs, and tweets in 3 languages. For details, we refer to our [paper](https://arxiv.org/abs/2406.16678). We also provided verse segmentation modules for 16 genres for `sat-12-no-limited-lookahead`. diff --git a/docs/LENGTH_CONSTRAINTS.md b/docs/LENGTH_CONSTRAINTS.md new file mode 100644 index 00000000..90e3241e --- /dev/null +++ b/docs/LENGTH_CONSTRAINTS.md @@ -0,0 +1,295 @@ +# Length-Constrained Segmentation + +This supplementary document explains the theory and implementation of length-constrained segmentation in wtpsplit (NB: auto-generated). + +## Overview + +A text segmenter like SaT gives us a probability score for every character, indicating how likely that position is to be a segment boundary (e.g., end of a sentence). + +### The Basic Approach + +Simply split whenever `probability > threshold`. + +**Problem**: No control over segment lengths! + +### The Controllable Approach + +> **Note**: When using length-constrained segmentation (`max_length` is set), the `threshold` parameter is **ignored**. The algorithms use raw model probabilities directly to find optimal split points. + +Define a **prior probability distribution** over chunk lengths, then solve an optimization problem that balances: +- The model's boundary predictions +- Your length preferences + +## The Maths + +We view segmentation as selecting a subset of positions C = {C₁, C₂, ..., Cₖ} where each Cᵢ marks the end of a chunk. + +**Optimization Problem:** + +``` +argmax_C ∏ᵢ Prior(Cᵢ - Cᵢ₋₁) × P(Cᵢ) +``` + +Where: +- `Cᵢ` = position of the i-th split point +- `Prior(length)` = how much we prefer chunks of that length +- `P(Cᵢ)` = model's probability at position Cᵢ + +In Bayesian terms: +- `Prior(Cᵢ - Cᵢ₋₁)` is the prior +- `P(Cᵢ)` is the evidence +- Their product is the posterior + +## Prior Functions + +### Which Prior Should I Use? + +| Prior | Recommendation | Best For | +|-------|---------------|----------| +| **Uniform** | Default | Just enforce max_length, let model decide split points | +| **Gaussian** ⭐ | Recommended | Prefer segments around a target length (intuitive, easy to tune) | +| **Log-Normal** | Advanced | Right-skewed preference (more tolerant of longer segments) | +| **Clipped Polynomial** | Special cases | When segments MUST be very close to target length | + +**TL;DR**: Use `uniform` (default) if you just need max_length. Use `gaussian` if you want to prefer a specific segment size. + +### Prior Visualization + +The figure below shows how each prior behaves with `target_length=50`, `spread=15`, and `max_length=100`: + +![Prior Functions Comparison](prior_functions.png) + +**Key differences:** +- **Uniform**: Flat until max_length, then drops to zero +- **Gaussian**: Symmetric bell curve centered at target_length +- **Clipped Polynomial**: Parabola that clips to exactly zero at ±spread from target +- **Log-Normal**: Asymmetric (right-skewed) — more tolerant of longer segments, stricter on shorter ones + +### 1. Uniform Prior (Default) + +```python +Prior(length) = 1.0 if length ≤ max_length else 0.0 +``` + +- All lengths equally good up to `max_length` +- Hard cutoff at maximum +- **Use case**: Simple length limiting — let the model decide optimal splits + +### 2. Gaussian Prior ⭐ Recommended + +```python +Prior(length) = exp(-0.5 × ((length - target_length) / spread)²) +``` + +- Peaks at `target_length` (preferred length) +- Falls off smoothly based on `spread` (standard deviation) +- Symmetric around target +- **Use case**: Prefer specific chunk sizes (e.g., ~512 chars for embedding models) + +### 3. Clipped Polynomial Prior + +```python +Prior(length) = max(1 - (1/spread²) × (length - target_length)², 0) +``` + +- Peaks at `target_length` +- Clips to exactly zero at `±spread` characters from target +- More aggressive than Gaussian — creates a "hard window" around target +- **Use case**: Strong enforcement when segments must be close to target + +### 4. Log-Normal Prior + +```python +Prior(length) = exp(-0.5 × ((log(length) - μ) / σ)²) / length +``` + +- Right-skewed distribution (asymmetric around target) +- More tolerant of segments longer than target (long right tail) +- Drops off faster for segments shorter than target +- `target_length` sets the peak (mode) +- `spread` in characters (same scale as other priors for consistency) +- **Use case**: When you want to be more lenient with longer segments while preferring target length + +## Algorithms + +### Greedy Search + +At each step, pick the locally best split point based on `Prior(length) × P(position)`. + +**Pros:** +- Fast (O(n × max_length)) +- Simple to understand + +**Cons:** +- Not globally optimal +- May miss better overall segmentations + +### Viterbi Algorithm (Recommended) + +Dynamic programming to find the globally optimal sequence of splits. + +**Algorithm:** +``` +dp[i] = best score achievable for text[0:i] +dp[i] = max over j of: dp[j] + log(Prior(i-j)) + log(P[i-1]) + +Where j ranges from max(0, i-max_length) to i-min_length +``` + +**Pros:** +- Globally optimal solution +- Respects sentence boundaries when possible + +**Cons:** +- Slightly slower (O(n × max_length)) + +## Key Guarantees + +1. **`max_length` is STRICT**: No segment will ever exceed `max_length` characters +2. **`min_length` is BEST EFFORT**: Segments may be shorter if merging would violate `max_length` +3. **Text preservation** — use the correct join method: + +```python +# With constraints (max_length or min_length set): +original_text = "".join(segments) # ← newlines may be inside segments + +# Without constraints (SaT default with split_on_input_newlines=True): +original_text = "\n".join(segments) + +# Without constraints (WtP or split_on_input_newlines=False): +original_text = "".join(segments) +``` + +> **Why the difference?** When constraints force mid-line splits (line > max_length), we can't split exactly at newlines. Newlines stay embedded in segments to preserve text. The algorithm boosts probabilities at newline positions to prefer splitting there when possible. + +## Usage Examples + +### Basic Length Limiting + +```python +from wtpsplit import SaT + +sat = SaT("sat-3l-sm") + +# Limit segments to 100 characters +segments = sat.split(text, max_length=100) +``` + +### Both Min and Max + +```python +# Segments between 20-100 characters +segments = sat.split(text, min_length=20, max_length=100) +``` + +### Using Gaussian Prior + +```python +# Prefer ~50 character segments +segments = sat.split( + text, + max_length=100, + prior_type="gaussian", + prior_kwargs={"target_length": 50, "spread": 15} +) +``` + +### Algorithm Selection + +```python +# Use greedy for speed (slightly suboptimal) +segments = sat.split(text, max_length=100, algorithm="greedy") + +# Use viterbi for optimal results (default) +segments = sat.split(text, max_length=100, algorithm="viterbi") +``` + +## How It Respects Sentence Boundaries + +The Viterbi algorithm naturally prefers splitting at high-probability positions (sentence boundaries) because: + +1. `P(position)` is high at sentence boundaries +2. The product `Prior(length) × P(position)` is maximized when both factors are high +3. The algorithm finds the global optimum, so it won't make a locally good choice that leads to bad splits later + +**Example:** + +Text: `"The quick brown fox jumps. Pack my box with jugs."` + +With `max_length=100`: +- Position 25 (after "jumps.") has P ≈ 0.95 +- Position 50 (after "jugs.") has P ≈ 0.98 +- The algorithm will split at these natural boundaries + +With `max_length=30`: +- Must split somewhere before position 30 +- Will choose position 25 (sentence boundary) over position 28 (mid-word) + +## When Word Cuts Happen + +Word cuts only occur when there is **no sentence boundary within `max_length`** characters. This happens with: + +1. Very long sentences without punctuation +2. Very restrictive `max_length` values +3. Text without natural break points (e.g., code, URLs) + +## Implementation Details + +### Post-Processing (`_enforce_segment_constraints`) + +After the algorithm finds split points, post-processing ensures: +- Segments don't exceed `max_length` (force-splits if needed) +- Short segments are merged when possible +- All whitespace is preserved + +## Parameters Reference + +| Parameter | Type | Default | Description | +|-----------|------|---------|-------------| +| `min_length` | int | 1 | Minimum segment length (best effort) | +| `max_length` | int | None | Maximum segment length (strict) | +| `algorithm` | str | "viterbi" | "viterbi" (optimal) or "greedy" (faster) | +| `prior_type` | str | "uniform" | "uniform", "gaussian", "clipped_polynomial", or "lognormal" | +| `prior_kwargs` | dict | None | Parameters for prior function | + +### Prior Parameters + +**All priors support:** +- `lang_code`: Optional language code (e.g., "en", "zh", "de") for language-aware defaults + +**Gaussian:** +- `target_length`: Preferred length in characters (language-aware default if `lang_code` provided) +- `spread`: Standard deviation in characters (language-aware default if `lang_code` provided) + +**Clipped Polynomial:** +- `target_length`: Peak position in characters (language-aware default if `lang_code` provided) +- `spread`: Tolerance in characters — clips to zero at ±spread from target (language-aware default if `lang_code` provided) + +**Log-Normal:** +- `target_length`: Peak/mode in characters (language-aware default if `lang_code` provided) +- `spread`: Tolerance in characters (language-aware default if `lang_code` provided) — same scale as gaussian/clipped_polynomial + +### Language-Aware Defaults + +When `lang_code` is provided in `prior_kwargs` but `target_length` is not specified, +the prior uses empirically-derived defaults for that language: + +```python +# Uses Chinese defaults: target_length=45, spread=15 +sat.split(text, max_length=100, prior_type="gaussian", + prior_kwargs={"lang_code": "zh"}) + +# Uses German defaults: target_length=90, spread=35 +sat.split(text, max_length=150, prior_type="gaussian", + prior_kwargs={"lang_code": "de"}) +``` + +Supported languages include: zh, ja, ko (East Asian), de, en, fr, es, it, pt, ru, ar, and many more. +See `LANG_SENTENCE_STATS` in `wtpsplit/utils/priors.py` for the full list. + +## See Also + +- [Interactive Demo](../length_constrained_segmentation_demo.py) - Run examples and experiment +- [Test Suite](../test_length_constraints.py) - Comprehensive tests +- [README](../README.md) - Quick start guide + diff --git a/docs/prior_functions.png b/docs/prior_functions.png new file mode 100644 index 00000000..3ade7fd3 Binary files /dev/null and b/docs/prior_functions.png differ diff --git a/length_constrained_segmentation_demo.py b/length_constrained_segmentation_demo.py new file mode 100644 index 00000000..43bd57bf --- /dev/null +++ b/length_constrained_segmentation_demo.py @@ -0,0 +1,450 @@ +#!/usr/bin/env python3 +""" +Length-Constrained Segmentation Demo + +Interactive demo for experimenting with length-constrained segmentation. +For detailed documentation, see docs/LENGTH_CONSTRAINTS.md + +Usage: + python length_constrained_segmentation_demo.py # Run all examples + python length_constrained_segmentation_demo.py --interactive # Interactive mode + python length_constrained_segmentation_demo.py --example news # Specific example +""" + +import argparse +import sys + +# ============================================================================= +# SETUP +# ============================================================================= + + +def load_model(): + """Load SaT model.""" + from wtpsplit import SaT + + print("Loading model...", end=" ", flush=True) + sat = SaT("sat-3l-sm", ort_providers=["CPUExecutionProvider"]) + print("✓") + return sat + + +# ============================================================================= +# DISPLAY UTILITIES +# ============================================================================= + + +class C: + """ANSI colors.""" + + RESET = "\033[0m" + BOLD = "\033[1m" + RED = "\033[91m" + GREEN = "\033[92m" + YELLOW = "\033[93m" + BLUE = "\033[94m" + CYAN = "\033[96m" + GRAY = "\033[90m" + + +def show_segments(segments, max_length=None, label=""): + """Display segments with length info and quality indicators.""" + if label: + print(f"\n{C.BOLD}{label}{C.RESET}") + + for i, seg in enumerate(segments, 1): + length = len(seg) + + # Check for word cuts (ends with letter preceded by letter) + has_cut = len(seg) > 1 and seg[-1].isalpha() and seg[-2].isalpha() + + # Status indicator + if has_cut: + status = f"{C.YELLOW}~{C.RESET}" # Word cut + elif max_length and length > max_length: + status = f"{C.RED}!{C.RESET}" # Exceeds max + else: + status = f"{C.GREEN}✓{C.RESET}" # OK + + # Length display + if max_length: + len_str = f"[{length:3d}/{max_length}]" + else: + len_str = f"[{length:3d}]" + + # Truncate for display + display = repr(seg[:50]) + ("..." if len(seg) > 50 else "") + + print(f" {status} {len_str} {display}") + + +def verify_preservation(original, segments): + """Check if text is preserved.""" + rejoined = "".join(segments) + if rejoined == original: + print(f" {C.GREEN}✓ Text preserved{C.RESET}") + return True + else: + print(f" {C.RED}✗ Text NOT preserved!{C.RESET}") + print(f" Original: {len(original)} chars") + print(f" Rejoined: {len(rejoined)} chars") + return False + + +# ============================================================================= +# EXAMPLES +# ============================================================================= + +EXAMPLES = { + "basic": { + "name": "Basic Sentences", + "text": "The quick brown fox jumps over the lazy dog. Pack my box with five dozen liquor jugs. How vexingly quick daft zebras jump!", + "configs": [ + {"max_length": 50}, + {"max_length": 80}, + {"min_length": 30, "max_length": 100}, + ], + }, + "news": { + "name": "News Article", + "text": """Breaking News: Scientists at CERN have announced a groundbreaking discovery that could revolutionize our understanding of particle physics. The team, led by Dr. Elena Rodriguez, observed unexpected behavior in proton collisions at energies never before achieved. "This is the most significant finding in our field since the Higgs boson," Dr. Rodriguez stated at a press conference in Geneva. The discovery has implications for theories of dark matter and could lead to new technologies within the next decade.""", + "configs": [ + {"max_length": 100}, + {"max_length": 150}, + {"max_length": 200}, + ], + }, + "legal": { + "name": "Legal Text (Long Sentences)", + "text": """WHEREAS the Party of the First Part (hereinafter referred to as "Licensor") is the owner of certain intellectual property rights including but not limited to patents, trademarks, copyrights, and trade secrets relating to the technology described herein, and WHEREAS the Party of the Second Part (hereinafter referred to as "Licensee") desires to obtain a license to use, manufacture, and distribute products incorporating said technology, NOW THEREFORE in consideration of the mutual covenants and agreements set forth herein, the parties agree as follows.""", + "configs": [ + {"max_length": 100}, + {"max_length": 150}, + {"max_length": 250}, + ], + }, + "technical": { + "name": "Technical Documentation", + "text": """The function accepts three parameters: input_data (required), config (optional), and callback (optional). When input_data is a string, it will be parsed as JSON; when it's an object, it will be used directly. The config parameter supports the following options: timeout (default: 30000ms), retries (default: 3), and verbose (default: false). If callback is provided, the function operates asynchronously.""", + "configs": [ + {"max_length": 80}, + {"max_length": 120}, + {"min_length": 50, "max_length": 150}, + ], + }, + "stream": { + "name": "Stream of Consciousness (No Punctuation)", + "text": """I was walking down the street thinking about what to have for dinner maybe pasta or perhaps something lighter like a salad but then again it was cold outside and soup sounded really good especially that tomato soup from the place around the corner which reminded me I needed to call my mother""", + "configs": [ + {"max_length": 60}, + {"max_length": 100}, + {"max_length": 150}, + ], + }, + "dialogue": { + "name": "Dialogue with Quotes", + "text": '''"Have you seen the news?" asked Maria. "About the merger?" replied John. "No, I mean about the earthquake." Maria shook her head. "It's terrible." John sighed. "Sometimes I wonder if things will ever get better."''', + "configs": [ + {"max_length": 50}, + {"max_length": 80}, + {"max_length": 120}, + ], + }, + "mixed": { + "name": "Mixed Content (Numbers, Abbreviations)", + "text": """Dr. Smith earned $125,000 in Q4 2023, a 15.7% increase. The U.S. Department of Labor reported unemployment at 3.5%. Mr. Johnson's company, ABC Corp., plans to expand to the U.K. and E.U. by mid-2024. The CEO stated: "We expect revenues of $50M-$75M." The S&P 500 closed at 4,769.83 pts.""", + "configs": [ + {"max_length": 80}, + {"max_length": 120}, + {"min_length": 40, "max_length": 160}, + ], + }, + "priors": { + "name": "Prior Functions Comparison", + "text": "One. Two. Three. Four. Five. Six. Seven. Eight. Nine. Ten. Eleven. Twelve.", + "configs": [ + {"max_length": 100, "prior_type": "uniform"}, + {"max_length": 100, "prior_type": "gaussian", "prior_kwargs": {"target_length": 30, "spread": 10}}, + {"max_length": 100, "prior_type": "gaussian", "prior_kwargs": {"target_length": 60, "spread": 15}}, + ], + }, + "algorithms": { + "name": "Algorithm Comparison (Viterbi vs Greedy)", + "text": "The quick brown fox jumps over the lazy dog. Pack my box with five dozen liquor jugs. How vexingly quick daft zebras jump! The five boxing wizards jump quickly. Sphinx of black quartz, judge my vow.", + "configs": [ + {"max_length": 80, "algorithm": "viterbi"}, + {"max_length": 80, "algorithm": "greedy"}, + ], + }, +} + + +def run_example(sat, name, example): + """Run a single example.""" + print(f"\n{'=' * 70}") + print(f"{C.BOLD}{example['name']}{C.RESET}") + print(f"{'=' * 70}") + print(f"\n{C.GRAY}Text ({len(example['text'])} chars):{C.RESET}") + print(f' "{example["text"][:80]}{"..." if len(example["text"]) > 80 else ""}"') + + for config in example["configs"]: + # Build label + parts = [] + if "max_length" in config: + parts.append(f"max={config['max_length']}") + if "min_length" in config: + parts.append(f"min={config['min_length']}") + if config.get("algorithm"): + parts.append(f"algo={config['algorithm']}") + if config.get("prior_type") and config["prior_type"] != "uniform": + parts.append(f"prior={config['prior_type']}") + label = ", ".join(parts) + + # Run segmentation + segments = sat.split(example["text"], threshold=0.025, **config) + + show_segments(segments, config.get("max_length"), label) + verify_preservation(example["text"], segments) + + +def run_all_examples(sat): + """Run all examples.""" + print(f"\n{C.CYAN}{'=' * 70}") + print(" LENGTH-CONSTRAINED SEGMENTATION EXAMPLES") + print(f"{'=' * 70}{C.RESET}") + + for name, example in EXAMPLES.items(): + run_example(sat, name, example) + + print(f"\n{C.CYAN}{'=' * 70}{C.RESET}") + print(f"\nFor interactive mode: {C.BOLD}python {sys.argv[0]} --interactive{C.RESET}") + print(f"For documentation: {C.BOLD}see docs/LENGTH_CONSTRAINTS.md{C.RESET}") + + +# ============================================================================= +# INTERACTIVE MODE +# ============================================================================= + + +def interactive_mode(sat): + """Interactive segmentation playground.""" + print(f"\n{C.CYAN}{'=' * 70}") + print(" INTERACTIVE MODE") + print(f"{'=' * 70}{C.RESET}") + + print(""" +Commands: + - Segment the text + max=N - Set max_length + min=N - Set min_length + algo=X - Set algorithm (viterbi/greedy) + prior=X - Set prior (uniform/gaussian/polynomial) + reset - Reset to defaults + examples - List available examples + run NAME - Run an example + q - Quit +""") + + # Settings + settings = { + "max_length": 100, + "min_length": 1, + "algorithm": "viterbi", + "prior_type": "uniform", + "prior_kwargs": None, + } + + while True: + # Show current settings + print( + f"\n{C.GRAY}[max={settings['max_length']}, min={settings['min_length']}, " + f"algo={settings['algorithm']}, prior={settings['prior_type']}]{C.RESET}" + ) + + try: + user_input = input(f"{C.BOLD}> {C.RESET}").strip() + except (EOFError, KeyboardInterrupt): + print("\nGoodbye!") + break + + if not user_input: + continue + + # Commands + if user_input.lower() == "q": + print("Goodbye!") + break + + if user_input.lower() == "reset": + settings = { + "max_length": 100, + "min_length": 1, + "algorithm": "viterbi", + "prior_type": "uniform", + "prior_kwargs": None, + } + print("Settings reset.") + continue + + if user_input.lower() == "examples": + print("\nAvailable examples:") + for name, ex in EXAMPLES.items(): + print(f" {name:12s} - {ex['name']}") + continue + + if user_input.lower().startswith("run "): + name = user_input[4:].strip() + if name in EXAMPLES: + run_example(sat, name, EXAMPLES[name]) + else: + print(f"Unknown example: {name}") + continue + + if user_input.startswith("max="): + try: + settings["max_length"] = int(user_input[4:]) + print(f"max_length = {settings['max_length']}") + except ValueError: + print("Invalid number") + continue + + if user_input.startswith("min="): + try: + settings["min_length"] = int(user_input[4:]) + print(f"min_length = {settings['min_length']}") + except ValueError: + print("Invalid number") + continue + + if user_input.startswith("algo="): + algo = user_input[5:].strip() + if algo in ["viterbi", "greedy"]: + settings["algorithm"] = algo + print(f"algorithm = {algo}") + else: + print("Unknown algorithm (use: viterbi, greedy)") + continue + + if user_input.startswith("prior="): + prior = user_input[6:].strip() + if prior in ["uniform", "gaussian", "polynomial"]: + settings["prior_type"] = prior if prior != "polynomial" else "clipped_polynomial" + if prior == "gaussian": + settings["prior_kwargs"] = {"target_length": settings["max_length"] * 0.7, "spread": 15} + elif prior == "polynomial": + settings["prior_kwargs"] = {"target_length": settings["max_length"] * 0.7, "spread": 30} + else: + settings["prior_kwargs"] = None + print(f"prior = {prior}") + else: + print("Unknown prior (use: uniform, gaussian, polynomial)") + continue + + # Treat as text to segment + text = user_input + + try: + kwargs = { + "threshold": 0.025, + "max_length": settings["max_length"], + "min_length": settings["min_length"], + "algorithm": settings["algorithm"], + "prior_type": settings["prior_type"], + } + if settings["prior_kwargs"]: + kwargs["prior_kwargs"] = settings["prior_kwargs"] + + segments = sat.split(text, **kwargs) + + print(f"\n{C.BOLD}Result: {len(segments)} segments{C.RESET}") + show_segments(segments, settings["max_length"]) + verify_preservation(text, segments) + + except Exception as e: + print(f"{C.RED}Error: {e}{C.RESET}") + + +# ============================================================================= +# PROBABILITY VISUALIZATION +# ============================================================================= + + +def show_probabilities(sat): + """Visualize model probabilities for a sample text.""" + print(f"\n{C.CYAN}{'=' * 70}") + print(" PROBABILITY VISUALIZATION") + print(f"{'=' * 70}{C.RESET}") + + text = "The quick brown fox jumps. Pack my box with jugs. How vexingly quick!" + + print(f'\n{C.BOLD}Text:{C.RESET} "{text}"') + + # Get probabilities + probs = list(sat.predict_proba([text]))[0] + + # Build visualization + viz = "" + for p in probs: + if p > 0.9: + viz += f"{C.GREEN}█{C.RESET}" + elif p > 0.5: + viz += f"{C.YELLOW}▓{C.RESET}" + elif p > 0.1: + viz += f"{C.GRAY}▒{C.RESET}" + else: + viz += f"{C.GRAY}░{C.RESET}" + + print(f"\n{C.BOLD}Probabilities:{C.RESET}") + print(f" {text}") + print(f" {viz}") + print( + f"\n Legend: {C.GREEN}█{C.RESET}>0.9 {C.YELLOW}▓{C.RESET}>0.5 {C.GRAY}▒{C.RESET}>0.1 {C.GRAY}░{C.RESET}≤0.1" + ) + + # Show high-probability positions + print(f"\n{C.BOLD}Detected boundaries (prob > 0.5):{C.RESET}") + for i, p in enumerate(probs): + if p > 0.5: + ctx = text[max(0, i - 5) : i + 3] + print(f' Position {i:2d}: p={p:.3f} "...{ctx}..."') + + +# ============================================================================= +# MAIN +# ============================================================================= + + +def main(): + parser = argparse.ArgumentParser( + description="Length-Constrained Segmentation Demo", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + python %(prog)s # Run all examples + python %(prog)s --interactive # Interactive playground + python %(prog)s --example news # Run specific example + python %(prog)s --probs # Show probability visualization + +Available examples: """ + + ", ".join(EXAMPLES.keys()), + ) + parser.add_argument("-i", "--interactive", action="store_true", help="Interactive mode") + parser.add_argument("-e", "--example", choices=list(EXAMPLES.keys()), help="Run specific example") + parser.add_argument("-p", "--probs", action="store_true", help="Show probability visualization") + + args = parser.parse_args() + + sat = load_model() + + if args.interactive: + interactive_mode(sat) + elif args.example: + run_example(sat, args.example, EXAMPLES[args.example]) + elif args.probs: + show_probabilities(sat) + else: + run_all_examples(sat) + + +if __name__ == "__main__": + main() diff --git a/pyproject.toml b/pyproject.toml index c7ad162a..174a5fc6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,12 +1,12 @@ [tool.black] line-length = 120 -target-version = ["py38", "py39", "py310"] +target-version = ["py39", "py310", "py311", "py312"] [tool.ruff] line-length = 120 -ignore = ["E741"] +lint.ignore = ["E741"] -[tool.ruff.per-file-ignores] +[tool.ruff.lint.per-file-ignores] "test.py" = ["E501"] [dependency-groups] diff --git a/requirements.txt b/requirements.txt index c299089c..921794a3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -19,7 +19,6 @@ replicate onnx onnxruntime mosestokenizer -cached_property tqdm skops pandas diff --git a/scripts/compute_sentence_stats.py b/scripts/compute_sentence_stats.py new file mode 100644 index 00000000..a441684b --- /dev/null +++ b/scripts/compute_sentence_stats.py @@ -0,0 +1,152 @@ +#!/usr/bin/env python3 +""" +Compute sentence length statistics from Universal Dependencies treebanks. + +Usage: + python scripts/compute_sentence_stats.py --output_dir wtpsplit/data/ -v +""" + +import argparse +import io +import json +import os +import sys +import tarfile +import urllib.request +from collections import defaultdict +from datetime import datetime + +import conllu +import numpy as np + + +def compute_stats(lengths: list[int]) -> dict: + """Compute target_length and spread from sentence lengths.""" + if len(lengths) < 10: + return None + + arr = np.array(lengths) + target_length = int(np.median(arr)) + + # IQR-based spread (robust to outliers) + q75, q25 = np.percentile(arr, [75, 25]) + spread = max(int((q75 - q25) / 1.35), 10) + + return { + "target_length": target_length, + "spread": spread, + "n_sentences": len(lengths), + "min": int(np.min(arr)), + "max": int(np.max(arr)), + } + + +def load_ud_from_hf() -> dict[str, list[int]]: + """Download and process UD treebanks from official release.""" + lang_lengths = defaultdict(list) + + ud_version = "2.14" + url = f"https://lindat.mff.cuni.cz/repository/xmlui/bitstream/handle/11234/1-5502/ud-treebanks-v{ud_version}.tgz" + + print(f"Downloading UD v{ud_version} (~500MB)...", file=sys.stderr) + + with urllib.request.urlopen(url) as response: + print("Reading archive into memory...", file=sys.stderr) + fileobj = io.BytesIO(response.read()) + + print("Processing CONLL-U files...", file=sys.stderr) + file_count = 0 + + with tarfile.open(fileobj=fileobj, mode="r:gz") as tar: + for member in tar: + if member.name.endswith(".conllu"): + filename = os.path.basename(member.name) + lang_code = filename.split("_")[0] + + try: + f = tar.extractfile(member) + if f is not None: + content = f.read().decode("utf-8") + data = conllu.parse(content) + + for sentence in data: + if "text" in sentence.metadata: + lang_lengths[lang_code].append(len(sentence.metadata["text"])) + + file_count += 1 + if file_count % 50 == 0: + print(f" Processed {file_count} files...", file=sys.stderr) + except Exception as e: + print(f"Warning: Could not parse {member.name}: {e}", file=sys.stderr) + + print(f" Processed {file_count} CONLL-U files total", file=sys.stderr) + + return dict(lang_lengths) + + +def json_to_python(json_path: str) -> str: + """Convert JSON stats to Python dict format for priors.py.""" + with open(json_path, "r") as f: + data = json.load(f) + + lines = ["LANG_SENTENCE_STATS = {"] + for lang_code, s in sorted(data["stats"].items()): + lines.append(f' "{lang_code}": {{"target_length": {s["target_length"]}, "spread": {s["spread"]}}},') + lines.append("}") + return "\n".join(lines) + + +def main(): + parser = argparse.ArgumentParser(description="Compute sentence length statistics from UD") + parser.add_argument("--output_dir", "-o", type=str, help="Output directory (for downloading)") + parser.add_argument("--to-python", type=str, metavar="JSON_FILE", help="Convert JSON to Python dict format") + parser.add_argument("--min_sentences", type=int, default=100, help="Minimum sentences per language") + parser.add_argument("--verbose", "-v", action="store_true", help="Verbose output") + args = parser.parse_args() + + # Convert existing JSON to Python format + if args.to_python: + print(json_to_python(args.to_python)) + return + + if not args.output_dir: + parser.error("--output_dir required (or use --to-python to convert existing JSON)") + + # Load data + lang_lengths = load_ud_from_hf() + print(f"Loaded data for {len(lang_lengths)} languages", file=sys.stderr) + + # Compute statistics + stats = {} + for lang_code, lengths in sorted(lang_lengths.items()): + if len(lengths) >= args.min_sentences: + stats[lang_code] = compute_stats(lengths) + if args.verbose: + s = stats[lang_code] + print( + f" {lang_code}: {len(lengths):>6} sentences, median={s['target_length']:>3}, spread={s['spread']}", + file=sys.stderr, + ) + elif args.verbose: + print(f" {lang_code}: skipped ({len(lengths)} < {args.min_sentences} sentences)", file=sys.stderr) + + print(f"\nComputed statistics for {len(stats)} languages", file=sys.stderr) + + # Write output + os.makedirs(args.output_dir, exist_ok=True) + output = { + "metadata": { + "source": "Universal Dependencies v2.14", + "generated_at": datetime.utcnow().isoformat() + "Z", + }, + "stats": {k: {"target_length": v["target_length"], "spread": v["spread"]} for k, v in stats.items()}, + } + + json_path = os.path.join(args.output_dir, "sentence_stats.json") + with open(json_path, "w", encoding="utf-8") as f: + json.dump(output, f, indent=2) + print(f"Wrote {json_path}", file=sys.stderr) + + +if __name__ == "__main__": + main() diff --git a/scripts/export_to_onnx_sat.py b/scripts/export_to_onnx_sat.py index c6fadf24..583d6e40 100644 --- a/scripts/export_to_onnx_sat.py +++ b/scripts/export_to_onnx_sat.py @@ -131,4 +131,4 @@ class Args: path_or_fileobj=output_dir / "model.onnx", path_in_repo="model.onnx", repo_id=args.model_name_or_path, - ) \ No newline at end of file + ) diff --git a/setup.py b/setup.py index d6b15967..37ab98d6 100644 --- a/setup.py +++ b/setup.py @@ -2,27 +2,27 @@ setup( name="wtpsplit", - version="2.1.7", + version="2.2.0", packages=find_packages(), description="Universal Robust, Efficient and Adaptable Sentence Segmentation", author="Markus Frohmann, Igor Sterner, Benjamin Minixhofer", author_email="markus.frohmann@gmail.com", + python_requires=">=3.9", install_requires=[ # "onnxruntime>=1.13.1", # can make conflicts between onnxruntime and onnxruntime-gpu - "transformers>=4.22.2", - "huggingface-hub", + "transformers>=4.22.2,<5.0", # v5.0 has breaking changes; adapters library needs update first + "huggingface-hub<1.0", # v1.0 has breaking changes (HfFolder removed) "numpy>=1.0", "scikit-learn>=1", "tqdm", "skops", "pandas>=1", - "cached_property", # for Py37 "mosestokenizer", "adapters>=1.0.1", ], extras_require={ - 'onnx-gpu': ['onnxruntime-gpu>=1.13.1'], - 'onnx-cpu': ['onnxruntime>=1.13.1'], + "onnx-gpu": ["onnxruntime-gpu>=1.13.1"], + "onnx-cpu": ["onnxruntime>=1.13.1"], }, url="https://github.com/segment-any-text/wtpsplit", package_data={"wtpsplit": ["data/*"]}, diff --git a/test.py b/test.py index 20f662e1..df256a0e 100644 --- a/test.py +++ b/test.py @@ -1,5 +1,7 @@ # noqa: E501 from wtpsplit import WtP, SaT +import numpy as np + def test_weighting(): sat = SaT("sat-3l-sm", ort_providers=["CPUExecutionProvider"]) @@ -8,7 +10,7 @@ def test_weighting(): splits_default = sat.split(text, threshold=0.25) splits_uniform = sat.split(text, threshold=0.25, weighting="uniform") splits_hat = sat.split(text, threshold=0.25, weighting="hat") - expected_splits = ["This is a test sentence ", "This is another test sentence."] + expected_splits = ["This is a test sentence ", "This is another test sentence."] assert splits_default == splits_uniform == splits_hat == expected_splits assert "".join(splits_default) == text @@ -55,13 +57,12 @@ def test_strip_newline_behaviour(): "Yes\nthis is a test sentence. This is another test sentence.", ) assert splits == ["Yes", "this is a test sentence. ", "This is another test sentence."] - + + def test_strip_newline_behaviour_as_spaces(): sat = SaT("segment-any-text/sat-3l", hub_prefix=None) - splits = sat.split( - "Yes\nthis is a test sentence. This is another test sentence.", treat_newline_as_space=True - ) + splits = sat.split("Yes\nthis is a test sentence. This is another test sentence.", treat_newline_as_space=True) assert splits == ["Yes\nthis is a test sentence. ", "This is another test sentence."] @@ -114,10 +115,11 @@ def test_split_paragraphs(): assert paragraph1.startswith("Text segmentation is") assert paragraph2.startswith("Daniel Wroughton Craig CMG (born 2 March 1968) is") - + + def test_split_empty_strings(): sat = SaT("segment-any-text/sat-3l", hub_prefix=None) - + text = " " splits = sat.split(text) assert splits == [" "] @@ -253,4 +255,152 @@ def test_split_threshold_wtp(): splits = wtp.split("This is a test sentence. This is another test sentence.", threshold=-1e-3) # space might still be included in a character split - assert splits[:3] == list("Thi") \ No newline at end of file + assert splits[:3] == list("Thi") + + +# ============================================================================ +# Length-Constrained Segmentation Tests +# ============================================================================ + + +def test_min_length_constraint_wtp(): + """Test minimum length constraint with WtP""" + wtp = WtP("wtp-bert-mini", ort_providers=["CPUExecutionProvider"]) + + text = "Short. Test. Hello. World. This is longer." + splits = wtp.split(text, min_length=15, threshold=0.005) + + # All segments should be >= 15 characters + for segment in splits: + assert len(segment) >= 15, f"Segment '{segment}' is shorter than min_length" + + # Text should be preserved + assert "".join(splits) == text + + +def test_max_length_constraint_sat(): + """Test maximum length constraint with SaT""" + sat = SaT("sat-3l-sm", ort_providers=["CPUExecutionProvider"]) + + text = "This is a test sentence. " * 10 + splits = sat.split(text, max_length=60, threshold=0.025) + + # All segments should be <= 60 characters + for segment in splits: + assert len(segment) <= 60, f"Segment '{segment}' is longer than max_length" + + # Text should be preserved + assert "".join(splits) == text + + +def test_min_max_constraints_together(): + """Test both constraints simultaneously""" + wtp = WtP("wtp-bert-mini", ort_providers=["CPUExecutionProvider"]) + + text = "Hello world. " * 15 + splits = wtp.split(text, min_length=25, max_length=65, threshold=0.005) + + # All segments should satisfy both constraints + for segment in splits: + assert 25 <= len(segment) <= 65, f"Segment '{segment}' violates constraints" + + # Text should be preserved + assert "".join(splits) == text + + +def test_gaussian_prior(): + """Test Gaussian prior preference""" + sat = SaT("sat-3l-sm", ort_providers=["CPUExecutionProvider"]) + + text = "Sentence. " * 30 + splits = sat.split( + text, + min_length=20, + max_length=60, + prior_type="gaussian", + prior_kwargs={"target_length": 40.0, "spread": 5.0}, + threshold=0.025, + ) + + # Should produce valid splits + for segment in splits: + assert 20 <= len(segment) <= 60 + + # Text should be preserved + assert "".join(splits) == text + + +def test_greedy_algorithm(): + """Test greedy algorithm""" + wtp = WtP("wtp-bert-mini", ort_providers=["CPUExecutionProvider"]) + + text = "Test sentence. " * 10 + splits = wtp.split(text, min_length=20, max_length=50, algorithm="greedy", threshold=0.005) + + # Should produce valid splits + for segment in splits: + assert 20 <= len(segment) <= 50 + + # Text should be preserved + assert "".join(splits) == text + + +def test_constraints_with_paragraph_segmentation(): + """Test constraints with nested paragraph segmentation""" + wtp = WtP("wtp-bert-mini", ort_providers=["CPUExecutionProvider"]) + + text = " ".join( + [ + "First paragraph first sentence. First paragraph second sentence.", + "Second paragraph first sentence. Second paragraph second sentence.", + ] + ) + + paragraphs = wtp.split(text, do_paragraph_segmentation=True, min_length=20, max_length=70) + + # Check structure + assert isinstance(paragraphs, list) + for paragraph in paragraphs: + assert isinstance(paragraph, list) + for sentence in paragraph: + assert 20 <= len(sentence) <= 70 + + +def test_constraints_preserved_in_batched(): + """Test constraints work with batched processing""" + sat = SaT("sat-3l-sm", ort_providers=["CPUExecutionProvider"]) + + texts = [ + "First batch text. " * 5, + "Second batch text. " * 5, + ] + + results = list(sat.split(texts, min_length=25, max_length=55, threshold=0.025)) + + assert len(results) == 2 + for splits in results: + for segment in splits: + assert 25 <= len(segment) <= 55 + + +def test_constraint_low_level(): + """Test constrained_segmentation directly""" + from wtpsplit.utils.constraints import constrained_segmentation + from wtpsplit.utils.priors import create_prior_function + + probs = np.array([0.1, 0.3, 0.5, 0.7, 0.9, 0.2, 0.4, 0.6, 0.8, 1.0]) + prior_fn = create_prior_function("uniform", {"max_length": 5}) + + indices = constrained_segmentation(probs, prior_fn, min_length=3, max_length=5, algorithm="viterbi") + + # Verify constraints on chunk lengths + prev = 0 + for idx in indices: + chunk_len = idx - prev + assert 3 <= chunk_len <= 5, f"Chunk length {chunk_len} violates constraints" + prev = idx + + # Check last chunk + if prev < len(probs): + last_len = len(probs) - prev + assert 3 <= last_len <= 5, f"Last chunk length {last_len} violates constraints" diff --git a/test_length_constraints.py b/test_length_constraints.py new file mode 100644 index 00000000..c714f490 --- /dev/null +++ b/test_length_constraints.py @@ -0,0 +1,1121 @@ +# noqa: E501 +""" +Comprehensive tests for length-constrained segmentation in wtpsplit. + +This test suite covers: +- Text preservation guarantee +- Strict max_length enforcement +- min_length best-effort behavior +- Viterbi and greedy algorithms +- Prior functions (uniform, gaussian, clipped_polynomial) +- Edge cases and special characters +- Both WtP and SaT models +- Real-world scenarios +- Regression tests for fixed bugs + +Run with: pytest test_length_constraints.py -v +""" + +import pytest +import numpy as np +import warnings + +from wtpsplit import WtP, SaT +from wtpsplit.utils.constraints import constrained_segmentation +from wtpsplit.utils.priors import create_prior_function, get_language_defaults, LANG_SENTENCE_STATS + + +# ============================================================================= +# FIXTURES +# ============================================================================= + + +@pytest.fixture(scope="module") +def sat_model(): + """Load SaT model once for all tests.""" + return SaT("sat-3l-sm", ort_providers=["CPUExecutionProvider"]) + + +@pytest.fixture(scope="module") +def wtp_model(): + """Load WtP model once for all tests.""" + return WtP("wtp-bert-mini", ort_providers=["CPUExecutionProvider"]) + + +# ============================================================================= +# BASIC CONSTRAINT ENFORCEMENT (Low-level) +# ============================================================================= + + +class TestBasicConstraints: + """Test basic constraint enforcement at the algorithm level.""" + + def test_min_length_enforcement(self): + """Verify all chunks are >= min_length.""" + probs = np.random.random(100) + min_len = 10 + prior_fn = create_prior_function("uniform", {"max_length": 100}) + + indices = constrained_segmentation(probs, prior_fn, min_length=min_len, max_length=100) + + prev = 0 + for idx in indices: + chunk_len = idx - prev + assert chunk_len >= min_len, f"Chunk length {chunk_len} < min {min_len}" + prev = idx + + last_len = 100 - prev + assert last_len >= min_len, f"Last chunk length {last_len} < min {min_len}" + + def test_max_length_enforcement(self): + """Verify all chunks are <= max_length.""" + probs = np.random.random(100) + max_len = 20 + prior_fn = create_prior_function("uniform", {"max_length": max_len}) + + indices = constrained_segmentation(probs, prior_fn, min_length=1, max_length=max_len) + + prev = 0 + for idx in indices: + chunk_len = idx - prev + assert chunk_len <= max_len, f"Chunk length {chunk_len} > max {max_len}" + prev = idx + + last_len = 100 - prev + assert last_len <= max_len, f"Last chunk length {last_len} > max {max_len}" + + def test_min_max_together(self): + """Both constraints simultaneously.""" + probs = np.random.random(100) + min_len = 5 + max_len = 15 + prior_fn = create_prior_function("uniform", {"max_length": max_len}) + + indices = constrained_segmentation(probs, prior_fn, min_length=min_len, max_length=max_len) + + prev = 0 + for idx in indices: + chunk_len = idx - prev + assert min_len <= chunk_len <= max_len, f"Chunk length {chunk_len} not in [{min_len}, {max_len}]" + prev = idx + + last_len = 100 - prev + assert min_len <= last_len <= max_len, f"Last chunk length {last_len} not in [{min_len}, {max_len}]" + + def test_no_constraints(self): + """Default behavior (min=1, max=None) should work.""" + probs = np.array([0.1, 0.3, 0.7, 0.9]) + + def prior_fn(length): + return 1.0 + + indices = constrained_segmentation(probs, prior_fn, min_length=1, max_length=None) + assert isinstance(indices, list) + + def test_large_text_with_constraints(self): + """Test with large text.""" + probs = np.random.random(1000) + prior_fn = create_prior_function("uniform", {"max_length": 50}) + + indices = constrained_segmentation(probs, prior_fn, min_length=20, max_length=50) + + prev = 0 + for idx in indices: + chunk_len = idx - prev + assert 20 <= chunk_len <= 50 + prev = idx + + +# ============================================================================= +# TEXT PRESERVATION TESTS +# ============================================================================= + + +class TestTextPreservation: + """Verify that segmentation preserves original text exactly.""" + + def test_simple_text(self, sat_model): + text = "Hello world. How are you? I am fine." + segments = sat_model.split(text, threshold=0.5) + assert "".join(segments) == text + + def test_with_max_length(self, sat_model): + text = "The quick brown fox jumps over the lazy dog. Pack my box with five dozen liquor jugs." + segments = sat_model.split(text, max_length=50, threshold=0.025) + assert "".join(segments) == text + + def test_multiline_preserved(self, sat_model): + text = "Line one.\n\nLine two.\n\nLine three." + segments = sat_model.split(text, threshold=0.5, split_on_input_newlines=False) + assert "".join(segments) == text + + def test_whitespace_variations(self, sat_model): + text = "Word1 Word2. Word3 Word4." + segments = sat_model.split(text, threshold=0.5) + assert "".join(segments) == text + + def test_unicode_preserved(self, sat_model): + text = "Привет мир. 你好世界。مرحبا بالعالم." + segments = sat_model.split(text, threshold=0.5) + assert "".join(segments) == text + + def test_special_characters(self, sat_model): + text = "Price: $100.00! Email: test@example.com? URL: https://example.com/path?q=1" + segments = sat_model.split(text, threshold=0.5, max_length=100) + assert "".join(segments) == text + + def test_long_document(self, sat_model): + text = "This is sentence one. " * 50 + segments = sat_model.split(text, max_length=150, threshold=0.025) + assert "".join(segments) == text + + +# ============================================================================= +# MAX_LENGTH TESTS +# ============================================================================= + + +class TestMaxLength: + """Verify strict max_length enforcement.""" + + def test_all_segments_within_max(self, sat_model): + text = "The quick brown fox jumps over the lazy dog. " * 10 + max_length = 100 + segments = sat_model.split(text, max_length=max_length, threshold=0.025) + + for segment in segments: + assert len(segment) <= max_length, f"Segment too long: {len(segment)} > {max_length}" + + def test_various_max_lengths(self, sat_model): + text = "Hello world. How are you today? I am doing well. Thanks for asking!" + + for max_length in [30, 50, 80, 100, 150]: + segments = sat_model.split(text, max_length=max_length, threshold=0.025) + for segment in segments: + assert len(segment) <= max_length + assert "".join(segments) == text + + def test_max_length_forces_split(self, sat_model): + """Text longer than max_length must be split.""" + text = "This is a single very long sentence without any natural break points whatsoever." + segments = sat_model.split(text, max_length=40, threshold=0.025) + + assert len(segments) > 1 + for segment in segments: + assert len(segment) <= 40 + + def test_max_length_one(self, sat_model): + """Edge case: max_length=1.""" + text = "Hi" + segments = sat_model.split(text, max_length=1, threshold=0.025) + for segment in segments: + assert len(segment) <= 1 + + +# ============================================================================= +# MIN_LENGTH TESTS +# ============================================================================= + + +class TestMinLength: + """Verify min_length best-effort behavior.""" + + def test_min_length_merges_short(self, sat_model): + text = "A. B. C. D. E." + segments_no_min = sat_model.split(text, threshold=0.5) + segments_with_min = sat_model.split(text, threshold=0.5, min_length=5) + + assert len(segments_with_min) <= len(segments_no_min) + + def test_min_length_with_max_length(self, sat_model): + text = "Short. Another short. Yet another. And more." + segments = sat_model.split(text, min_length=10, max_length=50, threshold=0.025) + + for segment in segments: + assert len(segment) <= 50 + assert "".join(segments) == text + + def test_tiny_fragments_merging(self, wtp_model): + """Tiny fragments should be merged to meet min_length.""" + text = "A. B. C. D. E. F. G. H. I. J." + splits = wtp_model.split(text, min_length=10, threshold=0.005) + + for segment in splits: + assert len(segment) >= 10, f"Segment '{segment}' is too short" + + def test_very_short_sentences(self, wtp_model): + """Very short sentences should be merged when needed.""" + text = "Hi. Bye. Go. Stop. Run. Walk. Jump. Sit." + splits = wtp_model.split(text, min_length=15, threshold=0.005) + + for segment in splits: + assert len(segment) >= 15, f"Segment '{segment}' is too short" + + +# ============================================================================= +# ALGORITHM TESTS +# ============================================================================= + + +class TestAlgorithms: + """Test Viterbi and greedy algorithms.""" + + def test_viterbi_deterministic(self, sat_model): + text = "The quick brown fox. Pack my box. How vexingly quick!" + + results = [sat_model.split(text, max_length=80, algorithm="viterbi", threshold=0.025) for _ in range(3)] + + assert all(r == results[0] for r in results) + + def test_greedy_deterministic(self, sat_model): + text = "The quick brown fox. Pack my box. How vexingly quick!" + + results = [sat_model.split(text, max_length=80, algorithm="greedy", threshold=0.025) for _ in range(3)] + + assert all(r == results[0] for r in results) + + def test_both_algorithms_preserve_text(self, sat_model): + text = "First sentence here. Second sentence follows. Third one ends it." + + for algo in ["viterbi", "greedy"]: + segments = sat_model.split(text, max_length=100, algorithm=algo, threshold=0.025) + assert "".join(segments) == text + + def test_both_algorithms_respect_max_length(self, sat_model): + text = "The quick brown fox jumps. " * 20 + max_length = 80 + + for algo in ["viterbi", "greedy"]: + segments = sat_model.split(text, max_length=max_length, algorithm=algo, threshold=0.025) + for segment in segments: + assert len(segment) <= max_length + + def test_viterbi_vs_greedy_both_valid(self): + """Both algorithms should produce valid segmentations.""" + probs = np.random.rand(50) + probs[15] = 0.95 + probs[30] = 0.95 + probs[45] = 0.95 + + prior_fn = create_prior_function("uniform", {"max_length": 20}) + + greedy = constrained_segmentation(probs, prior_fn, min_length=1, max_length=20, algorithm="greedy") + viterbi = constrained_segmentation(probs, prior_fn, min_length=1, max_length=20, algorithm="viterbi") + + for boundaries in [greedy, viterbi]: + prev = 0 + for b in boundaries + [50]: + assert b - prev <= 20 + prev = b + + +# ============================================================================= +# PRIOR FUNCTION TESTS +# ============================================================================= + + +class TestPriors: + """Test prior functions behavior.""" + + def test_uniform_prior(self): + prior_fn = create_prior_function("uniform", {"max_length": 100}) + + assert prior_fn(50) == 1.0 + assert prior_fn(100) == 1.0 + assert prior_fn(101) == 0.0 + assert prior_fn(200) == 0.0 + + def test_gaussian_prior(self): + prior_fn = create_prior_function("gaussian", {"target_length": 50, "spread": 10}) + + assert prior_fn(50) == pytest.approx(1.0) + assert prior_fn(30) < prior_fn(50) + assert prior_fn(70) < prior_fn(50) + # Symmetric + assert prior_fn(40) == pytest.approx(prior_fn(60), rel=1e-5) + + def test_polynomial_prior(self): + # spread=20 means tolerance of ±20 chars before clipping to zero + prior_fn = create_prior_function("clipped_polynomial", {"target_length": 50, "spread": 20}) + + assert prior_fn(50) == pytest.approx(1.0) + assert prior_fn(40) < prior_fn(50) + assert prior_fn(60) < prior_fn(50) + + def test_polynomial_clips_to_zero(self): + # spread=30 means clips to zero at ±30 chars from target + prior_fn = create_prior_function("clipped_polynomial", {"target_length": 50, "spread": 30}) + + assert prior_fn(50) == 1.0 + assert prior_fn(100) == 0.0 # 50 chars away, definitely clipped + + def test_gaussian_affects_segmentation(self, sat_model): + text = "One. Two. Three. Four. Five. Six. Seven. Eight. Nine. Ten." + + segments_small = sat_model.split( + text, + max_length=200, + prior_type="gaussian", + prior_kwargs={"target_length": 20, "spread": 5}, + threshold=0.025, + ) + + segments_large = sat_model.split( + text, + max_length=200, + prior_type="gaussian", + prior_kwargs={"target_length": 100, "spread": 20}, + threshold=0.025, + ) + + assert "".join(segments_small) == text + assert "".join(segments_large) == text + + def test_language_stats_loaded_from_json(self): + """Verify sentence stats are loaded from JSON file.""" + # Should have loaded stats for common languages + assert len(LANG_SENTENCE_STATS) > 50, "Expected stats for 50+ languages" + # Check a few known languages + assert "en" in LANG_SENTENCE_STATS + assert "de" in LANG_SENTENCE_STATS + assert "zh" in LANG_SENTENCE_STATS + + def test_get_language_defaults_known_language(self): + """Test getting defaults for a known language.""" + defaults = get_language_defaults("en") + assert "target_length" in defaults + assert "spread" in defaults + assert isinstance(defaults["target_length"], int) + assert isinstance(defaults["spread"], int) + + def test_get_language_defaults_unknown_language_warns(self): + """Test that unknown language triggers a warning.""" + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + defaults = get_language_defaults("xyz_unknown") + assert len(w) == 1 + assert "No sentence statistics" in str(w[0].message) + assert defaults["target_length"] == 70 # default + assert defaults["spread"] == 25 # default + + def test_get_language_defaults_none(self): + """Test getting defaults with no language code.""" + defaults = get_language_defaults(None) + assert defaults["target_length"] == 70 + assert defaults["spread"] == 25 + + def test_prior_with_lang_code(self): + """Test creating prior with lang_code parameter.""" + prior_fn = create_prior_function("gaussian", {"lang_code": "en"}) + # Should use English defaults, peak should be near English target_length + en_defaults = get_language_defaults("en") + assert prior_fn(en_defaults["target_length"]) == pytest.approx(1.0) + + +# ============================================================================= +# INPUT VALIDATION TESTS +# ============================================================================= + + +class TestInputValidation: + """Test input parameter validation.""" + + def test_min_greater_than_max_raises(self, sat_model): + with pytest.raises(ValueError, match="min_length.*cannot be greater than max_length"): + sat_model.split("Hello", min_length=100, max_length=50) + + def test_invalid_prior_type_raises(self, sat_model): + with pytest.raises(ValueError, match="Unknown prior_type"): + sat_model.split("Hello", prior_type="invalid_prior") + + def test_invalid_algorithm_raises(self, sat_model): + with pytest.raises(ValueError, match="Unknown algorithm"): + sat_model.split("Hello", algorithm="invalid_algo") + + def test_min_length_zero_raises(self, sat_model): + with pytest.raises(ValueError, match="min_length must be >= 1"): + sat_model.split("Hello", min_length=0) + + +# ============================================================================= +# EDGE CASES +# ============================================================================= + + +class TestEdgeCases: + """Test edge cases and boundary conditions.""" + + def test_empty_string(self, sat_model): + segments = sat_model.split("", threshold=0.5) + assert segments == [] or segments == [""] + + def test_single_character(self, sat_model): + segments = sat_model.split("A", threshold=0.5) + assert "".join(segments) == "A" + + def test_only_whitespace(self, sat_model): + text = " \n\t " + segments = sat_model.split(text, threshold=0.5) + assert isinstance(segments, list) + + def test_only_punctuation(self, sat_model): + text = "!?!.!?.!" + segments = sat_model.split(text, threshold=0.5) + assert "".join(segments) == text + + def test_very_long_text(self, sat_model): + text = "This is a test sentence. " * 200 + segments = sat_model.split(text, max_length=200, threshold=0.025) + + assert "".join(segments) == text + for segment in segments: + assert len(segment) <= 200 + + def test_empty_probabilities(self): + """Handle empty input.""" + probs = np.array([]) + + def prior_fn(length): + return 1.0 + + indices = constrained_segmentation(probs, prior_fn, min_length=1, max_length=10) + assert indices == [] + + def test_min_length_larger_than_text(self): + """Handle impossible constraints gracefully.""" + probs = np.array([0.5, 0.5, 0.5]) + + def prior_fn(length): + return 1.0 + + indices = constrained_segmentation(probs, prior_fn, min_length=10, max_length=None) + assert len(indices) <= 1 + + +# ============================================================================= +# BOTH MODELS TEST +# ============================================================================= + + +class TestBothModels: + """Test that both WtP and SaT work with constraints.""" + + def test_wtp_preserves_text(self, wtp_model): + text = "Hello world. How are you?" + segments = wtp_model.split(text, threshold=0.5) + assert "".join(segments) == text + + def test_wtp_max_length(self, wtp_model): + text = "The quick brown fox jumps over the lazy dog. Pack my box." + segments = wtp_model.split(text, max_length=30, threshold=0.025) + + for segment in segments: + assert len(segment) <= 30 + assert "".join(segments) == text + + def test_wtp_with_both_constraints(self, wtp_model): + text = "Hello world. " * 20 + splits = wtp_model.split(text, min_length=30, max_length=80, threshold=0.005) + + for segment in splits: + assert len(segment) <= 80 + assert sum(1 for s in splits if len(s) >= 30) >= len(splits) * 0.7 + + def test_sat_preserves_text(self, sat_model): + text = "Hello world. How are you?" + segments = sat_model.split(text, threshold=0.5) + assert "".join(segments) == text + + def test_sat_max_length(self, sat_model): + text = "The quick brown fox jumps over the lazy dog. Pack my box." + segments = sat_model.split(text, max_length=30, threshold=0.025) + + for segment in segments: + assert len(segment) <= 30 + assert "".join(segments) == text + + +# ============================================================================= +# BATCH PROCESSING TESTS +# ============================================================================= + + +class TestBatchProcessing: + """Test batch processing with constraints.""" + + def test_batch_preserves_all(self, sat_model): + texts = [ + "First document here. With sentences.", + "Second document. Also with sentences. Multiple ones.", + "Third. Short.", + ] + + results = list(sat_model.split(texts, max_length=100, threshold=0.025)) + + for text, segments in zip(texts, results): + assert "".join(segments) == text + + def test_batch_respects_max_length(self, sat_model): + texts = ["Long text here. " * 10, "Another long one. " * 15] + max_length = 80 + + results = list(sat_model.split(texts, max_length=max_length, threshold=0.025)) + + for segments in results: + for segment in segments: + assert len(segment) <= max_length + + +# ============================================================================= +# WHITESPACE HANDLING TESTS +# ============================================================================= + + +class TestWhitespaceHandling: + """Comprehensive whitespace handling tests.""" + + def test_single_space_between_words(self, sat_model): + text = "Hello world." + segments = sat_model.split(text, threshold=0.5) + assert "".join(segments) == text + + def test_multiple_spaces_preserved(self, sat_model): + text = "Hello world. How are you?" + segments = sat_model.split(text, threshold=0.5) + assert "".join(segments) == text + + def test_tabs_preserved(self, sat_model): + text = "Hello\tworld.\tHow\tare\tyou?" + segments = sat_model.split(text, threshold=0.5) + assert "".join(segments) == text + + def test_mixed_whitespace(self, sat_model): + text = "Hello \t world. \n\n How are you?" + segments = sat_model.split(text, threshold=0.5, split_on_input_newlines=False) + assert "".join(segments) == text + + def test_leading_whitespace(self, sat_model): + text = " Leading spaces. Then more text." + segments = sat_model.split(text, threshold=0.5) + assert "".join(segments) == text + + def test_trailing_whitespace(self, sat_model): + text = "Text here. More text. " + segments = sat_model.split(text, threshold=0.5) + assert "".join(segments) == text + + def test_only_newlines(self, sat_model): + text = "\n\n\n" + segments = sat_model.split(text, threshold=0.5, split_on_input_newlines=False) + assert "".join(segments) == text or segments == [] + + def test_windows_line_endings(self, sat_model): + text = "Line one.\r\nLine two.\r\nLine three." + segments = sat_model.split(text, threshold=0.5, split_on_input_newlines=False) + assert "".join(segments) == text + + +# ============================================================================= +# REAL-WORLD TEXT TESTS +# ============================================================================= + + +class TestRealWorldText: + """Tests with realistic text content.""" + + def test_news_article(self, sat_model): + text = """Breaking News: Scientists at CERN have announced a groundbreaking discovery that could revolutionize our understanding of particle physics. The team, led by Dr. Elena Rodriguez, observed unexpected behavior in proton collisions at energies never before achieved. "This is the most significant finding in our field since the Higgs boson," Dr. Rodriguez stated at a press conference in Geneva.""" + + for max_len in [100, 150, 200]: + segments = sat_model.split(text, max_length=max_len, threshold=0.025) + assert "".join(segments) == text + for segment in segments: + assert len(segment) <= max_len + + def test_legal_text(self, sat_model): + text = """WHEREAS the Party of the First Part (hereinafter referred to as "Licensor") is the owner of certain intellectual property rights including but not limited to patents, trademarks, copyrights, and trade secrets relating to the technology described herein, and WHEREAS the Party of the Second Part (hereinafter referred to as "Licensee") desires to obtain a license to use said technology.""" + + segments = sat_model.split(text, max_length=150, threshold=0.025) + assert "".join(segments) == text + for segment in segments: + assert len(segment) <= 150 + + def test_technical_documentation(self, sat_model): + text = """The function accepts three parameters: input_data (required), config (optional), and callback (optional). When input_data is a string, it will be parsed as JSON; when it's an object, it will be used directly. The config parameter supports the following options: timeout (default: 30000ms), retries (default: 3), and verbose (default: false).""" + + segments = sat_model.split(text, max_length=120, threshold=0.025) + assert "".join(segments) == text + for segment in segments: + assert len(segment) <= 120 + + def test_dialogue(self, sat_model): + text = '''"Have you seen the news?" asked Maria. "About the merger?" replied John. "No, I mean about the earthquake." Maria shook her head sadly. "It's terrible."''' + + segments = sat_model.split(text, max_length=80, threshold=0.025) + assert "".join(segments) == text + for segment in segments: + assert len(segment) <= 80 + + def test_email_style_text(self, wtp_model): + """Email-style text should be segmented appropriately.""" + text = "Hi John. Thanks for your email yesterday. I reviewed the documents you sent. Everything looks good. We can proceed with the next phase. Let me know if you have questions. Best regards." + splits = wtp_model.split(text, min_length=20, max_length=70, threshold=0.005) + + for segment in splits: + assert 20 <= len(segment) <= 70 + + +# ============================================================================= +# STRESS TESTS +# ============================================================================= + + +class TestStress: + """Stress tests for edge conditions and performance.""" + + def test_very_many_short_sentences(self, sat_model): + text = "A. B. C. D. E. F. G. H. I. J. K. L. M. N. O. P. Q. R. S. T. U. V. W. X. Y. Z." + + segments = sat_model.split(text, max_length=50, threshold=0.025) + assert "".join(segments) == text + + def test_alternating_long_short(self, sat_model): + text = ( + "Short. " + + "This is a much longer sentence with many words. " * 5 + + "Short again. " + + "Another very long sentence. " * 3 + ) + + segments = sat_model.split(text, max_length=100, threshold=0.025) + assert "".join(segments) == text + for segment in segments: + assert len(segment) <= 100 + + def test_no_natural_breaks(self, sat_model): + """Text with no punctuation at all.""" + text = "This text has no punctuation at all and just keeps going and going without any natural break points whatsoever and the algorithm needs to handle this gracefully" + + segments = sat_model.split(text, max_length=50, threshold=0.025) + assert "".join(segments) == text + for segment in segments: + assert len(segment) <= 50 + + def test_repeated_sentence(self, sat_model): + sentence = "The quick brown fox jumps over the lazy dog. " + text = sentence * 100 + + segments = sat_model.split(text, max_length=100, threshold=0.025) + assert "".join(segments) == text + for segment in segments: + assert len(segment) <= 100 + + def test_rapid_repeated_calls(self, sat_model): + """Ensure consistency across rapid repeated calls.""" + text = "Hello world. How are you today?" + + results = [sat_model.split(text, max_length=50, threshold=0.025) for _ in range(10)] + + for result in results[1:]: + assert result == results[0] + + def test_10k_characters(self, sat_model): + """Test with ~10,000 character document.""" + text = "This is a test sentence with some content. " * 250 + + segments = sat_model.split(text, max_length=200, threshold=0.025) + assert "".join(segments) == text + for segment in segments: + assert len(segment) <= 200 + + def test_extreme_tiny_sentences(self, wtp_model): + """Many tiny sentences should be merged appropriately.""" + text = "A. " * 100 + splits = wtp_model.split(text, min_length=20, max_length=100, threshold=0.005) + + for segment in splits: + assert len(segment) <= 100 + + segments_meeting_min = sum(1 for s in splits if len(s) >= 20) + assert segments_meeting_min >= len(splits) - 1 + + assert "".join(splits) == text + + +# ============================================================================= +# CONSTRAINT COMBINATION TESTS +# ============================================================================= + + +class TestConstraintCombinations: + """Test various combinations of constraints.""" + + def test_tight_constraints(self, sat_model): + """min_length close to max_length.""" + text = "First sentence here. Second one follows. Third sentence ends." + + segments = sat_model.split(text, min_length=15, max_length=25, threshold=0.025) + assert "".join(segments) == text + for segment in segments: + assert len(segment) <= 25 + + def test_equal_min_max(self, sat_model): + """min_length equals max_length.""" + text = "Hello world." + + segments = sat_model.split(text, min_length=12, max_length=12, threshold=0.025) + assert "".join(segments) == text + + def test_large_min_length(self, sat_model): + text = "A. B. C. D. E. F. G. H. I. J." + + segments = sat_model.split(text, min_length=20, max_length=100, threshold=0.025) + assert "".join(segments) == text + assert len(segments) < 10 + + def test_very_small_max_length(self, sat_model): + text = "Hello world. Test." + + segments = sat_model.split(text, max_length=10, threshold=0.025) + assert "".join(segments) == text + for segment in segments: + assert len(segment) <= 10 + + +# ============================================================================= +# UNICODE AND INTERNATIONALIZATION TESTS +# ============================================================================= + + +class TestUnicodeAndI18n: + """Tests for unicode and international text.""" + + def test_chinese(self, sat_model): + text = "你好世界。今天天气很好。我很高兴见到你。" + segments = sat_model.split(text, threshold=0.5) + assert "".join(segments) == text + + def test_russian(self, sat_model): + text = "Привет мир. Как дела? Хорошо, спасибо!" + segments = sat_model.split(text, threshold=0.5) + assert "".join(segments) == text + + def test_arabic(self, sat_model): + text = "مرحبا بالعالم. كيف حالك؟" + segments = sat_model.split(text, threshold=0.5) + assert "".join(segments) == text + + def test_mixed_scripts(self, sat_model): + text = "Hello 世界. Привет мир. مرحبا world." + segments = sat_model.split(text, threshold=0.5) + assert "".join(segments) == text + + def test_emojis(self, sat_model): + text = "Hello! 😀 How are you? 🎉 Great!" + segments = sat_model.split(text, threshold=0.5) + assert "".join(segments) == text + + def test_special_unicode_characters(self, sat_model): + text = "Price: €100. Temperature: 25°C. Copyright © 2024." + segments = sat_model.split(text, threshold=0.5) + assert "".join(segments) == text + + def test_multilingual_with_constraints(self, sat_model): + text = "Hello world. 你好世界。Bonjour le monde. Hola mundo. こんにちは世界。" + splits = sat_model.split(text, min_length=10, max_length=50, threshold=0.025) + + for segment in splits: + assert len(segment) <= 50 + + +# ============================================================================= +# CONSTRAINED SEGMENTATION ALGORITHM TESTS +# ============================================================================= + + +class TestConstrainedSegmentationAlgorithm: + """Direct tests of the constrained_segmentation function.""" + + def test_high_prob_at_boundaries(self): + """Algorithm should prefer positions with high probability.""" + probs = np.zeros(30) + probs[9] = 0.9 + probs[19] = 0.9 + probs[29] = 0.9 + + prior_fn = create_prior_function("uniform", {"max_length": 30}) + boundaries = constrained_segmentation(probs, prior_fn, min_length=1, max_length=30) + + assert isinstance(boundaries, list) + + def test_uniform_probs_uses_max_length(self): + """With uniform probabilities, should split at max_length.""" + probs = np.ones(100) * 0.5 + prior_fn = create_prior_function("uniform", {"max_length": 25}) + + boundaries = constrained_segmentation(probs, prior_fn, min_length=1, max_length=25) + + assert len(boundaries) >= 3 + + prev = 0 + for b in boundaries + [100]: + assert b - prev <= 25 + prev = b + + def test_consistency(self): + """Same input should produce same output.""" + probs = np.array([0.1, 0.3, 0.5, 0.7, 0.9]) + + def prior_fn(length): + return 1.0 + + result1 = constrained_segmentation(probs, prior_fn, min_length=2, max_length=3, algorithm="viterbi") + result2 = constrained_segmentation(probs, prior_fn, min_length=2, max_length=3, algorithm="viterbi") + + assert result1 == result2 + + +# ============================================================================= +# REGRESSION TESTS +# ============================================================================= + + +class TestRegressions: + """Regression tests for previously fixed bugs.""" + + def test_viterbi_finds_sentence_boundaries(self, sat_model): + """Viterbi should prefer sentence boundaries over arbitrary positions.""" + text = "The quick brown fox jumps. Pack my box with jugs. How vexingly quick!" + segments = sat_model.split(text, max_length=150, algorithm="viterbi", threshold=0.025) + + for segment in segments: + if segment.strip(): + last_char = segment.rstrip()[-1] + assert last_char in ".!?,;:'\"" or segment[-1].isspace() or segment == segments[-1] + + def test_trailing_whitespace_preserved(self, sat_model): + text = "Sentence one. Sentence two. End." + segments = sat_model.split(text, threshold=0.5) + assert "".join(segments) == text + + def test_viterbi_backtracking_bug_fixed(self, sat_model): + """Test that Viterbi correctly traces back to start (bug fix verification).""" + text = "The quick brown fox jumps over the lazy dog. Pack my box with five dozen liquor jugs. How vexingly quick daft zebras jump! The five boxing wizards jump quickly." + segments = sat_model.split(text, max_length=150, algorithm="viterbi", threshold=0.025) + + for segment in segments: + if len(segment) > 1: + if segment[-1].isalpha() and segment[-2].isalpha(): + pytest.fail(f"Word cut detected: segment ends with '{segment[-10:]}'") + + assert "".join(segments) == text + + def test_min_length_merge_no_text_duplication(self): + """ + Regression test for bug where min_length merging caused text duplication. + """ + sat = SaT("sat-3l-sm", ort_providers=["CPUExecutionProvider"]) + + text = "A. " * 20 + segments = sat.split(text, min_length=15, max_length=30, threshold=0.5) + + rejoined = "".join(segments) + assert rejoined == text, f"Text corrupted! Expected {len(text)} chars, got {len(rejoined)}" + + original_a_count = text.count("A") + rejoined_a_count = rejoined.count("A") + assert original_a_count == rejoined_a_count + + def test_viterbi_backtracking_prev_zero_valid(self): + """ + Regression test for Viterbi backtracking bug where prev=0 was incorrectly + treated as error when it's a valid state for first chunk. + """ + probs = np.zeros(100) + probs[49] = 0.99 + probs[99] = 0.99 + + prior_fn = create_prior_function("uniform", {"max_length": 60}) + + boundaries = constrained_segmentation(probs, prior_fn, min_length=1, max_length=60, algorithm="viterbi") + + assert 50 in boundaries, f"Viterbi should find boundary at 50, got {boundaries}" + + def test_greedy_final_segment_max_length(self): + """ + Regression test for greedy algorithm final segment bug. + """ + probs = np.ones(100) * 0.1 + probs[54] = 0.99 + probs[84] = 0.99 + + prior_fn = create_prior_function("uniform", {"max_length": 40}) + + boundaries = constrained_segmentation(probs, prior_fn, min_length=20, max_length=40, algorithm="greedy") + + prev = 0 + for b in boundaries + [100]: + seg_len = b - prev + assert seg_len <= 40, f"Segment [{prev}:{b}] length {seg_len} exceeds max_length=40" + prev = b + + def test_equal_min_max_viterbi_fallback(self): + """ + Regression test: When min_length == max_length and DP fails, + the fallback should still produce valid segments. + """ + probs = np.zeros(15) + prior_fn = create_prior_function("uniform", {"max_length": 5}) + + indices = constrained_segmentation(probs, prior_fn, min_length=5, max_length=5, algorithm="viterbi") + + prev = 0 + chunks = [] + for idx in indices: + chunks.append(idx - prev) + prev = idx + if prev < 15: + chunks.append(15 - prev) + + assert all(c == 5 for c in chunks), f"Chunks should all be 5, got {chunks}" + + def test_newline_preservation_with_constraints(self, sat_model): + """ + Regression test: When using length constraints, newlines stay embedded + in segments and text is preserved with "".join() (not "\\n".join()). + + With constraints: "".join(segments) == text (newlines embedded) + Without constraints: "\\n".join(segments) == text (split on newlines) + """ + # Test basic newline + text1 = "Hello world.\nGoodbye world." + segments1 = sat_model.split(text1, max_length=50) + assert "".join(segments1) == text1, f"Basic newline failed: {segments1}" + + # Test trailing newline + text2 = "Hello world.\nGoodbye world.\n" + segments2 = sat_model.split(text2, max_length=50) + assert "".join(segments2) == text2, f"Trailing newline failed: {segments2}" + + # Test consecutive newlines + text3 = "Hello.\n\nWorld." + segments3 = sat_model.split(text3, max_length=50) + assert "".join(segments3) == text3, f"Consecutive newlines failed: {segments3}" + + # Test triple newline + text4 = "A.\n\n\nB." + segments4 = sat_model.split(text4, max_length=50) + assert "".join(segments4) == text4, f"Triple newline failed: {segments4}" + + # Test consecutive + trailing + text5 = "Hello.\n\nWorld.\n" + segments5 = sat_model.split(text5, max_length=50) + assert "".join(segments5) == text5, f"Consecutive + trailing failed: {segments5}" + + def test_viterbi_min_length_adjustment_when_possible(self): + """ + Regression test: Viterbi algorithm should adjust split points to satisfy + min_length when mathematically possible. + + With 10 chars, min=4, max=6: valid splits exist (e.g., [4,6] or [5,5] or [6,4]) + The algorithm should find one that satisfies both constraints. + """ + probs = np.array([0.1, 0.1, 0.1, 0.9, 0.1, 0.1, 0.1, 0.9, 0.1, 0.9]) + prior_fn = create_prior_function("uniform", {"max_length": 6}) + + indices = constrained_segmentation(probs, prior_fn, min_length=4, max_length=6, algorithm="viterbi") + + # Calculate chunks + prev = 0 + chunks = [] + for idx in indices: + chunks.append(idx - prev) + prev = idx + if prev < len(probs): + chunks.append(len(probs) - prev) + + # All chunks should satisfy constraints when possible + for chunk in chunks: + assert chunk <= 6, f"max_length violated: chunk={chunk}" + # min_length should be satisfied when mathematically possible + assert chunk >= 4, f"min_length violated when valid solution exists: chunks={chunks}" + + def test_viterbi_min_length_best_effort_impossible(self): + """ + Regression test: When min_length cannot be satisfied for all segments + (mathematically impossible), the algorithm should still return valid + segments with max_length strictly enforced. + + With 7 chars, min=4, max=5: impossible (needs 4+4=8 chars minimum) + Algorithm should return best-effort result with max_length enforced. + """ + probs = np.array([0.1, 0.1, 0.1, 0.1, 0.9, 0.1, 0.9]) + prior_fn = create_prior_function("uniform", {"max_length": 5}) + + indices = constrained_segmentation(probs, prior_fn, min_length=4, max_length=5, algorithm="viterbi") + + # Calculate chunks + prev = 0 + chunks = [] + for idx in indices: + chunks.append(idx - prev) + prev = idx + if prev < len(probs): + chunks.append(len(probs) - prev) + + # max_length must ALWAYS be enforced (strict) + for chunk in chunks: + assert chunk <= 5, f"max_length violated: chunk={chunk}" + + # min_length is best-effort - some chunk may be short when impossible + # Just verify we got a valid segmentation + assert sum(chunks) == 7, f"Total length wrong: {sum(chunks)}" + + def test_viterbi_single_split_adjustment(self): + """ + Regression test: When there's only one split point and the final chunk + is too short, the algorithm should try to adjust or remove the split. + """ + # 8 chars with min=3, max=5 + # Possible valid: [3,5], [4,4], [5,3] + probs = np.array([0.1, 0.1, 0.9, 0.1, 0.1, 0.1, 0.1, 0.9]) + prior_fn = create_prior_function("uniform", {"max_length": 5}) + + indices = constrained_segmentation(probs, prior_fn, min_length=3, max_length=5, algorithm="viterbi") + + prev = 0 + chunks = [] + for idx in indices: + chunks.append(idx - prev) + prev = idx + if prev < len(probs): + chunks.append(len(probs) - prev) + + # Should find a valid solution + for chunk in chunks: + assert chunk <= 5, "max_length violated" + assert chunk >= 3, "min_length violated when valid solution exists" + + +# ============================================================================= +# PARAGRAPH SEGMENTATION WITH CONSTRAINTS +# ============================================================================= + + +class TestParagraphSegmentation: + """Test nested paragraph and sentence segmentation with constraints.""" + + def test_paragraph_segmentation_with_constraints(self, wtp_model): + text = "Paragraph one sentence one. Paragraph one sentence two.\n\nParagraph two sentence one. Paragraph two sentence two." + + paragraphs = wtp_model.split(text, do_paragraph_segmentation=True, min_length=10, max_length=40) + + assert isinstance(paragraphs, list) + for paragraph in paragraphs: + assert isinstance(paragraph, list) + for sentence in paragraph: + assert 10 <= len(sentence) <= 40 + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/wtpsplit/__init__.py b/wtpsplit/__init__.py index 35af3395..5abb349a 100644 --- a/wtpsplit/__init__.py +++ b/wtpsplit/__init__.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import contextlib import math import os @@ -18,8 +20,13 @@ from wtpsplit.extract import BertCharORTWrapper, SaTORTWrapper, PyTorchWrapper, extract from wtpsplit.utils import Constants, indices_to_sentences, sigmoid, token_to_char_probs +from wtpsplit.utils.constraints import ( + constrained_segmentation, + _enforce_segment_constraints, +) +from wtpsplit.utils.priors import create_prior_function -__version__ = "2.1.7" +__version__ = "2.2.0" # suppress docopt syntax warnings (triggered in Python 3.14+) warnings.filterwarnings("ignore", category=SyntaxWarning, module="docopt") @@ -27,7 +34,8 @@ warnings.filterwarnings("ignore", category=UserWarning, message="Torchaudio's I/O functions now support.*") warnings.simplefilter("default", DeprecationWarning) # show by default -warnings.simplefilter("ignore", category=FutureWarning) # for tranformers +warnings.simplefilter("ignore", category=FutureWarning) # for transformers + class WtP: def __init__( @@ -39,10 +47,12 @@ def __init__( mixtures=None, hub_prefix="benjamin", ignore_legacy_warning=False, + language: str = None, ): self.model_name_or_model = model_name_or_model self.ort_providers = ort_providers self.ort_kwargs = ort_kwargs + self.language = language # Store for language-aware prior defaults mixture_path = None @@ -295,7 +305,7 @@ def split( text_or_texts, lang_code: str = None, style: str = None, - threshold: float = None, + threshold: float = None, # ignored when max_length is set stride=64, block_size: int = 512, batch_size=32, @@ -307,7 +317,33 @@ def split( strip_whitespace: bool = False, do_paragraph_segmentation=False, verbose: bool = False, + min_length: int = 1, + max_length: int = None, # when set, segments may contain newlines; use ''.join(segments) + prior_type: str = "uniform", + prior_kwargs: dict = None, + algorithm: str = "viterbi", ): + # Input validation + if max_length is not None and min_length > max_length: + raise ValueError(f"min_length ({min_length}) cannot be greater than max_length ({max_length})") + if min_length < 1: + raise ValueError(f"min_length must be >= 1, got {min_length}") + if max_length is not None and max_length < 1: + raise ValueError(f"max_length must be >= 1, got {max_length}") + valid_priors = ["uniform", "gaussian", "clipped_polynomial", "lognormal"] + if prior_type not in valid_priors: + raise ValueError(f"Unknown prior_type: '{prior_type}'. Must be one of {valid_priors}") + valid_algorithms = ["viterbi", "greedy"] + if algorithm not in valid_algorithms: + raise ValueError(f"Unknown algorithm: '{algorithm}'. Must be one of {valid_algorithms}") + + if max_length is not None and threshold is not None: + warnings.warn( + "Both 'threshold' and 'max_length' are set. When using length-constrained " + "segmentation (max_length), the threshold parameter is ignored.", + UserWarning, + ) + if isinstance(text_or_texts, str): return next( self._split( @@ -326,6 +362,11 @@ def split( strip_whitespace=strip_whitespace, do_paragraph_segmentation=do_paragraph_segmentation, verbose=verbose, + min_length=min_length, + max_length=max_length, + prior_type=prior_type, + prior_kwargs=prior_kwargs, + algorithm=algorithm, ) ) else: @@ -345,6 +386,11 @@ def split( strip_whitespace=strip_whitespace, do_paragraph_segmentation=do_paragraph_segmentation, verbose=verbose, + min_length=min_length, + max_length=max_length, + prior_type=prior_type, + prior_kwargs=prior_kwargs, + algorithm=algorithm, ) def get_threshold(self, lang_code: str, style: str, return_punctuation_threshold: bool = False): @@ -361,9 +407,9 @@ def get_threshold(self, lang_code: str, style: str, return_punctuation_threshold def _split( self, texts, - lang_code: str, - style: str, - threshold: float, + lang_code: str | None, + style: str | None, + threshold: float | None, stride: int, block_size: int, batch_size: int, @@ -375,6 +421,11 @@ def _split( do_paragraph_segmentation: bool, strip_whitespace: bool, verbose: bool, + min_length: int, + max_length: int | None, + prior_type: str, + prior_kwargs: dict | None, + algorithm: str, ): if style is not None: if lang_code is None: @@ -422,23 +473,68 @@ def _split( for paragraph in indices_to_sentences(text, np.where(newline_probs > paragraph_threshold)[0]): sentences = [] - for sentence in indices_to_sentences( - paragraph, - np.where( - sentence_probs[offset : offset + len(paragraph)] > sentence_threshold, - )[0], - strip_whitespace=strip_whitespace, - ): - sentences.append(sentence) + if max_length is not None or min_length > 1: + paragraph_probs = sentence_probs[offset : offset + len(paragraph)] + # Create fresh copy each iteration to avoid state leakage + local_prior_kwargs = {} if prior_kwargs is None else prior_kwargs.copy() + if max_length is not None: + local_prior_kwargs["max_length"] = max_length + # Use model's language for prior defaults if not explicitly set + if ( + self.language + and "lang_code" not in local_prior_kwargs + and "target_length" not in local_prior_kwargs + ): + local_prior_kwargs["lang_code"] = self.language + prior_fn = create_prior_function(prior_type, local_prior_kwargs) + + boundaries = constrained_segmentation( + paragraph_probs, prior_fn, min_length=min_length, max_length=max_length, algorithm=algorithm + ) + indices = [b - 1 for b in boundaries] + + sentences = _enforce_segment_constraints( + paragraph, indices, min_length, max_length, strip_whitespace=strip_whitespace + ) + else: + for sentence in indices_to_sentences( + paragraph, + np.where( + sentence_probs[offset : offset + len(paragraph)] > sentence_threshold, + )[0], + strip_whitespace=strip_whitespace, + ): + sentences.append(sentence) paragraphs.append(sentences) offset += len(paragraph) yield paragraphs else: - sentences = indices_to_sentences( - text, np.where(probs > sentence_threshold)[0], strip_whitespace=strip_whitespace - ) + if max_length is not None or min_length > 1: + # Create fresh copy each iteration to avoid state leakage + local_prior_kwargs = {} if prior_kwargs is None else prior_kwargs.copy() + if max_length is not None: + local_prior_kwargs["max_length"] = max_length + # Use model's language for prior defaults if not explicitly set + if ( + self.language + and "lang_code" not in local_prior_kwargs + and "target_length" not in local_prior_kwargs + ): + local_prior_kwargs["lang_code"] = self.language + prior_fn = create_prior_function(prior_type, local_prior_kwargs) + boundaries = constrained_segmentation( + probs, prior_fn, min_length=min_length, max_length=max_length, algorithm=algorithm + ) + indices = [b - 1 for b in boundaries] + sentences = _enforce_segment_constraints( + text, indices, min_length, max_length, strip_whitespace=strip_whitespace + ) + else: + sentences = indices_to_sentences( + text, np.where(probs > sentence_threshold)[0], strip_whitespace=strip_whitespace + ) yield sentences @@ -466,6 +562,7 @@ def __init__( self.model_name_or_model = model_name_or_model self.ort_providers = ort_providers self.ort_kwargs = ort_kwargs + self.language = language # Store for language-aware prior defaults self.use_lora = False @@ -771,7 +868,7 @@ def newline_probability_fn(logits): def split( self, text_or_texts, - threshold: float = None, + threshold: float = None, # ignored when max_length is set stride=64, block_size: int = 512, batch_size=32, @@ -782,9 +879,14 @@ def split( paragraph_threshold: float = 0.5, strip_whitespace: bool = False, do_paragraph_segmentation: bool = False, - split_on_input_newlines: bool = True, + split_on_input_newlines: bool = True, # only applies when max_length is not set treat_newline_as_space=None, # Deprecated verbose: bool = False, + min_length: int = 1, + max_length: int = None, # when set, segments may contain newlines; use ''.join(segments) + prior_type: str = "uniform", + prior_kwargs: dict = None, + algorithm: str = "viterbi", ): if treat_newline_as_space is not None: warnings.warn( @@ -793,6 +895,36 @@ def split( DeprecationWarning, ) split_on_input_newlines = not treat_newline_as_space + + # Input validation + if max_length is not None and min_length > max_length: + raise ValueError(f"min_length ({min_length}) cannot be greater than max_length ({max_length})") + if min_length < 1: + raise ValueError(f"min_length must be >= 1, got {min_length}") + if max_length is not None and max_length < 1: + raise ValueError(f"max_length must be >= 1, got {max_length}") + valid_priors = ["uniform", "gaussian", "clipped_polynomial", "lognormal"] + if prior_type not in valid_priors: + raise ValueError(f"Unknown prior_type: '{prior_type}'. Must be one of {valid_priors}") + valid_algorithms = ["viterbi", "greedy"] + if algorithm not in valid_algorithms: + raise ValueError(f"Unknown algorithm: '{algorithm}'. Must be one of {valid_algorithms}") + + if max_length is not None and threshold is not None: + warnings.warn( + "Both 'threshold' and 'max_length' are set. When using length-constrained " + "segmentation (max_length), the threshold parameter is ignored.", + UserWarning, + ) + + if (max_length is not None or min_length > 1) and split_on_input_newlines: + warnings.warn( + "When using length constraints (max_length/min_length), segments may contain newlines. " + "split_on_input_newlines is ignored; use ''.join(segments) to reconstruct the original text. " + "To split at newlines with constraints, pre-split your text at newlines and process each line.", + UserWarning, + ) + if isinstance(text_or_texts, str): return next( self._split( @@ -810,6 +942,11 @@ def split( do_paragraph_segmentation=do_paragraph_segmentation, split_on_input_newlines=split_on_input_newlines, verbose=verbose, + min_length=min_length, + max_length=max_length, + prior_type=prior_type, + prior_kwargs=prior_kwargs, + algorithm=algorithm, ) ) else: @@ -828,12 +965,17 @@ def split( do_paragraph_segmentation=do_paragraph_segmentation, split_on_input_newlines=split_on_input_newlines, verbose=verbose, + min_length=min_length, + max_length=max_length, + prior_type=prior_type, + prior_kwargs=prior_kwargs, + algorithm=algorithm, ) def _split( self, texts, - threshold: float, + threshold: float | None, stride: int, block_size: int, batch_size: int, @@ -844,8 +986,13 @@ def _split( outer_batch_size: int, do_paragraph_segmentation: bool, split_on_input_newlines: bool, + min_length: int, + max_length: int | None, strip_whitespace: bool, verbose: bool, + prior_type: str, + prior_kwargs: dict | None, + algorithm: str, ): def get_default_threshold(model_str: str): # basic type check for safety @@ -889,35 +1036,89 @@ def get_default_threshold(model_str: str): for paragraph in indices_to_sentences(text, np.where(newline_probs > paragraph_threshold)[0]): sentences = [] - for sentence in indices_to_sentences( - paragraph, - np.where( - sentence_probs[offset : offset + len(paragraph)] > sentence_threshold, - )[0], - strip_whitespace=strip_whitespace, - ): - sentences.append(sentence) + if max_length is not None or min_length > 1: + paragraph_probs = sentence_probs[offset : offset + len(paragraph)] + # Create fresh copy each iteration to avoid state leakage + local_prior_kwargs = {} if prior_kwargs is None else prior_kwargs.copy() + if max_length is not None: + local_prior_kwargs["max_length"] = max_length + # Use model's language for prior defaults if not explicitly set + if ( + self.language + and "lang_code" not in local_prior_kwargs + and "target_length" not in local_prior_kwargs + ): + local_prior_kwargs["lang_code"] = self.language + prior_fn = create_prior_function(prior_type, local_prior_kwargs) + + boundaries = constrained_segmentation( + paragraph_probs, prior_fn, min_length=min_length, max_length=max_length, algorithm=algorithm + ) + indices = [b - 1 for b in boundaries] + + sentences = _enforce_segment_constraints( + paragraph, indices, min_length, max_length, strip_whitespace=strip_whitespace + ) + else: + for sentence in indices_to_sentences( + paragraph, + np.where( + sentence_probs[offset : offset + len(paragraph)] > sentence_threshold, + )[0], + strip_whitespace=strip_whitespace, + ): + sentences.append(sentence) paragraphs.append(sentences) offset += len(paragraph) yield paragraphs else: - sentences = indices_to_sentences( - text, np.where(probs > sentence_threshold)[0], strip_whitespace=strip_whitespace - ) - if split_on_input_newlines: - # within the model, newlines in the text were ignored - they were treated as spaces. - # this is the default behavior: additionally split on newlines as provided in the input - new_sentences = [] - for sentence in sentences: - new_sentences.extend(sentence.split("\n")) - sentences = new_sentences + if max_length is not None or min_length > 1: + # Create fresh copy each iteration to avoid state leakage + local_prior_kwargs = {} if prior_kwargs is None else prior_kwargs.copy() + if max_length is not None: + local_prior_kwargs["max_length"] = max_length + # Use model's language for prior defaults if not explicitly set + if ( + self.language + and "lang_code" not in local_prior_kwargs + and "target_length" not in local_prior_kwargs + ): + local_prior_kwargs["lang_code"] = self.language + prior_fn = create_prior_function(prior_type, local_prior_kwargs) + + boundaries = constrained_segmentation( + probs, prior_fn, min_length=min_length, max_length=max_length, algorithm=algorithm + ) + indices = [b - 1 for b in boundaries] + sentences = _enforce_segment_constraints( + text, indices, min_length, max_length, strip_whitespace=strip_whitespace + ) + # Note: when constraints are used, newlines may appear inside segments. + # Use "".join(segments) == text for reconstruction (not "\n".join()). else: - warnings.warn( - "split_on_input_newlines=False will lead to newlines in the output " - "if they were present in the input. Within the model, such newlines are " - "treated as spaces. " - "If you want to split on such newlines, set split_on_input_newlines=False." + sentences = indices_to_sentences( + text, np.where(probs > sentence_threshold)[0], strip_whitespace=strip_whitespace ) + + if split_on_input_newlines: + # within the model, newlines in the text were ignored - they were treated as spaces. + # this is the default behavior: additionally split on newlines as provided in the input + # Note: use "\n".join(segments) to reconstruct text (not "".join()) + new_sentences = [] + for i, sentence in enumerate(sentences): + # Strip ONE trailing newline from non-final segments to avoid + # duplicate delimiters when joined (but preserve internal newlines) + if i < len(sentences) - 1 and sentence.endswith("\n"): + sentence = sentence[:-1] + new_sentences.extend(sentence.split("\n")) + sentences = new_sentences + else: + warnings.warn( + "split_on_input_newlines=False will lead to newlines in the output " + "if they were present in the input. Within the model, such newlines are " + "treated as spaces. " + "If you want to split on such newlines, set split_on_input_newlines=True." + ) yield sentences diff --git a/wtpsplit/data/sentence_stats.json b/wtpsplit/data/sentence_stats.json new file mode 100644 index 00000000..b02e9aa3 --- /dev/null +++ b/wtpsplit/data/sentence_stats.json @@ -0,0 +1,576 @@ +{ + "metadata": { + "source": "Universal Dependencies v2.14", + "generated_at": "2026-01-21T12:19:34.907700Z" + }, + "stats": { + "ab": { + "target_length": 48, + "spread": 22 + }, + "af": { + "target_length": 132, + "spread": 70 + }, + "ajp": { + "target_length": 27, + "spread": 14 + }, + "akk": { + "target_length": 51, + "spread": 36 + }, + "aln": { + "target_length": 67, + "spread": 35 + }, + "am": { + "target_length": 18, + "spread": 10 + }, + "apu": { + "target_length": 30, + "spread": 12 + }, + "aqz": { + "target_length": 16, + "spread": 10 + }, + "ar": { + "target_length": 67, + "spread": 55 + }, + "arr": { + "target_length": 16, + "spread": 10 + }, + "az": { + "target_length": 29, + "spread": 14 + }, + "azz": { + "target_length": 30, + "spread": 28 + }, + "bar": { + "target_length": 51, + "spread": 40 + }, + "be": { + "target_length": 58, + "spread": 48 + }, + "bej": { + "target_length": 58, + "spread": 46 + }, + "bg": { + "target_length": 66, + "spread": 46 + }, + "bho": { + "target_length": 72, + "spread": 42 + }, + "bm": { + "target_length": 38, + "spread": 22 + }, + "bor": { + "target_length": 29, + "spread": 22 + }, + "br": { + "target_length": 42, + "spread": 33 + }, + "bxr": { + "target_length": 55, + "spread": 37 + }, + "ca": { + "target_length": 151, + "spread": 85 + }, + "ceb": { + "target_length": 27, + "spread": 13 + }, + "ckt": { + "target_length": 34, + "spread": 19 + }, + "cop": { + "target_length": 70, + "spread": 42 + }, + "cpg": { + "target_length": 40, + "spread": 28 + }, + "cs": { + "target_length": 86, + "spread": 62 + }, + "cu": { + "target_length": 39, + "spread": 29 + }, + "cy": { + "target_length": 82, + "spread": 62 + }, + "da": { + "target_length": 82, + "spread": 61 + }, + "de": { + "target_length": 111, + "spread": 61 + }, + "egy": { + "target_length": 29, + "spread": 16 + }, + "el": { + "target_length": 94, + "spread": 74 + }, + "eme": { + "target_length": 19, + "spread": 11 + }, + "en": { + "target_length": 63, + "spread": 53 + }, + "es": { + "target_length": 135, + "spread": 87 + }, + "ess": { + "target_length": 29, + "spread": 11 + }, + "et": { + "target_length": 74, + "spread": 55 + }, + "eu": { + "target_length": 79, + "spread": 45 + }, + "fa": { + "target_length": 72, + "spread": 48 + }, + "fi": { + "target_length": 63, + "spread": 47 + }, + "fo": { + "target_length": 63, + "spread": 48 + }, + "fr": { + "target_length": 90, + "spread": 65 + }, + "frm": { + "target_length": 87, + "spread": 66 + }, + "fro": { + "target_length": 37, + "spread": 25 + }, + "ga": { + "target_length": 111, + "spread": 54 + }, + "gd": { + "target_length": 66, + "spread": 72 + }, + "gl": { + "target_length": 160, + "spread": 57 + }, + "got": { + "target_length": 49, + "spread": 37 + }, + "grc": { + "target_length": 66, + "spread": 48 + }, + "gsw": { + "target_length": 73, + "spread": 25 + }, + "gu": { + "target_length": 36, + "spread": 26 + }, + "gub": { + "target_length": 33, + "spread": 16 + }, + "gun": { + "target_length": 20, + "spread": 11 + }, + "gv": { + "target_length": 26, + "spread": 10 + }, + "ha": { + "target_length": 29, + "spread": 24 + }, + "hbo": { + "target_length": 118, + "spread": 45 + }, + "he": { + "target_length": 86, + "spread": 48 + }, + "hi": { + "target_length": 93, + "spread": 45 + }, + "hit": { + "target_length": 67, + "spread": 37 + }, + "hr": { + "target_length": 118, + "spread": 68 + }, + "hsb": { + "target_length": 86, + "spread": 48 + }, + "ht": { + "target_length": 89, + "spread": 47 + }, + "hu": { + "target_length": 138, + "spread": 79 + }, + "hy": { + "target_length": 91, + "spread": 77 + }, + "hyw": { + "target_length": 88, + "spread": 69 + }, + "id": { + "target_length": 117, + "spread": 64 + }, + "is": { + "target_length": 87, + "spread": 64 + }, + "it": { + "target_length": 108, + "spread": 62 + }, + "ja": { + "target_length": 18, + "spread": 15 + }, + "jv": { + "target_length": 69, + "spread": 38 + }, + "ka": { + "target_length": 85, + "spread": 60 + }, + "kk": { + "target_length": 52, + "spread": 28 + }, + "kmr": { + "target_length": 57, + "spread": 21 + }, + "ko": { + "target_length": 44, + "spread": 22 + }, + "koi": { + "target_length": 34, + "spread": 20 + }, + "kpv": { + "target_length": 51, + "spread": 36 + }, + "krl": { + "target_length": 73, + "spread": 40 + }, + "ky": { + "target_length": 58, + "spread": 25 + }, + "la": { + "target_length": 73, + "spread": 58 + }, + "lij": { + "target_length": 55, + "spread": 49 + }, + "lt": { + "target_length": 102, + "spread": 74 + }, + "lv": { + "target_length": 84, + "spread": 62 + }, + "lzh": { + "target_length": 5, + "spread": 10 + }, + "mdf": { + "target_length": 49, + "spread": 22 + }, + "mk": { + "target_length": 37, + "spread": 14 + }, + "ml": { + "target_length": 69, + "spread": 47 + }, + "mr": { + "target_length": 32, + "spread": 17 + }, + "mt": { + "target_length": 98, + "spread": 75 + }, + "myu": { + "target_length": 25, + "spread": 14 + }, + "myv": { + "target_length": 46, + "spread": 31 + }, + "nds": { + "target_length": 80, + "spread": 59 + }, + "nhi": { + "target_length": 47, + "spread": 31 + }, + "nl": { + "target_length": 79, + "spread": 62 + }, + "no": { + "target_length": 77, + "spread": 54 + }, + "olo": { + "target_length": 70, + "spread": 41 + }, + "orv": { + "target_length": 43, + "spread": 38 + }, + "ota": { + "target_length": 82, + "spread": 69 + }, + "pad": { + "target_length": 26, + "spread": 13 + }, + "pcm": { + "target_length": 46, + "spread": 37 + }, + "pl": { + "target_length": 55, + "spread": 42 + }, + "pt": { + "target_length": 74, + "spread": 57 + }, + "qaf": { + "target_length": 68, + "spread": 39 + }, + "qfn": { + "target_length": 47, + "spread": 18 + }, + "qpm": { + "target_length": 61, + "spread": 40 + }, + "qtd": { + "target_length": 79, + "spread": 45 + }, + "quc": { + "target_length": 26, + "spread": 10 + }, + "ro": { + "target_length": 102, + "spread": 60 + }, + "ru": { + "target_length": 79, + "spread": 60 + }, + "sa": { + "target_length": 37, + "spread": 27 + }, + "sah": { + "target_length": 25, + "spread": 10 + }, + "say": { + "target_length": 32, + "spread": 25 + }, + "si": { + "target_length": 44, + "spread": 10 + }, + "sjo": { + "target_length": 62, + "spread": 59 + }, + "sk": { + "target_length": 44, + "spread": 31 + }, + "sl": { + "target_length": 79, + "spread": 64 + }, + "sme": { + "target_length": 41, + "spread": 29 + }, + "sms": { + "target_length": 51, + "spread": 30 + }, + "sr": { + "target_length": 120, + "spread": 62 + }, + "sv": { + "target_length": 83, + "spread": 53 + }, + "swl": { + "target_length": 66, + "spread": 46 + }, + "ta": { + "target_length": 52, + "spread": 58 + }, + "te": { + "target_length": 25, + "spread": 10 + }, + "th": { + "target_length": 95, + "spread": 40 + }, + "tl": { + "target_length": 34, + "spread": 21 + }, + "tpn": { + "target_length": 38, + "spread": 23 + }, + "tr": { + "target_length": 49, + "spread": 37 + }, + "tt": { + "target_length": 93, + "spread": 44 + }, + "ug": { + "target_length": 63, + "spread": 35 + }, + "uk": { + "target_length": 75, + "spread": 59 + }, + "ur": { + "target_length": 108, + "spread": 59 + }, + "vep": { + "target_length": 69, + "spread": 36 + }, + "vi": { + "target_length": 80, + "spread": 40 + }, + "wo": { + "target_length": 78, + "spread": 48 + }, + "xav": { + "target_length": 42, + "spread": 25 + }, + "xcl": { + "target_length": 73, + "spread": 42 + }, + "xnr": { + "target_length": 35, + "spread": 14 + }, + "xum": { + "target_length": 26, + "spread": 26 + }, + "yo": { + "target_length": 96, + "spread": 45 + }, + "yrl": { + "target_length": 38, + "spread": 30 + }, + "yue": { + "target_length": 13, + "spread": 14 + }, + "zh": { + "target_length": 29, + "spread": 19 + } + } +} \ No newline at end of file diff --git a/wtpsplit/evaluation/adapt.py b/wtpsplit/evaluation/adapt.py index 52e6852b..fe886572 100644 --- a/wtpsplit/evaluation/adapt.py +++ b/wtpsplit/evaluation/adapt.py @@ -312,7 +312,7 @@ def main(args): save_model_path = args.model_path if args.adapter_path: save_model_path = args.adapter_path - save_str = f"{save_model_path.replace('/','_')}_b{args.block_size}_s{args.stride}" + save_str = f"{save_model_path.replace('/', '_')}_b{args.block_size}_s{args.stride}" eval_data = torch.load(args.eval_data_path) if "canine" in args.model_path and "no-adapters" not in args.model_path: diff --git a/wtpsplit/evaluation/evaluate_sepp_nlg_subtask1.py b/wtpsplit/evaluation/evaluate_sepp_nlg_subtask1.py index aa7fec51..a54571ad 100644 --- a/wtpsplit/evaluation/evaluate_sepp_nlg_subtask1.py +++ b/wtpsplit/evaluation/evaluate_sepp_nlg_subtask1.py @@ -52,9 +52,9 @@ def evaluate_subtask1(splits, langs, prediction_dir: str, supervisions, include_ rows = [line.split("\t") for line in lines] pred_labels = [row[1] for row in rows] - assert ( - len(gt_labels) == len(pred_labels) - ), f"unequal no. of labels for files {gt_tsv_file} and {os.path.join(prediction_dir, lang_code, split, basename)}" + assert len(gt_labels) == len(pred_labels), ( + f"unequal no. of labels for files {gt_tsv_file} and {os.path.join(prediction_dir, lang_code, split, basename)}" + ) all_gt_labels.extend(gt_labels) all_predicted_labels.extend(pred_labels) diff --git a/wtpsplit/evaluation/intrinsic_pairwise.py b/wtpsplit/evaluation/intrinsic_pairwise.py index e6183d8d..b659b16d 100644 --- a/wtpsplit/evaluation/intrinsic_pairwise.py +++ b/wtpsplit/evaluation/intrinsic_pairwise.py @@ -286,7 +286,7 @@ def main(args): save_model_path = args.model_path if args.adapter_path: save_model_path = args.adapter_path - save_str = f"{save_model_path.replace('/','_')}_b{args.block_size}_k{args.k}{args.save_suffix}" + save_str = f"{save_model_path.replace('/', '_')}_b{args.block_size}_k{args.k}{args.save_suffix}" print(save_str) eval_data = torch.load(args.eval_data_path) diff --git a/wtpsplit/evaluation/intrinsic_ted.py b/wtpsplit/evaluation/intrinsic_ted.py index 9183f29a..f68cca3d 100644 --- a/wtpsplit/evaluation/intrinsic_ted.py +++ b/wtpsplit/evaluation/intrinsic_ted.py @@ -246,7 +246,7 @@ def main(args): save_model_path = args.model_path if args.adapter_path: save_model_path = args.adapter_path - save_str = f"{save_model_path.replace('/','_')}_b{args.block_size}_s{args.stride}" + save_str = f"{save_model_path.replace('/', '_')}_b{args.block_size}_s{args.stride}" eval_data = torch.load(args.eval_data_path) if args.valid_text_path is not None: diff --git a/wtpsplit/evaluation/legal_baselines.py b/wtpsplit/evaluation/legal_baselines.py index 88141d88..a17b5400 100644 --- a/wtpsplit/evaluation/legal_baselines.py +++ b/wtpsplit/evaluation/legal_baselines.py @@ -173,7 +173,7 @@ def compute_statistics(values): def main(args): save_model_path = f"rcds/distilbert-SBD-{args.lang_support}-{args.type}_s{args.stride}" - save_str = f"{save_model_path.replace('/','_')}" + save_str = f"{save_model_path.replace('/', '_')}" eval_data = torch.load(args.eval_data_path) diff --git a/wtpsplit/evaluation/punct_annotation_wtp.py b/wtpsplit/evaluation/punct_annotation_wtp.py index 3a8df910..fed6c944 100644 --- a/wtpsplit/evaluation/punct_annotation_wtp.py +++ b/wtpsplit/evaluation/punct_annotation_wtp.py @@ -104,7 +104,7 @@ def load_iwslt(path, fix_space=True): json.dump( results, open( - Constants.CACHE_DIR / "extrinsic" / f"iwslt_{args.model_path.replace('/','_')}_{args.lang}.json", + Constants.CACHE_DIR / "extrinsic" / f"iwslt_{args.model_path.replace('/', '_')}_{args.lang}.json", "w", ), ) diff --git a/wtpsplit/evaluation/stat_tests/permutation_test.py b/wtpsplit/evaluation/stat_tests/permutation_test.py index 4eb442b9..c3d8694f 100644 --- a/wtpsplit/evaluation/stat_tests/permutation_test.py +++ b/wtpsplit/evaluation/stat_tests/permutation_test.py @@ -115,9 +115,9 @@ _, _, f1 = compute_prf(y_true, y_pred, num_docs) - assert np.allclose( - f1, val_results[args.lang][dataset][model] - ), f" MISMATCH! {args.lang} {dataset} {model} F1: {f1} intrinsic_py_out: {val_results[args.lang][dataset][model]}" + assert np.allclose(f1, val_results[args.lang][dataset][model]), ( + f" MISMATCH! {args.lang} {dataset} {model} F1: {f1} intrinsic_py_out: {val_results[args.lang][dataset][model]}" + ) for i in range(num_systems): for j in range(i + 1, num_systems): diff --git a/wtpsplit/evaluation/stat_tests/permutation_test_data.py b/wtpsplit/evaluation/stat_tests/permutation_test_data.py index 58b52d12..aa460e3a 100644 --- a/wtpsplit/evaluation/stat_tests/permutation_test_data.py +++ b/wtpsplit/evaluation/stat_tests/permutation_test_data.py @@ -85,9 +85,9 @@ data_lengths = [data_lengths] if "lengths" in raw_data[lang][dataset]: - assert ( - raw_data[lang][dataset]["lengths"] == data_lengths - ), f"{lang}, {dataset}, {model_type}... [lengths assertion] before: {raw_data[lang][dataset]['lengths']} after: {data_lengths}" + assert raw_data[lang][dataset]["lengths"] == data_lengths, ( + f"{lang}, {dataset}, {model_type}... [lengths assertion] before: {raw_data[lang][dataset]['lengths']} after: {data_lengths}" + ) else: raw_data[lang][dataset]["lengths"] = data_lengths @@ -212,16 +212,16 @@ raw_data[lang]["main_table_mean"][system] = preds_main_results_indicies if "true_indices" in raw_data[lang]["main_table_mean"]: - assert ( - raw_data[lang]["main_table_mean"]["true_indices"] == trues_main_results_indicies - ), f"{lang} {system}, {[len(i) for i in trues_main_results_indicies]}, {[len(i) for i in raw_data[lang]['main_table_mean']['true_indices']]}" + assert raw_data[lang]["main_table_mean"]["true_indices"] == trues_main_results_indicies, ( + f"{lang} {system}, {[len(i) for i in trues_main_results_indicies]}, {[len(i) for i in raw_data[lang]['main_table_mean']['true_indices']]}" + ) else: raw_data[lang]["main_table_mean"]["true_indices"] = trues_main_results_indicies if "lengths" in raw_data[lang]["main_table_mean"]: - assert ( - raw_data[lang]["main_table_mean"]["lengths"] == lengths_main_results - ), f"{lang} {system} {raw_data[lang]['main_table_mean']['lengths']} {lengths_main_results}" + assert raw_data[lang]["main_table_mean"]["lengths"] == lengths_main_results, ( + f"{lang} {system} {raw_data[lang]['main_table_mean']['lengths']} {lengths_main_results}" + ) else: raw_data[lang]["main_table_mean"]["lengths"] = lengths_main_results diff --git a/wtpsplit/extract.py b/wtpsplit/extract.py index 26938026..e8595096 100644 --- a/wtpsplit/extract.py +++ b/wtpsplit/extract.py @@ -220,7 +220,7 @@ def extract( ) else: language_ids = None - + # compute weights for the given weighting scheme if weighting == "uniform": weights = np.ones(block_size, dtype=np.float16) diff --git a/wtpsplit/models.py b/wtpsplit/models.py index b68d0e85..c00727c9 100644 --- a/wtpsplit/models.py +++ b/wtpsplit/models.py @@ -70,9 +70,9 @@ def _embed_hash_buckets(self, input_ids=None, hashed_ids=None): num_hashes = self.config.num_hash_functions num_buckets = self.config.num_hash_buckets - assert (input_ids is None) + ( - hashed_ids is None - ) == 1, "Either `input_ids` or `hashed_ids` must be provided (and not both!)." + assert (input_ids is None) + (hashed_ids is None) == 1, ( + "Either `input_ids` or `hashed_ids` must be provided (and not both!)." + ) """Converts IDs (e.g. codepoints) into embeddings via multiple hashing.""" if embedding_size % num_hashes != 0: diff --git a/wtpsplit/utils/__init__.py b/wtpsplit/utils/__init__.py index b535bfb9..c17c06e3 100644 --- a/wtpsplit/utils/__init__.py +++ b/wtpsplit/utils/__init__.py @@ -9,7 +9,7 @@ import numpy as np import pandas as pd -from cached_property import cached_property +from functools import cached_property from mosestokenizer import MosesTokenizer from transformers import AutoTokenizer diff --git a/wtpsplit/utils/constraints.py b/wtpsplit/utils/constraints.py new file mode 100644 index 00000000..c4f950e9 --- /dev/null +++ b/wtpsplit/utils/constraints.py @@ -0,0 +1,360 @@ +import numpy as np + +from wtpsplit.utils import indices_to_sentences + + +def _enforce_segment_constraints(text, indices, min_length, max_length, strip_whitespace=False): + """ + Extract segments from text using indices, enforcing STRICT length constraints. + + NOTE: This post-processing is necessary because Viterbi operates on raw character + indices, but text extraction extends segments to include trailing whitespace (which + can exceed max_length) and optionally strips whitespace (which can go below min_length). + + Guarantees: + - All segments are strictly <= max_length characters + - All segments are >= min_length characters (best effort) + - "".join(segments) == original_text (text preservation) + + Args: + text: Original text string + indices: List of split indices (positions where segments end) + min_length: Minimum segment length + max_length: Maximum segment length (None for no limit) + strip_whitespace: Whether to strip whitespace from final segments + + Returns: + List of segments that respect the constraints + """ + if not text: + return [] + + # For whitespace-only text, return empty if strip_whitespace, otherwise preserve + if not text.strip(): + if strip_whitespace: + return [] + # Text is whitespace-only but we need to preserve it + if max_length is not None and len(text) > max_length: + return [text[i : i + max_length] for i in range(0, len(text), max_length)] + return [text] + + # No constraints - use standard extraction + if min_length <= 1 and max_length is None: + return indices_to_sentences(text, indices, strip_whitespace=strip_whitespace) + + # Build initial segment boundaries from indices + boundaries = [] + offset = 0 + for idx in indices: + end = idx + 1 + # Extend to include trailing whitespace + while end < len(text) and text[end].isspace(): + end += 1 + if end > offset: + boundaries.append((offset, end)) + offset = end + # Add final segment + if offset < len(text): + boundaries.append((offset, len(text))) + + if not boundaries: + seg = text.strip() if strip_whitespace else text + if max_length is not None and len(seg) > max_length: + # Force split - only filter whitespace-only chunks if strip_whitespace is True + chunks = [seg[i : i + max_length] for i in range(0, len(seg), max_length)] + if strip_whitespace: + chunks = [c for c in chunks if c.strip()] + else: + chunks = [c for c in chunks if c] # Only filter truly empty strings + return chunks + return [seg] if seg else [] + + # Process boundaries to enforce strict max_length while preserving text + result = [] + pending_prefix = "" # Whitespace to prepend to next segment + i = 0 + + while i < len(boundaries): + start, end = boundaries[i] + segment = pending_prefix + text[start:end] + pending_prefix = "" + + # STRICT max_length enforcement + if max_length is not None and len(segment) > max_length: + # Split this segment to fit max_length + while len(segment) > max_length: + # Find a good split point (prefer splitting at whitespace) + split_at = max_length + # Look for whitespace near the end to split at + for j in range(max_length - 1, max(0, max_length - 20), -1): + if segment[j].isspace(): + split_at = j + 1 + break + + chunk = segment[:split_at] + segment = segment[split_at:] + + if strip_whitespace: + chunk = chunk.strip() + if chunk: + result.append(chunk) + + # Handle remaining part + if segment: + # Check if remaining can be merged with next segment + if i + 1 < len(boundaries): + pending_prefix = segment + else: + if strip_whitespace: + segment = segment.strip() + if segment: + result.append(segment) + i += 1 + continue + + # Check min_length - merge with next if too short + seg_len = len(segment.strip()) if strip_whitespace else len(segment) + if seg_len < min_length and i + 1 < len(boundaries): + # Try to merge with next segment + j = i + 1 + + while j < len(boundaries) and seg_len < min_length: + _, next_end = boundaries[j] + # Merge by appending text from current end to next boundary + merged = segment + text[end:next_end] if segment else text[start:next_end] + merged_len = len(merged.strip()) if strip_whitespace else len(merged) + + # Check strict max_length + if max_length is not None and merged_len > max_length: + break + + segment = merged + end = next_end # Update end to track where we've merged up to + seg_len = merged_len + j += 1 + + if strip_whitespace: + segment = segment.strip() + if segment: + result.append(segment) + i = j + pending_prefix = "" + continue + + # Segment is valid + if strip_whitespace: + segment = segment.strip() + if segment: + result.append(segment) + i += 1 + + # Handle any remaining prefix + if pending_prefix: + if result: + # Try to append to last segment + last = result[-1] + merged = last + pending_prefix + if max_length is None or len(merged) <= max_length: + result[-1] = merged + else: + result.append(pending_prefix.strip() if strip_whitespace else pending_prefix) + else: + result.append(pending_prefix.strip() if strip_whitespace else pending_prefix) + + # Final cleanup: merge last segment if too short + if len(result) > 1: + last = result[-1] + last_len = len(last.strip()) if strip_whitespace else len(last) + if last_len < min_length: + prev = result[-2] + merged = prev + last + if max_length is None or len(merged) <= max_length: + result[-2] = merged + result.pop() + + # Return all segments to preserve text (don't filter whitespace-only) + return result + + +def constrained_segmentation( + probs, + prior_fn, + min_length=1, + max_length=None, + algorithm="viterbi", +): + """ + Segments text based on probabilities and length constraints. + + Args: + probs: Array of probabilities (scores) for each unit. + prior_fn: Function that takes a length and returns a prior probability. + min_length: Minimum length of a chunk. + max_length: Maximum length of a chunk. + algorithm: "viterbi" or "greedy". + + Returns: + List of indices where splits occur (end of chunk). + """ + n = len(probs) + if max_length is None: + max_length = n + + if algorithm == "greedy": + # Simple greedy approach (not optimal) + indices = [] + current_idx = 0 + while current_idx < n: + best_score = -float("inf") + best_end = -1 + + start_search = current_idx + min_length + end_search = min(current_idx + max_length + 1, n + 1) + + if start_search >= end_search: + remaining = n - current_idx + if remaining < min_length and indices: + # Want to merge remaining with previous chunk by removing last split + # But must verify the resulting final segment fits in max_length + new_last_split = indices[-2] if len(indices) >= 2 else 0 + if n - new_last_split <= max_length: + indices.pop() + return indices + best_end = n + else: + for end in range(start_search, end_search): + if end == n: + score = prior_fn(end - current_idx) + else: + score = probs[end - 1] * prior_fn(end - current_idx) + + if score > best_score: + best_score = score + best_end = end + + if best_end == -1: + best_end = min(current_idx + max_length, n) + + if best_end == n: + remaining = n - current_idx + if remaining < min_length and indices: + # Want to merge remaining with previous chunk by removing last split + # But must verify the resulting final segment fits in max_length + new_last_split = indices[-2] if len(indices) >= 2 else 0 + if n - new_last_split <= max_length: + indices.pop() + return indices + break + + indices.append(best_end) + current_idx = best_end + + return indices + + elif algorithm == "viterbi": + dp = np.full(n + 1, -float("inf")) + dp[0] = 0.0 + backpointers = np.zeros(n + 1, dtype=int) + + with np.errstate(divide="ignore"): + log_probs = np.log(probs) + + for i in range(1, n + 1): + start_j = max(0, i - max_length) + end_j = i - min_length + + if end_j < start_j: + continue + + for j in range(start_j, end_j + 1): + length = i - j + prior = prior_fn(length) + if prior <= 0: + continue + + log_prior = np.log(prior) + current_score = dp[j] + log_prior + + if i < n: + current_score += log_probs[i - 1] + + if current_score > dp[i]: + dp[i] = current_score + backpointers[i] = j + + indices = [] + curr = n + + if dp[n] == -float("inf"): + curr_idx = 0 + while curr_idx < n: + next_split = min(curr_idx + max_length, n) + # Use >= to handle min_length == max_length case + if next_split >= curr_idx + min_length: + indices.append(next_split) + curr_idx = next_split + + if indices and n - indices[-1] < min_length: + if len(indices) > 1: + prev_split = indices[-2] + # Try to merge with previous: remove last split if result fits max_length + if n - prev_split <= max_length: + indices.pop() + else: + # Can't merge - try to move split point to satisfy min_length + # New split should give final chunk >= min_length + desired_split = n - min_length + # But previous chunk must stay <= max_length + min_valid_split = prev_split + 1 # at least 1 char in prev chunk after prev_split + # And previous chunk must stay >= min_length (best effort) + adjusted_split = max(desired_split, min_valid_split) + # Ensure we don't exceed max_length for previous chunk + if adjusted_split - prev_split <= max_length: + indices[-1] = adjusted_split + # else: keep current split (best effort - one constraint must give) + elif n <= max_length: + # Single split that leaves short final - just remove it + return [] + return indices + + while curr > 0: + prev = backpointers[curr] + indices.append(curr) + curr = prev + + result = indices[::-1] + + if result and result[-1] == n: + result = result[:-1] + + if result: + last_chunk_len = n - result[-1] + if last_chunk_len < min_length: + if len(result) > 1: + prev_split = result[-2] + # Try to merge with previous: remove last split if result fits max_length + if n - prev_split <= max_length: + result.pop() + else: + # Can't merge - try to move split point to satisfy min_length + desired_split = n - min_length + min_valid_split = prev_split + 1 + adjusted_split = max(desired_split, min_valid_split) + # Ensure previous chunk doesn't exceed max_length + if adjusted_split - prev_split <= max_length: + result[-1] = adjusted_split + # else: keep current split (best effort) + else: + # Single split - try to adjust or remove + if n <= max_length: + return [] + else: + # Try to move split to satisfy min_length for final chunk + desired_split = n - min_length + if desired_split >= min_length: # first chunk also needs min_length + result[-1] = desired_split + + return result + + else: + raise ValueError(f"Unknown algorithm: {algorithm}") diff --git a/wtpsplit/utils/create_dummy_data.py b/wtpsplit/utils/create_dummy_data.py index 92cd1450..49d25359 100644 --- a/wtpsplit/utils/create_dummy_data.py +++ b/wtpsplit/utils/create_dummy_data.py @@ -8,12 +8,10 @@ "meta": { "train_data": ["train sentence 1", "train sentence 2"], }, - "data": [ - "train sentence 1", "train sentence 2" - ] + "data": ["train sentence 1", "train sentence 2"], } } } }, - "dummy-dataset.pth" -) \ No newline at end of file + "dummy-dataset.pth", +) diff --git a/wtpsplit/utils/priors.py b/wtpsplit/utils/priors.py new file mode 100644 index 00000000..9dfd62d4 --- /dev/null +++ b/wtpsplit/utils/priors.py @@ -0,0 +1,109 @@ +import json +import warnings +from pathlib import Path + +import numpy as np + +# Default for unknown languages +DEFAULT_SENTENCE_STATS = {"target_length": 70, "spread": 25} + +# Sentence statistics loaded from JSON (computed from Universal Dependencies) +# See: scripts/compute_sentence_stats.py +# +# Note: Some wtpsplit-supported languages are not in UD and will use defaults: +# bn, eo, fy, ig, km, kn, ku, mg, mn, ms, my, ne, pa, ps, sq, tg, uz, xh, yi, zu +_stats_path = Path(__file__).parent.parent / "data" / "sentence_stats.json" +if _stats_path.exists(): + with open(_stats_path, "r") as _f: + LANG_SENTENCE_STATS = json.load(_f).get("stats", {}) +else: + LANG_SENTENCE_STATS = {} + + +def get_language_defaults(lang_code=None): + """Get recommended target_length and spread for a given language.""" + if lang_code is None: + return DEFAULT_SENTENCE_STATS.copy() + if lang_code not in LANG_SENTENCE_STATS: + warnings.warn( + f"No sentence statistics for '{lang_code}', using defaults " + f"(target_length={DEFAULT_SENTENCE_STATS['target_length']}, " + f"spread={DEFAULT_SENTENCE_STATS['spread']}). " + f"You can override with explicit prior_kwargs.", + stacklevel=3, + ) + return DEFAULT_SENTENCE_STATS.copy() + return LANG_SENTENCE_STATS.get(lang_code, DEFAULT_SENTENCE_STATS).copy() + + +def create_prior_function(name, kwargs): + if name == "uniform": + max_length = kwargs.get("max_length") + + def prior(length): + if max_length is not None and length > max_length: + return 0.0 + return 1.0 + + return prior + + elif name == "clipped_polynomial": + # Quadratic falloff from target_length, clips to zero far from peak + # Use language-aware defaults if lang_code provided and target_length not specified + lang_defaults = get_language_defaults(kwargs.get("lang_code")) + target_length = kwargs.get("target_length", lang_defaults["target_length"]) + # Convert spread (tolerance in chars) to falloff coefficient + # Clips to zero at |length - target| = spread + spread = kwargs.get("spread", lang_defaults["spread"]) + falloff = 1.0 / (spread**2) + max_length = kwargs.get("max_length") + + def prior(length): + if max_length is not None and length > max_length: + return 0.0 + val = 1.0 - falloff * ((length - target_length) ** 2) + return max(val, 0.0) + + return prior + + elif name == "gaussian": + # Gaussian prior centered at target_length + # Use language-aware defaults if lang_code provided and target_length not specified + lang_defaults = get_language_defaults(kwargs.get("lang_code")) + target_length = kwargs.get("target_length", lang_defaults["target_length"]) + spread = kwargs.get("spread", lang_defaults["spread"]) + max_length = kwargs.get("max_length") + + def prior(length): + if max_length is not None and length > max_length: + return 0.0 + return np.exp(-0.5 * ((length - target_length) / spread) ** 2) + + return prior + + elif name == "lognormal": + # Log-normal prior - right-skewed distribution (more tolerant of longer segments) + # Use language-aware defaults if lang_code provided + lang_defaults = get_language_defaults(kwargs.get("lang_code")) + target_length = kwargs.get("target_length", lang_defaults["target_length"]) + # spread is in characters (like gaussian/clipped_polynomial) for consistency + spread = kwargs.get("spread", lang_defaults["spread"]) + max_length = kwargs.get("max_length") + + # Convert character-based spread to lognormal sigma + # sigma ≈ spread / target_length gives values in sensible 0.3-0.5 range + sigma = spread / target_length + mu = np.log(target_length) + sigma**2 + + def prior(length): + if length <= 0: + return 0.0 + if max_length is not None and length > max_length: + return 0.0 + log_len = np.log(length) + return np.exp(-0.5 * ((log_len - mu) / sigma) ** 2) / length + + return prior + + else: + raise ValueError(f"Unknown prior: {name}")