Skip to content

Commit 837484e

Browse files
committed
Fix OpenAI ChatResponse usage calculation when toolcalling is used
- Fix OpenAI ChatModel's call() operation - When toolcalling is used, calculate cumulative usage from the preceding ChatResponses - Fix OpenAI ChatModel's stream() operation - Make sure that cumulative usage is calculated from the ChatResponse which has a valid usage - Use overlapping buffer to check and store the usage from the response that holds the usage. - Add tests for both call() and stream()
1 parent 6cfe5e7 commit 837484e

File tree

3 files changed

+153
-14
lines changed

3 files changed

+153
-14
lines changed

models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java

Lines changed: 65 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@
4242
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
4343
import org.springframework.ai.chat.metadata.EmptyUsage;
4444
import org.springframework.ai.chat.metadata.RateLimit;
45+
import org.springframework.ai.chat.metadata.Usage;
46+
import org.springframework.ai.chat.metadata.UsageUtils;
4547
import org.springframework.ai.chat.model.AbstractToolCallSupport;
4648
import org.springframework.ai.chat.model.ChatModel;
4749
import org.springframework.ai.chat.model.ChatResponse;
@@ -99,6 +101,7 @@
99101
* @author Mariusz Bernacki
100102
* @author luocongqiu
101103
* @author Thomas Vitale
104+
* @author Ilayaperumal Gopinathan
102105
* @see ChatModel
103106
* @see StreamingChatModel
104107
* @see OpenAiApi
@@ -215,6 +218,10 @@ public OpenAiChatModel(OpenAiApi openAiApi, OpenAiChatOptions options,
215218

216219
@Override
217220
public ChatResponse call(Prompt prompt) {
221+
return this.internalCall(prompt, null);
222+
}
223+
224+
public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatResponse) {
218225

219226
ChatCompletionRequest request = createRequest(prompt, false);
220227

@@ -259,8 +266,12 @@ public ChatResponse call(Prompt prompt) {
259266

260267
// Non function calling.
261268
RateLimit rateLimit = OpenAiResponseHeaderExtractor.extractAiResponseHeaders(completionEntity);
262-
263-
ChatResponse chatResponse = new ChatResponse(generations, from(completionEntity.getBody(), rateLimit));
269+
// Current usage
270+
OpenAiApi.Usage usage = completionEntity.getBody().usage();
271+
Usage currentChatResponseUsage = usage != null ? OpenAiUsage.from(usage) : new EmptyUsage();
272+
Usage accumulatedUsage = UsageUtils.getCumulativeUsage(currentChatResponseUsage, previousChatResponse);
273+
ChatResponse chatResponse = new ChatResponse(generations,
274+
from(completionEntity.getBody(), rateLimit, accumulatedUsage));
264275

265276
observationContext.setResponse(chatResponse);
266277

@@ -274,14 +285,18 @@ && isToolCall(response, Set.of(OpenAiApi.ChatCompletionFinishReason.TOOL_CALLS.n
274285
var toolCallConversation = handleToolCalls(prompt, response);
275286
// Recursively call the call method with the tool call message
276287
// conversation that contains the call responses.
277-
return this.call(new Prompt(toolCallConversation, prompt.getOptions()));
288+
return this.internalCall(new Prompt(toolCallConversation, prompt.getOptions()), response);
278289
}
279290

280291
return response;
281292
}
282293

283294
@Override
284295
public Flux<ChatResponse> stream(Prompt prompt) {
296+
return internalStream(prompt, null);
297+
}
298+
299+
public Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse previousChatResponse) {
285300
return Flux.deferContextual(contextView -> {
286301
ChatCompletionRequest request = createRequest(prompt, true);
287302

@@ -337,15 +352,43 @@ public Flux<ChatResponse> stream(Prompt prompt) {
337352
return buildGeneration(choice, metadata, request);
338353
}).toList();
339354
// @formatter:on
340-
341-
return new ChatResponse(generations, from(chatCompletion2, null));
355+
OpenAiApi.Usage usage = chatCompletion2.usage();
356+
Usage currentChatResponseUsage = usage != null ? OpenAiUsage.from(usage) : new EmptyUsage();
357+
Usage accumulatedUsage = UsageUtils.getCumulativeUsage(currentChatResponseUsage,
358+
previousChatResponse);
359+
return new ChatResponse(generations, from(chatCompletion2, null, accumulatedUsage));
342360
}
343361
catch (Exception e) {
344362
logger.error("Error processing chat completion", e);
345363
return new ChatResponse(List.of());
346364
}
347-
348-
}));
365+
// When in stream mode and enabled to include the usage, the OpenAI
366+
// Chat completion response would have the usage set only in its
367+
// final response. Hence, the following overlapping buffer is
368+
// created to store both the current and the subsequent response
369+
// to accumulate the usage from the subsequent response.
370+
}))
371+
.buffer(2, 1)
372+
.map(bufferList -> {
373+
ChatResponse firstResponse = bufferList.get(0);
374+
if (request.streamOptions() != null && request.streamOptions().includeUsage()) {
375+
if (bufferList.size() == 2) {
376+
ChatResponse secondResponse = bufferList.get(1);
377+
if (secondResponse != null && secondResponse.getMetadata() != null) {
378+
// This is the usage from the final Chat response for a
379+
// given Chat request.
380+
Usage usage = secondResponse.getMetadata().getUsage();
381+
if (!UsageUtils.isEmpty(usage)) {
382+
// Store the usage from the final response to the
383+
// penultimate response for accumulation.
384+
return new ChatResponse(firstResponse.getResults(),
385+
from(firstResponse.getMetadata(), usage));
386+
}
387+
}
388+
}
389+
}
390+
return firstResponse;
391+
});
349392

350393
// @formatter:off
351394
Flux<ChatResponse> flux = chatResponse.flatMap(response -> {
@@ -355,7 +398,7 @@ public Flux<ChatResponse> stream(Prompt prompt) {
355398
var toolCallConversation = handleToolCalls(prompt, response);
356399
// Recursively call the stream method with the tool call message
357400
// conversation that contains the call responses.
358-
return this.stream(new Prompt(toolCallConversation, prompt.getOptions()));
401+
return this.internalStream(new Prompt(toolCallConversation, prompt.getOptions()), response);
359402
}
360403
else {
361404
return Flux.just(response);
@@ -412,11 +455,11 @@ private Generation buildGeneration(Choice choice, Map<String, Object> metadata,
412455
return new Generation(assistantMessage, generationMetadataBuilder.build());
413456
}
414457

415-
private ChatResponseMetadata from(OpenAiApi.ChatCompletion result, RateLimit rateLimit) {
458+
private ChatResponseMetadata from(OpenAiApi.ChatCompletion result, RateLimit rateLimit, Usage usage) {
416459
Assert.notNull(result, "OpenAI ChatCompletionResult must not be null");
417460
var builder = ChatResponseMetadata.builder()
418461
.withId(result.id() != null ? result.id() : "")
419-
.withUsage(result.usage() != null ? OpenAiUsage.from(result.usage()) : new EmptyUsage())
462+
.withUsage(usage)
420463
.withModel(result.model() != null ? result.model() : "")
421464
.withKeyValue("created", result.created() != null ? result.created() : 0L)
422465
.withKeyValue("system-fingerprint", result.systemFingerprint() != null ? result.systemFingerprint() : "");
@@ -426,6 +469,18 @@ private ChatResponseMetadata from(OpenAiApi.ChatCompletion result, RateLimit rat
426469
return builder.build();
427470
}
428471

472+
private ChatResponseMetadata from(ChatResponseMetadata chatResponseMetadata, Usage usage) {
473+
Assert.notNull(chatResponseMetadata, "OpenAI ChatResponseMetadata must not be null");
474+
var builder = ChatResponseMetadata.builder()
475+
.withId(chatResponseMetadata.getId() != null ? chatResponseMetadata.getId() : "")
476+
.withUsage(usage)
477+
.withModel(chatResponseMetadata.getModel() != null ? chatResponseMetadata.getModel() : "");
478+
if (chatResponseMetadata.getRateLimit() != null) {
479+
builder.withRateLimit(chatResponseMetadata.getRateLimit());
480+
}
481+
return builder.build();
482+
}
483+
429484
/**
430485
* Convert the ChatCompletionChunk into a ChatCompletion. The Usage is set to null.
431486
* @param chunk the ChatCompletionChunk to convert
@@ -533,7 +588,6 @@ else if (message.getMessageType() == MessageType.TOOL) {
533588
OpenAiChatOptions.builder().withTools(this.getFunctionTools(enabledToolsToUse)).build(), request,
534589
ChatCompletionRequest.class);
535590
}
536-
537591
// Remove `streamOptions` from the request if it is not a streaming request
538592
if (request.streamOptions() != null && !stream) {
539593
logger.warn("Removing streamOptions from the request as it is not a streaming request!");

models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelIT.java

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
import org.springframework.ai.chat.messages.AssistantMessage;
4040
import org.springframework.ai.chat.messages.Message;
4141
import org.springframework.ai.chat.messages.UserMessage;
42+
import org.springframework.ai.chat.metadata.DefaultUsage;
4243
import org.springframework.ai.chat.metadata.EmptyUsage;
4344
import org.springframework.ai.chat.metadata.Usage;
4445
import org.springframework.ai.chat.model.ChatResponse;
@@ -385,6 +386,35 @@ void streamFunctionCallTest() {
385386
assertThat(content).containsAnyOf("15.0", "15");
386387
}
387388

389+
@Test
390+
void functionCallUsageTest() {
391+
392+
UserMessage userMessage = new UserMessage("What's the weather like in San Francisco, Tokyo, and Paris?");
393+
394+
List<Message> messages = new ArrayList<>(List.of(userMessage));
395+
396+
var promptOptions = OpenAiChatOptions.builder()
397+
// .withModel(OpenAiApi.ChatModel.GPT_4_TURBO_PREVIEW.getValue())
398+
.withFunctionCallbacks(List.of(FunctionCallback.builder()
399+
.function("getCurrentWeather", new MockWeatherService())
400+
.description("Get the weather in location")
401+
.inputType(MockWeatherService.Request.class)
402+
.build()))
403+
.build();
404+
405+
ChatResponse chatResponse = this.chatModel.call(new Prompt(messages, promptOptions));
406+
logger.info("Response: {}", chatResponse);
407+
Usage usage = chatResponse.getMetadata().getUsage();
408+
409+
logger.info("Usage: {}", usage);
410+
assertThat(usage).isNotNull();
411+
assertThat(usage).isNotInstanceOf(EmptyUsage.class);
412+
assertThat(usage).isInstanceOf(DefaultUsage.class);
413+
assertThat(usage.getPromptTokens()).isGreaterThan(450L).isLessThan(600L);
414+
assertThat(usage.getGenerationTokens()).isGreaterThan(230L).isLessThan(360L);
415+
assertThat(usage.getTotalTokens()).isGreaterThan(680L).isLessThan(900L);
416+
}
417+
388418
@Test
389419
void streamFunctionCallUsageTest() {
390420

@@ -403,13 +433,15 @@ void streamFunctionCallUsageTest() {
403433
.build();
404434

405435
Flux<ChatResponse> response = this.streamingChatModel.stream(new Prompt(messages, promptOptions));
406-
407-
Usage usage = response.blockLast().getMetadata().getUsage();
436+
Usage usage = response.last().block().getMetadata().getUsage();
408437

409438
logger.info("Usage: {}", usage);
410439
assertThat(usage).isNotNull();
411440
assertThat(usage).isNotInstanceOf(EmptyUsage.class);
412-
assertThat(usage).isInstanceOf(OpenAiUsage.class);
441+
assertThat(usage).isInstanceOf(DefaultUsage.class);
442+
assertThat(usage.getPromptTokens()).isGreaterThan(450L).isLessThan(600L);
443+
assertThat(usage.getGenerationTokens()).isGreaterThan(230L).isLessThan(360L);
444+
assertThat(usage.getTotalTokens()).isGreaterThan(680L).isLessThan(960L);
413445
}
414446

415447
@ParameterizedTest(name = "{0} : {displayName} ")
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
/*
2+
* Copyright 2024-2024 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.ai.chat.metadata;
18+
19+
import org.springframework.ai.chat.model.ChatResponse;
20+
21+
/**
22+
* An utility class to provide support methods handling {@link Usage}.
23+
*
24+
* @author Ilayaperumal Gopinathan
25+
*/
26+
public class UsageUtils {
27+
28+
public static Usage getCumulativeUsage(final Usage currentUsage, final ChatResponse previousChatResponse) {
29+
Long promptTokens = currentUsage.getPromptTokens().longValue();
30+
Long generationTokens = currentUsage.getGenerationTokens().longValue();
31+
Long totalTokens = currentUsage.getTotalTokens().longValue();
32+
// Make sure to accumulate the usage from the previous chat response.
33+
if (previousChatResponse != null && previousChatResponse.getMetadata() != null
34+
&& previousChatResponse.getMetadata().getUsage() != null) {
35+
Usage usageFromPreviousChatResponse = previousChatResponse.getMetadata().getUsage();
36+
promptTokens += usageFromPreviousChatResponse.getPromptTokens();
37+
generationTokens += usageFromPreviousChatResponse.getGenerationTokens();
38+
totalTokens += usageFromPreviousChatResponse.getTotalTokens();
39+
}
40+
return new DefaultUsage(promptTokens, generationTokens, totalTokens);
41+
}
42+
43+
public static boolean isEmpty(Usage usage) {
44+
if (usage == null) {
45+
return true;
46+
}
47+
else if (usage != null && usage.getTotalTokens() == 0L) {
48+
return true;
49+
}
50+
return false;
51+
}
52+
53+
}

0 commit comments

Comments
 (0)