diff --git a/.changeset/replay-aborted-tool-outputs.md b/.changeset/replay-aborted-tool-outputs.md new file mode 100644 index 000000000..2e7351ce4 --- /dev/null +++ b/.changeset/replay-aborted-tool-outputs.md @@ -0,0 +1,5 @@ +--- +'@openai/agents-core': patch +--- + +Replay managed tool outputs that completed before an abort, so aborted runs don't lose already-finished tool results on the next turn. diff --git a/packages/agents-core/src/run.ts b/packages/agents-core/src/run.ts index ef96e09e8..5c364636f 100644 --- a/packages/agents-core/src/run.ts +++ b/packages/agents-core/src/run.ts @@ -70,6 +70,10 @@ import { } from './runner/sessionPersistence'; import { resolveTurnAfterModelResponse } from './runner/turnResolution'; import { prepareTurn } from './runner/turnPreparation'; +import { + clearManagedConversationSupplementalItems, + queueManagedConversationSupplementalItems, +} from './runner/turnPreparation'; import { applyTurnResult, handleInterruptedOutcome, @@ -85,6 +89,71 @@ import type { import { tryHandleRunError } from './runner/errorHandlers'; import type { RunErrorHandlers } from './runner/errorHandlers'; +function extractPendingFunctionCallOutputsFromRawModelEvents( + rawModelEvents: unknown[], +): AgentInputItem[] { + const pendingItems: AgentInputItem[] = []; + const seenCallIds = new Set(); + + const getRawItem = (event: unknown): Record | undefined => { + if (!event || typeof event !== 'object') { + return undefined; + } + + const rawEvent = event as Record; + const eventType = rawEvent.type; + if ( + eventType !== 'response.output_item.added' && + eventType !== 'response.output_item.done' + ) { + return undefined; + } + + const item = + (rawEvent.item as Record | undefined) ?? + (rawEvent.output_item as Record | undefined); + + return item && typeof item === 'object' ? item : undefined; + }; + + for (const rawEvent of rawModelEvents) { + const rawItem = getRawItem(rawEvent); + if (!rawItem || rawItem.type !== 'function_call') { + continue; + } + + const callId = + typeof rawItem.call_id === 'string' + ? rawItem.call_id + : typeof rawItem.callId === 'string' + ? rawItem.callId + : undefined; + const name = + typeof rawItem.name === 'string' ? rawItem.name : undefined; + + if (!callId || !name || seenCallIds.has(callId)) { + continue; + } + + seenCallIds.add(callId); + pendingItems.push({ + type: 'function_call_result', + name, + ...(typeof rawItem.namespace === 'string' + ? { namespace: rawItem.namespace } + : {}), + callId, + status: 'completed', + output: { + type: 'text', + text: 'aborted', + }, + }); + } + + return pendingItems; +} + export type { CallModelInputFilter, CallModelInputFilterArgs, @@ -1150,6 +1219,7 @@ export class Runner extends RunHooks> { let finalResponse: ModelResponse | undefined = undefined; let inputMarked = false; + const rawModelEvents: unknown[] = []; const markInputOnce = () => { if (inputMarked || !serverConversationTracker) { return; @@ -1212,6 +1282,8 @@ export class Runner extends RunHooks> { requestId: parsed.response.requestId, }; result.state._context.usage.add(finalResponse.usage); + } else if (event.type === 'model') { + rawModelEvents.push(event.event); } if (result.cancelled) { // When the user's code exits a loop to consume the stream, we need to break @@ -1226,6 +1298,16 @@ export class Runner extends RunHooks> { if (sentInputToModel) { markInputOnce(); } + if (serverConversationTracker?.conversationId) { + const pendingItems = + extractPendingFunctionCallOutputsFromRawModelEvents( + rawModelEvents, + ); + queueManagedConversationSupplementalItems( + serverConversationTracker.conversationId, + pendingItems, + ); + } await awaitGuardrailsAndPersistInput(); return; } @@ -1236,6 +1318,12 @@ export class Runner extends RunHooks> { markInputOnce(); } + if (serverConversationTracker?.conversationId && inputMarked) { + clearManagedConversationSupplementalItems( + serverConversationTracker.conversationId, + ); + } + await awaitGuardrailsAndPersistInput(); if (result.cancelled) { diff --git a/packages/agents-core/src/runner/turnPreparation.ts b/packages/agents-core/src/runner/turnPreparation.ts index cb8fd4cce..0a1dcb0f5 100644 --- a/packages/agents-core/src/runner/turnPreparation.ts +++ b/packages/agents-core/src/runner/turnPreparation.ts @@ -135,15 +135,63 @@ const managedConversationSupplementalItemsCache = new WeakMap< ProcessedResponse, AgentInputItem[] >(); +const pendingManagedConversationAbortItems = new Map(); + +export function queueManagedConversationSupplementalItems( + conversationId: string | undefined, + items: AgentInputItem[], +): void { + if (!conversationId || items.length === 0) { + return; + } + + const existing = pendingManagedConversationAbortItems.get(conversationId) ?? []; + const merged = [...existing]; + const seenCallIds = new Set( + existing.flatMap((item) => + item.type === 'function_call_result' && typeof item.callId === 'string' + ? [item.callId] + : [], + ), + ); + + for (const item of items) { + if ( + item.type !== 'function_call_result' || + typeof item.callId !== 'string' || + seenCallIds.has(item.callId) + ) { + continue; + } + merged.push(item); + seenCallIds.add(item.callId); + } + + if (merged.length > 0) { + pendingManagedConversationAbortItems.set(conversationId, merged); + } +} + +export function clearManagedConversationSupplementalItems( + conversationId: string | undefined, +): void { + if (!conversationId) { + return; + } + pendingManagedConversationAbortItems.delete(conversationId); +} export function getManagedConversationSupplementalItems< TContext, TAgent extends Agent, >(state: RunState): AgentInputItem[] { + const pendingAbortItems = state._conversationId + ? pendingManagedConversationAbortItems.get(state._conversationId) ?? [] + : []; const processedResponse = state._lastProcessedResponse; const handoffs = processedResponse?.handoffs; if (!handoffs || handoffs.length <= 1) { - return []; + return pendingAbortItems; } const acceptedCallId = handoffs[0]?.toolCall.callId; @@ -156,13 +204,15 @@ export function getManagedConversationSupplementalItems< item.rawItem.callId === acceptedCallId, ); if (!acceptedHandoffOutputStillPresent) { - return []; + return pendingAbortItems; } const cached = managedConversationSupplementalItemsCache.get(processedResponse); if (cached) { - return cached; + return pendingAbortItems.length > 0 + ? [...pendingAbortItems, ...cached] + : cached; } // Server-managed transcripts still contain ignored handoff calls from the last response. @@ -173,7 +223,7 @@ export function getManagedConversationSupplementalItems< getToolCallOutputItem(toolCall, IGNORED_HANDOFF_OUTPUT_MESSAGE), ); managedConversationSupplementalItemsCache.set(processedResponse, items); - return items; + return pendingAbortItems.length > 0 ? [...pendingAbortItems, ...items] : items; } async function runInputGuardrailsForTurn< diff --git a/packages/agents-core/test/run.stream.test.ts b/packages/agents-core/test/run.stream.test.ts index 6856d2a9d..0dbbdcb09 100644 --- a/packages/agents-core/test/run.stream.test.ts +++ b/packages/agents-core/test/run.stream.test.ts @@ -1935,6 +1935,130 @@ describe('Runner.run (streaming)', () => { ]); }); + it('replays synthetic tool outputs after an aborted managed conversation turn', async () => { + class AbortAfterFunctionCallStreamingModel implements Model { + public readonly requests: ModelRequest[] = []; + private attempt = 0; + + async getResponse(): Promise { + throw new Error('not used'); + } + + async *getStreamedResponse( + request: ModelRequest, + ): AsyncIterable { + this.requests.push({ + ...request, + input: Array.isArray(request.input) + ? (JSON.parse(JSON.stringify(request.input)) as AgentInputItem[]) + : request.input, + }); + this.attempt += 1; + + if (this.attempt === 1) { + yield { type: 'response_started' } as any; + yield { + type: 'model', + event: { + type: 'response.output_item.done', + item: { + type: 'function_call', + id: 'fc_abort_1', + call_id: 'call_abort_1', + name: 'test', + arguments: '{"test":"abort"}', + }, + }, + providerData: { + rawModelEventSource: 'openai-responses', + }, + } as any; + + const abortError = new Error('aborted'); + (abortError as Error & { name: string }).name = 'AbortError'; + const signal = request.signal as AbortSignal | undefined; + await new Promise((_resolve, reject) => { + if (signal?.aborted) { + reject(abortError); + return; + } + const onAbort = () => { + signal?.removeEventListener('abort', onAbort); + reject(abortError); + }; + signal?.addEventListener('abort', onAbort, { once: true }); + }); + return; + } + + yield { + type: 'response_done', + response: { + id: 'resp-after-abort', + usage: { + requests: 1, + inputTokens: 0, + outputTokens: 0, + totalTokens: 0, + }, + output: [fakeModelMessage('done after abort')], + }, + } as any; + } + } + + const model = new AbortAfterFunctionCallStreamingModel(); + const agent = new Agent({ + name: 'ManagedAbortRecovery', + model, + tools: [serverTool], + }); + const runner = new Runner(); + const controller = new AbortController(); + + const firstRun = await runner.run(agent, 'hi', { + stream: true, + conversationId: 'conv-abort-recovery', + signal: controller.signal, + }); + const reader = (firstRun.toStream() as any).getReader(); + + while (true) { + const { done, value } = await reader.read(); + if (done) { + break; + } + if ( + value?.type === 'raw_model_stream_event' && + value.data?.type === 'model' && + value.data?.event?.type === 'response.output_item.done' + ) { + controller.abort(); + break; + } + } + await firstRun.completed; + + const secondRun = await runner.run(agent, 'resume after abort', { + stream: true, + conversationId: 'conv-abort-recovery', + }); + await drain(secondRun); + + expect(secondRun.finalOutput).toBe('done after abort'); + expect(model.requests).toHaveLength(2); + expect(getRequestInputItems(model.requests[1])).toEqual([ + expect.objectContaining({ + role: 'user', + }), + expect.objectContaining({ + type: 'function_call_result', + callId: 'call_abort_1', + name: 'test', + }), + ]); + }); + it('does not replay orphan hosted shell calls in default streamed multi-turn runs', async () => { const hostedShell = shellTool({ environment: { type: 'container_auto' },