diff --git a/.changeset/abort-stream-usage.md b/.changeset/abort-stream-usage.md new file mode 100644 index 000000000..042c3597f --- /dev/null +++ b/.changeset/abort-stream-usage.md @@ -0,0 +1,6 @@ +--- +'@openai/agents-core': patch +'@openai/agents-openai': patch +--- + +fix: preserve streaming usage snapshots across aborts and response_done replacement (#995) diff --git a/packages/agents-core/src/run.ts b/packages/agents-core/src/run.ts index ef96e09e8..1bf810225 100644 --- a/packages/agents-core/src/run.ts +++ b/packages/agents-core/src/run.ts @@ -59,6 +59,7 @@ import { processModelResponseAsync } from './runner/modelOutputs'; import { addStepToRunResult, streamStepItemsToRunResult, + getUsageSnapshotFromStreamEvent, isAbortError, } from './runner/streaming'; import { @@ -1149,6 +1150,7 @@ export class Runner extends RunHooks> { guardrailTracker.throwIfError(); let finalResponse: ModelResponse | undefined = undefined; + let latestStreamingUsageSnapshot: Usage | undefined = undefined; let inputMarked = false; const markInputOnce = () => { if (inputMarked || !serverConversationTracker) { @@ -1168,6 +1170,13 @@ export class Runner extends RunHooks> { ); inputMarked = true; }; + const applyStreamingUsageSnapshot = (usageSnapshot: Usage) => { + result.state._context.usage.replaceCurrentRequestSnapshot( + usageSnapshot, + latestStreamingUsageSnapshot, + ); + latestStreamingUsageSnapshot = usageSnapshot; + }; sentInputToModel = true; if (!delayStreamInputPersistence) { @@ -1203,6 +1212,10 @@ export class Runner extends RunHooks> { )) { guardrailTracker.throwIfError(); markInputOnce(); + const usageSnapshot = getUsageSnapshotFromStreamEvent(event); + if (usageSnapshot) { + applyStreamingUsageSnapshot(usageSnapshot); + } if (event.type === 'response_done') { const parsed = StreamEventResponseCompleted.parse(event); finalResponse = { @@ -1211,7 +1224,6 @@ export class Runner extends RunHooks> { responseId: parsed.response.id, requestId: parsed.response.requestId, }; - result.state._context.usage.add(finalResponse.usage); } if (result.cancelled) { // When the user's code exits a loop to consume the stream, we need to break diff --git a/packages/agents-core/src/runner/modelRetry.ts b/packages/agents-core/src/runner/modelRetry.ts index 58ea7fccb..08b9fe3b5 100644 --- a/packages/agents-core/src/runner/modelRetry.ts +++ b/packages/agents-core/src/runner/modelRetry.ts @@ -89,6 +89,75 @@ function addFailedRetryAttemptsToUsage( }); } +function toUsageSnapshotData( + usage: Usage, +): ConstructorParameters[0] { + return { + requests: usage.requests, + inputTokens: usage.inputTokens, + outputTokens: usage.outputTokens, + totalTokens: usage.totalTokens, + inputTokensDetails: usage.inputTokensDetails, + outputTokensDetails: usage.outputTokensDetails, + ...(usage.requestUsageEntries && usage.requestUsageEntries.length > 0 + ? { + requestUsageEntries: usage.requestUsageEntries.map((entry) => ({ + inputTokens: entry.inputTokens, + outputTokens: entry.outputTokens, + totalTokens: entry.totalTokens, + inputTokensDetails: entry.inputTokensDetails, + outputTokensDetails: entry.outputTokensDetails, + ...(entry.endpoint ? { endpoint: entry.endpoint } : {}), + })), + } + : {}), + }; +} + +function addFailedRetryAttemptsToStreamEvent( + event: StreamEvent, + failedRetryAttempts: number, +): StreamEvent { + if (failedRetryAttempts <= 0) { + return event; + } + + if (event.type === 'response_done') { + return { + ...event, + response: { + ...event.response, + usage: addFailedRetryAttemptsToUsage( + new Usage(event.response.usage), + failedRetryAttempts, + ), + }, + }; + } + + if (event.type !== 'model') { + return event; + } + + const usageSnapshot = event.providerData?.usageSnapshot; + if (!isRecord(usageSnapshot)) { + return event; + } + + return { + ...event, + providerData: { + ...(event.providerData ?? {}), + usageSnapshot: toUsageSnapshotData( + addFailedRetryAttemptsToUsage( + new Usage(usageSnapshot as ConstructorParameters[0]), + failedRetryAttempts, + ), + ), + }, + }; +} + function withRunnerManagedRetry(request: ModelRequest): ModelRequest { return Object.assign({}, request, { _internal: { @@ -736,20 +805,7 @@ export async function* getStreamedResponseWithRetry( emittedRawModelEvent = true; } emittedVisibleEvent = true; - if (event.type === 'response_done' && attempt > 1) { - yield { - ...event, - response: { - ...event.response, - usage: addFailedRetryAttemptsToUsage( - new Usage(event.response.usage), - attempt - 1, - ), - }, - }; - continue; - } - yield event; + yield addFailedRetryAttemptsToStreamEvent(event, attempt - 1); } return; } catch (error) { diff --git a/packages/agents-core/src/runner/streaming.ts b/packages/agents-core/src/runner/streaming.ts index 61511fe86..7e896684c 100644 --- a/packages/agents-core/src/runner/streaming.ts +++ b/packages/agents-core/src/runner/streaming.ts @@ -1,4 +1,5 @@ import logger from '../logger'; +import { Usage } from '../usage'; import { RunItemStreamEvent, RunItemStreamEventName } from '../events'; import { RunHandoffCallItem, @@ -12,6 +13,8 @@ import { RunToolSearchCallItem, RunToolSearchOutputItem, } from '../items'; +import type { StreamEvent } from '../types/protocol'; +import { StreamEventResponseCompleted } from '../types/protocol'; import { StreamedRunResult } from '../result'; export const isAbortError = (error: unknown): boolean => { @@ -33,6 +36,26 @@ export const isAbortError = (error: unknown): boolean => { return false; }; +export function getUsageSnapshotFromStreamEvent( + event: StreamEvent, +): Usage | undefined { + if (event.type === 'response_done') { + return new Usage(StreamEventResponseCompleted.parse(event).response.usage); + } + + if ( + event.type !== 'model' || + !event.providerData?.usageSnapshot || + typeof event.providerData.usageSnapshot !== 'object' + ) { + return undefined; + } + + return new Usage( + event.providerData.usageSnapshot as ConstructorParameters[0], + ); +} + function getRunItemStreamEventName( item: RunItem, ): RunItemStreamEventName | undefined { diff --git a/packages/agents-core/src/usage.ts b/packages/agents-core/src/usage.ts index ecd53029d..3cfffb8c7 100644 --- a/packages/agents-core/src/usage.ts +++ b/packages/agents-core/src/usage.ts @@ -216,6 +216,92 @@ export class Usage { ); } } + + /** + * Replaces the latest in-flight request usage snapshot with a newer snapshot. + * + * This is used for streaming providers that surface provisional usage before + * emitting a terminal response event. + */ + replaceCurrentRequestSnapshot(nextUsage: Usage, previousUsage?: Usage) { + if (!previousUsage) { + this.add(nextUsage); + return; + } + + this.requests += nextUsage.requests - previousUsage.requests; + this.inputTokens += nextUsage.inputTokens - previousUsage.inputTokens; + this.outputTokens += nextUsage.outputTokens - previousUsage.outputTokens; + this.totalTokens += nextUsage.totalTokens - previousUsage.totalTokens; + + this.#replaceLatestDetails( + this.inputTokensDetails, + previousUsage.inputTokensDetails, + nextUsage.inputTokensDetails, + ); + this.#replaceLatestDetails( + this.outputTokensDetails, + previousUsage.outputTokensDetails, + nextUsage.outputTokensDetails, + ); + + this.#replaceLatestRequestUsageEntries( + Usage.#getTrackedRequestUsageEntries(previousUsage), + Usage.#getTrackedRequestUsageEntries(nextUsage), + ); + } + + static #getTrackedRequestUsageEntries(usage: Usage): RequestUsage[] { + if (usage.requestUsageEntries?.length) { + return usage.requestUsageEntries.map((entry) => new RequestUsage(entry)); + } + + if (usage.requests === 1 && usage.totalTokens > 0) { + return [ + new RequestUsage({ + inputTokens: usage.inputTokens, + outputTokens: usage.outputTokens, + totalTokens: usage.totalTokens, + inputTokensDetails: usage.inputTokensDetails[0], + outputTokensDetails: usage.outputTokensDetails[0], + }), + ]; + } + + return []; + } + + #replaceLatestDetails( + details: Array>, + previous: Array>, + next: Array>, + ) { + details.splice( + Math.max(details.length - previous.length, 0), + previous.length, + ...next, + ); + } + + #replaceLatestRequestUsageEntries( + previous: RequestUsage[], + next: RequestUsage[], + ) { + if (previous.length > 0) { + const retainedEntries = + this.requestUsageEntries?.slice( + 0, + Math.max(this.requestUsageEntries.length - previous.length, 0), + ) ?? []; + this.requestUsageEntries = + retainedEntries.length > 0 ? retainedEntries : undefined; + } + + if (next.length > 0) { + this.requestUsageEntries ??= []; + this.requestUsageEntries.push(...next); + } + } } export { RequestUsageData, UsageData }; diff --git a/packages/agents-core/test/run.stream.test.ts b/packages/agents-core/test/run.stream.test.ts index 6856d2a9d..01f26af4d 100644 --- a/packages/agents-core/test/run.stream.test.ts +++ b/packages/agents-core/test/run.stream.test.ts @@ -27,6 +27,7 @@ import { OutputGuardrailTripwireTriggered, RunState, shellTool, + retryPolicies, } from '../src'; import { FakeModel, FakeModelProvider, fakeModelMessage } from './stubs'; import * as protocol from '../src/types/protocol'; @@ -755,6 +756,420 @@ describe('Runner.run (streaming)', () => { warnSpy.mockRestore(); }); + it('preserves the latest usage snapshot when AbortSignal cancels a streaming run', async () => { + const waitWithAbort = (ms: number, signal?: AbortSignal) => + new Promise((resolve, reject) => { + const timer = setTimeout(resolve, ms); + if (!signal) { + return; + } + if (signal.aborted) { + clearTimeout(timer); + const error = new Error('Aborted'); + error.name = 'AbortError'; + reject(error); + return; + } + signal.addEventListener( + 'abort', + () => { + clearTimeout(timer); + const error = new Error('Aborted'); + error.name = 'AbortError'; + reject(error); + }, + { once: true }, + ); + }); + + class AbortableUsageSnapshotStreamingModel implements Model { + async getResponse(_req: ModelRequest): Promise { + return { + output: [fakeModelMessage('unused')], + usage: new Usage(), + }; + } + + async *getStreamedResponse( + request: ModelRequest, + ): AsyncIterable { + yield { + type: 'model', + event: { + type: 'response.in_progress', + response: { id: 'resp_usage_in_progress' }, + }, + providerData: { + usageSnapshot: { + inputTokens: 21, + outputTokens: 13, + totalTokens: 34, + requestUsageEntries: [ + { + inputTokens: 21, + outputTokens: 13, + totalTokens: 34, + endpoint: 'responses.create', + }, + ], + }, + }, + } as any; + await waitWithAbort(500, request.signal); + } + } + + const controller = new AbortController(); + const agent = new Agent({ + name: 'AbortWithUsageSnapshot', + model: new AbortableUsageSnapshotStreamingModel(), + }); + + const result = await run(agent, 'go', { + stream: true, + signal: controller.signal, + }); + + const reader = (result.toStream() as any).getReader(); + const first = await reader.read(); + expect(first.done).toBe(false); + + controller.abort(); + + await expect(result.completed).resolves.toBeUndefined(); + + expect(result.state.usage.inputTokens).toBe(21); + expect(result.state.usage.outputTokens).toBe(13); + expect(result.state.usage.totalTokens).toBe(34); + expect(result.state.usage.requestUsageEntries).toEqual([ + { + inputTokens: 21, + outputTokens: 13, + totalTokens: 34, + inputTokensDetails: {}, + outputTokensDetails: {}, + endpoint: 'responses.create', + }, + ]); + + const done = await reader.read(); + expect(done.done).toBe(true); + }); + + it('preserves retry-adjusted usage when AbortSignal cancels after a retried streaming snapshot', async () => { + const waitWithAbort = (ms: number, signal?: AbortSignal) => + new Promise((resolve, reject) => { + const timer = setTimeout(resolve, ms); + if (!signal) { + return; + } + if (signal.aborted) { + clearTimeout(timer); + const error = new Error('Aborted'); + error.name = 'AbortError'; + reject(error); + return; + } + signal.addEventListener( + 'abort', + () => { + clearTimeout(timer); + const error = new Error('Aborted'); + error.name = 'AbortError'; + reject(error); + }, + { once: true }, + ); + }); + + class RetryThenAbortableUsageSnapshotStreamingModel implements Model { + #attempt = 0; + + async getResponse(_req: ModelRequest): Promise { + return { + output: [fakeModelMessage('unused')], + usage: new Usage(), + }; + } + + async *getStreamedResponse( + request: ModelRequest, + ): AsyncIterable { + this.#attempt += 1; + if (this.#attempt === 1) { + const error = new Error('Rate limited'); + (error as Error & { statusCode?: number }).statusCode = 429; + throw error; + } + + yield { + type: 'model', + event: { + type: 'response.in_progress', + response: { id: 'resp_retry_abort_snapshot' }, + }, + providerData: { + usageSnapshot: { + requests: 1, + inputTokens: 11, + outputTokens: 7, + totalTokens: 18, + requestUsageEntries: [ + { + inputTokens: 11, + outputTokens: 7, + totalTokens: 18, + endpoint: 'responses.create', + }, + ], + }, + }, + } as any; + await waitWithAbort(500, request.signal); + } + } + + const controller = new AbortController(); + const agent = new Agent({ + name: 'RetryThenAbortSnapshot', + model: new RetryThenAbortableUsageSnapshotStreamingModel(), + modelSettings: { + retry: { + maxRetries: 1, + backoff: { initialDelayMs: 0, jitter: false }, + policy: retryPolicies.httpStatus([429]), + }, + }, + }); + + const result = await run(agent, 'go', { + stream: true, + signal: controller.signal, + }); + + const reader = (result.toStream() as any).getReader(); + const first = await reader.read(); + expect(first.done).toBe(false); + + controller.abort(); + + await expect(result.completed).resolves.toBeUndefined(); + + expect(result.state.usage.requests).toBe(2); + expect(result.state.usage.inputTokens).toBe(11); + expect(result.state.usage.outputTokens).toBe(7); + expect(result.state.usage.totalTokens).toBe(18); + expect(result.state.usage.requestUsageEntries).toEqual([ + { + inputTokens: 0, + outputTokens: 0, + totalTokens: 0, + inputTokensDetails: {}, + outputTokensDetails: {}, + endpoint: 'responses.create', + }, + { + inputTokens: 11, + outputTokens: 7, + totalTokens: 18, + inputTokensDetails: {}, + outputTokensDetails: {}, + endpoint: 'responses.create', + }, + ]); + + const done = await reader.read(); + expect(done.done).toBe(true); + }); + + it('preserves retry-adjusted request usage when a streaming snapshot is replaced by response_done', async () => { + class RetryThenSnapshotStreamingModel implements Model { + #attempt = 0; + + async getResponse(_req: ModelRequest): Promise { + return { + output: [fakeModelMessage('unused')], + usage: new Usage(), + }; + } + + async *getStreamedResponse( + _request: ModelRequest, + ): AsyncIterable { + this.#attempt += 1; + if (this.#attempt === 1) { + const error = new Error('Rate limited'); + (error as Error & { statusCode?: number }).statusCode = 429; + throw error; + } + + yield { + type: 'model', + event: { + type: 'response.in_progress', + response: { id: 'resp_retry_snapshot' }, + }, + providerData: { + usageSnapshot: { + requests: 1, + inputTokens: 11, + outputTokens: 7, + totalTokens: 18, + requestUsageEntries: [ + { + inputTokens: 11, + outputTokens: 7, + totalTokens: 18, + endpoint: 'responses.create', + }, + ], + }, + }, + } as any; + yield { + type: 'response_done', + response: { + id: 'resp_retry_snapshot', + usage: { + requests: 1, + inputTokens: 13, + outputTokens: 8, + totalTokens: 21, + requestUsageEntries: [ + { + inputTokens: 13, + outputTokens: 8, + totalTokens: 21, + endpoint: 'responses.create', + }, + ], + }, + output: [fakeModelMessage('Recovered with retried snapshot')], + }, + } as any; + } + } + + const agent = new Agent({ + name: 'RetryThenSnapshot', + model: new RetryThenSnapshotStreamingModel(), + modelSettings: { + retry: { + maxRetries: 1, + backoff: { initialDelayMs: 0, jitter: false }, + policy: retryPolicies.httpStatus([429]), + }, + }, + }); + + const result = await run(agent, 'go', { stream: true }); + + for await (const _event of result.toStream()) { + // Exhaust the stream so completion reflects the final usage state. + } + await result.completed; + + expect(result.state.usage.requests).toBe(2); + expect(result.state.usage.inputTokens).toBe(13); + expect(result.state.usage.outputTokens).toBe(8); + expect(result.state.usage.totalTokens).toBe(21); + expect(result.state.usage.requestUsageEntries).toEqual([ + { + inputTokens: 0, + outputTokens: 0, + totalTokens: 0, + inputTokensDetails: {}, + outputTokensDetails: {}, + endpoint: 'responses.create', + }, + { + inputTokens: 13, + outputTokens: 8, + totalTokens: 21, + inputTokensDetails: {}, + outputTokensDetails: {}, + endpoint: 'responses.create', + }, + ]); + expect(result.finalOutput).toBe('Recovered with retried snapshot'); + }); + + it('replaces entire detail arrays when response_done supersedes a streaming usage snapshot', async () => { + class DetailArraySnapshotStreamingModel implements Model { + async getResponse(_req: ModelRequest): Promise { + return { + output: [fakeModelMessage('unused')], + usage: new Usage(), + }; + } + + async *getStreamedResponse( + _request: ModelRequest, + ): AsyncIterable { + yield { + type: 'model', + event: { + type: 'response.in_progress', + response: { id: 'resp_detail_array_snapshot' }, + }, + providerData: { + usageSnapshot: { + requests: 1, + inputTokens: 5, + outputTokens: 2, + totalTokens: 7, + inputTokensDetails: [{ cached_tokens: 1 }, { audio_tokens: 2 }], + outputTokensDetails: [{ reasoning_tokens: 1 }], + }, + }, + } as any; + yield { + type: 'response_done', + response: { + id: 'resp_detail_array_snapshot', + usage: { + requests: 1, + inputTokens: 8, + outputTokens: 4, + totalTokens: 12, + inputTokensDetails: [{ cached_tokens: 6 }], + outputTokensDetails: [ + { reasoning_tokens: 2 }, + { accepted_prediction_tokens: 1 }, + ], + }, + output: [fakeModelMessage('Final snapshot')], + }, + } as any; + } + } + + const agent = new Agent({ + name: 'DetailArraySnapshot', + model: new DetailArraySnapshotStreamingModel(), + }); + + const result = await run(agent, 'go', { stream: true }); + + for await (const _event of result.toStream()) { + // Exhaust the stream so completion reflects the final usage state. + } + await result.completed; + + expect(result.state.usage.requests).toBe(1); + expect(result.state.usage.inputTokens).toBe(8); + expect(result.state.usage.outputTokens).toBe(4); + expect(result.state.usage.totalTokens).toBe(12); + expect(result.state.usage.inputTokensDetails).toEqual([ + { cached_tokens: 6 }, + ]); + expect(result.state.usage.outputTokensDetails).toEqual([ + { reasoning_tokens: 2 }, + { accepted_prediction_tokens: 1 }, + ]); + expect(result.finalOutput).toBe('Final snapshot'); + }); + it('cancels streaming promptly when the consumer cancels the stream', async () => { const waitWithAbort = (ms: number, signal?: AbortSignal) => new Promise((resolve, reject) => { diff --git a/packages/agents-core/test/usage.test.ts b/packages/agents-core/test/usage.test.ts index bdeea2165..6db52627a 100644 --- a/packages/agents-core/test/usage.test.ts +++ b/packages/agents-core/test/usage.test.ts @@ -195,4 +195,162 @@ describe('Usage', () => { }, ]); }); + + it('replaces the latest in-flight request usage snapshot without double counting', () => { + const aggregated = new Usage(); + const partial = new Usage({ + requests: 1, + inputTokens: 5, + outputTokens: 2, + totalTokens: 7, + requestUsageEntries: [ + new RequestUsage({ + inputTokens: 5, + outputTokens: 2, + totalTokens: 7, + endpoint: 'responses.create', + }), + ], + }); + const final = new Usage({ + requests: 1, + inputTokens: 8, + outputTokens: 4, + totalTokens: 12, + requestUsageEntries: [ + new RequestUsage({ + inputTokens: 8, + outputTokens: 4, + totalTokens: 12, + endpoint: 'responses.create', + }), + ], + }); + + aggregated.replaceCurrentRequestSnapshot(partial); + aggregated.replaceCurrentRequestSnapshot(final, partial); + + expect(aggregated.requests).toBe(1); + expect(aggregated.inputTokens).toBe(8); + expect(aggregated.outputTokens).toBe(4); + expect(aggregated.totalTokens).toBe(12); + expect(aggregated.requestUsageEntries).toEqual([ + { + inputTokens: 8, + outputTokens: 4, + totalTokens: 12, + inputTokensDetails: {}, + outputTokensDetails: {}, + endpoint: 'responses.create', + }, + ]); + }); + + it('replaces a streaming snapshot with the full retry-adjusted usage', () => { + const aggregated = new Usage(); + const partial = new Usage({ + requests: 1, + inputTokens: 5, + outputTokens: 2, + totalTokens: 7, + requestUsageEntries: [ + new RequestUsage({ + inputTokens: 5, + outputTokens: 2, + totalTokens: 7, + endpoint: 'responses.create', + }), + ], + }); + const final = new Usage({ + requests: 2, + inputTokens: 8, + outputTokens: 4, + totalTokens: 12, + requestUsageEntries: [ + new RequestUsage({ + endpoint: 'responses.create', + }), + new RequestUsage({ + inputTokens: 8, + outputTokens: 4, + totalTokens: 12, + endpoint: 'responses.create', + }), + ], + }); + + aggregated.replaceCurrentRequestSnapshot(partial); + aggregated.replaceCurrentRequestSnapshot(final, partial); + + expect(aggregated.requests).toBe(2); + expect(aggregated.inputTokens).toBe(8); + expect(aggregated.outputTokens).toBe(4); + expect(aggregated.totalTokens).toBe(12); + expect(aggregated.requestUsageEntries).toEqual([ + { + inputTokens: 0, + outputTokens: 0, + totalTokens: 0, + inputTokensDetails: {}, + outputTokensDetails: {}, + endpoint: 'responses.create', + }, + { + inputTokens: 8, + outputTokens: 4, + totalTokens: 12, + inputTokensDetails: {}, + outputTokensDetails: {}, + endpoint: 'responses.create', + }, + ]); + }); + + it('replaces the full trailing detail slices when a snapshot changes length', () => { + const aggregated = new Usage({ + requests: 1, + inputTokens: 4, + outputTokens: 3, + totalTokens: 7, + inputTokensDetails: [{ baseline_input: 4 }], + outputTokensDetails: [{ baseline_output: 3 }], + }); + const partial = new Usage({ + requests: 1, + inputTokens: 5, + outputTokens: 2, + totalTokens: 7, + inputTokensDetails: [{ cached_tokens: 1 }, { audio_tokens: 2 }], + outputTokensDetails: [{ reasoning_tokens: 1 }], + }); + const final = new Usage({ + requests: 1, + inputTokens: 6, + outputTokens: 4, + totalTokens: 10, + inputTokensDetails: [{ cached_tokens: 6 }], + outputTokensDetails: [ + { reasoning_tokens: 2 }, + { accepted_prediction_tokens: 1 }, + ], + }); + + aggregated.replaceCurrentRequestSnapshot(partial); + aggregated.replaceCurrentRequestSnapshot(final, partial); + + expect(aggregated.requests).toBe(2); + expect(aggregated.inputTokens).toBe(10); + expect(aggregated.outputTokens).toBe(7); + expect(aggregated.totalTokens).toBe(17); + expect(aggregated.inputTokensDetails).toEqual([ + { baseline_input: 4 }, + { cached_tokens: 6 }, + ]); + expect(aggregated.outputTokensDetails).toEqual([ + { baseline_output: 3 }, + { reasoning_tokens: 2 }, + { accepted_prediction_tokens: 1 }, + ]); + }); }); diff --git a/packages/agents-openai/src/openaiResponsesModel.ts b/packages/agents-openai/src/openaiResponsesModel.ts index 609d142f7..2df2ba03b 100644 --- a/packages/agents-openai/src/openaiResponsesModel.ts +++ b/packages/agents-openai/src/openaiResponsesModel.ts @@ -3088,16 +3088,7 @@ export class OpenAIResponsesModel implements Model { }); const output: ModelResponse = { - usage: new Usage({ - inputTokens: response.usage?.input_tokens ?? 0, - outputTokens: response.usage?.output_tokens ?? 0, - totalTokens: response.usage?.total_tokens ?? 0, - inputTokensDetails: { ...response.usage?.input_tokens_details }, - outputTokensDetails: { ...response.usage?.output_tokens_details }, - requestUsageEntries: [ - toRequestUsageEntry(response.usage, 'responses.create'), - ], - }), + usage: new Usage(toProtocolUsage(response.usage)), output: convertToOutputItem(response.output), responseId: response.id, requestId: getOpenAIResponseRequestId(response), @@ -3150,20 +3141,7 @@ export class OpenAIResponsesModel implements Model { id: id, requestId: getOpenAIResponseRequestId(response), output: convertToOutputItem(output), - usage: { - inputTokens: usage?.input_tokens ?? 0, - outputTokens: usage?.output_tokens ?? 0, - totalTokens: usage?.total_tokens ?? 0, - inputTokensDetails: { - ...usage?.input_tokens_details, - }, - outputTokensDetails: { - ...usage?.output_tokens_details, - }, - requestUsageEntries: [ - toRequestUsageEntry(usage, 'responses.create'), - ], - }, + usage: toProtocolUsage(usage), providerData: remainingResponse, }, providerData: remainingEvent, @@ -3193,6 +3171,18 @@ export class OpenAIResponsesModel implements Model { event: event, providerData: { rawModelEventSource: OPENAI_RESPONSES_RAW_MODEL_EVENT_SOURCE, + ...(eventType === 'response.in_progress' && + (event as OpenAI.Responses.ResponseInProgressEvent).response?.usage + ? { + // Preserve the latest normalized usage snapshot so agents-core + // can retain usage when a streaming request is aborted before + // the terminal response event arrives. + usageSnapshot: toProtocolUsage( + (event as OpenAI.Responses.ResponseInProgressEvent).response + ?.usage, + ), + } + : {}), }, }; } @@ -3912,3 +3902,16 @@ function toRequestUsageEntry( endpoint, }); } + +function toProtocolUsage( + usage: OpenAI.Responses.ResponseUsage | undefined, +): protocol.UsageData { + return { + inputTokens: usage?.input_tokens ?? 0, + outputTokens: usage?.output_tokens ?? 0, + totalTokens: usage?.total_tokens ?? 0, + inputTokensDetails: { ...usage?.input_tokens_details }, + outputTokensDetails: { ...usage?.output_tokens_details }, + requestUsageEntries: [toRequestUsageEntry(usage, 'responses.create')], + }; +} diff --git a/packages/agents-openai/test/openaiResponsesModel.test.ts b/packages/agents-openai/test/openaiResponsesModel.test.ts index 1e225f8b1..07875bafa 100644 --- a/packages/agents-openai/test/openaiResponsesModel.test.ts +++ b/packages/agents-openai/test/openaiResponsesModel.test.ts @@ -3957,6 +3957,98 @@ describe('OpenAIResponsesModel', () => { }); }); + it('getStreamedResponse exposes response.in_progress usage snapshots on raw model events', async () => { + await withTrace('test', async () => { + const createdEvent: OpenAIResponseStreamEvent = { + type: 'response.created', + response: { id: 'res-stream-init' } as any, + sequence_number: 0, + }; + const inProgressEvent: OpenAIResponseStreamEvent = { + type: 'response.in_progress', + response: { + id: 'res-stream-progress', + output: [], + usage: { + input_tokens: 7, + output_tokens: 4, + total_tokens: 11, + input_tokens_details: { cached_tokens: 1 }, + output_tokens_details: { reasoning_tokens: 2 }, + }, + }, + sequence_number: 1, + } as any; + const completedEvent: OpenAIResponseStreamEvent = { + type: 'response.completed', + response: { + id: 'res-stream-progress', + output: [], + usage: { + input_tokens: 7, + output_tokens: 4, + total_tokens: 11, + input_tokens_details: { cached_tokens: 1 }, + output_tokens_details: { reasoning_tokens: 2 }, + }, + }, + sequence_number: 2, + } as any; + async function* fakeStream() { + yield createdEvent; + yield inProgressEvent; + yield completedEvent; + } + const createMock = vi.fn().mockResolvedValue(fakeStream()); + const fakeClient = { + responses: { create: createMock }, + } as unknown as OpenAI; + const model = new OpenAIResponsesModel(fakeClient, 'model-usage'); + + const request = { + systemInstructions: undefined, + input: 'payload', + modelSettings: {}, + tools: [], + outputType: 'text', + handoffs: [], + tracing: false, + signal: undefined, + }; + + const received: ResponseStreamEvent[] = []; + for await (const ev of model.getStreamedResponse(request as any)) { + received.push(ev); + } + + const inProgressRawEvent = received.find( + (ev) => + ev.type === 'model' && + (ev as any).event?.type === 'response.in_progress', + ); + expect(inProgressRawEvent).toBeDefined(); + expect( + (inProgressRawEvent as any).providerData?.usageSnapshot, + ).toMatchObject({ + inputTokens: 7, + outputTokens: 4, + totalTokens: 11, + inputTokensDetails: { cached_tokens: 1 }, + outputTokensDetails: { reasoning_tokens: 2 }, + requestUsageEntries: [ + { + inputTokens: 7, + outputTokens: 4, + totalTokens: 11, + inputTokensDetails: { cached_tokens: 1 }, + outputTokensDetails: { reasoning_tokens: 2 }, + endpoint: 'responses.create', + }, + ], + }); + }); + }); + it('getStreamedResponse preserves request IDs from HTTP streaming responses', async () => { await withTrace('test', async () => { const createdEvent: OpenAIResponseStreamEvent = {