Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion src/split_markdown4gpt/splitter.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#!/usr/bin/env python3
import sys
from collections import defaultdict
from functools import lru_cache
from io import TextIOWrapper
Expand Down Expand Up @@ -35,6 +36,8 @@ def meta_data(md: str) -> tuple:


OPENAI_MODELS = {
"gpt-4o": 128000,
"gpt-4-turbo": 128000,
"gpt-4": 8192,
"gpt-4-32k": 32768,
"gpt-4-32k-0613": 32768,
Expand Down Expand Up @@ -70,7 +73,16 @@ class MarkdownLLMSplitter:
def __init__(
self, gptok_model: str = "gpt-3.5-turbo", gptok_limit: int = None
) -> None:
self.gptoker = tiktoken.encoding_for_model(gptok_model)
try:
self.gptoker = tiktoken.encoding_for_model(gptok_model)
except KeyError:
self.gptoker = tiktoken.get_encoding("cl100k_base")
if gptok_model not in OPENAI_MODELS:
print(
f"Warning: Model '{gptok_model}' not found in the list of known models. "
f"Token limits may be inaccurate.",
file=sys.stderr,
)
self.gptok_limit = gptok_limit or OPENAI_MODELS.get(gptok_model, 2048)
self.md_meta = {}
self.md_str = ""
Expand Down
22 changes: 22 additions & 0 deletions tests/test_splitter.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import pytest
from pathlib import Path
import sys
from io import StringIO
from split_markdown4gpt.splitter import MarkdownLLMSplitter, split

def test_split():
Expand All @@ -9,3 +11,23 @@ def test_split():
assert isinstance(sections, list)
assert all(isinstance(section, str) for section in sections)

def test_new_openai_models():
"""Test that new OpenAI models are recognized and their limits are set correctly."""
splitter_4o = MarkdownLLMSplitter(gptok_model="gpt-4o")
assert splitter_4o.gptok_limit == 128000
splitter_4_turbo = MarkdownLLMSplitter(gptok_model="gpt-4-turbo")
assert splitter_4_turbo.gptok_limit == 128000

def test_unknown_model_warning():
"""Test that a warning is printed for unknown models."""
# Redirect stderr to capture the warning message
old_stderr = sys.stderr
sys.stderr = captured_stderr = StringIO()

MarkdownLLMSplitter(gptok_model="claude-3-opus-20240229")

# Restore stderr
sys.stderr = old_stderr

warning_message = captured_stderr.getvalue()
assert "Warning: Model 'claude-3-opus-20240229' not found" in warning_message