Skip to content

Commit e5cf75c

Browse files
authored
Merge pull request #1 from zain/codex/add-support-for-logprobs-in-agent-responses
Add logprob support for Responses API
2 parents 18cb55e + 41c2ffb commit e5cf75c

File tree

6 files changed

+104
-12
lines changed

6 files changed

+104
-12
lines changed

docs/ja/models/index.md

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ OpenAI の Responses API を使用する場合、`user` や `service_tier` な
103103
```python
104104
from agents import Agent, ModelSettings
105105

106-
english_agent = Agent(
106+
english_agent = Agent(
107107
name="English agent",
108108
instructions="You only speak English",
109109
model="gpt-4o",
@@ -114,6 +114,20 @@ english_agent = Agent(
114114
)
115115
```
116116

117+
Responses API でトークンの対数確率を取得したい場合は、
118+
`ModelSettings``top_logprobs` を設定してください。
119+
120+
```python
121+
from agents import Agent, ModelSettings
122+
123+
agent = Agent(
124+
name="English agent",
125+
instructions="You only speak English",
126+
model="gpt-4o",
127+
model_settings=ModelSettings(top_logprobs=2),
128+
)
129+
```
130+
117131
## 他の LLM プロバイダー使用時の一般的な問題
118132

119133
### Tracing クライアントの 401 エラー

docs/models/index.md

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,20 @@ english_agent = Agent(
109109
)
110110
```
111111

112+
You can also request token log probabilities when using the Responses API by
113+
setting `top_logprobs` in `ModelSettings`.
114+
115+
```python
116+
from agents import Agent, ModelSettings
117+
118+
agent = Agent(
119+
name="English agent",
120+
instructions="You only speak English",
121+
model="gpt-4o",
122+
model_settings=ModelSettings(top_logprobs=2),
123+
)
124+
```
125+
112126
## Common issues with using other LLM providers
113127

114128
### Tracing client error 401

src/agents/model_settings.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@
1717
class _OmitTypeAnnotation:
1818
@classmethod
1919
def __get_pydantic_core_schema__(
20-
cls,
21-
_source_type: Any,
22-
_handler: GetCoreSchemaHandler,
20+
cls,
21+
_source_type: Any,
22+
_handler: GetCoreSchemaHandler,
2323
) -> core_schema.CoreSchema:
2424
def validate_from_none(value: None) -> _Omit:
2525
return _Omit()
@@ -39,13 +39,14 @@ def validate_from_none(value: None) -> _Omit:
3939
from_none_schema,
4040
]
4141
),
42-
serialization=core_schema.plain_serializer_function_ser_schema(
43-
lambda instance: None
44-
),
42+
serialization=core_schema.plain_serializer_function_ser_schema(lambda instance: None),
4543
)
44+
45+
4646
Omit = Annotated[_Omit, _OmitTypeAnnotation]
4747
Headers: TypeAlias = Mapping[str, Union[str, Omit]]
4848

49+
4950
@dataclass
5051
class ModelSettings:
5152
"""Settings to use when calling an LLM.
@@ -107,6 +108,10 @@ class ModelSettings:
107108
"""Additional output data to include in the model response.
108109
[include parameter](https://platform.openai.com/docs/api-reference/responses/create#responses-create-include)"""
109110

111+
top_logprobs: int | None = None
112+
"""Number of top tokens to return logprobs for. Setting this will
113+
automatically include ``"message.output_text.logprobs"`` in the response."""
114+
110115
extra_query: Query | None = None
111116
"""Additional query fields to provide with the request.
112117
Defaults to None if not provided."""

src/agents/models/openai_responses.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import json
44
from collections.abc import AsyncIterator
55
from dataclasses import dataclass
6-
from typing import TYPE_CHECKING, Any, Literal, overload
6+
from typing import TYPE_CHECKING, Any, Literal, cast, overload
77

88
from openai import NOT_GIVEN, APIStatusError, AsyncOpenAI, AsyncStream, NotGiven
99
from openai.types import ChatModel
@@ -246,9 +246,12 @@ async def _fetch_response(
246246
converted_tools = Converter.convert_tools(tools, handoffs)
247247
response_format = Converter.get_response_format(output_schema)
248248

249-
include: list[ResponseIncludable] = converted_tools.includes
249+
include_set: set[str] = set(converted_tools.includes)
250250
if model_settings.response_include is not None:
251-
include = list({*include, *model_settings.response_include})
251+
include_set.update(model_settings.response_include)
252+
if model_settings.top_logprobs is not None:
253+
include_set.add("message.output_text.logprobs")
254+
include = cast(list[ResponseIncludable], list(include_set))
252255

253256
if _debug.DONT_LOG_MODEL_DATA:
254257
logger.debug("Calling LLM")
@@ -263,6 +266,10 @@ async def _fetch_response(
263266
f"Previous response id: {previous_response_id}\n"
264267
)
265268

269+
extra_args = dict(model_settings.extra_args or {})
270+
if model_settings.top_logprobs is not None:
271+
extra_args["top_logprobs"] = model_settings.top_logprobs
272+
266273
return await self._client.responses.create(
267274
previous_response_id=self._non_null_or_not_given(previous_response_id),
268275
instructions=self._non_null_or_not_given(system_instructions),
@@ -285,7 +292,7 @@ async def _fetch_response(
285292
store=self._non_null_or_not_given(model_settings.store),
286293
reasoning=self._non_null_or_not_given(model_settings.reasoning),
287294
metadata=self._non_null_or_not_given(model_settings.metadata),
288-
**(model_settings.extra_args or {}),
295+
**extra_args,
289296
)
290297

291298
def _get_client(self) -> AsyncOpenAI:

tests/model_settings/test_serialization.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ def test_all_fields_serialization() -> None:
4747
store=False,
4848
include_usage=False,
4949
response_include=["reasoning.encrypted_content"],
50+
top_logprobs=1,
5051
extra_query={"foo": "bar"},
5152
extra_body={"foo": "bar"},
5253
extra_headers={"foo": "bar"},
@@ -135,8 +136,8 @@ def test_extra_args_resolve_both_none() -> None:
135136
assert resolved.temperature == 0.5
136137
assert resolved.top_p == 0.9
137138

138-
def test_pydantic_serialization() -> None:
139139

140+
def test_pydantic_serialization() -> None:
140141
"""Tests whether ModelSettings can be serialized with Pydantic."""
141142

142143
# First, lets create a ModelSettings instance
@@ -153,6 +154,7 @@ def test_pydantic_serialization() -> None:
153154
metadata={"foo": "bar"},
154155
store=False,
155156
include_usage=False,
157+
top_logprobs=1,
156158
extra_query={"foo": "bar"},
157159
extra_body={"foo": "bar"},
158160
extra_headers={"foo": "bar"},

tests/test_logprobs.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
import pytest
2+
from openai.types.responses.response_usage import InputTokensDetails, OutputTokensDetails
3+
4+
from agents import ModelSettings, ModelTracing, OpenAIResponsesModel
5+
6+
7+
class DummyResponses:
8+
async def create(self, **kwargs):
9+
self.kwargs = kwargs
10+
11+
class DummyResponse:
12+
id = "dummy"
13+
output = []
14+
usage = type(
15+
"Usage",
16+
(),
17+
{
18+
"input_tokens": 0,
19+
"output_tokens": 0,
20+
"total_tokens": 0,
21+
"input_tokens_details": InputTokensDetails(cached_tokens=0),
22+
"output_tokens_details": OutputTokensDetails(reasoning_tokens=0),
23+
},
24+
)()
25+
26+
return DummyResponse()
27+
28+
29+
class DummyClient:
30+
def __init__(self):
31+
self.responses = DummyResponses()
32+
33+
34+
@pytest.mark.allow_call_model_methods
35+
@pytest.mark.asyncio
36+
async def test_top_logprobs_param_passed():
37+
client = DummyClient()
38+
model = OpenAIResponsesModel(model="gpt-4", openai_client=client) # type: ignore
39+
await model.get_response(
40+
system_instructions=None,
41+
input="hi",
42+
model_settings=ModelSettings(top_logprobs=2),
43+
tools=[],
44+
output_schema=None,
45+
handoffs=[],
46+
tracing=ModelTracing.DISABLED,
47+
previous_response_id=None,
48+
)
49+
assert client.responses.kwargs["top_logprobs"] == 2
50+
assert "message.output_text.logprobs" in client.responses.kwargs["include"]

0 commit comments

Comments
 (0)