Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
import org.springframework.ai.chat.metadata.EmptyUsage;
import org.springframework.ai.chat.metadata.RateLimit;
import org.springframework.ai.chat.metadata.Usage;
import org.springframework.ai.chat.metadata.UsageUtils;
import org.springframework.ai.chat.model.AbstractToolCallSupport;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.model.ChatResponse;
Expand Down Expand Up @@ -99,6 +101,7 @@
* @author Mariusz Bernacki
* @author luocongqiu
* @author Thomas Vitale
* @author Ilayaperumal Gopinathan
* @see ChatModel
* @see StreamingChatModel
* @see OpenAiApi
Expand Down Expand Up @@ -215,6 +218,10 @@ public OpenAiChatModel(OpenAiApi openAiApi, OpenAiChatOptions options,

@Override
public ChatResponse call(Prompt prompt) {
return this.internalCall(prompt, null);
}

public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatResponse) {

ChatCompletionRequest request = createRequest(prompt, false);

Expand Down Expand Up @@ -259,8 +266,12 @@ public ChatResponse call(Prompt prompt) {

// Non function calling.
RateLimit rateLimit = OpenAiResponseHeaderExtractor.extractAiResponseHeaders(completionEntity);

ChatResponse chatResponse = new ChatResponse(generations, from(completionEntity.getBody(), rateLimit));
// Current usage
OpenAiApi.Usage usage = completionEntity.getBody().usage();
Usage currentChatResponseUsage = usage != null ? OpenAiUsage.from(usage) : new EmptyUsage();
Usage accumulatedUsage = UsageUtils.getCumulativeUsage(currentChatResponseUsage, previousChatResponse);
ChatResponse chatResponse = new ChatResponse(generations,
from(completionEntity.getBody(), rateLimit, accumulatedUsage));

observationContext.setResponse(chatResponse);

Expand All @@ -274,14 +285,18 @@ && isToolCall(response, Set.of(OpenAiApi.ChatCompletionFinishReason.TOOL_CALLS.n
var toolCallConversation = handleToolCalls(prompt, response);
// Recursively call the call method with the tool call message
// conversation that contains the call responses.
return this.call(new Prompt(toolCallConversation, prompt.getOptions()));
return this.internalCall(new Prompt(toolCallConversation, prompt.getOptions()), response);
}

return response;
}

@Override
public Flux<ChatResponse> stream(Prompt prompt) {
return internalStream(prompt, null);
}

public Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse previousChatResponse) {
return Flux.deferContextual(contextView -> {
ChatCompletionRequest request = createRequest(prompt, true);

Expand Down Expand Up @@ -337,15 +352,43 @@ public Flux<ChatResponse> stream(Prompt prompt) {
return buildGeneration(choice, metadata, request);
}).toList();
// @formatter:on

return new ChatResponse(generations, from(chatCompletion2, null));
OpenAiApi.Usage usage = chatCompletion2.usage();
Usage currentChatResponseUsage = usage != null ? OpenAiUsage.from(usage) : new EmptyUsage();
Usage accumulatedUsage = UsageUtils.getCumulativeUsage(currentChatResponseUsage,
previousChatResponse);
return new ChatResponse(generations, from(chatCompletion2, null, accumulatedUsage));
}
catch (Exception e) {
logger.error("Error processing chat completion", e);
return new ChatResponse(List.of());
}

}));
// When in stream mode and enabled to include the usage, the OpenAI
// Chat completion response would have the usage set only in its
// final response. Hence, the following overlapping buffer is
// created to store both the current and the subsequent response
// to accumulate the usage from the subsequent response.
}))
.buffer(2, 1)
.map(bufferList -> {
ChatResponse firstResponse = bufferList.get(0);
if (request.streamOptions() != null && request.streamOptions().includeUsage()) {
if (bufferList.size() == 2) {
ChatResponse secondResponse = bufferList.get(1);
if (secondResponse != null && secondResponse.getMetadata() != null) {
// This is the usage from the final Chat response for a
// given Chat request.
Usage usage = secondResponse.getMetadata().getUsage();
if (!UsageUtils.isEmpty(usage)) {
// Store the usage from the final response to the
// penultimate response for accumulation.
return new ChatResponse(firstResponse.getResults(),
from(firstResponse.getMetadata(), usage));
}
}
}
}
return firstResponse;
});

// @formatter:off
Flux<ChatResponse> flux = chatResponse.flatMap(response -> {
Expand All @@ -355,7 +398,7 @@ public Flux<ChatResponse> stream(Prompt prompt) {
var toolCallConversation = handleToolCalls(prompt, response);
// Recursively call the stream method with the tool call message
// conversation that contains the call responses.
return this.stream(new Prompt(toolCallConversation, prompt.getOptions()));
return this.internalStream(new Prompt(toolCallConversation, prompt.getOptions()), response);
}
else {
return Flux.just(response);
Expand Down Expand Up @@ -412,11 +455,11 @@ private Generation buildGeneration(Choice choice, Map<String, Object> metadata,
return new Generation(assistantMessage, generationMetadataBuilder.build());
}

private ChatResponseMetadata from(OpenAiApi.ChatCompletion result, RateLimit rateLimit) {
private ChatResponseMetadata from(OpenAiApi.ChatCompletion result, RateLimit rateLimit, Usage usage) {
Assert.notNull(result, "OpenAI ChatCompletionResult must not be null");
var builder = ChatResponseMetadata.builder()
.withId(result.id() != null ? result.id() : "")
.withUsage(result.usage() != null ? OpenAiUsage.from(result.usage()) : new EmptyUsage())
.withUsage(usage)
.withModel(result.model() != null ? result.model() : "")
.withKeyValue("created", result.created() != null ? result.created() : 0L)
.withKeyValue("system-fingerprint", result.systemFingerprint() != null ? result.systemFingerprint() : "");
Expand All @@ -426,6 +469,18 @@ private ChatResponseMetadata from(OpenAiApi.ChatCompletion result, RateLimit rat
return builder.build();
}

private ChatResponseMetadata from(ChatResponseMetadata chatResponseMetadata, Usage usage) {
Assert.notNull(chatResponseMetadata, "OpenAI ChatResponseMetadata must not be null");
var builder = ChatResponseMetadata.builder()
.withId(chatResponseMetadata.getId() != null ? chatResponseMetadata.getId() : "")
.withUsage(usage)
.withModel(chatResponseMetadata.getModel() != null ? chatResponseMetadata.getModel() : "");
if (chatResponseMetadata.getRateLimit() != null) {
builder.withRateLimit(chatResponseMetadata.getRateLimit());
}
return builder.build();
}

/**
* Convert the ChatCompletionChunk into a ChatCompletion. The Usage is set to null.
* @param chunk the ChatCompletionChunk to convert
Expand Down Expand Up @@ -533,7 +588,6 @@ else if (message.getMessageType() == MessageType.TOOL) {
OpenAiChatOptions.builder().withTools(this.getFunctionTools(enabledToolsToUse)).build(), request,
ChatCompletionRequest.class);
}

// Remove `streamOptions` from the request if it is not a streaming request
if (request.streamOptions() != null && !stream) {
logger.warn("Removing streamOptions from the request as it is not a streaming request!");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.metadata.DefaultUsage;
import org.springframework.ai.chat.metadata.EmptyUsage;
import org.springframework.ai.chat.metadata.Usage;
import org.springframework.ai.chat.model.ChatResponse;
Expand Down Expand Up @@ -385,6 +386,35 @@ void streamFunctionCallTest() {
assertThat(content).containsAnyOf("15.0", "15");
}

@Test
void functionCallUsageTest() {

UserMessage userMessage = new UserMessage("What's the weather like in San Francisco, Tokyo, and Paris?");

List<Message> messages = new ArrayList<>(List.of(userMessage));

var promptOptions = OpenAiChatOptions.builder()
// .withModel(OpenAiApi.ChatModel.GPT_4_TURBO_PREVIEW.getValue())
.withFunctionCallbacks(List.of(FunctionCallback.builder()
.function("getCurrentWeather", new MockWeatherService())
.description("Get the weather in location")
.inputType(MockWeatherService.Request.class)
.build()))
.build();

ChatResponse chatResponse = this.chatModel.call(new Prompt(messages, promptOptions));
logger.info("Response: {}", chatResponse);
Usage usage = chatResponse.getMetadata().getUsage();

logger.info("Usage: {}", usage);
assertThat(usage).isNotNull();
assertThat(usage).isNotInstanceOf(EmptyUsage.class);
assertThat(usage).isInstanceOf(DefaultUsage.class);
assertThat(usage.getPromptTokens()).isGreaterThan(450L).isLessThan(600L);
assertThat(usage.getGenerationTokens()).isGreaterThan(230L).isLessThan(360L);
assertThat(usage.getTotalTokens()).isGreaterThan(680L).isLessThan(900L);
}

@Test
void streamFunctionCallUsageTest() {

Expand All @@ -403,13 +433,15 @@ void streamFunctionCallUsageTest() {
.build();

Flux<ChatResponse> response = this.streamingChatModel.stream(new Prompt(messages, promptOptions));

Usage usage = response.blockLast().getMetadata().getUsage();
Usage usage = response.last().block().getMetadata().getUsage();

logger.info("Usage: {}", usage);
assertThat(usage).isNotNull();
assertThat(usage).isNotInstanceOf(EmptyUsage.class);
assertThat(usage).isInstanceOf(OpenAiUsage.class);
assertThat(usage).isInstanceOf(DefaultUsage.class);
assertThat(usage.getPromptTokens()).isGreaterThan(450L).isLessThan(600L);
assertThat(usage.getGenerationTokens()).isGreaterThan(230L).isLessThan(360L);
assertThat(usage.getTotalTokens()).isGreaterThan(680L).isLessThan(960L);
}

@ParameterizedTest(name = "{0} : {displayName} ")
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
/*
* Copyright 2024-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.ai.chat.metadata;

import org.springframework.ai.chat.model.ChatResponse;

/**
* An utility class to provide support methods handling {@link Usage}.
*
* @author Ilayaperumal Gopinathan
*/
public class UsageUtils {

public static Usage getCumulativeUsage(final Usage currentUsage, final ChatResponse previousChatResponse) {
Long promptTokens = currentUsage.getPromptTokens().longValue();
Long generationTokens = currentUsage.getGenerationTokens().longValue();
Long totalTokens = currentUsage.getTotalTokens().longValue();
// Make sure to accumulate the usage from the previous chat response.
if (previousChatResponse != null && previousChatResponse.getMetadata() != null
&& previousChatResponse.getMetadata().getUsage() != null) {
Usage usageFromPreviousChatResponse = previousChatResponse.getMetadata().getUsage();
promptTokens += usageFromPreviousChatResponse.getPromptTokens();
generationTokens += usageFromPreviousChatResponse.getGenerationTokens();
totalTokens += usageFromPreviousChatResponse.getTotalTokens();
}
return new DefaultUsage(promptTokens, generationTokens, totalTokens);
}

public static boolean isEmpty(Usage usage) {
if (usage == null) {
return true;
}
else if (usage != null && usage.getTotalTokens() == 0L) {
return true;
}
return false;
}

}
Loading