Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .changeset/abort-stream-usage.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
---
'@openai/agents-core': patch
'@openai/agents-openai': patch
---

fix: preserve streaming usage snapshots after AbortSignal cancellation
14 changes: 13 additions & 1 deletion packages/agents-core/src/run.ts
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ import { processModelResponseAsync } from './runner/modelOutputs';
import {
addStepToRunResult,
streamStepItemsToRunResult,
getUsageSnapshotFromStreamEvent,
isAbortError,
} from './runner/streaming';
import {
Expand Down Expand Up @@ -1149,6 +1150,7 @@ export class Runner extends RunHooks<any, AgentOutputType<unknown>> {
guardrailTracker.throwIfError();

let finalResponse: ModelResponse | undefined = undefined;
let latestStreamingUsageSnapshot: Usage | undefined = undefined;
let inputMarked = false;
const markInputOnce = () => {
if (inputMarked || !serverConversationTracker) {
Expand All @@ -1168,6 +1170,13 @@ export class Runner extends RunHooks<any, AgentOutputType<unknown>> {
);
inputMarked = true;
};
const applyStreamingUsageSnapshot = (usageSnapshot: Usage) => {
result.state._context.usage.replaceCurrentRequestSnapshot(
usageSnapshot,
latestStreamingUsageSnapshot,
);
latestStreamingUsageSnapshot = usageSnapshot;
};

sentInputToModel = true;
if (!delayStreamInputPersistence) {
Expand Down Expand Up @@ -1203,6 +1212,10 @@ export class Runner extends RunHooks<any, AgentOutputType<unknown>> {
)) {
guardrailTracker.throwIfError();
markInputOnce();
const usageSnapshot = getUsageSnapshotFromStreamEvent(event);
if (usageSnapshot) {
applyStreamingUsageSnapshot(usageSnapshot);
}
Comment thread
wsk-builds marked this conversation as resolved.
if (event.type === 'response_done') {
const parsed = StreamEventResponseCompleted.parse(event);
finalResponse = {
Expand All @@ -1211,7 +1224,6 @@ export class Runner extends RunHooks<any, AgentOutputType<unknown>> {
responseId: parsed.response.id,
requestId: parsed.response.requestId,
};
result.state._context.usage.add(finalResponse.usage);
}
if (result.cancelled) {
// When the user's code exits a loop to consume the stream, we need to break
Expand Down
22 changes: 22 additions & 0 deletions packages/agents-core/src/runner/streaming.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logger from '../logger';
import { Usage } from '../usage';
import { RunItemStreamEvent, RunItemStreamEventName } from '../events';
import {
RunHandoffCallItem,
Expand All @@ -12,6 +13,8 @@ import {
RunToolSearchCallItem,
RunToolSearchOutputItem,
} from '../items';
import type { StreamEvent } from '../types/protocol';
import { StreamEventResponseCompleted } from '../types/protocol';
import { StreamedRunResult } from '../result';

export const isAbortError = (error: unknown): boolean => {
Expand All @@ -33,6 +36,25 @@ export const isAbortError = (error: unknown): boolean => {
return false;
};

export function getUsageSnapshotFromStreamEvent(
event: StreamEvent,
): Usage | undefined {
if (event.type === 'response_done') {
return new Usage(StreamEventResponseCompleted.parse(event).response.usage);
}

if (event.type !== 'model') {
return undefined;
}

const usageSnapshot = event.providerData?.usageSnapshot;
if (!usageSnapshot || typeof usageSnapshot !== 'object') {
return undefined;
Comment thread
wsk-builds marked this conversation as resolved.
}

return new Usage(usageSnapshot as ConstructorParameters<typeof Usage>[0]);
}

function getRunItemStreamEventName(
item: RunItem,
): RunItemStreamEventName | undefined {
Expand Down
94 changes: 94 additions & 0 deletions packages/agents-core/src/usage.ts
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,100 @@ export class Usage {
);
}
}

/**
* Replaces the latest in-flight request usage snapshot with a newer snapshot.
*
* This is used for streaming providers that surface provisional usage before
* emitting a terminal response event.
*/
replaceCurrentRequestSnapshot(nextUsage: Usage, previousUsage?: Usage) {
if (!previousUsage) {
this.add(nextUsage);
return;
}

this.inputTokens += nextUsage.inputTokens - previousUsage.inputTokens;
this.outputTokens += nextUsage.outputTokens - previousUsage.outputTokens;
this.totalTokens += nextUsage.totalTokens - previousUsage.totalTokens;
Comment thread
wsk-builds marked this conversation as resolved.

this.#replaceLatestDetails(
this.inputTokensDetails,
previousUsage.inputTokensDetails[0],
nextUsage.inputTokensDetails[0],
);
this.#replaceLatestDetails(
this.outputTokensDetails,
previousUsage.outputTokensDetails[0],
nextUsage.outputTokensDetails[0],
);

this.#replaceLatestRequestUsageEntry(
Usage.#getSingleRequestUsageEntry(previousUsage),
Usage.#getSingleRequestUsageEntry(nextUsage),
);
}

static #getSingleRequestUsageEntry(usage: Usage): RequestUsage | undefined {
if (usage.requestUsageEntries?.length) {
return usage.requestUsageEntries[0];
Comment thread
wsk-builds marked this conversation as resolved.
Outdated
}

if (usage.requests === 1 && usage.totalTokens > 0) {
return new RequestUsage({
inputTokens: usage.inputTokens,
outputTokens: usage.outputTokens,
totalTokens: usage.totalTokens,
inputTokensDetails: usage.inputTokensDetails[0],
outputTokensDetails: usage.outputTokensDetails[0],
});
}

return undefined;
}

#replaceLatestDetails(
details: Array<Record<string, number>>,
previous: Record<string, number> | undefined,
next: Record<string, number> | undefined,
) {
if (previous && next) {
details[details.length - 1] = next;
Comment thread
wsk-builds marked this conversation as resolved.
Outdated
return;
}

if (previous) {
details.pop();
return;
}

if (next) {
details.push(next);
}
}

#replaceLatestRequestUsageEntry(
previous: RequestUsage | undefined,
next: RequestUsage | undefined,
) {
if (previous && next) {
this.requestUsageEntries![this.requestUsageEntries!.length - 1] = next;
return;
}

if (previous) {
this.requestUsageEntries?.pop();
if (this.requestUsageEntries?.length === 0) {
this.requestUsageEntries = undefined;
}
return;
}

if (next) {
this.requestUsageEntries ??= [];
this.requestUsageEntries.push(next);
}
}
}

export { RequestUsageData, UsageData };
100 changes: 100 additions & 0 deletions packages/agents-core/test/run.stream.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -755,6 +755,106 @@ describe('Runner.run (streaming)', () => {
warnSpy.mockRestore();
});

it('preserves the latest usage snapshot when AbortSignal cancels a streaming run', async () => {
const waitWithAbort = (ms: number, signal?: AbortSignal) =>
new Promise<void>((resolve, reject) => {
const timer = setTimeout(resolve, ms);
if (!signal) {
return;
}
if (signal.aborted) {
clearTimeout(timer);
const error = new Error('Aborted');
error.name = 'AbortError';
reject(error);
return;
}
signal.addEventListener(
'abort',
() => {
clearTimeout(timer);
const error = new Error('Aborted');
error.name = 'AbortError';
reject(error);
},
{ once: true },
);
});

class AbortableUsageSnapshotStreamingModel implements Model {
async getResponse(_req: ModelRequest): Promise<ModelResponse> {
return {
output: [fakeModelMessage('unused')],
usage: new Usage(),
};
}

async *getStreamedResponse(
request: ModelRequest,
): AsyncIterable<StreamEvent> {
yield {
type: 'model',
event: {
type: 'response.in_progress',
response: { id: 'resp_usage_in_progress' },
},
providerData: {
usageSnapshot: {
inputTokens: 21,
outputTokens: 13,
totalTokens: 34,
requestUsageEntries: [
{
inputTokens: 21,
outputTokens: 13,
totalTokens: 34,
endpoint: 'responses.create',
},
],
},
},
} as any;
await waitWithAbort(500, request.signal);
}
}

const controller = new AbortController();
const agent = new Agent({
name: 'AbortWithUsageSnapshot',
model: new AbortableUsageSnapshotStreamingModel(),
});

const result = await run(agent, 'go', {
stream: true,
signal: controller.signal,
});

const reader = (result.toStream() as any).getReader();
const first = await reader.read();
expect(first.done).toBe(false);

controller.abort();

await expect(result.completed).resolves.toBeUndefined();

expect(result.state.usage.inputTokens).toBe(21);
expect(result.state.usage.outputTokens).toBe(13);
expect(result.state.usage.totalTokens).toBe(34);
expect(result.state.usage.requestUsageEntries).toEqual([
{
inputTokens: 21,
outputTokens: 13,
totalTokens: 34,
inputTokensDetails: {},
outputTokensDetails: {},
endpoint: 'responses.create',
},
]);

const done = await reader.read();
expect(done.done).toBe(true);
});

it('cancels streaming promptly when the consumer cancels the stream', async () => {
const waitWithAbort = (ms: number, signal?: AbortSignal) =>
new Promise<void>((resolve, reject) => {
Expand Down
50 changes: 50 additions & 0 deletions packages/agents-core/test/usage.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -195,4 +195,54 @@ describe('Usage', () => {
},
]);
});

it('replaces the latest in-flight request usage snapshot without double counting', () => {
const aggregated = new Usage();
const partial = new Usage({
requests: 1,
inputTokens: 5,
outputTokens: 2,
totalTokens: 7,
requestUsageEntries: [
new RequestUsage({
inputTokens: 5,
outputTokens: 2,
totalTokens: 7,
endpoint: 'responses.create',
}),
],
});
const final = new Usage({
requests: 1,
inputTokens: 8,
outputTokens: 4,
totalTokens: 12,
requestUsageEntries: [
new RequestUsage({
inputTokens: 8,
outputTokens: 4,
totalTokens: 12,
endpoint: 'responses.create',
}),
],
});

aggregated.replaceCurrentRequestSnapshot(partial);
aggregated.replaceCurrentRequestSnapshot(final, partial);

expect(aggregated.requests).toBe(1);
expect(aggregated.inputTokens).toBe(8);
expect(aggregated.outputTokens).toBe(4);
expect(aggregated.totalTokens).toBe(12);
expect(aggregated.requestUsageEntries).toEqual([
{
inputTokens: 8,
outputTokens: 4,
totalTokens: 12,
inputTokensDetails: {},
outputTokensDetails: {},
endpoint: 'responses.create',
},
]);
});
});
Loading