Skip to content

Commit 6401569

Browse files
authored
Handle gpt-5-nano as reasoning model (#8693)
1 parent 43e241e commit 6401569

File tree

2 files changed

+13
-7
lines changed

2 files changed

+13
-7
lines changed

dspy/clients/lm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,8 @@ def __init__(
7878
# Handle model-specific configuration for different model families
7979
model_family = model.split("/")[-1].lower() if "/" in model else model.lower()
8080

81-
# Match pattern: o[1,3,4] at the start, optionally followed by -mini and anything else
82-
model_pattern = re.match(r"^(?:o([1345])|gpt-(5))(?:-mini)?", model_family)
81+
# Recognize OpenAI reasoning models (o1, o3, o4, gpt-5 family)
82+
model_pattern = re.match(r"^(?:o[1345]|gpt-5)(?:-(?:mini|nano))?", model_family)
8383

8484
if model_pattern:
8585
if max_tokens < 20000 or temperature != 1.0:

tests/clients/test_lm.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,9 @@ def test_reasoning_model_token_parameter():
226226
("openai/o1-2023-01-01", True),
227227
("openai/o3", True),
228228
("openai/o3-mini-2023-01-01", True),
229+
("openai/gpt-5", True),
230+
("openai/gpt-5-mini", True),
231+
("openai/gpt-5-nano", True),
229232
("openai/gpt-4", False),
230233
("anthropic/claude-2", False),
231234
]
@@ -245,19 +248,22 @@ def test_reasoning_model_token_parameter():
245248
assert "max_tokens" in lm.kwargs
246249
assert lm.kwargs["max_tokens"] == 1000
247250

248-
249-
def test_reasoning_model_requirements():
251+
@pytest.mark.parametrize("model_name", ["openai/o1", "openai/gpt-5-nano"])
252+
def test_reasoning_model_requirements(model_name):
250253
# Should raise assertion error if temperature or max_tokens requirements not met
251-
with pytest.raises(ValueError, match="reasoning models require passing temperature=1.0 and max_tokens >= 20000"):
254+
with pytest.raises(
255+
ValueError,
256+
match="reasoning models require passing temperature=1.0 and max_tokens >= 20000",
257+
):
252258
dspy.LM(
253-
model="openai/o1",
259+
model=model_name,
254260
temperature=0.7, # Should be 1.0
255261
max_tokens=1000, # Should be >= 20_000
256262
)
257263

258264
# Should pass with correct parameters
259265
lm = dspy.LM(
260-
model="openai/o1",
266+
model=model_name,
261267
temperature=1.0,
262268
max_tokens=20_000,
263269
)

0 commit comments

Comments
 (0)