Skip to content

Commit fd8c05b

Browse files
GDaamnwillgdjones
authored andcommitted
Add support for previous_response_id from Responses API (pydantic#2756)
Co-authored-by: Douwe Maan <[email protected]> Implement stream cancellation
1 parent 773e1be commit fd8c05b

File tree

11 files changed

+591
-13
lines changed

11 files changed

+591
-13
lines changed

docs/models/openai.md

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,56 @@ As of 7:48 AM on Wednesday, April 2, 2025, in Tokyo, Japan, the weather is cloud
143143

144144
You can learn more about the differences between the Responses API and Chat Completions API in the [OpenAI API docs](https://platform.openai.com/docs/guides/responses-vs-chat-completions).
145145

146+
#### Referencing earlier responses
147+
148+
The Responses API supports referencing earlier model responses in a new request using a `previous_response_id` parameter, to ensure the full [conversation state](https://platform.openai.com/docs/guides/conversation-state?api-mode=responses#passing-context-from-the-previous-response) including [reasoning items](https://platform.openai.com/docs/guides/reasoning#keeping-reasoning-items-in-context) are kept in context. This is available through the `openai_previous_response_id` field in
149+
[`OpenAIResponsesModelSettings`][pydantic_ai.models.openai.OpenAIResponsesModelSettings].
150+
151+
```python
152+
from pydantic_ai import Agent
153+
from pydantic_ai.models.openai import OpenAIResponsesModel, OpenAIResponsesModelSettings
154+
155+
model = OpenAIResponsesModel('gpt-5')
156+
agent = Agent(model=model)
157+
158+
result = agent.run_sync('The secret is 1234')
159+
model_settings = OpenAIResponsesModelSettings(
160+
openai_previous_response_id=result.all_messages()[-1].provider_response_id
161+
)
162+
result = agent.run_sync('What is the secret code?', model_settings=model_settings)
163+
print(result.output)
164+
#> 1234
165+
```
166+
167+
By passing the `provider_response_id` from an earlier run, you can allow the model to build on its own prior reasoning without needing to resend the full message history.
168+
169+
##### Automatically referencing earlier responses
170+
171+
When the `openai_previous_response_id` field is set to `'auto'`, Pydantic AI will automatically select the most recent `provider_response_id` from message history and omit messages that came before it, letting the OpenAI API leverage server-side history instead for improved efficiency.
172+
173+
```python
174+
from pydantic_ai import Agent
175+
from pydantic_ai.models.openai import OpenAIResponsesModel, OpenAIResponsesModelSettings
176+
177+
model = OpenAIResponsesModel('gpt-5')
178+
agent = Agent(model=model)
179+
180+
result1 = agent.run_sync('Tell me a joke.')
181+
print(result1.output)
182+
#> Did you hear about the toothpaste scandal? They called it Colgate.
183+
184+
# When set to 'auto', the most recent provider_response_id
185+
# and messages after it are sent as request.
186+
model_settings = OpenAIResponsesModelSettings(openai_previous_response_id='auto')
187+
result2 = agent.run_sync(
188+
'Explain?',
189+
message_history=result1.new_messages(),
190+
model_settings=model_settings
191+
)
192+
print(result2.output)
193+
#> This is an excellent joke invented by Samuel Colvin, it needs no explanation.
194+
```
195+
146196
## OpenAI-compatible Models
147197

148198
Many providers and models are compatible with the OpenAI API, and can be used with `OpenAIChatModel` in Pydantic AI.

pydantic_ai_slim/pydantic_ai/_agent_graph.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -407,8 +407,12 @@ async def stream(
407407
)
408408
yield agent_stream
409409
# In case the user didn't manually consume the full stream, ensure it is fully consumed here,
410-
# otherwise usage won't be properly counted:
411-
async for _ in agent_stream:
410+
# However, if the stream was cancelled, we should not consume further.
411+
try:
412+
async for _ in agent_stream:
413+
pass
414+
except exceptions.StreamCancelled:
415+
# Stream was cancelled - don't consume further
412416
pass
413417

414418
model_response = streamed_response.get()

pydantic_ai_slim/pydantic_ai/exceptions.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
'UsageLimitExceeded',
2525
'ModelHTTPError',
2626
'FallbackExceptionGroup',
27+
'StreamCancelled',
2728
)
2829

2930

@@ -162,6 +163,14 @@ class FallbackExceptionGroup(ExceptionGroup):
162163
"""A group of exceptions that can be raised when all fallback models fail."""
163164

164165

166+
class StreamCancelled(Exception):
167+
"""Exception raised when a streaming response is cancelled."""
168+
169+
def __init__(self, message: str = 'Stream was cancelled'):
170+
self.message = message
171+
super().__init__(message)
172+
173+
165174
class ToolRetryError(Exception):
166175
"""Exception used to signal a `ToolRetry` message should be returned to the LLM."""
167176

pydantic_ai_slim/pydantic_ai/models/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -641,6 +641,14 @@ def timestamp(self) -> datetime:
641641
"""Get the timestamp of the response."""
642642
raise NotImplementedError()
643643

644+
async def cancel(self) -> None:
645+
"""Cancel the streaming response.
646+
647+
This should close the underlying network connection and cause any active iteration
648+
to raise a StreamCancelled exception. The default implementation is a no-op.
649+
"""
650+
pass
651+
644652

645653
ALLOW_MODEL_REQUESTS = True
646654
"""Whether to allow requests to models.

pydantic_ai_slim/pydantic_ai/models/openai.py

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from .._thinking_part import split_content_into_text_and_thinking
1818
from .._utils import guard_tool_call_id as _guard_tool_call_id, now_utc as _now_utc, number_to_datetime
1919
from ..builtin_tools import CodeExecutionTool, WebSearchTool
20-
from ..exceptions import UserError
20+
from ..exceptions import StreamCancelled, UserError
2121
from ..messages import (
2222
AudioUrl,
2323
BinaryContent,
@@ -222,6 +222,17 @@ class OpenAIResponsesModelSettings(OpenAIChatModelSettings, total=False):
222222
`medium`, and `high`.
223223
"""
224224

225+
openai_previous_response_id: Literal['auto'] | str
226+
"""The ID of a previous response from the model to use as the starting point for a continued conversation.
227+
228+
When set to `'auto'`, the request automatically uses the most recent
229+
`provider_response_id` from the message history and omits earlier messages.
230+
231+
This enables the model to use server-side conversation state and faithfully reference previous reasoning.
232+
See the [OpenAI Responses API documentation](https://platform.openai.com/docs/guides/reasoning#keeping-reasoning-items-in-context)
233+
for more information.
234+
"""
235+
225236

226237
@dataclass(init=False)
227238
class OpenAIChatModel(Model):
@@ -977,6 +988,10 @@ async def _responses_create(
977988
else:
978989
tool_choice = 'auto'
979990

991+
previous_response_id = model_settings.get('openai_previous_response_id')
992+
if previous_response_id == 'auto':
993+
previous_response_id, messages = self._get_previous_response_id_and_new_messages(messages)
994+
980995
instructions, openai_messages = await self._map_messages(messages, model_settings)
981996
reasoning = self._get_reasoning(model_settings)
982997

@@ -1027,6 +1042,7 @@ async def _responses_create(
10271042
truncation=model_settings.get('openai_truncation', NOT_GIVEN),
10281043
timeout=model_settings.get('timeout', NOT_GIVEN),
10291044
service_tier=model_settings.get('openai_service_tier', NOT_GIVEN),
1045+
previous_response_id=previous_response_id,
10301046
reasoning=reasoning,
10311047
user=model_settings.get('openai_user', NOT_GIVEN),
10321048
text=text or NOT_GIVEN,
@@ -1092,6 +1108,28 @@ def _map_tool_definition(self, f: ToolDefinition) -> responses.FunctionToolParam
10921108
),
10931109
}
10941110

1111+
def _get_previous_response_id_and_new_messages(
1112+
self, messages: list[ModelMessage]
1113+
) -> tuple[str | None, list[ModelMessage]]:
1114+
# When `openai_previous_response_id` is set to 'auto', the most recent
1115+
# `provider_response_id` from the message history is selected and all
1116+
# earlier messages are omitted. This allows the OpenAI SDK to reuse
1117+
# server-side history for efficiency. The returned tuple contains the
1118+
# `previous_response_id` (if found) and the trimmed list of messages.
1119+
previous_response_id = None
1120+
trimmed_messages: list[ModelMessage] = []
1121+
for m in reversed(messages):
1122+
if isinstance(m, ModelResponse) and m.provider_name == self.system:
1123+
previous_response_id = m.provider_response_id
1124+
break
1125+
else:
1126+
trimmed_messages.append(m)
1127+
1128+
if previous_response_id and trimmed_messages:
1129+
return previous_response_id, list(reversed(trimmed_messages))
1130+
else:
1131+
return None, messages
1132+
10951133
async def _map_messages( # noqa: C901
10961134
self, messages: list[ModelMessage], model_settings: OpenAIResponsesModelSettings
10971135
) -> tuple[str | NotGiven, list[responses.ResponseInputItemParam]]:
@@ -1309,9 +1347,14 @@ class OpenAIStreamedResponse(StreamedResponse):
13091347
_response: AsyncIterable[ChatCompletionChunk]
13101348
_timestamp: datetime
13111349
_provider_name: str
1350+
_cancelled: bool = field(default=False, init=False)
13121351

13131352
async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
13141353
async for chunk in self._response:
1354+
# Check for cancellation before processing each chunk
1355+
if self._cancelled:
1356+
raise StreamCancelled('OpenAI stream was cancelled')
1357+
13151358
self._usage += _map_usage(chunk)
13161359

13171360
if chunk.id and self.provider_response_id is None:
@@ -1380,6 +1423,14 @@ def timestamp(self) -> datetime:
13801423
"""Get the timestamp of the response."""
13811424
return self._timestamp
13821425

1426+
async def cancel(self) -> None:
1427+
"""Cancel the streaming response.
1428+
1429+
This marks the stream as cancelled, which will cause the iterator to raise
1430+
a StreamCancelled exception on the next iteration.
1431+
"""
1432+
self._cancelled = True
1433+
13831434

13841435
@dataclass
13851436
class OpenAIResponsesStreamedResponse(StreamedResponse):

pydantic_ai_slim/pydantic_ai/result.py

Lines changed: 42 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ class AgentStream(Generic[AgentDepsT, OutputDataT]):
5454

5555
_agent_stream_iterator: AsyncIterator[ModelResponseStreamEvent] | None = field(default=None, init=False)
5656
_initial_run_ctx_usage: RunUsage = field(init=False)
57+
_cancelled: bool = field(default=False, init=False)
5758

5859
def __post_init__(self):
5960
self._initial_run_ctx_usage = copy(self._run_ctx.usage)
@@ -123,6 +124,19 @@ def timestamp(self) -> datetime:
123124
"""Get the timestamp of the response."""
124125
return self._raw_stream_response.timestamp
125126

127+
async def cancel(self) -> None:
128+
"""Cancel the streaming response.
129+
130+
This will close the underlying network connection and cause any active iteration
131+
over the stream to raise a StreamCancelled exception.
132+
133+
Subsequent calls to cancel() are safe and will not raise additional exceptions.
134+
"""
135+
if not self._cancelled:
136+
self._cancelled = True
137+
# Cancel the underlying stream response
138+
await self._raw_stream_response.cancel()
139+
126140
async def get_output(self) -> OutputDataT:
127141
"""Stream the whole response, validate the output and return it."""
128142
async for _ in self:
@@ -227,8 +241,8 @@ async def _stream_text_deltas() -> AsyncIterator[str]:
227241
def __aiter__(self) -> AsyncIterator[ModelResponseStreamEvent]:
228242
"""Stream [`ModelResponseStreamEvent`][pydantic_ai.messages.ModelResponseStreamEvent]s."""
229243
if self._agent_stream_iterator is None:
230-
self._agent_stream_iterator = _get_usage_checking_stream_response(
231-
self._raw_stream_response, self._usage_limits, self.usage
244+
self._agent_stream_iterator = _get_cancellation_aware_stream_response(
245+
self._raw_stream_response, self._usage_limits, self.usage, lambda: self._cancelled
232246
)
233247

234248
return self._agent_stream_iterator
@@ -450,6 +464,18 @@ async def stream_responses(
450464
else:
451465
raise ValueError('No stream response or run result provided') # pragma: no cover
452466

467+
async def cancel(self) -> None:
468+
"""Cancel the streaming response.
469+
470+
This will close the underlying network connection and cause any active iteration
471+
over the stream to raise a StreamCancelled exception.
472+
473+
Subsequent calls to cancel() are safe and will not raise additional exceptions.
474+
"""
475+
if self._stream_response is not None:
476+
await self._stream_response.cancel()
477+
# If there's no stream response, this is a no-op (already completed)
478+
453479
async def get_output(self) -> OutputDataT:
454480
"""Stream the whole response, validate and return it."""
455481
if self._run_result is not None:
@@ -526,21 +552,27 @@ class FinalResult(Generic[OutputDataT]):
526552
__repr__ = _utils.dataclasses_no_defaults_repr
527553

528554

529-
def _get_usage_checking_stream_response(
555+
def _get_cancellation_aware_stream_response(
530556
stream_response: models.StreamedResponse,
531557
limits: UsageLimits | None,
532558
get_usage: Callable[[], RunUsage],
559+
is_cancelled: Callable[[], bool],
533560
) -> AsyncIterator[ModelResponseStreamEvent]:
534-
if limits is not None and limits.has_token_limits():
561+
"""Create an iterator that checks for cancellation and usage limits."""
535562

536-
async def _usage_checking_iterator():
537-
async for item in stream_response:
563+
async def _cancellation_aware_iterator():
564+
async for item in stream_response:
565+
# Check for cancellation first
566+
if is_cancelled():
567+
raise exceptions.StreamCancelled()
568+
569+
# Then check usage limits if needed
570+
if limits is not None and limits.has_token_limits():
538571
limits.check_tokens(get_usage())
539-
yield item
540572

541-
return _usage_checking_iterator()
542-
else:
543-
return aiter(stream_response)
573+
yield item
574+
575+
return _cancellation_aware_iterator()
544576

545577

546578
def _get_deferred_tool_requests(

0 commit comments

Comments
 (0)