diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java index ff0d28cc23c..b883f04ee61 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java @@ -42,6 +42,8 @@ import org.springframework.ai.chat.metadata.ChatResponseMetadata; import org.springframework.ai.chat.metadata.EmptyUsage; import org.springframework.ai.chat.metadata.RateLimit; +import org.springframework.ai.chat.metadata.Usage; +import org.springframework.ai.chat.metadata.UsageUtils; import org.springframework.ai.chat.model.AbstractToolCallSupport; import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.chat.model.ChatResponse; @@ -99,6 +101,7 @@ * @author Mariusz Bernacki * @author luocongqiu * @author Thomas Vitale + * @author Ilayaperumal Gopinathan * @see ChatModel * @see StreamingChatModel * @see OpenAiApi @@ -215,6 +218,10 @@ public OpenAiChatModel(OpenAiApi openAiApi, OpenAiChatOptions options, @Override public ChatResponse call(Prompt prompt) { + return this.internalCall(prompt, null); + } + + public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatResponse) { ChatCompletionRequest request = createRequest(prompt, false); @@ -259,8 +266,12 @@ public ChatResponse call(Prompt prompt) { // Non function calling. RateLimit rateLimit = OpenAiResponseHeaderExtractor.extractAiResponseHeaders(completionEntity); - - ChatResponse chatResponse = new ChatResponse(generations, from(completionEntity.getBody(), rateLimit)); + // Current usage + OpenAiApi.Usage usage = completionEntity.getBody().usage(); + Usage currentChatResponseUsage = usage != null ? OpenAiUsage.from(usage) : new EmptyUsage(); + Usage accumulatedUsage = UsageUtils.getCumulativeUsage(currentChatResponseUsage, previousChatResponse); + ChatResponse chatResponse = new ChatResponse(generations, + from(completionEntity.getBody(), rateLimit, accumulatedUsage)); observationContext.setResponse(chatResponse); @@ -274,7 +285,7 @@ && isToolCall(response, Set.of(OpenAiApi.ChatCompletionFinishReason.TOOL_CALLS.n var toolCallConversation = handleToolCalls(prompt, response); // Recursively call the call method with the tool call message // conversation that contains the call responses. - return this.call(new Prompt(toolCallConversation, prompt.getOptions())); + return this.internalCall(new Prompt(toolCallConversation, prompt.getOptions()), response); } return response; @@ -282,6 +293,10 @@ && isToolCall(response, Set.of(OpenAiApi.ChatCompletionFinishReason.TOOL_CALLS.n @Override public Flux stream(Prompt prompt) { + return internalStream(prompt, null); + } + + public Flux internalStream(Prompt prompt, ChatResponse previousChatResponse) { return Flux.deferContextual(contextView -> { ChatCompletionRequest request = createRequest(prompt, true); @@ -337,15 +352,43 @@ public Flux stream(Prompt prompt) { return buildGeneration(choice, metadata, request); }).toList(); // @formatter:on - - return new ChatResponse(generations, from(chatCompletion2, null)); + OpenAiApi.Usage usage = chatCompletion2.usage(); + Usage currentChatResponseUsage = usage != null ? OpenAiUsage.from(usage) : new EmptyUsage(); + Usage accumulatedUsage = UsageUtils.getCumulativeUsage(currentChatResponseUsage, + previousChatResponse); + return new ChatResponse(generations, from(chatCompletion2, null, accumulatedUsage)); } catch (Exception e) { logger.error("Error processing chat completion", e); return new ChatResponse(List.of()); } - - })); + // When in stream mode and enabled to include the usage, the OpenAI + // Chat completion response would have the usage set only in its + // final response. Hence, the following overlapping buffer is + // created to store both the current and the subsequent response + // to accumulate the usage from the subsequent response. + })) + .buffer(2, 1) + .map(bufferList -> { + ChatResponse firstResponse = bufferList.get(0); + if (request.streamOptions() != null && request.streamOptions().includeUsage()) { + if (bufferList.size() == 2) { + ChatResponse secondResponse = bufferList.get(1); + if (secondResponse != null && secondResponse.getMetadata() != null) { + // This is the usage from the final Chat response for a + // given Chat request. + Usage usage = secondResponse.getMetadata().getUsage(); + if (!UsageUtils.isEmpty(usage)) { + // Store the usage from the final response to the + // penultimate response for accumulation. + return new ChatResponse(firstResponse.getResults(), + from(firstResponse.getMetadata(), usage)); + } + } + } + } + return firstResponse; + }); // @formatter:off Flux flux = chatResponse.flatMap(response -> { @@ -355,7 +398,7 @@ public Flux stream(Prompt prompt) { var toolCallConversation = handleToolCalls(prompt, response); // Recursively call the stream method with the tool call message // conversation that contains the call responses. - return this.stream(new Prompt(toolCallConversation, prompt.getOptions())); + return this.internalStream(new Prompt(toolCallConversation, prompt.getOptions()), response); } else { return Flux.just(response); @@ -412,11 +455,11 @@ private Generation buildGeneration(Choice choice, Map metadata, return new Generation(assistantMessage, generationMetadataBuilder.build()); } - private ChatResponseMetadata from(OpenAiApi.ChatCompletion result, RateLimit rateLimit) { + private ChatResponseMetadata from(OpenAiApi.ChatCompletion result, RateLimit rateLimit, Usage usage) { Assert.notNull(result, "OpenAI ChatCompletionResult must not be null"); var builder = ChatResponseMetadata.builder() .withId(result.id() != null ? result.id() : "") - .withUsage(result.usage() != null ? OpenAiUsage.from(result.usage()) : new EmptyUsage()) + .withUsage(usage) .withModel(result.model() != null ? result.model() : "") .withKeyValue("created", result.created() != null ? result.created() : 0L) .withKeyValue("system-fingerprint", result.systemFingerprint() != null ? result.systemFingerprint() : ""); @@ -426,6 +469,18 @@ private ChatResponseMetadata from(OpenAiApi.ChatCompletion result, RateLimit rat return builder.build(); } + private ChatResponseMetadata from(ChatResponseMetadata chatResponseMetadata, Usage usage) { + Assert.notNull(chatResponseMetadata, "OpenAI ChatResponseMetadata must not be null"); + var builder = ChatResponseMetadata.builder() + .withId(chatResponseMetadata.getId() != null ? chatResponseMetadata.getId() : "") + .withUsage(usage) + .withModel(chatResponseMetadata.getModel() != null ? chatResponseMetadata.getModel() : ""); + if (chatResponseMetadata.getRateLimit() != null) { + builder.withRateLimit(chatResponseMetadata.getRateLimit()); + } + return builder.build(); + } + /** * Convert the ChatCompletionChunk into a ChatCompletion. The Usage is set to null. * @param chunk the ChatCompletionChunk to convert @@ -533,7 +588,6 @@ else if (message.getMessageType() == MessageType.TOOL) { OpenAiChatOptions.builder().withTools(this.getFunctionTools(enabledToolsToUse)).build(), request, ChatCompletionRequest.class); } - // Remove `streamOptions` from the request if it is not a streaming request if (request.streamOptions() != null && !stream) { logger.warn("Removing streamOptions from the request as it is not a streaming request!"); diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelIT.java index be0425662bf..4cb946d8681 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelIT.java @@ -39,6 +39,7 @@ import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.metadata.DefaultUsage; import org.springframework.ai.chat.metadata.EmptyUsage; import org.springframework.ai.chat.metadata.Usage; import org.springframework.ai.chat.model.ChatResponse; @@ -385,6 +386,35 @@ void streamFunctionCallTest() { assertThat(content).containsAnyOf("15.0", "15"); } + @Test + void functionCallUsageTest() { + + UserMessage userMessage = new UserMessage("What's the weather like in San Francisco, Tokyo, and Paris?"); + + List messages = new ArrayList<>(List.of(userMessage)); + + var promptOptions = OpenAiChatOptions.builder() + // .withModel(OpenAiApi.ChatModel.GPT_4_TURBO_PREVIEW.getValue()) + .withFunctionCallbacks(List.of(FunctionCallback.builder() + .function("getCurrentWeather", new MockWeatherService()) + .description("Get the weather in location") + .inputType(MockWeatherService.Request.class) + .build())) + .build(); + + ChatResponse chatResponse = this.chatModel.call(new Prompt(messages, promptOptions)); + logger.info("Response: {}", chatResponse); + Usage usage = chatResponse.getMetadata().getUsage(); + + logger.info("Usage: {}", usage); + assertThat(usage).isNotNull(); + assertThat(usage).isNotInstanceOf(EmptyUsage.class); + assertThat(usage).isInstanceOf(DefaultUsage.class); + assertThat(usage.getPromptTokens()).isGreaterThan(450L).isLessThan(600L); + assertThat(usage.getGenerationTokens()).isGreaterThan(230L).isLessThan(360L); + assertThat(usage.getTotalTokens()).isGreaterThan(680L).isLessThan(900L); + } + @Test void streamFunctionCallUsageTest() { @@ -403,13 +433,15 @@ void streamFunctionCallUsageTest() { .build(); Flux response = this.streamingChatModel.stream(new Prompt(messages, promptOptions)); - - Usage usage = response.blockLast().getMetadata().getUsage(); + Usage usage = response.last().block().getMetadata().getUsage(); logger.info("Usage: {}", usage); assertThat(usage).isNotNull(); assertThat(usage).isNotInstanceOf(EmptyUsage.class); - assertThat(usage).isInstanceOf(OpenAiUsage.class); + assertThat(usage).isInstanceOf(DefaultUsage.class); + assertThat(usage.getPromptTokens()).isGreaterThan(450L).isLessThan(600L); + assertThat(usage.getGenerationTokens()).isGreaterThan(230L).isLessThan(360L); + assertThat(usage.getTotalTokens()).isGreaterThan(680L).isLessThan(960L); } @ParameterizedTest(name = "{0} : {displayName} ") diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/metadata/UsageUtils.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/metadata/UsageUtils.java new file mode 100644 index 00000000000..8cc50e91334 --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/metadata/UsageUtils.java @@ -0,0 +1,53 @@ +/* + * Copyright 2024-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.chat.metadata; + +import org.springframework.ai.chat.model.ChatResponse; + +/** + * An utility class to provide support methods handling {@link Usage}. + * + * @author Ilayaperumal Gopinathan + */ +public class UsageUtils { + + public static Usage getCumulativeUsage(final Usage currentUsage, final ChatResponse previousChatResponse) { + Long promptTokens = currentUsage.getPromptTokens().longValue(); + Long generationTokens = currentUsage.getGenerationTokens().longValue(); + Long totalTokens = currentUsage.getTotalTokens().longValue(); + // Make sure to accumulate the usage from the previous chat response. + if (previousChatResponse != null && previousChatResponse.getMetadata() != null + && previousChatResponse.getMetadata().getUsage() != null) { + Usage usageFromPreviousChatResponse = previousChatResponse.getMetadata().getUsage(); + promptTokens += usageFromPreviousChatResponse.getPromptTokens(); + generationTokens += usageFromPreviousChatResponse.getGenerationTokens(); + totalTokens += usageFromPreviousChatResponse.getTotalTokens(); + } + return new DefaultUsage(promptTokens, generationTokens, totalTokens); + } + + public static boolean isEmpty(Usage usage) { + if (usage == null) { + return true; + } + else if (usage != null && usage.getTotalTokens() == 0L) { + return true; + } + return false; + } + +}