Skip to content

Commit 7bbd3ef

Browse files
ilayaperumalgtzolov
authored andcommitted
Fix Azure OpenAI chat model function calling token usage
- Fix Azure OpenAI chat model's functioncalling to report accumulated token usage - Fix both call() and stream() operations - For streaming operation, use buffering to store the usage from the last response when stream option include usage is enabled - Add tests
1 parent cadea8b commit 7bbd3ef

File tree

2 files changed

+141
-10
lines changed

2 files changed

+141
-10
lines changed

models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java

Lines changed: 107 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
import com.azure.ai.openai.models.ChatRequestToolMessage;
5252
import com.azure.ai.openai.models.ChatRequestUserMessage;
5353
import com.azure.ai.openai.models.CompletionsFinishReason;
54+
import com.azure.ai.openai.models.CompletionsUsage;
5455
import com.azure.ai.openai.models.ContentFilterResultsForPrompt;
5556
import com.azure.ai.openai.models.FunctionCall;
5657
import com.azure.core.util.BinaryData;
@@ -70,6 +71,7 @@
7071
import org.springframework.ai.chat.metadata.PromptMetadata;
7172
import org.springframework.ai.chat.metadata.PromptMetadata.PromptFilterMetadata;
7273
import org.springframework.ai.chat.metadata.Usage;
74+
import org.springframework.ai.chat.metadata.UsageUtils;
7375
import org.springframework.ai.chat.model.AbstractToolCallSupport;
7476
import org.springframework.ai.chat.model.ChatModel;
7577
import org.springframework.ai.chat.model.ChatResponse;
@@ -105,6 +107,7 @@
105107
* @author timostark
106108
* @author Soby Chacko
107109
* @author Jihoon Kim
110+
* @author Ilayaperumal Gopinathan
108111
* @see ChatModel
109112
* @see com.azure.ai.openai.OpenAIClient
110113
* @since 1.0.0
@@ -176,10 +179,10 @@ public AzureOpenAiChatModel(OpenAIClientBuilder openAIClientBuilder, AzureOpenAi
176179
this.observationRegistry = observationRegistry;
177180
}
178181

179-
public static ChatResponseMetadata from(ChatCompletions chatCompletions, PromptMetadata promptFilterMetadata) {
182+
public static ChatResponseMetadata from(ChatCompletions chatCompletions, PromptMetadata promptFilterMetadata,
183+
Usage usage) {
180184
Assert.notNull(chatCompletions, "Azure OpenAI ChatCompletions must not be null");
181185
String id = chatCompletions.getId();
182-
Usage usage = (chatCompletions.getUsage() != null) ? AzureOpenAiUsage.from(chatCompletions) : new EmptyUsage();
183186
return ChatResponseMetadata.builder()
184187
.withId(id)
185188
.withUsage(usage)
@@ -189,12 +192,40 @@ public static ChatResponseMetadata from(ChatCompletions chatCompletions, PromptM
189192
.build();
190193
}
191194

195+
public static ChatResponseMetadata from(ChatCompletions chatCompletions, PromptMetadata promptFilterMetadata) {
196+
Usage usage = (chatCompletions.getUsage() != null) ? AzureOpenAiUsage.from(chatCompletions) : new EmptyUsage();
197+
return from(chatCompletions, promptFilterMetadata, usage);
198+
}
199+
200+
public static ChatResponseMetadata from(ChatCompletions chatCompletions, PromptMetadata promptFilterMetadata,
201+
CompletionsUsage usage) {
202+
return from(chatCompletions, promptFilterMetadata, AzureOpenAiUsage.from(usage));
203+
}
204+
205+
public static ChatResponseMetadata from(ChatResponse chatResponse, Usage usage) {
206+
Assert.notNull(chatResponse, "ChatResponse must not be null");
207+
ChatResponseMetadata chatResponseMetadata = chatResponse.getMetadata();
208+
ChatResponseMetadata.Builder builder = ChatResponseMetadata.builder();
209+
builder.withId(chatResponseMetadata.getId())
210+
.withUsage(usage)
211+
.withModel(chatResponseMetadata.getModel())
212+
.withPromptMetadata(chatResponseMetadata.getPromptMetadata());
213+
if (chatResponseMetadata.containsKey("system-fingerprint")) {
214+
builder.withKeyValue("system-fingerprint", chatResponseMetadata.get("system-fingerprint"));
215+
}
216+
return builder.build();
217+
}
218+
192219
public AzureOpenAiChatOptions getDefaultOptions() {
193220
return AzureOpenAiChatOptions.fromOptions(this.defaultOptions);
194221
}
195222

196223
@Override
197224
public ChatResponse call(Prompt prompt) {
225+
return this.internalCall(prompt, null);
226+
}
227+
228+
public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatResponse) {
198229

199230
ChatModelObservationContext observationContext = ChatModelObservationContext.builder()
200231
.prompt(prompt)
@@ -210,7 +241,7 @@ public ChatResponse call(Prompt prompt) {
210241
ChatCompletionsOptionsAccessHelper.setStream(options, false);
211242

212243
ChatCompletions chatCompletions = this.openAIClient.getChatCompletions(options.getModel(), options);
213-
ChatResponse chatResponse = toChatResponse(chatCompletions);
244+
ChatResponse chatResponse = toChatResponse(chatCompletions, previousChatResponse);
214245
observationContext.setResponse(chatResponse);
215246
return chatResponse;
216247
});
@@ -220,14 +251,18 @@ && isToolCall(response, Set.of(String.valueOf(CompletionsFinishReason.TOOL_CALLS
220251
var toolCallConversation = handleToolCalls(prompt, response);
221252
// Recursively call the call method with the tool call message
222253
// conversation that contains the call responses.
223-
return this.call(new Prompt(toolCallConversation, prompt.getOptions()));
254+
return this.internalCall(new Prompt(toolCallConversation, prompt.getOptions()), response);
224255
}
225256

226257
return response;
227258
}
228259

229260
@Override
230261
public Flux<ChatResponse> stream(Prompt prompt) {
262+
return this.internalStream(prompt, null);
263+
}
264+
265+
public Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse previousChatResponse) {
231266

232267
return Flux.deferContextual(contextView -> {
233268
ChatCompletionsOptions options = toAzureChatCompletionsOptions(prompt);
@@ -279,16 +314,36 @@ public Flux<ChatResponse> stream(Prompt prompt) {
279314
})
280315
.flatMap(mono -> mono);
281316

282-
return accessibleChatCompletionsFlux.switchMap(chatCompletions -> {
283-
284-
ChatResponse chatResponse = toChatResponse(chatCompletions);
317+
final Flux<ChatResponse> chatResponseFlux = accessibleChatCompletionsFlux.map(chatCompletion -> {
318+
if (previousChatResponse == null) {
319+
return toChatResponse(chatCompletion);
320+
}
321+
// Accumulate the usage from the previous chat response
322+
CompletionsUsage usage = chatCompletion.getUsage();
323+
Usage currentChatResponseUsage = usage != null ? AzureOpenAiUsage.from(usage) : new EmptyUsage();
324+
Usage accumulatedUsage = UsageUtils.getCumulativeUsage(currentChatResponseUsage, previousChatResponse);
325+
return toChatResponse(chatCompletion, accumulatedUsage);
326+
}).buffer(2, 1).map(bufferList -> {
327+
ChatResponse chatResponse1 = bufferList.get(0);
328+
if (options.getStreamOptions() != null && options.getStreamOptions().isIncludeUsage()) {
329+
if (bufferList.size() == 2) {
330+
ChatResponse chatResponse2 = bufferList.get(1);
331+
if (chatResponse2 != null && chatResponse2.getMetadata() != null
332+
&& !UsageUtils.isEmpty(chatResponse2.getMetadata().getUsage())) {
333+
return toChatResponse(chatResponse1, chatResponse2.getMetadata().getUsage());
334+
}
335+
}
336+
}
337+
return chatResponse1;
338+
});
285339

340+
return chatResponseFlux.flatMap(chatResponse -> {
286341
if (!isProxyToolCalls(prompt, this.defaultOptions) && isToolCall(chatResponse,
287342
Set.of(String.valueOf(CompletionsFinishReason.TOOL_CALLS).toLowerCase()))) {
288343
var toolCallConversation = handleToolCalls(prompt, chatResponse);
289344
// Recursively call the call method with the tool call message
290345
// conversation that contains the call responses.
291-
return this.stream(new Prompt(toolCallConversation, prompt.getOptions()));
346+
return this.internalStream(new Prompt(toolCallConversation, prompt.getOptions()), chatResponse);
292347
}
293348

294349
Flux<ChatResponse> flux = Flux.just(chatResponse)
@@ -305,6 +360,44 @@ public Flux<ChatResponse> stream(Prompt prompt) {
305360

306361
private ChatResponse toChatResponse(ChatCompletions chatCompletions) {
307362

363+
List<Generation> generations = nullSafeList(chatCompletions.getChoices()).stream().map(choice -> {
364+
// @formatter:off
365+
Map<String, Object> metadata = Map.of(
366+
"id", chatCompletions.getId() != null ? chatCompletions.getId() : "",
367+
"choiceIndex", choice.getIndex(),
368+
"finishReason", choice.getFinishReason() != null ? String.valueOf(choice.getFinishReason()) : "");
369+
// @formatter:on
370+
return buildGeneration(choice, metadata);
371+
}).toList();
372+
373+
PromptMetadata promptFilterMetadata = generatePromptMetadata(chatCompletions);
374+
375+
return new ChatResponse(generations, from(chatCompletions, promptFilterMetadata));
376+
}
377+
378+
private ChatResponse toChatResponse(ChatCompletions chatCompletions, Usage usage) {
379+
380+
List<Generation> generations = nullSafeList(chatCompletions.getChoices()).stream().map(choice -> {
381+
// @formatter:off
382+
Map<String, Object> metadata = Map.of(
383+
"id", chatCompletions.getId() != null ? chatCompletions.getId() : "",
384+
"choiceIndex", choice.getIndex(),
385+
"finishReason", choice.getFinishReason() != null ? String.valueOf(choice.getFinishReason()) : "");
386+
// @formatter:on
387+
return buildGeneration(choice, metadata);
388+
}).toList();
389+
390+
PromptMetadata promptFilterMetadata = generatePromptMetadata(chatCompletions);
391+
392+
return new ChatResponse(generations, from(chatCompletions, promptFilterMetadata, usage));
393+
}
394+
395+
private ChatResponse toChatResponse(ChatResponse chatResponse, Usage usage) {
396+
return new ChatResponse(chatResponse.getResults(), from(chatResponse, usage));
397+
}
398+
399+
private ChatResponse toChatResponse(ChatCompletions chatCompletions, ChatResponse previousChatResponse) {
400+
308401
List<Generation> generations = nullSafeList(chatCompletions.getChoices()).stream().map(choice -> {
309402
// @formatter:off
310403
Map<String, Object> metadata = Map.of(
@@ -316,8 +409,12 @@ private ChatResponse toChatResponse(ChatCompletions chatCompletions) {
316409
}).toList();
317410

318411
PromptMetadata promptFilterMetadata = generatePromptMetadata(chatCompletions);
319-
320-
return new ChatResponse(generations, from(chatCompletions, promptFilterMetadata));
412+
Usage currentUsage = null;
413+
if (chatCompletions.getUsage() != null) {
414+
currentUsage = AzureOpenAiUsage.from(chatCompletions);
415+
}
416+
Usage cumulativeUsage = UsageUtils.getCumulativeUsage(currentUsage, previousChatResponse);
417+
return new ChatResponse(generations, from(chatCompletions, promptFilterMetadata, cumulativeUsage));
321418
}
322419

323420
private Generation buildGeneration(ChatChoice choice, Map<String, Object> metadata) {

models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/function/AzureOpenAiChatModelFunctionCallIT.java

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import java.util.stream.Collectors;
2525

2626
import com.azure.ai.openai.OpenAIClientBuilder;
27+
import com.azure.ai.openai.models.ChatCompletionStreamOptions;
2728
import com.azure.core.credential.AzureKeyCredential;
2829
import org.junit.jupiter.api.Test;
2930
import org.slf4j.Logger;
@@ -80,7 +81,12 @@ void functionCallTest() {
8081

8182
logger.info("Response: {}", response);
8283

84+
assertThat(response.getResult()).isNotNull();
85+
assertThat(response.getResult().getOutput()).isNotNull();
8386
assertThat(response.getResult().getOutput().getText()).contains("30", "10", "15");
87+
assertThat(response.getMetadata()).isNotNull();
88+
assertThat(response.getMetadata().getUsage()).isNotNull();
89+
assertThat(response.getMetadata().getUsage().getTotalTokens()).isGreaterThan(600).isLessThan(800);
8490
}
8591

8692
@Test
@@ -142,6 +148,34 @@ void streamFunctionCallTest() {
142148

143149
}
144150

151+
@Test
152+
void streamFunctionCallUsageTest() {
153+
UserMessage userMessage = new UserMessage("What's the weather like in San Francisco, Tokyo, and Paris?");
154+
155+
List<Message> messages = new ArrayList<>(List.of(userMessage));
156+
157+
ChatCompletionStreamOptions streamOptions = new ChatCompletionStreamOptions();
158+
streamOptions.setIncludeUsage(true);
159+
160+
var promptOptions = AzureOpenAiChatOptions.builder()
161+
.withDeploymentName(this.selectedModel)
162+
.withFunctionCallbacks(List.of(FunctionCallback.builder()
163+
.function("getCurrentWeather", new MockWeatherService())
164+
.description("Get the current weather in a given location")
165+
.inputType(MockWeatherService.Request.class)
166+
.build()))
167+
.withStreamOptions(streamOptions)
168+
.build();
169+
170+
Flux<ChatResponse> response = this.chatModel.stream(new Prompt(messages, promptOptions));
171+
172+
ChatResponse chatResponse = response.last().block();
173+
logger.info("Response: {}", chatResponse);
174+
175+
assertThat(chatResponse.getMetadata().getUsage().getTotalTokens()).isGreaterThan(600).isLessThan(800);
176+
177+
}
178+
145179
@Test
146180
void functionCallSequentialAndStreamTest() {
147181

0 commit comments

Comments
 (0)