diff --git a/src/split_markdown4gpt/splitter.py b/src/split_markdown4gpt/splitter.py index 68bd484..91913cb 100644 --- a/src/split_markdown4gpt/splitter.py +++ b/src/split_markdown4gpt/splitter.py @@ -1,4 +1,5 @@ #!/usr/bin/env python3 +import sys from collections import defaultdict from functools import lru_cache from io import TextIOWrapper @@ -35,6 +36,12 @@ def meta_data(md: str) -> tuple: OPENAI_MODELS = { + "gpt-5": 400000, + "gpt-5-mini": 400000, + "gpt-5-nano": 400000, + "gpt-4.1": 1000000, + "gpt-4o": 128000, + "gpt-4-turbo": 128000, "gpt-4": 8192, "gpt-4-32k": 32768, "gpt-4-32k-0613": 32768, @@ -70,7 +77,16 @@ class MarkdownLLMSplitter: def __init__( self, gptok_model: str = "gpt-3.5-turbo", gptok_limit: int = None ) -> None: - self.gptoker = tiktoken.encoding_for_model(gptok_model) + try: + self.gptoker = tiktoken.encoding_for_model(gptok_model) + except KeyError: + self.gptoker = tiktoken.get_encoding("cl100k_base") + if gptok_model not in OPENAI_MODELS: + print( + f"Warning: Model '{gptok_model}' not found in the list of known models. " + f"Token limits may be inaccurate.", + file=sys.stderr, + ) self.gptok_limit = gptok_limit or OPENAI_MODELS.get(gptok_model, 2048) self.md_meta = {} self.md_str = "" diff --git a/tests/test_splitter.py b/tests/test_splitter.py index 2d96678..13e9b33 100644 --- a/tests/test_splitter.py +++ b/tests/test_splitter.py @@ -1,5 +1,7 @@ import pytest from pathlib import Path +import sys +from io import StringIO from split_markdown4gpt.splitter import MarkdownLLMSplitter, split def test_split(): @@ -9,3 +11,25 @@ def test_split(): assert isinstance(sections, list) assert all(isinstance(section, str) for section in sections) +def test_new_openai_models(): + """Test that new OpenAI models are recognized and their limits are set correctly.""" + splitter_gpt5 = MarkdownLLMSplitter(gptok_model="gpt-5") + assert splitter_gpt5.gptok_limit == 400000 + splitter_gpt5_mini = MarkdownLLMSplitter(gptok_model="gpt-5-mini") + assert splitter_gpt5_mini.gptok_limit == 400000 + splitter_4_1 = MarkdownLLMSplitter(gptok_model="gpt-4.1") + assert splitter_4_1.gptok_limit == 1000000 + +def test_unknown_model_warning(): + """Test that a warning is printed for unknown models.""" + # Redirect stderr to capture the warning message + old_stderr = sys.stderr + sys.stderr = captured_stderr = StringIO() + + MarkdownLLMSplitter(gptok_model="claude-3-opus-20240229") + + # Restore stderr + sys.stderr = old_stderr + + warning_message = captured_stderr.getvalue() + assert "Warning: Model 'claude-3-opus-20240229' not found" in warning_message